import os
import numpy as np
from argparse import Namespace
from operator import itemgetter
from abc import ABC, abstractmethod
from xuance.common import Optional, List, Union
from xuance.mindspore import ms, nn, Tensor, Module, optim, ops
[docs]
class Learner(ABC):
def __init__(self,
config: Namespace,
policy: Module,
callback):
self.value_normalizer = None
self.config = config
self.distributed_training = config.distributed_training
self.episode_length = config.episode_length
self.learning_rate = config.learning_rate if hasattr(config, 'learning_rate') else None
self.use_linear_lr_decay = config.use_linear_lr_decay if hasattr(config, 'use_linear_lr_decay') else False
self.end_factor_lr_decay = config.end_factor_lr_decay if hasattr(config, 'end_factor_lr_decay') else 0.5
self.gamma = config.gamma if hasattr(config, 'gamma') else 0.99
self.use_rnn = config.use_rnn if hasattr(config, 'use_rnn') else False
self.use_actions_mask = config.use_actions_mask if hasattr(config, 'use_actions_mask') else False
self.policy = policy
self.optimizer: Union[dict, list, Optional[nn.Optimizer]] = None
self.scheduler: Union[dict, list, Optional[optim.lr_scheduler.LRScheduler]] = None
self.callback = callback
self.use_grad_clip = config.use_grad_clip
self.grad_clip_norm = config.grad_clip_norm
self.device = config.device
self.model_dir = config.model_dir
self.total_iters = self.estimate_total_iterations()
self.iterations = 0
[docs]
def get_grad_reducer(self,
optimizer: Union[dict, list, Optional[nn.Optimizer]]
) -> Optional[nn.DistributedGradReducer]:
if self.distributed_training:
mean = ms.context.get_auto_parallel_context("gradients_mean")
return nn.DistributedGradReducer(optimizer.parameters, mean)
else:
return None
[docs]
def estimate_total_iterations(self):
"""Estimated total number of training iterations"""
start_training = getattr(self.config, "start_training", 0)
training_frequency = getattr(self.config, "training_frequency", 1)
total_iters = (self.config.running_steps - start_training) // (training_frequency * self.config.parallels)
return total_iters
[docs]
def save_model(self, model_path):
ms.save_checkpoint(self.policy, model_path)
[docs]
def load_model(self, path, model=None):
file_names = os.listdir(path)
if model is not None:
path = os.path.join(path, model)
if model not in file_names:
raise RuntimeError(f"The folder '{path}' does not exist, please specify a correct path to load model.")
else:
for f in file_names:
if "seed_" not in f:
file_names.remove(f)
file_names.sort()
path = os.path.join(path, file_names[-1])
model_names = os.listdir(path)
if os.path.exists(path + "/obs_rms.npy"):
model_names.remove("obs_rms.npy")
if len(model_names) == 0:
raise RuntimeError(f"There is no model file in '{path}'!")
model_names.sort()
model_path = os.path.join(path, model_names[-1])
ms.load_param_into_net(self.policy, ms.load_checkpoint(model_path))
print(f"Successfully load model from '{path}'.")
return path
[docs]
@abstractmethod
def update(self, *args):
raise NotImplementedError
[docs]
class LearnerMAS(ABC):
def __init__(self,
config: Namespace,
model_keys: List[str],
agent_keys: List[str],
policy: Module,
callback):
self.value_normalizer = None
self.config = config
self.n_agents = config.n_agents
self.dim_id = self.n_agents
self.use_parameter_sharing = config.use_parameter_sharing
self.model_keys = model_keys
self.agent_keys = agent_keys
self.episode_length = config.episode_length
self.learning_rate = config.learning_rate if hasattr(config, 'learning_rate') else None
self.use_linear_lr_decay = config.use_linear_lr_decay if hasattr(config, 'use_linear_lr_decay') else False
self.end_factor_lr_decay = config.end_factor_lr_decay if hasattr(config, 'end_factor_lr_decay') else 0.5
self.gamma = config.gamma if hasattr(config, 'gamma') else 0.99
self.use_rnn = config.use_rnn if hasattr(config, 'use_rnn') else False
self.use_actions_mask = config.use_actions_mask if hasattr(config, 'use_actions_mask') else False
self.policy = policy
self.optimizer: Union[dict, list, Optional[nn.Optimizer]] = None
self.scheduler: Union[dict, list, Optional[ms.experimental.optim.lr_scheduler.LRScheduler]] = None
self.callback = callback
self.use_grad_clip = config.use_grad_clip
self.grad_clip_norm = config.grad_clip_norm
self.device = config.device
self.model_dir = config.model_dir
self.total_iters = self.estimate_total_iterations()
self.iterations = 0
self.eye = ops.Eye()
[docs]
def estimate_total_iterations(self):
"""Estimated total number of training iterations"""
start_training = getattr(self.config, "start_training", 0)
training_frequency = getattr(self.config, "training_frequency", 1)
n_epochs = getattr(self.config, "n_epochs", 1)
episode_length = self.episode_length
if self.use_rnn:
total_iters = (self.config.running_steps - start_training) // (episode_length * self.config.parallels)
else:
total_iters = (self.config.running_steps - start_training) // (training_frequency * self.config.parallels)
total_iters *= n_epochs
return total_iters
[docs]
def build_training_data(self, sample: Optional[dict],
use_parameter_sharing: Optional[bool] = False,
use_actions_mask: Optional[bool] = False,
use_global_state: Optional[bool] = False):
"""
Prepare the training data.
Parameters:
sample (dict): The raw sampled data.
use_parameter_sharing (bool): Whether to use parameter sharing for individual agent models.
use_actions_mask (bool): Whether to use actions mask for unavailable actions.
use_global_state (bool): Whether to use global state.
Returns:
sample_Tensor (dict): The formatted sampled data.
"""
batch_size = sample['batch_size']
seq_length = sample['sequence_length'] if self.use_rnn else 1
state, avail_actions, filled = None, None, None
obs_next, state_next, avail_actions_next = None, None, None
IDs = None
if use_parameter_sharing:
k = self.model_keys[0]
bs = batch_size * self.n_agents
obs_tensor = Tensor(np.stack(itemgetter(*self.agent_keys)(sample['obs']), axis=1))
actions_tensor = Tensor(np.stack(itemgetter(*self.agent_keys)(sample['actions']), axis=1))
rewards_tensor = Tensor(np.stack(itemgetter(*self.agent_keys)(sample['rewards']), axis=1))
ter_tensor = Tensor(np.stack(itemgetter(*self.agent_keys)(sample['terminals']), 1)).astype(ms.float32)
msk_tensor = Tensor(np.stack(itemgetter(*self.agent_keys)(sample['agent_mask']), 1)).astype(ms.float32)
if self.use_rnn:
obs = {k: obs_tensor.reshape(bs, seq_length + 1, -1)}
if len(actions_tensor.shape) == 3:
actions = {k: actions_tensor.reshape(bs, seq_length)}
elif len(actions_tensor.shape) == 4:
actions = {k: actions_tensor.reshape(bs, seq_length, -1)}
else:
raise AttributeError("Wrong actions shape.")
rewards = {k: rewards_tensor.reshape(batch_size, self.n_agents, seq_length)}
terminals = {k: ter_tensor.reshape(batch_size, self.n_agents, seq_length)}
agent_mask = {k: msk_tensor.reshape(bs, seq_length)}
IDs = self.eye(self.n_agents, self.n_agents, ms.float32).unsqueeze(1).unsqueeze(0).broadcast_to(
(batch_size, -1, seq_length + 1, -1)).reshape(bs, seq_length + 1, self.n_agents)
else:
obs = {k: obs_tensor.reshape(bs, -1)}
if len(actions_tensor.shape) == 2:
actions = {k: actions_tensor.reshape(bs)}
elif len(actions_tensor.shape) == 3:
actions = {k: actions_tensor.reshape(bs, -1)}
else:
raise AttributeError("Wrong actions shape.")
rewards = {k: rewards_tensor.reshape(batch_size, self.n_agents)}
terminals = {k: ter_tensor.reshape(batch_size, self.n_agents)}
agent_mask = {k: msk_tensor.reshape(bs)}
obs_next = {k: Tensor(np.stack(itemgetter(*self.agent_keys)(sample['obs_next']),
axis=1)).reshape(bs, -1)}
IDs = self.eye(self.n_agents, self.n_agents, ms.float32).unsqueeze(0).broadcast_to(
(batch_size, -1, -1)).reshape(bs, self.n_agents)
if use_actions_mask:
avail_a = np.stack(itemgetter(*self.agent_keys)(sample['avail_actions']), axis=1)
if self.use_rnn:
avail_actions = {k: Tensor(avail_a.reshape([bs, seq_length + 1, -1])).astype(ms.float32)}
else:
avail_actions = {k: Tensor(avail_a.reshape([bs, -1])).astype(ms.float32)}
avail_a_next = np.stack(itemgetter(*self.agent_keys)(sample['avail_actions_next']), axis=1)
avail_actions_next = {k: Tensor(avail_a_next.reshape([bs, -1])).astype(ms.float32)}
else:
obs = {k: Tensor(sample['obs'][k]) for k in self.agent_keys}
actions = {k: Tensor(sample['actions'][k]) for k in self.agent_keys}
rewards = {k: Tensor(sample['rewards'][k]) for k in self.agent_keys}
terminals = {k: Tensor(sample['terminals'][k]).astype(ms.float32) for k in self.agent_keys}
agent_mask = {k: Tensor(sample['agent_mask'][k]).astype(ms.float32) for k in self.agent_keys}
if not self.use_rnn:
obs_next = {k: Tensor(sample['obs_next'][k]) for k in self.agent_keys}
if use_actions_mask:
avail_actions = {k: Tensor(sample['avail_actions'][k]).astype(ms.float32) for k in self.agent_keys}
if not self.use_rnn:
avail_actions_next = {k: Tensor(sample['avail_actions_next'][k]).astype(ms.float32)
for k in self.model_keys}
if use_global_state:
state = Tensor(sample['state'])
if not self.use_rnn:
state_next = Tensor(sample['state_next'])
if self.use_rnn:
filled = Tensor(sample['filled']).astype(ms.float32)
sample_Tensor = {
'batch_size': batch_size,
'state': state,
'state_next': state_next,
'obs': obs,
'actions': actions,
'obs_next': obs_next,
'rewards': rewards,
'terminals': terminals,
'agent_mask': agent_mask,
'avail_actions': avail_actions,
'avail_actions_next': avail_actions_next,
'agent_ids': IDs,
'filled': filled,
'seq_length': seq_length,
}
return sample_Tensor
[docs]
@abstractmethod
def update(self, *args):
raise NotImplementedError
[docs]
def update_rnn(self, *args):
raise NotImplementedError
[docs]
def save_model(self, model_path):
ms.save_checkpoint(self.policy, model_path)
[docs]
def load_model(self, path, model=None):
file_names = os.listdir(path)
if model is not None:
path = os.path.join(path, model)
if model not in file_names:
raise RuntimeError(f"The folder '{path}' does not exist, please specify a correct path to load model.")
else:
for f in file_names:
if "seed_" not in f:
file_names.remove(f)
file_names.sort()
path = os.path.join(path, file_names[-1])
model_names = os.listdir(path)
if os.path.exists(path + "/obs_rms.npy"):
model_names.remove("obs_rms.npy")
if len(model_names) == 0:
raise RuntimeError(f"There is no model file in '{path}'!")
model_names.sort()
model_path = os.path.join(path, model_names[-1])
ms.load_param_into_net(self.policy, ms.load_checkpoint(model_path))
print(f"Successfully load model from '{path}'.")