Source code for xuance.torch.learners.multi_agent_rl.iac_learner

"""
Independent Advantage Actor Critic (IAC)
Paper link: https://ojs.aaai.org/index.php/AAAI/article/view/11794
Implementation: Pytorch
"""
import numpy as np
import torch
from torch import nn
from argparse import Namespace
from operator import itemgetter
from xuance.common import Optional, List
from xuance.torch import Tensor
from xuance.torch.utils import ValueNorm
from xuance.torch.learners import LearnerMAS


[docs] class IAC_Learner(LearnerMAS): def __init__(self, config: Namespace, model_keys: List[str], agent_keys: List[str], policy: nn.Module, callback): super(IAC_Learner, self).__init__(config, model_keys, agent_keys, policy, callback) self.build_optimizer() self.use_value_clip, self.value_clip_range = config.use_value_clip, config.value_clip_range self.use_huber_loss, self.huber_delta = config.use_huber_loss, config.huber_delta self.use_value_norm = config.use_value_norm self.vf_coef, self.ent_coef = config.vf_coef, config.ent_coef self.mse_loss = nn.MSELoss() self.huber_loss = nn.HuberLoss(reduction="none", delta=self.huber_delta) if self.use_value_norm: self.value_normalizer = {key: ValueNorm(1).to(self.device) for key in self.model_keys} else: self.value_normalizer = None
[docs] def estimate_total_iterations(self): """Estimated total number of training iterations""" buffer_size = self.config.buffer_size n_epochs = getattr(self.config, "n_epochs", 1) n_minibatch = getattr(self.config, "n_minibatch", 1) episode_length = self.episode_length if self.use_rnn: update_times = (self.config.running_steps // episode_length) // buffer_size else: update_times = self.config.running_steps // buffer_size total_iters = update_times * n_epochs * n_minibatch return total_iters
[docs] def build_optimizer(self): self.optimizer = torch.optim.Adam(self.policy.parameters_model, lr=self.learning_rate, eps=1e-5, weight_decay=self.config.weight_decay) self.scheduler = torch.optim.lr_scheduler.LinearLR(self.optimizer, start_factor=1.0, end_factor=self.end_factor_lr_decay, total_iters=self.total_iters)
[docs] def build_training_data(self, sample: Optional[dict], use_parameter_sharing: Optional[bool] = False, use_actions_mask: Optional[bool] = False, use_global_state: Optional[bool] = False): """ Prepare the training data. Parameters: sample (dict): The raw sampled data. use_parameter_sharing (bool): Whether to use parameter sharing for individual agent models. use_actions_mask (bool): Whether to use actions mask for unavailable actions. use_global_state (bool): Whether to use global state. Returns: sample_Tensor (dict): The formatted sampled data. """ batch_size = sample['batch_size'] seq_length = sample['sequence_length'] if self.use_rnn else 1 state, avail_actions, filled, IDs = None, None, None, None if use_parameter_sharing: k = self.model_keys[0] bs = batch_size * self.n_agents obs_tensor = Tensor(np.stack(itemgetter(*self.agent_keys)(sample['obs']), axis=1)).to(self.device) actions_tensor = Tensor(np.stack(itemgetter(*self.agent_keys)(sample['actions']), axis=1)).to(self.device) values_tensor = Tensor(np.stack(itemgetter(*self.agent_keys)(sample['values']), axis=1)).to(self.device) returns_tensor = Tensor(np.stack(itemgetter(*self.agent_keys)(sample['returns']), axis=1)).to(self.device) advantages_tensor = Tensor(np.stack(itemgetter(*self.agent_keys)(sample['advantages']), 1)).to(self.device) log_pi_old_tensor = Tensor(np.stack(itemgetter(*self.agent_keys)(sample['log_pi_old']), 1)).to(self.device) ter_tensor = Tensor(np.stack(itemgetter(*self.agent_keys)(sample['terminals']), 1)).float().to(self.device) msk_tensor = Tensor(np.stack(itemgetter(*self.agent_keys)(sample['agent_mask']), 1)).float().to(self.device) if self.use_cnn and len(obs_tensor.shape) > 3: obs_shape_item = obs_tensor.shape[2:] else: obs_shape_item = (-1, ) if self.use_rnn: obs = {k: obs_tensor.reshape(bs, seq_length, *obs_shape_item)} if len(actions_tensor.shape) == 3: actions = {k: actions_tensor.reshape(bs, seq_length)} elif len(actions_tensor.shape) == 4: actions = {k: actions_tensor.reshape(bs, seq_length, -1)} else: raise AttributeError("Wrong actions shape.") values = {k: values_tensor.reshape(bs, seq_length)} returns = {k: returns_tensor.reshape(bs, seq_length)} advantages = {k: advantages_tensor.reshape(bs, seq_length)} log_pi_old = {k: log_pi_old_tensor.reshape(bs, seq_length)} terminals = {k: ter_tensor.reshape(bs, seq_length)} agent_mask = {k: msk_tensor.reshape(bs, seq_length)} IDs = torch.eye(self.n_agents).unsqueeze(1).unsqueeze(0).expand( batch_size, -1, seq_length, -1).reshape(bs, seq_length, self.n_agents).to(self.device) else: obs = {k: obs_tensor.reshape(bs, *obs_shape_item)} if len(actions_tensor.shape) == 2: actions = {k: actions_tensor.reshape(bs)} elif len(actions_tensor.shape) == 3: actions = {k: actions_tensor.reshape(bs, -1)} else: raise AttributeError("Wrong actions shape.") values = {k: values_tensor.reshape(bs)} returns = {k: returns_tensor.reshape(bs)} advantages = {k: advantages_tensor.reshape(bs)} log_pi_old = {k: log_pi_old_tensor.reshape(bs)} terminals = {k: ter_tensor.reshape(bs)} agent_mask = {k: msk_tensor.reshape(bs)} IDs = torch.eye(self.n_agents).unsqueeze(0).expand( batch_size, -1, -1).reshape(bs, self.n_agents).to(self.device) if use_actions_mask: avail_a = np.stack(itemgetter(*self.agent_keys)(sample['avail_actions']), axis=1) if self.use_rnn: avail_actions = {k: Tensor(avail_a.reshape([bs, seq_length, -1])).float().to(self.device)} else: avail_actions = {k: Tensor(avail_a.reshape([bs, -1])).float().to(self.device)} else: obs = {} for key in self.agent_keys: obs_tensor = Tensor(sample['obs'][key]).to(self.device) if self.use_cnn and len(obs_tensor.shape) > 3: obs_shape_item = obs_tensor.shape[1:] else: obs_shape_item = (-1,) obs[key] = obs_tensor.reshape([batch_size, *obs_shape_item]) actions = {k: Tensor(sample['actions'][k]).to(self.device) for k in self.agent_keys} values = {k: Tensor(sample['values'][k]).to(self.device) for k in self.agent_keys} returns = {k: Tensor(sample['returns'][k]).to(self.device) for k in self.agent_keys} advantages = {k: Tensor(sample['advantages'][k]).to(self.device) for k in self.agent_keys} log_pi_old = {k: Tensor(sample['log_pi_old'][k]).to(self.device) for k in self.agent_keys} terminals = {k: Tensor(sample['terminals'][k]).float().to(self.device) for k in self.agent_keys} agent_mask = {k: Tensor(sample['agent_mask'][k]).float().to(self.device) for k in self.agent_keys} if use_actions_mask: avail_actions = {k: Tensor(sample['avail_actions'][k]).float().to(self.device) for k in self.agent_keys} if use_global_state: state = Tensor(sample['state']).to(self.device) if self.use_rnn: filled = Tensor(sample['filled']).float().to(self.device) sample_Tensor = { 'batch_size': batch_size, 'state': state, 'obs': obs, 'actions': actions, 'values': values, 'returns': returns, 'advantages': advantages, 'log_pi_old': log_pi_old, 'terminals': terminals, 'agent_mask': agent_mask, 'avail_actions': avail_actions, 'agent_ids': IDs, 'filled': filled, 'seq_length': seq_length, } return sample_Tensor
[docs] def update(self, sample): self.iterations += 1 # prepare training data sample_Tensor = self.build_training_data(sample=sample, use_parameter_sharing=self.use_parameter_sharing, use_actions_mask=self.use_actions_mask) batch_size = sample_Tensor['batch_size'] obs = sample_Tensor['obs'] actions = sample_Tensor['actions'] agent_mask = sample_Tensor['agent_mask'] avail_actions = sample_Tensor['avail_actions'] values = sample_Tensor['values'] returns = sample_Tensor['returns'] advantages = sample_Tensor['advantages'] IDs = sample_Tensor['agent_ids'] bs = batch_size * self.n_agents if self.use_parameter_sharing else batch_size info = self.callback.on_update_start(self.iterations, method="update", policy=self.policy, sample_Tensor=sample_Tensor, bs=bs) # feedforward _, pi_dist_dict = self.policy(observation=obs, agent_ids=IDs, avail_actions=avail_actions) _, values_pred_dict = self.policy.get_values(observation=obs, agent_ids=IDs) loss_a, loss_e, loss_c = [], [], [] for key in self.model_keys: mask_values = agent_mask[key] # policy gradient loss log_pi = pi_dist_dict[key].log_prob(actions[key]) pg_loss = -((advantages[key].detach() * log_pi) * mask_values).sum() / mask_values.sum() loss_a.append(pg_loss) # entropy loss entropy = pi_dist_dict[key].entropy() entropy_loss = (entropy * mask_values).sum() / mask_values.sum() loss_e.append(entropy_loss) # value loss value_pred_i = values_pred_dict[key].reshape(bs) value_target = returns[key].reshape(bs) values_i = values[key].reshape(bs) if self.use_value_clip: value_clipped = values_i + (value_pred_i - values_i).clamp(-self.value_clip_range, self.value_clip_range) if self.use_value_norm: self.value_normalizer[key].update(value_target.reshape(bs, 1)) value_target = self.value_normalizer[key].normalize(value_target.reshape(bs, 1)).reshape(bs) if self.use_huber_loss: loss_v = self.huber_loss(value_pred_i, value_target) loss_v_clipped = self.huber_loss(value_clipped, value_target) else: loss_v = (value_pred_i - value_target) ** 2 loss_v_clipped = (value_clipped - value_target) ** 2 loss_c_ = torch.max(loss_v, loss_v_clipped) * mask_values loss_c.append(loss_c_.sum() / mask_values.sum()) else: if self.use_value_norm: self.value_normalizer[key].update(value_target) value_target = self.value_normalizer[key].normalize(value_target) if self.use_huber_loss: loss_v = self.huber_loss(value_pred_i, value_target) * mask_values else: loss_v = ((value_pred_i - value_target) ** 2) * mask_values loss_c.append(loss_v.sum() / mask_values.sum()) info.update({ f"predict_value/{key}": value_pred_i.mean().item() }) info.update(self.callback.on_update_agent_wise(self.iterations, key, info=info, method="update", mask_values=mask_values, log_pi=log_pi, pg_loss=pg_loss, entropy=entropy, entropy_loss=entropy_loss, value_pred_i=value_pred_i, value_target=value_target, values_i=values_i, loss_v=loss_v)) # Total loss loss = sum(loss_a) + self.vf_coef * sum(loss_c) - self.ent_coef * sum(loss_e) self.optimizer.zero_grad() loss.backward() if self.use_grad_clip: grad_norm = torch.nn.utils.clip_grad_norm_(self.policy.parameters_model, self.grad_clip_norm) info["gradient_norm"] = grad_norm.item() self.optimizer.step() if self.scheduler is not None: self.scheduler.step() # Logger lr = self.optimizer.state_dict()['param_groups'][0]['lr'] info.update({ "learning_rate": lr, "pg_loss": sum(loss_a).item(), "vf_loss": sum(loss_c).item(), "entropy_loss": sum(loss_e).item(), "loss": loss.item(), }) info.update(self.callback.on_update_end(self.iterations, method="update", policy=self.policy, info=info)) return info
[docs] def update_rnn(self, sample): self.iterations += 1 sample_Tensor = self.build_training_data(sample=sample, use_parameter_sharing=self.use_parameter_sharing, use_actions_mask=self.use_actions_mask) batch_size = sample_Tensor['batch_size'] bs_rnn = batch_size * self.n_agents if self.use_parameter_sharing else batch_size obs = sample_Tensor['obs'] actions = sample_Tensor['actions'] values = sample_Tensor['values'] returns = sample_Tensor['returns'] advantages = sample_Tensor['advantages'] avail_actions = sample_Tensor['avail_actions'] agent_mask = sample_Tensor['agent_mask'] filled = sample_Tensor['filled'] seq_len = filled.shape[1] IDs = sample_Tensor['agent_ids'] if self.use_parameter_sharing: filled = filled.unsqueeze(1).expand(batch_size, self.n_agents, seq_len).reshape(bs_rnn, seq_len) info = self.callback.on_update_start(self.iterations, method="update_rnn", policy=self.policy, sample_Tensor=sample_Tensor, bs_rnn=bs_rnn) rnn_hidden_actor = {k: self.policy.actor_representation[k].init_hidden(bs_rnn) for k in self.model_keys} rnn_hidden_critic = {k: self.policy.critic_representation[k].init_hidden(bs_rnn) for k in self.model_keys} # feedforward _, pi_dist_dict = self.policy(obs, agent_ids=IDs, avail_actions=avail_actions, rnn_hidden=rnn_hidden_actor) _, values_pred_dict = self.policy.get_values(obs, agent_ids=IDs, rnn_hidden=rnn_hidden_critic) # calculate losses for each agent loss_a, loss_e, loss_c = [], [], [] for key in self.model_keys: mask_values = agent_mask[key] * filled # policy gradient loss log_pi = pi_dist_dict[key].log_prob(actions[key]).reshape(bs_rnn, seq_len) pg_loss = -((advantages[key].detach() * log_pi) * mask_values).sum() / mask_values.sum() loss_a.append(pg_loss) # entropy loss entropy = pi_dist_dict[key].entropy() entropy_loss = (entropy * mask_values).sum() / mask_values.sum() loss_e.append(entropy_loss) # value loss value_pred_i = values_pred_dict[key].reshape(bs_rnn, seq_len) value_target = returns[key].reshape(bs_rnn, seq_len) values_i = values[key].reshape(bs_rnn, seq_len) if self.use_value_clip: value_clipped = values_i + (value_pred_i - values_i).clamp(-self.value_clip_range, self.value_clip_range) if self.use_value_norm: self.value_normalizer[key].update(value_target.reshape(-1, 1)) value_target = self.value_normalizer[key].normalize(value_target.reshape(-1, 1)) value_target = value_target.reshape(bs_rnn, seq_len) if self.use_huber_loss: loss_v = self.huber_loss(value_pred_i, value_target) loss_v_clipped = self.huber_loss(value_clipped, value_target) else: loss_v = (value_pred_i - value_target) ** 2 loss_v_clipped = (value_clipped - value_target) ** 2 loss_c_ = torch.max(loss_v, loss_v_clipped) * mask_values loss_c.append(loss_c_.sum() / mask_values.sum()) else: if self.use_value_norm: self.value_normalizer[key].update(value_target) value_target = self.value_normalizer[key].normalize(value_target) if self.use_huber_loss: loss_v = self.huber_loss(value_pred_i, value_target) else: loss_v = (value_pred_i - value_target) ** 2 loss_c.append((loss_v * mask_values).sum() / mask_values.sum()) info.update({ f"predict_value/{key}": value_pred_i.mean().item() }) info.update(self.callback.on_update_agent_wise(self.iterations, key, info=info, method="update_rnn", mask_values=mask_values, log_pi=log_pi, pg_loss=pg_loss, entropy=entropy, entropy_loss=entropy_loss, value_pred_i=value_pred_i, value_target=value_target, values_i=values_i, loss_v=loss_v)) loss = sum(loss_a) + self.vf_coef * sum(loss_c) - self.ent_coef * sum(loss_e) self.optimizer.zero_grad() loss.backward() if self.use_grad_clip: grad_norm = torch.nn.utils.clip_grad_norm_(self.policy.parameters_model, self.grad_clip_norm) info["gradient_norm"] = grad_norm.item() self.optimizer.step() if self.scheduler is not None: self.scheduler.step() # Logger lr = self.optimizer.state_dict()['param_groups'][0]['lr'] info.update({ "learning_rate": lr, "pg_loss": sum(loss_a).item(), "vf_loss": sum(loss_c).item(), "entropy_loss": sum(loss_e).item(), "loss": loss.item(), }) info.update(self.callback.on_update_end(self.iterations, method="update_rnn", policy=self.policy, info=info)) return info