Source code for xuance.torch.agents.qlearning_family.drqn_agent

import torch
import numpy as np
from tqdm import tqdm
from copy import deepcopy
from argparse import Namespace
from gymnasium.spaces import Space
from xuance.common import Optional, RecurrentOffPolicyBuffer, EpisodeBuffer, BaseCallback
from xuance.environment import DummyVecEnv, SubprocVecEnv
from xuance.torch import Module
from xuance.torch.utils import NormalizeFunctions, ActivationFunctions
from xuance.torch.policies import REGISTRY_Policy
from xuance.torch.agents import OffPolicyAgent


[docs] class DRQN_Agent(OffPolicyAgent): def __init__( self, config: Namespace, envs: Optional[DummyVecEnv | SubprocVecEnv] = None, observation_space: Optional[Space] = None, action_space: Optional[Space] = None, callback: Optional[BaseCallback] = None ): super(DRQN_Agent, self).__init__(config, envs, observation_space, action_space, callback) self.start_greedy, self.end_greedy = config.start_greedy, config.end_greedy self.egreedy = config.start_greedy self.delta_egreedy = (self.start_greedy - self.end_greedy) / (config.decay_step_greedy / self.n_envs) self.policy = self._build_policy() # build policy self.auxiliary_info_shape = {} self.memory = self._build_memory(auxiliary_info_shape=self.auxiliary_info_shape) # build memory self.learner = self._build_learner(self.config, self.policy, self.callback) # build learner self.lstm = True if config.rnn == "LSTM" else False def _build_memory(self, auxiliary_info_shape=None): self.atari = True if self.config.env_name == "Atari" else False Buffer = RecurrentOffPolicyBuffer(self.observation_space, self.action_space, auxiliary_info_shape, self.n_envs, self.config.buffer_size, self.config.batch_size, episode_length=self.episode_length, lookup_length=self.config.lookup_length) return Buffer def _build_policy(self) -> Module: normalize_fn = NormalizeFunctions[self.config.normalize] if hasattr(self.config, "normalize") else None initializer = torch.nn.init.orthogonal_ activation = ActivationFunctions[self.config.activation] device = self.device # build representation. representation = self._build_representation(self.config.representation, self.observation_space, self.config) # build policy. if self.config.policy == "DRQN_Policy": policy = REGISTRY_Policy["DRQN_Policy"]( action_space=self.action_space, representation=representation, rnn=self.config.rnn, recurrent_hidden_size=self.config.recurrent_hidden_size, recurrent_layer_N=self.config.recurrent_layer_N, dropout=self.config.dropout, normalize=normalize_fn, initialize=initializer, activation=activation, device=device, use_distributed_training=self.distributed_training) else: raise AttributeError( f"{self.config.agent} currently does not support the policy named {self.config.policy}.") return policy
[docs] def get_actions(self, obs, egreedy=0.0, rnn_hidden=None): _, argmax_action, _, rnn_hidden_next = self.policy(obs[:, None], *rnn_hidden) random_action = np.random.choice(self.action_space.n, self.n_envs) if np.random.rand() < egreedy: actions = random_action else: actions = argmax_action.detach().cpu().numpy() return {"actions": actions, "rnn_hidden_next": rnn_hidden_next}
[docs] def train(self, train_steps: int) -> dict: train_info = {} obs = self.train_envs.buf_obs episode_data = [EpisodeBuffer() for _ in range(self.n_envs)] for i_env in range(self.n_envs): episode_data[i_env].obs.append(self._process_observation(obs[i_env])) self.rnn_hidden = self.policy.init_hidden(self.n_envs) dones = [False for _ in range(self.n_envs)] for _ in tqdm(range(train_steps)): self.obs_rms.update(obs) obs = self._process_observation(obs) policy_out = self.get_actions(obs, self.egreedy, self.rnn_hidden) acts, self.rnn_hidden = policy_out['actions'], policy_out['rnn_hidden_next'] next_obs, rewards, terminals, truncations, infos = self.train_envs.step(acts) self.callback.on_train_step(self.current_step, envs=self.train_envs, policy=self.policy, obs=obs, policy_out=policy_out, acts=acts, next_obs=next_obs, rewards=rewards, terminals=terminals, truncations=truncations, infos=infos, train_steps=train_steps, rnn_hidden=self.rnn_hidden) if (self.current_step > self.start_training) and (self.current_step % self.training_frequency == 0): # training update_info = self.train_epochs(n_epochs=1) self.log_infos(update_info, self.current_step) train_info.update(update_info) obs = deepcopy(next_obs) for i in range(self.n_envs): episode_data[i].put( [self._process_observation(obs[i]), acts[i], self._process_reward(rewards[i]), terminals[i]]) if terminals[i] or truncations[i]: if self.atari and (~truncations[i]): pass else: self.rnn_hidden = self.policy.init_hidden_item(self.rnn_hidden, i) dones[i] = True self.current_episode[i] += 1 if self.use_wandb: episode_info = { f"Episode-Steps/rank_{self.rank}/env-{i}": infos[i]["episode_step"], f"Train-Episode-Rewards/rank_{self.rank}/env-{i}": infos[i]["episode_score"] } else: episode_info = { f"Episode-Steps/rank_{self.rank}": {f"env-{i}": infos[i]["episode_step"]}, f"Train-Episode-Rewards/rank_{self.rank}": {f"env-{i}": infos[i]["episode_score"]} } self.log_infos(episode_info, self.current_step) train_info.update(episode_info) self.memory.store(episode_data[i]) episode_data[i] = EpisodeBuffer() obs[i] = infos[i]["reset_obs"] self.train_envs.buf_obs[i] = obs[i] episode_data[i].obs.append(self._process_observation(obs[i])) self.callback.on_train_episode_info(envs=self.train_envs, policy=self.policy, env_id=i, memory=self.memory, infos=infos, rank=self.rank, use_wandb=self.use_wandb, current_step=self.current_step, current_episode=self.current_episode, train_steps=train_steps) self.current_step += self.n_envs if self.egreedy > self.end_greedy: self.egreedy = self.egreedy - self.delta_egreedy self.callback.on_train_step_end(self.current_step, envs=self.train_envs, policy=self.policy, train_steps=train_steps, train_info=train_info) return train_info
[docs] def test(self, test_episodes: int, test_envs: Optional[DummyVecEnv | SubprocVecEnv] = None, close_envs: bool = True) -> list: if test_envs is None: raise ValueError("`test_envs` must be provided for evaluation.") num_envs = test_envs.num_envs videos, episode_videos, images = [[] for _ in range(num_envs)], [], None current_episode, current_step, scores, best_score = 0, 0, [], -np.inf obs, infos = test_envs.reset() if self.config.render_mode == "rgb_array" and self.render: images = test_envs.render(self.config.render_mode) for idx, img in enumerate(images): videos[idx].append(img) rnn_hidden = self.policy.init_hidden(num_envs) while current_episode < test_episodes: self.obs_rms.update(obs) obs = self._process_observation(obs) policy_out = self.get_actions(obs, egreedy=0.0, rnn_hidden=rnn_hidden) acts, rnn_hidden = policy_out['actions'], policy_out['rnn_hidden_next'] next_obs, rewards, terminals, truncations, infos = test_envs.step(acts) if self.config.render_mode == "rgb_array" and self.render: images = test_envs.render(self.config.render_mode) for idx, img in enumerate(images): videos[idx].append(img) self.callback.on_test_step(envs=test_envs, policy=self.policy, images=images, rnn_hidden=rnn_hidden, obs=obs, policy_out=policy_out, next_obs=next_obs, rewards=rewards, terminals=terminals, truncations=truncations, infos=infos, current_train_step=self.current_step, current_step=current_step, current_episode=current_episode) obs = deepcopy(next_obs) for i in range(num_envs): if terminals[i] or truncations[i]: if self.atari and (~truncations[i]): pass else: obs[i] = infos[i]["reset_obs"] rnn_hidden = self.policy.init_hidden_item(rnn_hidden, i) scores.append(infos[i]["episode_score"]) current_episode += 1 if best_score < infos[i]["episode_score"]: best_score = infos[i]["episode_score"] episode_videos = videos[i].copy() current_step += num_envs if self.config.render_mode == "rgb_array" and self.render: # time, height, width, channel -> time, channel, height, width videos_info = {"Videos_Test": np.array([episode_videos], dtype=np.uint8).transpose((0, 1, 4, 2, 3))} self.log_videos(info=videos_info, fps=self.fps, x_index=self.current_step) test_info = { "Test-Episode-Rewards/Mean-Score": np.mean(scores), "Test-Episode-Rewards/Std-Score": np.std(scores) } self.log_infos(test_info, self.current_step) self.callback.on_test_end(envs=test_envs, policy=self.policy, current_train_step=self.current_step, current_step=current_step, current_episode=current_episode, scores=scores, best_score=best_score) if close_envs: test_envs.close() return scores