import torch
import torch.nn as nn
from xuance.torch.learners import Learner
from argparse import Namespace
try:
from torchvision import transforms
except:
pass
[docs]
class CURL_Learner(Learner):
def __init__(self,
config: Namespace,
policy: nn.Module,
callback):
super(CURL_Learner, self).__init__(config, policy, callback)
self.temperature = config.temperature # temperature of InfoNCE Loss
self.tau = config.tau # moment update coefficient
self.sync_frequency = config.sync_frequency
self.encoder_optim = torch.optim.Adam(
self.policy.representation.parameters(),
lr=config.repr_lr,
eps=1e-5
)
self.q_optim = torch.optim.Adam(
self.policy.parameters(),
lr=self.learning_rate,
eps=1e-5
)
self.scheduler = torch.optim.lr_scheduler.LinearLR(
self.q_optim,
start_factor=1.0,
end_factor=self.end_factor_lr_decay,
total_iters=self.running_steps
)
self.mse_loss = nn.MSELoss()
self.ce_loss = nn.CrossEntropyLoss()
self.one_hot = nn.functional.one_hot
self.n_actions = self.policy.action_dim
self.transform = FrameStackTransform()
def _update_target_encoder(self):
with torch.no_grad():
for target_param, param in zip(
self.policy.target_representation.parameters(),
self.policy.representation.parameters()
):
target_param.data.copy_(
self.tau * target_param.data + (1 - self.tau) * param.data
)
def _compute_contrastive_loss(self, obs: torch.Tensor) -> torch.Tensor:
"""InfoNCE Loss"""
aug_obs_q = self.transform(obs)
aug_obs_k = self.transform(obs)
q = self.policy.representation(aug_obs_q)
with torch.no_grad():
k = self.policy.target_representation(aug_obs_k)
q = nn.functional.normalize(q['state'], dim=1)
k = nn.functional.normalize(k['state'], dim=1)
logits = torch.mm(q, k.T) / self.temperature # (batch_size, batch_size)
labels = torch.arange(obs.size(0), device=obs.device)
return self.ce_loss(logits, labels)
[docs]
def update(self, **samples):
self.iterations += 1
obs = torch.as_tensor(samples['obs'], device=self.device)
act = torch.as_tensor(samples['actions'], device=self.device)
next_obs = torch.as_tensor(samples['obs_next'], device=self.device)
rew = torch.as_tensor(samples['rewards'], device=self.device)
done = torch.as_tensor(samples['terminals'], dtype=torch.float, device=self.device)
info = self.callback.on_update_start(self.iterations,
policy=self.policy, obs=obs, act=act,
next_obs=next_obs, rew=rew, termination=done)
# --------------------- update CURL---------------------
curl_loss = self._compute_contrastive_loss(obs)
self.encoder_optim.zero_grad()
curl_loss.backward()
if self.use_grad_clip:
torch.nn.utils.clip_grad_norm_(self.policy.representation.parameters(), self.grad_clip_norm)
self.encoder_optim.step()
self._update_target_encoder()
_, _, evalQ = self.policy(obs)
_, _, targetQ = self.policy.target(next_obs)
targetQ = targetQ.max(dim=-1).values
targetQ = rew + self.gamma * (1 - done) * targetQ
predictQ = (evalQ * self.one_hot(act.long(), evalQ.shape[1])).sum(dim=-1)
q_loss = self.mse_loss(predictQ, targetQ.detach())
self.q_optim.zero_grad()
q_loss.backward()
if self.use_grad_clip:
torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.grad_clip_norm)
self.q_optim.step()
if self.scheduler is not None:
self.scheduler.step()
if self.iterations % self.sync_frequency == 0:
self.policy.copy_target()
lr = self.q_optim.state_dict()['param_groups'][0]['lr']
info.update({
"curl_loss": curl_loss.item(),
"q_loss": q_loss.item(),
"predictQ": predictQ.mean().item(),
"learning_rate": lr,
})
# print(info)
if self.distributed_training:
info.update({f"{k}/rank_{self.rank}": v for k, v in info.items()})
info.update(self.callback.on_update_end(self.iterations,
policy=self.policy, info=info,
curl_loss=curl_loss, q_loss=q_loss,
evalQ=evalQ, predictQ=predictQ, targetQ=targetQ))
return info