Source code for xuance.environment.utils.wrapper

from xuance.common import Optional, Tuple
from gymnasium import spaces


[docs] class XuanCeEnvWrapper: """ Wraps an environment for single-agent system that can run in XuanCe. """ def __init__(self, env, **kwargs): super(XuanCeEnvWrapper, self).__init__() self.env = env self._action_space: Optional[spaces.Space] = None self._observation_space: Optional[spaces.Space] = None self._metadata: Optional[dict] = None self._max_episode_steps: Optional[int] = None self._episode_step = 0 self._episode_score = 0.0 self._is_continuous = isinstance(self._action_space, spaces.Box) self._action_low = self.action_space.low if hasattr(self._action_space, "low") else None self._action_high = self.action_space.high if hasattr(self._action_space, "high") else None @property def action_space(self): """Returns the action space of the environment.""" if self._action_space is None: return self.env.action_space return self._action_space @action_space.setter def action_space(self, space: spaces.Space): """Sets the action space""" self._action_space = space @property def observation_space(self) -> spaces.Space: """Returns the observation space of the environment.""" if self._observation_space is None: return self.env.observation_space return self._observation_space @observation_space.setter def observation_space(self, space: spaces.Space): """Sets the observation space.""" self._observation_space = space @property def metadata(self) -> dict: """Returns the environment metadata.""" if self._metadata is None: return self.env.metadata return self._metadata @metadata.setter def metadata(self, value): """Sets metadata""" self._metadata = value @property def max_episode_steps(self) -> int: """Returns the maximum of episode steps.""" if self._max_episode_steps is None: return self.env.max_episode_steps return self._max_episode_steps @max_episode_steps.setter def max_episode_steps(self, value): """Sets the maximum of episode steps""" self._max_episode_steps = value @property def render_mode(self) -> Optional[str]: """Returns the environment render_mode.""" return self.env.render_mode
[docs] def reset(self, **kwargs): """Resets the environment with kwargs.""" try: obs, info = self.env.reset(**kwargs) except: obs = self.env.reset(**kwargs) info = {} self._episode_step = 0 self._episode_score = 0.0 info["episode_step"] = self._episode_step return obs, info
[docs] def step(self, action): """Steps through the environment with action.""" if self._is_continuous: action = (action + 1.0) * 0.5 * (self._action_high - self._action_low) + self._action_low observation, reward, terminated, truncated, info = self.env.step(action) self._episode_step += 1 self._episode_score += reward info["episode_step"] = self._episode_step # current episode step info["episode_score"] = self._episode_score # the accumulated rewards return observation, reward, terminated, truncated, info
[docs] def render(self, *args, **kwargs): """Renders the environment.""" return self.env.render(*args, **kwargs)
[docs] def close(self): """Closes the environment.""" return self.env.close()
@property def unwrapped(self): """Returns the base environment of the wrapper.""" return self.env
[docs] class XuanCeAtariEnvWrapper(XuanCeEnvWrapper): """ Wraps an Atari environment that can run in XuanCe. """ def __init__(self, env, **kwargs): super().__init__(env, **kwargs)
[docs] def reset(self, **kwargs): """Resets the environment with kwargs.""" if self.env.was_real_done: self._episode_step = 0 self._episode_score = 0.0 obs, info = self.env.reset(**kwargs) info["episode_step"] = self._episode_step return obs, info
[docs] def step(self, action): """Steps through the environment with action.""" if self._is_continuous: action = (action + 1.0) * 0.5 * (self._action_high - self._action_low) + self._action_low observation, reward, terminated, truncated, info = self.env.step(action) self._episode_step = self.env._episode_step self._episode_score = self.env._episode_score info["episode_step"] = self._episode_step # current episode step info["episode_score"] = self._episode_score # the accumulated rewards return observation, reward, terminated, truncated, info
[docs] class XuanCeMultiAgentEnvWrapper(XuanCeEnvWrapper): """ Wraps an environment for multi-agent system that can run in XuanCe. """ def __init__(self, env, **kwargs): super(XuanCeMultiAgentEnvWrapper, self).__init__(env, **kwargs) self._env_info: Optional[dict] = None self._state_space: Optional[spaces.Space] = None self.agents = self.env.agents # e.g., ['red_0', 'red_1', 'blue_0', 'blue_1']. self.num_agents = self.env.num_agents # Number of all agents, e.g., 4. self.agent_groups = self.env.agent_groups self._episode_score = {agent: 0.0 for agent in self.agents} self.env_info = self.env.get_env_info() self.groups_info = self.env.get_groups_info()
[docs] def reset(self, **kwargs) -> Tuple[dict, dict]: """Resets the environment with kwargs.""" obs, info = self.env.reset(**kwargs) self._episode_step = 0 self._episode_score = {agent: 0.0 for agent in self.agents} info["episode_step"] = self._episode_step # current episode step info["episode_score"] = self._episode_score # the accumulated rewards info["agent_mask"] = self.agent_mask info["avail_actions"] = self.avail_actions info["state"] = self.state return obs, info
[docs] def step(self, action): """Steps through the environment with action.""" observation, reward, terminated, truncated, info = self.env.step(action) self._episode_step += 1 for agent in self.agents: self._episode_score[agent] += reward[agent] info["episode_step"] = self._episode_step # current episode step info["episode_score"] = self._episode_score # the accumulated rewards info["agent_mask"] = self.agent_mask info["avail_actions"] = self.avail_actions info["state"] = self.state return observation, reward, terminated, truncated, info
@property def env_info(self) -> Optional[dict]: """Returns the information of the environment.""" if self._env_info is None: return self.env.env_info return self._env_info @env_info.setter def env_info(self, info: {}): """Sets the action space""" self._env_info = info @property def state_space(self) -> spaces.Space: """Returns the global state space of the environment.""" if self._state_space is None: return self.env.state_space return self._state_space @state_space.setter def state_space(self, space: spaces.Space): """Sets the global state space.""" self._state_space = space @property def state(self): """Returns global states in the multi-agent environment.""" return self.env.state() @property def agent_mask(self): """Returns mask variables to mark alive agents in multi-agent environment.""" return self.env.agent_mask() @property def avail_actions(self): """Returns mask variables to mark available actions for each agent.""" return self.env.avail_actions()