Source code for xuance.torch.agents.qlearning_family.noisydqn_agent

import torch
import numpy as np
from tqdm import tqdm
from copy import deepcopy
from argparse import Namespace
from gymnasium.spaces import Space
from xuance.common import Optional, DummyOffPolicyBuffer, DummyOffPolicyBuffer_Atari, BaseCallback
from xuance.environment import DummyVecEnv, SubprocVecEnv
from xuance.torch import Module
from xuance.torch.utils import NormalizeFunctions, ActivationFunctions
from xuance.torch.policies import REGISTRY_Policy
from xuance.torch.agents import Agent


[docs] class NoisyDQN_Agent(Agent): def __init__( self, config: Namespace, envs: Optional[DummyVecEnv | SubprocVecEnv] = None, observation_space: Optional[Space] = None, action_space: Optional[Space] = None, callback: Optional[BaseCallback] = None ): super(NoisyDQN_Agent, self).__init__(config, envs, observation_space, action_space, callback) self.start_noise, self.end_noise = config.start_noise, config.end_noise self.noise_scale = config.start_noise self.delta_noise = (self.start_noise - self.end_noise) / (self.config.decay_step_noise / self.n_envs) # Build policy, optimizer, scheduler. self.policy = self._build_policy() # Create experience replay buffer. input_buffer = dict(observation_space=self.observation_space, action_space=self.action_space, auxiliary_shape={}, n_envs=self.n_envs, buffer_size=self.config.buffer_size, batch_size=self.config.batch_size) self.auxiliary_info_shape = {} self.atari = True if config.env_name == "Atari" else False Buffer = DummyOffPolicyBuffer_Atari if self.atari else DummyOffPolicyBuffer self.memory = Buffer(**input_buffer) self.learner = self._build_learner(self.config, self.policy, self.callback) def _build_policy(self) -> Module: normalize_fn = NormalizeFunctions[self.config.normalize] if hasattr(self.config, "normalize") else None initializer = torch.nn.init.orthogonal_ activation = ActivationFunctions[self.config.activation] device = self.device # build representation. representation = self._build_representation(self.config.representation, self.observation_space, self.config) # build policy. if self.config.policy == "Noisy_Q_network": policy = REGISTRY_Policy["Noisy_Q_network"]( action_space=self.action_space, representation=representation, hidden_size=self.config.q_hidden_size, normalize=normalize_fn, initialize=initializer, activation=activation, device=device, use_distributed_training=self.distributed_training) else: raise AttributeError(f"{self.config.agent} currently does not support the policy named {self.config.policy}.") return policy
[docs] def get_actions(self, obs): self.policy.noise_scale = self.noise_scale _, argmax_action, _ = self.policy(obs) actions = argmax_action.detach().cpu().numpy() return actions
[docs] def train_epochs(self, n_epochs=1): train_info = {} for _ in range(n_epochs): samples = self.memory.sample() self.policy.noise_scale = self.noise_scale train_info = self.learner.update(**samples) return train_info
[docs] def train(self, train_steps: int) -> dict: train_info = {} obs = self.train_envs.buf_obs for _ in tqdm(range(train_steps)): self.obs_rms.update(obs) obs = self._process_observation(obs) acts = self.get_actions(obs) next_obs, rewards, terminals, truncations, infos = self.train_envs.step(acts) self.callback.on_train_step(self.current_step, envs=self.train_envs, policy=self.policy, obs=obs, acts=acts, next_obs=next_obs, rewards=rewards, terminals=terminals, truncations=truncations, infos=infos, train_steps=train_steps) self.memory.store(obs, acts, self._process_reward(rewards), terminals, self._process_observation(next_obs)) if self.current_step > self.start_training and self.current_step % self.training_frequency == 0: update_info = self.train_epochs(n_epochs=self.n_epochs) self.log_infos(update_info, self.current_step) train_info.update(update_info) self.callback.on_train_epochs_end(self.current_step, policy=self.policy, memory=self.memory, current_episode=self.current_episode, train_steps=train_steps, update_info=update_info) obs = deepcopy(next_obs) for i in range(self.n_envs): if terminals[i] or truncations[i]: if self.atari and (~truncations[i]): pass else: obs[i] = infos[i]["reset_obs"] self.train_envs.buf_obs[i] = obs[i] self.current_episode[i] += 1 if self.use_wandb: episode_info = { f"Episode-Steps/rank_{self.rank}/env-{i}": infos[i]["episode_step"], f"Train-Episode-Rewards/rank_{self.rank}/env-{i}": infos[i]["episode_score"] } else: episode_info = { f"Episode-Steps/rank_{self.rank}": {f"env-{i}": infos[i]["episode_step"]}, f"Train-Episode-Rewards/rank_{self.rank}": {f"env-{i}": infos[i]["episode_score"]} } self.log_infos(episode_info, self.current_step) train_info.update(episode_info) self.callback.on_train_episode_info(envs=self.train_envs, policy=self.policy, env_id=i, infos=infos, rank=self.rank, use_wandb=self.use_wandb, current_step=self.current_step, current_episode=self.current_episode, train_steps=train_steps) self.current_step += self.n_envs if self.noise_scale > self.end_noise: self.noise_scale = self.noise_scale - self.delta_noise if terminals[0]: self.policy.update_noise(self.noise_scale) self.callback.on_train_step_end(self.current_step, envs=self.train_envs, policy=self.policy, train_steps=train_steps, train_info=train_info) return train_info
[docs] def test(self, test_episodes: int, test_envs: Optional[DummyVecEnv | SubprocVecEnv] = None, close_envs: bool = True) -> list: if test_envs is None: raise ValueError("`test_envs` must be provided for evaluation.") num_envs = test_envs.num_envs videos, episode_videos, images = [[] for _ in range(num_envs)], [], None current_episode, current_step, scores, best_score = 0, 0, [], -np.inf obs, infos = test_envs.reset() if self.config.render_mode == "rgb_array" and self.render: images = test_envs.render(self.config.render_mode) for idx, img in enumerate(images): videos[idx].append(img) self.policy.noise_scale = 0.0 while current_episode < test_episodes: self.obs_rms.update(obs) obs = self._process_observation(obs) acts = self.get_actions(obs) next_obs, rewards, terminals, truncations, infos = test_envs.step(acts) if self.config.render_mode == "rgb_array" and self.render: images = test_envs.render(self.config.render_mode) for idx, img in enumerate(images): videos[idx].append(img) self.callback.on_test_step(envs=test_envs, policy=self.policy, images=images, obs=obs, acts=acts, next_obs=next_obs, rewards=rewards, terminals=terminals, truncations=truncations, infos=infos, current_train_step=self.current_step, current_step=current_step, current_episode=current_episode) obs = deepcopy(next_obs) for i in range(num_envs): if terminals[i] or truncations[i]: if self.atari and (~truncations[i]): pass else: obs[i] = infos[i]["reset_obs"] scores.append(infos[i]["episode_score"]) current_episode += 1 if best_score < infos[i]["episode_score"]: best_score = infos[i]["episode_score"] episode_videos = videos[i].copy() current_step += num_envs if self.config.render_mode == "rgb_array" and self.render: # time, height, width, channel -> time, channel, height, width videos_info = {"Videos_Test": np.array([episode_videos], dtype=np.uint8).transpose((0, 1, 4, 2, 3))} self.log_videos(info=videos_info, fps=self.fps, x_index=self.current_step) test_info = { "Test-Episode-Rewards/Mean-Score": np.mean(scores), "Test-Episode-Rewards/Std-Score": np.std(scores) } self.log_infos(test_info, self.current_step) self.callback.on_test_end(envs=test_envs, policy=self.policy, current_train_step=self.current_step, current_step=current_step, current_episode=current_episode, scores=scores, best_score=best_score) if close_envs: test_envs.close() return scores