Source code for xuance.torch.communications.comm_net

from argparse import Namespace

import torch
import torch.nn as nn

from xuance.common import Optional, Union, Sequence
from xuance.torch import Module


[docs] class CommNet(Module): def __init__(self, input_shape: Sequence[int], hidden_sizes: dict, comm_passes: Optional[int] = 1, model_keys: dict = None, agent_keys: dict = None, n_agents: int = 1, device: Optional[Union[str, int, torch.device]] = None, config: Optional[Namespace] = None, **kwargs): super().__init__() self.input_shape = input_shape self.device = device self.fc_hidden_sizes = hidden_sizes["fc_hidden_sizes"] self.recurrent_hidden_size = hidden_sizes["recurrent_hidden_size"] self.comm_passes = comm_passes self.model_keys = model_keys self.agent_keys = agent_keys self.n_agents = n_agents self.config = config self.use_parameter_sharing = self.config.use_parameter_sharing self.obs_encoder = self.create_mlp(input_shape[0], [], self.recurrent_hidden_size, nn.LeakyReLU(), self.device) self.msg_encoder = self.create_mlp(self.recurrent_hidden_size, [], self.recurrent_hidden_size, nn.LeakyReLU(), self.device)
[docs] def create_mlp(self, input_shape: int, layers: list, out_shape: int, activation: nn.Module, device: Union[str, torch.device]) -> nn.Sequential: network_layers = [] if len(layers): network_layers.append(nn.Linear(input_shape, layers[0])) network_layers.append(activation) for i in range(len(layers) - 1): network_layers.append(nn.Linear(layers[i], layers[i + 1])) network_layers.append(activation) network_layers.append(nn.Linear(layers[-1], out_shape)) else: network_layers = [nn.Linear(input_shape, out_shape)] return nn.Sequential(*network_layers).to(device=device)
[docs] def message_encode(self, message: torch.Tensor) -> torch.Tensor: return self.msg_encoder(message)
[docs] def obs_encode(self, observation): obs = torch.as_tensor(observation, dtype=torch.float32, device=self.device) return self.obs_encoder(obs)
[docs] def forward(self, obs: torch.Tensor, msg_send: dict, alive_ally: dict) -> torch.Tensor: alive_ally = {k: torch.as_tensor(alive_ally[k], dtype=torch.float32, device=self.device) for k in self.agent_keys} batch_size, seq_length = obs.shape[0], obs.shape[1] if self.use_parameter_sharing: key = self.model_keys[0] msg_send = msg_send[key].view(batch_size // self.n_agents, self.n_agents, seq_length, -1) alive_ally = torch.stack(list(alive_ally.values()), dim=1) alive_agent_num = torch.sum(alive_ally, dim=1).unsqueeze(1) alive_agent_num = torch.clamp(alive_agent_num, min=1.0) msg_send = msg_send * alive_ally message = torch.sum(msg_send, dim=1, keepdim=True) - msg_send else: message = {k: msg_send[k] * alive_ally[k] for k in self.agent_keys} message = torch.stack(list(message.values()), dim=1) alive_ally = torch.stack(list(alive_ally.values()), dim=1) alive_agent_num = torch.sum(alive_ally, dim=1) alive_agent_num = torch.clamp(alive_agent_num, min=1.0) message = torch.sum(message, dim=1) message = message / alive_agent_num msg_receive = self.message_encode(message) if self.use_parameter_sharing: msg_receive = msg_receive.view(batch_size, seq_length, -1) return msg_receive