import torch
from argparse import Namespace
from torch import nn
from gymnasium.spaces import Space
from xuance.common import Optional, BaseCallback
from xuance.environment import DummyVecEnv, SubprocVecEnv
from xuance.torch import REGISTRY_Policy
from xuance.torch.learners import CURL_Learner
from xuance.torch.utils import NormalizeFunctions, ActivationFunctions
from xuance.torch.agents import OffPolicyAgent
[docs]
class CURL_Agent(OffPolicyAgent):
def __init__(
self,
config: Namespace,
envs: Optional[DummyVecEnv | SubprocVecEnv] = None,
observation_space: Optional[Space] = None,
action_space: Optional[Space] = None,
callback: Optional[BaseCallback] = None
):
super(CURL_Agent, self).__init__(config, envs, observation_space, action_space, callback)
self._init_exploration_params(config)
self.policy = self._build_policy()
self.memory = self._build_memory()
self.learner = self._build_learner(self.config, self.policy, self.callback)
def _init_exploration_params(self, config: Namespace):
self.e_greedy = config.start_greedy
self.e_greedy_decay = (config.start_greedy - config.end_greedy) / (config.decay_step_greedy / self.n_envs)
def _build_policy(self) -> nn.Module:
normalize_fn = NormalizeFunctions[self.config.normalize] if hasattr(self.config, "normalize") else None
activation = ActivationFunctions[self.config.activation]
initializer = torch.nn.init.orthogonal_
representation = self._build_representation(self.config.representation, self.observation_space, self.config)
policy = REGISTRY_Policy[self.config.policy](
action_space=self.action_space, representation=representation, hidden_size=self.config.q_hidden_size,
normalize=normalize_fn, initialize=initializer, activation=activation, device=self.device,
use_distributed_training=self.distributed_training)
return CURL_Policy(
device=self.device,
policy = policy,
)
def _build_learner(self, config: Namespace, policy: nn.Module, callback: Optional[BaseCallback] = None):
return CURL_Learner(
config=config,
policy=policy,
callback=callback,
)
def _learn(self, batch_size: int):
batch = self.memory.sample(batch_size)
samples = {
'obs': torch.as_tensor(batch['obs'], device=self.device),
'actions': torch.as_tensor(batch['actions'], device=self.device),
'rewards': torch.as_tensor(batch['rewards'], device=self.device),
'obs_next': torch.as_tensor(batch['next_obs'], device=self.device),
'terminals': torch.as_tensor(batch['dones'], dtype=torch.float, device=self.device)
}
learner_info = self.learner.update(**samples)
return {
"curl_loss": learner_info["curl_loss"],
"dqn_loss": learner_info["q_loss"],
"predictQ": learner_info["predictQ"],
"learning_rate": learner_info["learning_rate"]
}
[docs]
class CURL_Policy(nn.Module):
def __init__(self,
device: str = 'cuda:0',
policy: nn.Module = None,):
super(CURL_Policy, self).__init__()
self.device = device
self.policy = policy
self.representation = self.policy.representation.to(device)
self.target_representation = self.policy.target_representation.to(device)
self.q_net = self.policy.eval_Qhead.to(device)
self.target_q_net = self.policy.target_Qhead.to(device)
self.action_dim = policy.action_dim
self.target_q_net.load_state_dict(self.q_net.state_dict())
for param in self.target_q_net.parameters():
param.requires_grad = False
[docs]
def forward(self, x: torch.Tensor):
features = self.representation(x)
evalQ = self.q_net(features['state'])
argmax_action = evalQ.argmax(dim=-1)
return features,argmax_action, evalQ
[docs]
def target(self, x: torch.Tensor):
with torch.no_grad():
features = self.target_representation(x)
evalQ = self.q_net(features['state'])
argmax_action = evalQ.argmax(dim=-1)
return features, argmax_action, evalQ
[docs]
def copy_target(self):
self.target_q_net.load_state_dict(self.q_net.state_dict())