"""
Deep Q-Network (DQN)
Paper link: https://www.nature.com/articles/nature14236
Implementation: Pytorch
"""
import torch
from torch import nn
from argparse import Namespace
try:
from torchvision.transforms import transforms
except:
pass
from xuance.torch.learners import Learner
[docs]
class DrQ_Learner(Learner):
def __init__(self,
config: Namespace,
policy: nn.Module):
super(DrQ_Learner, self).__init__(config, policy)
self.optimizer = torch.optim.Adam(self.policy.parameters(), self.config.learning_rate, eps=1e-5)
self.scheduler = torch.optim.lr_scheduler.LinearLR(self.optimizer,
start_factor=1.0,
end_factor=self.end_factor_lr_decay,
total_iters=self.total_iters)
self.gamma = config.gamma
self.sync_frequency = config.sync_frequency
self.mse_loss = nn.MSELoss()
self.one_hot = nn.functional.one_hot
self.n_actions = self.policy.action_dim
self.transform = FrameStackTransform()
self.aug_factor = config.aug_factor
[docs]
def update(self, **samples):
self.iterations += 1
obs_batch = torch.as_tensor(samples['obs'], device=self.device)
act_batch = torch.as_tensor(samples['actions'], device=self.device)
next_batch = torch.as_tensor(samples['obs_next'], device=self.device)
rew_batch = torch.as_tensor(samples['rewards'], device=self.device)
ter_batch = torch.as_tensor(samples['terminals'], dtype=torch.float, device=self.device)
aug_evalQ = []
aug_targetQ = []
for _ in range(self.aug_factor):
obs_batch = self.transform(obs_batch)
next_batch = self.transform(next_batch)
_, _, evalQ = self.policy(obs_batch)
_, _, targetQ = self.policy.target(next_batch)
aug_evalQ.append(evalQ)
aug_targetQ.append(targetQ)
evalQ = torch.stack(aug_evalQ, dim=0)
targetQ = torch.stack(aug_targetQ, dim=0)
evalQ = evalQ.mean(dim=0)
targetQ = targetQ.mean(dim=0)
targetQ = targetQ.max(dim=-1).values
targetQ = rew_batch + self.gamma * (1 - ter_batch) * targetQ
predictQ = (evalQ * self.one_hot(act_batch.long(), evalQ.shape[1])).sum(dim=-1)
loss = self.mse_loss(predictQ, targetQ)
self.optimizer.zero_grad()
loss.backward()
if self.use_grad_clip:
torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.grad_clip_norm)
self.optimizer.step()
if self.scheduler is not None:
self.scheduler.step()
# hard update for target network
if self.iterations % self.sync_frequency == 0:
self.policy.copy_target()
lr = self.optimizer.state_dict()['param_groups'][0]['lr']
if self.distributed_training:
info = {
f"Qloss/rank_{self.rank}": loss.item(),
f"predictQ/rank_{self.rank}": predictQ.mean().item(),
f"learning_rate/rank_{self.rank}": lr,
}
else:
info = {
"Qloss": loss.item(),
"predictQ": predictQ.mean().item(),
"learning_rate": lr,
}
print(info)
return info