Source code for xuance.mindspore.representations.mlp
from mindspore import nn
from xuance.common import Sequence, Optional, Callable
from xuance.mindspore import Module, Tensor
from xuance.mindspore.utils import mlp_block, ModuleType
# directly returns the original observation
[docs]
class Basic_Identical(Module):
def __init__(self,
input_shape: Sequence[int],
**kwargs):
super(Basic_Identical, self).__init__()
assert len(input_shape) == 1
self.output_shapes = {'state': (input_shape[0],)}
[docs]
def construct(self, observations: Tensor):
return observations
# process the input observations with stacks of MLP layers
[docs]
class Basic_MLP(Module):
def __init__(self,
input_shape: Sequence[int],
hidden_sizes: Sequence[int],
normalize: Optional[ModuleType] = None,
initialize: Optional[Callable[..., Tensor]] = None,
activation: Optional[ModuleType] = None,
**kwargs):
super(Basic_MLP, self).__init__()
self.input_shape = input_shape
self.hidden_sizes = hidden_sizes
self.normalize = normalize
self.initialize = initialize
self.activation = activation
self.output_shapes = {'state': (hidden_sizes[-1],)}
self.model = self._create_network()
def _create_network(self):
layers = []
input_shape = self.input_shape
for h in self.hidden_sizes:
mlp, input_shape = mlp_block(input_shape[0], h, self.normalize, self.activation, self.initialize)
layers.extend(mlp)
return nn.SequentialCell(*layers)
[docs]
def construct(self, observations: Tensor):
return self.model(observations)