from argparse import Namespace
from gymnasium.spaces import Space
from xuance.common import List, Optional, MultiAgentBaseCallback
from xuance.environment import DummyVecMultiAgentEnv, SubprocVecMultiAgentEnv
from xuance.mindspore import Module
from xuance.mindspore.utils import NormalizeFunctions, ActivationFunctions, InitializeFunctions
from xuance.mindspore.policies import REGISTRY_Policy, QMIX_mixer
from xuance.mindspore.agents import OffPolicyMARLAgents
[docs]
class QMIX_Agents(OffPolicyMARLAgents):
"""The implementation of QMIX agents.
Args:
config: the Namespace variable that provides hyperparameters and other settings.
envs: the vectorized environments.
callback: A user-defined callback function object to inject custom logic during training.
"""
def __init__(
self,
config: Namespace,
envs: Optional[DummyVecMultiAgentEnv | SubprocVecMultiAgentEnv] = None,
num_agents: Optional[int] = None,
agent_keys: Optional[List[str]] = None,
state_space: Optional[Space] = None,
observation_space: Optional[Space] = None,
action_space: Optional[Space] = None,
callback: Optional[MultiAgentBaseCallback] = None
):
super(QMIX_Agents, self).__init__(
config, envs, num_agents, agent_keys, state_space, observation_space, action_space, callback
)
self.state_space = envs.state_space
self.use_global_state = True
self.start_greedy, self.end_greedy = config.start_greedy, config.end_greedy
self.e_greedy = self.start_greedy
self.delta_egreedy = (self.start_greedy - self.end_greedy) / (config.decay_step_greedy / self.n_envs)
# build policy, optimizers, schedulers
self.policy = self._build_policy() # build policy
self.memory = self._build_memory() # build memory
self.learner = self._build_learner(self.config, self.model_keys, self.agent_keys, self.policy, self.callback)
def _build_policy(self) -> Module:
"""
Build representation(s) and policy(ies) for agent(s)
Returns:
policy (torch.nn.Module): A dict of policies.
"""
normalize_fn = NormalizeFunctions[self.config.normalize] if hasattr(self.config, "normalize") else None
initializer = InitializeFunctions[self.config.initialize] if hasattr(self.config, "initialize") else None
activation = ActivationFunctions[self.config.activation]
# build representations
representation = self._build_representation(self.config.representation, self.observation_space, self.config)
# build policies
dim_state = self.state_space.shape[-1]
mixer = QMIX_mixer(dim_state, self.config.hidden_dim_mixing_net,
self.config.hidden_dim_hyper_net, self.n_agents)
if self.config.policy == "Mixing_Q_network":
policy = REGISTRY_Policy["Mixing_Q_network"](
action_space=self.action_space, n_agents=self.n_agents, representation=representation,
mixer=mixer, hidden_size=self.config.q_hidden_size,
normalize=normalize_fn, initialize=initializer, activation=activation,
use_parameter_sharing=self.use_parameter_sharing, model_keys=self.model_keys,
use_rnn=self.use_rnn, rnn=self.config.rnn if self.use_rnn else None)
else:
raise AttributeError(f"QMIX currently does not support the policy named {self.config.policy}.")
return policy