Source code for xuance.torch.communications.base_comm
import torch
import torch.nn as nn
from xuance.common import Optional, Callable, Union, Sequence
from xuance.torch import Module, Tensor
from xuance.torch.utils import mlp_block, ModuleType
[docs]
class BaseComm(Module):
def __init__(self,
state_dim: int,
n_agents: int,
hidden_sizes_comm: Sequence[int],
msg_dim: int,
normalize: Optional[ModuleType] = None,
initialize: Optional[Callable[..., Tensor]] = None,
activation: Optional[ModuleType] = None,
device: Optional[Union[str, int, torch.device]] = None,
**kwargs):
super().__init__()
self.n_agents = n_agents
self.msg_dim = msg_dim
self.hidden_sizes_comm = hidden_sizes_comm
layers_ = []
input_shape = (state_dim,)
for h in hidden_sizes_comm:
mlp, input_shape = mlp_block(input_shape[0], h, normalize, activation, initialize, device)
layers_.extend(mlp)
layers_.extend(mlp_block(input_shape[0], msg_dim, None, None, initialize, device)[0])
self.msg_encoder = nn.Sequential(*layers_)
[docs]
def forward(self, hidden_features: Tensor):
encoded_msg = self.msg_encoder(hidden_features)
return encoded_msg
[docs]
class NoneComm(Module):
def __init__(self):
super().__init__()
[docs]
def forward(self, msg: Tensor, **kwargs):
return msg