Source code for xuance.mindspore.learners.multi_agent_rl.coma_learner

"""
COMA: Counterfactual Multi-Agent Policy Gradients
Paper link: https://ojs.aaai.org/index.php/AAAI/article/view/11794
Implementation: MindSpore
"""
import numpy as np
from argparse import Namespace
from operator import itemgetter
from xuance.common import List, Optional
from xuance.mindspore import ms, msd, Tensor, Module, ops, nn, optim
from xuance.mindspore.utils import clip_grads
from xuance.mindspore.learners import LearnerMAS  # Different from torch


[docs] class COMA_Learner(LearnerMAS): def __init__(self, config: Namespace, model_keys: List[str], agent_keys: List[str], policy: Module, callback): config.use_value_clip, config.value_clip_range = False, None config.use_huber_loss, config.huber_delta = False, None config.use_value_norm = False config.vf_coef, config.ent_coef = None, None super(COMA_Learner, self).__init__(config, model_keys, agent_keys, policy, callback) self.build_optimizer() self.use_value_clip, self.value_clip_range = config.use_value_clip, config.value_clip_range self.use_huber_loss, self.huber_delta = config.use_huber_loss, config.huber_delta self.use_value_norm = config.use_value_norm self.vf_coef, self.ent_coef = config.vf_coef, config.ent_coef self.mse_loss = nn.MSELoss() self.huber_loss = nn.HuberLoss(reduction="none", delta=self.huber_delta) self.sync_frequency = config.sync_frequency self.n_actions = {k: int(self.policy.action_space[k].n) for k in self.model_keys} self.mse_loss = nn.MSELoss() self.softmax = nn.Softmax(axis=-1) self.is_continuous = self.policy.is_continuous if self.is_continuous: self.pi_dist = {k: msd.Normal(dtype=ms.float32) for k in self.model_keys} else: self.pi_dist = {k: msd.Categorical() for k in self.model_keys} # Get gradient function self.actor_grad_fn = ms.value_and_grad(self.actor_forward_fn, None, self.optimizer['actor'].parameters, has_aux=True) self.critic_grad_fn = ms.value_and_grad(self.critic_forward_fn, None, self.optimizer['critic'].parameters, has_aux=True) self.policy.set_train()
[docs] def build_optimizer(self): self.optimizer = { 'actor': optim.Adam(params=self.policy.parameters_actor, lr=self.config.learning_rate_actor, eps=1e-5), 'critic': optim.Adam(params=self.policy.parameters_critic, lr=self.config.learning_rate_critic, eps=1e-5) } self.scheduler = { 'actor': optim.lr_scheduler.LinearLR(self.optimizer['actor'], start_factor=1.0, end_factor=self.end_factor_lr_decay, total_iters=self.config.running_steps), 'critic': optim.lr_scheduler.LinearLR(self.optimizer['critic'], start_factor=1.0, end_factor=self.end_factor_lr_decay, total_iters=self.config.running_steps) }
[docs] def build_training_data(self, sample: Optional[dict], use_parameter_sharing: Optional[bool] = False, use_actions_mask: Optional[bool] = False, use_global_state: Optional[bool] = False): """ Prepare the training data. Parameters: sample (dict): The raw sampled data. use_parameter_sharing (bool): Whether to use parameter sharing for individual agent models. use_actions_mask (bool): Whether to use actions mask for unavailable actions. use_global_state (bool): Whether to use global state. Returns: sample_Tensor (dict): The formatted sampled data. """ batch_size = sample['batch_size'] seq_length = sample['sequence_length'] if self.use_rnn else 1 state, avail_actions, filled, IDs = None, None, None, None if use_parameter_sharing: k = self.model_keys[0] bs = batch_size * self.n_agents obs_tensor = Tensor(np.stack(itemgetter(*self.agent_keys)(sample['obs']), axis=1)) actions_tensor = Tensor(np.stack(itemgetter(*self.agent_keys)(sample['actions']), axis=1)) values_tensor = Tensor(np.stack(itemgetter(*self.agent_keys)(sample['values']), axis=1)) returns_tensor = Tensor(np.stack(itemgetter(*self.agent_keys)(sample['returns']), axis=1)) advantages_tensor = Tensor(np.stack(itemgetter(*self.agent_keys)(sample['advantages']), 1)) log_pi_old_tensor = Tensor(np.stack(itemgetter(*self.agent_keys)(sample['log_pi_old']), 1)) ter_tensor = Tensor(np.stack(itemgetter(*self.agent_keys)(sample['terminals']), 1)).float() msk_tensor = Tensor(np.stack(itemgetter(*self.agent_keys)(sample['agent_mask']), 1)).float() if self.use_rnn: obs = {k: obs_tensor.reshape(bs, seq_length, -1)} if len(actions_tensor.shape) == 3: actions = {k: actions_tensor.reshape(bs, seq_length)} elif len(actions_tensor.shape) == 4: actions = {k: actions_tensor.reshape(bs, seq_length, -1)} else: raise AttributeError("Wrong actions shape.") values = {k: values_tensor.reshape(bs, seq_length)} returns = {k: returns_tensor.reshape(bs, seq_length)} advantages = {k: advantages_tensor.reshape(bs, seq_length)} log_pi_old = {k: log_pi_old_tensor.reshape(bs, seq_length)} terminals = {k: ter_tensor.reshape(bs, seq_length)} agent_mask = {k: msk_tensor.reshape(bs, seq_length)} IDs = ops.eye(self.n_agents).unsqueeze(1).unsqueeze(0).expand( batch_size, -1, seq_length, -1).reshape(bs, seq_length, self.n_agents) else: obs = {k: obs_tensor.reshape(bs, -1)} if len(actions_tensor.shape) == 2: actions = {k: actions_tensor.reshape(bs)} elif len(actions_tensor.shape) == 3: actions = {k: actions_tensor.reshape(bs, -1)} else: raise AttributeError("Wrong actions shape.") values = {k: values_tensor.reshape(bs)} returns = {k: returns_tensor.reshape(bs)} advantages = {k: advantages_tensor.reshape(bs)} log_pi_old = {k: log_pi_old_tensor.reshape(bs)} terminals = {k: ter_tensor.reshape(bs)} agent_mask = {k: msk_tensor.reshape(bs)} IDs = Tensor(np.eye(self.n_agents, dtype=np.float32)[None].repeat(batch_size, axis=0).reshape(bs, -1)) if use_actions_mask: avail_a = np.stack(itemgetter(*self.agent_keys)(sample['avail_actions']), axis=1) if self.use_rnn: avail_actions = {k: Tensor(avail_a.reshape([bs, seq_length, -1])).float()} else: avail_actions = {k: Tensor(avail_a.reshape([bs, -1])).float()} else: obs = {k: Tensor(sample['obs'][k]) for k in self.agent_keys} actions = {k: Tensor(sample['actions'][k]) for k in self.agent_keys} values = {k: Tensor(sample['values'][k]) for k in self.agent_keys} returns = {k: Tensor(sample['returns'][k]) for k in self.agent_keys} advantages = {k: Tensor(sample['advantages'][k]) for k in self.agent_keys} log_pi_old = {k: Tensor(sample['log_pi_old'][k]) for k in self.agent_keys} terminals = {k: Tensor(sample['terminals'][k]).float() for k in self.agent_keys} agent_mask = {k: Tensor(sample['agent_mask'][k]).float() for k in self.agent_keys} if use_actions_mask: avail_actions = {k: Tensor(sample['avail_actions'][k]).float() for k in self.agent_keys} if use_global_state: state = Tensor(sample['state']) if self.use_rnn: filled = Tensor(sample['filled']).float() sample_Tensor = { 'batch_size': batch_size, 'state': state, 'obs': obs, 'actions': actions, 'values': values, 'returns': returns, 'advantages': advantages, 'log_pi_old': log_pi_old, 'terminals': terminals, 'agent_mask': agent_mask, 'avail_actions': avail_actions, 'agent_ids': IDs, 'filled': filled, 'seq_length': seq_length, } return sample_Tensor
[docs] def actor_forward_fn(self, bs, obs, actions, values_pred_dict, agent_mask, avail_actions, IDs, epsilon): # feedforward _, pi_logits = self.policy(observation=obs, agent_ids=IDs, avail_actions=avail_actions, epsilon=epsilon) # calculate loss loss_a, advantages_list = [], [] for key in self.model_keys: mask_values = agent_mask[key] mask_values_sum = ops.reduce_sum(mask_values) pi_probs = self.softmax(pi_logits[key]) pi_probs = (1 - epsilon) * pi_probs + epsilon * 1 / self.n_actions[key] baseline = ops.reduce_sum(pi_probs * values_pred_dict[key], axis=-1).reshape(bs) pi_taken = pi_probs.gather(-1, actions[key].unsqueeze(-1).long()) q_taken = values_pred_dict[key].gather(-1, actions[key].unsqueeze(-1).long()).reshape(bs) log_pi_taken = ops.log(pi_taken).reshape(bs) advantages = ops.stop_gradient(q_taken - baseline) loss_a.append(-(advantages * log_pi_taken * mask_values).sum() / mask_values_sum) # info.update(self.callback.on_update_agent_wise(self.iterations, key, info=info, method="update", # mask_values=mask_values, pi_probs=pi_probs, # baseline=baseline, pi_taken=pi_taken, # q_taken=q_taken, log_pi_taken=log_pi_taken, # advantages=advantages, loss_a=loss_a, # td_error=td_error)) advantages_list.append(advantages) loss_coma = sum(loss_a) advantages_mean = sum(advantages_list) / len(advantages_list) return loss_coma, advantages_mean
[docs] def critic_forward_fn(self, bs, batch_size, obs, state, actions, agent_mask, returns, IDs): # forward if self.use_parameter_sharing: key = self.model_keys[0] actions_onehot = {key: ops.one_hot(actions[key].long(), depth=self.n_actions[key])} else: IDs = ops.repeat_elements(ops.eye(self.n_agents).unsqueeze(0), rep=batch_size, axis=0).reshape(bs, -1) actions_onehot = {k: ops.one_hot(actions[k].long(), depth=self.n_actions[k]) for k in self.agent_keys} _, values_pred = self.policy.get_values(state=state, observation=obs, actions=actions_onehot, agent_ids=IDs, target=False) if self.use_parameter_sharing: values_pred_dict = {k: values_pred.reshape(bs, -1) for k in self.model_keys} else: values_pred_dict = {k: values_pred[:, i] for i, k in enumerate(self.model_keys)} # calculate loss loss_c = [] for key in self.model_keys: mask_values = agent_mask[key] q_taken = values_pred_dict[key].gather(-1, actions[key].unsqueeze(-1).long()).reshape(bs) td_error = (q_taken - ops.stop_gradient(returns[key])) * mask_values loss_c.append((td_error ** 2).sum() / mask_values.sum()) loss_critic = sum(loss_c) return loss_critic, values_pred_dict
[docs] def update(self, sample, epsilon=0.0): self.iterations += 1 # prepare training data sample_Tensor = self.build_training_data(sample=sample, use_parameter_sharing=self.use_parameter_sharing, use_actions_mask=self.use_actions_mask, use_global_state=True) batch_size = sample_Tensor['batch_size'] state = sample_Tensor['state'] obs = sample_Tensor['obs'] actions = sample_Tensor['actions'] agent_mask = sample_Tensor['agent_mask'] avail_actions = sample_Tensor['avail_actions'] returns = sample_Tensor['returns'] IDs = sample_Tensor['agent_ids'] bs = batch_size * self.n_agents if self.use_parameter_sharing else batch_size info = self.callback.on_update_start(self.iterations, method="update", policy=self.policy, sample_Tensor=sample_Tensor, bs=bs) # update critic (loss_critic, values_pred_dict), grad_critic = self.critic_grad_fn(bs, batch_size, obs, state, actions, agent_mask, returns, IDs) if self.use_grad_clip: grad_critic = clip_grads(grad_critic, Tensor(-self.grad_clip_norm), Tensor(self.grad_clip_norm)) self.optimizer['critic'](grad_critic) self.scheduler['critic'].step() # update learning rate # update actor (loss_coma, advantages_mean), grad_actor = self.actor_grad_fn(bs, obs, actions, values_pred_dict, agent_mask, avail_actions, IDs, epsilon) if self.use_grad_clip: grad_actor = clip_grads(grad_actor, Tensor(-self.grad_clip_norm), Tensor(self.grad_clip_norm)) self.optimizer['actor'](grad_actor) self.scheduler['actor'].step() # update learning rate # Logger learning_rate_actor = self.scheduler['actor'].get_last_lr()[0] learning_rate_critic = self.scheduler['critic'].get_last_lr()[0] info.update({ "learning_rate_actor": learning_rate_actor.asnumpy(), "learning_rate_critic": learning_rate_critic.asnumpy(), "actor_loss": loss_coma.asnumpy(), "critic_loss": loss_critic.asnumpy(), "advantage": advantages_mean.mean().asnumpy(), }) info.update(self.callback.on_update_end(self.iterations, method="update", policy=self.policy, info=info)) return info