Source code for xuance.torch.learners.policy_gradient.sacdis_learner

"""
Soft Actor-Critic with discrete action spaces (SAC-Discrete)
Paper link: https://arxiv.org/pdf/1910.07207.pdf
Implementation: Pytorch
"""
import torch
from torch import nn
from xuance.torch.learners import Learner
from argparse import Namespace


[docs] class SACDIS_Learner(Learner): def __init__(self, config: Namespace, policy: nn.Module, callback): super(SACDIS_Learner, self).__init__(config, policy, callback) self.optimizer = { 'actor': torch.optim.Adam(self.policy.actor_parameters, self.config.learning_rate_actor), 'critic': torch.optim.Adam(self.policy.critic_parameters, self.config.learning_rate_critic)} self.scheduler = { 'actor': torch.optim.lr_scheduler.LinearLR(self.optimizer['actor'], start_factor=1.0, end_factor=self.end_factor_lr_decay, total_iters=self.total_iters), 'critic': torch.optim.lr_scheduler.LinearLR(self.optimizer['critic'], start_factor=1.0, end_factor=self.end_factor_lr_decay, total_iters=self.total_iters)} self.mse_loss = nn.MSELoss() self.tau = config.tau self.gamma = config.gamma self.alpha = config.alpha self.use_automatic_entropy_tuning = config.use_automatic_entropy_tuning if self.use_automatic_entropy_tuning: self.target_entropy = -float(policy.action_space.n) self.log_alpha = nn.Parameter(torch.zeros(1, requires_grad=True, device=self.device)) self.alpha = self.log_alpha.exp() self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=config.learning_rate_actor)
[docs] def update(self, **samples): self.iterations += 1 obs_batch = torch.as_tensor(samples['obs'], device=self.device) act_batch = torch.as_tensor(samples['actions'], device=self.device).unsqueeze(-1) next_batch = torch.as_tensor(samples['obs_next'], device=self.device) rew_batch = torch.as_tensor(samples['rewards'], device=self.device).unsqueeze(-1) ter_batch = torch.as_tensor(samples['terminals'], dtype=torch.float, device=self.device).reshape([-1, 1]) info = self.callback.on_update_start(self.iterations, policy=self.policy, obs=obs_batch, act=act_batch, next_obs=next_batch, rew=rew_batch, termination=ter_batch) # actor update action_prob, log_pi, policy_q_1, policy_q_2 = self.policy.Qpolicy(obs_batch) policy_q = torch.min(policy_q_1, policy_q_2) p_loss = (action_prob * (self.alpha * log_pi - policy_q)).sum(dim=1).mean() self.optimizer['actor'].zero_grad() p_loss.backward() if self.use_grad_clip: torch.nn.utils.clip_grad_norm_(self.policy.actor_parameters, self.grad_clip_norm) self.optimizer['actor'].step() # critic update action_q_1, action_q_2 = self.policy.Qaction(obs_batch) action_q_1 = action_q_1.gather(1, act_batch.long()) action_q_2 = action_q_2.gather(1, act_batch.long()) action_prob_next, log_pi_next, target_q = self.policy.Qtarget(next_batch) target_q = action_prob_next * (target_q - self.alpha * log_pi_next) target_q = target_q.sum(dim=1).unsqueeze(-1) backup = rew_batch + (1 - ter_batch) * self.gamma * target_q q_loss = self.mse_loss(action_q_1, backup.detach()) + self.mse_loss(action_q_2, backup.detach()) self.optimizer['critic'].zero_grad() q_loss.backward() if self.use_grad_clip: torch.nn.utils.clip_grad_norm_(self.policy.critic_parameters, self.grad_clip_norm) self.optimizer['critic'].step() # automatic entropy tuning if self.use_automatic_entropy_tuning: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() self.alpha = self.log_alpha.exp() else: alpha_loss = 0 if self.scheduler is not None: self.scheduler['actor'].step() self.scheduler['critic'].step() self.policy.soft_update(self.tau) actor_lr = self.optimizer['actor'].state_dict()['param_groups'][0]['lr'] critic_lr = self.optimizer['critic'].state_dict()['param_groups'][0]['lr'] if self.distributed_training: info.update({ f"Qloss/rank_{self.rank}": q_loss.item(), f"Ploss/rank_{self.rank}": p_loss.item(), f"Qvalue/rank_{self.rank}": policy_q.mean().item(), f"actor_lr/rank_{self.rank}": actor_lr, f"critic_lr/rank_{self.rank}": critic_lr, }) else: info.update({ "Qloss": q_loss.item(), "Ploss": p_loss.item(), "Qvalue": policy_q.mean().item(), "actor_lr": actor_lr, "critic_lr": critic_lr, }) if self.use_automatic_entropy_tuning: if self.distributed_training: info.update({f"alpha_loss/rank_{self.rank}": alpha_loss.item(), f"alpha/rank_{self.rank}": self.alpha.item()}) else: info.update({"alpha_loss": alpha_loss.item(), "alpha": self.alpha.item()}) info.update(self.callback.on_update_end(self.iterations, policy=self.policy, info=info, action_prob=action_prob, log_pi=log_pi, policy_q_1=policy_q_1, policy_q_2=policy_q_2, policy_q=policy_q, p_loss=p_loss, action_q_1=action_q_1, action_q_2=action_q_2, action_prob_next=action_prob_next, log_pi_next=log_pi_next, target_q=target_q, backup=backup, q_loss=q_loss, alpha_loss=alpha_loss, alpha=self.alpha)) return info