import os.path
import wandb
import socket
import torch
import xuance
import numpy as np
import torch.distributed as dist
from abc import ABC, abstractmethod
from pathlib import Path
from argparse import Namespace
from operator import itemgetter
from gymnasium.spaces import Space
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from torch.distributed import destroy_process_group
from xuance.common import get_time_string, create_directory, Optional, List, Dict, Union, MultiAgentBaseCallback
from xuance.environment import DummyVecMultiAgentEnv, SubprocVecMultiAgentEnv, space2shape
from xuance.torch import ModuleDict, REGISTRY_Representation, REGISTRY_Learners, Module
from xuance.torch.learners import learner
from xuance.torch.utils import NormalizeFunctions, ActivationFunctions, init_distributed_mode, set_seed, set_device
[docs]
class MARLAgents(ABC):
"""Base class for Multi-Agent Reinforcement Learning (MARL) agents.
This class defines the common interface and shared functionalities for all
MARL agent implementations in XuanCe. It handles environment interaction,
logging, model saving/loading, distributed training setup, and representation
construction, while leaving algorithm-specific logic to subclasses.
Subclasses should implement the abstract methods to define:
- how experiences are stored,
- how actions are selected,
- how training and evaluation are performed.
Args:
config (Namespace):
A configuration object that contains hyperparameters and runtime
settings, such as algorithm name, environment name, learning rates,
device, seed, and logging options.
envs (Optional[DummyVecMultiAgentEnv | SubprocVecMultiAgentEnv]):
Vectorized multi-agent environments for training. If not provided,
environment-related attributes (e.g., observation/action spaces)
must be specified explicitly.
num_agents (Optional[int]):
Number of agents in the environment. Required if `envs` is None.
agent_keys (Optional[List[str]]):
Unique identifiers for each agent. Required if `envs` is None.
state_space (Optional[Space]):
Global state space used by centralized critics or state-based
representations. Required when `use_global_state` is enabled and
`envs` is None.
observation_space (Optional[Space]):
Observation space for each agent. Required if `envs` is None.
action_space (Optional[Space]):
Action space for each agent. Required if `envs` is None.
callback (Optional[MultiAgentBaseCallback]):
A user-defined callback object for injecting custom logic during
training and evaluation (e.g., logging, early stopping, debugging).
"""
def __init__(
self,
config: Namespace,
envs: Optional[DummyVecMultiAgentEnv | SubprocVecMultiAgentEnv] = None,
num_agents: Optional[int] = None,
agent_keys: Optional[List[str]] = None,
state_space: Optional[Space] = None,
observation_space: Optional[Space] = None,
action_space: Optional[Space] = None,
callback: Optional[MultiAgentBaseCallback] = None
):
set_seed(config.seed)
# Training settings.
self.config = config
self.use_cnn = getattr(config, "use_cnn", False)
self.use_rnn = getattr(config, "use_rnn", False)
self.use_parameter_sharing = config.use_parameter_sharing
self.use_actions_mask = getattr(config, "use_actions_mask", False)
self.use_global_state = getattr(config, "use_global_state", False)
self.distributed_training = config.distributed_training
if self.distributed_training:
self.world_size = int(os.environ['WORLD_SIZE'])
self.rank = int(os.environ['RANK'])
master_port = getattr(config, "master_port", None)
init_distributed_mode(master_port=master_port)
else:
self.world_size = 1
self.rank = 0
self.gamma = config.gamma
self.start_training = getattr(config, "start_training", 1)
self.training_frequency = getattr(config, "training_frequency", 1)
self.n_epochs = getattr(config, "n_epochs", 1)
self.device = self.config.device = set_device(self.config.device)
# Environment attributes.
self.train_envs = envs
self.render = config.render
self.fps = config.fps
if self.train_envs is None:
if observation_space is None or action_space is None or agent_keys is None or num_agents is None:
raise ValueError(
"Please provide the num_agents, agent_keys, observation_space, and action_space when the envs is not provided. Or the networks cannot be built."
"You can get them from test_envs.num_agents, test_envs.agents, test_envs.observation_space, and test_envs.action_space.")
if self.use_global_state and state_space is None:
raise ValueError("Please provide the state_space when the envs is not provided.")
self.n_envs = self.config.parallels
self.n_agents = self.config.n_agents = num_agents
self.agent_keys = agent_keys
self.state_space = state_space if self.use_global_state else None
self.observation_space = observation_space
self.action_space = action_space
self.episode_length = None
else:
try:
self.train_envs.reset()
except:
pass
self.n_agents = self.config.n_agents = self.train_envs.num_agents
self.n_envs = self.train_envs.num_envs
self.agent_keys = self.train_envs.agents
self.state_space = self.train_envs.state_space if self.use_global_state else None
self.observation_space = self.train_envs.observation_space
self.action_space = self.train_envs.action_space
self.episode_length = getattr(config, "episode_length", self.train_envs.max_episode_steps)
self.config.episode_length = self.episode_length
self.current_step = 0
self.current_episode = np.zeros((self.n_envs,), np.int32)
# Prepare directories.
if self.distributed_training and self.world_size > 1:
if self.rank == 0:
time_string = get_time_string()
time_string_tensor = torch.tensor(list(time_string.encode('utf-8')), dtype=torch.uint8).to(self.rank)
else:
time_string_tensor = torch.zeros(16, dtype=torch.uint8).to(self.rank)
dist.broadcast(time_string_tensor, src=0)
time_string = bytes(time_string_tensor.cpu().tolist()).decode('utf-8').rstrip('\x00')
else:
time_string = get_time_string()
seed = f"seed_{config.seed}_"
self.model_dir_load = config.model_dir
self.model_dir_save = os.path.join(os.getcwd(), config.model_dir, seed + time_string)
# Create logger.
if config.logger == "tensorboard":
log_dir = os.path.join(os.getcwd(), config.log_dir, seed + time_string)
if self.rank == 0:
create_directory(log_dir)
else:
while not os.path.exists(log_dir):
pass # Wait until the master process finishes creating directory.
self.writer = SummaryWriter(log_dir)
self.use_wandb = False
elif config.logger == "wandb":
config_dict = vars(config)
log_dir = config.log_dir
wandb_dir = Path(os.path.join(os.getcwd(), config.log_dir))
if self.rank == 0:
create_directory(str(wandb_dir))
else:
while not os.path.exists(str(wandb_dir)):
pass # Wait until the master process finishes creating directory.
wandb.init(config=config_dict,
project=config.project_name,
entity=config.wandb_user_name,
notes=socket.gethostname(),
dir=wandb_dir,
group=config.env_id,
job_type=config.agent,
name=time_string,
reinit=True,
settings=wandb.Settings(start_method="fork")
)
# os.environ["WANDB_SILENT"] = "True"
self.use_wandb = True
else:
raise AttributeError("No logger is implemented.")
self.log_dir = log_dir
# predefine necessary components
self.model_keys = [self.agent_keys[0]] if self.use_parameter_sharing else self.agent_keys
self.policy: Optional[nn.Module] = None
self.learner: Optional[learner] = None
self.memory: Optional[object] = None
self.callback = callback or MultiAgentBaseCallback()
self.meta_data = dict(algo=self.config.agent, env=self.config.env_name, env_id=self.config.env_id,
dl_toolbox=self.config.dl_toolbox, device=self.device, seed=self.config.seed,
xuance_version=xuance.__version__)
[docs]
@abstractmethod
def store_experience(self, *args, **kwargs):
raise NotImplementedError
[docs]
def save_model(self, model_name, model_path=None):
if self.distributed_training:
if self.rank > 0:
return
# save the neural networks
model_path = self.model_dir_save if model_path is None else model_path
if not os.path.exists(model_path):
os.makedirs(model_path)
self.learner.save_model(os.path.join(model_path, model_name))
[docs]
def load_model(self, path, model=None):
# load neural networks
self.learner.load_model(path, model)
[docs]
def log_infos(self, info: dict, x_index: int):
"""
info: (dict) information to be visualized
n_steps: current step
"""
if self.use_wandb:
for k, v in info.items():
if v is None:
continue
wandb.log({k: v}, step=x_index)
else:
for k, v in info.items():
if v is None:
continue
try:
self.writer.add_scalar(k, v, x_index)
except:
self.writer.add_scalars(k, v, x_index)
[docs]
def log_videos(self, info: dict, fps: int, x_index: int = 0):
if self.use_wandb:
for k, v in info.items():
if v is None:
continue
wandb.log({k: wandb.Video(v, fps=fps, format='gif')}, step=x_index)
else:
for k, v in info.items():
if v is None:
continue
self.writer.add_video(k, v, fps=fps, global_step=x_index)
def _build_representation(self, representation_key: str,
input_space: Union[Dict[str, Space], Dict[str, tuple]],
config: Namespace) -> Module:
"""
Build representation for policies.
Parameters:
representation_key (str): The selection of representation, e.g., "Basic_MLP", "Basic_RNN", etc.
config: The configurations for creating the representation module.
Returns:
representation (Module): The representation Module.
"""
# build representations
representation = ModuleDict()
for key in self.model_keys:
if self.use_rnn:
hidden_sizes = {'fc_hidden_sizes': self.config.fc_hidden_sizes,
'recurrent_hidden_size': self.config.recurrent_hidden_size}
else:
hidden_sizes = getattr(config, "representation_hidden_size", None)
input_representations = dict(
input_shape=space2shape(input_space[key]),
hidden_sizes=hidden_sizes,
normalize=NormalizeFunctions[config.normalize] if hasattr(config, "normalize") else None,
initialize=nn.init.orthogonal_,
activation=ActivationFunctions[config.activation],
kernels=getattr(config, "kernels", None),
strides=getattr(config, "strides", None),
filters=getattr(config, "filters", None),
fc_hidden_sizes=getattr(config, "fc_hidden_sizes", None),
N_recurrent_layers=getattr(config, "N_recurrent_layers", None),
rnn=getattr(config, "rnn", None),
dropout=getattr(config, "dropout", None),
device=self.device)
representation[key] = REGISTRY_Representation[representation_key](**input_representations)
if representation_key not in REGISTRY_Representation:
raise AttributeError(f"{representation_key} is not registered in REGISTRY_Representation.")
return representation
@abstractmethod
def _build_policy(self) -> Module:
raise NotImplementedError
def _build_learner(self, *args):
return REGISTRY_Learners[self.config.learner](*args)
def _build_inputs(self,
obs_dict: List[dict],
avail_actions_dict: Optional[List[dict]] = None):
"""
Build inputs for representations before calculating actions.
Parameters:
obs_dict (List[dict]): Observations for each agent in self.agent_keys.
avail_actions_dict (Optional[List[dict]]): Actions mask values, default is None.
Returns:
obs_input: The represented observations.
agents_id: The agent id (One-Hot variables).
"""
batch_size = len(obs_dict)
bs = batch_size * self.n_agents if self.use_parameter_sharing else batch_size
obs_input = {}
avail_actions_input = {} if self.use_actions_mask else None
if self.use_parameter_sharing:
key = self.agent_keys[0]
obs_array = np.array([itemgetter(*self.agent_keys)(data) for data in obs_dict])
if self.use_cnn and len(obs_array.shape) > 3: # batch * n_agent * height * width * channels (images)
obs_shape_item = obs_array.shape[2:]
else:
obs_shape_item = (-1,)
agents_id = torch.eye(self.n_agents).unsqueeze(0).expand(batch_size, -1, -1).to(self.device)
avail_actions_array = np.array([itemgetter(*self.agent_keys)(data)
for data in avail_actions_dict]) if self.use_actions_mask else None
if self.use_rnn:
obs_input = {key: obs_array.reshape([bs, 1, *obs_shape_item])}
agents_id = agents_id.reshape(bs, 1, -1)
if self.use_actions_mask:
avail_actions_input = {key: avail_actions_array.reshape([bs, 1, -1])}
else:
obs_input = {key: obs_array.reshape([bs, *obs_shape_item])}
agents_id = agents_id.reshape(bs, -1)
if self.use_actions_mask:
avail_actions_input = {key: avail_actions_array.reshape([bs, -1])}
else:
agents_id = None
for key in self.agent_keys:
obs_array = np.stack([data[key] for data in obs_dict])
if self.use_cnn and len(obs_array.shape) > 3: # batch * height * width * channels (images)
obs_shape_item = obs_array.shape[1:]
else:
obs_shape_item = (-1,)
if self.use_rnn:
obs_input[key] = obs_array.reshape([bs, 1, *obs_shape_item])
if self.use_actions_mask:
avail_actions_input[key] = np.stack(
[data[key] for data in avail_actions_dict]).reshape([bs, 1, -1])
else:
obs_input[key] = obs_array.reshape([bs, *obs_shape_item])
if self.use_actions_mask:
avail_actions_input[key] = np.stack(
[data[key] for data in avail_actions_dict]).reshape([bs, -1])
return obs_input, agents_id, avail_actions_input
[docs]
@abstractmethod
def get_actions(self, **kwargs):
raise NotImplementedError
[docs]
@abstractmethod
def train_epochs(self, *args, **kwargs):
raise NotImplementedError
[docs]
@abstractmethod
def train(self, **kwargs):
raise NotImplementedError
[docs]
@abstractmethod
def test(self, **kwargs):
raise NotImplementedError
[docs]
def finish(self):
if self.use_wandb:
wandb.finish()
else:
self.writer.close()
if self.distributed_training:
if dist.get_rank() == 0:
if os.path.exists(self.learner.snapshot_path):
if os.path.exists(os.path.join(self.learner.snapshot_path, "snapshot.pt")):
os.remove(os.path.join(self.learner.snapshot_path, "snapshot.pt"))
os.removedirs(self.learner.snapshot_path)
destroy_process_group()
[docs]
class RandomAgents(object):
def __init__(self, args, envs, device=None):
self.args = args
self.n_agents = self.args.n_agents
self.agent_keys = args.agent_keys
self.action_space = self.args.action_space
self.nenvs = envs.num_envs
[docs]
def get_actions(self, *args, **kwargs):
rand_a = [[self.action_space[agent].sample() for agent in self.agent_keys] for e in range(self.nenvs)]
random_actions = np.array(rand_a)
return random_actions
[docs]
def load_model(self, model_dir):
return