import numpy as np
from tqdm import tqdm
from copy import deepcopy
from argparse import Namespace
from gymnasium.spaces import Space
from xuance.common import Optional, DummyOffPolicyBuffer, DummyOffPolicyBuffer_Atari, BaseCallback
from xuance.environment import DummyVecEnv, SubprocVecEnv
from xuance.mindspore import Module, Tensor
from xuance.mindspore.agents.base import Agent
[docs]
class OffPolicyAgent(Agent):
"""Base class for single-agent off-policy reinforcement learning algorithms.
This class implements the common logic shared by off-policy algorithms (e.g., DQN, DDPG, TD3, SAC) in XuanCe.
It extends the generic `Agent` abstraction with off-policy–specific components such as replay buffers,
exploration strategies, and update schedules.
The agent can be used in both training and evaluation-only scenarios. When initialized without training environments
(`envs=None`), the agent relies on explicitly provided observation and action spaces to construct policy networks,
which is useful for inference or standalone evaluation.
Args:
config (Namespace): Configuration object containing hyperparameters, algorithm settings, and runtime options.
envs (Optional[DummyVecEnv | SubprocVecEnv]): Vectorized environments used for training.
If None, the agent will not initialize training environments and must be provided with `observation_space`
and `action_space`.
observation_space (Optional[gymnasium.spaces.Space]): Observation space specification used to build policy
and value networks when `envs` is None. Typically obtained from `test_envs.observation_space`.
action_space (Optional[gymnasium.spaces.Space]): Action space specification used to build policy and
value networks when `envs` is None. Typically obtained from `test_envs.action_space`.
callback (Optional[BaseCallback]): Optional callback object for injecting custom logic during training
or evaluation, such as logging, early stopping, model checkpointing, or visualization.
Notes:
- Off-policy agents maintain a replay buffer to reuse past experience.
- Training and evaluation environments are conceptually separated;
evaluation environments may be created and managed externally.
- In evaluation mode, exploration noise is disabled and policy actions
are executed deterministically by default.
"""
def __init__(
self,
config: Namespace,
envs: Optional[DummyVecEnv | SubprocVecEnv] = None,
observation_space: Optional[Space] = None,
action_space: Optional[Space] = None,
callback: Optional[BaseCallback] = None
):
super(OffPolicyAgent, self).__init__(config, envs, observation_space, action_space, callback)
self.start_greedy = config.start_greedy if hasattr(config, "start_greedy") else None
self.end_greedy = config.end_greedy if hasattr(config, "start_greedy") else None
self.delta_egreedy: Optional[float] = None
self.e_greedy: Optional[float] = None
self.start_noise = config.start_noise if hasattr(config, "start_noise") else None
self.end_noise = config.end_noise if hasattr(config, "end_noise") else None
self.delta_noise: Optional[float] = None
self.noise_scale: Optional[float] = None
self.actions_low = self.action_space.low if hasattr(self.action_space, "low") else None
self.actions_high = self.action_space.high if hasattr(self.action_space, "high") else None
self.auxiliary_info_shape = None
self.memory: Optional[DummyOffPolicyBuffer] = None
def _build_memory(self, auxiliary_info_shape=None) -> DummyOffPolicyBuffer:
"""Build and initialize the replay buffer for off-policy training.
This method creates a replay buffer instance based on the environment type and agent configuration.
For Atari environments, a specialized replay buffer implementation is used to handle image-based observations;
otherwise, a standard off-policy replay buffer is constructed.
Args:
auxiliary_info_shape (Optional[tuple]): Shape of auxiliary information to be stored alongside transitions
in the replay buffer (e.g., additional state features or metadata).
If None, no auxiliary information is stored.
Returns:
DummyOffPolicyBuffer: An initialized replay buffer instance configured with the current observation space,
action space, number of parallel environments, buffer size, and batch size.
Notes:
- The buffer type is selected automatically based on whether the environment is an Atari environment.
- The replay buffer is shared across all parallel environments and
supports batched sampling for off-policy updates.
"""
self.atari = True if self.config.env_name == "Atari" else False
Buffer = DummyOffPolicyBuffer_Atari if self.atari else DummyOffPolicyBuffer
input_buffer = dict(observation_space=self.observation_space,
action_space=self.action_space,
auxiliary_shape=auxiliary_info_shape,
n_envs=self.n_envs,
buffer_size=self.config.buffer_size,
batch_size=self.config.batch_size)
return Buffer(**input_buffer)
def _build_policy(self) -> Module:
raise NotImplementedError
def _update_explore_factor(self):
if self.e_greedy is not None:
if self.e_greedy > self.end_greedy:
self.e_greedy = self.start_greedy - self.current_step * self.delta_egreedy
elif self.noise_scale is not None:
if self.noise_scale >= self.end_noise:
self.noise_scale = self.start_noise - self.current_step * self.delta_noise
else:
return
[docs]
def exploration(self, pi_actions):
"""Returns the actions for exploration.
Parameters:
pi_actions: The original output actions.
Returns:
explore_actions: The actions with noisy values.
"""
if self.e_greedy is not None:
random_actions = np.random.choice(self.action_space.n, self.n_envs)
if np.random.rand() < self.e_greedy:
explore_actions = random_actions
else:
explore_actions = pi_actions.asnumpy()
elif self.noise_scale is not None:
explore_actions = pi_actions + np.random.normal(size=pi_actions.shape) * self.noise_scale
explore_actions = np.clip(explore_actions, self.actions_low, self.actions_high)
else:
explore_actions = pi_actions
return explore_actions
[docs]
def get_actions(self, observations: np.ndarray,
test_mode: Optional[bool] = False) -> dict:
"""Returns actions and values.
Parameters:
observations (np.ndarray): The observation.
test_mode (Optional[bool]): True for testing without noises.
Returns:
actions: The actions to be executed.
values: The evaluated values.
dists: The policy distributions.
log_pi: Log of stochastic actions.
"""
_, actions_output, _ = self.policy(Tensor(observations))
if test_mode:
actions = actions_output.asnumpy()
else:
actions = self.exploration(actions_output)
return {"actions": actions}
[docs]
def train_epochs(self, n_epochs=1) -> dict:
train_info = {}
for _ in range(n_epochs):
samples = self.memory.sample()
train_info = self.learner.update(**samples)
train_info["epsilon-greedy"] = self.e_greedy
train_info["noise_scale"] = self.noise_scale
return train_info
[docs]
def train(self, train_steps: int) -> dict:
"""Run the main off-policy training loop.
This method interacts with the training environments for a fixed number
of environment steps, collects transitions, stores them in the replay
buffer, and periodically updates the policy using sampled mini-batches.
Training proceeds in a step-based manner (not episode-based): at each iteration, the agent selects actions,
steps the environments, logs intermediate information via callbacks, and performs policy updates
when the configured conditions are met.
Args:
train_steps (int): Total number of environment steps to execute for training. The actual loop advances in
increments of `self.n_envs` steps per iteration due to vectorized environments.
Returns:
dict: A dictionary containing aggregated training information and
logged metrics collected during the training process.
Notes:
- This method assumes that training environments (`self.train_envs`)
and the replay buffer (`self.memory`) have already been initialized.
- Exploration behavior (e.g., epsilon-greedy or action noise) is applied during training and updated
dynamically based on the current training step.
- Policy updates are triggered periodically according to
`self.training_frequency` after `self.start_training` steps.
- Episode termination and reset logic are handled per environment
instance, and episode-level statistics are logged via callbacks.
"""
train_info = {}
obs = self.train_envs.buf_obs
for _ in tqdm(range(train_steps)):
self.obs_rms.update(obs)
obs = self._process_observation(obs)
policy_out = self.get_actions(obs, test_mode=False)
acts = policy_out['actions']
next_obs, rewards, terminals, truncations, infos = self.train_envs.step(acts)
self.callback.on_train_step(self.current_step, envs=self.train_envs, policy=self.policy,
obs=obs, policy_out=policy_out, acts=acts, next_obs=next_obs, rewards=rewards,
terminals=terminals, truncations=truncations, infos=infos,
train_steps=train_steps)
self.memory.store(obs, acts, self._process_reward(rewards), terminals, self._process_observation(next_obs))
if self.current_step > self.start_training and self.current_step % self.training_frequency == 0:
update_info = self.train_epochs(n_epochs=self.n_epochs)
self.log_infos(update_info, self.current_step)
train_info.update(update_info)
self.callback.on_train_epochs_end(self.current_step, policy=self.policy, memory=self.memory,
current_episode=self.current_episode, train_steps=train_steps,
update_info=update_info)
self.returns = self.gamma * self.returns + rewards
obs = deepcopy(next_obs)
for i in range(self.n_envs):
if terminals[i] or truncations[i]:
if self.atari and (~truncations[i]):
pass
else:
obs[i] = infos[i]["reset_obs"]
self.train_envs.buf_obs[i] = obs[i]
self.ret_rms.update(self.returns[i:i + 1])
self.returns[i] = 0.0
self.current_episode[i] += 1
if self.use_wandb:
episode_info = {
f"Episode-Steps/env-{i}": infos[i]["episode_step"],
f"Train-Episode-Rewards/env-{i}": infos[i]["episode_score"]
}
else:
episode_info = {
f"Episode-Steps": {f"env-{i}": infos[i]["episode_step"]},
f"Train-Episode-Rewards": {f"env-{i}": infos[i]["episode_score"]}
}
self.log_infos(episode_info, self.current_step)
train_info.update(episode_info)
self.callback.on_train_episode_info(envs=self.train_envs, policy=self.policy, env_id=i,
infos=infos, use_wandb=self.use_wandb,
current_step=self.current_step,
current_episode=self.current_episode,
train_steps=train_steps)
self.current_step += self.n_envs
self._update_explore_factor()
self.callback.on_train_step_end(self.current_step, envs=self.train_envs, policy=self.policy,
train_steps=train_steps, train_info=train_info)
return train_info
[docs]
def test(self,
test_episodes: int,
test_envs: Optional[DummyVecEnv | SubprocVecEnv] = None,
close_envs: bool = True) -> list:
"""Evaluate the current policy in a vectorized environment.
This method runs evaluation episodes using `test_envs` and returns the per-episode scores. During evaluation,
actions are selected in deterministic (test) mode and optional RGB-array frames can be recorded
for video logging when rendering is enabled.
Args:
test_episodes (int): Total number of evaluation episodes to run across all vectorized environments.
test_envs (Optional[DummyVecEnv | SubprocVecEnv]): Vectorized environments used for evaluation.
Must not be None.
close_envs (bool): Whether to close `test_envs` before returning.
Set this to False if `test_envs` is managed externally and will be reused after evaluation.
Returns:
list: A list of episode scores collected during evaluation.
Notes:
- This method resets the evaluation environments at the beginning of testing and steps them
until `test_episodes` episodes are completed.
- When `render_mode == "rgb_array"` and `self.render` is True, the method records frames and logs the
best-scoring episode as a video.
- By default, this implementation updates `obs_rms` during testing. If you want to avoid contaminating
training statistics, consider guarding this update with a dedicated flag (e.g., `update_rms=False`).
"""
if test_envs is None:
raise ValueError("`test_envs` must be provided for evaluation.")
num_envs = test_envs.num_envs
videos, episode_videos, images = [[] for _ in range(num_envs)], [], None
current_episode, current_step, scores, best_score = 0, 0, [], -np.inf
obs, infos = test_envs.reset()
if self.config.render_mode == "rgb_array" and self.render:
images = test_envs.render(self.config.render_mode)
for idx, img in enumerate(images):
videos[idx].append(img)
while current_episode < test_episodes:
self.obs_rms.update(obs)
obs = self._process_observation(obs)
policy_out = self.get_actions(obs, test_mode=True)
next_obs, rewards, terminals, truncations, infos = test_envs.step(policy_out['actions'])
if self.config.render_mode == "rgb_array" and self.render:
images = test_envs.render(self.config.render_mode)
for idx, img in enumerate(images):
videos[idx].append(img)
self.callback.on_test_step(envs=test_envs, policy=self.policy, images=images,
obs=obs, policy_out=policy_out, next_obs=next_obs, rewards=rewards,
terminals=terminals, truncations=truncations, infos=infos,
current_train_step=self.current_step,
current_step=current_step, current_episode=current_episode)
obs = deepcopy(next_obs)
for i in range(num_envs):
if terminals[i] or truncations[i]:
if self.atari and (~truncations[i]):
pass
else:
obs[i] = infos[i]["reset_obs"]
scores.append(infos[i]["episode_score"])
current_episode += 1
if best_score < infos[i]["episode_score"]:
best_score = infos[i]["episode_score"]
episode_videos = videos[i].copy()
current_step += num_envs
if self.config.render_mode == "rgb_array" and self.render:
# time, height, width, channel -> time, channel, height, width
videos_info = {"Videos_Test": np.array([episode_videos], dtype=np.uint8).transpose((0, 1, 4, 2, 3))}
self.log_videos(info=videos_info, fps=self.fps, x_index=self.current_step)
test_info = {
"Test-Episode-Rewards/Mean-Score": np.mean(scores),
"Test-Episode-Rewards/Std-Score": np.std(scores)
}
self.log_infos(test_info, self.current_step)
self.callback.on_test_end(envs=test_envs, policy=self.policy,
current_train_step=self.current_step,
current_step=current_step, current_episode=current_episode,
scores=scores, best_score=best_score)
if close_envs:
test_envs.close()
return scores