import torch
import torch.nn as nn
import torch.nn.functional as F
from gymnasium.spaces import Discrete
from xuance.common import Sequence, Optional, Callable, Union, Dict
from xuance.torch import Tensor, Module
from xuance.torch.utils import ModuleType, mlp_block, gru_block, lstm_block
from xuance.torch.utils import CategoricalDistribution, DiagGaussianDistribution, ActivatedDiagGaussianDistribution
[docs]
class BasicQhead(Module):
"""
A base class to build Q network and calculate the Q values.
Args:
state_dim (int): The input state dimension.
n_actions (int): The number of discrete actions.
hidden_sizes: List of hidden units for fully connect layers.
normalize (Optional[ModuleType]): The layer normalization over a minibatch of inputs.
initialize (Optional[Callable[..., Tensor]]): The parameters initializer.
activation (Optional[ModuleType]): The activation function for each layer.
device (Optional[Union[str, int, torch.device]]): The calculating device.
"""
def __init__(self,
state_dim: int,
n_actions: int,
hidden_sizes: Sequence[int],
normalize: Optional[ModuleType] = None,
initialize: Optional[Callable[..., Tensor]] = None,
activation: Optional[ModuleType] = None,
device: Optional[Union[str, int, torch.device]] = None):
super(BasicQhead, self).__init__()
layers_ = []
input_shape = (state_dim,)
for h in hidden_sizes:
mlp, input_shape = mlp_block(input_shape[0], h, normalize, activation, initialize, device)
layers_.extend(mlp)
layers_.extend(mlp_block(input_shape[0], n_actions, None, None, initialize, device)[0])
self.model = nn.Sequential(*layers_)
[docs]
def forward(self, x: Tensor):
"""
Returns the output of the Q network.
Parameters:
x (Tensor): The input tensor.
"""
return self.model(x)
[docs]
class DuelQhead(Module):
"""
A base class to build Q network and calculate the dueling Q values.
Args:
state_dim (int): The input state dimension.
n_actions (int): The number of discrete actions.
hidden_sizes: List of hidden units for fully connect layers.
normalize (Optional[ModuleType]): The layer normalization over a minibatch of inputs.
initialize (Optional[Callable[..., Tensor]]): The parameters initializer.
activation (Optional[ModuleType]): The activation function for each layer.
device (Optional[Union[str, int, torch.device]]): The calculating device.
"""
def __init__(self,
state_dim: int,
n_actions: int,
hidden_sizes: Sequence[int],
normalize: Optional[ModuleType] = None,
initialize: Optional[Callable[..., Tensor]] = None,
activation: Optional[ModuleType] = None,
device: Optional[Union[str, int, torch.device]] = None):
super(DuelQhead, self).__init__()
v_layers = []
input_shape = (state_dim,)
for h in hidden_sizes:
v_mlp, input_shape = mlp_block(input_shape[0], h // 2, normalize, activation, initialize, device)
v_layers.extend(v_mlp)
v_layers.extend(mlp_block(input_shape[0], 1, None, None, None, device)[0])
a_layers = []
input_shape = (state_dim,)
for h in hidden_sizes:
a_mlp, input_shape = mlp_block(input_shape[0], h // 2, normalize, activation, initialize, device)
a_layers.extend(a_mlp)
a_layers.extend(mlp_block(input_shape[0], n_actions, None, None, None, device)[0])
self.a_model = nn.Sequential(*a_layers)
self.v_model = nn.Sequential(*v_layers)
[docs]
def forward(self, x: Tensor):
"""
Returns the dueling Q-values.
Parameters:
x (Tensor): The input tensor.
Returns:
q: The dueling Q-values.
"""
v = self.v_model(x)
a = self.a_model(x)
q = v + (a - a.mean(dim=-1).unsqueeze(dim=-1))
return q
[docs]
class C51Qhead(Module):
"""
A base class to build Q network and calculate the distributional Q values.
Args:
state_dim (int): The input state dimension.
n_actions (int): The number of discrete actions.
atom_num (int): The number of atoms.
hidden_sizes: List of hidden units for fully connect layers.
normalize (Optional[ModuleType]): The layer normalization over a minibatch of inputs.
initialize (Optional[Callable[..., Tensor]]): The parameters initializer.
activation (Optional[ModuleType]): The activation function for each layer.
device (Optional[Union[str, int, torch.device]]): The calculating device.
"""
def __init__(self,
state_dim: int,
n_actions: int,
atom_num: int,
hidden_sizes: Sequence[int],
normalize: Optional[ModuleType] = None,
initialize: Optional[Callable[..., Tensor]] = None,
activation: Optional[ModuleType] = None,
device: Optional[Union[str, int, torch.device]] = None):
super(C51Qhead, self).__init__()
self.n_actions = n_actions
self.atom_num = atom_num
layers = []
input_shape = (state_dim,)
for h in hidden_sizes:
mlp, input_shape = mlp_block(input_shape[0], h, normalize, activation, initialize, device)
layers.extend(mlp)
layers.extend(mlp_block(input_shape[0], n_actions * atom_num, None, None, initialize, device)[0])
self.model = nn.Sequential(*layers)
[docs]
def forward(self, x: Tensor):
"""
Returns the discrete action distributions.
Parameters:
x (Tensor): The input tensor.
Returns:
dist_probs: The probability distribution of the discrete actions.
"""
dist_logits = self.model(x).view(-1, self.n_actions, self.atom_num)
dist_probs = F.softmax(dist_logits, dim=-1)
return dist_probs
[docs]
class QRDQNhead(Module):
"""
A base class to build Q networks for QRDQN policy.
Args:
state_dim (int): The input state dimension.
n_actions (int): The number of discrete actions.
atom_num (int): The number of atoms.
hidden_sizes: List of hidden units for fully connect layers.
normalize (Optional[ModuleType]): The layer normalization over a minibatch of inputs.
initialize (Optional[Callable[..., Tensor]]): The parameters initializer.
activation (Optional[ModuleType]): The activation function for each layer.
device (Optional[Union[str, int, torch.device]]): The calculating device.
"""
def __init__(self,
state_dim: int,
n_actions: int,
atom_num: int,
hidden_sizes: Sequence[int],
normalize: Optional[ModuleType] = None,
initialize: Optional[Callable[..., Tensor]] = None,
activation: Optional[ModuleType] = None,
device: Optional[Union[str, int, torch.device]] = None):
super(QRDQNhead, self).__init__()
self.n_actions = n_actions
self.atom_num = atom_num
layers = []
input_shape = (state_dim,)
for h in hidden_sizes:
mlp, input_shape = mlp_block(input_shape[0], h, normalize, activation, initialize, device)
layers.extend(mlp)
layers.extend(mlp_block(input_shape[0], n_actions * atom_num, None, None, None, device)[0])
self.model = nn.Sequential(*layers)
[docs]
def forward(self, x: Tensor):
"""
Returns the quantiles of the distribution.
Parameters:
x (Tensor): The input tensor.
Returns:
quantiles: The quantiles of the action distribution.
"""
quantiles = self.model(x).view(-1, self.n_actions, self.atom_num)
return quantiles
[docs]
class BasicRecurrent(Module):
"""Build recurrent neural network to calculate Q values."""
def __init__(self, **kwargs):
super(BasicRecurrent, self).__init__()
self.lstm = False
if kwargs["rnn"] == "GRU":
output, _ = gru_block(kwargs["input_dim"],
kwargs["recurrent_hidden_size"],
kwargs["recurrent_layer_N"],
kwargs["dropout"],
kwargs["initialize"],
kwargs["device"])
elif kwargs["rnn"] == "LSTM":
self.lstm = True
output, _ = lstm_block(kwargs["input_dim"],
kwargs["recurrent_hidden_size"],
kwargs["recurrent_layer_N"],
kwargs["dropout"],
kwargs["initialize"],
kwargs["device"])
else:
raise "Unknown recurrent module!"
self.rnn_layer = output
fc_layer = mlp_block(kwargs["recurrent_hidden_size"], kwargs["action_dim"], None, None, None, kwargs["device"])[
0]
self.model = nn.Sequential(*fc_layer)
[docs]
def forward(self, x: Tensor, h: Tensor, c: Tensor = None):
"""Returns the rnn hidden and Q-values via RNN networks."""
self.rnn_layer.flatten_parameters()
if self.lstm:
output, (hn, cn) = self.rnn_layer(x, (h, c))
return hn, cn, self.model(output)
else:
output, hn = self.rnn_layer(x, h)
return hn, self.model(output)
[docs]
class ActorNet(Module):
"""
The actor network for deterministic policy, which outputs activated continuous actions directly.
Args:
state_dim (int): The input state dimension.
action_dim (int): The dimension of continuous action space.
hidden_sizes (Sequence[int]): List of hidden units for fully connect layers.
normalize (Optional[ModuleType]): The layer normalization over a minibatch of inputs.
initialize (Optional[Callable[..., Tensor]]): The parameters initializer.
activation (Optional[ModuleType]): The activation function for each layer.
activation_action (Optional[ModuleType]): The activation of final layer to bound the actions.
device (Optional[Union[str, int, torch.device]]): The calculating device.
"""
def __init__(self,
state_dim: int,
action_dim: int,
hidden_sizes: Sequence[int],
normalize: Optional[ModuleType] = None,
initialize: Optional[Callable[..., Tensor]] = None,
activation: Optional[ModuleType] = None,
activation_action: Optional[ModuleType] = None,
device: Optional[Union[str, int, torch.device]] = None):
super(ActorNet, self).__init__()
layers = []
input_shape = (state_dim,)
for h in hidden_sizes:
mlp, input_shape = mlp_block(input_shape[0], h, normalize, activation, initialize, device)
layers.extend(mlp)
layers.extend(mlp_block(input_shape[0], action_dim, None, activation_action, initialize, device)[0])
self.model = nn.Sequential(*layers)
[docs]
def forward(self, x: Tensor, avail_actions: Optional[Tensor] = None):
"""
Returns the output of the actor.
Parameters:
x (Tensor): The input tensor.
avail_actions (Optional[Tensor]): The actions mask values when use actions mask, default is None.
"""
logits = self.model(x)
if avail_actions is not None:
logits[avail_actions == 0] = -1e10
return logits
[docs]
class CategoricalActorNet(Module):
"""
The actor network for categorical policy, which outputs a distribution over all discrete actions.
Args:
state_dim (int): The input state dimension.
action_dim (int): The dimension of continuous action space.
hidden_sizes (Sequence[int]): List of hidden units for fully connect layers.
normalize (Optional[ModuleType]): The layer normalization over a minibatch of inputs.
initialize (Optional[Callable[..., Tensor]]): The parameters initializer.
activation (Optional[ModuleType]): The activation function for each layer.
device (Optional[Union[str, int, torch.device]]): The calculating device.
"""
def __init__(self,
state_dim: int,
action_dim: int,
hidden_sizes: Sequence[int],
normalize: Optional[ModuleType] = None,
initialize: Optional[Callable[..., Tensor]] = None,
activation: Optional[ModuleType] = None,
device: Optional[Union[str, int, torch.device]] = None):
super(CategoricalActorNet, self).__init__()
layers = []
input_shape = (state_dim,)
for h in hidden_sizes:
mlp, input_shape = mlp_block(input_shape[0], h, normalize, activation, initialize, device)
layers.extend(mlp)
layers.extend(mlp_block(input_shape[0], action_dim, None, None, initialize, device)[0])
self.model = nn.Sequential(*layers)
self.dist = CategoricalDistribution(action_dim)
[docs]
def forward(self, x: Tensor, avail_actions: Optional[Tensor] = None):
"""
Returns the stochastic distribution over all discrete actions.
Parameters:
x (Tensor): The input tensor.
avail_actions (Optional[Tensor]): The actions mask values when use actions mask, default is None.
Returns:
self.dist: CategoricalDistribution(action_dim), a distribution over all discrete actions.
"""
logits = self.model(x)
if avail_actions is not None:
logits[avail_actions == 0] = -1e10
self.dist.set_param(logits=logits)
return self.dist
[docs]
class CategoricalActorNet_SAC(CategoricalActorNet):
"""
The actor network for categorical policy in SAC-DIS, which outputs a distribution over all discrete actions.
Args:
state_dim (int): The input state dimension.
action_dim (int): The dimension of continuous action space.
hidden_sizes (Sequence[int]): List of hidden units for fully connect layers.
normalize (Optional[ModuleType]): The layer normalization over a minibatch of inputs.
initialize (Optional[Callable[..., Tensor]]): The parameters initializer.
activation (Optional[ModuleType]): The activation function for each layer.
device (Optional[Union[str, int, torch.device]]): The calculating device.
"""
def __init__(self,
state_dim: int,
action_dim: int,
hidden_sizes: Sequence[int],
normalize: Optional[ModuleType] = None,
initialize: Optional[Callable[..., Tensor]] = None,
activation: Optional[ModuleType] = None,
device: Optional[Union[str, int, torch.device]] = None):
super(CategoricalActorNet_SAC, self).__init__(state_dim, action_dim, hidden_sizes,
normalize, initialize, activation, device)
self.output = nn.Softmax(dim=-1)
[docs]
def forward(self, x: Tensor, avail_actions: Optional[Tensor] = None):
"""
Returns the stochastic distribution over all discrete actions.
Parameters:
x (Tensor): The input tensor.
avail_actions (Optional[Tensor]): The actions mask values when use actions mask, default is None.
Returns:
self.dist: CategoricalDistribution(action_dim), a distribution over all discrete actions.
"""
logits = self.model(x)
if avail_actions is not None:
logits[avail_actions == 0] = -1e10
probs = self.output(logits)
self.dist.set_param(probs=probs)
return self.dist
[docs]
class GaussianActorNet(Module):
"""
The actor network for Gaussian policy, which outputs a distribution over the continuous action space.
Args:
state_dim (int): The input state dimension.
action_dim (int): The dimension of continuous action space.
hidden_sizes (Sequence[int]): List of hidden units for fully connect layers.
normalize (Optional[ModuleType]): The layer normalization over a minibatch of inputs.
initialize (Optional[Callable[..., Tensor]]): The parameters initializer.
activation (Optional[ModuleType]): The activation function for each layer.
activation_action (Optional[ModuleType]): The activation of final layer to bound the actions.
device (Optional[Union[str, int, torch.device]]): The calculating device.
"""
def __init__(self,
state_dim: int,
action_dim: int,
hidden_sizes: Sequence[int],
normalize: Optional[ModuleType] = None,
initialize: Optional[Callable[..., Tensor]] = None,
activation: Optional[ModuleType] = None,
activation_action: Optional[ModuleType] = None,
device: Optional[Union[str, int, torch.device]] = None):
super(GaussianActorNet, self).__init__()
layers = []
input_shape = (state_dim,)
for h in hidden_sizes:
mlp, input_shape = mlp_block(input_shape[0], h, normalize, activation, initialize, device)
layers.extend(mlp)
layers.extend(mlp_block(input_shape[0], action_dim, None, activation_action, initialize, device)[0])
self.mu = nn.Sequential(*layers)
self.logstd = nn.Parameter(-torch.ones((action_dim,), device=device))
self.dist = DiagGaussianDistribution(action_dim)
[docs]
def forward(self, x: Tensor):
"""
Returns the stochastic distribution over the continuous action space.
Parameters:
x (Tensor): The input tensor.
Returns:
self.dist: A distribution over the continuous action space.
"""
self.dist.set_param(self.mu(x), self.logstd.exp())
return self.dist
[docs]
class CriticNet(Module):
"""
The critic network that outputs the evaluated values for states (State-Value) or state-action pairs (Q-value).
Args:
input_dim (int): The input dimension (dim_state or dim_state + dim_action).
hidden_sizes (Sequence[int]): List of hidden units for fully connect layers.
normalize (Optional[ModuleType]): The layer normalization over a minibatch of inputs.
initialize (Optional[Callable[..., Tensor]]): The parameters initializer.
activation (Optional[ModuleType]): The activation function for each layer.
device (Optional[Union[str, int, torch.device]]): The calculating device.
"""
def __init__(self,
input_dim: int,
hidden_sizes: Sequence[int],
normalize: Optional[ModuleType] = None,
initialize: Optional[Callable[..., Tensor]] = None,
activation: Optional[ModuleType] = None,
device: Optional[Union[str, int, torch.device]] = None):
super(CriticNet, self).__init__()
layers = []
input_shape = (input_dim,)
for h in hidden_sizes:
mlp, input_shape = mlp_block(input_shape[0], h, normalize, activation, initialize, device)
layers.extend(mlp)
layers.extend(mlp_block(input_shape[0], 1, None, None, initialize, device)[0])
self.model = nn.Sequential(*layers)
[docs]
def forward(self, x: Tensor):
"""
Returns the output of the Q network.
Parameters:
x (Tensor): The input tensor.
"""
return self.model(x)
[docs]
class GaussianActorNet_SAC(Module):
"""
The actor network for Gaussian policy in SAC, which outputs a distribution over the continuous action space.
Args:
state_dim (int): The input state dimension.
action_dim (int): The dimension of continuous action space.
hidden_sizes (Sequence[int]): List of hidden units for fully connect layers.
normalize (Optional[ModuleType]): The layer normalization over a minibatch of inputs.
initialize (Optional[Callable[..., Tensor]]): The parameters initializer.
activation (Optional[ModuleType]): The activation function for each layer.
activation_action (Optional[ModuleType]): The activation of final layer to bound the actions.
device (Optional[Union[str, int, torch.device]]): The calculating device.
"""
def __init__(self,
state_dim: int,
action_dim: int,
hidden_sizes: Sequence[int],
normalize: Optional[ModuleType] = None,
initialize: Optional[Callable[..., Tensor]] = None,
activation: Optional[ModuleType] = None,
activation_action: Optional[ModuleType] = None,
device: Optional[Union[str, int, torch.device]] = None):
super(GaussianActorNet_SAC, self).__init__()
layers = []
input_shape = (state_dim,)
for h in hidden_sizes:
mlp, input_shape = mlp_block(input_shape[0], h, normalize, activation, initialize, device)
layers.extend(mlp)
self.output = nn.Sequential(*layers)
self.out_mu = nn.Linear(hidden_sizes[-1], action_dim, device=device)
self.out_log_std = nn.Linear(hidden_sizes[-1], action_dim, device=device)
self.dist = ActivatedDiagGaussianDistribution(action_dim, activation_action, device)
[docs]
def forward(self, x: Tensor):
"""
Returns the stochastic distribution over the continuous action space.
Parameters:
x (Tensor): The input tensor.
Returns:
self.dist: A distribution over the continuous action space.
"""
output = self.output(x)
mu = self.out_mu(output)
log_std = torch.clamp(self.out_log_std(output), -20, 2)
std = log_std.exp()
self.dist.set_param(mu, std)
return self.dist
[docs]
class VDN_mixer(Module):
"""
The value decomposition networks mixer. (Additivity)
"""
def __init__(self):
super(VDN_mixer, self).__init__()
[docs]
def forward(self, values_n, states=None):
return values_n.sum(dim=1)
[docs]
class QMIX_mixer(Module):
"""
The QMIX mixer. (Monotonicity)
Args:
dim_state (int): The dimension of global state.
dim_hidden (int): The size of rach hidden layer.
dim_hypernet_hidden (int): The size of rach hidden layer for hyper network.
n_agents (int): The number of agents.
device (Optional[Union[str, int, torch.device]]): The calculating device.
"""
def __init__(self,
dim_state: Optional[int] = None,
dim_hidden: int = 32,
dim_hypernet_hidden: int = 32,
n_agents: int = 1,
device: Optional[Union[str, int, torch.device]] = None):
super(QMIX_mixer, self).__init__()
self.device = device
self.dim_state = dim_state
self.dim_hidden = dim_hidden
self.dim_hypernet_hidden = dim_hypernet_hidden
self.n_agents = n_agents
# self.hyper_w_1 = nn.Linear(self.dim_state, self.dim_hidden * self.n_agents)
# self.hyper_w_2 = nn.Linear(self.dim_state, self.dim_hidden)
self.hyper_w_1 = nn.Sequential(nn.Linear(self.dim_state, self.dim_hypernet_hidden),
nn.ReLU(),
nn.Linear(self.dim_hypernet_hidden, self.dim_hidden * self.n_agents)).to(device)
self.hyper_w_2 = nn.Sequential(nn.Linear(self.dim_state, self.dim_hypernet_hidden),
nn.ReLU(),
nn.Linear(self.dim_hypernet_hidden, self.dim_hidden)).to(device)
self.hyper_b_1 = nn.Linear(self.dim_state, self.dim_hidden).to(device)
self.hyper_b_2 = nn.Sequential(nn.Linear(self.dim_state, self.dim_hypernet_hidden),
nn.ReLU(),
nn.Linear(self.dim_hypernet_hidden, 1)).to(device)
[docs]
def forward(self, values_n, states):
"""
Returns the total Q-values for multi-agent team.
Parameters:
values_n: The individual values for agents in team.
states: The global states.
Returns:
q_tot: The total Q-values for the multi-agent team.
"""
states = torch.as_tensor(states, dtype=torch.float32, device=self.device)
states = states.reshape(-1, self.dim_state)
agent_qs = values_n.reshape(-1, 1, self.n_agents)
# First layer
w_1 = torch.abs(self.hyper_w_1(states))
w_1 = w_1.view(-1, self.n_agents, self.dim_hidden)
b_1 = self.hyper_b_1(states)
b_1 = b_1.view(-1, 1, self.dim_hidden)
hidden = F.elu(torch.bmm(agent_qs, w_1) + b_1)
# Second layer
w_2 = torch.abs(self.hyper_w_2(states))
w_2 = w_2.view(-1, self.dim_hidden, 1)
b_2 = self.hyper_b_2(states)
b_2 = b_2.view(-1, 1, 1)
# Compute final output
y = torch.bmm(hidden, w_2) + b_2
# Reshape and return
q_tot = y.view(-1, 1)
return q_tot
[docs]
class QMIX_FF_mixer(Module):
"""
The feedforward mixer without the constraints of monotonicity.
"""
def __init__(self,
dim_state: int = 0,
dim_hidden: int = 32,
n_agents: int = 1,
device: Optional[Union[str, int, torch.device]] = None):
super(QMIX_FF_mixer, self).__init__()
self.device = device
self.dim_state = dim_state
self.dim_hidden = dim_hidden
self.n_agents = n_agents
self.dim_input = self.n_agents + self.dim_state
self.ff_net = nn.Sequential(nn.Linear(self.dim_input, self.dim_hidden),
nn.ReLU(),
nn.Linear(self.dim_hidden, self.dim_hidden),
nn.ReLU(),
nn.Linear(self.dim_hidden, self.dim_hidden),
nn.ReLU(),
nn.Linear(self.dim_hidden, 1)).to(self.device)
self.ff_net_bias = nn.Sequential(nn.Linear(self.dim_state, self.dim_hidden),
nn.ReLU(),
nn.Linear(self.dim_hidden, 1)).to(self.device)
[docs]
def forward(self, values_n, states=None):
"""
Returns the feedforward total Q-values.
Parameters:
values_n: The individual Q-values.
states: The global states.
"""
states = states.reshape(-1, self.dim_state)
agent_qs = values_n.view([-1, self.n_agents])
inputs = torch.cat([agent_qs, states], dim=-1).to(self.device)
out_put = self.ff_net(inputs)
bias = self.ff_net_bias(states)
y = out_put + bias
q_tot = y.view([-1, 1])
return q_tot
[docs]
class QTRAN_base(Module):
"""
The basic QTRAN module.
Args:
dim_state (int): The dimension of the global state.
action_space (Dict[str, Discrete]): The action space for all agents.
dim_hidden (int): The dimension of the hidden layers.
n_agents (int): The number of agents.
dim_utility_hidden (int): The dimension of the utility hidden states.
use_parameter_sharing (bool): Whether to use parameters sharing trick.
device: Optional[Union[str, int, torch.device]]: The calculating device.
"""
def __init__(self,
dim_state: int = 0,
action_space: Dict[str, Discrete] = None,
dim_hidden: int = 32,
n_agents: int = 1,
dim_utility_hidden: int = 1,
use_parameter_sharing: bool = False,
device: Optional[Union[str, int, torch.device]] = None):
super(QTRAN_base, self).__init__()
self.dim_state = dim_state
self.action_space = action_space
self.n_actions_list = [a_space.n for a_space in action_space.values()]
self.n_actions_max = max(self.n_actions_list)
self.dim_hidden = dim_hidden
self.n_agents = n_agents
self.use_parameter_sharing = use_parameter_sharing
self.dim_q_input = self.dim_state + dim_utility_hidden + self.n_actions_max
self.dim_v_input = self.dim_state
self.Q_jt = nn.Sequential(nn.Linear(self.dim_q_input, self.dim_hidden),
nn.ReLU(),
nn.Linear(self.dim_hidden, self.dim_hidden),
nn.ReLU(),
nn.Linear(self.dim_hidden, 1)).to(device)
self.V_jt = nn.Sequential(nn.Linear(self.dim_v_input, self.dim_hidden),
nn.ReLU(),
nn.Linear(self.dim_hidden, self.dim_hidden),
nn.ReLU(),
nn.Linear(self.dim_hidden, 1)).to(device)
self.dim_ae_input = dim_utility_hidden + self.n_actions_max
self.action_encoding = nn.Sequential(nn.Linear(self.dim_ae_input, self.dim_ae_input),
nn.ReLU(),
nn.Linear(self.dim_ae_input, self.dim_ae_input)).to(device)
[docs]
def forward(self, states: Tensor, hidden_state_inputs: Tensor, actions_onehot: Tensor):
"""
Calculating the joint Q and V values.
Parameters:
states (Tensor): The global states.
hidden_state_inputs (Tensor): The joint hidden states inputs for QTRAN network.
actions_onehot (Tensor): The joint onehot actions for QTRAN network.
Returns:
q_jt (Tensor): The evaluated joint Q values.
v_jt (Tensor): The evaluated joint V values.
"""
h_state_action_input = torch.cat([hidden_state_inputs, actions_onehot], dim=-1)
h_state_action_encode = self.action_encoding(h_state_action_input).reshape(-1, self.n_agents, self.dim_ae_input)
h_state_action_encode = h_state_action_encode.sum(dim=1) # Sum across agents
input_q = torch.cat([states, h_state_action_encode], dim=-1)
input_v = states
q_jt = self.Q_jt(input_q)
v_jt = self.V_jt(input_v)
return q_jt, v_jt
[docs]
class QTRAN_alt(Module):
"""
The basic QTRAN module.
Parameters:
dim_state (int): The dimension of the global state.
action_space (Dict[str, Discrete]): The action space for all agents.
dim_hidden (int): The dimension of the hidden layers.
n_agents (int): The number of agents.
dim_utility_hidden (int): The dimension of the utility hidden states.
use_parameter_sharing (bool): Whether to use parameters sharing trick.
device: Optional[Union[str, int, torch.device]]: The calculating device.
"""
def __init__(self,
dim_state: int = 0,
action_space: Dict[str, Discrete] = None,
dim_hidden: int = 32,
n_agents: int = 1,
dim_utility_hidden: int = 1,
use_parameter_sharing: bool = False,
device: Optional[Union[str, int, torch.device]] = None):
super(QTRAN_alt, self).__init__()
self.dim_state = dim_state
self.action_space = action_space
self.n_actions_list = [a_space.n for a_space in action_space.values()]
self.n_actions_max = max(self.n_actions_list)
self.dim_hidden = dim_hidden
self.n_agents = n_agents
self.use_parameter_sharing = use_parameter_sharing
self.device = device
self.dim_q_input = self.dim_state + dim_utility_hidden + self.n_actions_max + self.n_agents
self.dim_v_input = self.dim_state
self.Q_jt = nn.Sequential(nn.Linear(self.dim_q_input, self.dim_hidden),
nn.ReLU(),
nn.Linear(self.dim_hidden, self.dim_hidden),
nn.ReLU(),
nn.Linear(self.dim_hidden, self.n_actions_max)).to(device)
self.V_jt = nn.Sequential(nn.Linear(self.dim_v_input, self.dim_hidden),
nn.ReLU(),
nn.Linear(self.dim_hidden, self.dim_hidden),
nn.ReLU(),
nn.Linear(self.dim_hidden, 1)).to(device)
self.dim_ae_input = dim_utility_hidden + self.n_actions_max
self.action_encoding = nn.Sequential(nn.Linear(self.dim_ae_input, self.dim_ae_input),
nn.ReLU(),
nn.Linear(self.dim_ae_input, self.dim_ae_input)).to(device)
[docs]
def forward(self, states: Tensor, hidden_state_inputs: Tensor, actions_onehot: Tensor):
"""Calculating the joint Q and V values.
Parameters:
states (Tensor): The global states.
hidden_state_inputs (Tensor): The joint hidden states inputs for QTRAN network.
actions_onehot (Tensor): The joint onehot actions for QTRAN network.
Returns:
q_jt (Tensor): The evaluated joint Q values.
v_jt (Tensor): The evaluated joint V values.
"""
h_state_action_input = torch.cat([hidden_state_inputs, actions_onehot], dim=-1)
h_state_action_encode = self.action_encoding(h_state_action_input).reshape(-1, self.n_agents, self.dim_ae_input)
bs, dim_h = h_state_action_encode.shape[0], h_state_action_encode.shape[-1]
agent_ids = torch.eye(self.n_agents, dtype=torch.float32, device=self.device)
agent_masks = (1 - agent_ids)
repeat_agent_ids = agent_ids.unsqueeze(0).repeat(bs, 1, 1)
repeated_agent_masks = agent_masks.unsqueeze(0).unsqueeze(-1).repeat(bs, 1, 1, dim_h)
repeated_h_state_action_encode = h_state_action_encode.unsqueeze(2).repeat(1, 1, self.n_agents, 1)
h_state_action_encode = repeated_h_state_action_encode * repeated_agent_masks
h_state_action_encode = h_state_action_encode.sum(dim=2) # Sum across other agents
repeated_states = states.unsqueeze(1).repeat(1, self.n_agents, 1)
input_q = torch.cat([repeated_states, h_state_action_encode, repeat_agent_ids], dim=-1)
input_v = states
q_jt = self.Q_jt(input_q)
v_jt = self.V_jt(input_v)
return q_jt, v_jt