import os
import torch
import wandb
import socket
import xuance
import numpy as np
import torch.distributed as dist
from abc import ABC, abstractmethod
from pathlib import Path
from argparse import Namespace
from gymnasium.spaces import Dict, Space
from torch.utils.tensorboard import SummaryWriter
from torch.distributed import destroy_process_group
from xuance.common import get_time_string, create_directory, RunningMeanStd, EPS, Optional, BaseCallback
from xuance.environment import DummyVecEnv, SubprocVecEnv, space2shape
from xuance.torch import REGISTRY_Representation, REGISTRY_Learners, Module
from xuance.torch.utils import (nn, NormalizeFunctions, ActivationFunctions, init_distributed_mode, set_seed,
set_device,
TensorEnvWrapper, TensorRunningMeanStd)
[docs]
class Agent(ABC):
"""Base class for single-agent Deep Reinforcement Learning (DRL).
This class defines the common interface and shared infrastructure for
single-agent DRL algorithms in XuanCe. An Agent encapsulates the policy,
learner, and training/testing logic, while environments are managed
externally by the runner or provided explicitly by the user.
The agent can be initialized either with training environments (`envs`)
or, for inference/testing-only scenarios, without environments but with
explicit observation and action spaces.
Args:
config (Namespace): Configuration object containing hyperparameters,
runtime settings, and environment specifications.
envs (Optional[DummyVecEnv | SubprocVecEnv]): Vectorized environments
used for training. If None, the agent will not initialize training
environments and must be provided with `observation_space` and
`action_space` to build networks.
observation_space (Optional[gymnasium.spaces.Space]): Observation space
specification used to construct policy networks when `envs` is None.
Typically obtained from `test_envs.observation_space`.
action_space (Optional[gymnasium.spaces.Space]): Action space
specification used to construct policy networks when `envs` is None.
Typically obtained from `test_envs.action_space`.
callback (Optional[BaseCallback]): Optional callback object for injecting
custom logic during training or evaluation (e.g., logging, early
stopping, or custom hooks).
Notes:
- When `envs` is provided, the agent assumes a training context and
derives observation/action spaces from the environments.
- When `envs` is None, the agent can still be used for evaluation or
inference as long as the corresponding spaces are explicitly given.
- Environment creation and lifecycle management are intentionally
decoupled from the agent and handled by the runner or user code.
"""
def __init__(
self,
config: Namespace,
envs: Optional[DummyVecEnv | SubprocVecEnv] = None,
observation_space: Optional[Space] = None,
action_space: Optional[Space] = None,
callback: Optional[BaseCallback] = None
):
set_seed(config.seed)
# Training settings.
self.config = config
self.use_rnn = getattr(config, "use_rnn", False)
self.use_actions_mask = getattr(config, "use_actions_mask", False)
self.is_tensor_memory = getattr(self.config, "use_tensor_memory", False)
self.distributed_training = getattr(config, "distributed_training", False)
if self.distributed_training:
self.world_size = int(os.environ['WORLD_SIZE'])
self.rank = int(os.environ['RANK'])
master_port = getattr(config, "master_port", None)
init_distributed_mode(master_port=master_port)
else:
self.world_size = 1
self.rank = 0
self.gamma = config.gamma
self.start_training = getattr(config, "start_training", 1)
self.training_frequency = getattr(config, "training_frequency", 1)
self.n_epochs = getattr(config, "n_epochs", 1)
self.device = self.config.device = set_device(self.config.device)
# Environment attributes.
if self.is_tensor_memory:
self.train_envs = TensorEnvWrapper(envs, self.device)
else:
self.train_envs = envs
self.render = config.render
self.fps = config.fps
if self.train_envs is None:
if observation_space is None or action_space is None:
raise ValueError("Please provide the observation_space and action_space when the envs is not provided."
"Or the networks cannot be built."
"You can get them from test_envs.observation_space and test_envs.action_space.")
self.n_envs = self.config.parallels
self.observation_space = observation_space
self.action_space = action_space
self.episode_length = self.config.episode_length = None
else:
self.train_envs.reset()
self.n_envs = self.train_envs.num_envs
self.episode_length = self.config.episode_length = self.train_envs.max_episode_steps
self.observation_space = self.train_envs.observation_space
self.action_space = self.train_envs.action_space
self.current_step = 0
self.current_episode = np.zeros((self.n_envs,), np.int32)
# Set normalizations for observations and rewards.
if self.is_tensor_memory:
self.obs_rms = TensorRunningMeanStd(shape=space2shape(self.observation_space),
device=self.device, distributed=self.distributed_training)
self.ret_rms = TensorRunningMeanStd(shape=(), device=self.device, distributed=self.distributed_training)
self.returns = torch.zeros(size=(self.n_envs,), dtype=torch.float32, device=self.device)
else:
self.obs_rms = RunningMeanStd(shape=space2shape(self.observation_space))
self.ret_rms = RunningMeanStd(shape=())
self.returns = np.zeros((self.n_envs,), np.float32)
self.use_obsnorm = config.use_obsnorm
self.use_rewnorm = config.use_rewnorm
self.obsnorm_range = config.obsnorm_range
self.rewnorm_range = config.rewnorm_range
# Prepare directories.
if self.distributed_training and self.world_size > 1:
if self.rank == 0:
time_string = get_time_string()
time_string_tensor = torch.tensor(list(time_string.encode('utf-8')), dtype=torch.uint8).to(self.rank)
else:
time_string_tensor = torch.zeros(16, dtype=torch.uint8).to(self.rank)
dist.broadcast(time_string_tensor, src=0)
time_string = bytes(time_string_tensor.cpu().tolist()).decode('utf-8').rstrip('\x00')
else:
time_string = get_time_string()
seed = f"seed_{self.config.seed}_"
self.model_dir_load = config.model_dir
self.model_dir_save = os.path.join(os.getcwd(), config.model_dir, seed + time_string)
# Create logger.
if config.logger == "tensorboard":
log_dir = os.path.join(os.getcwd(), config.log_dir, seed + time_string)
if self.rank == 0:
create_directory(log_dir)
else:
while not os.path.exists(log_dir):
pass # Wait until the master process finishes creating directory.
self.writer = SummaryWriter(log_dir)
self.use_wandb = False
elif config.logger == "wandb":
config_dict = vars(config)
log_dir = config.log_dir
wandb_dir = Path(os.path.join(os.getcwd(), config.log_dir))
if self.rank == 0:
create_directory(str(wandb_dir))
else:
while not os.path.exists(str(wandb_dir)):
pass # Wait until the master process finishes creating directory.
wandb.init(
config=config_dict,
project=config.project_name,
entity=config.wandb_user_name,
notes=socket.gethostname(),
dir=wandb_dir,
group=config.env_id,
job_type=config.agent,
name=time_string,
reinit=True,
settings=wandb.Settings(start_method="fork")
)
# os.environ["WANDB_SILENT"] = "True"
self.use_wandb = True
else:
raise AttributeError("No logger is implemented.")
self.log_dir = log_dir
# Prepare necessary components.
self.policy: Optional[Module] = None
self.learner: Optional[Module] = None
self.memory: Optional[object] = None
self.callback = callback or BaseCallback()
self.meta_data = dict(algo=self.config.agent, env=self.config.env_name, env_id=self.config.env_id,
dl_toolbox=self.config.dl_toolbox, device=self.device, seed=self.config.seed,
xuance_version=xuance.__version__)
[docs]
def save_model(self, model_name, model_path=None):
if self.distributed_training:
if self.rank > 0:
return
# save the neural networks
model_path = self.model_dir_save if model_path is None else model_path
if not os.path.exists(model_path):
os.makedirs(model_path)
self.learner.save_model(os.path.join(model_path, model_name))
# save the observation status
if self.use_obsnorm:
obs_norm_path = os.path.join(model_path, "obs_rms.npy")
observation_stat = {'count': self.obs_rms.count,
'mean': self.obs_rms.mean,
'var': self.obs_rms.var}
np.save(obs_norm_path, observation_stat)
[docs]
def load_model(self, path, model=None):
# load neural networks
path_loaded = self.learner.load_model(path, model)
# recover observation status
if self.use_obsnorm:
obs_norm_path = os.path.join(path_loaded, "obs_rms.npy")
if os.path.exists(obs_norm_path):
observation_stat = np.load(obs_norm_path, allow_pickle=True).item()
self.obs_rms.count = observation_stat['count']
self.obs_rms.mean = observation_stat['mean']
self.obs_rms.var = observation_stat['var']
else:
raise RuntimeError(f"Failed to load observation status file 'obs_rms.npy' from {obs_norm_path}!")
[docs]
def log_infos(self, info: dict, x_index: int):
"""
info: (dict) information to be visualized
n_steps: current step
"""
if self.use_wandb:
for k, v in info.items():
if v is None:
continue
wandb.log({k: v}, step=x_index)
else:
for k, v in info.items():
if v is None:
continue
try:
self.writer.add_scalar(k, v, x_index)
except:
self.writer.add_scalars(k, v, x_index)
[docs]
def log_videos(self, info: dict, fps: int, x_index: int = 0):
if self.use_wandb:
for k, v in info.items():
if v is None:
continue
wandb.log({k: wandb.Video(v, fps=fps, format='gif')}, step=x_index)
else:
for k, v in info.items():
if v is None:
continue
self.writer.add_video(k, v, fps=fps, global_step=x_index)
def _process_observation(self, observations):
if self.use_obsnorm:
if isinstance(self.observation_space, Dict):
for key in self.observation_space.spaces.keys():
if self.is_tensor_memory:
observations[key] = torch.clip(
(observations[key] - self.obs_rms.mean[key]) / (self.obs_rms.std[key] + EPS),
-self.obsnorm_range, self.obsnorm_range)
else:
observations[key] = np.clip(
(observations[key] - self.obs_rms.mean[key]) / (self.obs_rms.std[key] + EPS),
-self.obsnorm_range, self.obsnorm_range)
else:
if self.is_tensor_memory:
observations = torch.clip((observations - self.obs_rms.mean) / (self.obs_rms.std + EPS),
-self.obsnorm_range, self.obsnorm_range)
else:
observations = np.clip((observations - self.obs_rms.mean) / (self.obs_rms.std + EPS),
-self.obsnorm_range, self.obsnorm_range)
return observations
else:
return observations
def _process_reward(self, rewards):
if self.use_rewnorm:
if self.is_tensor_memory:
std = torch.clip(self.ret_rms.std, 0.1, 100)
return torch.clip(rewards / std, -self.rewnorm_range, self.rewnorm_range)
else:
std = np.clip(self.ret_rms.std, 0.1, 100)
return np.clip(rewards / std, -self.rewnorm_range, self.rewnorm_range)
else:
return rewards
def _to_tensor(self, x):
return None if x is None else torch.as_tensor(x, device=self.device)
def _build_representation(self, representation_key: str,
input_space: Optional[Space],
config: Namespace) -> Module:
"""
Build representation for policies.
Parameters:
representation_key (str): The selection of representation, e.g., "Basic_MLP", "Basic_RNN", etc.
input_space (Optional[Space]): The space of input tensors.
config: The configurations for creating the representation module.
Returns:
representation (Module): The representation Module.
"""
input_representations = dict(
input_shape=space2shape(input_space),
hidden_sizes=getattr(config, "representation_hidden_size", None),
normalize=NormalizeFunctions[config.normalize] if hasattr(config, "normalize") else None,
initialize=nn.init.orthogonal_,
activation=ActivationFunctions[config.activation],
kernels=getattr(config, "kernels", None),
strides=getattr(config, "strides", None),
filters=getattr(config, "filters", None),
fc_hidden_sizes=getattr(config, "fc_hidden_sizes", None),
image_patch_size=getattr(config, "image_patch_size", None),
frame_patch_size=getattr(config, "frame_patch_size", None),
final_dim=getattr(config, "final_dim", None),
embedding_dim=getattr(config, "embedding_dim", None),
depth=getattr(config, "depth", None),
heads=getattr(config, "heads", None),
FFN_dim=getattr(config, "FFN_dim", None),
device=self.device)
representation = REGISTRY_Representation[representation_key](**input_representations)
if representation_key not in REGISTRY_Representation:
raise AttributeError(f"{representation_key} is not registered in REGISTRY_Representation.")
return representation
@abstractmethod
def _build_policy(self) -> Module:
raise NotImplementedError
def _build_learner(self, *args):
return REGISTRY_Learners[self.config.learner](*args)
[docs]
@abstractmethod
def get_actions(self, observations):
raise NotImplementedError
[docs]
@abstractmethod
def train(self, train_steps: int) -> dict:
raise NotImplementedError
[docs]
@abstractmethod
def test(self,
test_episodes: int,
test_envs: Optional[DummyVecEnv | SubprocVecEnv] = None,
close_envs: bool = True):
raise NotImplementedError
[docs]
def finish(self):
if self.use_wandb:
wandb.finish()
else:
self.writer.close()
if self.distributed_training:
if dist.get_rank() == 0:
if os.path.exists(self.learner.snapshot_path):
if os.path.exists(os.path.join(self.learner.snapshot_path, "snapshot.pt")):
os.remove(os.path.join(self.learner.snapshot_path, "snapshot.pt"))
os.removedirs(self.learner.snapshot_path)
destroy_process_group()