"""
MFQ: Mean Field Q-Learning
Paper link:
http://proceedings.mlr.press/v80/yang18d/yang18d.pdf
Implementation: Pytorch
"""
import torch
import numpy as np
from operator import itemgetter
from xuance.torch import nn, Tensor, Module
from xuance.torch.learners import LearnerMAS
from xuance.common import List, Optional
from argparse import Namespace
[docs]
class MFQ_Learner(LearnerMAS):
def __init__(self,
config: Namespace,
model_keys: List[str],
agent_keys: List[str],
policy: Module,
callback):
super(MFQ_Learner, self).__init__(config, model_keys, agent_keys, policy, callback)
self.optimizer = {key: torch.optim.Adam(self.policy.parameters_model[key], config.learning_rate, eps=1e-5)
for key in self.model_keys}
self.scheduler = {key: torch.optim.lr_scheduler.LinearLR(self.optimizer[key],
start_factor=1.0,
end_factor=self.end_factor_lr_decay,
total_iters=self.total_iters)
for key in self.model_keys}
self.gamma = config.gamma
self.sync_frequency = config.sync_frequency
self.n_actions = {k: self.policy.action_space[k].n for k in self.model_keys}
self.policy_type = self.policy.policy_type
[docs]
def update(self, sample):
self.iterations += 1
# prepare training data
act_mean, act_mean_next = self.build_actions_mean_input(sample=sample,
use_parameter_sharing=self.use_parameter_sharing)
sample_Tensor = self.build_training_data(sample=sample,
use_parameter_sharing=self.use_parameter_sharing,
use_actions_mask=self.use_actions_mask)
batch_size = sample_Tensor['batch_size']
obs = sample_Tensor['obs']
actions = sample_Tensor['actions']
obs_next = sample_Tensor['obs_next']
rewards = sample_Tensor['rewards']
terminals = sample_Tensor['terminals']
agent_mask = sample_Tensor['agent_mask']
avail_actions = sample_Tensor['avail_actions']
avail_actions_next = sample_Tensor['avail_actions_next']
IDs = sample_Tensor['agent_ids']
if self.use_parameter_sharing:
key = self.model_keys[0]
bs = batch_size * self.n_agents
rewards[key] = rewards[key].reshape(batch_size * self.n_agents)
terminals[key] = terminals[key].reshape(batch_size * self.n_agents)
else:
bs = batch_size
info = self.callback.on_update_start(self.iterations, method="update", policy=self.policy)
_, _, q_eval = self.policy(observation=obs, agent_ids=IDs, actions_mean=act_mean, avail_actions=avail_actions)
_, q_next = self.policy.Qtarget(observation=obs_next, actions_mean=act_mean_next, agent_ids=IDs)
for key in self.model_keys:
mask_values = agent_mask[key]
q_eval_a = q_eval[key].gather(-1, actions[key].long().unsqueeze(-1)).reshape(bs)
if self.use_actions_mask:
q_next[key][avail_actions_next[key] == 0] = -1e10
if self.policy_type == "Boltzmann":
pi_probs = self.policy.get_boltzmann_policy(q_next[key])
v_mf = (pi_probs * q_next[key]).sum(-1).reshape(-1)
q_target = rewards[key] + (1 - terminals[key]) * self.gamma * v_mf
elif self.policy_type == "greedy":
_, actions_next_greedy, _ = self.policy(obs_next, IDs, actions_mean=act_mean_next, agent_key=key,
avail_actions=avail_actions)
q_next_a = q_next[key].gather(-1, actions_next_greedy[key].unsqueeze(-1).long()).reshape(bs)
q_target = rewards[key] + (1 - terminals[key]) * self.gamma * q_next_a
else:
raise NotImplementedError
# calculate the loss function
td_error = (q_eval_a - q_target.detach()) * mask_values
loss = (td_error ** 2).sum() / mask_values.sum()
self.optimizer[key].zero_grad()
loss.backward()
if self.use_grad_clip:
torch.nn.utils.clip_grad_norm_(self.policy.parameters_model[key], self.grad_clip_norm)
self.optimizer[key].step()
if self.scheduler[key] is not None:
self.scheduler[key].step()
lr = self.optimizer[key].state_dict()['param_groups'][0]['lr']
info.update({
f"{key}/learning_rate": lr,
f"{key}/loss_Q": loss.item(),
f"{key}/predictQ": q_eval_a.mean().item()
})
info.update(self.callback.on_update_agent_wise(self.iterations, key, info=info, method="update",
mask_values=mask_values, q_eval_a=q_eval_a,
q_next=q_next[key], q_target=q_target,
td_error=td_error, loss=loss))
if self.iterations % self.sync_frequency == 0:
self.policy.copy_target()
info.update(self.callback.on_update_end(self.iterations, method="update", policy=self.policy, info=info))
return info
[docs]
def update_rnn(self, sample):
self.iterations += 1
# prepare training data
act_mean, act_mean_next = self.build_actions_mean_input(sample=sample,
use_parameter_sharing=self.use_parameter_sharing)
sample_Tensor = self.build_training_data(sample=sample,
use_parameter_sharing=self.use_parameter_sharing,
use_actions_mask=self.use_actions_mask)
batch_size = sample_Tensor['batch_size']
seq_len = sample_Tensor['seq_length']
obs = sample_Tensor['obs']
actions = sample_Tensor['actions']
rewards = sample_Tensor['rewards']
terminals = sample_Tensor['terminals']
agent_mask = sample_Tensor['agent_mask']
avail_actions = sample_Tensor['avail_actions']
filled = sample_Tensor['filled']
IDs = sample_Tensor['agent_ids']
if self.use_parameter_sharing:
key = self.model_keys[0]
bs_rnn = batch_size * self.n_agents
filled = filled.unsqueeze(1).expand(-1, self.n_agents, -1).reshape(bs_rnn, seq_len)
rewards[key] = rewards[key].reshape(batch_size * self.n_agents, seq_len)
terminals[key] = terminals[key].reshape(batch_size * self.n_agents, seq_len)
else:
bs_rnn = batch_size
info = self.callback.on_update_start(self.iterations, method="update_rnn",
policy=self.policy, sample_Tensor=sample_Tensor, bs_rnn=bs_rnn)
rnn_hidden = {k: self.policy.representation[k].init_hidden(bs_rnn) for k in self.model_keys}
_, actions_greedy, q_eval = self.policy(observation=obs, agent_ids=IDs, actions_mean=act_mean,
avail_actions=avail_actions,
rnn_hidden=rnn_hidden)
target_rnn_hidden = {k: self.policy.target_representation[k].init_hidden(bs_rnn) for k in self.model_keys}
_, q_next_seq = self.policy.Qtarget(observation=obs, agent_ids=IDs, actions_mean=act_mean,
rnn_hidden=target_rnn_hidden)
for key in self.model_keys:
mask_values = agent_mask[key] * filled
# calculate the target Q values
q_eval_a = q_eval[key][:, :-1].gather(-1, actions[key].long().unsqueeze(-1)).reshape(bs_rnn, seq_len)
q_next = q_next_seq[key][:, 1:]
if self.use_actions_mask:
q_next[avail_actions[key][:, 1:] == 0] = -1e10
if self.policy_type == "Boltzmann":
pi_probs = self.policy.get_boltzmann_policy(q_next)
v_mf = (pi_probs * q_next).sum(-1).reshape(bs_rnn, seq_len)
q_target = rewards[key] + (1 - terminals[key]) * self.gamma * v_mf
elif self.policy_type == "greedy":
actions_next_greedy = actions_greedy[key][:, 1:].unsqueeze(-1)
q_next_a = q_next.gather(-1, actions_next_greedy.long().detach()).reshape(bs_rnn, seq_len)
q_target = rewards[key] + (1 - terminals[key]) * self.gamma * q_next_a
else:
raise NotImplementedError
# calculate the loss function
td_errors = (q_eval_a - q_target.detach()) * mask_values
loss = (td_errors ** 2).sum() / mask_values.sum()
self.optimizer[key].zero_grad()
loss.backward()
if self.use_grad_clip:
torch.nn.utils.clip_grad_norm_(self.policy.parameters_model[key], self.grad_clip_norm)
self.optimizer[key].step()
if self.scheduler is not None:
self.scheduler[key].step()
lr = self.optimizer[key].state_dict()['param_groups'][0]['lr']
info.update({
f"{key}/learning_rate": lr,
f"{key}/loss_Q": loss.item(),
f"{key}/predictQ": q_eval_a.mean().item()
})
info.update(self.callback.on_update_agent_wise(self.iterations, key, info=info, method="update_rnn",
mask_values=mask_values, q_eval_a=q_eval_a,
q_next_a=q_next, q_target=q_target,
td_error=td_errors, loss=loss))
if self.iterations % self.sync_frequency == 0:
self.policy.copy_target()
info.update(self.callback.on_update_end(self.iterations, method="update_rnn", policy=self.policy, info=info))
return info