Source code for xuance.tensorflow.agents.contrastive_unsupervised_rl.drq_agent
import torch
from argparse import Namespace
from gymnasium.spaces import Space
from xuance.common import Optional, BaseCallback
from xuance.environment import DummyVecEnv, SubprocVecEnv
from xuance.torch import Module
from xuance.torch.utils import NormalizeFunctions, ActivationFunctions
from xuance.torch.policies import REGISTRY_Policy
from xuance.torch.agents import OffPolicyAgent
[docs]
class DrQ_Agent(OffPolicyAgent):
"""The implementation of Deep Q-Networks (DQN) agent.
Args:
config: the Namespace variable that provides hyper-parameters and other settings.
envs: the vectorized environments.
"""
def __init__(
self,
config: Namespace,
envs: Optional[DummyVecEnv | SubprocVecEnv] = None,
observation_space: Optional[Space] = None,
action_space: Optional[Space] = None,
callback: Optional[BaseCallback] = None
):
super(DrQ_Agent, self).__init__(config, envs, observation_space, action_space, callback)
self.start_greedy, self.end_greedy = config.start_greedy, config.end_greedy
self.e_greedy = config.start_greedy
self.delta_egreedy = (self.start_greedy - self.end_greedy) / (config.decay_step_greedy / self.n_envs)
self.policy = self._build_policy() # build policy
self.memory = self._build_memory() # build memory
self.learner = self._build_learner(self.config, self.policy) # build learner
def _build_policy(self) -> Module:
normalize_fn = NormalizeFunctions[self.config.normalize] if hasattr(self.config, "normalize") else None
initializer = torch.nn.init.orthogonal_
activation = ActivationFunctions[self.config.activation]
device = self.device
# build representation.
representation = self._build_representation(self.config.representation, self.observation_space, self.config)
# build policy.
if self.config.policy == "Basic_Q_network":
policy = REGISTRY_Policy["Basic_Q_network"](
action_space=self.action_space, representation=representation, hidden_size=self.config.q_hidden_size,
normalize=normalize_fn, initialize=initializer, activation=activation, device=device,
use_distributed_training=self.distributed_training)
else:
raise AttributeError(f"{self.config.agent} does not support the policy named {self.config.policy}.")
return policy