Source code for xuance.tensorflow.communications.gnn_comm
from typing import Sequence, Optional, Union
import torch
from torch import nn
[docs]
class DGNComm(nn.Module):
def __init__(self,
input_shape: Sequence[int],
hidden_sizes: dict,
atten_head: Optional[int] = 1,
agent_keys: dict = None,
device: Optional[Union[str, int, torch.device]] = None,
**kwargs
):
super(DGNComm, self).__init__()
self.input_shape = input_shape
self.device = device
self.fc_hidden_sizes = hidden_sizes["fc_hidden_sizes"]
self.recurrent_hidden_size : int = hidden_sizes["recurrent_hidden_size"]
self.agent_keys = agent_keys
self.config = kwargs['config']
self.obs_encoder = nn.Linear(input_shape[0], self.recurrent_hidden_size, device=self.device)
self.atten_head = atten_head
self.q_dim = self.recurrent_hidden_size // self.atten_head
self.scale = self.q_dim ** -0.5
self.q = nn.ModuleList(
nn.Linear(self.recurrent_hidden_size, self.q_dim).to(self.device) for i in range(self.atten_head))
self.k = nn.ModuleList(
nn.Linear(self.recurrent_hidden_size, self.q_dim).to(self.device) for i in range(self.atten_head))
self.v = nn.ModuleList(
nn.Linear(self.recurrent_hidden_size, self.q_dim).to(self.device) for i in range(self.atten_head))
[docs]
def obs_encode(self, observation):
observation = torch.as_tensor(observation, dtype=torch.float32, device=self.device)
return self.obs_encoder(observation)
[docs]
def gcn(self, obs, matrix, alive_ally):
global atten_scores
alive_agent_num = torch.sum(torch.stack(list(alive_ally.values()), dim=2), dim=2)
alive_agent_num = torch.clamp(alive_agent_num, min=2.0)
matrix = [data / alive_agent_num for data in matrix]
matrix = torch.stack(matrix, dim=-2)
gnn_out = []
for i in range(self.atten_head):
atten_query = self.q[i](obs).unsqueeze(dim=-2)
atten_key = self.k[i](matrix)
atten_scores = nn.Softmax(dim=-1)(torch.matmul(atten_query, atten_key.transpose(-1, -2))) * self.scale
atten_value = self.v[i](matrix)
gnn_out.append(torch.matmul(atten_scores, atten_value).squeeze(-2))
gnn_out = torch.cat(gnn_out, dim=-1)
return obs + gnn_out
[docs]
def forward(self, key: str, obs: dict, alive_ally: dict):
alive_ally = {k: torch.as_tensor(alive_ally[k], dtype=torch.float32, device=self.device) for k in
alive_ally.keys()}
# get matrix
matrix = []
matrix.append(obs[key])
for k in self.agent_keys:
if k != key:
matrix.append(obs[k] * alive_ally[k])
gcn_out = self.gcn(obs[key], matrix, alive_ally)
return gcn_out