import numpy as np
from argparse import Namespace
from operator import itemgetter
from gymnasium.spaces import Space
from xuance.common import List, Optional, MultiAgentBaseCallback
from xuance.environment import DummyVecMultiAgentEnv, SubprocVecMultiAgentEnv
from xuance.mindspore import ops, Module
from xuance.mindspore.utils import NormalizeFunctions, InitializeFunctions, ActivationFunctions
from xuance.mindspore.policies import REGISTRY_Policy, VDN_mixer, QMIX_mixer
from xuance.mindspore.agents import OnPolicyMARLAgents
[docs]
class VDAC_Agents(OnPolicyMARLAgents):
def __init__(
self,
config: Namespace,
envs: Optional[DummyVecMultiAgentEnv | SubprocVecMultiAgentEnv] = None,
num_agents: Optional[int] = None,
agent_keys: Optional[List[str]] = None,
state_space: Optional[Space] = None,
observation_space: Optional[Space] = None,
action_space: Optional[Space] = None,
callback: Optional[MultiAgentBaseCallback] = None
):
super(VDAC_Agents, self).__init__(
config, envs, num_agents, agent_keys, state_space, observation_space, action_space, callback
)
self.state_space = envs.state_space
self.mixer = config.mixer
self.policy = self._build_policy() # build policy
self.memory = self._build_memory() # build memory
self.learner = self._build_learner(self.config, self.model_keys, self.agent_keys, self.policy, self.callback)
def _build_policy(self) -> Module:
"""
Build representation(s) and policy(ies) for agent(s)
Returns:
policy (Module): A dict of policies.
"""
normalize_fn = NormalizeFunctions[self.config.normalize] if hasattr(self.config, "normalize") else None
initializer = InitializeFunctions[self.config.initialize] if hasattr(self.config, "initialize") else None
activation = ActivationFunctions[self.config.activation]
agent = self.config.agent
# build representations
A_representation = self._build_representation(self.config.representation, self.observation_space, self.config)
C_representation = self._build_representation(self.config.representation, self.observation_space, self.config)
# create mixer
if self.mixer == "VDN":
mixer = VDN_mixer()
elif self.mixer == "QMIX":
dim_state = self.state_space.shape[-1]
mixer = QMIX_mixer(dim_state, self.config.hidden_dim_mixing_net, self.config.hidden_dim_hyper_net,
self.n_agents)
self.use_global_state = True
elif self.mixer == "Independent":
mixer = None
else:
raise AttributeError(f"Mixer named {self.mixer} is not supported in XuanCe!")
# build policies
if self.config.policy == "Categorical_MAAC_Policy":
policy = REGISTRY_Policy["Categorical_MAAC_Policy"](
action_space=self.action_space, n_agents=self.n_agents,
representation_actor=A_representation, representation_critic=C_representation, mixer=mixer,
actor_hidden_size=self.config.actor_hidden_size, critic_hidden_size=self.config.critic_hidden_size,
normalize=normalize_fn, initialize=initializer, activation=activation,
use_distributed_training=self.distributed_training,
use_parameter_sharing=self.use_parameter_sharing, model_keys=self.model_keys,
use_rnn=self.use_rnn, rnn=self.config.rnn if self.use_rnn else None)
self.continuous_control = False
elif self.config.policy == "Gaussian_MAAC_Policy":
policy = REGISTRY_Policy["Gaussian_MAAC_Policy"](
action_space=self.action_space, n_agents=self.n_agents,
representation_actor=A_representation, representation_critic=C_representation, mixer=mixer,
actor_hidden_size=self.config.actor_hidden_size, critic_hidden_size=self.config.critic_hidden_size,
normalize=normalize_fn, initialize=initializer, activation=activation,
activation_action=ActivationFunctions[self.config.activation_action],
use_distributed_training=self.distributed_training,
use_parameter_sharing=self.use_parameter_sharing, model_keys=self.model_keys,
use_rnn=self.use_rnn, rnn=self.config.rnn if self.use_rnn else None)
self.continuous_control = True
else:
raise AttributeError(f"{agent} currently does not support the policy named {self.config.policy}.")
return policy
[docs]
def store_experience(self, obs_dict, avail_actions, actions_dict, log_pi_a, rewards_dict, values_dict,
terminals_dict, info, **kwargs):
"""
Store experience data into replay buffer.
Parameters:
obs_dict (List[dict]): Observations for each agent in self.agent_keys.
avail_actions (List[dict]): Actions mask values for each agent in self.agent_keys.
actions_dict (List[dict]): Actions for each agent in self.agent_keys.
log_pi_a (dict): The log of pi.
rewards_dict (List[dict]): Rewards for each agent in self.agent_keys.
values_dict (dict): Critic values for each agent in self.agent_keys.
terminals_dict (List[dict]): Terminated values for each agent in self.agent_keys.
info (List[dict]): Other information for the environment at current step.
**kwargs: Other inputs.
"""
experience_data = {
'obs': {k: np.array([data[k] for data in obs_dict]) for k in self.agent_keys},
'actions': {k: np.array([data[k] for data in actions_dict]) for k in self.agent_keys},
# 'log_pi_old': log_pi_a,
'rewards': {k: np.array([np.array(list(data.values())).mean() for data in rewards_dict])
for k in self.agent_keys},
'values': values_dict,
'terminals': {k: np.array([data[k] for data in terminals_dict]) for k in self.agent_keys},
'agent_mask': {k: np.array([data['agent_mask'][k] for data in info]) for k in self.agent_keys},
}
if self.use_rnn:
experience_data['episode_steps'] = np.array([data['episode_step'] - 1 for data in info])
if self.use_global_state:
experience_data['state'] = np.array(kwargs['state'])
if self.use_actions_mask:
experience_data['avail_actions'] = {k: np.array([data[k] for data in avail_actions])
for k in self.agent_keys}
self.memory.store(**experience_data)
[docs]
def get_actions(self,
obs_dict: List[dict],
state: Optional[np.ndarray] = None,
avail_actions_dict: Optional[List[dict]] = None,
rnn_hidden_actor: Optional[dict] = None,
rnn_hidden_critic: Optional[dict] = None,
test_mode: Optional[bool] = False,
**kwargs):
"""
Returns actions for agents.
Parameters:
obs_dict (dict): Observations for each agent in self.agent_keys.
state (Optional[np.ndarray]): The global state.
avail_actions_dict (Optional[List[dict]]): Actions mask values, default is None.
rnn_hidden_actor (Optional[dict]): The RNN hidden states of actor representation.
rnn_hidden_critic (Optional[dict]): The RNN hidden states of critic representation.
test_mode (Optional[bool]): True for testing without noises.
Returns:
rnn_hidden_actor_new (dict): The new RNN hidden states of actor representation (if self.use_rnn=True).
rnn_hidden_critic_new (dict): The new RNN hidden states of critic representation (if self.use_rnn=True).
actions_dict (dict): The output actions.
log_pi_a (dict): The log of pi.
values_dict (dict): The evaluated critic values (when test_mode is False).
"""
n_env = len(obs_dict)
rnn_hidden_critic_new, log_pi_a_dict, values_dict = {}, {}, {}
obs_input, agents_id, avail_actions_input = self._build_inputs(obs_dict, avail_actions_dict)
if self.continuous_control:
rnn_hidden_actor_new, pi_mu, pi_std = self.policy(observation=obs_input,
agent_ids=agents_id,
avail_actions=avail_actions_input,
rnn_hidden=rnn_hidden_actor)
else:
rnn_hidden_actor_new, pi_logits = self.policy(observation=obs_input,
agent_ids=agents_id,
avail_actions=avail_actions_input,
rnn_hidden=rnn_hidden_actor)
if not test_mode:
rnn_hidden_critic_new, values_out = self.policy.get_values(observation=obs_input,
agent_ids=agents_id,
rnn_hidden=rnn_hidden_critic)
if self.use_parameter_sharing:
values_n = values_out[self.model_keys[0]].reshape([n_env, self.n_agents])
else:
values_n = ops.stack(itemgetter(*self.agent_keys)(values_out), axis=-1).reshape([n_env, self.n_agents])
if self.config.mixer == "VDN":
values_tot = self.policy.value_tot(values_n).asnumpy().reshape(n_env)
elif self.config.mixer == "QMIX":
values_tot = self.policy.value_tot(values_n, state).asnumpy().reshape(n_env)
else:
raise NotImplementedError(f"Mixer {self.config.mixer} for VDAC is not implemented.")
values_dict = {k: values_tot for k in self.agent_keys}
if self.use_parameter_sharing:
key = self.agent_keys[0]
if self.continuous_control:
pi_dists = self.policy.actor[key].distribution(mu=pi_mu[key], std=pi_std[key])
actions_sample = pi_dists.stochastic_sample()
actions_out = actions_sample.asnumpy().reshape(n_env, self.n_agents, -1)
else:
pi_dists = self.policy.actor[key].distribution(logits=pi_logits[key])
actions_sample = pi_dists.stochastic_sample()
actions_out = actions_sample.asnumpy().reshape(n_env, self.n_agents)
actions_dict = [{k: actions_out[e, i] for i, k in enumerate(self.agent_keys)} for e in range(n_env)]
else:
if self.continuous_control:
pi_dists = {k: self.policy.actor[k].distribution(pi_mu[k], pi_std[k]) for k in self.agent_keys}
actions_sample = {k: pi_dists[k].stochastic_sample() for k in self.agent_keys}
actions_dict = [{k: actions_sample[k].asnumpy()[e].reshape([-1]) for k in self.agent_keys}
for e in range(n_env)]
else:
pi_dists = {k: self.policy.actor[k].distribution(logits=pi_logits[k]) for k in self.agent_keys}
actions_sample = {k: pi_dists[k].stochastic_sample() for k in self.agent_keys}
actions_dict = [{k: actions_sample[k].asnumpy()[e].reshape([]) for k in self.agent_keys}
for e in range(n_env)]
return {"rnn_hidden_actor": rnn_hidden_actor_new, "rnn_hidden_critic": rnn_hidden_critic_new,
"actions": actions_dict, "log_pi": None, "values": values_dict}
[docs]
def values_next(self,
i_env: int,
obs_dict: dict,
state: Optional[np.ndarray] = None,
rnn_hidden_critic: Optional[dict] = None):
"""
Returns critic values of one environment that finished an episode.
Parameters:
i_env (int): The index of environment.
obs_dict (dict): Observations for each agent in self.agent_keys.
state (Optional[np.ndarray]): The global state.
rnn_hidden_critic (Optional[dict]): The RNN hidden states of critic representation.
Returns:
rnn_hidden_critic_new (dict): The new RNN hidden states of critic representation (if self.use_rnn=True).
values_dict: The critic values.
"""
n_env = 1
obs_input, agents_id, avail_actions_input = self._build_inputs([obs_dict])
rnn_hidden_critic_new, values_dict = self.policy.get_values(observation=obs_input,
agent_ids=agents_id,
rnn_hidden=rnn_hidden_critic)
if self.use_parameter_sharing:
values_n = values_dict[self.model_keys[0]].reshape([n_env, self.n_agents])
else:
values_n = ops.stack(itemgetter(*self.agent_keys)(values_dict), axis=-1).reshape([n_env, self.n_agents])
if self.config.mixer == "VDN":
values_tot = self.policy.value_tot(values_n).asnumpy().reshape([])
elif self.config.mixer == "QMIX":
values_tot = self.policy.value_tot(values_n, state).asnumpy().reshape([])
else:
raise NotImplementedError(f"Mixer {self.config.mixer} for VDAC is not implemented.")
values_dict = {k: values_tot for k in self.agent_keys}
return rnn_hidden_critic_new, values_dict