Source code for xuance.torch.agents.core.on_policy

import numpy as np
from tqdm import tqdm
from copy import deepcopy
from argparse import Namespace
from gymnasium.spaces import Space
from xuance.common import Optional, DummyOnPolicyBuffer, DummyOnPolicyBuffer_Atari, BaseCallback
from xuance.environment import DummyVecEnv, SubprocVecEnv
from xuance.torch import Module
from xuance.torch.utils import split_distributions
from xuance.torch.agents.base import Agent


[docs] class OnPolicyAgent(Agent): """Base class for single-agent on-policy reinforcement learning algorithms. This class implements the common logic shared by on-policy algorithms (e.g., A2C, PPO, TRPO) in XuanCe. It extends the generic `Agent` abstraction with on-policy–specific components such as trajectory buffers, rollout collection, and multi-epoch policy/value updates. The agent can be used in both training and evaluation-only scenarios. When initialized without training environments (`envs=None`), the agent relies on explicitly provided observation and action spaces to construct policy networks, which is useful for inference or standalone evaluation. Args: config (Namespace): Configuration object containing hyperparameters, algorithm settings, and runtime options. envs (Optional[DummyVecEnv | SubprocVecEnv]): Vectorized environments used for training. If None, the agent will not initialize training environments and must be provided with `observation_space` and `action_space`. observation_space (Optional[gymnasium.spaces.Space]): Observation space specification used to build policy and value networks when `envs` is None. Typically obtained from `test_envs.observation_space`. action_space (Optional[gymnasium.spaces.Space]): Action space specification used to build policy and value networks when `envs` is None. Typically obtained from `test_envs.action_space`. callback (Optional[BaseCallback]): Optional callback object for injecting custom logic during training or evaluation, such as logging, early stopping, model checkpointing, or visualization. Notes: - On-policy agents collect fresh trajectories from the current policy and update the policy using rollouts stored in a trajectory buffer. - Training and evaluation environments are conceptually separated; evaluation environments may be created and managed externally. - In evaluation mode, actions are sampled without exploration schedules specific to training (e.g., no epsilon-greedy / action noise). """ def __init__( self, config: Namespace, envs: Optional[DummyVecEnv | SubprocVecEnv] = None, observation_space: Optional[Space] = None, action_space: Optional[Space] = None, callback: Optional[BaseCallback] = None ): super(OnPolicyAgent, self).__init__(config, envs, observation_space, action_space, callback) self.horizon_size = config.horizon_size self.n_epochs = config.n_epochs self.n_minibatch = config.n_minibatch self.gae_lam = config.gae_lambda self.auxiliary_info_shape = None self.memory: Optional[DummyOnPolicyBuffer] = None def _build_memory(self, auxiliary_info_shape=None) -> DummyOnPolicyBuffer: """Build and initialize the on-policy trajectory buffer. This method creates a trajectory buffer instance used to store rollouts collected from the current policy. For Atari environments, a specialized buffer implementation is used to handle image-based observations; otherwise, a standard on-policy buffer is constructed. Args: auxiliary_info_shape (Optional[tuple]): Shape of auxiliary information to be stored alongside transitions in the buffer (e.g., additional state features or metadata). If None, no auxiliary information is stored. Returns: DummyOnPolicyBuffer: An initialized trajectory buffer instance configured with the current observation space, action space, number of parallel environments, horizon size, and GAE/advantage settings. Notes: - The buffer type is selected automatically based on whether the environment is an Atari environment. - The buffer stores rollouts of length `horizon_size` for each parallel environment and is cleared after each update cycle. - When `use_gae` is enabled, the buffer computes advantages using `gamma` and `gae_lam`; when `use_advnorm` is enabled, advantages are normalized before updates. """ self.atari = self.config.env_name == "Atari" Buffer = DummyOnPolicyBuffer_Atari if self.atari else DummyOnPolicyBuffer self.buffer_size = self.n_envs * self.horizon_size self.batch_size = self.buffer_size // self.n_minibatch input_buffer = dict(observation_space=self.observation_space, action_space=self.action_space, auxiliary_shape=auxiliary_info_shape, n_envs=self.n_envs, horizon_size=self.horizon_size, use_gae=self.config.use_gae, use_advnorm=self.config.use_advnorm, gamma=self.gamma, gae_lam=self.gae_lam) return Buffer(**input_buffer) def _build_policy(self) -> Module: raise NotImplementedError
[docs] def get_terminated_values(self, observations_next: np.ndarray, rewards: np.ndarray = None) -> np.ndarray: """Compute value estimates for terminal/terminated states. This method evaluates the value function on terminal observations and returns the value estimates used for bootstrapping (e.g., when finishing a trajectory segment). Args: observations_next (np.ndarray): Observations at the terminal step (or the next observations used for bootstrapping). rewards (Optional[np.ndarray]): Rewards corresponding to the terminal transitions. This argument is reserved for algorithm-specific implementations and may be unused. Returns: np.ndarray: Value estimates for the provided terminal observations. """ policy_out = self.get_actions(self._process_observation(observations_next)) values_next = policy_out['values'] return values_next
[docs] def get_actions(self, observations: np.ndarray, return_dists: bool = False, return_logpi: bool = False) -> dict: """Compute actions and value estimates for a batch of observations. This method performs a forward pass through the current policy to obtain action distributions and value predictions. Actions are sampled stochastically from the policy distribution. Args: observations (np.ndarray): Batch of observations. The array is expected to have shape compatible with the underlying policy. return_dists (bool): Whether to return the action distributions (split into a Python-friendly structure). return_logpi (bool): Whether to return the log-probabilities of the sampled actions. Returns: dict: A dictionary containing: - actions (np.ndarray): Sampled actions to execute in the environment(s). - values (np.ndarray): Value estimates for the input observations. If the policy does not produce values, this is set to 0. - dists (Optional[Any]): Action distributions (when `return_dists=True`); otherwise None. - log_pi (Optional[np.ndarray]): Log-probabilities of sampled actions (when `return_logpi=True`); otherwise None. """ _, policy_dists, values = self.policy(observations) actions = policy_dists.stochastic_sample() log_pi = policy_dists.log_prob(actions).detach().cpu().numpy() if return_logpi else None dists = split_distributions(policy_dists) if return_dists else None actions = actions.detach().cpu().numpy() if values is None: values = 0 else: values = values.detach().cpu().numpy() return {"actions": actions, "values": values, "dists": dists, "log_pi": log_pi}
[docs] def get_aux_info(self, policy_output: dict = None) -> dict: """Returns auxiliary information. Args: policy_output (dict): The output information of the policy. Returns: aux_info (dict): The auxiliary information. """ return {}
[docs] def train_epochs(self, n_epochs: int = 1) -> dict: """Update the policy for multiple epochs using samples from the rollout buffer. This method performs multiple passes over the collected rollout data in `self.memory`. For each epoch, it shuffles transition indices and iterates over mini-batches to compute gradient updates via the learner. Args: n_epochs (int): Number of optimization epochs to perform over the current rollout buffer. Returns: dict: A dictionary of training metrics returned by the learner from the last mini-batch update (e.g., policy loss, value loss, entropy, KL divergence). Implementations may include additional diagnostics depending on the algorithm. """ indexes = np.arange(self.buffer_size) train_info = {} for _ in range(n_epochs): np.random.shuffle(indexes) for start in range(0, self.buffer_size, self.batch_size): end = start + self.batch_size sample_idx = indexes[start:end] samples = self.memory.sample(sample_idx) train_info = self.learner.update(**samples) return train_info
[docs] def train(self, train_steps: int) -> dict: """Run the main on-policy training loop. This method interacts with the training environments to collect rollouts from the current policy, stores transitions in the on-policy trajectory buffer, and triggers policy/value updates when the buffer is full. The loop advances in vectorized steps (one iteration corresponds to `self.n_envs` environment steps). Args: train_steps (int): Number of rollout collection iterations to run. Each iteration steps all vectorized environments once, so the total number of environment steps is approximately `train_steps * self.n_envs`. Returns: dict: A dictionary containing aggregated training information and logged metrics collected during training. Notes: - This method assumes that training environments (`self.train_envs`) and the trajectory buffer (`self.memory`) have already been initialized. - After collecting `horizon_size` steps per environment, the buffer becomes full and the agent computes bootstrapped terminal values, finalizes trajectory segments via `finish_path`, and performs `n_epochs` optimization passes over mini-batches using `train_epochs`. - Episode termination and reset logic are handled per environment, and episode-level statistics are reported via callbacks. """ train_info = {} obs = self.train_envs.buf_obs for _ in tqdm(range(train_steps)): self.obs_rms.update(obs) obs = self._process_observation(obs) policy_out = self.get_actions(obs, return_dists=False, return_logpi=False) acts, vals = policy_out['actions'], policy_out['values'] next_obs, rewards, terminals, truncations, infos = self.train_envs.step(acts) aux_info = self.get_aux_info() self.callback.on_train_step(self.current_step, envs=self.train_envs, policy=self.policy, obs=obs, policy_out=policy_out, acts=acts, vals=vals, next_obs=next_obs, rewards=rewards, terminals=terminals, truncations=truncations, infos=infos, aux_info=aux_info, train_steps=train_steps) self.memory.store(obs, acts, self._process_reward(rewards), vals, terminals, aux_info) if self.memory.full: vals = self.get_terminated_values(next_obs, rewards) for i in range(self.n_envs): if terminals[i]: self.memory.finish_path(0.0, i) else: self.memory.finish_path(vals[i], i) update_info = self.train_epochs(self.n_epochs) self.log_infos(update_info, self.current_step) train_info.update(update_info) self.callback.on_train_epochs_end(self.current_step, policy=self.policy, memory=self.memory, current_episode=self.current_episode, train_steps=train_steps, update_info=update_info) self.memory.clear() self.returns = self.gamma * self.returns + rewards obs = deepcopy(next_obs) for i in range(self.n_envs): if terminals[i] or truncations[i]: self.ret_rms.update(self.returns[i:i + 1]) self.returns[i] = 0.0 if self.atari and (~truncations[i]): pass else: if terminals[i]: self.memory.finish_path(0, i) else: vals = self.get_terminated_values(next_obs, rewards) self.memory.finish_path(vals[i], i) obs[i] = infos[i]["reset_obs"] self.train_envs.buf_obs[i] = obs[i] self.current_episode[i] += 1 if self.use_wandb: episode_info = { f"Episode-Steps/rank_{self.rank}/env-{i}": infos[i]["episode_step"], f"Train-Episode-Rewards/rank_{self.rank}/env-{i}": infos[i]["episode_score"] } else: episode_info = { f"Episode-Steps/rank_{self.rank}": {f"env-{i}": infos[i]["episode_step"]}, f"Train-Episode-Rewards/rank_{self.rank}": {f"env-{i}": infos[i]["episode_score"]} } self.log_infos(episode_info, self.current_step) train_info.update(episode_info) self.callback.on_train_episode_info(envs=self.train_envs, policy=self.policy, env_id=i, infos=infos, rank=self.rank, use_wandb=self.use_wandb, current_step=self.current_step, current_episode=self.current_episode, train_steps=train_steps) self.current_step += self.n_envs self.callback.on_train_step_end(self.current_step, envs=self.train_envs, policy=self.policy, train_steps=train_steps, train_info=train_info) return train_info
[docs] def test(self, test_episodes: int, test_envs: Optional[DummyVecEnv | SubprocVecEnv] = None, close_envs: bool = True) -> list: """Evaluate the current policy in a vectorized environment. This method runs evaluation episodes using `test_envs` and returns the per-episode scores. Actions are produced by the current policy (by default sampled from the policy distribution for on-policy methods), and optional RGB-array frames can be recorded for video logging when rendering is enabled. Args: test_episodes (int): Total number of evaluation episodes to run across all vectorized environments. test_envs (Optional[DummyVecEnv | SubprocVecEnv]): Vectorized environments used for evaluation. Must not be None. close_envs (bool): Whether to close `test_envs` before returning. Set this to False if `test_envs` is managed externally and will be reused after evaluation. Returns: list: A list of episode scores collected during evaluation. Notes: - This method resets the evaluation environments at the beginning of testing and steps them until `test_episodes` episodes are completed. - When `render_mode == "rgb_array"` and `self.render` is True, the method records frames and logs the best-scoring episode as a video. - By default, this implementation updates `obs_rms` during testing. If you want to avoid contaminating training statistics, consider guarding this update with a dedicated flag (e.g., `update_rms=False`). """ if test_envs is None: raise ValueError("`test_envs` must be provided for evaluation.") num_envs = test_envs.num_envs videos, episode_videos, images = [[] for _ in range(num_envs)], [], None current_episode, current_step, scores, best_score = 0, 0, [], -np.inf obs, infos = test_envs.reset() if self.config.render_mode == "rgb_array" and self.render: images = test_envs.render(self.config.render_mode) for idx, img in enumerate(images): videos[idx].append(img) while current_episode < test_episodes: self.obs_rms.update(obs) obs = self._process_observation(obs) policy_out = self.get_actions(obs) next_obs, rewards, terminals, truncations, infos = test_envs.step(policy_out['actions']) if self.config.render_mode == "rgb_array" and self.render: images = test_envs.render(self.config.render_mode) for idx, img in enumerate(images): videos[idx].append(img) self.callback.on_test_step(envs=test_envs, policy=self.policy, images=images, obs=obs, policy_out=policy_out, next_obs=next_obs, rewards=rewards, terminals=terminals, truncations=truncations, infos=infos, current_train_step=self.current_step, current_step=current_step, current_episode=current_episode) obs = deepcopy(next_obs) for i in range(num_envs): if terminals[i] or truncations[i]: if self.atari and (~truncations[i]): pass else: obs[i] = infos[i]["reset_obs"] scores.append(infos[i]["episode_score"]) current_episode += 1 if best_score < infos[i]["episode_score"]: best_score = infos[i]["episode_score"] episode_videos = videos[i].copy() current_step += num_envs if self.config.render_mode == "rgb_array" and self.render: # time, height, width, channel -> time, channel, height, width videos_info = {"Videos_Test": np.array([episode_videos], dtype=np.uint8).transpose((0, 1, 4, 2, 3))} self.log_videos(info=videos_info, fps=self.fps, x_index=self.current_step) test_info = { "Test-Episode-Rewards/Mean-Score": np.mean(scores), "Test-Episode-Rewards/Std-Score": np.std(scores) } self.log_infos(test_info, self.current_step) self.callback.on_test_end(envs=test_envs, policy=self.policy, current_train_step=self.current_step, current_step=current_step, current_episode=current_episode, scores=scores, best_score=best_score) if close_envs: test_envs.close() return scores