Source code for xuance.mindspore.learners.multi_agent_rl.qtran_learner

"""
QTRAN: Learning to Factorize with Transformation for Cooperative Multi-Agent Reinforcement Learning
Paper link:
http://proceedings.mlr.press/v97/son19a/son19a.pdf
Implementation: MindSpore
"""
from argparse import Namespace
from operator import itemgetter
from xuance.common import List
from xuance.mindspore import ms, Module, nn, Tensor, optim, ops
from xuance.mindspore.learners import LearnerMAS
from xuance.mindspore.utils import clip_grads


[docs] class QTRAN_Learner(LearnerMAS): def __init__(self, config: Namespace, model_keys: List[str], agent_keys: List[str], policy: Module, callback): self.sync_frequency = config.sync_frequency self.mse_loss = nn.MSELoss() super(QTRAN_Learner, self).__init__(config, model_keys, agent_keys, policy, callback) self.optimizer = optim.Adam(params=self.policy.parameters_model, lr=config.learning_rate, eps=1e-5) self.scheduler = optim.lr_scheduler.LinearLR(self.optimizer, start_factor=1.0, end_factor=self.end_factor_lr_decay, total_iters=self.config.running_steps) self.n_actions = {k: self.policy.action_space[k].n for k in self.model_keys} # Get gradient function self.grad_fn = ms.value_and_grad(self.forward_fn, None, self.optimizer.parameters, has_aux=True) self.policy.set_train()
[docs] def forward_fn(self, bs, batch_size, state, obs, actions, state_next, obs_next, agent_mask, avail_actions, avail_actions_next, rewards_tot, terminals_tot, IDs): info = {} _, hidden_state, actions_greedy, q_eval = self.policy(obs, agent_ids=IDs, avail_actions=avail_actions) _, hidden_state_next, q_next = self.policy.Qtarget(obs_next, agent_ids=IDs) q_eval_a, q_eval_greedy_a, q_next_a = {}, {}, {} actions_next_greedy = {} for key in self.model_keys: mask_values = agent_mask[key] q_eval_a[key] = q_eval[key].gather(-1, actions[key].long().unsqueeze(-1)).reshape(bs) q_eval_greedy_a[key] = q_eval[key].gather(-1, actions_greedy[key].long().unsqueeze(-1)).reshape(bs) if self.use_actions_mask: q_next[key][avail_actions_next[key] == 0] = -1e10 if self.config.double_q: _, _, act_next, _ = self.policy(observation=obs_next, agent_ids=IDs, avail_actions=avail_actions, agent_key=key) actions_next_greedy[key] = act_next[key] q_next_a[key] = q_next[key].gather(-1, act_next[key].long().unsqueeze(-1)).reshape(bs) else: actions_next_greedy[key] = q_next[key].argmax(dim=-1, keepdim=False) q_next_a[key] = q_next[key].max(dim=-1, keepdim=True).values.reshape(bs) q_eval_a[key] *= mask_values q_eval_greedy_a[key] *= mask_values q_next_a[key] *= mask_values info.update(self.callback.on_update_agent_wise(self.iterations, key, info=info, method="update", mask_values=mask_values, q_eval_a=q_eval_a, q_eval_greedy_a=q_eval_greedy_a)) if self.config.agent == "QTRAN_base": # -- TD Loss -- q_joint, v_joint = self.policy.Q_tran(state, hidden_state, actions, agent_mask) q_joint_next, _ = self.policy.Q_tran_target(state_next, hidden_state_next, actions_next_greedy, agent_mask) y_dqn = rewards_tot + (1 - terminals_tot) * self.gamma * q_joint_next loss_td = self.mse_loss(q_joint, ops.stop_gradient(y_dqn)) # TD loss # -- Opt Loss -- # Argmax across the current agents' actions q_tot_greedy = self.policy.Q_tot(q_eval_greedy_a) q_joint_greedy_hat, _ = self.policy.Q_tran(state, hidden_state, actions_greedy, agent_mask) error_opt = q_tot_greedy - ops.stop_gradient(q_joint_greedy_hat) + v_joint loss_opt = ops.reduce_mean(error_opt ** 2) # Opt loss # -- Nopt Loss -- q_tot = self.policy.Q_tot(q_eval_a) q_joint_hat = q_joint error_nopt = q_tot - ops.stop_gradient(q_joint_hat) + v_joint error_nopt = error_nopt.clamp(max=0) loss_nopt = ops.reduce_mean(error_nopt ** 2) # NOPT loss info["Q_joint"] = q_joint.mean().asnumpy() elif self.config.agent == "QTRAN_alt": # -- TD Loss -- (Computed for all agents) q_count, v_joint = self.policy.Q_tran(state, hidden_state, actions, agent_mask) actions_choosen = ops.stack(itemgetter(*self.model_keys)(actions), axis=0).reshape(-1, self.n_agents, 1) q_joint_choosen = q_count.gather(-1, actions_choosen.long()).reshape(-1, self.n_agents) q_next_count, _ = self.policy.Q_tran_target(state_next, hidden_state_next, actions_next_greedy, agent_mask) actions_next_choosen = ops.stack(itemgetter(*self.model_keys)(actions_next_greedy), axis=0) actions_next_choosen = actions_next_choosen.reshape(-1, self.n_agents, 1) q_joint_next_choosen = q_next_count.gather(-1, actions_next_choosen.long()).reshape(-1, self.n_agents) y_dqn = rewards_tot + (1 - terminals_tot) * self.gamma * q_joint_next_choosen loss_td = self.mse_loss(q_joint_choosen, ops.stop_gradient(y_dqn)) # TD loss # -- Opt Loss -- (Computed for all agents) q_tot_greedy = self.policy.Q_tot(q_eval_greedy_a) q_joint_greedy_hat, _ = self.policy.Q_tran(state, hidden_state, actions_greedy, agent_mask) actions_greedy_current = ops.stack(itemgetter(*self.model_keys)(actions_greedy), axis=0) actions_greedy_current = actions_greedy_current.reshape(-1, self.n_agents, 1) q_joint_greedy_hat_all = q_joint_greedy_hat.gather( -1, actions_greedy_current.long()).reshape(-1, self.n_agents) error_opt = q_tot_greedy - ops.stop_gradient(q_joint_greedy_hat_all) + v_joint loss_opt = ops.reduce_mean(error_opt ** 2) # Opt loss # -- Nopt Loss -- q_eval_count = ops.stack(itemgetter(*self.model_keys)(q_eval), axis=0).reshape(batch_size * self.n_agents, -1) q_sums = ops.stack(itemgetter(*self.model_keys)(q_eval_a), axis=0).reshape(-1, self.n_agents) q_sums_repeat = ops.repeat_elements(q_sums.unsqueeze(1), rep=self.n_agents, axis=1) agent_mask_diag = ops.repeat_elements((1 - ops.eye(self.n_agents, dtype=ms.float32)).unsqueeze(0), rep=batch_size, axis=0) q_sum_mask = (q_sums_repeat * agent_mask_diag).sum(axis=-1) q_count_for_nopt = q_count.view(batch_size * self.n_agents, -1) v_joint_repeated = ops.repeat_elements(v_joint, rep=self.n_agents, axis=1).reshape(-1, 1) error_nopt = q_eval_count + q_sum_mask.view(-1, 1) - ops.stop_gradient(q_count_for_nopt) + v_joint_repeated error_nopt_min, _ = ops.min(error_nopt, axis=-1) loss_nopt = ops.reduce_mean(error_nopt_min ** 2) # NOPT loss info["Q_joint"] = q_joint_choosen.mean().asnumpy() else: raise ValueError("Mixer {} not recognised.".format(self.config.agent)) # calculate the loss function loss = loss_td + self.config.lambda_opt * loss_opt + self.config.lambda_nopt * loss_nopt return loss, loss_td, loss_opt, loss_nopt, \ v_joint, y_dqn, q_tot_greedy, q_joint_greedy_hat, error_opt, error_nopt
[docs] def update(self, sample): self.iterations += 1 # prepare training data sample_Tensor = self.build_training_data(sample=sample, use_parameter_sharing=self.use_parameter_sharing, use_actions_mask=self.use_actions_mask, use_global_state=True) batch_size = sample_Tensor['batch_size'] state = sample_Tensor['state'] state_next = sample_Tensor['state_next'] obs = sample_Tensor['obs'] actions = sample_Tensor['actions'] obs_next = sample_Tensor['obs_next'] rewards = sample_Tensor['rewards'] terminals = sample_Tensor['terminals'] agent_mask = sample_Tensor['agent_mask'] avail_actions = sample_Tensor['avail_actions'] avail_actions_next = sample_Tensor['avail_actions_next'] IDs = sample_Tensor['agent_ids'] if self.use_parameter_sharing: key = self.model_keys[0] bs = batch_size * self.n_agents rewards_tot = rewards[key].mean(dim=1).reshape(batch_size, 1) terminals_tot = terminals[key].all(dim=1, keepdim=False).float().reshape(batch_size, 1) else: bs = batch_size rewards_tot = ops.stack(itemgetter(*self.agent_keys)(rewards), axis=1).mean(dim=-1, keepdim=True) terminals_tot = ops.stack(itemgetter(*self.agent_keys)(terminals), axis=1).all(dim=1, keepdim=True).float() info = self.callback.on_update_start(self.iterations, method="update", policy=self.policy, sample_Tensor=sample_Tensor, bs=bs, rewards_tot=rewards_tot, terminals_tot=terminals_tot) (loss, loss_td, loss_opt, loss_nopt, v_joint, y_dqn, q_tot_greedy, q_joint_greedy_hat, error_opt, error_nopt), grads = self.grad_fn(bs, batch_size, state, obs, actions, state_next, obs_next, agent_mask, avail_actions, avail_actions_next, rewards_tot, terminals_tot, IDs) if self.use_grad_clip: grads = clip_grads(grads, Tensor(-self.grad_clip_norm), Tensor(self.grad_clip_norm)) self.optimizer(grads) if self.iterations % self.sync_frequency == 0: self.policy.copy_target() self.scheduler.step() lr = self.scheduler.get_last_lr()[0] info.update({ "learning_rate": lr.asnumpy(), "loss_td": loss_td.asnumpy(), "loss_opt": loss_opt.asnumpy(), "loss_nopt": loss_nopt.asnumpy(), "loss": loss.asnumpy() }) info.update(self.callback.on_update_end(self.iterations, method="update", policy=self.policy, info=info, v_joint=v_joint, y_dqn=y_dqn, q_tot_greedy=q_tot_greedy, q_joint_greedy_hat=q_joint_greedy_hat, error_opt=error_opt, error_nopt=error_nopt)) return info