Source code for xuance.torch.representations.mlp
from xuance.common import Sequence, Optional, Union, Callable
from xuance.torch import Module, Tensor
from xuance.torch.utils import torch, nn, mlp_block, ModuleType
# directly returns the original observation
[docs]
class Basic_Identical(Module):
def __init__(self,
input_shape: Sequence[int],
device: Optional[Union[str, int, torch.device]] = None,
**kwargs):
super(Basic_Identical, self).__init__()
assert len(input_shape) == 1
self.output_shapes = {'state': (input_shape[0],)}
self.device = device
[docs]
def forward(self, observations: Tensor) -> dict[str, Tensor]:
return {
'state': torch.as_tensor(observations, dtype=torch.float32, device=self.device)
}
# process the input observations with stacks of MLP layers
[docs]
class Basic_MLP(Module):
def __init__(self,
input_shape: Sequence[int],
hidden_sizes: Sequence[int],
normalize: Optional[ModuleType] = None,
initialize: Optional[Callable[..., Tensor]] = None,
activation: Optional[ModuleType] = None,
device: Optional[Union[str, int, torch.device]] = None,
**kwargs):
super(Basic_MLP, self).__init__()
self.input_shape = input_shape
self.hidden_sizes = hidden_sizes
self.normalize = normalize
self.initialize = initialize
self.activation = activation
self.device = device
self.output_shapes = {'state': (hidden_sizes[-1],)}
self.model = self._create_network()
def _create_network(self):
layers = []
input_shape = self.input_shape
for h in self.hidden_sizes:
mlp, input_shape = mlp_block(input_shape[0], h, self.normalize, self.activation, self.initialize,
device=self.device)
layers.extend(mlp)
return nn.Sequential(*layers)
[docs]
def forward(self, observations: Tensor) -> dict[str, Tensor]:
tensor_observation = torch.as_tensor(observations, dtype=torch.float32, device=self.device)
return {'state': self.model(tensor_observation)}