Source code for xuance.torch.learners.multi_agent_rl.dcg_learner

"""
DCG: Deep coordination graphs
Paper link: http://proceedings.mlr.press/v119/boehmer20a/boehmer20a.pdf
Implementation: Pytorch
"""
import torch
from torch import nn
from operator import itemgetter
from xuance.torch.learners import LearnerMAS
from xuance.common import List
from argparse import Namespace
try:
    import torch_scatter
except ImportError:
    print("The module torch_scatter is not installed.")


[docs] class DCG_Learner(LearnerMAS): def __init__(self, config: Namespace, model_keys: List[str], agent_keys: List[str], policy: nn.Module, callback): super(DCG_Learner, self).__init__(config, model_keys, agent_keys, policy, callback) self.optimizer = torch.optim.Adam(self.policy.parameters_model, self.learning_rate, eps=1e-5) self.scheduler = torch.optim.lr_scheduler.LinearLR(self.optimizer, start_factor=1.0, end_factor=self.end_factor_lr_decay, total_iters=self.total_iters) self.dim_hidden_state = policy.representation[self.model_keys[0]].output_shapes['state'][0] self.dim_act = max([self.policy.action_space[key].n for key in agent_keys]) self.sync_frequency = config.sync_frequency self.mse_loss = nn.MSELoss()
[docs] def get_graph_values(self, hidden_states, use_target_net=False): if use_target_net: utilities = self.policy.target_utility(hidden_states) payoff = self.policy.target_payoffs(hidden_states, self.policy.graph.edges_from, self.policy.graph.edges_to) else: utilities = self.policy.utility(hidden_states) payoff = self.policy.payoffs(hidden_states, self.policy.graph.edges_from, self.policy.graph.edges_to) return utilities, payoff
[docs] def act(self, hidden_states, avail_actions=None): """ Calculate the actions via belief propagation. Args: hidden_states (torch.Tensor): The hidden states for the representation of all agents. avail_actions (torch.Tensor): The avail actions for the agents, default is None. Returns: The actions. """ with torch.no_grad(): f_i, f_ij = self.get_graph_values(hidden_states) n_edges = self.policy.graph.n_edges n_vertexes = self.policy.graph.n_vertexes f_i_mean = f_i.double() / n_vertexes f_ij_mean = f_ij.double() / n_edges f_ji_mean = f_ij_mean.transpose(dim0=-1, dim1=-2).clone() batch_size = f_i.shape[0] msg_ij = torch.zeros(batch_size, n_edges, self.dim_act).to(self.device) # i -> j (send) msg_ji = torch.zeros(batch_size, n_edges, self.dim_act).to(self.device) # j -> i (receive) # msg_forward = torch_scatter.scatter_add(src=msg_ij, index=self.policy.graph.edges_to, dim=1, dim_size=n_vertexes) msg_backward = torch_scatter.scatter_add(src=msg_ji, index=self.policy.graph.edges_from, dim=1, dim_size=n_vertexes) utility = f_i_mean + msg_forward + msg_backward if len(self.policy.graph.edges) != 0: for i in range(self.config.n_msg_iterations): joint_forward = (utility[:, self.policy.graph.edges_from, :] - msg_ji).unsqueeze(dim=-1) + f_ij_mean joint_backward = (utility[:, self.policy.graph.edges_to, :] - msg_ij).unsqueeze(dim=-1) + f_ji_mean msg_ij = joint_forward.max(dim=-2).values msg_ji = joint_backward.max(dim=-2).values if self.config.msg_normalized: msg_ij -= msg_ij.mean(dim=-1, keepdim=True) msg_ji -= msg_ji.mean(dim=-1, keepdim=True) msg_forward = torch_scatter.scatter_add(src=msg_ij, index=self.policy.graph.edges_to, dim=1, dim_size=n_vertexes) msg_backward = torch_scatter.scatter_add(src=msg_ji, index=self.policy.graph.edges_from, dim=1, dim_size=n_vertexes) utility = f_i_mean + msg_forward + msg_backward if avail_actions is not None: avail_actions = torch.Tensor(avail_actions) utility_detach = utility.clone().detach() utility_detach[avail_actions == 0] = -1e10 actions_greedy = utility_detach.argmax(dim=-1) else: actions_greedy = utility.argmax(dim=-1) return actions_greedy
[docs] def q_dcg(self, hidden_states, actions, states=None, use_target_net=False): f_i, f_ij = self.get_graph_values(hidden_states, use_target_net=use_target_net) f_i_mean = f_i.double() / self.policy.graph.n_vertexes f_ij_mean = f_ij.double() / self.policy.graph.n_edges utilities = f_i_mean.gather(-1, actions.unsqueeze(dim=-1).long()).sum(dim=1) if len(self.policy.graph.edges) == 0 or self.config.n_msg_iterations == 0: return utilities actions_ij = (actions[:, self.policy.graph.edges_from] * self.dim_act + \ actions[:, self.policy.graph.edges_to]).unsqueeze(-1) payoffs = f_ij_mean.reshape(list(f_ij_mean.shape[0:-2]) + [-1]).gather(-1, actions_ij.long()).sum(dim=1) if self.config.agent == "DCG_S": state_value = self.policy.bias(states) return utilities + payoffs + state_value else: return utilities + payoffs
[docs] def update(self, sample): self.iterations += 1 # prepare training data sample_Tensor = self.build_training_data(sample=sample, use_parameter_sharing=self.use_parameter_sharing, use_actions_mask=self.use_actions_mask, use_global_state=True if self.config.agent == "DCG_S" else False) batch_size = sample_Tensor['batch_size'] state = sample_Tensor['state'] state_next = sample_Tensor['state_next'] obs = sample_Tensor['obs'] actions = sample_Tensor['actions'] obs_next = sample_Tensor['obs_next'] rewards = sample_Tensor['rewards'] terminals = sample_Tensor['terminals'] avail_actions = sample_Tensor['avail_actions'] avail_actions_next = sample_Tensor['avail_actions_next'] if self.use_parameter_sharing: key = self.model_keys[0] rewards_tot = rewards[key].mean(dim=1).reshape(batch_size, 1) terminals_tot = terminals[key].all(dim=1, keepdim=False).float().reshape(batch_size, 1) actions = actions[key].reshape(batch_size, self.n_agents) if self.use_actions_mask: avail_actions_next = avail_actions_next[key].reshape(batch_size, self.n_agents, -1) else: rewards_tot = torch.stack(itemgetter(*self.agent_keys)(rewards), dim=1).mean(dim=-1, keepdim=True) terminals_tot = torch.stack(itemgetter(*self.agent_keys)(terminals), dim=1).all(dim=1, keepdim=True).float() actions = torch.stack(itemgetter(*self.agent_keys)(actions), dim=-1) if self.use_actions_mask: avail_actions_next = torch.stack(itemgetter(*self.agent_keys)(avail_actions_next), dim=-2) info = self.callback.on_update_start(self.iterations, method="update", policy=self.policy, sample_Tensor=sample_Tensor, rewards_tot=rewards_tot, terminals_tot=terminals_tot, actions=actions, avail_actions_next=avail_actions_next) _, hidden_states = self.policy.get_hidden_states(batch_size, obs, use_target_net=False) q_tot_eval = self.q_dcg(hidden_states, actions, states=state, use_target_net=False) _, hidden_states_next = self.policy.get_hidden_states(batch_size, obs_next, use_target_net=False) action_next_greedy = torch.Tensor(self.act(hidden_states_next, avail_actions_next)).to(self.device) _, hidden_states_target = self.policy.get_hidden_states(batch_size, obs_next, use_target_net=True) q_tot_next = self.q_dcg(hidden_states_target, action_next_greedy, states=state_next, use_target_net=True) q_tot_target = rewards_tot + (1 - terminals_tot) * self.gamma * q_tot_next # calculate the loss function loss = self.mse_loss(q_tot_eval, q_tot_target.detach()) self.optimizer.zero_grad() loss.backward() if self.use_grad_clip: torch.nn.utils.clip_grad_norm_(self.policy.parameters_model, self.grad_clip_norm) self.optimizer.step() if self.scheduler is not None: self.scheduler.step() lr = self.optimizer.state_dict()['param_groups'][0]['lr'] info.update({ "learning_rate": lr, "loss_Q": loss.item(), "predictQ": q_tot_eval.mean().item() }) if self.iterations % self.sync_frequency == 0: self.policy.copy_target() info.update(self.callback.on_update_end(self.iterations, method="update", policy=self.policy, info=info, hidden_states=hidden_states, q_tot_eval=q_tot_eval, hidden_states_next=hidden_states_next, action_next_greedy=action_next_greedy, hidden_states_target=hidden_states_target, q_tot_next=q_tot_next, q_tot_target=q_tot_target, loss=loss)) return info
[docs] def update_rnn(self, sample): self.iterations += 1 # prepare training data sample_Tensor = self.build_training_data(sample=sample, use_parameter_sharing=self.use_parameter_sharing, use_actions_mask=self.use_actions_mask, use_global_state=True if self.config.agent == "DCG_S" else False) batch_size = sample_Tensor['batch_size'] seq_len = sample['sequence_length'] state = sample_Tensor['state'] obs = sample_Tensor['obs'] actions = sample_Tensor['actions'] rewards = sample_Tensor['rewards'] terminals = sample_Tensor['terminals'] avail_actions = sample_Tensor['avail_actions'] filled = sample_Tensor['filled'].reshape([-1, 1]) if self.use_parameter_sharing: key = self.model_keys[0] bs_rnn = batch_size * self.n_agents rewards_tot = rewards[key].mean(dim=1).reshape([-1, 1]) terminals_tot = terminals[key].all(dim=1, keepdim=False).float().reshape([-1, 1]) actions = actions[key].reshape(batch_size, self.n_agents, seq_len).transpose(1, 2) if self.use_actions_mask: avail_actions = avail_actions[key].reshape(batch_size, self.n_agents, seq_len + 1, -1).transpose(1, 2) else: bs_rnn = batch_size rewards_tot = torch.stack(itemgetter(*self.agent_keys)(rewards), dim=1).mean(dim=1).reshape(-1, 1) terminals_tot = torch.stack(itemgetter(*self.agent_keys)(terminals), dim=1).all(1).reshape([-1, 1]).float() actions = torch.stack(itemgetter(*self.agent_keys)(actions), dim=-1) if self.use_actions_mask: avail_actions = torch.stack(itemgetter(*self.agent_keys)(avail_actions), dim=-2) info = self.callback.on_update_start(self.iterations, method="update_rnn", policy=self.policy, sample_Tensor=sample_Tensor, bs_rnn=bs_rnn, rewards_tot=rewards_tot, terminals_tot=terminals_tot, actions=actions, avail_actions=avail_actions) rnn_hidden = {k: self.policy.representation[k].init_hidden(bs_rnn) for k in self.model_keys} _, hidden_states = self.policy.get_hidden_states(batch_size, obs, rnn_hidden, use_target_net=False) state_current = state[:, :-1] if self.config.agent == "DCG_S" else None state_next = state[:, 1:] if self.config.agent == "DCG_S" else None q_tot_eval = self.q_dcg(hidden_states[:, :-1].reshape(batch_size * seq_len, self.n_agents, -1), actions.reshape(batch_size * seq_len, self.n_agents), states=state_current, use_target_net=False) if self.use_actions_mask: avail_a_next = avail_actions[:, 1:].reshape(batch_size * seq_len, self.n_agents, -1) else: avail_a_next = None hidden_states_next = hidden_states[:, 1:].reshape(batch_size * seq_len, self.n_agents, -1) action_next_greedy = torch.Tensor(self.act(hidden_states_next, avail_actions=avail_a_next)).to(self.device) rnn_hidden_target = {k: self.policy.target_representation[k].init_hidden(bs_rnn) for k in self.model_keys} _, hidden_states_tar = self.policy.get_hidden_states(batch_size, obs, rnn_hidden_target, use_target_net=True) q_tot_next = self.q_dcg(hidden_states_tar[:, 1:].reshape(batch_size * seq_len, self.n_agents, -1), action_next_greedy, states=state_next, use_target_net=True) q_tot_target = rewards_tot + (1 - terminals_tot) * self.gamma * q_tot_next td_error = (q_tot_eval - q_tot_target.detach()) * filled # calculate the loss function loss = (td_error ** 2).sum() / filled.sum() self.optimizer.zero_grad() loss.backward() if self.use_grad_clip: torch.nn.utils.clip_grad_norm_(self.policy.parameters_model, self.grad_clip_norm) self.optimizer.step() if self.scheduler is not None: self.scheduler.step() lr = self.optimizer.state_dict()['param_groups'][0]['lr'] info.update({ "learning_rate": lr, "loss_Q": loss.item(), "predictQ": q_tot_eval.mean().item() }) if self.iterations % self.sync_frequency == 0: self.policy.copy_target() info.update(self.callback.on_update_end(self.iterations, method="update_rnn", policy=self.policy, info=info, hidden_states=hidden_states, q_tot_eval=q_tot_eval, hidden_states_next=hidden_states_next, action_next_greedy=action_next_greedy, hidden_states_target=hidden_states_tar, q_tot_next=q_tot_next, q_tot_target=q_tot_target, loss=loss)) return info