Source code for xuance.torch.communications.gnn_comm

from argparse import Namespace
from typing import Optional, Sequence, Union
import torch
from torch import nn
from xuance.torch.communications.comm_net import CommNet


[docs] class DGNComm(CommNet): def __init__(self, input_shape: Sequence[int], hidden_sizes: dict, comm_passes: Optional[int] = 1, model_keys: dict = None, agent_keys: dict = None, n_agents: int = 1, device: Optional[Union[str, int, torch.device]] = None, config: Optional[Namespace] = None, **kwargs): super(DGNComm, self).__init__(input_shape, hidden_sizes, comm_passes, model_keys, agent_keys, n_agents, device, config, **kwargs) self.n_head = config.n_head self.convolution_layers = config.convolution_layers self.gcn = self.build_gcn(self.convolution_layers, self.n_head) self.activation = nn.ReLU()
[docs] def build_gcn(self, num_layers, num_heads, activation = nn.ReLU()): gcn_layers = [] for _ in range(num_layers): gcn_layers.append(GraphMultiHeadAttentionLayer(self.recurrent_hidden_size, self.recurrent_hidden_size, num_heads, device=self.device)) gcn_layers.append(activation) return nn.Sequential(*gcn_layers)
[docs] def create_adjacency_matrix(self, alive_ally): adj_matrix = torch.matmul(alive_ally, alive_ally.transpose(1, 2)) return adj_matrix
[docs] def gcn_block(self, x, adj_matrix): for layer in self.gcn: if isinstance(layer, GraphMultiHeadAttentionLayer): x = layer(x, adj_matrix) else: x = layer(x) return x
[docs] def forward(self, obs: torch.Tensor, msg_send: dict, alive_ally: dict) -> torch.Tensor: alive_ally = {k: torch.as_tensor(alive_ally[k], dtype=torch.float32, device=self.device) for k in self.agent_keys} batch_size, seq_length = obs.shape[0], obs.shape[1] key = self.model_keys[0] msg_send = msg_send[key].view(batch_size // self.n_agents, self.n_agents, seq_length, -1) alive_ally = torch.stack(list(alive_ally.values()), dim=1) alive_ally = alive_ally.squeeze(2) msg_send = msg_send.squeeze(2) adj_matrix = self.create_adjacency_matrix(alive_ally) message = self.gcn_block(msg_send, adj_matrix) msg_receive = self.msg_encoder(message) msg_receive = msg_receive.view(batch_size, seq_length, -1) return msg_receive
[docs] class GraphMultiHeadAttentionLayer(nn.Module): def __init__(self, input_dim, output_dim, num_heads=2, dropout=0.6, concat=True, device='cpu'): super().__init__() self.device = device self.num_heads = num_heads self.output_dim = output_dim self.concat = concat self.head_dim = output_dim // num_heads assert self.head_dim * num_heads == output_dim, "output_dim必须能被num_heads整除" self.W = nn.Linear(input_dim, num_heads * self.head_dim).to(self.device) self.a = nn.Parameter(torch.empty(size=(2 * self.head_dim, 1))).to(self.device) self.leakyrelu = nn.LeakyReLU(0.2) self.dropout = nn.Dropout(dropout) self.reset_parameters()
[docs] def reset_parameters(self): nn.init.xavier_uniform_(self.W.weight) nn.init.xavier_uniform_(self.a)
[docs] def forward(self, x, adj): """ x: node feature matrix [batch_size, n_agents, input_dim] adj: adjacency matrix [batch_size, n_agents, n_agents] """ batch_size, n_agents, _ = x.shape # [batch_size, n_agents, num_heads * head_dim] h = self.W(x) h = h.view(batch_size, n_agents, self.num_heads, self.head_dim) # h_i: [batch_size, n_agents, num_heads, head_dim] # h_j: [batch_size, n_agents, num_heads, head_dim] h_i = h.unsqueeze(2) # [batch_size, n_agents, 1, num_heads, head_dim] h_j = h.unsqueeze(1) # [batch_size, 1, n_agents, num_heads, head_dim] # [batch_size, n_agents, n_agents, num_heads, 2*head_dim] h_cat = torch.cat([h_i.repeat(1, 1, n_agents, 1, 1), h_j.repeat(1, n_agents, 1, 1, 1)], dim=-1) # [batch_size, n_agents, n_agents, num_heads] e = self.leakyrelu(torch.matmul(h_cat, self.a).squeeze(-1)) mask = -9e15 * torch.ones_like(e) adj_expanded = adj.unsqueeze(-1).repeat(1, 1, 1, self.num_heads) e = torch.where(adj_expanded > 0, e, mask) attention = nn.Softmax(dim=2)(e) attention = self.dropout(attention) # attention: [batch_size, n_agents, n_agents, num_heads] # h: [batch_size, n_agents, num_heads, head_dim] h = h.permute(0, 2, 1, 3) # [batch_size, num_heads, n_agents, head_dim] output = torch.zeros(batch_size, n_agents, self.num_heads, self.head_dim).to(self.device) for head in range(self.num_heads): att_head = attention[:, :, :, head] # [batch_size, n_agents, n_agents] h_head = h[:, head, :, :] # [batch_size, n_agents, head_dim] output[:, :, head, :] = torch.bmm(att_head, h_head) if self.concat: output = output.reshape(batch_size, n_agents, self.output_dim) else: output = output.mean(dim=2) return output