Source code for xuance.torch.agents.base.agent

import os
import torch
import wandb
import socket
import xuance
import numpy as np
import torch.distributed as dist
from abc import ABC, abstractmethod
from pathlib import Path
from argparse import Namespace
from gymnasium.spaces import Dict, Space
from torch.utils.tensorboard import SummaryWriter
from torch.distributed import destroy_process_group
from xuance.common import get_time_string, create_directory, RunningMeanStd, EPS, Optional, BaseCallback
from xuance.environment import DummyVecEnv, SubprocVecEnv, space2shape
from xuance.torch import REGISTRY_Representation, REGISTRY_Learners, Module
from xuance.torch.utils import (nn, NormalizeFunctions, ActivationFunctions, init_distributed_mode, set_seed,
                                set_device,
                                TensorEnvWrapper, TensorRunningMeanStd)


[docs] class Agent(ABC): """Base class for single-agent Deep Reinforcement Learning (DRL). This class defines the common interface and shared infrastructure for single-agent DRL algorithms in XuanCe. An Agent encapsulates the policy, learner, and training/testing logic, while environments are managed externally by the runner or provided explicitly by the user. The agent can be initialized either with training environments (`envs`) or, for inference/testing-only scenarios, without environments but with explicit observation and action spaces. Args: config (Namespace): Configuration object containing hyperparameters, runtime settings, and environment specifications. envs (Optional[DummyVecEnv | SubprocVecEnv]): Vectorized environments used for training. If None, the agent will not initialize training environments and must be provided with `observation_space` and `action_space` to build networks. observation_space (Optional[gymnasium.spaces.Space]): Observation space specification used to construct policy networks when `envs` is None. Typically obtained from `test_envs.observation_space`. action_space (Optional[gymnasium.spaces.Space]): Action space specification used to construct policy networks when `envs` is None. Typically obtained from `test_envs.action_space`. callback (Optional[BaseCallback]): Optional callback object for injecting custom logic during training or evaluation (e.g., logging, early stopping, or custom hooks). Notes: - When `envs` is provided, the agent assumes a training context and derives observation/action spaces from the environments. - When `envs` is None, the agent can still be used for evaluation or inference as long as the corresponding spaces are explicitly given. - Environment creation and lifecycle management are intentionally decoupled from the agent and handled by the runner or user code. """ def __init__( self, config: Namespace, envs: Optional[DummyVecEnv | SubprocVecEnv] = None, observation_space: Optional[Space] = None, action_space: Optional[Space] = None, callback: Optional[BaseCallback] = None ): set_seed(config.seed) # Training settings. self.config = config self.use_rnn = getattr(config, "use_rnn", False) self.use_actions_mask = getattr(config, "use_actions_mask", False) self.is_tensor_memory = getattr(self.config, "use_tensor_memory", False) self.distributed_training = getattr(config, "distributed_training", False) if self.distributed_training: self.world_size = int(os.environ['WORLD_SIZE']) self.rank = int(os.environ['RANK']) master_port = getattr(config, "master_port", None) init_distributed_mode(master_port=master_port) else: self.world_size = 1 self.rank = 0 self.gamma = config.gamma self.start_training = getattr(config, "start_training", 1) self.training_frequency = getattr(config, "training_frequency", 1) self.n_epochs = getattr(config, "n_epochs", 1) self.device = self.config.device = set_device(self.config.device) # Environment attributes. if self.is_tensor_memory: self.train_envs = TensorEnvWrapper(envs, self.device) else: self.train_envs = envs self.render = config.render self.fps = config.fps if self.train_envs is None: if observation_space is None or action_space is None: raise ValueError("Please provide the observation_space and action_space when the envs is not provided." "Or the networks cannot be built." "You can get them from test_envs.observation_space and test_envs.action_space.") self.n_envs = self.config.parallels self.observation_space = observation_space self.action_space = action_space self.episode_length = self.config.episode_length = None else: self.train_envs.reset() self.n_envs = self.train_envs.num_envs self.episode_length = self.config.episode_length = self.train_envs.max_episode_steps self.observation_space = self.train_envs.observation_space self.action_space = self.train_envs.action_space self.current_step = 0 self.current_episode = np.zeros((self.n_envs,), np.int32) # Set normalizations for observations and rewards. if self.is_tensor_memory: self.obs_rms = TensorRunningMeanStd(shape=space2shape(self.observation_space), device=self.device, distributed=self.distributed_training) self.ret_rms = TensorRunningMeanStd(shape=(), device=self.device, distributed=self.distributed_training) self.returns = torch.zeros(size=(self.n_envs,), dtype=torch.float32, device=self.device) else: self.obs_rms = RunningMeanStd(shape=space2shape(self.observation_space)) self.ret_rms = RunningMeanStd(shape=()) self.returns = np.zeros((self.n_envs,), np.float32) self.use_obsnorm = config.use_obsnorm self.use_rewnorm = config.use_rewnorm self.obsnorm_range = config.obsnorm_range self.rewnorm_range = config.rewnorm_range # Prepare directories. if self.distributed_training and self.world_size > 1: if self.rank == 0: time_string = get_time_string() time_string_tensor = torch.tensor(list(time_string.encode('utf-8')), dtype=torch.uint8).to(self.rank) else: time_string_tensor = torch.zeros(16, dtype=torch.uint8).to(self.rank) dist.broadcast(time_string_tensor, src=0) time_string = bytes(time_string_tensor.cpu().tolist()).decode('utf-8').rstrip('\x00') else: time_string = get_time_string() seed = f"seed_{self.config.seed}_" self.model_dir_load = config.model_dir self.model_dir_save = os.path.join(os.getcwd(), config.model_dir, seed + time_string) # Create logger. if config.logger == "tensorboard": log_dir = os.path.join(os.getcwd(), config.log_dir, seed + time_string) if self.rank == 0: create_directory(log_dir) else: while not os.path.exists(log_dir): pass # Wait until the master process finishes creating directory. self.writer = SummaryWriter(log_dir) self.use_wandb = False elif config.logger == "wandb": config_dict = vars(config) log_dir = config.log_dir wandb_dir = Path(os.path.join(os.getcwd(), config.log_dir)) if self.rank == 0: create_directory(str(wandb_dir)) else: while not os.path.exists(str(wandb_dir)): pass # Wait until the master process finishes creating directory. wandb.init( config=config_dict, project=config.project_name, entity=config.wandb_user_name, notes=socket.gethostname(), dir=wandb_dir, group=config.env_id, job_type=config.agent, name=time_string, reinit=True, settings=wandb.Settings(start_method="fork") ) # os.environ["WANDB_SILENT"] = "True" self.use_wandb = True else: raise AttributeError("No logger is implemented.") self.log_dir = log_dir # Prepare necessary components. self.policy: Optional[Module] = None self.learner: Optional[Module] = None self.memory: Optional[object] = None self.callback = callback or BaseCallback() self.meta_data = dict(algo=self.config.agent, env=self.config.env_name, env_id=self.config.env_id, dl_toolbox=self.config.dl_toolbox, device=self.device, seed=self.config.seed, xuance_version=xuance.__version__)
[docs] def save_model(self, model_name, model_path=None): if self.distributed_training: if self.rank > 0: return # save the neural networks model_path = self.model_dir_save if model_path is None else model_path if not os.path.exists(model_path): os.makedirs(model_path) self.learner.save_model(os.path.join(model_path, model_name)) # save the observation status if self.use_obsnorm: obs_norm_path = os.path.join(model_path, "obs_rms.npy") observation_stat = {'count': self.obs_rms.count, 'mean': self.obs_rms.mean, 'var': self.obs_rms.var} np.save(obs_norm_path, observation_stat)
[docs] def load_model(self, path, model=None): # load neural networks path_loaded = self.learner.load_model(path, model) # recover observation status if self.use_obsnorm: obs_norm_path = os.path.join(path_loaded, "obs_rms.npy") if os.path.exists(obs_norm_path): observation_stat = np.load(obs_norm_path, allow_pickle=True).item() self.obs_rms.count = observation_stat['count'] self.obs_rms.mean = observation_stat['mean'] self.obs_rms.var = observation_stat['var'] else: raise RuntimeError(f"Failed to load observation status file 'obs_rms.npy' from {obs_norm_path}!")
[docs] def log_infos(self, info: dict, x_index: int): """ info: (dict) information to be visualized n_steps: current step """ if self.use_wandb: for k, v in info.items(): if v is None: continue wandb.log({k: v}, step=x_index) else: for k, v in info.items(): if v is None: continue try: self.writer.add_scalar(k, v, x_index) except: self.writer.add_scalars(k, v, x_index)
[docs] def log_videos(self, info: dict, fps: int, x_index: int = 0): if self.use_wandb: for k, v in info.items(): if v is None: continue wandb.log({k: wandb.Video(v, fps=fps, format='gif')}, step=x_index) else: for k, v in info.items(): if v is None: continue self.writer.add_video(k, v, fps=fps, global_step=x_index)
def _process_observation(self, observations): if self.use_obsnorm: if isinstance(self.observation_space, Dict): for key in self.observation_space.spaces.keys(): if self.is_tensor_memory: observations[key] = torch.clip( (observations[key] - self.obs_rms.mean[key]) / (self.obs_rms.std[key] + EPS), -self.obsnorm_range, self.obsnorm_range) else: observations[key] = np.clip( (observations[key] - self.obs_rms.mean[key]) / (self.obs_rms.std[key] + EPS), -self.obsnorm_range, self.obsnorm_range) else: if self.is_tensor_memory: observations = torch.clip((observations - self.obs_rms.mean) / (self.obs_rms.std + EPS), -self.obsnorm_range, self.obsnorm_range) else: observations = np.clip((observations - self.obs_rms.mean) / (self.obs_rms.std + EPS), -self.obsnorm_range, self.obsnorm_range) return observations else: return observations def _process_reward(self, rewards): if self.use_rewnorm: if self.is_tensor_memory: std = torch.clip(self.ret_rms.std, 0.1, 100) return torch.clip(rewards / std, -self.rewnorm_range, self.rewnorm_range) else: std = np.clip(self.ret_rms.std, 0.1, 100) return np.clip(rewards / std, -self.rewnorm_range, self.rewnorm_range) else: return rewards def _to_tensor(self, x): return None if x is None else torch.as_tensor(x, device=self.device) def _build_representation(self, representation_key: str, input_space: Optional[Space], config: Namespace) -> Module: """ Build representation for policies. Parameters: representation_key (str): The selection of representation, e.g., "Basic_MLP", "Basic_RNN", etc. input_space (Optional[Space]): The space of input tensors. config: The configurations for creating the representation module. Returns: representation (Module): The representation Module. """ input_representations = dict( input_shape=space2shape(input_space), hidden_sizes=getattr(config, "representation_hidden_size", None), normalize=NormalizeFunctions[config.normalize] if hasattr(config, "normalize") else None, initialize=nn.init.orthogonal_, activation=ActivationFunctions[config.activation], kernels=getattr(config, "kernels", None), strides=getattr(config, "strides", None), filters=getattr(config, "filters", None), fc_hidden_sizes=getattr(config, "fc_hidden_sizes", None), image_patch_size=getattr(config, "image_patch_size", None), frame_patch_size=getattr(config, "frame_patch_size", None), final_dim=getattr(config, "final_dim", None), embedding_dim=getattr(config, "embedding_dim", None), depth=getattr(config, "depth", None), heads=getattr(config, "heads", None), FFN_dim=getattr(config, "FFN_dim", None), device=self.device) representation = REGISTRY_Representation[representation_key](**input_representations) if representation_key not in REGISTRY_Representation: raise AttributeError(f"{representation_key} is not registered in REGISTRY_Representation.") return representation @abstractmethod def _build_policy(self) -> Module: raise NotImplementedError def _build_learner(self, *args): return REGISTRY_Learners[self.config.learner](*args)
[docs] @abstractmethod def get_actions(self, observations): raise NotImplementedError
[docs] @abstractmethod def train(self, train_steps: int) -> dict: raise NotImplementedError
[docs] @abstractmethod def test(self, test_episodes: int, test_envs: Optional[DummyVecEnv | SubprocVecEnv] = None, close_envs: bool = True): raise NotImplementedError
[docs] def finish(self): if self.use_wandb: wandb.finish() else: self.writer.close() if self.distributed_training: if dist.get_rank() == 0: if os.path.exists(self.learner.snapshot_path): if os.path.exists(os.path.join(self.learner.snapshot_path, "snapshot.pt")): os.remove(os.path.join(self.learner.snapshot_path, "snapshot.pt")) os.removedirs(self.learner.snapshot_path) destroy_process_group()