import numpy as np
from tqdm import tqdm
from copy import deepcopy
from argparse import Namespace
from gymnasium.spaces import Space
from xuance.common import Optional, DummyOffPolicyBuffer, DummyOffPolicyBuffer_Atari, BaseCallback
from xuance.environment import DummyVecEnv, SubprocVecEnv
from xuance.tensorflow import Module
from xuance.tensorflow.utils import NormalizeFunctions, ActivationFunctions, InitializeFunctions
from xuance.tensorflow.policies import REGISTRY_Policy
from xuance.tensorflow.agents import Agent
[docs]
class NoisyDQN_Agent(Agent):
def __init__(
self,
config: Namespace,
envs: Optional[DummyVecEnv | SubprocVecEnv] = None,
observation_space: Optional[Space] = None,
action_space: Optional[Space] = None,
callback: Optional[BaseCallback] = None
):
super(NoisyDQN_Agent, self).__init__(config, envs, observation_space, action_space, callback)
self.start_noise, self.end_noise = config.start_noise, config.end_noise
self.noise_scale = config.start_noise
self.delta_noise = (self.start_noise - self.end_noise) / (self.config.decay_step_noise / self.n_envs)
# Build policy, optimizer, scheduler.
self.policy = self._build_policy()
# Create experience replay buffer.
input_buffer = dict(observation_space=self.observation_space,
action_space=self.action_space,
auxiliary_shape={},
n_envs=self.n_envs,
buffer_size=self.config.buffer_size,
batch_size=self.config.batch_size)
self.auxiliary_info_shape = {}
self.atari = True if config.env_name == "Atari" else False
Buffer = DummyOffPolicyBuffer_Atari if self.atari else DummyOffPolicyBuffer
self.memory = Buffer(**input_buffer)
self.learner = self._build_learner(self.config, self.policy, self.callback)
def _build_policy(self) -> Module:
normalize_fn = NormalizeFunctions[self.config.normalize] if hasattr(self.config, "normalize") else None
initializer = InitializeFunctions[self.config.initialize] if hasattr(self.config, "initialize") else None
activation = ActivationFunctions[self.config.activation]
# build representation.
representation = self._build_representation(self.config.representation, self.observation_space, self.config)
# build policy.
if self.config.policy == "Noisy_Q_network":
policy = REGISTRY_Policy["Noisy_Q_network"](
action_space=self.action_space, representation=representation, hidden_size=self.config.q_hidden_size,
normalize=normalize_fn, initialize=initializer, activation=activation,
use_distributed_training=self.distributed_training)
else:
raise AttributeError(f"{self.config.agent} currently does not support the policy named {self.config.policy}.")
return policy
[docs]
def get_actions(self, obs):
self.policy.noise_scale = self.noise_scale
_, argmax_action, _ = self.policy(obs)
action = argmax_action.numpy()
return action
[docs]
def train_epochs(self, n_epochs=1):
train_info = {}
for _ in range(n_epochs):
samples = self.memory.sample()
self.policy.noise_scale = self.noise_scale
train_info = self.learner.update(**samples)
return train_info
[docs]
def train(self, train_steps: int) -> dict:
train_info = {}
obs = self.train_envs.buf_obs
for _ in tqdm(range(train_steps)):
self.obs_rms.update(obs)
obs = self._process_observation(obs)
acts = self.get_actions(obs)
next_obs, rewards, terminals, truncations, infos = self.train_envs.step(acts)
self.callback.on_train_step(self.current_step, envs=self.train_envs, policy=self.policy,
obs=obs, acts=acts, next_obs=next_obs, rewards=rewards,
terminals=terminals, truncations=truncations, infos=infos,
train_steps=train_steps)
self.memory.store(obs, acts, self._process_reward(rewards), terminals, self._process_observation(next_obs))
if self.current_step > self.start_training and self.current_step % self.training_frequency == 0:
update_info = self.train_epochs(n_epochs=self.n_epochs)
self.log_infos(update_info, self.current_step)
train_info.update(update_info)
self.callback.on_train_epochs_end(self.current_step, policy=self.policy, memory=self.memory,
current_episode=self.current_episode, train_steps=train_steps,
update_info=update_info)
obs = deepcopy(next_obs)
for i in range(self.n_envs):
if terminals[i] or truncations[i]:
if self.atari and (~truncations[i]):
pass
else:
obs[i] = infos[i]["reset_obs"]
self.train_envs.buf_obs[i] = obs[i]
self.current_episode[i] += 1
if self.use_wandb:
episode_info = {
f"Episode-Steps/env-{i}": infos[i]["episode_step"],
f"Train-Episode-Rewards/env-{i}": infos[i]["episode_score"]
}
else:
episode_info = {
f"Episode-Steps": {f"env-{i}": infos[i]["episode_step"]},
f"Train-Episode-Rewards": {f"env-{i}": infos[i]["episode_score"]}
}
self.log_infos(episode_info, self.current_step)
train_info.update(episode_info)
self.callback.on_train_episode_info(envs=self.train_envs, policy=self.policy, env_id=i,
infos=infos, use_wandb=self.use_wandb,
current_step=self.current_step,
current_episode=self.current_episode,
train_steps=train_steps)
self.current_step += self.n_envs
if self.noise_scale > self.end_noise:
self.noise_scale = self.noise_scale - self.delta_noise
if terminals[0]:
self.policy.update_noise(self.noise_scale)
self.callback.on_train_step_end(self.current_step, envs=self.train_envs, policy=self.policy,
train_steps=train_steps, train_info=train_info)
return train_info
[docs]
def test(self,
test_episodes: int,
test_envs: Optional[DummyVecEnv | SubprocVecEnv] = None,
close_envs: bool = True) -> list:
if test_envs is None:
raise ValueError("`test_envs` must be provided for evaluation.")
num_envs = test_envs.num_envs
videos, episode_videos, images = [[] for _ in range(num_envs)], [], None
current_episode, current_step, scores, best_score = 0, 0, [], -np.inf
obs, infos = test_envs.reset()
if self.config.render_mode == "rgb_array" and self.render:
images = test_envs.render(self.config.render_mode)
for idx, img in enumerate(images):
videos[idx].append(img)
self.policy.noise_scale = 0.0
while current_episode < test_episodes:
self.obs_rms.update(obs)
obs = self._process_observation(obs)
acts = self.get_actions(obs)
next_obs, rewards, terminals, truncations, infos = test_envs.step(acts)
if self.config.render_mode == "rgb_array" and self.render:
images = test_envs.render(self.config.render_mode)
for idx, img in enumerate(images):
videos[idx].append(img)
self.callback.on_test_step(envs=test_envs, policy=self.policy, images=images,
obs=obs, acts=acts, next_obs=next_obs, rewards=rewards,
terminals=terminals, truncations=truncations, infos=infos,
current_train_step=self.current_step,
current_step=current_step, current_episode=current_episode)
obs = deepcopy(next_obs)
for i in range(num_envs):
if terminals[i] or truncations[i]:
if self.atari and (~truncations[i]):
pass
else:
obs[i] = infos[i]["reset_obs"]
scores.append(infos[i]["episode_score"])
current_episode += 1
if best_score < infos[i]["episode_score"]:
best_score = infos[i]["episode_score"]
episode_videos = videos[i].copy()
current_step += num_envs
if self.config.render_mode == "rgb_array" and self.render:
# time, height, width, channel -> time, channel, height, width
videos_info = {"Videos_Test": np.array([episode_videos], dtype=np.uint8).transpose((0, 1, 4, 2, 3))}
self.log_videos(info=videos_info, fps=self.fps, x_index=self.current_step)
test_info = {
"Test-Episode-Rewards/Mean-Score": np.mean(scores),
"Test-Episode-Rewards/Std-Score": np.std(scores)
}
self.log_infos(test_info, self.current_step)
self.callback.on_test_end(envs=test_envs, policy=self.policy,
current_train_step=self.current_step,
current_step=current_step, current_episode=current_episode,
scores=scores, best_score=best_score)
if close_envs:
test_envs.close()
return scores