import os.path
import wandb
import socket
import xuance
import numpy as np
from abc import ABC, abstractmethod
from pathlib import Path
from argparse import Namespace
from operator import itemgetter
from gymnasium.spaces import Space
from torch.utils.tensorboard import SummaryWriter
from mindspore.communication import init, get_rank, get_group_size
from xuance.common import get_time_string, create_directory, Optional, List, Dict, Union, MultiAgentBaseCallback
from xuance.environment import DummyVecMultiAgentEnv, SubprocVecMultiAgentEnv, space2shape
from xuance.mindspore import ms, Tensor, Module, ModuleDict, REGISTRY_Representation, REGISTRY_Learners, ops
from xuance.mindspore.learners import learner
from xuance.mindspore.utils import NormalizeFunctions, ActivationFunctions, InitializeFunctions, set_seed, set_device
[docs]
class MARLAgents(ABC):
"""Base class for Multi-Agent Reinforcement Learning (MARL) agents.
This class defines the common interface and shared functionalities for all
MARL agent implementations in XuanCe. It handles environment interaction,
logging, model saving/loading, distributed training setup, and representation
construction, while leaving algorithm-specific logic to subclasses.
Subclasses should implement the abstract methods to define:
- how experiences are stored,
- how actions are selected,
- how training and evaluation are performed.
Args:
config (Namespace):
A configuration object that contains hyperparameters and runtime
settings, such as algorithm name, environment name, learning rates,
device, seed, and logging options.
envs (Optional[DummyVecMultiAgentEnv | SubprocVecMultiAgentEnv]):
Vectorized multi-agent environments for training. If not provided,
environment-related attributes (e.g., observation/action spaces)
must be specified explicitly.
num_agents (Optional[int]):
Number of agents in the environment. Required if `envs` is None.
agent_keys (Optional[List[str]]):
Unique identifiers for each agent. Required if `envs` is None.
state_space (Optional[Space]):
Global state space used by centralized critics or state-based
representations. Required when `use_global_state` is enabled and
`envs` is None.
observation_space (Optional[Space]):
Observation space for each agent. Required if `envs` is None.
action_space (Optional[Space]):
Action space for each agent. Required if `envs` is None.
callback (Optional[MultiAgentBaseCallback]):
A user-defined callback object for injecting custom logic during
training and evaluation (e.g., logging, early stopping, debugging).
"""
def __init__(
self,
config: Namespace,
envs: Optional[DummyVecMultiAgentEnv | SubprocVecMultiAgentEnv] = None,
num_agents: Optional[int] = None,
agent_keys: Optional[List[str]] = None,
state_space: Optional[Space] = None,
observation_space: Optional[Space] = None,
action_space: Optional[Space] = None,
callback: Optional[MultiAgentBaseCallback] = None
):
# Training settings.
self.config = config
self.use_rnn = config.use_rnn if hasattr(config, "use_rnn") else False
self.use_parameter_sharing = config.use_parameter_sharing
self.use_actions_mask = config.use_actions_mask if hasattr(config, "use_actions_mask") else False
self.use_global_state = config.use_global_state if hasattr(config, "use_global_state") else False
self.distributed_training = getattr(config, "distributed_training", False)
self.static_graph = getattr(config, "static_graph", True)
if self.static_graph:
ms.set_context(mode=ms.GRAPH_MODE) # Static graph mode (accelerating the calculation)
print("Running mode: Static Graph. (Also known as Graph mode)")
else:
ms.set_context(mode=ms.PYNATIVE_MODE) # Dynamic graph mode (default mode)
print("Running mode: Dynamic Graph.")
if self.distributed_training:
print("Running mode: Static Graph. (Also known as Graph mode)")
init()
self.world_size = get_group_size()
self.rank = get_rank()
ms.context.set_auto_parallel_context(
parallel_mode=ms.ParallelMode.DATA_PARALLEL,
gradients_mean=True # Calculate mean gradient automatically (like DDP).
)
else:
self.world_size = 1
self.rank = 0
set_seed(config.seed + self.rank * 1000)
self.gamma = config.gamma
self.start_training = config.start_training if hasattr(config, "start_training") else 1
self.training_frequency = config.training_frequency if hasattr(config, "training_frequency") else 1
self.n_epochs = config.n_epochs if hasattr(config, "n_epochs") else 1
self.device = self.config.device = set_device(self.config.dl_toolbox, self.config.device)
# Environment attributes.
self.train_envs = envs
self.render = config.render
self.fps = config.fps
if self.train_envs is None:
if observation_space is None or action_space is None or agent_keys is None or num_agents is None:
raise ValueError(
"Please provide the num_agents, agent_keys, observation_space, and action_space when the envs is not provided. Or the networks cannot be built."
"You can get them from test_envs.num_agents, test_envs.agents, test_envs.observation_space, and test_envs.action_space.")
if self.use_global_state and state_space is None:
raise ValueError("Please provide the state_space when the envs is not provided.")
self.n_envs = self.config.parallels
self.n_agents = self.config.n_agents = num_agents
self.agent_keys = agent_keys
self.state_space = state_space if self.use_global_state else None
self.observation_space = observation_space
self.action_space = action_space
self.episode_length = None
else:
try:
self.train_envs.reset()
except:
pass
self.n_agents = self.config.n_agents = self.train_envs.num_agents
self.n_envs = self.train_envs.num_envs
self.agent_keys = self.train_envs.agents
self.state_space = self.train_envs.state_space if self.use_global_state else None
self.observation_space = self.train_envs.observation_space
self.action_space = self.train_envs.action_space
self.episode_length = getattr(config, "episode_length", self.train_envs.max_episode_steps)
self.config.episode_length = self.episode_length
self.current_step = 0
self.current_episode = np.zeros((self.n_envs,), np.int32)
# Prepare directories.
if self.distributed_training and self.world_size > 1:
if self.rank == 0:
time_string = get_time_string()
time_bytes = list(time_string.encode('utf-8'))
time_array = np.zeros(32, dtype=np.int32)
time_array[:len(time_bytes)] = time_bytes
time_string = Tensor(time_array, dtype=ms.int32)
else:
time_string = Tensor(np.zeros(32, dtype=np.int32), dtype=ms.int32)
broadcast_op = ops.Broadcast(root_rank=0)
time_tensor = broadcast_op((time_string,))[0]
time_bytes_list = [int(x) for x in time_tensor.asnumpy().tolist() if x != 0]
time_string = bytes(time_bytes_list).decode('utf-8')
else:
time_string = get_time_string()
seed = f"seed_{config.seed}_"
self.model_dir_load = config.model_dir
self.model_dir_save = os.path.join(os.getcwd(), config.model_dir, seed + time_string)
# Create logger.
if config.logger == "tensorboard":
log_dir = os.path.join(os.getcwd(), config.log_dir, seed + time_string)
create_directory(log_dir)
self.writer = SummaryWriter(log_dir)
self.use_wandb = False
elif config.logger == "wandb":
config_dict = vars(config)
log_dir = config.log_dir
wandb_dir = Path(os.path.join(os.getcwd(), config.log_dir))
create_directory(str(wandb_dir))
wandb.init(config=config_dict,
project=config.project_name,
entity=config.wandb_user_name,
notes=socket.gethostname(),
dir=wandb_dir,
group=config.env_id,
job_type=config.agent,
name=time_string,
reinit=True,
settings=wandb.Settings(start_method="fork")
)
# os.environ["WANDB_SILENT"] = "True"
self.use_wandb = True
else:
raise AttributeError("No logger is implemented.")
self.log_dir = log_dir
# predefine necessary components
self.model_keys = [self.agent_keys[0]] if self.use_parameter_sharing else self.agent_keys
self.policy: Optional[Module] = None
self.learner: Optional[learner] = None
self.memory: Optional[object] = None
self.callback = callback or MultiAgentBaseCallback()
self.eye = ops.Eye()
self.meta_data = dict(algo=self.config.agent, env=self.config.env_name, env_id=self.config.env_id,
dl_toolbox=self.config.dl_toolbox, device=self.device, seed=self.config.seed,
xuance_version=xuance.__version__)
[docs]
def store_experience(self, *args, **kwargs):
raise NotImplementedError
[docs]
def save_model(self, model_name, model_path=None):
# save the neural networks
model_path = self.model_dir_save if model_path is None else model_path
if not os.path.exists(model_path):
os.makedirs(model_path)
self.learner.save_model(os.path.join(model_path, model_name))
[docs]
def load_model(self, path, model=None):
self.learner.load_model(path, model)
[docs]
def log_infos(self, info: dict, x_index: int):
"""
info: (dict) information to be visualized
n_steps: current step
"""
if self.use_wandb:
for k, v in info.items():
if v is None:
continue
wandb.log({k: v}, step=x_index)
else:
for k, v in info.items():
if v is None:
continue
try:
self.writer.add_scalar(k, v, x_index)
except:
self.writer.add_scalars(k, v, x_index)
[docs]
def log_videos(self, info: dict, fps: int, x_index: int = 0):
if self.use_wandb:
for k, v in info.items():
if v is None:
continue
wandb.log({k: wandb.Video(v, fps=fps, format='gif')}, step=x_index)
else:
for k, v in info.items():
if v is None:
continue
self.writer.add_video(k, v, fps=fps, global_step=x_index)
def _build_representation(self, representation_key: str,
input_space: Union[Dict[str, Space], tuple],
config: Namespace):
"""
Build representation for policies.
Parameters:
representation_key (str): The selection of representation, e.g., "Basic_MLP", "Basic_RNN", etc.
config: The configurations for creating the representation module.
Returns:
representation (Module): The representation Module.
"""
# build representations
representation = ModuleDict()
for key in self.model_keys:
if self.use_rnn:
hidden_sizes = {'fc_hidden_sizes': self.config.fc_hidden_sizes,
'recurrent_hidden_size': self.config.recurrent_hidden_size}
else:
hidden_sizes = config.representation_hidden_size if hasattr(config,
"representation_hidden_size") else None
input_representations = dict(
input_shape=space2shape(input_space[key]),
hidden_sizes=hidden_sizes,
normalize=NormalizeFunctions[config.normalize] if hasattr(config, "normalize") else None,
initialize=InitializeFunctions[config.initialize] if hasattr(self.config, "initialize") else None,
activation=ActivationFunctions[config.activation],
kernels=config.kernels if hasattr(config, "kernels") else None,
strides=config.strides if hasattr(config, "strides") else None,
filters=config.filters if hasattr(config, "filters") else None,
fc_hidden_sizes=config.fc_hidden_sizes if hasattr(config, "fc_hidden_sizes") else None,
N_recurrent_layers=config.N_recurrent_layers if hasattr(config, "N_recurrent_layers") else None,
rnn=config.rnn if hasattr(config, "rnn") else None,
dropout=config.dropout if hasattr(config, "dropout") else None,
device=self.device)
representation[key] = REGISTRY_Representation[representation_key](**input_representations)
if representation_key not in REGISTRY_Representation:
raise AttributeError(f"{representation_key} is not registered in REGISTRY_Representation.")
return representation
@abstractmethod
def _build_policy(self) -> Module:
raise NotImplementedError
def _build_learner(self, *args):
return REGISTRY_Learners[self.config.learner](*args)
def _build_inputs(self,
obs_dict: List[dict],
avail_actions_dict: Optional[List[dict]] = None):
"""
Build inputs for representations before calculating actions.
Parameters:
obs_dict (List[dict]): Observations for each agent in self.agent_keys.
avail_actions_dict (Optional[List[dict]]): Actions mask values, default is None.
Returns:
obs_input: The represented observations.
agents_id: The agent id (One-Hot variables).
"""
batch_size = len(obs_dict)
bs = batch_size * self.n_agents if self.use_parameter_sharing else batch_size
avail_actions_input = None
if self.use_parameter_sharing:
key = self.agent_keys[0]
obs_array = Tensor([itemgetter(*self.agent_keys)(data) for data in obs_dict])
agents_id = Tensor(np.eye(self.n_agents, dtype=np.float32)[None].repeat(batch_size, axis=0))
avail_actions_array = Tensor([itemgetter(*self.agent_keys)(data)
for data in avail_actions_dict]) if self.use_actions_mask else None
if self.use_rnn:
obs_input = {key: obs_array.reshape([bs, 1, -1])}
agents_id = agents_id.reshape(bs, 1, -1)
if self.use_actions_mask:
avail_actions_input = {key: avail_actions_array.reshape([bs, 1, -1])}
else:
obs_input = {key: obs_array.reshape([bs, -1])}
agents_id = agents_id.reshape(bs, -1)
if self.use_actions_mask:
avail_actions_input = {key: avail_actions_array.reshape([bs, -1])}
else:
agents_id = None
if self.use_rnn:
obs_input = {k: Tensor(np.stack([data[k] for data in obs_dict]).reshape([bs, 1, -1]))
for k in self.agent_keys}
if self.use_actions_mask:
avail_actions_input = {
k: Tensor(np.stack([data[k] for data in avail_actions_dict]).reshape([bs, 1, -1]))
for k in self.agent_keys}
else:
obs_input = {k: Tensor(np.stack([data[k] for data in obs_dict]).reshape(bs, -1))
for k in self.agent_keys}
if self.use_actions_mask:
avail_actions_input = {k: Tensor(np.stack([data[k]
for data in avail_actions_dict]).reshape([bs, -1]))
for k in self.agent_keys}
return obs_input, agents_id, avail_actions_input
[docs]
@abstractmethod
def get_actions(self, **kwargs):
raise NotImplementedError
[docs]
@abstractmethod
def train_epochs(self, *args, **kwargs):
raise NotImplementedError
[docs]
@abstractmethod
def train(self, **kwargs):
raise NotImplementedError
[docs]
@abstractmethod
def test(self, **kwargs):
raise NotImplementedError
[docs]
def finish(self):
if self.use_wandb:
wandb.finish()
else:
self.writer.close()
if self.distributed_training:
if self.rank == 0:
if os.path.exists(self.learner.snapshot_path):
if os.path.exists(os.path.join(self.learner.snapshot_path, "snapshot.pt")):
os.remove(os.path.join(self.learner.snapshot_path, "snapshot.pt"))
os.removedirs(self.learner.snapshot_path)
[docs]
class RandomAgents(object):
def __init__(self, args, envs, device=None):
self.args = args
self.n_agents = self.args.n_agents
self.agent_keys = args.agent_keys
self.action_space = self.args.action_space
self.nenvs = envs.num_envs
[docs]
def get_actions(self, obs_n, episode, test_mode, noise=False):
rand_a = [[self.action_space[agent].sample() for agent in self.agent_keys] for e in range(self.nenvs)]
random_actions = np.array(rand_a)
return random_actions
[docs]
def load_model(self, model_dir):
return