Source code for xuance.torch.learners.learner

import os
import torch
import numpy as np
from pathlib import Path
from abc import ABC, abstractmethod
from xuance.common import Optional, List, Union
from argparse import Namespace
from operator import itemgetter
from xuance.torch import Tensor, Module

MAX_GPUs = torch.cuda.device_count()


[docs] class Learner(ABC): def __init__(self, config: Namespace, policy: Module, callback): self.value_normalizer = None self.config = config self.distributed_training = config.distributed_training self.episode_length = config.episode_length self.learning_rate = config.learning_rate if hasattr(config, 'learning_rate') else None self.use_linear_lr_decay = config.use_linear_lr_decay if hasattr(config, 'use_linear_lr_decay') else False self.end_factor_lr_decay = config.end_factor_lr_decay if hasattr(config, 'end_factor_lr_decay') else 1.0 self.gamma = config.gamma if hasattr(config, 'gamma') else 0.99 self.use_rnn = config.use_rnn if hasattr(config, 'use_rnn') else False self.use_actions_mask = config.use_actions_mask if hasattr(config, 'use_actions_mask') else False self.policy = policy self.optimizer: Union[dict, list, Optional[torch.optim.Optimizer]] = None self.scheduler: Union[dict, list, Optional[torch.optim.lr_scheduler.LinearLR]] = None self.callback = callback if self.distributed_training: self.world_size = int(os.environ['WORLD_SIZE']) self.rank = self.device = int(os.environ['RANK']) self.snapshot_path = os.path.join(os.getcwd(), config.model_dir, "DDP_Snapshot") if os.path.exists(self.snapshot_path): if os.path.exists(os.path.join(self.snapshot_path, "snapshot.pt")): print("Loading Snapshot...") self.load_snapshot(self.snapshot_path) else: if self.device == 0: os.makedirs(self.snapshot_path) else: self.world_size = 1 self.rank = 0 self.device = config.device self.use_grad_clip = config.use_grad_clip self.grad_clip_norm = config.grad_clip_norm self.device = config.device self.model_dir = config.model_dir self.total_iters = self.estimate_total_iterations() self.iterations = 0
[docs] def estimate_total_iterations(self): """Estimated total number of training iterations""" start_training = getattr(self.config, "start_training", 0) training_frequency = getattr(self.config, "training_frequency", 1) total_iters = (self.config.running_steps - start_training) // (training_frequency * self.config.parallels) return total_iters
[docs] def save_model(self, model_path): if type(self.optimizer) is dict: torch.save( { 'policy': self.policy.state_dict(), 'optimizer': {k: v.state_dict() for k, v in self.optimizer.items()}, 'rng_state': torch.get_rng_state(), 'cuda_rng_state': torch.cuda.get_rng_state_all(), }, model_path) else: torch.save( { 'policy': self.policy.state_dict(), 'optimizer': self.optimizer.state_dict(), 'rng_state': torch.get_rng_state(), 'cuda_rng_state': torch.cuda.get_rng_state_all(), }, model_path) if self.distributed_training: self.save_snapshot()
[docs] def load_model(self, path, model=None): target_path = os.path.join(path, model) if model is not None else path if os.path.isfile(target_path): # load the specified model file model_path = target_path dir_name = os.path.dirname(model_path) else: if not os.path.isdir(path): raise RuntimeError(f"The path '{path}' is not a valid directory or file!") folder_names = [f for f in os.listdir(path) if "seed_" in f] folder_names.sort() if not folder_names: raise RuntimeError(f"No model files with 'seed_' found in '{path}'!") path = Path(os.path.join(path, folder_names[-1])) dir_name = str(path) model_names = list(path.glob("*.pth")) model_path = None if len(model_names) == 0: raise FileNotFoundError(f"No .pth file found in {path}") else: for f in model_names: if "final_train_model.pth" in str(f): model_path = f break model_path = str(model_names) checkpoint = torch.load(str(model_path), map_location={f"cuda:{i}": self.device for i in range(MAX_GPUs)}, weights_only=True) self.policy.load_state_dict(checkpoint['policy'], strict=False) if 'optimizer' in checkpoint and self.optimizer is not None: if type(self.optimizer) is dict: for k, v in self.optimizer.items(): v.load_state_dict(checkpoint['optimizer'][k]) else: self.optimizer.load_state_dict(checkpoint['optimizer']) current_lr = self.optimizer.param_groups[0]['lr'] self.learning_rate = current_lr if 'rng_state' in checkpoint: rng_state = checkpoint['rng_state'] rng_state = rng_state.cpu().to(dtype=torch.uint8) torch.set_rng_state(rng_state) if 'cuda_rng_state' in checkpoint and torch.cuda.is_available(): cuda_states = checkpoint['cuda_rng_state'] if isinstance(cuda_states, list): num_available_gpus = torch.cuda.device_count() cuda_states = cuda_states[:num_available_gpus] for i, state in enumerate(cuda_states): state = state.cpu().to(dtype=torch.uint8) torch.cuda.set_rng_state(state, device=i) self._safe_scheduler_step() print(f"Successfully load model from '{model_path}'.") return dir_name
[docs] def load_snapshot(self, snapshot_path): loc = f"cuda:{self.device}" if torch.cuda.is_available() else "cpu" snapshot = torch.load(snapshot_path, map_location=loc) if "MODEL_STATE" in snapshot: self.policy.load_state_dict(snapshot["MODEL_STATE"]) elif "policy" in snapshot: self.policy.load_state_dict(snapshot["policy"]) if "optimizer" in snapshot and self.optimizer is not None: self.optimizer.load_state_dict(snapshot["optimizer"]) if "rng_state" in snapshot: torch.set_rng_state(snapshot["rng_state"].to('cpu')) if "cuda_rng_state" in snapshot and torch.cuda.is_available(): cuda_states = snapshot["cuda_rng_state"] if isinstance(cuda_states, list): for i, state in enumerate(cuda_states): torch.cuda.set_rng_state(state.to(f'cuda:{i}'), device=i) print("Resuming training from snapshot (including optimizer/rng state).")
[docs] def save_snapshot(self): snapshot = { "policy": self.policy.state_dict(), "optimizer": self.optimizer.state_dict(), "rng_state": torch.get_rng_state(), "cuda_rng_state": torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None, } snapshot_pt = os.path.join(self.snapshot_path, "snapshot.pt") os.makedirs(self.snapshot_path, exist_ok=True) torch.save(snapshot, snapshot_pt)
def _safe_scheduler_step(self): if not hasattr(self, 'scheduler'): return if not hasattr(self.config, 'rt_epoch'): return try: train_steps = self.config.running_steps // self.config.parallels eval_interval = self.config.eval_interval // self.config.parallels num_epoch = int(train_steps / eval_interval) current_iters = int(self.total_iters * self.config.rt_epoch / num_epoch) self.scheduler.step(current_iters) print(f"scheduler.step success,rt_epoch={self.config.rt_epoch}") except TypeError as e: if "positional argument" in str(e) or "takes 1 positional argument" in str(e): self.scheduler.step() print(f"scheduler.step success, rt_epoch={self.config.rt_epoch}") except Exception as e: print(f"scheduler.step failure:{e}")
[docs] @abstractmethod def update(self, *args): raise NotImplementedError
[docs] class LearnerMAS(Learner): def __init__(self, config: Namespace, model_keys: List[str], agent_keys: List[str], policy: Module, callback): self.value_normalizer = None self.config = config self.distributed_training = config.distributed_training self.n_agents = config.n_agents self.dim_id = self.n_agents self.use_parameter_sharing = config.use_parameter_sharing self.model_keys = model_keys self.agent_keys = agent_keys self.episode_length = config.episode_length self.learning_rate = getattr(config, 'learning_rate', None) self.use_linear_lr_decay = getattr(config, 'use_linear_lr_decay', False) self.end_factor_lr_decay = getattr(config, 'end_factor_lr_decay', 0.5) self.gamma = getattr(config, 'gamma', 0.99) self.use_cnn = getattr(config, "use_cnn", False) self.use_rnn = getattr(config, 'use_rnn', False) self.use_actions_mask = getattr(config, 'use_actions_mask', False) self.policy = policy self.optimizer: Union[dict, list, Optional[torch.optim.Optimizer]] = None self.scheduler: Union[dict, list, Optional[torch.optim.lr_scheduler.LinearLR]] = None self.callback = callback self.use_grad_clip = config.use_grad_clip self.grad_clip_norm = config.grad_clip_norm self.device = config.device self.model_dir = config.model_dir self.total_iters = self.estimate_total_iterations() self.iterations = 0
[docs] def estimate_total_iterations(self): """Estimated total number of training iterations""" start_training = getattr(self.config, "start_training", 0) training_frequency = getattr(self.config, "training_frequency", 1) n_epochs = getattr(self.config, "n_epochs", 1) episode_length = self.episode_length if self.use_rnn: total_iters = (self.config.running_steps - start_training) // (episode_length * self.config.parallels) else: total_iters = (self.config.running_steps - start_training) // (training_frequency * self.config.parallels) total_iters *= n_epochs return total_iters
[docs] def build_training_data(self, sample: Optional[dict], use_parameter_sharing: Optional[bool] = False, use_actions_mask: Optional[bool] = False, use_global_state: Optional[bool] = False): """ Prepare the training data. Parameters: sample (dict): The raw sampled data. use_parameter_sharing (bool): Whether to use parameter sharing for individual agent models. use_actions_mask (bool): Whether to use actions mask for unavailable actions. use_global_state (bool): Whether to use global state. Returns: sample_Tensor (dict): The formatted sampled data. """ batch_size = sample['batch_size'] seq_length = sample['sequence_length'] if self.use_rnn else 1 state, avail_actions, filled = None, None, None obs_next, state_next, avail_actions_next = None, None, None IDs = None if use_parameter_sharing: k = self.model_keys[0] bs = batch_size * self.n_agents if self.n_agents == 1: obs_tensor = Tensor(sample['obs'][k]).to(self.device).unsqueeze(1) actions_tensor = Tensor(sample['actions'][k]).to(self.device).unsqueeze(1) rewards_tensor = Tensor(sample['rewards'][k]).to(self.device).unsqueeze(1) ter_tensor = Tensor(sample['terminals'][k]).float().to(self.device).unsqueeze(1) msk_tensor = Tensor(sample['agent_mask'][k]).float().to(self.device).unsqueeze(1) else: obs_tensor = Tensor(np.stack(itemgetter(*self.agent_keys)(sample['obs']), axis=1)).to(self.device) actions_tensor = Tensor(np.stack(itemgetter(*self.agent_keys)(sample['actions']), axis=1)).to(self.device) rewards_tensor = Tensor(np.stack(itemgetter(*self.agent_keys)(sample['rewards']), axis=1)).to(self.device) ter_tensor = Tensor(np.stack(itemgetter(*self.agent_keys)(sample['terminals']), axis=1)).float().to(self.device) msk_tensor = Tensor(np.stack(itemgetter(*self.agent_keys)(sample['agent_mask']), axis=1)).float().to(self.device) if self.use_cnn and len(obs_tensor.shape) > 3: # obs_array consists of images obs_shape_item = obs_tensor.shape[2:] else: obs_shape_item = (-1,) if self.use_rnn: obs = {k: obs_tensor.reshape(bs, seq_length + 1, -1)} if len(actions_tensor.shape) == 3: actions = {k: actions_tensor.reshape(bs, seq_length)} elif len(actions_tensor.shape) == 4: actions = {k: actions_tensor.reshape(bs, seq_length, -1)} else: raise AttributeError("Wrong actions shape.") rewards = {k: rewards_tensor.reshape(batch_size, self.n_agents, seq_length)} terminals = {k: ter_tensor.reshape(batch_size, self.n_agents, seq_length)} agent_mask = {k: msk_tensor.reshape(bs, seq_length)} IDs = torch.eye(self.n_agents).unsqueeze(1).unsqueeze(0).expand( batch_size, -1, seq_length + 1, -1).reshape(bs, seq_length + 1, self.n_agents).to(self.device) else: obs = {k: obs_tensor.reshape(bs, *obs_shape_item)} if len(actions_tensor.shape) == 2: actions = {k: actions_tensor.reshape(bs)} elif len(actions_tensor.shape) == 3: actions = {k: actions_tensor.reshape(bs, -1)} else: raise AttributeError("Wrong actions shape.") rewards = {k: rewards_tensor.reshape(batch_size, self.n_agents)} terminals = {k: ter_tensor.reshape(batch_size, self.n_agents)} agent_mask = {k: msk_tensor.reshape(bs)} obs_next = {k: Tensor(np.stack(itemgetter(*self.agent_keys)(sample['obs_next']), axis=1)).to(self.device).reshape(bs, *obs_shape_item)} IDs = torch.eye(self.n_agents).unsqueeze(0).expand( batch_size, -1, -1).reshape(bs, self.n_agents).to(self.device) if use_actions_mask: avail_a = np.stack(itemgetter(*self.agent_keys)(sample['avail_actions']), axis=1) if self.use_rnn: avail_actions = {k: Tensor(avail_a.reshape([bs, seq_length + 1, -1])).float().to(self.device)} else: avail_actions = {k: Tensor(avail_a.reshape([bs, -1])).float().to(self.device)} avail_a_next = np.stack(itemgetter(*self.agent_keys)(sample['avail_actions_next']), axis=1) avail_actions_next = {k: Tensor(avail_a_next.reshape([bs, -1])).float().to(self.device)} else: obs = {k: Tensor(sample['obs'][k]).to(self.device) for k in self.agent_keys} actions = {k: Tensor(sample['actions'][k]).to(self.device) for k in self.agent_keys} rewards = {k: Tensor(sample['rewards'][k]).to(self.device) for k in self.agent_keys} terminals = {k: Tensor(sample['terminals'][k]).float().to(self.device) for k in self.agent_keys} agent_mask = {k: Tensor(sample['agent_mask'][k]).float().to(self.device) for k in self.agent_keys} if not self.use_rnn: obs_next = {k: Tensor(sample['obs_next'][k]).to(self.device) for k in self.agent_keys} if use_actions_mask: avail_actions = {k: Tensor(sample['avail_actions'][k]).float().to(self.device) for k in self.agent_keys} if not self.use_rnn: avail_actions_next = {k: Tensor(sample['avail_actions_next'][k]).float().to(self.device) for k in self.model_keys} if use_global_state: state = Tensor(sample['state']).to(self.device) if not self.use_rnn: state_next = Tensor(sample['state_next']).to(self.device) if self.use_rnn: filled = Tensor(sample['filled']).float().to(self.device) sample_Tensor = { 'batch_size': batch_size, 'state': state, 'state_next': state_next, 'obs': obs, 'actions': actions, 'obs_next': obs_next, 'rewards': rewards, 'terminals': terminals, 'agent_mask': agent_mask, 'avail_actions': avail_actions, 'avail_actions_next': avail_actions_next, 'agent_ids': IDs, 'filled': filled, 'seq_length': seq_length, } return sample_Tensor
[docs] def get_joint_input(self, input_tensor, output_shape=None): if self.n_agents == 1: joint_tensor = itemgetter(*self.agent_keys)(input_tensor) else: joint_tensor = torch.concat(itemgetter(*self.agent_keys)(input_tensor), dim=-1) if output_shape is not None: joint_tensor = joint_tensor.reshape(output_shape) return joint_tensor
[docs] @abstractmethod def update(self, *args): raise NotImplementedError
[docs] def update_rnn(self, *args): raise NotImplementedError
[docs] def save_model(self, model_path): if type(self.optimizer) is dict: if type(list(self.optimizer.values())[0]) is dict: torch.save( { 'policy': self.policy.state_dict(), 'optimizer': {k_a: {k: v.state_dict() for k, v in v_a.items()} for k_a, v_a in self.optimizer.items()}, # agent-wise 'rng_state': torch.get_rng_state(), 'cuda_rng_state': torch.cuda.get_rng_state_all(), }, model_path) else: torch.save( { 'policy': self.policy.state_dict(), 'optimizer': {k: v.state_dict() for k, v in self.optimizer.items()}, 'rng_state': torch.get_rng_state(), 'cuda_rng_state': torch.cuda.get_rng_state_all(), }, model_path) else: torch.save( { 'policy': self.policy.state_dict(), 'optimizer': self.optimizer.state_dict(), 'rng_state': torch.get_rng_state(), 'cuda_rng_state': torch.cuda.get_rng_state_all(), }, model_path)
[docs] def load_model(self, path, model=None): target_path = os.path.join(path, model) if model is not None else path if os.path.isfile(target_path): # load the specified model file model_path = target_path dir_name = os.path.dirname(model_path) else: if not os.path.isdir(path): raise RuntimeError(f"The path '{path}' is not a valid directory or file!") folder_names = [f for f in os.listdir(path) if "seed_" in f] folder_names.sort() if not folder_names: raise RuntimeError(f"No model files with 'seed_' found in '{path}'!") path = Path(os.path.join(path, folder_names[-1])) dir_name = str(path) model_names = list(path.glob("*.pth")) model_path = None if len(model_names) == 0: raise FileNotFoundError(f"No .pth file found in {path}") else: for f in model_names: if "final_train_model.pth" in str(f): model_path = f break model_path = str(model_names) checkpoint = torch.load(str(model_path), map_location={f"cuda:{i}": self.device for i in range(MAX_GPUs)}, weights_only=True) self.policy.load_state_dict(checkpoint['policy'], strict=False) if 'optimizer' in checkpoint and self.optimizer is not None: if type(self.optimizer) is dict: if type(list(self.optimizer.values())[0]) is dict: for k_a, v_a in self.optimizer.items(): # agent-wise for k, v in v_a.items(): v.load_state_dict(checkpoint['optimizer'][k_a][k]) else: for k, v in self.optimizer.items(): v.load_state_dict(checkpoint['optimizer'][k]) else: self.optimizer.load_state_dict(checkpoint['optimizer']) current_lr = self.optimizer.param_groups[0]['lr'] self.learning_rate = current_lr if 'rng_state' in checkpoint: rng_state = checkpoint['rng_state'] rng_state = rng_state.cpu().to(dtype=torch.uint8) torch.set_rng_state(rng_state) if 'cuda_rng_state' in checkpoint and torch.cuda.is_available(): cuda_states = checkpoint['cuda_rng_state'] if isinstance(cuda_states, list): num_available_gpus = torch.cuda.device_count() cuda_states = cuda_states[:num_available_gpus] for i, state in enumerate(cuda_states): state = state.cpu().to(dtype=torch.uint8) torch.cuda.set_rng_state(state, device=i) self._safe_scheduler_step() print(f"Successfully load model from '{model_path}'.") return dir_name