Source code for xuance.torch.learners.policy_gradient.pdqn_learner

"""
Parameterised deep Q network (P-DQN)
Paper link: https://arxiv.org/pdf/1810.06394.pdf
Implementation: Pytorch
"""
import torch
from torch import nn
from xuance.torch.learners import Learner
from argparse import Namespace


[docs] class PDQN_Learner(Learner): def __init__(self, config: Namespace, policy: nn.Module, callback): super(PDQN_Learner, self).__init__(config, policy, callback) conactor_optimizer = torch.optim.Adam(self.policy.conactor.parameters(), self.config.learning_rate) qnetwork_optimizer = torch.optim.Adam(self.policy.qnetwork.parameters(), self.config.learning_rate) self.optimizers = [conactor_optimizer, qnetwork_optimizer] conactor_lr_scheduler = torch.optim.lr_scheduler.LinearLR(conactor_optimizer, start_factor=1.0, end_factor=self.end_factor_lr_decay, total_iters=self.total_iters) qnetwork_lr_scheduler = torch.optim.lr_scheduler.LinearLR(qnetwork_optimizer, start_factor=1.0, end_factor=self.end_factor_lr_decay, total_iters=self.total_iters) self.scheduler = [conactor_lr_scheduler, qnetwork_lr_scheduler] self.tau = config.tau self.gamma = config.gamma self.mse_loss = nn.MSELoss()
[docs] def update(self, **samples): self.iterations += 1 obs_batch = torch.as_tensor(samples['obs'], device=self.device) hyact_batch = torch.as_tensor(samples['actions'], device=self.device) next_batch = torch.as_tensor(samples['obs_next'], device=self.device) rew_batch = torch.as_tensor(samples['rewards'], device=self.device) ter_batch = torch.as_tensor(samples['terminals'], dtype=torch.float, device=self.device) disact_batch = hyact_batch[:, 0].long() conact_batch = hyact_batch[:, 1:] info = self.callback.on_update_start(self.iterations, policy=self.policy, obs=obs_batch, hyact_batch=hyact_batch, disact_batch=disact_batch, conact_batch=conact_batch, next_obs=next_batch, rew=rew_batch, termination=ter_batch) # optimize Q-network with torch.no_grad(): target_conact = self.policy.Atarget(next_batch) target_q = self.policy.Qtarget(next_batch, target_conact) target_q = torch.max(target_q, 1, keepdim=True)[0].squeeze() target_q = rew_batch + (1 - ter_batch) * self.gamma * target_q eval_qs = self.policy.Qeval(obs_batch, conact_batch) eval_q = eval_qs.gather(1, disact_batch.view(-1, 1)).squeeze() q_loss = self.mse_loss(eval_q, target_q) self.optimizer[1].zero_grad() q_loss.backward() self.optimizer[1].step() # optimize actor network policy_q = self.policy.Qpolicy(obs_batch) p_loss = - policy_q.mean() self.optimizer[0].zero_grad() p_loss.backward() self.optimizer[0].step() if self.scheduler is not None: self.scheduler[0].step() self.scheduler[1].step() self.policy.soft_update(self.tau) if self.distributed_training: info.update({ f"Q_loss/rank_{self.rank}": q_loss.item(), f"Qvalue/rank_{self.rank}": eval_q.mean().item(), f"P_loss/rank_{self.rank}": q_loss.item() }) else: info.update({ "Q_loss": q_loss.item(), "Qvalue": eval_q.mean().item(), "P_loss": q_loss.item() }) info.update(self.callback.on_update_end(self.iterations, policy=self.policy, scheduler=self.scheduler, info=info, target_conact=target_conact, target_q=target_q, eval_qs=eval_qs, eval_q=eval_q, policy_q=policy_q, q_loss=q_loss, p_loss=p_loss)) return info