Source code for xuance.torch.agents.qlearning_family.qrdqn_agent
import torch
from argparse import Namespace
from gymnasium.spaces import Space
from xuance.common import Optional, BaseCallback
from xuance.environment import DummyVecEnv, SubprocVecEnv
from xuance.torch import Module
from xuance.torch.utils import NormalizeFunctions, ActivationFunctions
from xuance.torch.policies import REGISTRY_Policy
from xuance.torch.agents.qlearning_family.dqn_agent import DQN_Agent
[docs]
class QRDQN_Agent(DQN_Agent):
"""The implementation of QRDQN agent.
Args:
config: the Namespace variable that provides hyperparameters and other settings.
envs: the vectorized environments.
callback: A user-defined callback function object to inject custom logic during training.
"""
def __init__(
self,
config: Namespace,
envs: Optional[DummyVecEnv | SubprocVecEnv] = None,
observation_space: Optional[Space] = None,
action_space: Optional[Space] = None,
callback: Optional[BaseCallback] = None
):
super(QRDQN_Agent, self).__init__(config, envs, observation_space, action_space, callback)
def _build_policy(self) -> Module:
normalize_fn = NormalizeFunctions[self.config.normalize] if hasattr(self.config, "normalize") else None
initializer = torch.nn.init.orthogonal_
activation = ActivationFunctions[self.config.activation]
device = self.device
# build representation.
representation = self._build_representation(self.config.representation, self.observation_space, self.config)
# build policy.
if self.config.policy == "QR_Q_network":
policy = REGISTRY_Policy["QR_Q_network"](
action_space=self.action_space, quantile_num=self.config.quantile_num,
representation=representation, hidden_size=self.config.q_hidden_size,
normalize=normalize_fn, initialize=initializer, activation=activation, device=device,
use_distributed_training=self.distributed_training)
else:
raise AttributeError(f"{self.config.agent} currently does not support the policy named {self.config.policy}.")
return policy