"""
MFAC: Mean Field Actor-Critic
Paper link:
http://proceedings.mlr.press/v80/yang18d/yang18d.pdf
Implementation: Pytorch
"""
import numpy as np
import torch
from torch import nn
from argparse import Namespace
from operator import itemgetter
from xuance.common import Optional, List
from xuance.torch import Tensor
from xuance.torch.learners.multi_agent_rl.ippo_learner import IPPO_Learner
[docs]
class MFAC_Learner(IPPO_Learner):
def __init__(self,
config: Namespace,
model_keys: List[str],
agent_keys: List[str],
policy: nn.Module,
callback):
super(MFAC_Learner, self).__init__(config, model_keys, agent_keys, policy, callback)
[docs]
def update(self, sample):
self.iterations += 1
# prepare training data
act_mean = self.build_actions_mean_input(sample=sample,
use_parameter_sharing=self.use_parameter_sharing)
sample_Tensor = self.build_training_data(sample=sample,
use_parameter_sharing=self.use_parameter_sharing,
use_actions_mask=self.use_actions_mask)
batch_size = sample_Tensor['batch_size']
obs = sample_Tensor['obs']
actions = sample_Tensor['actions']
agent_mask = sample_Tensor['agent_mask']
avail_actions = sample_Tensor['avail_actions']
values = sample_Tensor['values']
returns = sample_Tensor['returns']
advantages = sample_Tensor['advantages']
log_pi_old = sample_Tensor['log_pi_old']
IDs = sample_Tensor['agent_ids']
bs = batch_size * self.n_agents if self.use_parameter_sharing else batch_size
info = self.callback.on_update_start(self.iterations, method="update", actions_mean=act_mean,
policy=self.policy, sample_Tensor=sample_Tensor, bs=bs)
# feedforward
_, pi_dists_dict = self.policy(observation=obs, agent_ids=IDs, avail_actions=avail_actions)
_, value_pred_dict = self.policy.get_values(observation=obs, actions_mean=act_mean, agent_ids=IDs)
loss_a, loss_e, loss_c = [], [], []
for key in self.model_keys:
mask_values = agent_mask[key]
# actor loss
log_pi = pi_dists_dict[key].log_prob(actions[key]).reshape(bs)
ratio = torch.exp(log_pi - log_pi_old[key]).reshape(bs)
advantages_mask = advantages[key].detach() * mask_values
surrogate1 = ratio * advantages_mask
surrogate2 = torch.clip(ratio, 1 - self.clip_range, 1 + self.clip_range) * advantages_mask
loss_a.append(-torch.min(surrogate1, surrogate2).sum() / mask_values.sum())
# entropy loss
entropy = pi_dists_dict[key].entropy().reshape(bs) * mask_values
loss_e.append(entropy.sum() / mask_values.sum())
# critic loss
value_pred_i = value_pred_dict[key].reshape(bs)
value_target = returns[key].reshape(bs)
values_i = values[key].reshape(bs)
if self.use_value_clip:
value_clipped = values_i + (value_pred_i - values_i).clamp(-self.value_clip_range,
self.value_clip_range)
if self.use_value_norm:
self.value_normalizer[key].update(value_target.reshape(bs, 1))
value_target = self.value_normalizer[key].normalize(value_target.reshape(bs, 1)).reshape(bs)
if self.use_huber_loss:
loss_v = self.huber_loss(value_pred_i, value_target)
loss_v_clipped = self.huber_loss(value_clipped, value_target)
else:
loss_v = (value_pred_i - value_target) ** 2
loss_v_clipped = (value_clipped - value_target) ** 2
loss_c_ = torch.max(loss_v, loss_v_clipped) * mask_values
loss_c.append(loss_c_.sum() / mask_values.sum())
else:
if self.use_value_norm:
self.value_normalizer[key].update(value_target)
value_target = self.value_normalizer[key].normalize(value_target)
if self.use_huber_loss:
loss_v = self.huber_loss(value_pred_i, value_target) * mask_values
else:
loss_v = ((value_pred_i - value_target) ** 2) * mask_values
loss_c.append(loss_v.sum() / mask_values.sum())
info.update({
f"{key}/actor_loss": loss_a[-1].item(),
f"{key}/critic_loss": loss_c[-1].item(),
f"{key}/entropy": loss_e[-1].item(),
f"{key}/predict_value": value_pred_i.mean().item()
})
info.update(self.callback.on_update_agent_wise(self.iterations, key, info=info, method="update",
mask_values=mask_values, log_pi=log_pi, ratio=ratio,
surrogate1=surrogate1, surrogate2=surrogate2,
entropy=entropy,
value_pred_i=value_pred_i, value_target=value_target,
values_i=values_i, loss_v=loss_v))
loss = sum(loss_a) + self.vf_coef * sum(loss_c) - self.ent_coef * sum(loss_e)
self.optimizer.zero_grad()
loss.backward()
if self.use_grad_clip:
grad_norm = torch.nn.utils.clip_grad_norm_(self.policy.parameters_model, self.grad_clip_norm)
info["gradient_norm"] = grad_norm.item()
self.optimizer.step()
if self.scheduler is not None and self.use_linear_lr_decay:
self.scheduler.step()
# Logger
lr = self.optimizer.state_dict()['param_groups'][0]['lr']
info.update({
"learning_rate": lr,
"loss": loss.item(),
})
info.update(self.callback.on_update_end(self.iterations, method="update", policy=self.policy, info=info))
return info
[docs]
def update_rnn(self, sample):
self.iterations += 1
# prepare training data
act_mean = self.build_actions_mean_input(sample=sample,
use_parameter_sharing=self.use_parameter_sharing)
sample_Tensor = self.build_training_data(sample=sample,
use_parameter_sharing=self.use_parameter_sharing,
use_actions_mask=self.use_actions_mask)
batch_size = sample_Tensor['batch_size']
bs_rnn = batch_size * self.n_agents if self.use_parameter_sharing else batch_size
obs = sample_Tensor['obs']
actions = sample_Tensor['actions']
values = sample_Tensor['values']
returns = sample_Tensor['returns']
advantages = sample_Tensor['advantages']
log_pi_old = sample_Tensor['log_pi_old']
avail_actions = sample_Tensor['avail_actions']
agent_mask = sample_Tensor['agent_mask']
filled = sample_Tensor['filled']
seq_len = filled.shape[1]
IDs = sample_Tensor['agent_ids']
if self.use_parameter_sharing:
filled = filled.unsqueeze(1).expand(-1, self.n_agents, -1).reshape(bs_rnn, seq_len)
info = self.callback.on_update_start(self.iterations, method="update_rnn", actions_mean=act_mean,
policy=self.policy, sample_Tensor=sample_Tensor, bs_rnn=bs_rnn)
# feedfowrd
rnn_hidden_actor = {k: self.policy.actor_representation[k].init_hidden(bs_rnn) for k in self.model_keys}
rnn_hidden_critic = {k: self.policy.critic_representation[k].init_hidden(bs_rnn) for k in self.model_keys}
# feedforward
_, pi_dist_dict = self.policy(obs, agent_ids=IDs, avail_actions=avail_actions, rnn_hidden=rnn_hidden_actor)
# calculate values
if self.use_global_state:
state = sample_Tensor['state']
_, value_pred_dict = self.policy.get_values(observation=state, actions_mean=act_mean, agent_ids=IDs,
rnn_hidden=rnn_hidden_critic)
else:
_, value_pred_dict = self.policy.get_values(observation=obs, actions_mean=act_mean, agent_ids=IDs,
rnn_hidden=rnn_hidden_critic)
# calculate losses for each agent
loss_a, loss_e, loss_c = [], [], []
for key in self.model_keys:
mask_values = agent_mask[key] * filled
log_pi = pi_dist_dict[key].log_prob(actions[key]).reshape(bs_rnn, seq_len)
ratio = torch.exp(log_pi - log_pi_old[key])
surrogate1 = ratio * advantages[key]
surrogate2 = torch.clip(ratio, 1 - self.clip_range, 1 + self.clip_range) * advantages[key]
loss_a.append(-(torch.min(surrogate1, surrogate2) * mask_values).sum() / mask_values.sum())
# entropy loss
entropy = pi_dist_dict[key].entropy().reshape(bs_rnn, seq_len)
entropy = entropy * mask_values
loss_e.append(entropy.sum() / mask_values.sum())
# critic loss
value_pred_i = value_pred_dict[key].reshape(bs_rnn, seq_len)
value_target = returns[key].reshape(bs_rnn, seq_len)
values_i = values[key].reshape(bs_rnn, seq_len)
if self.use_value_clip:
value_clipped = values_i + (value_pred_i - values_i).clamp(-self.value_clip_range,
self.value_clip_range)
if self.use_value_norm:
self.value_normalizer[key].update(value_target.reshape(-1, 1))
value_target = self.value_normalizer[key].normalize(value_target.reshape(-1, 1))
value_target = value_target.reshape(bs_rnn, seq_len)
if self.use_huber_loss:
loss_v = self.huber_loss(value_pred_i, value_target)
loss_v_clipped = self.huber_loss(value_clipped, value_target)
else:
loss_v = (value_pred_i - value_target) ** 2
loss_v_clipped = (value_clipped - value_target) ** 2
loss_c_ = torch.max(loss_v, loss_v_clipped) * mask_values
loss_c.append(loss_c_.sum() / mask_values.sum())
else:
if self.use_value_norm:
self.value_normalizer[key].update(value_target)
value_target = self.value_normalizer[key].normalize(value_target)
if self.use_huber_loss:
loss_v = self.huber_loss(value_pred_i, value_target)
else:
loss_v = (value_pred_i - value_target) ** 2
loss_c.append((loss_v * mask_values).sum() / mask_values.sum())
info.update({
f"{key}/actor_loss": loss_a[-1].item(),
f"{key}/critic_loss": loss_c[-1].item(),
f"{key}/entropy": loss_e[-1].item(),
f"{key}/predict_value": value_pred_i.mean().item()
})
info.update(self.callback.on_update_agent_wise(self.iterations, key, info=info, method="update_rnn",
mask_values=mask_values, log_pi=log_pi, ratio=ratio,
surrogate1=surrogate1, surrogate2=surrogate2,
entropy=entropy,
value_pred_i=value_pred_i, value_target=value_target,
values_i=values_i, loss_v=loss_v))
loss = sum(loss_a) + self.vf_coef * sum(loss_c) - self.ent_coef * sum(loss_e)
self.optimizer.zero_grad()
loss.backward()
if self.use_grad_clip:
grad_norm = torch.nn.utils.clip_grad_norm_(self.policy.parameters_model, self.grad_clip_norm)
info["gradient_norm"] = grad_norm.item()
self.optimizer.step()
if self.scheduler is not None:
self.scheduler.step()
# Logger
lr = self.optimizer.state_dict()['param_groups'][0]['lr']
info.update({
"learning_rate": lr,
"loss": loss.item(),
})
info.update(self.callback.on_update_end(self.iterations, method="update_rnn", policy=self.policy, info=info))
return info