Source code for xuance.torch.learners.multi_agent_rl.isac_learner

"""
Independent Soft Actor-critic (ISAC)
Implementation: Pytorch
"""
import torch
from torch import nn
from xuance.torch.learners import LearnerMAS
from xuance.common import List
from argparse import Namespace


[docs] class ISAC_Learner(LearnerMAS): def __init__(self, config: Namespace, model_keys: List[str], agent_keys: List[str], policy: nn.Module, callback): super(ISAC_Learner, self).__init__(config, model_keys, agent_keys, policy, callback) self.optimizer = { key: {'actor': torch.optim.Adam(self.policy.parameters_actor[key], self.config.learning_rate_actor, eps=1e-5), 'critic': torch.optim.Adam(self.policy.parameters_critic[key], self.config.learning_rate_critic, eps=1e-5)} for key in self.model_keys} self.scheduler = { key: {'actor': torch.optim.lr_scheduler.LinearLR(self.optimizer[key]['actor'], start_factor=1.0, end_factor=self.end_factor_lr_decay, total_iters=self.total_iters), 'critic': torch.optim.lr_scheduler.LinearLR(self.optimizer[key]['critic'], start_factor=1.0, end_factor=self.end_factor_lr_decay, total_iters=self.total_iters)} for key in self.model_keys} self.gamma = config.gamma self.tau = config.tau self.alpha = {key: config.alpha for key in self.model_keys} self.mse_loss = nn.MSELoss() self.use_automatic_entropy_tuning = config.use_automatic_entropy_tuning if self.use_automatic_entropy_tuning: self.target_entropy = {key: -policy.action_space[key].shape[-1] for key in self.model_keys} self.log_alpha = {key: nn.Parameter(torch.zeros(1, requires_grad=True, device=self.device)) for key in self.model_keys} self.alpha = {key: self.log_alpha[key].exp() for key in self.model_keys} self.alpha_optimizer = {key: torch.optim.Adam([self.log_alpha[key]], lr=config.learning_rate_actor) for key in self.model_keys}
[docs] def update(self, sample): self.iterations += 1 # Prepare training data. sample_Tensor = self.build_training_data(sample, use_parameter_sharing=self.use_parameter_sharing, use_actions_mask=False) batch_size = sample_Tensor['batch_size'] obs = sample_Tensor['obs'] actions = sample_Tensor['actions'] obs_next = sample_Tensor['obs_next'] rewards = sample_Tensor['rewards'] terminals = sample_Tensor['terminals'] agent_mask = sample_Tensor['agent_mask'] IDs = sample_Tensor['agent_ids'] if self.use_parameter_sharing: key = self.model_keys[0] bs = batch_size * self.n_agents rewards[key] = rewards[key].reshape(batch_size * self.n_agents) terminals[key] = terminals[key].reshape(batch_size * self.n_agents) else: bs = batch_size info = self.callback.on_update_start(self.iterations, method="update", policy=self.policy, sample_Tensor=sample_Tensor, bs=bs) # feedforward _, actions_eval, log_pi_eval = self.policy(observation=obs, agent_ids=IDs) _, actions_next, log_pi_next = self.policy(observation=obs_next, agent_ids=IDs) _, _, action_q_1, action_q_2 = self.policy.Qpolicy(observation=obs, actions=actions, agent_ids=IDs) _, _, next_q = self.policy.Qtarget(next_observation=obs_next, next_actions=actions_next, agent_ids=IDs) for key in self.model_keys: mask_values = agent_mask[key] # update critic action_q_1_i, action_q_2_i = action_q_1[key].reshape(bs), action_q_2[key].reshape(bs) log_pi_next_eval = log_pi_next[key].reshape(bs) next_q_i = next_q[key].reshape(bs) target_value = next_q_i - self.alpha[key] * log_pi_next_eval backup = rewards[key] + (1 - terminals[key]) * self.gamma * target_value td_error_1, td_error_2 = action_q_1_i - backup.detach(), action_q_2_i - backup.detach() td_error_1 *= mask_values td_error_2 *= mask_values loss_c = ((td_error_1 ** 2).sum() + (td_error_2 ** 2).sum()) / mask_values.sum() self.optimizer[key]['critic'].zero_grad() loss_c.backward() if self.use_grad_clip: torch.nn.utils.clip_grad_norm_(self.policy.parameters_critic[key], self.grad_clip_norm) self.optimizer[key]['critic'].step() if self.scheduler[key]['critic'] is not None: self.scheduler[key]['critic'].step() # update actor _, _, policy_q_1, policy_q_2 = self.policy.Qpolicy(observation=obs, actions=actions_eval, agent_ids=IDs, agent_key=key) log_pi_eval_i = log_pi_eval[key].reshape(bs) policy_q = torch.min(policy_q_1[key], policy_q_2[key]).reshape(bs) loss_a = ((self.alpha[key] * log_pi_eval_i - policy_q) * mask_values).sum() / mask_values.sum() self.optimizer[key]['actor'].zero_grad() loss_a.backward() if self.use_grad_clip: torch.nn.utils.clip_grad_norm_(self.policy.parameters_actor[key], self.grad_clip_norm) self.optimizer[key]['actor'].step() if self.scheduler[key]['actor'] is not None: self.scheduler[key]['actor'].step() # automatic entropy tuning if self.use_automatic_entropy_tuning: alpha_loss = -(self.log_alpha[key] * (log_pi_eval_i + self.target_entropy[key]).detach()).mean() self.alpha_optimizer[key].zero_grad() alpha_loss.backward() self.alpha_optimizer[key].step() self.alpha[key] = self.log_alpha[key].exp() else: alpha_loss = 0 learning_rate_actor = self.optimizer[key]['actor'].state_dict()['param_groups'][0]['lr'] learning_rate_critic = self.optimizer[key]['critic'].state_dict()['param_groups'][0]['lr'] info.update({ f"{key}/learning_rate_actor": learning_rate_actor, f"{key}/learning_rate_critic": learning_rate_critic, f"{key}/loss_actor": loss_a.item(), f"{key}/loss_critic": loss_c.item(), f"{key}/predictQ": policy_q.mean().item(), }) if self.use_automatic_entropy_tuning: info.update({f"{key}/alpha_loss": alpha_loss.item(), f"{key}/alpha": self.alpha[key].item()}) info.update(self.callback.on_update_agent_wise(self.iterations, key, info=info, method="update", mask_values=mask_values, action_q_1_i=action_q_1_i, action_q_2_i=action_q_2_i, log_pi_next_eval=log_pi_next_eval, next_q_i=next_q_i, target_value=target_value, backup=backup, td_error_1=td_error_1, td_error_2=td_error_2, policy_q_1=policy_q_1, policy_q_2=policy_q_2, log_pi_eval_i=log_pi_eval_i, policy_q=policy_q)) self.policy.soft_update(self.tau) info.update(self.callback.on_update_end(self.iterations, method="update", policy=self.policy, info=info)) return info
[docs] def update_rnn(self, sample): self.iterations += 1 # prepare training data sample_Tensor = self.build_training_data(sample=sample, use_parameter_sharing=self.use_parameter_sharing, use_actions_mask=self.use_actions_mask) batch_size = sample_Tensor['batch_size'] seq_len = sample_Tensor['seq_length'] obs = sample_Tensor['obs'] actions = sample_Tensor['actions'] rewards = sample_Tensor['rewards'] terminals = sample_Tensor['terminals'] agent_mask = sample_Tensor['agent_mask'] filled = sample_Tensor['filled'] IDs = sample_Tensor['agent_ids'] if self.use_parameter_sharing: key = self.model_keys[0] bs_rnn = batch_size * self.n_agents filled = filled.unsqueeze(1).expand(-1, self.n_agents, -1).reshape(bs_rnn, seq_len) rewards[key] = rewards[key].reshape(bs_rnn, seq_len) terminals[key] = terminals[key].reshape(bs_rnn, seq_len) IDs_t = IDs[:, :-1] else: bs_rnn, IDs_t = batch_size, None info = self.callback.on_update_start(self.iterations, method="update_rnn", policy=self.policy, sample_Tensor=sample_Tensor, bs_rnn=bs_rnn) # initial hidden states for rnn rnn_hidden_actor = {k: self.policy.actor_representation[k].init_hidden(bs_rnn) for k in self.model_keys} rnn_hidden_critic = {k: self.policy.critic_1_representation[k].init_hidden(bs_rnn) for k in self.model_keys} _, actions_eval, log_pi_eval = self.policy(observation=obs, agent_ids=IDs, rnn_hidden=rnn_hidden_actor) obs_t = {k: v[:, :-1] for k, v in obs.items()} _, _, action_q_1, action_q_2 = self.policy.Qpolicy(observation=obs_t, actions=actions, agent_ids=IDs_t, rnn_hidden_critic_1=rnn_hidden_critic, rnn_hidden_critic_2=rnn_hidden_critic) _, _, next_q = self.policy.Qtarget(next_observation=obs, next_actions=actions_eval, agent_ids=IDs, rnn_hidden_critic_1=rnn_hidden_critic, rnn_hidden_critic_2=rnn_hidden_critic) for key in self.model_keys: mask_values = agent_mask[key] * filled # update critic action_q_1_i = action_q_1[key].reshape(bs_rnn, seq_len) action_q_2_i = action_q_2[key].reshape(bs_rnn, seq_len) log_pi_next_eval = log_pi_eval[key][:, 1:].reshape(bs_rnn, seq_len) next_q_i = next_q[key][:, 1:].reshape(bs_rnn, seq_len) target_value = next_q_i - self.alpha[key] * log_pi_next_eval backup = rewards[key] + (1 - terminals[key]) * self.gamma * target_value td_error_1, td_error_2 = action_q_1_i - backup.detach(), action_q_2_i - backup.detach() td_error_1 *= mask_values td_error_2 *= mask_values loss_c = ((td_error_1 ** 2).sum() + (td_error_2 ** 2).sum()) / mask_values.sum() self.optimizer[key]['critic'].zero_grad() loss_c.backward() if self.use_grad_clip: torch.nn.utils.clip_grad_norm_(self.policy.parameters_critic[key], self.grad_clip_norm) self.optimizer[key]['critic'].step() if self.scheduler[key]['critic'] is not None: self.scheduler[key]['critic'].step() # update actor _, _, policy_q_1, policy_q_2 = self.policy.Qpolicy(observation=obs, actions=actions_eval, agent_ids=IDs, agent_key=key, rnn_hidden_critic_1=rnn_hidden_critic, rnn_hidden_critic_2=rnn_hidden_critic) log_pi_eval_i = log_pi_eval[key][:, :-1].reshape(bs_rnn, seq_len) policy_q = torch.min(policy_q_1[key][:, :-1], policy_q_2[key][:, :-1]).reshape(bs_rnn, seq_len) loss_a = ((self.alpha[key] * log_pi_eval_i - policy_q) * mask_values).sum() / mask_values.sum() self.optimizer[key]['actor'].zero_grad() loss_a.backward() if self.use_grad_clip: torch.nn.utils.clip_grad_norm_(self.policy.parameters_actor[key], self.grad_clip_norm) self.optimizer[key]['actor'].step() if self.scheduler[key]['actor'] is not None: self.scheduler[key]['actor'].step() # automatic entropy tuning if self.use_automatic_entropy_tuning: alpha_loss = -(self.log_alpha[key] * (log_pi_eval_i + self.target_entropy[key]).detach()).mean() self.alpha_optimizer[key].zero_grad() alpha_loss.backward() self.alpha_optimizer[key].step() self.alpha = self.log_alpha[key].exp() else: alpha_loss = 0 learning_rate_actor = self.optimizer[key]['actor'].state_dict()['param_groups'][0]['lr'] learning_rate_critic = self.optimizer[key]['critic'].state_dict()['param_groups'][0]['lr'] info.update({ f"{key}/learning_rate_actor": learning_rate_actor, f"{key}/learning_rate_critic": learning_rate_critic, f"{key}/loss_actor": loss_a.item(), f"{key}/loss_critic": loss_c.item(), f"{key}/predictQ": policy_q.mean().item(), }) if self.use_automatic_entropy_tuning: info.update({f"{key}/alpha_loss": alpha_loss.item(), f"{key}/alpha": self.alpha[key].item()}) info.update(self.callback.on_update_agent_wise(self.iterations, key, info=info, method="update_rnn", mask_values=mask_values, action_q_1_i=action_q_1_i, action_q_2_i=action_q_2_i, log_pi_next_eval=log_pi_next_eval, next_q_i=next_q_i, target_value=target_value, backup=backup, td_error_1=td_error_1, td_error_2=td_error_2, policy_q_1=policy_q_1, policy_q_2=policy_q_2, log_pi_eval_i=log_pi_eval_i, policy_q=policy_q)) self.policy.soft_update(self.tau) info.update(self.callback.on_update_end(self.iterations, method="update_rnn", policy=self.policy, info=info)) return info