Source code for xuance.mindspore.agents.core.off_policy

import numpy as np
from tqdm import tqdm
from copy import deepcopy
from argparse import Namespace
from gymnasium.spaces import Space
from xuance.common import Optional, DummyOffPolicyBuffer, DummyOffPolicyBuffer_Atari, BaseCallback
from xuance.environment import DummyVecEnv, SubprocVecEnv
from xuance.mindspore import Module, Tensor
from xuance.mindspore.agents.base import Agent


[docs] class OffPolicyAgent(Agent): """Base class for single-agent off-policy reinforcement learning algorithms. This class implements the common logic shared by off-policy algorithms (e.g., DQN, DDPG, TD3, SAC) in XuanCe. It extends the generic `Agent` abstraction with off-policy–specific components such as replay buffers, exploration strategies, and update schedules. The agent can be used in both training and evaluation-only scenarios. When initialized without training environments (`envs=None`), the agent relies on explicitly provided observation and action spaces to construct policy networks, which is useful for inference or standalone evaluation. Args: config (Namespace): Configuration object containing hyperparameters, algorithm settings, and runtime options. envs (Optional[DummyVecEnv | SubprocVecEnv]): Vectorized environments used for training. If None, the agent will not initialize training environments and must be provided with `observation_space` and `action_space`. observation_space (Optional[gymnasium.spaces.Space]): Observation space specification used to build policy and value networks when `envs` is None. Typically obtained from `test_envs.observation_space`. action_space (Optional[gymnasium.spaces.Space]): Action space specification used to build policy and value networks when `envs` is None. Typically obtained from `test_envs.action_space`. callback (Optional[BaseCallback]): Optional callback object for injecting custom logic during training or evaluation, such as logging, early stopping, model checkpointing, or visualization. Notes: - Off-policy agents maintain a replay buffer to reuse past experience. - Training and evaluation environments are conceptually separated; evaluation environments may be created and managed externally. - In evaluation mode, exploration noise is disabled and policy actions are executed deterministically by default. """ def __init__( self, config: Namespace, envs: Optional[DummyVecEnv | SubprocVecEnv] = None, observation_space: Optional[Space] = None, action_space: Optional[Space] = None, callback: Optional[BaseCallback] = None ): super(OffPolicyAgent, self).__init__(config, envs, observation_space, action_space, callback) self.start_greedy = config.start_greedy if hasattr(config, "start_greedy") else None self.end_greedy = config.end_greedy if hasattr(config, "start_greedy") else None self.delta_egreedy: Optional[float] = None self.e_greedy: Optional[float] = None self.start_noise = config.start_noise if hasattr(config, "start_noise") else None self.end_noise = config.end_noise if hasattr(config, "end_noise") else None self.delta_noise: Optional[float] = None self.noise_scale: Optional[float] = None self.actions_low = self.action_space.low if hasattr(self.action_space, "low") else None self.actions_high = self.action_space.high if hasattr(self.action_space, "high") else None self.auxiliary_info_shape = None self.memory: Optional[DummyOffPolicyBuffer] = None def _build_memory(self, auxiliary_info_shape=None) -> DummyOffPolicyBuffer: """Build and initialize the replay buffer for off-policy training. This method creates a replay buffer instance based on the environment type and agent configuration. For Atari environments, a specialized replay buffer implementation is used to handle image-based observations; otherwise, a standard off-policy replay buffer is constructed. Args: auxiliary_info_shape (Optional[tuple]): Shape of auxiliary information to be stored alongside transitions in the replay buffer (e.g., additional state features or metadata). If None, no auxiliary information is stored. Returns: DummyOffPolicyBuffer: An initialized replay buffer instance configured with the current observation space, action space, number of parallel environments, buffer size, and batch size. Notes: - The buffer type is selected automatically based on whether the environment is an Atari environment. - The replay buffer is shared across all parallel environments and supports batched sampling for off-policy updates. """ self.atari = True if self.config.env_name == "Atari" else False Buffer = DummyOffPolicyBuffer_Atari if self.atari else DummyOffPolicyBuffer input_buffer = dict(observation_space=self.observation_space, action_space=self.action_space, auxiliary_shape=auxiliary_info_shape, n_envs=self.n_envs, buffer_size=self.config.buffer_size, batch_size=self.config.batch_size) return Buffer(**input_buffer) def _build_policy(self) -> Module: raise NotImplementedError def _update_explore_factor(self): if self.e_greedy is not None: if self.e_greedy > self.end_greedy: self.e_greedy = self.start_greedy - self.current_step * self.delta_egreedy elif self.noise_scale is not None: if self.noise_scale >= self.end_noise: self.noise_scale = self.start_noise - self.current_step * self.delta_noise else: return
[docs] def exploration(self, pi_actions): """Returns the actions for exploration. Parameters: pi_actions: The original output actions. Returns: explore_actions: The actions with noisy values. """ if self.e_greedy is not None: random_actions = np.random.choice(self.action_space.n, self.n_envs) if np.random.rand() < self.e_greedy: explore_actions = random_actions else: explore_actions = pi_actions.asnumpy() elif self.noise_scale is not None: explore_actions = pi_actions + np.random.normal(size=pi_actions.shape) * self.noise_scale explore_actions = np.clip(explore_actions, self.actions_low, self.actions_high) else: explore_actions = pi_actions return explore_actions
[docs] def get_actions(self, observations: np.ndarray, test_mode: Optional[bool] = False) -> dict: """Returns actions and values. Parameters: observations (np.ndarray): The observation. test_mode (Optional[bool]): True for testing without noises. Returns: actions: The actions to be executed. values: The evaluated values. dists: The policy distributions. log_pi: Log of stochastic actions. """ _, actions_output, _ = self.policy(Tensor(observations)) if test_mode: actions = actions_output.asnumpy() else: actions = self.exploration(actions_output) return {"actions": actions}
[docs] def train_epochs(self, n_epochs=1) -> dict: train_info = {} for _ in range(n_epochs): samples = self.memory.sample() train_info = self.learner.update(**samples) train_info["epsilon-greedy"] = self.e_greedy train_info["noise_scale"] = self.noise_scale return train_info
[docs] def train(self, train_steps: int) -> dict: """Run the main off-policy training loop. This method interacts with the training environments for a fixed number of environment steps, collects transitions, stores them in the replay buffer, and periodically updates the policy using sampled mini-batches. Training proceeds in a step-based manner (not episode-based): at each iteration, the agent selects actions, steps the environments, logs intermediate information via callbacks, and performs policy updates when the configured conditions are met. Args: train_steps (int): Total number of environment steps to execute for training. The actual loop advances in increments of `self.n_envs` steps per iteration due to vectorized environments. Returns: dict: A dictionary containing aggregated training information and logged metrics collected during the training process. Notes: - This method assumes that training environments (`self.train_envs`) and the replay buffer (`self.memory`) have already been initialized. - Exploration behavior (e.g., epsilon-greedy or action noise) is applied during training and updated dynamically based on the current training step. - Policy updates are triggered periodically according to `self.training_frequency` after `self.start_training` steps. - Episode termination and reset logic are handled per environment instance, and episode-level statistics are logged via callbacks. """ train_info = {} obs = self.train_envs.buf_obs for _ in tqdm(range(train_steps)): self.obs_rms.update(obs) obs = self._process_observation(obs) policy_out = self.get_actions(obs, test_mode=False) acts = policy_out['actions'] next_obs, rewards, terminals, truncations, infos = self.train_envs.step(acts) self.callback.on_train_step(self.current_step, envs=self.train_envs, policy=self.policy, obs=obs, policy_out=policy_out, acts=acts, next_obs=next_obs, rewards=rewards, terminals=terminals, truncations=truncations, infos=infos, train_steps=train_steps) self.memory.store(obs, acts, self._process_reward(rewards), terminals, self._process_observation(next_obs)) if self.current_step > self.start_training and self.current_step % self.training_frequency == 0: update_info = self.train_epochs(n_epochs=self.n_epochs) self.log_infos(update_info, self.current_step) train_info.update(update_info) self.callback.on_train_epochs_end(self.current_step, policy=self.policy, memory=self.memory, current_episode=self.current_episode, train_steps=train_steps, update_info=update_info) self.returns = self.gamma * self.returns + rewards obs = deepcopy(next_obs) for i in range(self.n_envs): if terminals[i] or truncations[i]: if self.atari and (~truncations[i]): pass else: obs[i] = infos[i]["reset_obs"] self.train_envs.buf_obs[i] = obs[i] self.ret_rms.update(self.returns[i:i + 1]) self.returns[i] = 0.0 self.current_episode[i] += 1 if self.use_wandb: episode_info = { f"Episode-Steps/env-{i}": infos[i]["episode_step"], f"Train-Episode-Rewards/env-{i}": infos[i]["episode_score"] } else: episode_info = { f"Episode-Steps": {f"env-{i}": infos[i]["episode_step"]}, f"Train-Episode-Rewards": {f"env-{i}": infos[i]["episode_score"]} } self.log_infos(episode_info, self.current_step) train_info.update(episode_info) self.callback.on_train_episode_info(envs=self.train_envs, policy=self.policy, env_id=i, infos=infos, use_wandb=self.use_wandb, current_step=self.current_step, current_episode=self.current_episode, train_steps=train_steps) self.current_step += self.n_envs self._update_explore_factor() self.callback.on_train_step_end(self.current_step, envs=self.train_envs, policy=self.policy, train_steps=train_steps, train_info=train_info) return train_info
[docs] def test(self, test_episodes: int, test_envs: Optional[DummyVecEnv | SubprocVecEnv] = None, close_envs: bool = True) -> list: """Evaluate the current policy in a vectorized environment. This method runs evaluation episodes using `test_envs` and returns the per-episode scores. During evaluation, actions are selected in deterministic (test) mode and optional RGB-array frames can be recorded for video logging when rendering is enabled. Args: test_episodes (int): Total number of evaluation episodes to run across all vectorized environments. test_envs (Optional[DummyVecEnv | SubprocVecEnv]): Vectorized environments used for evaluation. Must not be None. close_envs (bool): Whether to close `test_envs` before returning. Set this to False if `test_envs` is managed externally and will be reused after evaluation. Returns: list: A list of episode scores collected during evaluation. Notes: - This method resets the evaluation environments at the beginning of testing and steps them until `test_episodes` episodes are completed. - When `render_mode == "rgb_array"` and `self.render` is True, the method records frames and logs the best-scoring episode as a video. - By default, this implementation updates `obs_rms` during testing. If you want to avoid contaminating training statistics, consider guarding this update with a dedicated flag (e.g., `update_rms=False`). """ if test_envs is None: raise ValueError("`test_envs` must be provided for evaluation.") num_envs = test_envs.num_envs videos, episode_videos, images = [[] for _ in range(num_envs)], [], None current_episode, current_step, scores, best_score = 0, 0, [], -np.inf obs, infos = test_envs.reset() if self.config.render_mode == "rgb_array" and self.render: images = test_envs.render(self.config.render_mode) for idx, img in enumerate(images): videos[idx].append(img) while current_episode < test_episodes: self.obs_rms.update(obs) obs = self._process_observation(obs) policy_out = self.get_actions(obs, test_mode=True) next_obs, rewards, terminals, truncations, infos = test_envs.step(policy_out['actions']) if self.config.render_mode == "rgb_array" and self.render: images = test_envs.render(self.config.render_mode) for idx, img in enumerate(images): videos[idx].append(img) self.callback.on_test_step(envs=test_envs, policy=self.policy, images=images, obs=obs, policy_out=policy_out, next_obs=next_obs, rewards=rewards, terminals=terminals, truncations=truncations, infos=infos, current_train_step=self.current_step, current_step=current_step, current_episode=current_episode) obs = deepcopy(next_obs) for i in range(num_envs): if terminals[i] or truncations[i]: if self.atari and (~truncations[i]): pass else: obs[i] = infos[i]["reset_obs"] scores.append(infos[i]["episode_score"]) current_episode += 1 if best_score < infos[i]["episode_score"]: best_score = infos[i]["episode_score"] episode_videos = videos[i].copy() current_step += num_envs if self.config.render_mode == "rgb_array" and self.render: # time, height, width, channel -> time, channel, height, width videos_info = {"Videos_Test": np.array([episode_videos], dtype=np.uint8).transpose((0, 1, 4, 2, 3))} self.log_videos(info=videos_info, fps=self.fps, x_index=self.current_step) test_info = { "Test-Episode-Rewards/Mean-Score": np.mean(scores), "Test-Episode-Rewards/Std-Score": np.std(scores) } self.log_infos(test_info, self.current_step) self.callback.on_test_end(envs=test_envs, policy=self.policy, current_train_step=self.current_step, current_step=current_step, current_episode=current_episode, scores=scores, best_score=best_score) if close_envs: test_envs.close() return scores