Source code for xuance.tensorflow.agents.multi_agent_rl.ic3net_agents

from argparse import Namespace
from copy import deepcopy
from operator import itemgetter
from typing import List, Union, Optional, Dict
import gymnasium as gym
import numpy as np
import torch
from gymnasium import Space
from torch.nn import Module, ModuleDict

from xuance.common.memory_tools_marl import IC3Net_OnPolicyBuffer_RNN
from xuance.torch.communications import IC3NetComm

from xuance.torch import REGISTRY_Policy
from xuance.torch.utils import NormalizeFunctions, ActivationFunctions

from xuance.environment import DummyVecMultiAgentEnv, SubprocVecMultiAgentEnv, space2shape
from xuance.common import MultiAgentBaseCallback
from xuance.torch.agents.multi_agent_rl.commnet_agents import CommNet_Agents


[docs] class IC3Net_Agents(CommNet_Agents): def __init__( self, config: Namespace, envs: Optional[DummyVecMultiAgentEnv | SubprocVecMultiAgentEnv] = None, num_agents: Optional[int] = None, agent_keys: Optional[List[str]] = None, state_space: Optional[Space] = None, observation_space: Optional[Space] = None, action_space: Optional[Space] = None, callback: Optional[MultiAgentBaseCallback] = None ) -> None: super(IC3Net_Agents, self).__init__( config, envs, num_agents, agent_keys, state_space, observation_space, action_space, callback ) self.policy = self._build_policy() self.memory = self._build_memory() # build memory self.learner = self._build_learner(self.config, self.model_keys, self.agent_keys, self.policy, callback) def _build_memory(self): if self.use_actions_mask: avail_actions_shape = {key: (self.action_space[key].n,) for key in self.agent_keys} else: avail_actions_shape = None input_buffer = dict(agent_keys=self.agent_keys, state_space=self.state_space if self.use_global_state else None, obs_space=self.observation_space, act_space=self.action_space, n_envs=self.n_envs, buffer_size=self.config.buffer_size, use_gae=self.config.use_gae, use_advnorm=self.config.use_advnorm, gamma=self.config.gamma, gae_lam=self.config.gae_lambda, avail_actions_shape=avail_actions_shape, use_actions_mask=self.use_actions_mask, max_episode_steps=self.episode_length) Buffer = IC3Net_OnPolicyBuffer_RNN return Buffer(**input_buffer) def _build_communicator(self, input_space: Union[Dict[str, Space], Dict[str, tuple]], ) -> Module: communicator = ModuleDict() hidden_sizes = {'fc_hidden_sizes': self.config.fc_hidden_sizes, 'recurrent_hidden_size': self.config.recurrent_hidden_size} for key in self.model_keys: input_communicator = dict( input_shape=space2shape(input_space[key]), hidden_sizes=hidden_sizes, comm_passes=self.config.comm_passes, model_keys=self.model_keys, agent_keys=self.agent_keys, n_agents=self.n_agents, device=self.device, config=self.config) communicator[key] = IC3NetComm(**input_communicator) return communicator def _build_policy(self) -> Module: normalize_fn = NormalizeFunctions[self.config.normalize] if hasattr(self.config, "normalize") else None initializer = torch.nn.init.orthogonal_ activation = ActivationFunctions[self.config.activation] device = self.device agent = self.config.agent max_length = max(space.shape[0] for space in self.observation_space.values()) self.observation_space = {agent: gym.spaces.Box(-np.inf, np.inf, (max_length,), dtype=np.float32) for agent in self.observation_space} # build representations communicator = self._build_communicator(self.observation_space) space_actor_in = {agent: gym.spaces.Box(-np.inf, np.inf, (self.config.recurrent_hidden_size,), dtype=np.float32) for agent in self.observation_space} if self.use_global_state: dim_obs_all = sum(self.state_space.shape) else: dim_obs_all = sum([sum(self.observation_space[k].shape) for k in self.agent_keys]) space_critic_in = {k: (dim_obs_all,) for k in self.agent_keys} A_representation = self._build_representation(self.config.representation, space_actor_in, self.config) C_representation = self._build_representation(self.config.representation, space_critic_in, self.config) # build policies if self.config.policy == "IC3Net_Policy": policy = REGISTRY_Policy[self.config.policy]( action_space=self.action_space, n_agents=self.n_agents, representation_actor=A_representation, representation_critic=C_representation, actor_hidden_size=self.config.actor_hidden_size, critic_hidden_size=self.config.critic_hidden_size, normalize=normalize_fn, initialize=initializer, activation=activation, device=device, use_distributed_training=self.distributed_training, use_parameter_sharing=self.use_parameter_sharing, model_keys=self.model_keys, use_rnn=self.use_rnn, rnn=self.config.rnn if self.use_rnn else None, communicator=communicator, agent_keys=self.agent_keys, comm_passes=self.config.comm_passes, config=self.config) self.continuous_control = False else: raise AttributeError(f"{agent} currently does not support the policy named {self.config.policy}.") return policy
[docs] def get_actions(self, obs_dict: List[dict], state: Optional[np.ndarray] = None, avail_actions_dict: Optional[List[dict]] = None, rnn_hidden_actor: Optional[dict] = None, rnn_hidden_critic: Optional[dict] = None, test_mode: Optional[bool] = False, info: dict = None, **kwargs): n_env = len(obs_dict) rnn_hidden_critic_new, values_out, log_pi_a_dict, values_dict, log_pi_gate_dict = {}, {}, {}, {}, {} obs_input, agents_id, avail_actions_input = self._build_inputs(obs_dict, avail_actions_dict) alive_ally = {k: np.stack([int(data['agent_mask'][k]) for data in info]).reshape([n_env, 1, -1]) for k in self.agent_keys} rnn_hidden_actor_new, pi_dists, gate_log_prob = self.policy(observation=obs_input, agent_ids=agents_id, avail_actions=avail_actions_input, rnn_hidden=rnn_hidden_actor, alive_ally=alive_ally) if not test_mode: critic_input = self._build_critic_inputs(batch_size=n_env, obs_batch=obs_input, state=state) rnn_hidden_critic_new, values_out = self.policy.get_values(observation=critic_input, agent_ids=agents_id, rnn_hidden=rnn_hidden_critic) if self.use_parameter_sharing: key = self.agent_keys[0] actions_sample = pi_dists[key].stochastic_sample() if self.continuous_control: actions_out = actions_sample.reshape(n_env, self.n_agents, -1) else: actions_out = actions_sample.reshape(n_env, self.n_agents) actions_dict = [{k: actions_out[e, i].cpu().detach().numpy() for i, k in enumerate(self.agent_keys)} for e in range(n_env)] if not test_mode: log_pi_a = pi_dists[key].log_prob(actions_sample).cpu().detach().numpy() log_pi_a = log_pi_a.reshape(n_env, self.n_agents) key = self.model_keys[0] gate_log_prob = gate_log_prob[key].reshape(n_env, self.n_agents).cpu().detach().numpy() log_pi_a_dict = {k: log_pi_a[:, i] for i, k in enumerate(self.agent_keys)} log_pi_gate_dict = {k: gate_log_prob[:, i] for i, k in enumerate(self.agent_keys)} values_out[key] = values_out[key].reshape(n_env, self.n_agents) values_dict = {k: values_out[key][:, i].cpu().detach().numpy() for i, k in enumerate(self.agent_keys)} else: actions_sample = {k: pi_dists[k].stochastic_sample() for k in self.agent_keys} if self.continuous_control: actions_dict = [{k: actions_sample[k].cpu().detach().numpy()[e].reshape([-1]) for k in self.agent_keys} for e in range(n_env)] else: actions_dict = [{k: actions_sample[k].cpu().detach().numpy()[e].reshape([]) for k in self.agent_keys} for e in range(n_env)] if not test_mode: log_pi_a = {k: pi_dists[k].log_prob(actions_sample[k]).cpu().detach().numpy() for k in self.agent_keys} log_pi_gate = {k: gate_log_prob[k].cpu().detach().numpy() for k in self.agent_keys} log_pi_a_dict = {k: log_pi_a[k].reshape([n_env]) for i, k in enumerate(self.agent_keys)} log_pi_gate_dict = {k: log_pi_gate[k].reshape([n_env]) for i, k in enumerate(self.agent_keys)} values_dict = {k: values_out[k].cpu().detach().numpy().reshape([n_env]) for k in self.agent_keys} return {"rnn_hidden_actor": rnn_hidden_actor_new, "rnn_hidden_critic": rnn_hidden_critic_new, "actions": actions_dict, "log_pi": log_pi_a_dict, "gate_log_pi": log_pi_gate_dict ,"values": values_dict}
[docs] def store_experience(self, obs_dict, avail_actions, actions_dict, log_pi_a, rewards_dict, values_dict, terminals_dict, info, **kwargs): log_pi_gate = kwargs['log_pi_gate'] experience_data = { 'obs': {k: np.array([data[k] for data in obs_dict]) for k in self.agent_keys}, 'actions': {k: np.array([data[k] for data in actions_dict]) for k in self.agent_keys}, 'gate_log_pi_old': log_pi_gate, 'log_pi_old': log_pi_a, 'rewards': {k: np.array([data[k] for data in rewards_dict]) for k in self.agent_keys}, 'values': values_dict, 'terminals': {k: np.array([data[k] for data in terminals_dict]) for k in self.agent_keys}, 'agent_mask': {k: np.array([data['agent_mask'][k] for data in info]) for k in self.agent_keys}, } if self.use_rnn: experience_data['episode_steps'] = np.array([data['episode_step'] - 1 for data in info]) if self.use_global_state: experience_data['state'] = np.array(kwargs['state']) if self.use_actions_mask: experience_data['avail_actions'] = {k: np.array([data[k] for data in avail_actions]) for k in self.agent_keys} self.memory.store(**experience_data)
[docs] def run_episodes(self, n_episodes: int = 1, run_envs: Optional[DummyVecMultiAgentEnv | SubprocVecMultiAgentEnv] = None, test_mode: bool = False, close_envs: bool = True) -> list: envs = self.train_envs if run_envs is None else run_envs num_envs = envs.num_envs videos, episode_videos = [[] for _ in range(num_envs)], [] episode_count, scores, best_score = 0, [], -np.inf obs_dict, info = envs.reset() avail_actions = envs.buf_avail_actions if self.use_actions_mask else None state = envs.buf_state if self.use_global_state else None if test_mode: if self.config.render_mode == "rgb_array" and self.render: images = envs.render(self.config.render_mode) for idx, img in enumerate(images): videos[idx].append(img) else: if self.use_rnn: self.memory.clear_episodes() rnn_hidden_actor, rnn_hidden_critic = self.init_rnn_hidden(num_envs) info = [{'agent_mask': {k: True for k in self.agent_keys}} for _ in range(num_envs)] while episode_count < n_episodes: step_info = {} obs_dict = [self.pad_observation(obs) for obs in obs_dict] policy_out = self.get_actions(obs_dict=obs_dict, state=state, avail_actions_dict=avail_actions, rnn_hidden_actor=rnn_hidden_actor, rnn_hidden_critic=rnn_hidden_critic, test_mode=test_mode, info=info) rnn_hidden_actor, rnn_hidden_critic = policy_out['rnn_hidden_actor'], policy_out['rnn_hidden_critic'] actions_dict, log_pi_a_dict, log_pi_gate_dict = policy_out['actions'], policy_out['log_pi'], policy_out['gate_log_pi'] values_dict = policy_out['values'] next_obs_dict, rewards_dict, terminated_dict, truncated, info = envs.step(actions_dict) next_state = envs.buf_state if self.use_global_state else None next_avail_actions = envs.buf_avail_actions if self.use_actions_mask else None if test_mode: if self.config.render_mode == "rgb_array" and self.render: images = envs.render(self.config.render_mode) for idx, img in enumerate(images): videos[idx].append(img) else: self.store_experience(obs_dict, avail_actions, actions_dict, log_pi_a_dict, rewards_dict, values_dict, terminated_dict, info, **{'state': state, 'log_pi_gate': log_pi_gate_dict}) obs_dict, avail_actions = deepcopy(next_obs_dict), deepcopy(next_avail_actions) state = deepcopy(next_state) if self.use_global_state else None obs_dict = [self.pad_observation(obs) for obs in obs_dict] for i in range(num_envs): if all(terminated_dict[i].values()) or truncated[i]: episode_count += 1 episode_score = float(np.mean(itemgetter(*self.agent_keys)(info[i]["episode_score"]))) scores.append(episode_score) if test_mode: if self.use_rnn: rnn_hidden_actor, _ = self.init_hidden_item(i, rnn_hidden_actor) if best_score < episode_score: best_score = episode_score episode_videos = videos[i].copy() else: if all(terminated_dict[i].values()): value_next = {key: 0.0 for key in self.agent_keys} else: _, value_next = self.values_next(i_env=i, obs_dict=obs_dict[i], state=None if state is None else state[i], rnn_hidden_critic=rnn_hidden_critic) self.memory.finish_path(i_env=i, i_step=info[i]['episode_step'], value_next=value_next, value_normalizer=self.learner.value_normalizer) if self.use_rnn: rnn_hidden_actor, rnn_hidden_critic = self.init_hidden_item(i, rnn_hidden_actor, rnn_hidden_critic) if self.use_wandb: step_info["Train-Results/Episode-Steps/env-%d" % i] = info[i]["episode_step"] step_info["Train-Results/Episode-Rewards/env-%d" % i] = info[i]["episode_score"] else: step_info["Train-Results/Episode-Steps"] = {"env-%d" % i: info[i]["episode_step"]} step_info["Train-Results/Episode-Rewards"] = { "env-%d" % i: np.mean(itemgetter(*self.agent_keys)(info[i]["episode_score"]))} self.current_step += info[i]["episode_step"] self.log_infos(step_info, self.current_step) obs_dict[i] = info[i]["reset_obs"] envs.buf_obs[i] = info[i]["reset_obs"] if self.use_actions_mask: avail_actions[i] = info[i]["reset_avail_actions"] envs.buf_avail_actions[i] = info[i]["reset_avail_actions"] if test_mode: if self.config.render_mode == "rgb_array" and self.render: # time, height, width, channel -> time, channel, height, width videos_info = {"Videos_Test": np.array([episode_videos], dtype=np.uint8).transpose((0, 1, 4, 2, 3))} self.log_videos(info=videos_info, fps=self.fps, x_index=self.current_step) test_info = { "Test-Results/Episode-Rewards/Mean-Score": np.mean(scores), "Test-Results/Episode-Rewards/Std-Score": np.std(scores), } self.log_infos(test_info, self.current_step) if close_envs: envs.close() return scores