Source code for xuance.torch.agents.multi_agent_rl.tarmac_agents

from argparse import Namespace
from typing import Union, Optional, Dict, List
import gymnasium as gym
import numpy as np
import torch
from gymnasium import Space
from torch.nn import Module, ModuleDict

from xuance.torch import REGISTRY_Policy
from xuance.torch.communications.attention_comm import TarMAC
from xuance.torch.utils import NormalizeFunctions, ActivationFunctions
from xuance.common import MultiAgentBaseCallback

from xuance.environment import DummyVecMultiAgentEnv, SubprocVecMultiAgentEnv, space2shape
from xuance.torch.agents.multi_agent_rl.ic3net_agents import IC3Net_Agents


[docs] class TarMAC_Agents(IC3Net_Agents): def __init__( self, config: Namespace, envs: Optional[DummyVecMultiAgentEnv | SubprocVecMultiAgentEnv] = None, num_agents: Optional[int] = None, agent_keys: Optional[List[str]] = None, state_space: Optional[Space] = None, observation_space: Optional[Space] = None, action_space: Optional[Space] = None, callback: Optional[MultiAgentBaseCallback] = None ): super(TarMAC_Agents, self).__init__( config, envs, num_agents, agent_keys, state_space, observation_space, action_space, callback ) self.policy = self._build_policy() self.memory = self._build_memory() # build memory self.learner = self._build_learner(self.config, self.model_keys, self.agent_keys, self.policy, callback) def _build_communicator(self, input_space: Union[Dict[str, Space], Dict[str, tuple]], ) -> Module: communicator = ModuleDict() hidden_sizes = {'fc_hidden_sizes': self.config.fc_hidden_sizes, 'recurrent_hidden_size': self.config.recurrent_hidden_size} for key in self.model_keys: input_communicator = dict( input_shape=space2shape(input_space[key]), hidden_sizes=hidden_sizes, comm_passes=self.config.comm_passes, model_keys=self.model_keys, agent_keys=self.agent_keys, n_agents=self.n_agents, device=self.device, config=self.config) communicator[key] = TarMAC(**input_communicator) return communicator def _build_policy(self) -> Module: normalize_fn = NormalizeFunctions[self.config.normalize] if hasattr(self.config, "normalize") else None initializer = torch.nn.init.orthogonal_ activation = ActivationFunctions[self.config.activation] device = self.device agent = self.config.agent max_length = max(space.shape[0] for space in self.observation_space.values()) self.observation_space = {agent: gym.spaces.Box(-np.inf, np.inf, (max_length,), dtype=np.float32) for agent in self.observation_space} # build representations communicator = self._build_communicator(self.observation_space) space_actor_in = {agent: gym.spaces.Box(-np.inf, np.inf, (self.config.recurrent_hidden_size,), dtype=np.float32) for agent in self.observation_space} if self.use_global_state: dim_obs_all = sum(self.state_space.shape) else: dim_obs_all = sum([sum(self.observation_space[k].shape) for k in self.agent_keys]) space_critic_in = {k: (dim_obs_all,) for k in self.agent_keys} A_representation = self._build_representation(self.config.representation, space_actor_in, self.config) C_representation = self._build_representation(self.config.representation, space_critic_in, self.config) # build policies if self.config.policy == "TarMAC_Policy": policy = REGISTRY_Policy[self.config.policy]( action_space=self.action_space, n_agents=self.n_agents, representation_actor=A_representation, representation_critic=C_representation, actor_hidden_size=self.config.actor_hidden_size, critic_hidden_size=self.config.critic_hidden_size, normalize=normalize_fn, initialize=initializer, activation=activation, device=device, use_distributed_training=self.distributed_training, use_parameter_sharing=self.use_parameter_sharing, model_keys=self.model_keys, use_rnn=self.use_rnn, rnn=self.config.rnn if self.use_rnn else None, communicator=communicator, agent_keys=self.agent_keys, comm_passes=self.config.comm_passes, config=self.config) self.continuous_control = False else: raise AttributeError(f"{agent} currently does not support the policy named {self.config.policy}.") return policy