Source code for xuance.torch.policies.categorical_marl

import os
import torch
from copy import deepcopy
from operator import itemgetter
from gymnasium.spaces import Discrete
from torch.distributions import Categorical
from torch.nn.functional import one_hot
from xuance.common import Sequence, Optional, Callable, Union, Dict, List
from xuance.torch.policies import CategoricalActorNet, ActorNet
from xuance.torch.policies.core import CriticNet, BasicQhead
from xuance.torch.utils import ModuleType
from xuance.torch.representations import Basic_MLP
from xuance.torch import nn, Tensor, Module, ModuleDict, DistributedDataParallel
from .core import CategoricalActorNet_SAC as Actor_SAC


[docs] class MAAC_Policy(Module): """ MAAC_Policy: Multi-Agent Actor-Critic Policy with categorical policies. Args: action_space (Optional[Dict[str, Discrete]]): The discrete action space. n_agents (int): The number of agents. representation_actor (ModuleDict): A dict of representation modules for each agent's actor. representation_critic (ModuleDict): A dict of representation modules for each agent's critic. mixer (Module): The mixer module that mix together the individual values to the total value. actor_hidden_size (Sequence[int]): A list of hidden layer sizes for actor network. critic_hidden_size (Sequence[int]): A list of hidden layer sizes for critic network. normalize (Optional[ModuleType]): The layer normalization over a minibatch of inputs. initialize (Optional[Callable[..., Tensor]]): The parameters initializer. activation (Optional[ModuleType]): The activation function for each layer. device (Optional[Union[str, int, torch.device]]): The calculating device. use_distributed_training (bool): Whether to use multi-GPU for distributed training. **kwargs: The other args. """ def __init__(self, action_space: Optional[Dict[str, Discrete]], n_agents: int, representation_actor: ModuleDict, representation_critic: ModuleDict, mixer: Optional[Module] = None, actor_hidden_size: Sequence[int] = None, critic_hidden_size: Sequence[int] = None, normalize: Optional[ModuleType] = None, initialize: Optional[Callable[..., Tensor]] = None, activation: Optional[ModuleType] = None, device: Optional[Union[str, int, torch.device]] = None, use_distributed_training: bool = False, **kwargs): super(MAAC_Policy, self).__init__() self.device = device self.action_space = action_space self.n_agents = n_agents self.use_parameter_sharing = kwargs['use_parameter_sharing'] self.model_keys = kwargs['model_keys'] self.lstm = True if kwargs["rnn"] == "LSTM" else False self.use_rnn = True if kwargs["use_rnn"] else False self.actor_representation = representation_actor self.critic_representation = representation_critic self.dim_input_critic, self.n_actions = {}, {} self.actor, self.critic = ModuleDict(), ModuleDict() for key in self.model_keys: self.n_actions[key] = self.action_space[key].n dim_actor_in, dim_actor_out, dim_critic_in, dim_critic_out = self._get_actor_critic_input( self.n_actions[key], self.actor_representation[key].output_shapes['state'][0], self.critic_representation[key].output_shapes['state'][0], n_agents) self.actor[key] = CategoricalActorNet(dim_actor_in, dim_actor_out, actor_hidden_size, normalize, initialize, activation, device) self.critic[key] = CriticNet(dim_critic_in, critic_hidden_size, normalize, initialize, activation, device) self.mixer = mixer # Prepare DDP module. self.distributed_training = use_distributed_training if self.distributed_training: self.rank = int(os.environ["RANK"]) for key in self.model_keys: if self.actor_representation[key]._get_name() != "Basic_Identical": self.actor_representation[key] = DistributedDataParallel(self.actor_representation[key], device_ids=[self.rank]) if self.critic_representation[key]._get_name() != "Basic_Identical": self.critic_representation[key] = DistributedDataParallel(self.critic_representation[key], device_ids=[self.rank]) self.actor[key] = DistributedDataParallel(module=self.actor[key], device_ids=[self.rank]) self.critic[key] = DistributedDataParallel(module=self.critic[key], device_ids=[self.rank]) if self.mixer is not None: self.mixer = DistributedDataParallel(module=self.mixer, device_ids=[self.rank]) @property def parameters_model(self): parameters = list(self.actor_representation.parameters()) + list(self.actor.parameters()) + list( self.critic_representation.parameters()) + list(self.critic.parameters()) if self.mixer is not None: parameters += list(self.mixer.parameters()) return parameters def _get_actor_critic_input(self, dim_action, dim_actor_rep, dim_critic_rep, n_agents): """ Returns the input dimensions of actor network and critic networks. Parameters: dim_action: The dimension of actions. dim_actor_rep: The dimension of the output of actor representation. dim_critic_rep: The dimension of the output of critic representation. n_agents: The number of agents. Returns: dim_actor_in: The dimension of input of the actor networks. dim_actor_out: The dimension of output of the actor networks. dim_critic_in: The dimension of the input of critic networks. dim_critic_out: The dimension of the output of critic networks. """ dim_actor_in, dim_actor_out = dim_actor_rep, dim_action dim_critic_in, dim_critic_out = dim_critic_rep, dim_action if self.use_parameter_sharing: dim_actor_in += n_agents dim_critic_in += n_agents return dim_actor_in, dim_actor_out, dim_critic_in, dim_critic_out
[docs] def forward(self, observation: Dict[str, Tensor], agent_ids: Optional[Tensor] = None, avail_actions: Dict[str, Tensor] = None, agent_key: str = None, rnn_hidden: Optional[Dict[str, List[Tensor]]] = None): """ Returns actions of the policy. Parameters: observation (Dict[str, Tensor]): The input observations for the policies. agent_ids (Tensor): The agents' ids (for parameter sharing). avail_actions (Dict[str, Tensor]): Actions mask values, default is None. agent_key (str): Calculate actions for specified agent. rnn_hidden (Optional[Dict[str, List[Tensor]]]): The RNN hidden states of actor representation. Returns: rnn_hidden_new (Optional[Dict[str, List[Tensor]]]): The new RNN hidden states of actor representation. pi_dists (dict): The stochastic policy distributions. """ rnn_hidden_new, pi_dists = {}, {} agent_list = self.model_keys if agent_key is None else [agent_key] if avail_actions is not None: avail_actions = {key: Tensor(avail_actions[key]) for key in agent_list} for key in agent_list: if self.use_rnn: outputs = self.actor_representation[key](observation[key], *rnn_hidden[key]) rnn_hidden_new[key] = (outputs['rnn_hidden'], outputs['rnn_cell']) else: outputs = self.actor_representation[key](observation[key]) rnn_hidden_new[key] = [None, None] if self.use_parameter_sharing: actor_input = torch.concat([outputs['state'], agent_ids], dim=-1) else: actor_input = outputs['state'] avail_actions_input = None if avail_actions is None else avail_actions[key] pi_dists[key] = self.actor[key](actor_input, avail_actions_input) return rnn_hidden_new, pi_dists
[docs] def get_values(self, observation: Dict[str, Tensor], agent_ids: Tensor = None, agent_key: str = None, rnn_hidden: Optional[Dict[str, List[Tensor]]] = None): """ Get critic values via critic networks. Parameters: observation (Dict[str, Tensor]): The input observations for the policies. agent_ids (Tensor): The agents' ids (for parameter sharing). agent_key (str): Calculate actions for specified agent. rnn_hidden (Optional[Dict[str, List[Tensor]]]): The RNN hidden states of critic representation. Returns: rnn_hidden_new (Optional[Dict[str, List[Tensor]]]): The new RNN hidden states of critic representation. values (dict): The evaluated critic values. """ rnn_hidden_new, values = {}, {} agent_list = self.model_keys if agent_key is None else [agent_key] for key in agent_list: if self.use_rnn: outputs = self.critic_representation[key](observation[key], *rnn_hidden[key]) rnn_hidden_new[key] = (outputs['rnn_hidden'], outputs['rnn_cell']) else: outputs = self.critic_representation[key](observation[key]) rnn_hidden_new[key] = [None, None] if self.use_parameter_sharing: critic_input = torch.concat([outputs['state'], agent_ids], dim=-1) else: critic_input = outputs['state'] values[key] = self.critic[key](critic_input) return rnn_hidden_new, values
[docs] def value_tot(self, values_n: Tensor, global_state=None): if global_state is not None: global_state = torch.as_tensor(global_state).to(self.device) return values_n if self.mixer is None else self.mixer(values_n, global_state)
[docs] class MAAC_Policy_Share(MAAC_Policy): """ MAAC_Policy_Share: Multi-agent actor-critic Policy with categorical policies and shared representations. Args: action_space (Optional[Dict[str, Discrete]]): The discrete action space. n_agents (int): The number of agents. representation (ModuleDict): A dict of representation modules. mixer (Module): The mixer module that mix together the individual values to the total value. actor_hidden_size (Sequence[int]): A list of hidden layer sizes for actor network. critic_hidden_size (Sequence[int]): A list of hidden layer sizes for critic network. normalize (Optional[ModuleType]): The layer normalization over a minibatch of inputs. initialize (Optional[Callable[..., Tensor]]): The parameters initializer. activation (Optional[ModuleType]): The activation function for each layer. device (Optional[Union[str, int, torch.device]]): The calculating device. use_distributed_training (bool): Whether to use multi-GPU for distributed training. **kwargs: The other args. """ def __init__(self, action_space: Discrete, n_agents: int, representation: Module, mixer: Optional[Module] = None, actor_hidden_size: Sequence[int] = None, critic_hidden_size: Sequence[int] = None, normalize: Optional[ModuleType] = None, initialize: Optional[Callable[..., Tensor]] = None, activation: Optional[ModuleType] = None, device: Optional[Union[str, int, torch.device]] = None, use_distributed_training: bool = False, **kwargs): super(MAAC_Policy, self).__init__() self.device = device self.action_space = action_space self.n_agents = n_agents self.use_parameter_sharing = kwargs['use_parameter_sharing'] self.model_keys = kwargs['model_keys'] self.lstm = True if kwargs["rnn"] == "LSTM" else False self.use_rnn = True if kwargs["use_rnn"] else False self.representation = representation self.dim_input_critic, self.n_actions = {}, {} self.actor, self.critic = ModuleDict(), ModuleDict() for key in self.model_keys: self.n_actions[key] = self.action_space[key].n dim_rep = self.representation[key].output_shapes['state'][0] dim_actor_in, dim_actor_out, dim_critic_in, dim_critic_out = self._get_actor_critic_input( self.n_actions[key], dim_rep, dim_rep, n_agents) self.actor[key] = CategoricalActorNet(dim_actor_in, dim_actor_out, actor_hidden_size, normalize, initialize, activation, device) self.critic[key] = CriticNet(dim_critic_in, critic_hidden_size, normalize, initialize, activation, device) self.mixer = mixer # Prepare DDP module. self.distributed_training = use_distributed_training if self.distributed_training: self.rank = int(os.environ["RANK"]) for key in self.model_keys: if self.representation[key]._get_name() != "Basic_Identical": self.representation[key] = DistributedDataParallel(self.representation[key], device_ids=[self.rank]) self.actor[key] = DistributedDataParallel(module=self.actor[key], device_ids=[self.rank]) self.critic[key] = DistributedDataParallel(module=self.critic[key], device_ids=[self.rank]) if self.mixer is not None: self.mixer = DistributedDataParallel(module=self.mixer, device_ids=[self.rank]) @property def parameters_model(self): parameters = list(self.representation.parameters()) + list(self.actor.parameters()) + \ list(self.critic.parameters()) if self.mixer is not None: parameters += list(self.mixer.parameters()) return parameters
[docs] def forward(self, observation: Dict[str, Tensor], agent_ids: Optional[Tensor] = None, avail_actions: Dict[str, Tensor] = None, agent_key: str = None, rnn_hidden: Optional[Dict[str, List[Tensor]]] = None): """ Returns actions of the policy. Parameters: observation (Dict[str, Tensor]): The input observations for the policies. agent_ids (Tensor): The agents' ids (for parameter sharing). avail_actions (Dict[str, Tensor]): Actions mask values, default is None. agent_key (str): Calculate actions for specified agent. rnn_hidden (Optional[Dict[str, List[Tensor]]]): The RNN hidden states of actor representation. Returns: rnn_hidden_new (Optional[Dict[str, List[Tensor]]]): The new RNN hidden states of actor representation. pi_dists (dict): The stochastic policy distributions. """ rnn_hidden_new, pi_dists, values = {}, {}, {} agent_list = self.model_keys if agent_key is None else [agent_key] if avail_actions is not None: avail_actions = {key: Tensor(avail_actions[key]) for key in agent_list} for key in agent_list: if self.use_rnn: outputs = self.representation[key](observation[key], *rnn_hidden[key]) rnn_hidden_new[key] = (outputs['rnn_hidden'], outputs['rnn_cell']) else: outputs = self.representation[key](observation[key]) rnn_hidden_new[key] = [None, None] if self.use_parameter_sharing: actor_input = torch.concat([outputs['state'], agent_ids], dim=-1) else: actor_input = outputs['state'] avail_actions_input = None if avail_actions is None else avail_actions[key] pi_dists[key] = self.actor[key](actor_input, avail_actions_input) return rnn_hidden_new, pi_dists
[docs] def get_values(self, observation: Dict[str, Tensor], agent_ids: Tensor = None, agent_key: str = None, rnn_hidden: Optional[Dict[str, List[Tensor]]] = None): """ Get critic values via critic networks. Parameters: observation (Dict[str, Tensor]): The input observations for the policies. agent_ids (Tensor): The agents' ids (for parameter sharing). agent_key (str): Calculate actions for specified agent. rnn_hidden (Optional[Dict[str, List[Tensor]]]): The RNN hidden states of critic representation. Returns: rnn_hidden_new (Optional[Dict[str, List[Tensor]]]): The new RNN hidden states of critic representation. values (dict): The evaluated critic values. """ rnn_hidden_new, values = {}, {} agent_list = self.model_keys if agent_key is None else [agent_key] for key in agent_list: if self.use_rnn: outputs = self.representation[key](observation[key], *rnn_hidden[key]) rnn_hidden_new[key] = (outputs['rnn_hidden'], outputs['rnn_cell']) else: outputs = self.representation[key](observation[key]) rnn_hidden_new[key] = [None, None] if self.use_parameter_sharing: critic_input = torch.concat([outputs['state'], agent_ids], dim=-1) else: critic_input = outputs['state'] values[key] = self.critic[key](critic_input) return rnn_hidden_new, values
[docs] class CommNet_Policy(MAAC_Policy): def __init__(self, action_space: Optional[Dict[str, Discrete]], n_agents: int, representation_actor: ModuleDict, representation_critic: ModuleDict, mixer: Optional[Module] = None, actor_hidden_size: Sequence[int] = None, critic_hidden_size: Sequence[int] = None, normalize: Optional[ModuleType] = None, initialize: Optional[Callable[..., Tensor]] = None, activation: Optional[ModuleType] = None, device: Optional[Union[str, int, torch.device]] = None, use_distributed_training: bool = False, **kwargs): super(CommNet_Policy, self).__init__(action_space=action_space, n_agents=n_agents, representation_actor=representation_actor, representation_critic=representation_critic, mixer=mixer, actor_hidden_size=actor_hidden_size, critic_hidden_size=critic_hidden_size, normalize=normalize, initialize=initialize, activation=activation, device=device, use_distributed_training=use_distributed_training, **kwargs) self.communicator = kwargs['communicator'] self.agent_keys = kwargs['agent_keys'] self.comm_passes = kwargs['comm_passes'] @property def parameters_model(self): parameters = list(self.actor_representation.parameters()) + list(self.actor.parameters()) + list( self.critic_representation.parameters()) + list(self.critic.parameters()) + list(self.communicator.parameters()) if self.mixer is not None: parameters += list(self.mixer.parameters()) return parameters
[docs] def forward(self, observation: Dict[str, Tensor], agent_ids: Optional[Tensor] = None, avail_actions: Dict[str, Tensor] = None, agent_key: str = None, rnn_hidden: Optional[Dict[str, List[Tensor]]] = None, alive_ally: Optional[dict] = None): rnn_hidden_new, pi_dists = {}, {} agent_list = self.model_keys if agent_key is None else [agent_key] seq_length = observation[self.model_keys[0]].shape[1] if avail_actions is not None: avail_actions = {key: Tensor(avail_actions[key]) for key in agent_list} observation = {k: self.communicator[k].obs_encode(observation[k]) for k in agent_list} actor_inputs = {k: [] for k in agent_list} for i in range(seq_length): alive_ally_i = {k: alive_ally[k][:, i:i + 1, :] for k in self.agent_keys} observation_i = {k: observation[k][:, i:i + 1, :] for k in agent_list} msg_send = {k: rnn_hidden[k][0].transpose(0, 1) for k in self.model_keys} for _ in range(self.comm_passes): msg_receive = {k: self.communicator[k](observation_i[k], msg_send, alive_ally_i) for k in self.model_keys} msg_send = {k: observation_i[k] + msg_receive[k] for k in self.model_keys} observation_i = msg_send for key in agent_list: if self.use_rnn: outputs = self.actor_representation[key](observation_i[key], *rnn_hidden[key]) rnn_hidden_new[key] = (outputs['rnn_hidden'], outputs['rnn_cell']) else: outputs = self.actor_representation[key](observation[key]) rnn_hidden_new[key] = [None, None] if self.use_parameter_sharing: agent_ids_i = agent_ids[:, i:i + 1, :] actor_input = torch.concat([outputs['state'], agent_ids_i], dim=-1) else: actor_input = outputs['state'] actor_inputs[key].append(actor_input) rnn_hidden = deepcopy(rnn_hidden_new) for key in agent_list: actor_input = torch.cat(actor_inputs[key], dim=1) avail_actions_input = None if avail_actions is None else avail_actions[key] pi_dists[key] = self.actor[key](actor_input, avail_actions_input) return rnn_hidden_new, pi_dists
[docs] class IC3Net_Policy(CommNet_Policy): def __init__(self, action_space: Optional[Dict[str, Discrete]], n_agents: int, representation_actor: ModuleDict, representation_critic: ModuleDict, mixer: Optional[Module] = None, actor_hidden_size: Sequence[int] = None, critic_hidden_size: Sequence[int] = None, normalize: Optional[ModuleType] = None, initialize: Optional[Callable[..., Tensor]] = None, activation: Optional[ModuleType] = None, device: Optional[Union[str, int, torch.device]] = None, use_distributed_training: bool = False, **kwargs): super(IC3Net_Policy, self).__init__(action_space=action_space, n_agents=n_agents, representation_actor=representation_actor, representation_critic=representation_critic, mixer=mixer, actor_hidden_size=actor_hidden_size, critic_hidden_size=critic_hidden_size, normalize=normalize, initialize=initialize, activation=activation, device=device, use_distributed_training=use_distributed_training, **kwargs) self.config = kwargs['config'] self.gate = {k: self.communicator[k].create_mlp(self.config.recurrent_hidden_size, self.config.gate_hidden_size, 2, nn.LeakyReLU(), self.device) for k in self.model_keys}
[docs] def forward(self, observation: Dict[str, Tensor], agent_ids: Optional[Tensor] = None, avail_actions: Dict[str, Tensor] = None, agent_key: str = None, rnn_hidden: Optional[Dict[str, List[Tensor]]] = None, alive_ally: Optional[dict] = None): rnn_hidden_new, pi_dists = {}, {} agent_list = self.model_keys if agent_key is None else [agent_key] seq_length = observation[self.model_keys[0]].shape[1] if avail_actions is not None: avail_actions = {key: Tensor(avail_actions[key]) for key in agent_list} observation = {k: self.communicator[k].obs_encode(observation[k]) for k in agent_list} actor_inputs = {k: [] for k in agent_list} gate_log_probs = {k: [] for k in agent_list} for i in range(seq_length): alive_ally_i = {k: alive_ally[k][:, i:i + 1, :] for k in self.agent_keys} observation_i = {k: observation[k][:, i:i + 1, :] for k in agent_list} msg_send = {k: rnn_hidden[k][0].transpose(0, 1) for k in self.model_keys} for comm_time in range(self.comm_passes): # calculate gate_control gate_prob = {k: self.gate[k](msg_send[k]) for k in agent_list} gate_dist = {k: Categorical(logits=gate_prob[k]) for k in agent_list} gate_control = {k: gate_dist[k].sample() for k in agent_list} gate_log_prob = {k: gate_dist[k].log_prob(gate_control[k]) for k in agent_list} comm_out = {k: self.communicator[k](observation_i[k], msg_send, alive_ally_i, gate_control) for k in self.model_keys} msg_send = {k: observation_i[k] + comm_out[k] for k in self.model_keys} if comm_time == self.comm_passes - 1: for k in agent_list: gate_log_probs[k].append(gate_log_prob[k]) observation_i = msg_send for key in agent_list: if self.use_rnn: outputs = self.actor_representation[key](observation_i[key], *rnn_hidden[key]) rnn_hidden_new[key] = (outputs['rnn_hidden'], outputs['rnn_cell']) else: outputs = self.actor_representation[key](observation[key]) rnn_hidden_new[key] = [None, None] if self.use_parameter_sharing: agent_ids_i = agent_ids[:, i:i + 1, :] actor_input = torch.concat([outputs['state'], agent_ids_i], dim=-1) else: actor_input = outputs['state'] actor_inputs[key].append(actor_input) rnn_hidden = deepcopy(rnn_hidden_new) for key in agent_list: actor_input = torch.cat(actor_inputs[key], dim=1) avail_actions_input = None if avail_actions is None else avail_actions[key] pi_dists[key] = self.actor[key](actor_input, avail_actions_input) gate_log_probs = {k: torch.cat(gate_log_probs[k], dim=1) for k in self.model_keys} return rnn_hidden_new, pi_dists, gate_log_probs
[docs] class TarMAC_Policy(IC3Net_Policy): def __init__(self, action_space: Optional[Dict[str, Discrete]], n_agents: int, representation_actor: ModuleDict, representation_critic: ModuleDict, mixer: Optional[Module] = None, actor_hidden_size: Sequence[int] = None, critic_hidden_size: Sequence[int] = None, normalize: Optional[ModuleType] = None, initialize: Optional[Callable[..., Tensor]] = None, activation: Optional[ModuleType] = None, device: Optional[Union[str, int, torch.device]] = None, use_distributed_training: bool = False, **kwargs): super(TarMAC_Policy, self).__init__(action_space=action_space, n_agents=n_agents, representation_actor=representation_actor, representation_critic=representation_critic, mixer=mixer, actor_hidden_size=actor_hidden_size, critic_hidden_size=critic_hidden_size, normalize=normalize, initialize=initialize, activation=activation, device=device, use_distributed_training=use_distributed_training, **kwargs)
[docs] def forward(self, observation: Dict[str, Tensor], agent_ids: Optional[Tensor] = None, avail_actions: Dict[str, Tensor] = None, agent_key: str = None, rnn_hidden: Optional[Dict[str, List[Tensor]]] = None, alive_ally: Optional[dict] = None): rnn_hidden_new, pi_dists = {}, {} agent_list = self.model_keys if agent_key is None else [agent_key] key = self.model_keys[0] seq_length = observation[key].shape[1] if avail_actions is not None: avail_actions = {key: Tensor(avail_actions[key]) for key in agent_list} observation = {k: self.communicator[k].obs_encode(observation[k]) for k in agent_list} actor_inputs = {k: [] for k in agent_list} gate_log_probs = {k: [] for k in agent_list} for i in range(seq_length): alive_ally_i = {k: alive_ally[k][:, i:i + 1, :] for k in self.agent_keys} observation_i = {k: observation[k][:, i:i + 1, :] for k in agent_list} msg_send = {k: rnn_hidden[k][0].transpose(0, 1) for k in self.model_keys} for comm_time in range(self.comm_passes): # calculate gate_control gate_prob = {k: self.gate[k](msg_send[k]) for k in agent_list} gate_dist = {k: Categorical(logits=gate_prob[k]) for k in agent_list} gate_control = {k: gate_dist[k].sample() for k in agent_list} gate_log_prob = {k: gate_dist[k].log_prob(gate_control[k]) for k in agent_list} comm_out = {k: self.communicator[k](observation_i[k], msg_send, alive_ally_i, gate_control, k) for k in self.model_keys} msg_send = {k: observation_i[k] + comm_out[k] for k in self.model_keys} if comm_time == self.comm_passes - 1: for k in agent_list: gate_log_probs[k].append(gate_log_prob[k]) observation_i = msg_send for key in agent_list: if self.use_rnn: outputs = self.actor_representation[key](observation_i[key], *rnn_hidden[key]) rnn_hidden_new[key] = (outputs['rnn_hidden'], outputs['rnn_cell']) else: outputs = self.actor_representation[key](observation[key]) rnn_hidden_new[key] = [None, None] if self.use_parameter_sharing: agent_ids_i = agent_ids[:, i:i + 1, :] actor_input = torch.concat([outputs['state'], agent_ids_i], dim=-1) else: actor_input = outputs['state'] actor_inputs[key].append(actor_input) rnn_hidden = deepcopy(rnn_hidden_new) for key in agent_list: actor_input = torch.cat(actor_inputs[key], dim=1) avail_actions_input = None if avail_actions is None else avail_actions[key] pi_dists[key] = self.actor[key](actor_input, avail_actions_input) gate_log_probs = {k: torch.cat(gate_log_probs[k], dim=1) for k in self.model_keys} return rnn_hidden_new, pi_dists, gate_log_probs
[docs] class COMA_Policy(Module): """ COMA_Policy: Counterfactual Multi-Agent Actor-Critic Policy with categorical distributions. Args: action_space (Optional[Dict[str, Discrete]]): The discrete action space. n_agents (int): The number of agents. representation_actor (ModuleDict): A dict of representation modules for each agent's actor. representation_critic (ModuleDict): A dict of representation modules for each agent's critic. mixer (Module): The mixer module that mix together the individual values to the total value. actor_hidden_size (Sequence[int]): A list of hidden layer sizes for actor network. critic_hidden_size (Sequence[int]): A list of hidden layer sizes for critic network. normalize (Optional[ModuleType]): The layer normalization over a minibatch of inputs. initialize (Optional[Callable[..., Tensor]]): The parameters initializer. activation (Optional[ModuleType]): The activation function for each layer. device (Optional[Union[str, int, torch.device]]): The calculating device. use_distributed_training (bool): Whether to use multi-GPU for distributed training. **kwargs: The other args. """ def __init__(self, action_space: Optional[Dict[str, Discrete]], n_agents: int, representation_actor: Module, representation_critic: Module, actor_hidden_size: Sequence[int] = None, critic_hidden_size: Sequence[int] = None, normalize: Optional[ModuleType] = None, initialize: Optional[Callable[..., Tensor]] = None, activation: Optional[ModuleType] = None, device: Optional[Union[str, int, torch.device]] = None, use_distributed_training: bool = False, **kwargs): super(COMA_Policy, self).__init__() self.device = device self.action_space = action_space self.n_agents = n_agents self.use_parameter_sharing = kwargs['use_parameter_sharing'] self.model_keys = kwargs['model_keys'] self.lstm = True if kwargs["rnn"] == "LSTM" else False self.use_rnn = True if kwargs["use_rnn"] else False self.actor_representation = representation_actor self.critic_representation = representation_critic self.target_critic_representation = deepcopy(self.critic_representation) # create actor self.n_actions = {k: space.n for k, space in self.action_space.items()} self.actor = ModuleDict() for key in self.model_keys: dim_actor_input = self.actor_representation[key].output_shapes['state'][0] if self.use_parameter_sharing: dim_actor_input += self.n_agents self.actor[key] = ActorNet(dim_actor_input, self.n_actions[key], actor_hidden_size, normalize, initialize, activation, None, device) dim_input_critic = kwargs['dim_global_state'] dim_input_critic += self.critic_representation[self.model_keys[0]].output_shapes['state'][0] dim_input_critic += sum(self.n_actions.values()) dim_input_critic += self.n_agents self.n_actions_max = max(self.n_actions.values()) self.critic = BasicQhead(dim_input_critic, self.n_actions_max, critic_hidden_size, normalize, initialize, activation, device) self.target_critic = deepcopy(self.critic) # Prepare DDP module. self.distributed_training = use_distributed_training if self.distributed_training: self.rank = int(os.environ['RANK']) for key in self.model_keys: if self.actor_representation[key]._get_name() != "Basic_Identical": self.actor_representation[key] = DistributedDataParallel(self.actor_representation[key], device_ids=[self.rank]) if self.critic_representation[key]._get_name() != "Basic_Identical": self.critic_representation[key] = DistributedDataParallel(self.critic_representation[key], device_ids=[self.rank]) self.actor[key] = DistributedDataParallel(module=self.actor[key], device_ids=[self.rank]) self.critic = DistributedDataParallel(module=self.critic, device_ids=[self.rank]) @property def parameters_actor(self): return list(self.actor_representation.parameters()) + list(self.actor.parameters()) @property def parameters_critic(self): return list(self.critic_representation.parameters()) + list(self.critic.parameters())
[docs] def forward(self, observation: Dict[str, Tensor], agent_ids: Optional[Tensor] = None, avail_actions: Dict[str, Tensor] = None, agent_key: str = None, rnn_hidden: Optional[Dict[str, List[Tensor]]] = None, epsilon=0.0, test_mode=False): """ Returns actions of the policy. Parameters: observation (Dict[str, Tensor]): The input observations for the policies. agent_ids (Tensor): The agents' ids (for parameter sharing). avail_actions (Dict[str, Tensor]): Actions mask values, default is None. agent_key (str): Calculate actions for specified agent. rnn_hidden (Optional[Dict[str, List[Tensor]]]): The RNN hidden states of actor representation. epsilon: The epsilon. Returns: rnn_hidden_new (Optional[Dict[str, List[Tensor]]]): The new RNN hidden states of actor representation. act_probs (dict): The probabilities of the actions. """ rnn_hidden_new, pi_logits, act_probs = {}, {}, {} agent_list = self.model_keys if agent_key is None else [agent_key] if avail_actions is not None: avail_actions = {key: Tensor(avail_actions[key]) for key in agent_list} for key in agent_list: if self.use_rnn: outputs = self.actor_representation[key](observation[key], *rnn_hidden[key]) rnn_hidden_new[key] = (outputs['rnn_hidden'], outputs['rnn_cell']) else: outputs = self.actor_representation[key](observation[key]) rnn_hidden_new[key] = [None, None] if self.use_parameter_sharing: actor_input = torch.concat([outputs['state'], agent_ids], dim=-1) else: actor_input = outputs['state'] avail_actions_input = None if avail_actions is None else avail_actions[key] pi_logits[key] = self.actor[key](actor_input, avail_actions_input) act_probs[key] = nn.functional.softmax(pi_logits[key], dim=-1) if not test_mode: act_probs[key] = (1 - epsilon) * act_probs[key] + epsilon * 1 / self.n_actions[key] return rnn_hidden_new, act_probs
[docs] def get_values(self, state: Tensor, observation: Dict[str, Tensor], actions: Dict[str, Tensor], agent_ids: Tensor = None, rnn_hidden: Optional[Dict[str, List[Tensor]]] = None, target=False): """ Get evaluated critic values. Parameters: state: Tensor: The global state. observation (Dict[str, Tensor]): The input observations for the policies. actions (Dict[str, Tensor]): The input actions. agent_ids (Tensor): The agents' ids (for parameter sharing). rnn_hidden (Optional[Dict[str, List[Tensor]]]): The RNN hidden states of critic representation. target: If to use target critic network to calculate the critic values. Returns: rnn_hidden_new (Optional[Dict[str, List[Tensor]]]): The new RNN hidden states of critic representation. values (dict): The evaluated critic values. """ rnn_hidden_new, critic_input = {}, {} batch_size = state.shape[0] seq_len = state.shape[1] if self.use_rnn else 1 critic_inputs = [] if self.use_rnn: critic_inputs.append(state.unsqueeze(-2).repeat(1, 1, self.n_agents, 1)) # batch * T * N * dim_S else: critic_inputs.append(state.unsqueeze(-2).repeat(1, self.n_agents, 1)) # batch * N * dim_S obs_rep = {} for key in self.model_keys: if self.use_rnn: if target: outputs = self.target_critic_representation[key](observation[key], *rnn_hidden[key]) else: outputs = self.critic_representation[key](observation[key], *rnn_hidden[key]) rnn_hidden_new[key] = (outputs['rnn_hidden'], outputs['rnn_cell']) else: if target: outputs = self.target_critic_representation[key](observation[key]) else: outputs = self.critic_representation[key](observation[key]) rnn_hidden_new[key] = [None, None] obs_rep[key] = outputs['state'] agent_mask = (1 - torch.eye(self.n_agents, dtype=torch.float32, device=self.device)).unsqueeze(-1) if self.use_parameter_sharing: key = self.model_keys[0] agent_mask = agent_mask.repeat(1, 1, self.n_actions[key]).reshape(self.n_agents, -1).unsqueeze(0) if self.use_rnn: actions_input = actions[key].reshape(batch_size, seq_len, 1, -1).repeat(1, 1, self.n_agents, 1) critic_inputs.append(obs_rep[key].reshape(batch_size, self.n_agents, seq_len, -1).transpose(1, 2)) critic_inputs.append(actions_input * agent_mask.unsqueeze(0)) critic_inputs.append(agent_ids.reshape(batch_size, self.n_agents, seq_len, -1).transpose(1, 2)) else: actions_input = actions[key].reshape(batch_size, 1, -1).repeat(1, self.n_agents, 1) critic_inputs.append(obs_rep[key].reshape(batch_size, self.n_agents, -1)) critic_inputs.append(actions_input * agent_mask) critic_inputs.append(agent_ids.reshape(batch_size, self.n_agents, -1)) critic_inputs = torch.cat(critic_inputs, dim=-1) else: agent_mask = torch.cat([agent_mask[i].repeat(1, self.n_actions[k]) for i, k in enumerate(self.model_keys)], dim=-1).unsqueeze(0) if self.use_rnn: agent_mask = agent_mask.unsqueeze(1) actions_input = torch.cat(itemgetter(*self.model_keys)(actions), dim=-1).unsqueeze(-2).repeat(1, 1, self.n_agents, 1) # batch * T * N * A agent_ids = agent_ids.reshape(batch_size, self.n_agents, seq_len, -1).transpose(1, 2) else: actions_input = torch.cat(itemgetter(*self.model_keys)(actions), dim=-1).unsqueeze(1).repeat(1, self.n_agents, 1) # batch_size * N * A agent_ids = agent_ids.reshape(batch_size, self.n_agents, -1) # batch_size * N * N critic_inputs.append(torch.stack(itemgetter(*self.model_keys)(obs_rep), dim=-2)) critic_inputs.append(actions_input * agent_mask) critic_inputs.append(agent_ids) critic_inputs = torch.cat(critic_inputs, dim=-1) values = self.target_critic(critic_inputs) if target else self.critic(critic_inputs) return rnn_hidden_new, values
[docs] def copy_target(self): for ep, tp in zip(self.critic_representation.parameters(), self.target_critic_representation.parameters()): tp.data.copy_(ep) for ep, tp in zip(self.critic.parameters(), self.target_critic.parameters()): tp.data.copy_(ep)
[docs] class MeanFieldActorCriticPolicy(Module): """Mean-field actor-critic policy. This policy maintains separate actor and critic networks for each agent type (model key), embeds the mean action of neighboring agents, and produces Boltzmann policies. Args: action_space (Discrete): A mapping from model keys to discrete action spaces. n_agents (int): Total number of agents in the environment. representation_actor (ModuleDict): Actor state encoder modules for each model key. representation_critic (ModuleDict): Critic state encoder modules for each model key. actor_hidden_size (Sequence[int], optional): Hidden layer sizes for actor networks. critic_hidden_size (Sequence[int], optional): Hidden layer sizes for critic networks. normalize (ModuleType, optional): Normalization layer to apply after each hidden layer. initialize (Callable[..., Tensor], optional): Weight initialization function. activation (ModuleType, optional): Activation function class for hidden layers. device (str|int|torch.device, optional): Device identifier for module placement. use_distributed_training (bool): If True, wrap components in DistributedDataParallel. **kwargs: Additional keyword arguments: use_parameter_sharing (bool): Whether to share parameters across agent types. model_keys (List[str]): Keys identifying different agent types. rnn (str): RNN type, e.g., "LSTM" or "GRU". use_rnn (bool): Flag indicating whether to include RNN layers. action_embedding_hidden_size (Sequence[int]): Hidden sizes for action mean embedding. temperature (float): Temperature parameter for Boltzmann policy. """ def __init__(self, action_space: Discrete, n_agents: int, representation_actor: ModuleDict, representation_critic: ModuleDict, actor_hidden_size: Sequence[int] = None, critic_hidden_size: Sequence[int] = None, normalize: Optional[ModuleType] = None, initialize: Optional[Callable[..., Tensor]] = None, activation: Optional[ModuleType] = None, device: Optional[Union[str, int, torch.device]] = None, use_distributed_training: bool = False, **kwargs): super(MeanFieldActorCriticPolicy, self).__init__() self.device = device self.action_space = action_space self.n_agents = n_agents self.n_actions_list = [a_space.n for a_space in self.action_space.values()] self.n_actions_max = max(self.n_actions_list) self.use_parameter_sharing = kwargs['use_parameter_sharing'] self.model_keys = kwargs['model_keys'] self.lstm = True if kwargs["rnn"] == "LSTM" else False self.use_rnn = True if kwargs["use_rnn"] else False self.actor_representation = representation_actor self.critic_representation = representation_critic self.dim_input_critic, self.n_actions = {}, {} self.action_mean_embedding = ModuleDict() self.actor, self.critic = ModuleDict(), ModuleDict() dim_action_embedding = self.n_actions_max + self.n_agents if self.use_parameter_sharing else self.n_actions_max for key in self.model_keys: self.n_actions[key] = self.action_space[key].n dim_actor_in, dim_actor_out, dim_critic_in, dim_critic_out = self._get_actor_critic_input( self.n_actions[key], self.actor_representation[key].output_shapes['state'][0], self.critic_representation[key].output_shapes['state'][0], n_agents,) dim_critic_in += kwargs['action_embedding_hidden_size'][-1] self.action_mean_embedding[key] = Basic_MLP((dim_action_embedding,), kwargs['action_embedding_hidden_size'], normalize, initialize, activation, device) self.actor[key] = ActorNet(dim_actor_in, dim_actor_out, actor_hidden_size, normalize, initialize, activation, None, device) self.critic[key] = CriticNet(dim_critic_in, critic_hidden_size, normalize, initialize, activation, device) self.softmax = torch.nn.Softmax(dim=-1) self.temperature = kwargs['temperature'] # Prepare DDP module. self.distributed_training = use_distributed_training if self.distributed_training: self.rank = int(os.environ["RANK"]) for key in self.model_keys: if self.actor_representation[key]._get_name() != "Basic_Identical": self.actor_representation[key] = DistributedDataParallel(self.actor_representation[key], device_ids=[self.rank]) if self.critic_representation[key]._get_name() != "Basic_Identical": self.critic_representation[key] = DistributedDataParallel(self.critic_representation[key], device_ids=[self.rank]) self.action_mean_embedding[key] = DistributedDataParallel(module=self.action_mean_embedding[key], device_ids=[self.rank]) self.actor[key] = DistributedDataParallel(module=self.actor[key], device_ids=[self.rank]) self.critic[key] = DistributedDataParallel(module=self.critic[key], device_ids=[self.rank]) @property def parameters_model(self): parameters = list(self.actor_representation.parameters()) + list(self.actor.parameters()) + list( self.critic_representation.parameters()) + list(self.critic.parameters()) + list( self.action_mean_embedding.parameters()) return parameters
[docs] def forward(self, observation: Dict[str, Tensor], agent_ids: Optional[Tensor] = None, avail_actions: Dict[str, Tensor] = None, agent_key: str = None, rnn_hidden: Optional[Dict[str, List[Tensor]]] = None): """ Returns actions of the policy. Parameters: observation (Dict[str, Tensor]): The input observations for the policies. agent_ids (Tensor): The agents' ids (for parameter sharing). avail_actions (Dict[str, Tensor]): Actions mask values, default is None. agent_key (str): Calculate actions for specified agent. rnn_hidden (Optional[Dict[str, List[Tensor]]]): The RNN hidden states of actor representation. Returns: rnn_hidden_new (Optional[Dict[str, List[Tensor]]]): The new RNN hidden states of actor representation. pi_dists (dict): The stochastic policy distributions. """ rnn_hidden_new, pi_logits, pi_dists = {}, {}, {} agent_list = self.model_keys if agent_key is None else [agent_key] if avail_actions is not None: avail_actions = {key: Tensor(avail_actions[key]) for key in agent_list} for key in agent_list: if self.use_rnn: outputs = self.actor_representation[key](observation[key], *rnn_hidden[key]) rnn_hidden_new[key] = (outputs['rnn_hidden'], outputs['rnn_cell']) else: outputs = self.actor_representation[key](observation[key]) rnn_hidden_new[key] = [None, None] if self.use_parameter_sharing: actor_input = torch.concat([outputs['state'], agent_ids], dim=-1) else: actor_input = outputs['state'] avail_actions_input = None if avail_actions is None else avail_actions[key] pi_logits[key] = self.actor[key](actor_input, avail_actions_input) actions_prob = self.get_boltzmann_policy(pi_logits[key]) pi_dists[key] = Categorical(probs=actions_prob) return rnn_hidden_new, pi_dists
[docs] def get_boltzmann_policy(self, q): """Convert Q-values to a Boltzmann (softmax) policy distribution. Args: q (Tensor): Q-value tensor of shape [..., n_actions]. Returns: Tensor: Probability distribution over actions with same shape as `q`. """ actions_prob = self.softmax(q / self.temperature) return actions_prob
[docs] def get_mean_actions(self, actions: Dict[str, Tensor], agent_mask_tensor: Tensor, batch_size: int): """Compute mean one-hot action vectors of each agent's neighbors. For each batch and agent, exclude the agent's own action and average the one-hot action encodings of its alive neighbors. Args: actions (Dict[str, Tensor]): Mapping from model keys to chosen action indices of shape [batch_size * n_agents]. agent_mask_tensor (Tensor): Binary mask of shape [batch_size, n_agents] indicating alive (1) or dead (0) agents. batch_size (int): Number of samples in the batch. Returns: Tensor: Mean one-hot action tensor of shape [batch_size, n_agents, n_actions_max]. """ if self.use_parameter_sharing: actions_tensor = actions[self.model_keys[0]].reshape([-1, self.n_agents]) else: actions_tensor = torch.stack(itemgetter(*self.model_keys)(actions), dim=-1).reshape([-1, self.n_agents]) actions_onehot = one_hot(actions_tensor, num_classes=self.n_actions_max) # count alive neighbors _eyes = torch.eye(self.n_agents).unsqueeze(0).repeat(batch_size, 1, 1).to(self.device) agent_mask_diagonal = agent_mask_tensor.unsqueeze(-1).repeat(1, 1, self.n_agents) * _eyes agent_mask_neighbors = agent_mask_tensor.unsqueeze(-1).repeat(1, 1, self.n_agents) - agent_mask_diagonal agent_alive_neighbors = agent_mask_neighbors.sum(dim=-1, keepdim=True) # calculate mean actions of each agent's neighbors agent_mask_repeat = agent_mask_tensor.unsqueeze(-1).repeat(1, 1, self.n_actions_max) actions_onehot = actions_onehot * agent_mask_repeat actions_sum = actions_onehot.sum(dim=-2, keepdim=True).repeat(1, self.n_agents, 1) actions_neighbors_sum = actions_sum - actions_onehot # Sum of other agents' actions. actions_mean_masked = actions_neighbors_sum * agent_mask_repeat / agent_alive_neighbors return actions_mean_masked
def _get_actor_critic_input(self, dim_action, dim_actor_rep, dim_critic_rep, n_agents): """ Returns the input dimensions of actor network and critic networks. Parameters: dim_action: The dimension of actions. dim_actor_rep: The dimension of the output of actor representation. dim_action_max: The maximum dimension of the output of actor dim_critic_rep: The dimension of the output of critic representation. n_agents: The number of agents. Returns: dim_actor_in: The dimension of input of the actor networks. dim_actor_out: The dimension of output of the actor networks. dim_critic_in: The dimension of the input of critic networks. dim_critic_out: The dimension of the output of critic networks. """ dim_actor_in, dim_actor_out = dim_actor_rep, dim_action dim_critic_in, dim_critic_out = dim_critic_rep, dim_action if self.use_parameter_sharing: dim_actor_in += n_agents dim_critic_in += n_agents return dim_actor_in, dim_actor_out, dim_critic_in, dim_critic_out
[docs] def get_values(self, observation: Dict[str, Tensor], actions_mean: Dict[str, Tensor] = None, agent_ids: Tensor = None, agent_key: str = None, rnn_hidden: Optional[Dict[str, List[Tensor]]] = None): """ Get critic values via critic networks. Parameters: observation (Dict[str, Tensor]): The input observations for the policies. actions_mean (Dict[str, Tensor]): The mean actions of each agent's neighbors. agent_ids (Tensor): The agents' ids (for parameter sharing). agent_key (str): Calculate actions for specified agent. rnn_hidden (Optional[Dict[str, List[Tensor]]]): The RNN hidden states of critic representation. Returns: rnn_hidden_new (Optional[Dict[str, List[Tensor]]]): The new RNN hidden states of critic representation. values (dict): The evaluated critic values. """ rnn_hidden_new, values = {}, {} agent_list = self.model_keys if agent_key is None else [agent_key] actions_mean = {key: Tensor(actions_mean[key]).to(self.device) for key in agent_list} for key in agent_list: if self.use_rnn: outputs = self.critic_representation[key](observation[key], *rnn_hidden[key]) rnn_hidden_new[key] = (outputs['rnn_hidden'], outputs['rnn_cell']) else: outputs = self.critic_representation[key](observation[key]) rnn_hidden_new[key] = [None, None] if self.use_parameter_sharing: action_embedding_input = torch.cat([actions_mean[key], agent_ids], dim=-1) act_embedding = self.action_mean_embedding[key](action_embedding_input) critic_input = torch.concat([outputs['state'], act_embedding['state'], agent_ids], dim=-1) else: act_embedding = self.action_mean_embedding[key](actions_mean[key]) critic_input = torch.cat([outputs['state'], act_embedding['state']], dim=-1) values[key] = self.critic[key](critic_input) return rnn_hidden_new, values
[docs] class Basic_ISAC_Policy(Module): """ Basic_ISAC_Policy: The basic policy for independent soft actor-critic. Args: action_space (Box): The continuous action space. n_agents (int): The number of agents. actor_representation (ModuleDict): A dict of representation modules for each agent's actor. critic_representation (ModuleDict): A dict of representation modules for each agent's critic. actor_hidden_size (Sequence[int]): A list of hidden layer sizes for actor network. critic_hidden_size (Sequence[int]): A list of hidden layer sizes for critic network. normalize (Optional[ModuleType]): The layer normalization over a minibatch of inputs. initialize (Optional[Callable[..., Tensor]]): The parameters initializer. activation (Optional[ModuleType]): The activation function for each layer. activation_action (Optional[ModuleType]): The activation of final layer to bound the actions. device (Optional[Union[str, int, torch.device]]): The calculating device. use_distributed_training (bool): Whether to use multi-GPU for distributed training. **kwargs: Other arguments. """ def __init__(self, action_space: Optional[Dict[str, Discrete]], n_agents: int, actor_representation: ModuleDict, critic_representation: ModuleDict, actor_hidden_size: Sequence[int], critic_hidden_size: Sequence[int], normalize: Optional[ModuleType] = None, initialize: Optional[Callable[..., Tensor]] = None, activation: Optional[ModuleType] = None, device: Optional[Union[str, int, torch.device]] = None, use_distributed_training: bool = False, **kwargs): super(Basic_ISAC_Policy, self).__init__() self.device = device self.action_space = action_space self.n_agents = n_agents self.use_parameter_sharing = kwargs['use_parameter_sharing'] self.model_keys = kwargs['model_keys'] self.lstm = True if kwargs["rnn"] == "LSTM" else False self.use_rnn = True if kwargs["use_rnn"] else False self.actor_representation = actor_representation self.critic_1_representation = critic_representation self.critic_2_representation = deepcopy(critic_representation) self.target_critic_1_representation = deepcopy(self.critic_1_representation) self.target_critic_2_representation = deepcopy(self.critic_2_representation) self.actor, self.critic_1, self.critic_2 = ModuleDict(), ModuleDict(), ModuleDict() for key in self.model_keys: dim_action = self.action_space[key].n dim_actor_in, dim_actor_out, dim_critic_in = self._get_actor_critic_input( self.actor_representation[key].output_shapes['state'][0], dim_action, self.critic_1_representation[key].output_shapes['state'][0], n_agents) self.actor[key] = Actor_SAC(dim_actor_in, dim_actor_out, actor_hidden_size, normalize, initialize, activation, device) self.critic_1[key] = BasicQhead(dim_critic_in, dim_action, critic_hidden_size, normalize, initialize, activation, device) self.critic_2[key] = BasicQhead(dim_critic_in, dim_action, critic_hidden_size, normalize, initialize, activation, device) self.target_critic_1 = deepcopy(self.critic_1) self.target_critic_2 = deepcopy(self.critic_2) # Prepare DDP module. self.distributed_training = use_distributed_training if self.distributed_training: self.rank = int(os.environ["RANK"]) for key in self.model_keys: if self.actor_representation[key]._get_name() != "Basic_Identical": self.actor_representation[key] = DistributedDataParallel(self.actor_representation[key], device_ids=[self.rank]) if self.critic_1_representation[key]._get_name() != "Basic_Identical": self.critic_1_representation[key] = DistributedDataParallel(self.critic_1_representation[key], device_ids=[self.rank]) if self.critic_2_representation[key]._get_name() != "Basic_Identical": self.critic_2_representation[key] = DistributedDataParallel(self.critic_2_representation[key], device_ids=[self.rank]) self.actor[key] = DistributedDataParallel(module=self.actor[key], device_ids=[self.rank]) self.critic_1[key] = DistributedDataParallel(module=self.critic_1[key], device_ids=[self.rank]) self.critic_2[key] = DistributedDataParallel(module=self.critic_2[key], device_ids=[self.rank]) @property def parameters_actor(self): parameters_actor = {} for key in self.model_keys: parameters_actor[key] = list(self.actor_representation[key].parameters()) + list( self.actor[key].parameters()) return parameters_actor @property def parameters_critic(self): parameters_critic = {} for key in self.model_keys: parameters_critic[key] = list(self.critic_1_representation[key].parameters()) + list( self.critic_1[key].parameters()) + list(self.critic_2_representation[key].parameters()) + list( self.critic_2[key].parameters()) return parameters_critic def _get_actor_critic_input(self, dim_actor_rep, dim_action, dim_critic_rep, n_agents): """ Returns the input dimensions of actor network and critic networks. Parameters: dim_actor_rep: The dimension of the output of actor representation. dim_action: The dimension of actions (continuous), or the number of actions (discrete). dim_critic_rep: The dimension of the output of critic representation. n_agents: The number of agents. Returns: dim_actor_in: The dimension of input of the actor networks. dim_actor_out: The dimension of output of the actor networks. dim_critic_in: The dimension of the input of critic networks. dim_critic_out: The dimension of the output of critic networks. """ dim_actor_in, dim_actor_out = dim_actor_rep, dim_action dim_critic_in = dim_critic_rep if self.use_parameter_sharing: dim_actor_in += n_agents dim_critic_in += n_agents return dim_actor_in, dim_actor_out, dim_critic_in
[docs] def forward(self, observation: Dict[str, Tensor], agent_ids: Tensor = None, avail_actions: Dict[str, Tensor] = None, agent_key: str = None, rnn_hidden: Optional[Dict[str, List[Tensor]]] = None): """ Returns actions of the policy. Parameters: observation (Dict[Tensor]): The input observations for the policies. agent_ids (Tensor): The agents' ids (for parameter sharing). avail_actions (Dict[str, Tensor]): Actions mask values, default is None. agent_key (str): Calculate actions for specified agent. rnn_hidden (Optional[Dict[str, List[Tensor]]]): The hidden variables of the RNN. Returns: rnn_hidden_new (Optional[Dict[str, List[Tensor]]]): The new hidden variables of the RNN. actions (Dict[Tensor]): The actions output by the policies. """ rnn_hidden_new, act_dists, actions_dict, log_action_prob = deepcopy(rnn_hidden), {}, {}, {} agent_list = self.model_keys if agent_key is None else [agent_key] if avail_actions is not None: avail_actions = {key: Tensor(avail_actions[key]) for key in agent_list} for key in agent_list: if self.use_rnn: outputs = self.actor_representation[key](observation[key], *rnn_hidden[key]) rnn_hidden_new.update({key: (outputs['rnn_hidden'], outputs['rnn_cell'])}) else: outputs = self.actor_representation[key](observation[key]) if self.use_parameter_sharing: actor_in = torch.concat([outputs['state'], agent_ids], dim=-1) else: actor_in = outputs['state'] avail_actions_input = None if avail_actions is None else avail_actions[key] act_dists = self.actor[key](actor_in, avail_actions_input) actions_dict[key] = act_dists.stochastic_sample() return rnn_hidden_new, actions_dict, None
[docs] def Qpolicy(self, observation: Dict[str, Tensor], agent_ids: Tensor = None, avail_actions: Dict[str, Tensor] = None, agent_key: str = None, rnn_hidden_actor: Optional[Dict[str, List[Tensor]]] = None, rnn_hidden_critic_1: Optional[Dict[str, List[Tensor]]] = None, rnn_hidden_critic_2: Optional[Dict[str, List[Tensor]]] = None): """ Returns Q^policy of current observations and actions pairs. Parameters: observation (Dict[Tensor]): The observations. agent_ids (Dict[Tensor]): The agents' ids (for parameter sharing). avail_actions (Dict[str, Tensor]): Actions mask values, default is None. agent_key (str): Calculate actions for specified agent. rnn_hidden_actor (Optional[Dict[str, List[Tensor]]]): The RNN hidden states for actor representation. rnn_hidden_critic_1 (Optional[Dict[str, List[Tensor]]]): The RNN hidden states for critic_1 representation. rnn_hidden_critic_2 (Optional[Dict[str, List[Tensor]]]): The RNN hidden states for critic_2 representation. Returns: rnn_hidden_actor_new: The updated rnn states for actor_representation. rnn_hidden_critic_new_1: The updated rnn states for critic_1_representation. rnn_hidden_critic_new_2: The updated rnn states for critic_2_representation. act_prob_dict: The probabilities of actions. log_action_prob_dict: The log of action probabilities. q_1: The evaluation of Q values with critic 1. q_2: The evaluation of Q values with critic 2. """ rnn_hidden_actor_new = deepcopy(rnn_hidden_actor) rnn_hidden_critic_new_1, rnn_hidden_critic_new_2 = deepcopy(rnn_hidden_critic_1), deepcopy(rnn_hidden_critic_2) act_prob_dict, log_action_prob_dict, q_1, q_2 = {}, {}, {}, {} agent_list = self.model_keys if agent_key is None else [agent_key] if avail_actions is not None: avail_actions = {key: Tensor(avail_actions[key]) for key in agent_list} for key in agent_list: if self.use_rnn: outputs_actor = self.actor_representation[key](observation[key], *rnn_hidden_actor[key]) outputs_critic_1 = self.critic_1_representation[key](observation[key], *rnn_hidden_critic_1[key]) outputs_critic_2 = self.critic_2_representation[key](observation[key], *rnn_hidden_critic_2[key]) rnn_hidden_actor_new.update({key: (outputs_actor['rnn_hidden'], outputs_actor['rnn_cell'])}) rnn_hidden_critic_new_1.update({key: (outputs_critic_1['rnn_hidden'], outputs_critic_1['rnn_cell'])}) rnn_hidden_critic_new_2.update({key: (outputs_critic_2['rnn_hidden'], outputs_critic_2['rnn_cell'])}) else: outputs_actor = self.actor_representation[key](observation[key]) outputs_critic_1 = self.critic_1_representation[key](observation[key]) outputs_critic_2 = self.critic_2_representation[key](observation[key]) actor_in = outputs_actor['state'] critic_1_in = outputs_critic_1['state'] critic_2_in = outputs_critic_2['state'] if self.use_parameter_sharing: actor_in = torch.concat([actor_in, agent_ids], dim=-1) critic_1_in = torch.concat([critic_1_in, agent_ids], dim=-1) critic_2_in = torch.concat([critic_2_in, agent_ids], dim=-1) avail_actions_input = None if avail_actions is None else avail_actions[key] actor_dist = self.actor[key](actor_in, avail_actions_input) act_prob_dict[key] = actor_dist.probs z = act_prob_dict[key] <= 1e-20 z = z.float() * 1e-8 log_action_prob_dict[key] = torch.log(act_prob_dict[key] + z) q_1[key], q_2[key] = self.critic_1[key](critic_1_in), self.critic_2[key](critic_2_in) return (rnn_hidden_actor_new, rnn_hidden_critic_new_1, rnn_hidden_critic_new_2, act_prob_dict, log_action_prob_dict, q_1, q_2)
[docs] def Qtarget(self, next_observation: Dict[str, Tensor], agent_ids: Tensor = None, avail_actions: Dict[str, Tensor] = None, agent_key: str = None, rnn_hidden_actor: Optional[Dict[str, List[Tensor]]] = None, rnn_hidden_critic_1: Optional[Dict[str, List[Tensor]]] = None, rnn_hidden_critic_2: Optional[Dict[str, List[Tensor]]] = None): """ Returns the Q^target of next observations and actions pairs. Parameters: next_observation (Dict[Tensor]): The observations of next step. agent_ids (Dict[Tensor]): The agents' ids (for parameter sharing). avail_actions (Dict[str, Tensor]): Actions mask values, default is None. agent_key (str): Calculate actions for specified agent. rnn_hidden_actor (Optional[Dict[str, List[Tensor]]]): The RNN hidden states for actor representation. rnn_hidden_critic_1 (Optional[Dict[str, List[Tensor]]]): The RNN hidden states for critic_1 representation. rnn_hidden_critic_2 (Optional[Dict[str, List[Tensor]]]): The RNN hidden states for critic_2 representation. Returns: rnn_hidden_actor: The updated rnn states for actor_representation. rnn_hidden_critic_new_1: The updated rnn states for critic_1_representation. rnn_hidden_critic_new_2: The updated rnn states for critic_2_representation. q_target: The evaluations of Q^target. """ rnn_hidden_actor_new = deepcopy(rnn_hidden_actor) rnn_hidden_critic_new_1, rnn_hidden_critic_new_2 = deepcopy(rnn_hidden_critic_1), deepcopy(rnn_hidden_critic_2) new_act_prob_dict, log_action_prob_dict, target_q = {}, {}, {} agent_list = self.model_keys if agent_key is None else [agent_key] if avail_actions is not None: avail_actions = {key: Tensor(avail_actions[key]) for key in agent_list} for key in agent_list: if self.use_rnn: outputs_actor = self.actor_representation[key](next_observation[key], *rnn_hidden_actor[key]) outputs_critic_1 = self.target_critic_1_representation[key](next_observation[key], *rnn_hidden_critic_1[key]) outputs_critic_2 = self.target_critic_2_representation[key](next_observation[key], *rnn_hidden_critic_2[key]) rnn_hidden_actor_new.update({key: (outputs_actor['rnn_hidden'], outputs_actor['rnn_cell'])}) rnn_hidden_critic_new_1.update({key: (outputs_critic_1['rnn_hidden'], outputs_critic_1['rnn_cell'])}) rnn_hidden_critic_new_2.update({key: (outputs_critic_2['rnn_hidden'], outputs_critic_2['rnn_cell'])}) else: outputs_actor = self.actor_representation[key](next_observation[key]) outputs_critic_1 = self.target_critic_1_representation[key](next_observation[key]) outputs_critic_2 = self.target_critic_2_representation[key](next_observation[key]) actor_in = outputs_actor['state'] critic_1_in = outputs_critic_1['state'] critic_2_in = outputs_critic_2['state'] if self.use_parameter_sharing: actor_in = torch.concat([actor_in, agent_ids], dim=-1) critic_1_in = torch.concat([critic_1_in, agent_ids], dim=-1) critic_2_in = torch.concat([critic_2_in, agent_ids], dim=-1) avail_actions_input = None if avail_actions is None else avail_actions[key] new_act_dist = self.actor[key](actor_in, avail_actions_input) new_act_prob_dict[key] = new_act_dist.probs z = new_act_prob_dict[key] <= 1e-20 z = z.float() * 1e-8 # avoid log(0) log_action_prob_dict[key] = torch.log(new_act_prob_dict[key] + z) target_q_1, target_q_2 = self.target_critic_1[key](critic_1_in), self.target_critic_2[key](critic_2_in) target_q[key] = torch.min(target_q_1, target_q_2) return (rnn_hidden_actor_new, rnn_hidden_critic_new_1, rnn_hidden_critic_new_2, new_act_prob_dict, log_action_prob_dict, target_q)
[docs] def Qaction(self, observation: Union[Tensor, dict], agent_ids: Tensor, agent_key: str = None, rnn_hidden_critic_1: Optional[Dict[str, List[Tensor]]] = None, rnn_hidden_critic_2: Optional[Dict[str, List[Tensor]]] = None): """ Returns the evaluated Q-values for current observation-action pairs. Parameters: observation (Union[Tensor, dict]): The original observation. agent_ids (Tensor): The agents' ids (for parameter sharing). agent_key (str): Calculate actions for specified agent. rnn_hidden_critic_1 (Optional[Dict[str, List[Tensor]]]): The RNN hidden states for critic_1 representation. rnn_hidden_critic_2 (Optional[Dict[str, List[Tensor]]]): The RNN hidden states for critic_2 representation. Returns: rnn_hidden_critic_new_1: The updated rnn states for critic_1_representation. rnn_hidden_critic_new_2: The updated rnn states for critic_2_representation. q_1: The Q-value calculated by the first critic network. q_2: The Q-value calculated by the other critic network. """ rnn_hidden_critic_new_1, rnn_hidden_critic_new_2 = deepcopy(rnn_hidden_critic_1), deepcopy(rnn_hidden_critic_2) q_1, q_2 = {}, {} agent_list = self.model_keys if agent_key is None else [agent_key] for key in agent_list: if self.use_rnn: outputs_critic_1 = self.critic_1_representation[key](observation[key], *rnn_hidden_critic_1[key]) outputs_critic_2 = self.critic_2_representation[key](observation[key], *rnn_hidden_critic_2[key]) rnn_hidden_critic_new_1.update({key: (outputs_critic_1['rnn_hidden'], outputs_critic_1['rnn_cell'])}) rnn_hidden_critic_new_2.update({key: (outputs_critic_2['rnn_hidden'], outputs_critic_2['rnn_cell'])}) else: outputs_critic_1 = self.critic_1_representation[key](observation[key]) outputs_critic_2 = self.critic_2_representation[key](observation[key]) critic_1_in = outputs_critic_1['state'] critic_2_in = outputs_critic_2['state'] if self.use_parameter_sharing: critic_1_in = torch.concat([critic_1_in, agent_ids], dim=-1) critic_2_in = torch.concat([critic_2_in, agent_ids], dim=-1) q_1[key], q_2[key] = self.critic_1[key](critic_1_in), self.critic_2[key](critic_2_in) return rnn_hidden_critic_new_1, rnn_hidden_critic_new_2, q_1, q_2
[docs] def soft_update(self, tau=0.005): for ep, tp in zip(self.critic_1_representation.parameters(), self.target_critic_1_representation.parameters()): tp.data.mul_(1 - tau) tp.data.add_(tau * ep.data) for ep, tp in zip(self.critic_1.parameters(), self.target_critic_1.parameters()): tp.data.mul_(1 - tau) tp.data.add_(tau * ep.data) for ep, tp in zip(self.critic_2_representation.parameters(), self.target_critic_2_representation.parameters()): tp.data.mul_(1 - tau) tp.data.add_(tau * ep.data) for ep, tp in zip(self.critic_2.parameters(), self.target_critic_2.parameters()): tp.data.mul_(1 - tau) tp.data.add_(tau * ep.data)
[docs] class MASAC_Policy(Basic_ISAC_Policy): """ Basic_ISAC_Policy: The basic policy for independent soft actor-critic. Args: action_space (Box): The continuous action space. n_agents (int): The number of agents. actor_representation (ModuleDict): A dict of representation modules for each agent's actor. critic_representation (ModuleDict): A dict of representation modules for each agent's critic. actor_hidden_size (Sequence[int]): A list of hidden layer sizes for actor network. critic_hidden_size (Sequence[int]): A list of hidden layer sizes for critic network. normalize (Optional[ModuleType]): The layer normalization over a minibatch of inputs. initialize (Optional[Callable[..., Tensor]]): The parameters initializer. activation (Optional[ModuleType]): The activation function for each layer. activation_action (Optional[ModuleType]): The activation of final layer to bound the actions. device (Optional[Union[str, int, torch.device]]): The calculating device. use_distributed_training (bool): Whether to use multi-GPU for distributed training. **kwargs: Other arguments. """ def __init__(self, action_space: Optional[Dict[str, Discrete]], n_agents: int, actor_representation: ModuleDict, critic_representation: ModuleDict, actor_hidden_size: Sequence[int], critic_hidden_size: Sequence[int], normalize: Optional[ModuleType] = None, initialize: Optional[Callable[..., Tensor]] = None, activation: Optional[ModuleType] = None, activation_action: Optional[ModuleType] = None, device: Optional[Union[str, int, torch.device]] = None, use_distributed_training: bool = False, **kwargs): super(MASAC_Policy, self).__init__(action_space, n_agents, actor_representation, critic_representation, actor_hidden_size, critic_hidden_size, normalize, initialize, activation, device, use_distributed_training, **kwargs) def _get_actor_critic_input(self, dim_actor_rep, dim_action, dim_critic_rep, n_agents): """ Returns the input dimensions of actor network and critic networks. Parameters: dim_actor_rep: The dimension of the output of actor representation. dim_action: The dimension of actions (continuous), or the number of actions (discrete). dim_critic_rep: The dimension of the output of critic representation. n_agents: The number of agents. Returns: dim_actor_in: The dimension of input of the actor networks. dim_actor_out: The dimension of output of the actor networks. dim_critic_in: The dimension of the input of critic networks. dim_critic_out: The dimension of the output of critic networks. """ dim_actor_in, dim_actor_out = dim_actor_rep, dim_action dim_critic_in = dim_critic_rep if self.use_parameter_sharing: dim_actor_in += n_agents dim_critic_in += n_agents return dim_actor_in, dim_actor_out, dim_critic_in
[docs] def Qpolicy(self, joint_observation: Optional[Tensor] = None, joint_actions: Optional[Tensor] = None, agent_ids: Tensor = None, agent_key: str = None, rnn_hidden_critic_1: Optional[Dict[str, List[Tensor]]] = None, rnn_hidden_critic_2: Optional[Dict[str, List[Tensor]]] = None): """ Returns Q^policy of current observations and actions pairs. Parameters: joint_observation (Optional[Tensor]): The joint observations of the team. joint_actions (Optional[Tensor]): The joint actions of the team. agent_ids (Dict[Tensor]): The agents' ids (for parameter sharing). agent_key (str): Calculate actions for specified agent. rnn_hidden_critic_1 (Optional[Dict[str, List[Tensor]]]): The RNN hidden states for critic_1 representation. rnn_hidden_critic_2 (Optional[Dict[str, List[Tensor]]]): The RNN hidden states for critic_2 representation. Returns: rnn_hidden_critic_new_1: The updated rnn states for critic_1_representation. rnn_hidden_critic_new_2: The updated rnn states for critic_2_representation. q_1: The evaluations of Q^policy with critic 1. q_2: The evaluations of Q^policy with critic 2. """ rnn_hidden_critic_new_1, rnn_hidden_critic_new_2 = deepcopy(rnn_hidden_critic_1), deepcopy(rnn_hidden_critic_2) q_1, q_2 = {}, {} agent_list = self.model_keys if agent_key is None else [agent_key] batch_size = joint_observation.shape[0] seq_len = joint_observation.shape[1] if self.use_rnn else 1 critic_rep_in = torch.concat([joint_observation, joint_actions], dim=-1) if self.use_rnn: outputs_critic_1 = {k: self.critic_1_representation[k](critic_rep_in, *rnn_hidden_critic_1[k]) for k in agent_list} outputs_critic_2 = {k: self.critic_2_representation[k](critic_rep_in, *rnn_hidden_critic_2[k]) for k in agent_list} rnn_hidden_critic_new_1.update({k: (outputs_critic_1[k]['rnn_hidden'], outputs_critic_1[k]['rnn_cell']) for k in agent_list}) rnn_hidden_critic_new_2.update({k: (outputs_critic_2[k]['rnn_hidden'], outputs_critic_2[k]['rnn_cell']) for k in agent_list}) else: outputs_critic_1 = {k: self.critic_1_representation[k](critic_rep_in) for k in agent_list} outputs_critic_2 = {k: self.critic_2_representation[k](critic_rep_in) for k in agent_list} bs = batch_size * self.n_agents if self.use_parameter_sharing else batch_size for key in agent_list: if self.use_parameter_sharing: if self.use_rnn: joint_rep_out_1 = outputs_critic_1[key]['state'].unsqueeze(1).expand(-1, self.n_agents, -1, -1) joint_rep_out_2 = outputs_critic_2[key]['state'].unsqueeze(1).expand(-1, self.n_agents, -1, -1) joint_rep_out_1 = joint_rep_out_1.reshape(bs, seq_len, -1) joint_rep_out_2 = joint_rep_out_2.reshape(bs, seq_len, -1) else: joint_rep_out_1 = outputs_critic_1[key]['state'].unsqueeze(1).expand( -1, self.n_agents, -1).reshape(bs, -1) joint_rep_out_2 = outputs_critic_2[key]['state'].unsqueeze(1).expand( -1, self.n_agents, -1).reshape(bs, -1) critic_1_in = torch.concat([joint_rep_out_1, agent_ids], dim=-1) critic_2_in = torch.concat([joint_rep_out_2, agent_ids], dim=-1) else: if self.use_rnn: joint_rep_out_1 = outputs_critic_1[key]['state'].reshape(bs, seq_len, -1) joint_rep_out_2 = outputs_critic_2[key]['state'].reshape(bs, seq_len, -1) else: joint_rep_out_1 = outputs_critic_1[key]['state'].reshape(bs, -1) joint_rep_out_2 = outputs_critic_2[key]['state'].reshape(bs, -1) critic_1_in = joint_rep_out_1 critic_2_in = joint_rep_out_2 q_1[key] = self.critic_1[key](critic_1_in) q_2[key] = self.critic_2[key](critic_2_in) return rnn_hidden_critic_new_1, rnn_hidden_critic_new_2, q_1, q_2
[docs] def Qtarget(self, joint_observation: Optional[Tensor] = None, joint_actions: Optional[Tensor] = None, agent_ids: Tensor = None, agent_key: str = None, rnn_hidden_critic_1: Optional[Dict[str, List[Tensor]]] = None, rnn_hidden_critic_2: Optional[Dict[str, List[Tensor]]] = None): """ Returns the Q^target of next observations and actions pairs. Parameters: joint_observation (Optional[Tensor]): The joint observations of the team. joint_actions (Optional[Tensor]): The joint actions of the team. agent_ids (Dict[Tensor]): The agents' ids (for parameter sharing). agent_key (str): Calculate actions for specified agent. rnn_hidden_critic_1 (Optional[Dict[str, List[Tensor]]]): The RNN hidden states for critic_1 representation. rnn_hidden_critic_2 (Optional[Dict[str, List[Tensor]]]): The RNN hidden states for critic_2 representation. Returns: rnn_hidden_critic_new_1: The updated rnn states for critic_1_representation. rnn_hidden_critic_new_2: The updated rnn states for critic_2_representation. q_target: The evaluations of Q^target. """ rnn_hidden_critic_new_1, rnn_hidden_critic_new_2 = deepcopy(rnn_hidden_critic_1), deepcopy(rnn_hidden_critic_2) target_q = {} agent_list = self.model_keys if agent_key is None else [agent_key] batch_size = joint_observation.shape[0] seq_len = joint_observation.shape[1] if self.use_rnn else 1 critic_rep_in = torch.concat([joint_observation, joint_actions], dim=-1) if self.use_rnn: outputs_critic_1 = {k: self.target_critic_1_representation[k](critic_rep_in, *rnn_hidden_critic_1[k]) for k in agent_list} outputs_critic_2 = {k: self.target_critic_2_representation[k](critic_rep_in, *rnn_hidden_critic_2[k]) for k in agent_list} rnn_hidden_critic_new_1.update({k: (outputs_critic_1[k]['rnn_hidden'], outputs_critic_1[k]['rnn_cell']) for k in agent_list}) rnn_hidden_critic_new_2.update({k: (outputs_critic_2[k]['rnn_hidden'], outputs_critic_2[k]['rnn_cell']) for k in agent_list}) else: outputs_critic_1 = {k: self.target_critic_1_representation[k](critic_rep_in) for k in agent_list} outputs_critic_2 = {k: self.target_critic_2_representation[k](critic_rep_in) for k in agent_list} bs = batch_size * self.n_agents if self.use_parameter_sharing else batch_size for key in agent_list: if self.use_parameter_sharing: if self.use_rnn: joint_rep_out_1 = outputs_critic_1[key]['state'].unsqueeze(1).expand(-1, self.n_agents, -1, -1) joint_rep_out_2 = outputs_critic_2[key]['state'].unsqueeze(1).expand(-1, self.n_agents, -1, -1) joint_rep_out_1 = joint_rep_out_1.reshape(bs, seq_len, -1) joint_rep_out_2 = joint_rep_out_2.reshape(bs, seq_len, -1) else: joint_rep_out_1 = outputs_critic_1[key]['state'].unsqueeze(1).expand( -1, self.n_agents, -1).reshape(bs, -1) joint_rep_out_2 = outputs_critic_2[key]['state'].unsqueeze(1).expand( -1, self.n_agents, -1).reshape(bs, -1) critic_1_in = torch.concat([joint_rep_out_1, agent_ids], dim=-1) critic_2_in = torch.concat([joint_rep_out_2, agent_ids], dim=-1) else: if self.use_rnn: joint_rep_out_1 = outputs_critic_1[key]['state'].reshape(bs, seq_len, -1) joint_rep_out_2 = outputs_critic_2[key]['state'].reshape(bs, seq_len, -1) else: joint_rep_out_1 = outputs_critic_1[key]['state'].reshape(bs, -1) joint_rep_out_2 = outputs_critic_2[key]['state'].reshape(bs, -1) critic_1_in = joint_rep_out_1 critic_2_in = joint_rep_out_2 q_1 = self.target_critic_1[key](critic_1_in) q_2 = self.target_critic_2[key](critic_2_in) target_q[key] = torch.min(q_1, q_2) return rnn_hidden_critic_new_1, rnn_hidden_critic_new_2, target_q
[docs] def Qaction(self, joint_observation: Optional[Tensor] = None, joint_actions: Optional[Tensor] = None, agent_ids: Optional[Tensor] = None, agent_key: str = None, rnn_hidden_critic_1: Optional[Dict[str, List[Tensor]]] = None, rnn_hidden_critic_2: Optional[Dict[str, List[Tensor]]] = None): """ Returns the evaluated Q-values for current observation-action pairs. Parameters: joint_observation (Optional[Tensor]): The joint observations of the team. joint_actions (Tensor): The joint actions of the team. agent_ids (Dict[Tensor]): The agents' ids (for parameter sharing). agent_key (str): Calculate actions for specified agent. rnn_hidden_critic_1 (Optional[Dict[str, List[Tensor]]]): The RNN hidden states for critic_1 representation. rnn_hidden_critic_2 (Optional[Dict[str, List[Tensor]]]): The RNN hidden states for critic_2 representation. Returns: rnn_hidden_critic_new_1: The updated rnn states for critic_1_representation. rnn_hidden_critic_new_2: The updated rnn states for critic_2_representation. q_1: The Q-value calculated by the first critic network. q_2: The Q-value calculated by the other critic network. """ rnn_hidden_critic_new_1, rnn_hidden_critic_new_2 = deepcopy(rnn_hidden_critic_1), deepcopy(rnn_hidden_critic_2) q_1, q_2 = {}, {} agent_list = self.model_keys if agent_key is None else [agent_key] batch_size = joint_observation.shape[0] seq_len = joint_observation.shape[1] if self.use_rnn else 1 critic_rep_in = torch.concat([joint_observation, joint_actions], dim=-1) if self.use_rnn: outputs_critic_1 = {k: self.critic_1_representation[k](critic_rep_in, *rnn_hidden_critic_1[k]) for k in agent_list} outputs_critic_2 = {k: self.critic_2_representation[k](critic_rep_in, *rnn_hidden_critic_2[k]) for k in agent_list} rnn_hidden_critic_new_1.update({k: (outputs_critic_1[k]['rnn_hidden'], outputs_critic_1[k]['rnn_cell']) for k in agent_list}) rnn_hidden_critic_new_2.update({k: (outputs_critic_2[k]['rnn_hidden'], outputs_critic_2[k]['rnn_cell']) for k in agent_list}) else: outputs_critic_1 = {k: self.critic_1_representation[k](critic_rep_in) for k in agent_list} outputs_critic_2 = {k: self.critic_2_representation[k](critic_rep_in) for k in agent_list} bs = batch_size * self.n_agents if self.use_parameter_sharing else batch_size for key in agent_list: if self.use_parameter_sharing: if self.use_rnn: joint_rep_out_1 = outputs_critic_1[key]['state'].unsqueeze(1).expand(-1, self.n_agents, -1, -1) joint_rep_out_2 = outputs_critic_2[key]['state'].unsqueeze(1).expand(-1, self.n_agents, -1, -1) joint_rep_out_1 = joint_rep_out_1.reshape(bs, seq_len, -1) joint_rep_out_2 = joint_rep_out_2.reshape(bs, seq_len, -1) else: joint_rep_out_1 = outputs_critic_1[key]['state'].unsqueeze(1).expand( -1, self.n_agents, -1).reshape(bs, -1) joint_rep_out_2 = outputs_critic_2[key]['state'].unsqueeze(1).expand( -1, self.n_agents, -1).reshape(bs, -1) critic_1_in = torch.concat([joint_rep_out_1, agent_ids], dim=-1) critic_2_in = torch.concat([joint_rep_out_2, agent_ids], dim=-1) else: if self.use_rnn: joint_rep_out_1 = outputs_critic_1[key]['state'].reshape(bs, seq_len, -1) joint_rep_out_2 = outputs_critic_2[key]['state'].reshape(bs, seq_len, -1) else: joint_rep_out_1 = outputs_critic_1[key]['state'].reshape(bs, -1) joint_rep_out_2 = outputs_critic_2[key]['state'].reshape(bs, -1) critic_1_in = joint_rep_out_1 critic_2_in = joint_rep_out_2 q_1[key] = self.critic_1[key](critic_1_in) q_2[key] = self.critic_2[key](critic_2_in) return rnn_hidden_critic_new_1, rnn_hidden_critic_new_2, q_1, q_2