from xuance.common import Sequence, Optional, Callable, Tuple
from xuance.mindspore import Module, Tensor
from xuance.mindspore.utils import ms, nn, mlp_block, gru_block, lstm_block, ModuleType
[docs]
class Basic_RNN(Module):
def __init__(self,
input_shape: Sequence[int],
hidden_sizes: dict,
normalize: Optional[Module] = None,
initialize: Optional[Callable[..., Tensor]] = None,
activation: Optional[ModuleType] = None,
**kwargs):
super(Basic_RNN, self).__init__()
self.input_shape = input_shape
self.fc_hidden_sizes = hidden_sizes["fc_hidden_sizes"]
self.recurrent_hidden_size = hidden_sizes["recurrent_hidden_size"]
self.N_recurrent_layer = kwargs["N_recurrent_layers"]
self.dropout = kwargs["dropout"]
self.lstm = True if kwargs["rnn"] == "LSTM" else False
self.normalize = normalize
self.initialize = initialize
self.activation = activation
self.output_shapes = {'state': (hidden_sizes["recurrent_hidden_size"],)}
self.mlp, self.rnn, output_dim = self._create_network()
if self.normalize is not None:
self.use_normalize = True
self.input_norm = self.normalize(input_shape)
self.norm_rnn = self.normalize(output_dim)
else:
self.use_normalize = False
def _create_network(self) -> Tuple[Module, Module, int]:
layers = []
input_shape = self.input_shape
for h in self.fc_hidden_sizes:
mlp_layer, input_shape = mlp_block(input_shape[0], h, self.normalize, self.activation, self.initialize)
layers.extend(mlp_layer)
if self.lstm:
rnn_layer, input_shape = lstm_block(input_shape[0], self.recurrent_hidden_size, self.N_recurrent_layer,
self.dropout, self.initialize)
else:
rnn_layer, input_shape = gru_block(input_shape[0], self.recurrent_hidden_size, self.N_recurrent_layer,
self.dropout, self.initialize)
return nn.SequentialCell(*layers), rnn_layer, input_shape
[docs]
def construct(self, x: Tensor, h: Tensor, c: Tensor = None):
mlp_output = self.mlp(self.input_norm(x)) if self.use_normalize else self.mlp(x)
self.rnn.flatten_parameters()
if self.lstm:
output, (hn, cn) = self.rnn(mlp_output, (h, c))
if self.use_normalize:
output = self.norm_rnn(output)
return {"state": output, "rnn_hidden": hn.detach(), "rnn_cell": cn.detach()}
else:
output, hn = self.rnn(mlp_output, h)
if self.use_normalize:
output = self.norm_rnn(output)
return {"state": output, "rnn_hidden": hn.detach(), "rnn_cell": None}
[docs]
def init_hidden(self, batch):
hidden_states = ms.ops.zeros(size=(self.N_recurrent_layer, batch, self.recurrent_hidden_size))
cell_states = ms.ops.zeros_like(hidden_states) if self.lstm else None
return hidden_states, cell_states
[docs]
def init_hidden_item(self, indexes: list, *rnn_hidden):
zeros_size = (self.N_recurrent_layer, len(indexes), self.recurrent_hidden_size)
if self.lstm:
rnn_hidden[0][:, indexes] = ms.ops.zeros(size=zeros_size)
rnn_hidden[1][:, indexes] = ms.ops.zeros(size=zeros_size)
return rnn_hidden
else:
rnn_hidden[0][:, indexes] = ms.ops.zeros(size=zeros_size)
return rnn_hidden
[docs]
def get_hidden_item(self, i, *rnn_hidden):
return (rnn_hidden[0][:, i], rnn_hidden[1][:, i]) if self.lstm else (rnn_hidden[0][:, i], None)