Source code for xuance.mindspore.agents.base.agents_marl

import os.path
import wandb
import socket
import xuance
import numpy as np
from abc import ABC, abstractmethod
from pathlib import Path
from argparse import Namespace
from operator import itemgetter
from gymnasium.spaces import Space
from torch.utils.tensorboard import SummaryWriter
from mindspore.communication import init, get_rank, get_group_size
from xuance.common import get_time_string, create_directory, Optional, List, Dict, Union, MultiAgentBaseCallback
from xuance.environment import DummyVecMultiAgentEnv, SubprocVecMultiAgentEnv, space2shape
from xuance.mindspore import ms, Tensor, Module, ModuleDict, REGISTRY_Representation, REGISTRY_Learners, ops
from xuance.mindspore.learners import learner
from xuance.mindspore.utils import NormalizeFunctions, ActivationFunctions, InitializeFunctions, set_seed, set_device


[docs] class MARLAgents(ABC): """Base class for Multi-Agent Reinforcement Learning (MARL) agents. This class defines the common interface and shared functionalities for all MARL agent implementations in XuanCe. It handles environment interaction, logging, model saving/loading, distributed training setup, and representation construction, while leaving algorithm-specific logic to subclasses. Subclasses should implement the abstract methods to define: - how experiences are stored, - how actions are selected, - how training and evaluation are performed. Args: config (Namespace): A configuration object that contains hyperparameters and runtime settings, such as algorithm name, environment name, learning rates, device, seed, and logging options. envs (Optional[DummyVecMultiAgentEnv | SubprocVecMultiAgentEnv]): Vectorized multi-agent environments for training. If not provided, environment-related attributes (e.g., observation/action spaces) must be specified explicitly. num_agents (Optional[int]): Number of agents in the environment. Required if `envs` is None. agent_keys (Optional[List[str]]): Unique identifiers for each agent. Required if `envs` is None. state_space (Optional[Space]): Global state space used by centralized critics or state-based representations. Required when `use_global_state` is enabled and `envs` is None. observation_space (Optional[Space]): Observation space for each agent. Required if `envs` is None. action_space (Optional[Space]): Action space for each agent. Required if `envs` is None. callback (Optional[MultiAgentBaseCallback]): A user-defined callback object for injecting custom logic during training and evaluation (e.g., logging, early stopping, debugging). """ def __init__( self, config: Namespace, envs: Optional[DummyVecMultiAgentEnv | SubprocVecMultiAgentEnv] = None, num_agents: Optional[int] = None, agent_keys: Optional[List[str]] = None, state_space: Optional[Space] = None, observation_space: Optional[Space] = None, action_space: Optional[Space] = None, callback: Optional[MultiAgentBaseCallback] = None ): # Training settings. self.config = config self.use_rnn = config.use_rnn if hasattr(config, "use_rnn") else False self.use_parameter_sharing = config.use_parameter_sharing self.use_actions_mask = config.use_actions_mask if hasattr(config, "use_actions_mask") else False self.use_global_state = config.use_global_state if hasattr(config, "use_global_state") else False self.distributed_training = getattr(config, "distributed_training", False) self.static_graph = getattr(config, "static_graph", True) if self.static_graph: ms.set_context(mode=ms.GRAPH_MODE) # Static graph mode (accelerating the calculation) print("Running mode: Static Graph. (Also known as Graph mode)") else: ms.set_context(mode=ms.PYNATIVE_MODE) # Dynamic graph mode (default mode) print("Running mode: Dynamic Graph.") if self.distributed_training: print("Running mode: Static Graph. (Also known as Graph mode)") init() self.world_size = get_group_size() self.rank = get_rank() ms.context.set_auto_parallel_context( parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True # Calculate mean gradient automatically (like DDP). ) else: self.world_size = 1 self.rank = 0 set_seed(config.seed + self.rank * 1000) self.gamma = config.gamma self.start_training = config.start_training if hasattr(config, "start_training") else 1 self.training_frequency = config.training_frequency if hasattr(config, "training_frequency") else 1 self.n_epochs = config.n_epochs if hasattr(config, "n_epochs") else 1 self.device = self.config.device = set_device(self.config.dl_toolbox, self.config.device) # Environment attributes. self.train_envs = envs self.render = config.render self.fps = config.fps if self.train_envs is None: if observation_space is None or action_space is None or agent_keys is None or num_agents is None: raise ValueError( "Please provide the num_agents, agent_keys, observation_space, and action_space when the envs is not provided. Or the networks cannot be built." "You can get them from test_envs.num_agents, test_envs.agents, test_envs.observation_space, and test_envs.action_space.") if self.use_global_state and state_space is None: raise ValueError("Please provide the state_space when the envs is not provided.") self.n_envs = self.config.parallels self.n_agents = self.config.n_agents = num_agents self.agent_keys = agent_keys self.state_space = state_space if self.use_global_state else None self.observation_space = observation_space self.action_space = action_space self.episode_length = None else: try: self.train_envs.reset() except: pass self.n_agents = self.config.n_agents = self.train_envs.num_agents self.n_envs = self.train_envs.num_envs self.agent_keys = self.train_envs.agents self.state_space = self.train_envs.state_space if self.use_global_state else None self.observation_space = self.train_envs.observation_space self.action_space = self.train_envs.action_space self.episode_length = getattr(config, "episode_length", self.train_envs.max_episode_steps) self.config.episode_length = self.episode_length self.current_step = 0 self.current_episode = np.zeros((self.n_envs,), np.int32) # Prepare directories. if self.distributed_training and self.world_size > 1: if self.rank == 0: time_string = get_time_string() time_bytes = list(time_string.encode('utf-8')) time_array = np.zeros(32, dtype=np.int32) time_array[:len(time_bytes)] = time_bytes time_string = Tensor(time_array, dtype=ms.int32) else: time_string = Tensor(np.zeros(32, dtype=np.int32), dtype=ms.int32) broadcast_op = ops.Broadcast(root_rank=0) time_tensor = broadcast_op((time_string,))[0] time_bytes_list = [int(x) for x in time_tensor.asnumpy().tolist() if x != 0] time_string = bytes(time_bytes_list).decode('utf-8') else: time_string = get_time_string() seed = f"seed_{config.seed}_" self.model_dir_load = config.model_dir self.model_dir_save = os.path.join(os.getcwd(), config.model_dir, seed + time_string) # Create logger. if config.logger == "tensorboard": log_dir = os.path.join(os.getcwd(), config.log_dir, seed + time_string) create_directory(log_dir) self.writer = SummaryWriter(log_dir) self.use_wandb = False elif config.logger == "wandb": config_dict = vars(config) log_dir = config.log_dir wandb_dir = Path(os.path.join(os.getcwd(), config.log_dir)) create_directory(str(wandb_dir)) wandb.init(config=config_dict, project=config.project_name, entity=config.wandb_user_name, notes=socket.gethostname(), dir=wandb_dir, group=config.env_id, job_type=config.agent, name=time_string, reinit=True, settings=wandb.Settings(start_method="fork") ) # os.environ["WANDB_SILENT"] = "True" self.use_wandb = True else: raise AttributeError("No logger is implemented.") self.log_dir = log_dir # predefine necessary components self.model_keys = [self.agent_keys[0]] if self.use_parameter_sharing else self.agent_keys self.policy: Optional[Module] = None self.learner: Optional[learner] = None self.memory: Optional[object] = None self.callback = callback or MultiAgentBaseCallback() self.eye = ops.Eye() self.meta_data = dict(algo=self.config.agent, env=self.config.env_name, env_id=self.config.env_id, dl_toolbox=self.config.dl_toolbox, device=self.device, seed=self.config.seed, xuance_version=xuance.__version__)
[docs] def store_experience(self, *args, **kwargs): raise NotImplementedError
[docs] def save_model(self, model_name, model_path=None): # save the neural networks model_path = self.model_dir_save if model_path is None else model_path if not os.path.exists(model_path): os.makedirs(model_path) self.learner.save_model(os.path.join(model_path, model_name))
[docs] def load_model(self, path, model=None): self.learner.load_model(path, model)
[docs] def log_infos(self, info: dict, x_index: int): """ info: (dict) information to be visualized n_steps: current step """ if self.use_wandb: for k, v in info.items(): if v is None: continue wandb.log({k: v}, step=x_index) else: for k, v in info.items(): if v is None: continue try: self.writer.add_scalar(k, v, x_index) except: self.writer.add_scalars(k, v, x_index)
[docs] def log_videos(self, info: dict, fps: int, x_index: int = 0): if self.use_wandb: for k, v in info.items(): if v is None: continue wandb.log({k: wandb.Video(v, fps=fps, format='gif')}, step=x_index) else: for k, v in info.items(): if v is None: continue self.writer.add_video(k, v, fps=fps, global_step=x_index)
def _build_representation(self, representation_key: str, input_space: Union[Dict[str, Space], tuple], config: Namespace): """ Build representation for policies. Parameters: representation_key (str): The selection of representation, e.g., "Basic_MLP", "Basic_RNN", etc. config: The configurations for creating the representation module. Returns: representation (Module): The representation Module. """ # build representations representation = ModuleDict() for key in self.model_keys: if self.use_rnn: hidden_sizes = {'fc_hidden_sizes': self.config.fc_hidden_sizes, 'recurrent_hidden_size': self.config.recurrent_hidden_size} else: hidden_sizes = config.representation_hidden_size if hasattr(config, "representation_hidden_size") else None input_representations = dict( input_shape=space2shape(input_space[key]), hidden_sizes=hidden_sizes, normalize=NormalizeFunctions[config.normalize] if hasattr(config, "normalize") else None, initialize=InitializeFunctions[config.initialize] if hasattr(self.config, "initialize") else None, activation=ActivationFunctions[config.activation], kernels=config.kernels if hasattr(config, "kernels") else None, strides=config.strides if hasattr(config, "strides") else None, filters=config.filters if hasattr(config, "filters") else None, fc_hidden_sizes=config.fc_hidden_sizes if hasattr(config, "fc_hidden_sizes") else None, N_recurrent_layers=config.N_recurrent_layers if hasattr(config, "N_recurrent_layers") else None, rnn=config.rnn if hasattr(config, "rnn") else None, dropout=config.dropout if hasattr(config, "dropout") else None, device=self.device) representation[key] = REGISTRY_Representation[representation_key](**input_representations) if representation_key not in REGISTRY_Representation: raise AttributeError(f"{representation_key} is not registered in REGISTRY_Representation.") return representation @abstractmethod def _build_policy(self) -> Module: raise NotImplementedError def _build_learner(self, *args): return REGISTRY_Learners[self.config.learner](*args) def _build_inputs(self, obs_dict: List[dict], avail_actions_dict: Optional[List[dict]] = None): """ Build inputs for representations before calculating actions. Parameters: obs_dict (List[dict]): Observations for each agent in self.agent_keys. avail_actions_dict (Optional[List[dict]]): Actions mask values, default is None. Returns: obs_input: The represented observations. agents_id: The agent id (One-Hot variables). """ batch_size = len(obs_dict) bs = batch_size * self.n_agents if self.use_parameter_sharing else batch_size avail_actions_input = None if self.use_parameter_sharing: key = self.agent_keys[0] obs_array = Tensor([itemgetter(*self.agent_keys)(data) for data in obs_dict]) agents_id = Tensor(np.eye(self.n_agents, dtype=np.float32)[None].repeat(batch_size, axis=0)) avail_actions_array = Tensor([itemgetter(*self.agent_keys)(data) for data in avail_actions_dict]) if self.use_actions_mask else None if self.use_rnn: obs_input = {key: obs_array.reshape([bs, 1, -1])} agents_id = agents_id.reshape(bs, 1, -1) if self.use_actions_mask: avail_actions_input = {key: avail_actions_array.reshape([bs, 1, -1])} else: obs_input = {key: obs_array.reshape([bs, -1])} agents_id = agents_id.reshape(bs, -1) if self.use_actions_mask: avail_actions_input = {key: avail_actions_array.reshape([bs, -1])} else: agents_id = None if self.use_rnn: obs_input = {k: Tensor(np.stack([data[k] for data in obs_dict]).reshape([bs, 1, -1])) for k in self.agent_keys} if self.use_actions_mask: avail_actions_input = { k: Tensor(np.stack([data[k] for data in avail_actions_dict]).reshape([bs, 1, -1])) for k in self.agent_keys} else: obs_input = {k: Tensor(np.stack([data[k] for data in obs_dict]).reshape(bs, -1)) for k in self.agent_keys} if self.use_actions_mask: avail_actions_input = {k: Tensor(np.stack([data[k] for data in avail_actions_dict]).reshape([bs, -1])) for k in self.agent_keys} return obs_input, agents_id, avail_actions_input
[docs] @abstractmethod def get_actions(self, **kwargs): raise NotImplementedError
[docs] @abstractmethod def train_epochs(self, *args, **kwargs): raise NotImplementedError
[docs] @abstractmethod def train(self, **kwargs): raise NotImplementedError
[docs] @abstractmethod def test(self, **kwargs): raise NotImplementedError
[docs] def finish(self): if self.use_wandb: wandb.finish() else: self.writer.close() if self.distributed_training: if self.rank == 0: if os.path.exists(self.learner.snapshot_path): if os.path.exists(os.path.join(self.learner.snapshot_path, "snapshot.pt")): os.remove(os.path.join(self.learner.snapshot_path, "snapshot.pt")) os.removedirs(self.learner.snapshot_path)
[docs] class RandomAgents(object): def __init__(self, args, envs, device=None): self.args = args self.n_agents = self.args.n_agents self.agent_keys = args.agent_keys self.action_space = self.args.action_space self.nenvs = envs.num_envs
[docs] def get_actions(self, obs_n, episode, test_mode, noise=False): rand_a = [[self.action_space[agent].sample() for agent in self.agent_keys] for e in range(self.nenvs)] random_actions = np.array(rand_a) return random_actions
[docs] def load_model(self, model_dir): return