Source code for xuance.torch.utils.layers4dreamer

"""
Adapted from: https://github.com/thu-ml/tianshou/blob/master/tianshou/utils/net/common.py
"""

from typing import Any, Dict, List, Optional, Tuple, Type, Union

import torch
from torch import Tensor, nn

ModuleType = Optional[Type[nn.Module]]
ArgType = Union[Tuple[Any, ...], Dict[Any, Any], None]
ArgsType = Union[ArgType, List[ArgType]]


[docs] def create_layer_with_args(layer_type: ModuleType, layer_args: Optional[ArgType]) -> nn.Module: """Create a single layer with given layer type and arguments. Args: layer_type (ModuleType): the type of the layer to be created. layer_args (ArgType, optional): the arguments to be passed to the layer. """ if layer_type is None: raise ValueError("`layer_type` must be not None") if isinstance(layer_args, tuple): return layer_type(*layer_args) elif isinstance(layer_args, dict): return layer_type(**layer_args) elif layer_args is None: return layer_type() else: raise ValueError(f"`layer_args` must be None, tuple or dict, got {type(layer_args)}")
[docs] def miniblock( input_size: int, output_size: int, layer_type: Type[nn.Module] = nn.Linear, layer_args: ArgType = None, dropout_layer: ModuleType = None, dropout_args: ArgType = None, norm_layer: ModuleType = None, norm_args: ArgType = None, activation: ModuleType = None, act_args: ArgType = None, ) -> List[nn.Module]: """Construct a miniblock with given input/output-size, dropout layer, norm layer and activation function. Based on Tianshou's miniblock function (https://github.com/thu-ml/tianshou/blob/master/tianshou/utils/net/common.py). Args: input_size (int): the input size of the miniblock (in_features for Linear and in_channels for Conv2d). output_size (int): the output size of the miniblock. layer_type (Type[nn.Linear], optional): the type of the layer to be created. Defaults to nn.Linear. layer_args (ArgType, optional): the arguments to be passed to the layer. Defaults to None. dropout_layer (ModuleType, optional): the type of the dropout layer to be created. Defaults to None. dropout_args (ArgType, optional): the arguments to be passed to the dropout layer. Defaults to None. norm_layer (ModuleType, optional): the type of the norm layer to be created. Defaults to None. norm_args (ArgType, optional): the arguments to be passed to the norm layer. Defaults to None. activation (ModuleType, optional): the type of the activation function to be created. Defaults to None. act_args (Tuple[Any, ...] | Dict[Any, Any] | None, optional): the arguments to be passed to the activation function. Defaults to None. Returns: List[nn.Module]: the miniblock as a list of layers. """ if layer_args is None: layers: List[nn.Module] = [layer_type(input_size, output_size)] elif isinstance(layer_args, tuple): layers = [layer_type(input_size, output_size, *layer_args)] elif isinstance(layer_args, dict): layers = [layer_type(input_size, output_size, **layer_args)] else: raise ValueError(f"layer_args must be None, tuple or dict, got {type(layer_args)}") if dropout_layer is not None: layers += [create_layer_with_args(dropout_layer, dropout_args)] if norm_layer is not None: layers += [create_layer_with_args(norm_layer, norm_args)] if activation is not None: layers += [create_layer_with_args(activation, act_args)] return layers
[docs] def create_layers( layer_type: Union[ModuleType, List[ModuleType]], layer_args: Optional[ArgsType], num_layers: int ) -> Tuple[List[ModuleType], ArgsType]: """Create a list of layers with given layer type and arguments. If a layer_type is not specified, then the lists will be filled with None. If the layer type or the layer arguments are specified only once, they will be cast to a sequence of length num_layers. Args: layer_type (Union[ModuleType, Sequence[ModuleType]]): the type of the layer to be created. layer_args (ArgsType, optional): the arguments to be passed to the layer. num_layers (int): the number of layers to be created. Returns: Tuple[Sequence[ModuleType], ArgsType]: a list of layers and a list of args. Examples: >>> create_layers(nn.Linear, None, 3) ([nn.Linear, nn.Linear, nn.Linear], [None, None, None]) >>> create_layers(nn.Linear, {"arg1":3, "arg2": "foo"}, 3) ( [nn.Linear, nn.Linear, nn.Linear], [{'arg1': 3, 'arg2': 'foo'}, {'arg1': 3, 'arg2': 'foo'}, {'arg1': 3, 'arg2': 'foo'}] ) >>> create_layers([nn.Linear, nn.Conv2d], [{"bias":False}, {"kernel_size": 5, "bias": True}], 2) ([nn.Linear, nn.Conv2d], [{'bias': False}, {'kernel_size':5, 'bias': True}]) >>> create_layers([nn.Linear, nn.Linear], (64, 10), 2) ([nn.Linear, nn.Linear], [(64, 10), (64, 10)]) """ if layer_type is None: layers_list = [None] * num_layers args_list = [None] * num_layers return layers_list, args_list if isinstance(layer_type, list): assert len(layer_type) == num_layers layers_list = layer_type if isinstance(layer_args, list): assert len(layer_args) == num_layers args_list = layer_args else: args_list = [layer_args for _ in range(num_layers)] else: layers_list = [layer_type for _ in range(num_layers)] args_list = [layer_args for _ in range(num_layers)] return layers_list, args_list
[docs] def per_layer_ortho_init_weights(module: nn.Module, gain: float = 1.0, bias: float = 0.0): """Initialize the weights of a module with orthogonal weights. Args: module (nn.Module): module to initialize gain (float, optional): gain of the orthogonal initialization. Defaults to 1.0. bias (float, optional): bias of the orthogonal initialization. Defaults to 0.0. """ if isinstance(module, nn.Linear): nn.init.orthogonal_(module.weight, gain=gain) if module.bias is not None: module.bias.data.fill_(bias) elif isinstance(module, nn.LSTM): for name, param in module.named_parameters(): if "bias" in name: nn.init.constant_(param, val=bias) elif "weight" in name: nn.init.orthogonal_(param, gain=gain) elif isinstance(module, (nn.Sequential, nn.ModuleList)): for i in range(len(module)): per_layer_ortho_init_weights(module[i], gain=gain, bias=bias)
[docs] def cnn_forward( model: nn.Module, input: Tensor, input_dim: Union[torch.Size, Tuple[int, ...]], output_dim: Union[torch.Size, Tuple[int, ...]], ) -> Tensor: """ Compute the forward of a Convolutional neural network. It flattens all the dimensions before the model input_size, i.e., the dimensions before the (C_in, H, W) dimensions for the encoder and the dimensions before the (feature_size,) dimension for the decoder. Args: model (nn.Module): the model. input (Tensor): the input tensor of dimension `(*, C_in, H, W)` or `(*, feature_size)`, where `*` means any number of dimensions including None. input_dim (Union[torch.Size, Tuple[int, ...]]): the input dimensions, i.e., either (C_in, H, W) or (feature_size,). output_dim (Union[torch.Size, Tuple[int, ...]]): the desired dimensions in output. Returns: The output of dimensions (*, *output_dim). Examples: >>> encoder CNN( (network): Sequential( (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU() (2): Conv2d(4, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): ReLU() (4): Flatten(start_dim=1, end_dim=-1) (5): Linear(in_features=128, out_features=25, bias=True) ) ) >>> input = torch.rand(10, 20, 3, 4, 4) >>> cnn_forward(encoder, input, (3, 4, 4), -1).shape torch.Size([10, 20, 25]) >>> decoder Sequential( (0): Linear(in_features=230, out_features=1024, bias=True) (1): Unflatten(dim=-1, unflattened_size=(1024, 1, 1)) (2): ConvTranspose2d(1024, 128, kernel_size=(5, 5), stride=(2, 2)) (3): ReLU() (4): ConvTranspose2d(128, 64, kernel_size=(5, 5), stride=(2, 2)) (5): ReLU() (6): ConvTranspose2d(64, 32, kernel_size=(6, 6), stride=(2, 2)) (7): ReLU() (8): ConvTranspose2d(32, 3, kernel_size=(6, 6), stride=(2, 2)) ) >>> input = torch.rand(10, 20, 230) >>> cnn_forward(decoder, input, (230,), (3, 64, 64)).shape torch.Size([10, 20, 3, 64, 64]) """ batch_shapes = input.shape[: -len(input_dim)] flatten_input = input.reshape(-1, *input_dim) model_out = model(flatten_input) return model_out.reshape(*batch_shapes, *output_dim)
""" Adapted from: https://github.com/thu-ml/tianshou/blob/master/tianshou/utils/net/common.py """ import warnings from math import prod from typing import Any, Callable, Dict, Optional, Sequence, Union, no_type_check import torch import torch.nn.functional as F from torch import Tensor, nn
[docs] class MLP(nn.Module): """Simple MLP backbone. Args: input_dims (Union[int, Sequence[int]]): dimensions of the input vector. output_dim (int, optional): dimension of the output vector. If set to None, there is no final linear layer. Else, a final linear layer is added. Defaults to None. hidden_sizes (Sequence[int], optional): shape of MLP passed in as a list, not including input_dims and output_dim. dropout_layer (Union[ModuleType, Sequence[ModuleType]], optional): which dropout layer to be used before activation (possibly before the normalization layer), e.g., ``nn.Dropout``. You can also pass a list of dropout modules with the same length of hidden_sizes to use different dropout modules in different layers. If None, then no dropout layer is used. Defaults to None. norm_layer (Union[ModuleType, Sequence[ModuleType]], optional): which normalization layer to be used before activation, e.g., ``nn.LayerNorm`` and ``nn.BatchNorm1d``. You can also pass a list of normalization modules with the same length of hidden_sizes to use different normalization modules in different layers. If None, then no normalization layer is used. Defaults to None. activation (Union[ModuleType, Sequence[ModuleType]], optional): which activation to use after each layer, can be both the same activation for all layers if a single ``nn.Module`` is passed, or different activations for different layers if a list is passed. Defaults to ``nn.ReLU``. flatten_dim (int, optional): whether to flatten input data. The flatten dimension starts from 1. Defaults to True. """ def __init__( self, input_dims: Union[int, Sequence[int]], output_dim: Optional[int] = None, hidden_sizes: Sequence[int] = (), layer_args: Optional[ArgsType] = None, dropout_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None, dropout_args: Optional[ArgsType] = None, norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None, norm_args: Optional[ArgsType] = None, activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU, act_args: Optional[ArgsType] = None, flatten_dim: Optional[int] = None, ) -> None: super().__init__() num_layers = len(hidden_sizes) if num_layers < 1 and output_dim is None: raise ValueError("The number of layers should be at least 1.") if isinstance(input_dims, Sequence) and flatten_dim is None: warnings.warn( "input_dims is a sequence, but flatten_dim is not specified. " "Be careful to flatten the input data correctly before the forward." ) dropout_layer_list, dropout_args_list = create_layers(dropout_layer, dropout_args, num_layers) norm_layer_list, norm_args_list = create_layers(norm_layer, norm_args, num_layers) activation_list, act_args_list = create_layers(activation, act_args, num_layers) if isinstance(layer_args, list): layer_args_list = layer_args else: layer_args_list = [layer_args] * num_layers if isinstance(input_dims, int): input_dims = [input_dims] hidden_sizes = [prod(input_dims)] + list(hidden_sizes) model = [] for in_dim, out_dim, l_args, drop, drop_args, norm, norm_args, activ, act_args in zip( hidden_sizes[:-1], hidden_sizes[1:], layer_args_list, dropout_layer_list, dropout_args_list, norm_layer_list, norm_args_list, activation_list, act_args_list, ): model += miniblock(in_dim, out_dim, nn.Linear, l_args, drop, drop_args, norm, norm_args, activ, act_args) if output_dim is not None: model += [nn.Linear(hidden_sizes[-1], output_dim)] self._output_dim = output_dim or hidden_sizes[-1] self._model = nn.Sequential(*model) self._flatten_dim = flatten_dim @property def model(self) -> nn.Module: return self._model @property def output_dim(self) -> int: return self._output_dim @property def flatten_dim(self) -> Optional[int]: return self._flatten_dim
[docs] @no_type_check def forward(self, obs: Tensor) -> Tensor: if self.flatten_dim is not None: obs = obs.flatten(self.flatten_dim) return self.model(obs)
[docs] class CNN(nn.Module): """Simple CNN backbone. Args: input_channels (int): dimensions of the input channels. hidden_channels (Sequence[int], optional): intermediate number of channels for the CNN, including the output channels. dropout_layer (Union[ModuleType, Sequence[ModuleType]], optional): which dropout layer to be used before activation (possibly before the normalization layer), e.g., ``nn.Dropout``. You can also pass a list of dropout modules with the same length of hidden_sizes to use different dropout modules in different layers. If None, then no dropout layer is used. Defaults to None. norm_layer (Union[ModuleType, Sequence[ModuleType]], optional): which normalization layer to be used before activation, e.g., ``nn.LayerNorm`` and ``nn.BatchNorm1d``. You can also pass a list of normalization modules with the same length of hidden_sizes to use different normalization modules in different layers. If None, then no normalization layer is used. Defaults to None. activation (Union[ModuleType, Sequence[ModuleType]], optional): which activation to use after each layer, can be both the same activation for all layers if a single ``nn.Module`` is passed, or different activations for different layers if a list is passed. Defaults to ``nn.ReLU``. """ def __init__( self, input_channels: int, hidden_channels: Sequence[int], cnn_layer: ModuleType = nn.Conv2d, layer_args: ArgsType = None, dropout_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None, dropout_args: Optional[ArgsType] = None, norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None, norm_args: Optional[ArgsType] = None, activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU, act_args: Optional[ArgsType] = None, ) -> None: super().__init__() num_layers = len(hidden_channels) if num_layers < 1: raise ValueError("The number of layers should be at least 1.") dropout_layer_list, dropout_args_list = create_layers(dropout_layer, dropout_args, num_layers) norm_layer_list, norm_args_list = create_layers(norm_layer, norm_args, num_layers) activation_list, act_args_list = create_layers(activation, act_args, num_layers) if isinstance(layer_args, list): layer_args_list = layer_args else: layer_args_list = [layer_args] * num_layers hidden_sizes = [input_channels] + list(hidden_channels) model = [] for in_dim, out_dim, l_args, drop, drop_args, norm, norm_args, activ, act_args in zip( hidden_sizes[:-1], hidden_sizes[1:], layer_args_list, dropout_layer_list, dropout_args_list, norm_layer_list, norm_args_list, activation_list, act_args_list, ): model += miniblock(in_dim, out_dim, cnn_layer, l_args, drop, drop_args, norm, norm_args, activ, act_args) self._output_dim = hidden_sizes[-1] self._model = nn.Sequential(*model) @property def model(self) -> nn.Module: return self._model @property def output_dim(self) -> int: return self._output_dim
[docs] @no_type_check def forward(self, obs: Tensor) -> Tensor: return self.model(obs)
[docs] class DeCNN(nn.Module): """Simple DeCNN backbone. Args: input_channels (int): dimensions of the input channels. hidden_channels (Sequence[int], optional): intermediate number of channels for the CNN, including the output channels. dropout_layer (Union[ModuleType, Sequence[ModuleType]], optional): which dropout layer to be used before activation (possibly before the normalization layer), e.g., ``nn.Dropout``. You can also pass a list of dropout modules with the same length of hidden_sizes to use different dropout modules in different layers. If None, then no dropout layer is used. Defaults to None. norm_layer (Union[ModuleType, Sequence[ModuleType]], optional): which normalization layer to be used before activation, e.g., ``nn.LayerNorm`` and ``nn.BatchNorm1d``. You can also pass a list of normalization modules with the same length of hidden_sizes to use different normalization modules in different layers. If None, then no normalization layer is used. Defaults to None. activation (Union[ModuleType, Sequence[ModuleType]], optional): which activation to use after each layer, can be both the same activation for all layers if a single ``nn.Module`` is passed, or different activations for different layers if a list is passed. Defaults to ``nn.ReLU``. """ def __init__( self, input_channels: int, hidden_channels: Sequence[int] = (), cnn_layer: ModuleType = nn.ConvTranspose2d, layer_args: ArgsType = None, dropout_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None, dropout_args: Optional[ArgsType] = None, norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None, norm_args: Optional[ArgsType] = None, activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU, act_args: Optional[ArgsType] = None, ) -> None: super().__init__() num_layers = len(hidden_channels) if num_layers < 1: raise ValueError("The number of layers should be at least 1.") dropout_layer_list, dropout_args_list = create_layers(dropout_layer, dropout_args, num_layers) norm_layer_list, norm_args_list = create_layers(norm_layer, norm_args, num_layers) activation_list, act_args_list = create_layers(activation, act_args, num_layers) if isinstance(layer_args, list): layer_args_list = layer_args else: layer_args_list = [layer_args] * num_layers hidden_sizes = [input_channels] + list(hidden_channels) model = [] for in_dim, out_dim, l_args, drop, drop_args, norm, norm_args, activ, act_args in zip( hidden_sizes[:-1], hidden_sizes[1:], layer_args_list, dropout_layer_list, dropout_args_list, norm_layer_list, norm_args_list, activation_list, act_args_list, ): model += miniblock(in_dim, out_dim, cnn_layer, l_args, drop, drop_args, norm, norm_args, activ, act_args) self._output_dim = hidden_sizes[-1] self._model = nn.Sequential(*model) @property def model(self) -> nn.Module: return self._model @property def output_dim(self) -> int: return self._output_dim
[docs] @no_type_check def forward(self, obs: Tensor) -> Tensor: return self.model(obs)
[docs] class NatureCNN(CNN): """CNN from DQN Nature paper: Mnih, Volodymyr, et al. "Human-level control through deep reinforcement learning." Nature 518.7540 (2015): 529-533. Args: in_channels (int): the input channels to the first convolutional layer features_dim (int): the features dimension in output from the last convolutional layer screen_size (int, optional): the dimension of the input image as a single integer. Needed to extract the features and compute the output dimension after all the convolutional layers. Defaults to 64. """ def __init__(self, in_channels: int, features_dim: int, screen_size: int = 64): super().__init__( in_channels, [32, 64, 64], layer_args=[ {"kernel_size": 8, "stride": 4}, {"kernel_size": 4, "stride": 2}, {"kernel_size": 3, "stride": 1}, ], ) with torch.no_grad(): x = self.model(torch.rand(1, in_channels, screen_size, screen_size, device=self.model[0].weight.device)) out_dim = x.flatten(1).shape[1] self._output_dim = out_dim self.fc = None if features_dim is not None: self._output_dim = features_dim self.fc = nn.Linear(out_dim, features_dim) @property def output_dim(self) -> int: return self._output_dim
[docs] def forward(self, x: Tensor) -> Tensor: x = cnn_forward(self.model, x, input_dim=x.shape[-3:], output_dim=(-1,)) x = F.relu(self.fc(x)) return x
[docs] class LayerNormGRUCell(nn.Module): """A GRU cell with a LayerNorm, taken from https://github.com/danijar/dreamerv2/blob/main/dreamerv2/common/nets.py#L317. This particular GRU cell accepts 3-D inputs, with a sequence of length 1, and applies a LayerNorm after the projection of the inputs. Args: input_size (int): the input size. hidden_size (int): the hidden state size bias (bool, optional): whether to apply a bias to the input projection. Defaults to True. batch_first (bool, optional): whether the first dimension represent the batch dimension or not. Defaults to False. layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection. Defaults to nn.Identiy. layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm. Default to {}. """ def __init__( self, input_size: int, hidden_size: int, bias: bool = True, batch_first: bool = False, layer_norm_cls: Callable[..., nn.Module] = nn.Identity, layer_norm_kw: Dict[str, Any] = {}, ) -> None: super().__init__() self.input_size = input_size self.hidden_size = hidden_size self.bias = bias self.batch_first = batch_first self.linear = nn.Linear(input_size + hidden_size, 3 * hidden_size, bias=self.bias) # Avoid multiple values for the `normalized_shape` argument layer_norm_kw.pop("normalized_shape", None) self.layer_norm = layer_norm_cls(3 * hidden_size, **layer_norm_kw)
[docs] def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: is_3d = input.dim() == 3 if is_3d: if input.shape[int(self.batch_first)] == 1: input = input.squeeze(int(self.batch_first)) else: raise AssertionError( "LayerNormGRUCell: Expected input to be 3-D with sequence length equal to 1 but received " f"a sequence of length {input.shape[int(self.batch_first)]}" ) if hx.dim() == 3: hx = hx.squeeze(0) assert input.dim() in ( 1, 2, ), f"LayerNormGRUCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor" is_batched = input.dim() == 2 if not is_batched: input = input.unsqueeze(0) if hx is None: hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) else: hx = hx.unsqueeze(0) if not is_batched else hx input = torch.cat((hx, input), -1) x = self.linear(input) x = self.layer_norm(x) reset, cand, update = torch.chunk(x, 3, -1) reset = torch.sigmoid(reset) cand = torch.tanh(reset * cand) update = torch.sigmoid(update - 1) hx = update * cand + (1 - update) * hx if not is_batched: hx = hx.squeeze(0) elif is_3d: hx = hx.unsqueeze(0) return hx
[docs] class MultiEncoder(nn.Module): def __init__( self, cnn_encoder: ModuleType, mlp_encoder: ModuleType, ) -> None: super().__init__() if cnn_encoder is None and mlp_encoder is None: raise ValueError("There must be at least one encoder, both cnn and mlp encoders are None") self.has_cnn_encoder = False self.has_mlp_encoder = False if cnn_encoder is not None: if getattr(cnn_encoder, "input_dim", None) is None: raise AttributeError( "`cnn_encoder` must contain the `input_dim` attribute representing " "the dimension of the input tensor" ) if getattr(cnn_encoder, "output_dim", None) is None: raise AttributeError( "`cnn_encoder` must contain the `output_dim` attribute representing " "the dimension of the output tensor" ) self.has_cnn_encoder = True if mlp_encoder is not None: if getattr(mlp_encoder, "input_dim", None) is None: raise AttributeError( "`mlp_encoder` must contain the `input_dim` attribute representing " "the dimension of the input tensor" ) if getattr(mlp_encoder, "output_dim", None) is None: raise AttributeError( "`mlp_encoder` must contain the `output_dim` attribute representing " "the dimension of the output tensor" ) self.has_mlp_encoder = True self.has_both_encoders = self.has_cnn_encoder and self.has_mlp_encoder self.cnn_encoder = cnn_encoder self.mlp_encoder = mlp_encoder self.cnn_input_dim = self.cnn_encoder.input_dim if self.cnn_encoder is not None else None self.mlp_input_dim = self.mlp_encoder.input_dim if self.mlp_encoder is not None else None self.cnn_output_dim = self.cnn_encoder.output_dim if self.cnn_encoder is not None else 0 self.mlp_output_dim = self.mlp_encoder.output_dim if self.mlp_encoder is not None else 0 self.output_dim = self.cnn_output_dim + self.mlp_output_dim @property def cnn_keys(self) -> Sequence[str]: return self.cnn_encoder.keys if self.cnn_encoder is not None else [] @property def mlp_keys(self) -> Sequence[str]: return self.mlp_encoder.keys if self.mlp_encoder is not None else []
[docs] def forward(self, obs: Dict[str, Tensor], *args, **kwargs) -> Tensor: if self.has_cnn_encoder: cnn_out = self.cnn_encoder(obs, *args, **kwargs) if self.has_mlp_encoder: mlp_out = self.mlp_encoder(obs, *args, **kwargs) if self.has_both_encoders: return torch.cat((cnn_out, mlp_out), -1) elif self.has_cnn_encoder: return cnn_out else: return mlp_out
[docs] class MultiDecoder(nn.Module): def __init__( self, cnn_decoder: ModuleType, mlp_decoder: ModuleType, ) -> None: super().__init__() if cnn_decoder is None and mlp_decoder is None: raise ValueError("There must be an decoder, both cnn and mlp decoders are None") self.cnn_decoder = cnn_decoder self.mlp_decoder = mlp_decoder @property def cnn_keys(self) -> Sequence[str]: return self.cnn_decoder.keys if self.cnn_decoder is not None else [] @property def mlp_keys(self) -> Sequence[str]: return self.mlp_decoder.keys if self.mlp_decoder is not None else [] # adapt to xuance: the return Dict[str, Tensor] modified to -> Tensor
[docs] def forward(self, x: Tensor) -> Tensor: reconstructed_obs = None if self.cnn_decoder is not None: reconstructed_obs = self.cnn_decoder(x)[0] if self.mlp_decoder is not None: reconstructed_obs = self.mlp_decoder(x) return reconstructed_obs
[docs] class LayerNormChannelLast(nn.LayerNorm): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs)
[docs] def forward(self, x: Tensor) -> Tensor: if x.dim() != 4: raise ValueError(f"Input tensor must be 4D (NCHW), received {len(x.shape)}D instead: {x.shape}") input_dtype = x.dtype x = x.permute(0, 2, 3, 1) x = super().forward(x) x = x.permute(0, 3, 1, 2) return x.to(input_dtype)
[docs] class LayerNorm(nn.LayerNorm):
[docs] def forward(self, x: Tensor) -> Tensor: input_dtype = x.dtype out = super().forward(x) return out.to(input_dtype)