import torch
from copy import deepcopy
from xuance.common import List, Union, SequentialReplayBuffer, BaseCallback
from xuance.environment import DummyVecEnv, SubprocVecEnv
from xuance.torch.agents import OffPolicyAgent
from xuance.torch import REGISTRY_Representation, REGISTRY_Policy
from xuance.torch.representations.world_model_v2 import DreamerV2WorldModel, PlayerDV2
from xuance.torch.policies import DreamerV2Policy
import numpy as np
from tqdm import tqdm
import gymnasium as gym
from argparse import Namespace
from xuance.common import Optional
[docs]
class DreamerV2Agent(OffPolicyAgent):
def __init__(self,
config: Namespace,
envs: Union[DummyVecEnv, SubprocVecEnv],
callback: Optional[BaseCallback] = None):
super(DreamerV2Agent, self).__init__(config, envs, callback)
# special judge for atari env
self.atari = True if self.config.env_name == "Atari" else False
# continuous or not
self.is_continuous = (isinstance(self.train_envs.action_space, gym.spaces.Box))
self.is_multidiscrete = isinstance(self.train_envs.action_space, gym.spaces.MultiDiscrete)
self.config.is_continuous = self.is_continuous # add to config
# obs_shape & act_shape
self.obs_shape = self.observation_space.shape
"""
hwc 2 chw:
agent & memory both uses 'hwc'
obs needed to be transformed to 'chw' and be normalized before sample & taking an action
"""
if self.config.pixel:
self.obs_shape = (self.obs_shape[2], ) + self.obs_shape[:2]
self.act_shape = self.action_space.n if not self.is_continuous else self.action_space.shape
self.config.act_shape = self.act_shape # add to config
# ratio
self.replay_ratio = self.config.replay_ratio
self.current_step, self.gradient_step = 0, 0
# REGISTRY & create: representation, policy, learner
REGISTRY_Representation['DreamerV2WorldModel'] = DreamerV2WorldModel
self.model = self._build_representation(representation_key="DreamerV2WorldModel",
config=None, input_space=None)
REGISTRY_Policy["DreamerV2Policy"] = DreamerV2Policy
self.policy = self._build_policy()
self.memory = self._build_memory()
self.learner = self._build_learner(self.config, self.policy, self.act_shape, self.callback)
# train_player & train_states; make sure train & test to be independent
self.train_player: PlayerDV2 = self.model.player
self.train_player.init_states()
self.train_states: List[np.ndarray] = [
self.train_envs.buf_obs, # obs: (envs, *obs_shape),
np.zeros((self.train_envs.num_envs, )), # rews
np.zeros((self.train_envs.num_envs, )), # terms
np.zeros((self.train_envs.num_envs, )), # truncs
np.ones((self.train_envs.num_envs, )) # is_first
]
def _build_representation(self, representation_key: str,
input_space: Optional[gym.spaces.Space],
config: Optional[Namespace]) -> DreamerV2WorldModel:
# specify the type in order to use code completion
actions_dim = tuple(
self.train_envs.action_space.shape
if self.is_continuous else (
self.train_envs.action_space.nvec.tolist() if self.is_multidiscrete else [self.train_envs.action_space.n]
)
)
input_representations = dict(
actions_dim=actions_dim,
is_continuous=self.is_continuous,
config=self.config,
obs_space=self.train_envs.observation_space
)
representation = REGISTRY_Representation[representation_key](**input_representations)
if representation_key not in REGISTRY_Representation:
raise AttributeError(f"{representation_key} is not registered in REGISTRY_Representation.")
return representation
def _build_memory(self, auxiliary_info_shape=None) -> SequentialReplayBuffer:
input_buffer = dict(observation_space=self.observation_space,
action_space=self.action_space,
auxiliary_shape=auxiliary_info_shape,
n_envs=self.n_envs,
buffer_size=self.buffer_size,
batch_size=self.batch_size)
return SequentialReplayBuffer(**input_buffer)
def _build_policy(self) -> DreamerV2Policy:
return REGISTRY_Policy["DreamerV2Policy"](self.model, self.config)
[docs]
def get_actions(self,
obs: np.ndarray,
test_mode: Optional[bool] = False,
player: Optional[PlayerDV2] = None) -> np.ndarray:
"""Returns actions and values.
Parameters:
obs (np.ndarray): The observation.
test_mode (Optional[bool]): True for testing without noises.
player (Optional[PlayerDV2]): The player whose action is taken, default is train_player.
Returns:
actions: The real_actions to be executed.
"""
if self.config.pixel:
obs = obs.transpose(0, 3, 1, 2) / 255.0 - 0.5
player = player if player is not None else self.train_player
# actions_output = self.policy(observations)
# [envs, *obs_shape] -> [1: batch, envs, *obs_shape]
obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32).unsqueeze(0)
with torch.no_grad():
actions = player.get_actions(obs, greedy=test_mode, mask=None)[0][0]
# ont-hot -> real_actions
if not self.is_continuous:
actions = actions.argmax(dim=1).detach().cpu().numpy()
else: # [1, envs, *act_shape]
actions = actions.reshape(obs.shape[1], *self.act_shape).detach().cpu().numpy()
# action mapping in xuance.environment.utils.wrapper.XuanCeEnvWrapper.step
# actions = (actions + 1.0) * 0.5 * (self.actions_high - self.actions_low) + self.actions_low # action_scaling
"""
for env_interaction: actions.shape, (envs, ) or (env, *act_shape)
"""
return actions
[docs]
def train_epochs(self, n_epochs: int = 1):
train_info = {}
samples = self.memory.sample(self.config.seq_len) # (envs, seq, batch, ~)
if self.config.pixel:
samples['obs'] = samples['obs'].transpose(0, 1, 2, 5, 3, 4) / 255.0 - 0.5
# n_epoch(n_gradient step) scattered to each environment
# st = np.random.choice(np.arange(self.train_envs.num_envs), 1).item() # not necessary
st = 0
for _ in range(n_epochs): # assert n_epochs == parallels
cur_samples = {k: v[(st + _) % self.train_envs.num_envs] for k, v in samples.items()}
train_info = self.learner.update(**cur_samples)
return train_info
[docs]
def train(self, train_steps): # each train still uses old rssm_states until episode end
train_info = {}
obs, rews, terms, truncs, is_first = self.train_states
for _ in tqdm(range(train_steps)):
self.obs_rms.update(obs)
obs = self._process_observation(obs)
if self.current_step < self.start_training: # ramdom_sample before training
acts = np.array([self.train_envs.action_space.sample() for _ in range(self.train_envs.num_envs)])
else:
acts = self.get_actions(obs)
if self.atari: # use truncs to train in xc_atari
terms = deepcopy(truncs)
"""(o1, a1, r1, term1, trunc1, is_first1), acts: real_acts"""
self.memory.store(obs, acts, self._process_reward(rews), terms, truncs, is_first)
next_obs, rews, terms, truncs, infos = self.train_envs.step(acts)
self.callback.on_train_step(self.current_step, envs=self.train_envs, policy=self.policy,
obs=obs, acts=acts, next_obs=next_obs, rewards=rews,
terminals=terms, truncations=truncs, infos=infos,
train_steps=train_steps)
"""
set to zeros after the first step
(o2, a1, r2, term2, trunc2, is_first2)
"""
is_first = np.zeros_like(terms)
obs = next_obs
self.returns = self.gamma * self.returns + rews
done_idxes = []
for i in range(self.n_envs):
if terms[i] or truncs[i]:
if self.atari and (~truncs[i]): # do not term until trunc
pass
else:
# carry the reset procedure to the outside
done_idxes.append(i)
self.ret_rms.update(self.returns[i:i + 1])
self.returns[i] = 0.0
self.current_episode[i] += 1
if self.use_wandb:
episode_info = {
f"Episode-Steps/rank_{self.rank}/env-{i}": infos[i]["episode_step"],
f"Train-Episode-Rewards/rank_{self.rank}/env-{i}": infos[i]["episode_score"]
}
else:
episode_info = {
f"Episode-Steps/rank_{self.rank}": {f"env-{i}": infos[i]["episode_step"]},
f"Train-Episode-Rewards/rank_{self.rank}": {f"env-{i}": infos[i]["episode_score"]}
}
self.log_infos(episode_info, self.current_step)
train_info.update(episode_info)
self.callback.on_train_episode_info(envs=self.train_envs, policy=self.policy, env_id=i,
infos=infos, rank=self.rank, use_wandb=self.use_wandb,
current_step=self.current_step,
current_episode=self.current_episode,
train_steps=train_steps)
self.current_step += self.n_envs
if len(done_idxes) > 0:
"""
store the last data and reset all
(o_t, a_t = 0 for dones, r_t, term_t, trunc_t, is_first_t)
"""
extra_shape = () if not self.is_continuous else self.act_shape
acts[done_idxes] = np.zeros((len(done_idxes),) + extra_shape)
if self.atari: # use truncs to train in xc_atari
terms = deepcopy(truncs)
self.memory.store(obs, acts, self._process_reward(rews), terms, truncs, is_first)
"""reset DreamerV2 Player's states"""
obs[done_idxes] = np.stack([infos[idx]["reset_obs"] for idx in done_idxes]) # reset obs
self.train_envs.buf_obs[done_idxes] = obs[done_idxes]
rews[done_idxes] = np.zeros((len(done_idxes), ))
terms[done_idxes] = np.zeros((len(done_idxes), ))
truncs[done_idxes] = np.zeros((len(done_idxes), ))
is_first[done_idxes] = np.ones_like(terms[done_idxes])
self.train_player.init_states(done_idxes)
"""
start training
replay_ratio = self.gradient_step / self.current_step
"""
if self.current_step > self.start_training:
# count current_step when start_training
n_epochs = max(int((self.current_step - self.start_training) * self.replay_ratio - self.gradient_step), 0)
update_info = self.train_epochs(n_epochs=n_epochs)
self.gradient_step += n_epochs
if update_info is not None:
self.log_infos(update_info, self.current_step)
train_info.update(train_info)
self.callback.on_train_epochs_end(self.current_step, policy=self.policy, memory=self.memory,
current_episode=self.current_episode, train_steps=train_steps,
update_info=update_info)
self.callback.on_train_step_end(self.current_step, envs=self.train_envs, policy=self.policy,
train_steps=train_steps, train_info=train_info)
# save the train_states for next train
self.train_states = [obs, rews, terms, truncs, is_first]
return train_info
[docs]
def test(self,
test_episodes: int,
test_envs: Optional[DummyVecEnv | SubprocVecEnv] = None,
close_envs: bool = True) -> list:
if test_envs is None:
raise ValueError("`test_envs` must be provided for evaluation.")
num_envs = test_envs.num_envs
# copy the total network for test
test_player = deepcopy(self.train_player)
test_player.init_states(num_envs=num_envs)
videos, episode_videos, images = [[] for _ in range(num_envs)], [], None
current_episode, current_step, scores, best_score = 0, 0, [], -np.inf
obs, infos = test_envs.reset()
if self.config.render_mode == "rgb_array" and self.render:
images = test_envs.render(self.config.render_mode)
for idx, img in enumerate(images):
videos[idx].append(img)
while current_episode < test_episodes:
self.obs_rms.update(obs)
obs = self._process_observation(obs)
acts = self.get_actions(obs, test_mode=True, player=test_player)
next_obs, rews, terms, truncs, infos = test_envs.step(acts)
if self.config.render_mode == "rgb_array" and self.render:
images = test_envs.render(self.config.render_mode)
for idx, img in enumerate(images):
videos[idx].append(img)
self.callback.on_test_step(envs=test_envs, policy=self.policy, images=images,
obs=obs, next_obs=next_obs, rewards=rews,
terminals=terms, truncations=truncs, infos=infos,
current_train_step=self.current_step,
current_step=current_step, current_episode=current_episode)
obs = deepcopy(next_obs)
done_idxes = []
for i in range(num_envs):
if terms[i] or truncs[i]:
if self.atari and (~truncs[i]):
pass
else:
done_idxes.append(i)
obs[i] = infos[i]["reset_obs"]
scores.append(infos[i]["episode_score"])
current_episode += 1
if best_score < infos[i]["episode_score"]:
best_score = infos[i]["episode_score"]
episode_videos = videos[i].copy()
current_step += num_envs
if len(done_idxes) > 0:
test_player.init_states(reset_envs=done_idxes, num_envs=num_envs)
if self.config.render_mode == "rgb_array" and self.render:
# time, height, width, channel -> time, channel, height, width
videos_info = {"Videos_Test": np.array([episode_videos], dtype=np.uint8).transpose((0, 1, 4, 2, 3))}
self.log_videos(info=videos_info, fps=self.fps, x_index=self.current_step) # fps cannot work
test_info = {
"Test-Episode-Rewards/Mean-Score": np.mean(scores),
"Test-Episode-Rewards/Std-Score": np.std(scores)
}
self.log_infos(test_info, self.current_step)
test_envs.close()
self.callback.on_test_end(envs=test_envs, policy=self.policy,
current_train_step=self.current_step,
current_step=current_step, current_episode=current_episode,
scores=scores, best_score=best_score)
if close_envs:
test_envs.close()
return scores