import torch
import torch.nn as nn
from argparse import Namespace
from gymnasium.spaces import Space
from xuance.common import Optional, BaseCallback
from xuance.torch import REGISTRY_Policy
from xuance.torch.agents import OffPolicyAgent
from xuance.torch.learners.contrastive_unsupervised_rl.spr_learner import SPR_Learner
from xuance.torch.utils import NormalizeFunctions, ActivationFunctions
from xuance.environment import DummyVecEnv, SubprocVecEnv
[docs]
class SPR_Agent(OffPolicyAgent):
def __init__(
self,
config: Namespace,
envs: Optional[DummyVecEnv | SubprocVecEnv] = None,
observation_space: Optional[Space] = None,
action_space: Optional[Space] = None,
callback: Optional[BaseCallback] = None
):
super().__init__(config, envs, observation_space, action_space, callback)
self._init_exploration_params(config)
self.policy = self._build_policy()
self.memory = self._build_memory()
self.learner = self._build_learner(self.config, self.policy, self.callback)
def _init_exploration_params(self, config: Namespace):
self.e_greedy = config.start_greedy
self.e_greedy_decay = (config.start_greedy - config.end_greedy) / (config.decay_step_greedy / self.n_envs)
def _build_policy(self) -> nn.Module:
normalize_fn = NormalizeFunctions[self.config.normalize] if hasattr(self.config, "normalize") else None
activation = ActivationFunctions[self.config.activation]
initializer = torch.nn.init.orthogonal_
representation = self._build_representation(self.config.representation, self.observation_space, self.config)
policy = REGISTRY_Policy[self.config.policy](
action_space=self.action_space, representation=representation, hidden_size=self.config.q_hidden_size,
normalize=normalize_fn, initialize=initializer, activation=activation, device=self.device,
use_distributed_training=self.distributed_training)
return SPR_Policy(
device=self.device,
policy=policy,
)
def _build_learner(self, config: Namespace, policy: nn.Module, callback: Optional[BaseCallback] = None):
return SPR_Learner(
config=config,
policy=policy,
temperature=config.temperature,
tau=config.tau,
repr_lr=config.repr_lr,
prediction_steps=config.prediction_steps,
callback=callback,
)
[docs]
class SPR_Policy(nn.Module):
def __init__(self,
device: str = 'cuda:0',
policy: nn.Module = None, ):
super().__init__()
self.device = device
self.policy = policy
self.representation = self.policy.representation.to(device)
self.target_representation = self.policy.target_representation.to(device)
self.q_net = self.policy.eval_Qhead.to(device)
self.target_q_net = self.policy.target_Qhead.to(device)
self.action_dim = policy.action_dim
for param in self.target_q_net.parameters():
param.requires_grad = False
[docs]
def forward(self, x: torch.Tensor):
features = self.representation(x)
evalQ = self.q_net(features['state'])
argmax_action = evalQ.argmax(dim=-1)
return features, argmax_action, evalQ
[docs]
def target(self, x: torch.Tensor):
with torch.no_grad():
features = self.target_representation(x)
evalQ = self.q_net(features['state'])
argmax_action = evalQ.argmax(dim=-1)
return features, argmax_action, evalQ
[docs]
def copy_target(self):
self.target_q_net.load_state_dict(self.q_net.state_dict())