import torch
from torch import nn
from xuance.torch.learners import Learner
from argparse import Namespace
[docs]
class NPG_Learner(Learner):
def __init__(self,
config: Namespace,
policy: nn.Module,
callback):
super(NPG_Learner, self).__init__(config, policy, callback)
self.actor_optimizer = torch.optim.Adam(self.policy.actor.parameters(), config.learning_rate, eps=1e-5)
self.critic_optimizer = torch.optim.Adam(self.policy.critic.parameters(), config.learning_rate, eps=1e-5)
self.actor_scheduler = torch.optim.lr_scheduler.LinearLR(self.actor_optimizer,
start_factor=1.0,
end_factor=self.end_factor_lr_decay,
total_iters=config.running_steps)
self.critic_scheduler = torch.optim.lr_scheduler.LinearLR(self.critic_optimizer,
start_factor=1.0,
end_factor=self.end_factor_lr_decay,
total_iters=config.running_steps)
self.gamma = config.gamma
self.mse_loss = nn.MSELoss()
[docs]
def update(self, **samples):
self.iterations += 1
obs_batch = torch.as_tensor(samples['obs'], device=self.device)
act_batch = torch.as_tensor(samples['actions'], device=self.device)
ret_batch = torch.as_tensor(samples['returns'], device=self.device)
adv_batch = torch.as_tensor(samples['advantages'], device=self.device)
info = self.callback.on_update_start(self.iterations,
policy=self.policy, obs=obs_batch, act=act_batch,
returns=ret_batch, advantages=adv_batch)
outputs, a_dist, v_pred = self.policy(obs_batch)
log_prob = a_dist.log_prob(act_batch)
a_loss = -(adv_batch * log_prob).mean() # actor_loss
c_loss = self.mse_loss(v_pred, ret_batch) # critic_loss
# train critic
self.critic_optimizer.zero_grad()
c_loss.backward(retain_graph=True)
if self.use_grad_clip:
torch.nn.utils.clip_grad_norm_(self.policy.critic.parameters(), self.grad_clip_norm)
self.critic_optimizer.step()
#train actor
self.actor_optimizer.zero_grad()
a_loss.backward()
for param in self.policy.actor.parameters():
if param.requires_grad:
fisher_info = self.compute_fisher_information(param, obs_batch, act_batch)
grads = param.grad.view(-1)
natural_grads = torch.matmul(fisher_info, grads)
natural_grads = natural_grads.view(param.size())
param.grad = natural_grads.clone()
if self.use_grad_clip:
torch.nn.utils.clip_grad_norm_(self.policy.actor.parameters(), self.grad_clip_norm)
self.actor_optimizer.step()
if self.critic_scheduler is not None:
self.critic_scheduler.step()
if self.actor_scheduler is not None:
self.actor_scheduler.step()
# Logger
lr_actor = self.actor_optimizer.state_dict()['param_groups'][0]['lr']
lr_critic = self.actor_optimizer.state_dict()['param_groups'][0]['lr']
if self.distributed_training:
info.update({
f"actor-loss/rank_{self.rank}": a_loss.item(),
f"critic-loss/rank_{self.rank}": c_loss.item(),
f"learning_rate_actor/rank_{self.rank}": lr_actor,
f"learning_rate_critic/rank_{self.rank}": lr_critic,
f"predict_value/rank_{self.rank}": v_pred.mean().item()
})
else:
info.update({
"actor-loss": a_loss.item(),
"critic-loss": c_loss.item(),
"learning_rate_actor": lr_actor,
"learning_rate_critic": lr_critic,
"predict_value": v_pred.mean().item()
})
info.update(self.callback.on_update_end(self.iterations,
policy=self.policy, info=info, rep_output=outputs,
a_dist=a_dist, v_pred=v_pred, log_prob=log_prob,
a_loss=a_loss, c_loss=c_loss))
return info