Source code for xuance.tensorflow.representations.mlp
import numpy as np
from xuance.common import Sequence, Optional, Union
from xuance.tensorflow import tf, tk, Module, Tensor
from xuance.tensorflow.utils.layers import mlp_block
from xuance.tensorflow.utils import ModuleType
# directly returns the original observation
[docs]
class Basic_Identical(Module):
def __init__(self,
input_shape: Sequence[int], **kwargs):
super(Basic_Identical, self).__init__()
self.input_shapes = input_shape
self.output_shapes = {'state': (np.prod(input_shape),)}
self.model = tk.Sequential([tk.layers.Flatten()])
@tf.function
def call(self, x: Union[Tensor, np.ndarray], **kwargs):
return {'state': self.model(x)}
[docs]
class Basic_MLP(Module):
def __init__(self,
input_shape: Sequence[int],
hidden_sizes: Sequence[int],
normalize: Optional[ModuleType] = None,
initializer: Optional[tk.initializers.Initializer] = None,
activation: Optional[ModuleType] = None,
**kwargs):
super(Basic_MLP, self).__init__()
self.input_shapes = input_shape
self.hidden_sizes = hidden_sizes
self.normalize = normalize
self.initializer = initializer
self.activation = activation
self.output_shapes = {'state': (hidden_sizes[-1],)}
self.model = self._create_network()
def _create_network(self):
layers = [tk.layers.Flatten()]
input_shape = self.input_shapes
for h in self.hidden_sizes:
mlp, input_shape = mlp_block(input_shape[0], h, self.normalize, self.activation, self.initializer)
layers.extend(mlp)
return tk.Sequential(layers)
@tf.function
def call(self, x: Union[Tensor, np.ndarray], **kwargs):
input_shape = x.shape
x_flat = tf.reshape(x, (-1, input_shape[-1]))
y_flat = self.model(x_flat)
return {'state': tf.reshape(y_flat, input_shape[:-1] + self.output_shapes['state'])}