import torch
import numpy as np
from tqdm import tqdm
from copy import deepcopy
from argparse import Namespace
from gymnasium.spaces import Space
from xuance.common import Optional, BaseCallback
from xuance.environment import DummyVecEnv, SubprocVecEnv
from xuance.torch import Module
from xuance.torch.utils import NormalizeFunctions, ActivationFunctions, split_distributions
from xuance.torch.policies import REGISTRY_Policy
from xuance.torch.agents import OnPolicyAgent
[docs]
class PPG_Agent(OnPolicyAgent):
"""The implementation of PPG agent.
Args:
config: the Namespace variable that provides hyperparameters and other settings.
envs: the vectorized environments.
callback: A user-defined callback function object to inject custom logic during training.
"""
def __init__(
self,
config: Namespace,
envs: Optional[DummyVecEnv | SubprocVecEnv] = None,
observation_space: Optional[Space] = None,
action_space: Optional[Space] = None,
callback: Optional[BaseCallback] = None
):
super(PPG_Agent, self).__init__(config, envs, observation_space, action_space, callback)
self.policy_nepoch = config.policy_nepoch
self.value_nepoch = config.value_nepoch
self.aux_nepoch = config.aux_nepoch
self.auxiliary_info_shape = {"old_dist": None}
self.memory = self._build_memory(self.auxiliary_info_shape) # build memory
self.policy = self._build_policy() # build policy
self.learner = self._build_learner(self.config, self.policy, self.callback) # build learner.
def _build_policy(self) -> Module:
normalize_fn = NormalizeFunctions[self.config.normalize] if hasattr(self.config, "normalize") else None
initializer = torch.nn.init.orthogonal_
activation = ActivationFunctions[self.config.activation]
device = self.device
# build representation.
representation = self._build_representation(self.config.representation, self.observation_space, self.config)
# build policy.
if self.config.policy == "Categorical_PPG":
policy = REGISTRY_Policy["Categorical_PPG"](
action_space=self.action_space, representation=representation,
actor_hidden_size=self.config.actor_hidden_size, critic_hidden_size=self.config.critic_hidden_size,
normalize=normalize_fn, initialize=initializer, activation=activation, device=device,
use_distributed_training=self.distributed_training)
elif self.config.policy == "Gaussian_PPG":
policy = REGISTRY_Policy["Gaussian_PPG"](
action_space=self.action_space, representation=representation,
actor_hidden_size=self.config.actor_hidden_size, critic_hidden_size=self.config.critic_hidden_size,
normalize=normalize_fn, initialize=initializer, activation=activation, device=device,
use_distributed_training=self.distributed_training,
activation_action=ActivationFunctions[self.config.activation_action])
else:
raise AttributeError(f"PPG currently does not support the policy named {self.config.policy}.")
return policy
[docs]
def get_actions(self, observations: np.ndarray,
return_dists: bool = False, return_logpi: bool = False):
"""Returns actions and values.
Parameters:
observations (np.ndarray): The observation.
return_dists (bool): Whether to return dists.
return_logpi (bool): Whether to return log_pi.
Returns:
actions: The actions to be executed.
values: The evaluated values.
dists: The policy distributions.
log_pi: Log of stochastic actions.
"""
_, policy_dists, values, _ = self.policy(observations)
actions = policy_dists.stochastic_sample()
log_pi = policy_dists.log_prob(actions) if return_logpi else None
dists = split_distributions(policy_dists) if return_dists else None
actions = actions.detach().cpu().numpy()
values = values.detach().cpu().numpy()
return {"actions": actions, "values": values, "dists": dists, "log_pi": log_pi}
[docs]
def get_aux_info(self, policy_output: dict = None):
"""Returns auxiliary information.
Parameters:
policy_output (dict): The output information of the policy.
Returns:
aux_info (dict): The auxiliary information.
"""
aux_info = {"old_dist": policy_output["dists"]}
return aux_info
[docs]
def train(self, train_steps):
train_info = {}
obs = self.train_envs.buf_obs
for _ in tqdm(range(train_steps)):
self.obs_rms.update(obs)
obs = self._process_observation(obs)
policy_out = self.get_actions(obs, return_dists=True, return_logpi=False)
acts, rets = policy_out['actions'], policy_out['values']
next_obs, rewards, terminals, truncations, infos = self.train_envs.step(acts)
aux_info = self.get_aux_info(policy_out)
self.callback.on_train_step(self.current_step, envs=self.train_envs, policy=self.policy,
obs=obs, policy_out=policy_out, acts=acts, next_obs=next_obs, rewards=rewards,
terminals=terminals, truncations=truncations, infos=infos,
train_steps=train_steps, rets=rets, aux_info=aux_info)
self.memory.store(obs, acts, self._process_reward(rewards), rets, terminals, aux_info)
if self.memory.full:
vals = self.get_terminated_values(next_obs, rewards)
update_info, update_info_policy, update_info_critic, update_info_auxiliary = {}, {}, {}, {}
for i in range(self.n_envs):
if terminals[i]:
self.memory.finish_path(0.0, i)
else:
self.memory.finish_path(vals[i], i)
# policy update
indexes = np.arange(self.buffer_size)
for _ in range(self.policy_nepoch):
np.random.shuffle(indexes)
for start in range(0, self.buffer_size, self.batch_size):
end = start + self.batch_size
sample_idx = indexes[start:end]
samples = self.memory.sample(sample_idx)
update_info_policy = self.learner.update_policy(**samples)
update_info.update(update_info_policy)
# critic update
for _ in range(self.value_nepoch):
np.random.shuffle(indexes)
for start in range(0, self.buffer_size, self.batch_size):
end = start + self.batch_size
sample_idx = indexes[start:end]
samples = self.memory.sample(sample_idx)
update_info_critic = self.learner.update_critic(**samples)
update_info.update(update_info_critic)
# update old_prob
buffer_obs = self.memory.observations
buffer_act = self.memory.actions
new_policy_out = self.get_actions(buffer_obs, return_dists=True)
aux_info = self.get_aux_info(new_policy_out)
self.memory.auxiliary_infos.update(aux_info)
for _ in range(self.aux_nepoch):
np.random.shuffle(indexes)
for start in range(0, self.buffer_size, self.batch_size):
end = start + self.batch_size
sample_idx = indexes[start:end]
samples = self.memory.sample(sample_idx)
update_info_auxiliary = self.learner.update_auxiliary(**samples)
update_info.update(update_info_auxiliary)
self.log_infos(update_info, self.current_step)
train_info.update(update_info)
self.callback.on_train_epochs_end(self.current_step, policy=self.policy, memory=self.memory,
current_episode=self.current_episode, train_steps=train_steps,
update_info=update_info)
self.memory.clear()
self.returns = self.gamma * self.returns + rewards
obs = deepcopy(next_obs)
for i in range(self.n_envs):
if terminals[i] or truncations[i]:
self.ret_rms.update(self.returns[i:i + 1])
self.returns[i] = 0.0
if self.atari and (~truncations[i]):
pass
else:
if terminals[i]:
self.memory.finish_path(0, i)
else:
vals = self.get_terminated_values(next_obs, rewards)
self.memory.finish_path(vals[i], i)
obs[i] = infos[i]["reset_obs"]
self.train_envs.buf_obs[i] = obs[i]
self.current_episode[i] += 1
if self.use_wandb:
episode_info = {
f"Episode-Steps/rank_{self.rank}/env-{i}": infos[i]["episode_step"],
f"Train-Episode-Rewards/rank_{self.rank}/env-{i}": infos[i]["episode_score"]
}
else:
episode_info = {
f"Episode-Steps/rank_{self.rank}": {f"env-{i}": infos[i]["episode_step"]},
f"Train-Episode-Rewards/rank_{self.rank}": {f"env-{i}": infos[i]["episode_score"]}
}
self.log_infos(episode_info, self.current_step)
train_info.update(episode_info)
self.callback.on_train_episode_info(envs=self.train_envs, policy=self.policy, env_id=i,
infos=infos, rank=self.rank, use_wandb=self.use_wandb,
current_step=self.current_step,
current_episode=self.current_episode,
train_steps=train_steps)
self.current_step += self.n_envs
self.callback.on_train_step_end(self.current_step, envs=self.train_envs, policy=self.policy,
train_steps=train_steps, train_info=train_info)
return train_info