from copy import deepcopy
from operator import itemgetter
from typing import Dict, List
import numpy as np
import torch
from gymnasium import Space
from argparse import Namespace
from xuance.torch.agents.multi_agent_rl.mappo_agents import MAPPO_Agents
from xuance.common import Optional, Union, MultiAgentBaseCallback
import gymnasium as gym
from xuance.environment import DummyVecMultiAgentEnv, SubprocVecMultiAgentEnv, space2shape
from xuance.torch import Module, REGISTRY_Policy, ModuleDict
from xuance.torch.communications.comm_net import CommNet
from xuance.torch.utils import ActivationFunctions, NormalizeFunctions
[docs]
class CommNet_Agents(MAPPO_Agents):
def __init__(
self,
config: Namespace,
envs: Optional[DummyVecMultiAgentEnv | SubprocVecMultiAgentEnv] = None,
num_agents: Optional[int] = None,
agent_keys: Optional[List[str]] = None,
state_space: Optional[Space] = None,
observation_space: Optional[Space] = None,
action_space: Optional[Space] = None,
callback: Optional[MultiAgentBaseCallback] = None
) -> None:
super(CommNet_Agents, self).__init__(
config, envs, num_agents, agent_keys, state_space, observation_space, action_space, callback
)
self.policy = self._build_policy()
self.memory = self._build_memory() # build memory
self.learner = self._build_learner(self.config, self.model_keys, self.agent_keys, self.policy, callback)
def _build_communicator(self, input_space: Union[Dict[str, Space], Dict[str, tuple]], ) -> Module:
communicator = ModuleDict()
hidden_sizes = {'fc_hidden_sizes': self.config.fc_hidden_sizes,
'recurrent_hidden_size': self.config.recurrent_hidden_size}
for key in self.model_keys:
input_communicator = dict(
input_shape=space2shape(input_space[key]),
hidden_sizes=hidden_sizes,
model_keys=self.model_keys,
agent_keys=self.agent_keys,
n_agents=self.n_agents,
device=self.device,
config=self.config)
communicator[key] = CommNet(**input_communicator)
return communicator
def _build_policy(self) -> Module:
normalize_fn = NormalizeFunctions[self.config.normalize] if hasattr(self.config, "normalize") else None
initializer = torch.nn.init.orthogonal_
activation = ActivationFunctions[self.config.activation]
device = self.device
agent = self.config.agent
max_length = max(space.shape[0] for space in self.observation_space.values())
self.observation_space = {agent: gym.spaces.Box(-np.inf, np.inf, (max_length,), dtype=np.float32)
for agent in self.observation_space}
# build representations
communicator = self._build_communicator(self.observation_space)
space_actor_in = {agent: gym.spaces.Box(-np.inf, np.inf, (self.config.recurrent_hidden_size,), dtype=np.float32)
for agent in self.observation_space}
if self.use_global_state:
dim_obs_all = sum(self.state_space.shape)
else:
dim_obs_all = sum([sum(self.observation_space[k].shape) for k in self.agent_keys])
space_critic_in = {k: (dim_obs_all,) for k in self.agent_keys}
A_representation = self._build_representation(self.config.representation, space_actor_in, self.config)
C_representation = self._build_representation(self.config.representation, space_critic_in, self.config)
# build policies
if self.config.policy == "CommNet_Policy":
policy = REGISTRY_Policy[self.config.policy](
action_space=self.action_space, n_agents=self.n_agents,
representation_actor=A_representation, representation_critic=C_representation,
actor_hidden_size=self.config.actor_hidden_size, critic_hidden_size=self.config.critic_hidden_size,
normalize=normalize_fn, initialize=initializer, activation=activation,
device=device, use_distributed_training=self.distributed_training,
use_parameter_sharing=self.use_parameter_sharing, model_keys=self.model_keys,
use_rnn=self.use_rnn, rnn=self.config.rnn if self.use_rnn else None,
communicator=communicator, agent_keys=self.agent_keys, comm_passes=self.config.comm_passes)
self.continuous_control = False
else:
raise AttributeError(f"{agent} currently does not support the policy named {self.config.policy}.")
return policy
[docs]
def get_actions(self,
obs_dict: List[dict],
state: Optional[np.ndarray] = None,
avail_actions_dict: Optional[List[dict]] = None,
rnn_hidden_actor: Optional[dict] = None,
rnn_hidden_critic: Optional[dict] = None,
test_mode: Optional[bool] = False,
info: dict = None,
**kwargs):
n_env = len(obs_dict)
rnn_hidden_critic_new, values_out, log_pi_a_dict, values_dict = {}, {}, {}, {}
obs_input, agents_id, avail_actions_input = self._build_inputs(obs_dict, avail_actions_dict)
alive_ally = {k: np.stack([int(data['agent_mask'][k]) for data in info]).reshape([n_env, 1, -1]) for k in
self.agent_keys}
rnn_hidden_actor_new, pi_dists = self.policy(observation=obs_input,
agent_ids=agents_id,
avail_actions=avail_actions_input,
rnn_hidden=rnn_hidden_actor,
alive_ally=alive_ally)
if not test_mode:
critic_input = self._build_critic_inputs(batch_size=n_env, obs_batch=obs_input, state=state)
rnn_hidden_critic_new, values_out = self.policy.get_values(observation=critic_input,
agent_ids=agents_id,
rnn_hidden=rnn_hidden_critic)
if self.use_parameter_sharing:
key = self.agent_keys[0]
actions_sample = pi_dists[key].stochastic_sample()
if self.continuous_control:
actions_out = actions_sample.reshape(n_env, self.n_agents, -1)
else:
actions_out = actions_sample.reshape(n_env, self.n_agents)
actions_dict = [{k: actions_out[e, i].cpu().detach().numpy() for i, k in enumerate(self.agent_keys)}
for e in range(n_env)]
if not test_mode:
log_pi_a = pi_dists[key].log_prob(actions_sample).cpu().detach().numpy()
log_pi_a = log_pi_a.reshape(n_env, self.n_agents)
log_pi_a_dict = {k: log_pi_a[:, i] for i, k in enumerate(self.agent_keys)}
values_out[key] = values_out[key].reshape(n_env, self.n_agents)
values_dict = {k: values_out[key][:, i].cpu().detach().numpy() for i, k in enumerate(self.agent_keys)}
else:
actions_sample = {k: pi_dists[k].stochastic_sample() for k in self.agent_keys}
if self.continuous_control:
actions_dict = [{k: actions_sample[k].cpu().detach().numpy()[e].reshape([-1]) for k in self.agent_keys}
for e in range(n_env)]
else:
actions_dict = [{k: actions_sample[k].cpu().detach().numpy()[e].reshape([]) for k in self.agent_keys}
for e in range(n_env)]
if not test_mode:
log_pi_a = {k: pi_dists[k].log_prob(actions_sample[k]).cpu().detach().numpy() for k in self.agent_keys}
log_pi_a_dict = {k: log_pi_a[k].reshape([n_env]) for i, k in enumerate(self.agent_keys)}
values_dict = {k: values_out[k].cpu().detach().numpy().reshape([n_env]) for k in self.agent_keys}
return {"rnn_hidden_actor": rnn_hidden_actor_new, "rnn_hidden_critic": rnn_hidden_critic_new,
"actions": actions_dict, "log_pi": log_pi_a_dict, "values": values_dict}
[docs]
@staticmethod
def pad_observation(obs_dict):
max_length = max(obs.shape[0] for obs in obs_dict.values())
for k in obs_dict.keys():
if obs_dict[k].shape[0] < max_length:
obs_dict[k] = np.pad(obs_dict[k], (0, max_length - obs_dict[k].shape[0]), mode='constant')
return obs_dict
[docs]
def run_episodes(self,
n_episodes: int = 1,
run_envs: Optional[DummyVecMultiAgentEnv | SubprocVecMultiAgentEnv] = None,
test_mode: bool = False,
close_envs: bool = True) -> list:
envs = self.train_envs if run_envs is None else run_envs
num_envs = envs.num_envs
videos, episode_videos = [[] for _ in range(num_envs)], []
episode_count, scores, best_score = 0, [], -np.inf
obs_dict, info = envs.reset()
avail_actions = envs.buf_avail_actions if self.use_actions_mask else None
state = envs.buf_state if self.use_global_state else None
if test_mode:
if self.config.render_mode == "rgb_array" and self.render:
images = envs.render(self.config.render_mode)
for idx, img in enumerate(images):
videos[idx].append(img)
else:
if self.use_rnn:
self.memory.clear_episodes()
rnn_hidden_actor, rnn_hidden_critic = self.init_rnn_hidden(num_envs)
info = [{'agent_mask': {k: True for k in self.agent_keys}} for _ in range(num_envs)]
while episode_count < n_episodes:
step_info = {}
obs_dict = [self.pad_observation(obs) for obs in obs_dict]
policy_out = self.get_actions(obs_dict=obs_dict, state=state, avail_actions_dict=avail_actions,
rnn_hidden_actor=rnn_hidden_actor, rnn_hidden_critic=rnn_hidden_critic,
test_mode=test_mode, info=info)
rnn_hidden_actor, rnn_hidden_critic = policy_out['rnn_hidden_actor'], policy_out['rnn_hidden_critic']
actions_dict, log_pi_a_dict = policy_out['actions'], policy_out['log_pi']
values_dict = policy_out['values']
next_obs_dict, rewards_dict, terminated_dict, truncated, info = envs.step(actions_dict)
next_state = envs.buf_state if self.use_global_state else None
next_avail_actions = envs.buf_avail_actions if self.use_actions_mask else None
if test_mode:
if self.config.render_mode == "rgb_array" and self.render:
images = envs.render(self.config.render_mode)
for idx, img in enumerate(images):
videos[idx].append(img)
else:
self.store_experience(obs_dict, avail_actions, actions_dict, log_pi_a_dict, rewards_dict, values_dict,
terminated_dict, info, **{'state': state})
obs_dict, avail_actions = deepcopy(next_obs_dict), deepcopy(next_avail_actions)
state = deepcopy(next_state) if self.use_global_state else None
for i in range(num_envs):
if all(terminated_dict[i].values()) or truncated[i]:
episode_count += 1
episode_score = float(np.mean(itemgetter(*self.agent_keys)(info[i]["episode_score"])))
scores.append(episode_score)
if test_mode:
if self.use_rnn:
rnn_hidden_actor, _ = self.init_hidden_item(i, rnn_hidden_actor)
if best_score < episode_score:
best_score = episode_score
episode_videos = videos[i].copy()
else:
if all(terminated_dict[i].values()):
value_next = {key: 0.0 for key in self.agent_keys}
else:
obs_dict = [self.pad_observation(obs) for obs in obs_dict]
_, value_next = self.values_next(i_env=i, obs_dict=obs_dict[i],
state=None if state is None else state[i],
rnn_hidden_critic=rnn_hidden_critic)
self.memory.finish_path(i_env=i, i_step=info[i]['episode_step'], value_next=value_next,
value_normalizer=self.learner.value_normalizer)
if self.use_rnn:
rnn_hidden_actor, rnn_hidden_critic = self.init_hidden_item(i, rnn_hidden_actor,
rnn_hidden_critic)
if self.use_wandb:
step_info["Train-Results/Episode-Steps/env-%d" % i] = info[i]["episode_step"]
step_info["Train-Results/Episode-Rewards/env-%d" % i] = info[i]["episode_score"]
else:
step_info["Train-Results/Episode-Steps"] = {"env-%d" % i: info[i]["episode_step"]}
step_info["Train-Results/Episode-Rewards"] = {
"env-%d" % i: np.mean(itemgetter(*self.agent_keys)(info[i]["episode_score"]))}
self.current_step += info[i]["episode_step"]
self.log_infos(step_info, self.current_step)
obs_dict[i] = info[i]["reset_obs"]
envs.buf_obs[i] = info[i]["reset_obs"]
if self.use_actions_mask:
avail_actions[i] = info[i]["reset_avail_actions"]
envs.buf_avail_actions[i] = info[i]["reset_avail_actions"]
if test_mode:
if self.config.render_mode == "rgb_array" and self.render:
# time, height, width, channel -> time, channel, height, width
videos_info = {"Videos_Test": np.array([episode_videos], dtype=np.uint8).transpose((0, 1, 4, 2, 3))}
self.log_videos(info=videos_info, fps=self.fps, x_index=self.current_step)
test_info = {
"Test-Results/Episode-Rewards/Mean-Score": np.mean(scores),
"Test-Results/Episode-Rewards/Std-Score": np.std(scores),
}
self.log_infos(test_info, self.current_step)
if close_envs:
envs.close()
return scores