Source code for xuance.environment.utils.base

from xuance.common import Optional, List, Tuple, Dict, Any, Union
from abc import ABC, abstractmethod
from gymnasium import spaces
import numpy as np

AgentID = Any
AgentValue = Any
MultiAgentDict = Dict[AgentID, AgentValue]
AgentKeys = List[str]


[docs] class RawEnvironment(ABC): """ A base class for new environment. The following attributes are necessary when creating a new environment: - self.env: the environment object; - self.observation_space: the observation space of the agent; - self.action_space: the action space of the agent; - self.max_episode_steps: the maximum steps for one episode of the environment in XuanCe. """ def __new__(cls, *args, **kwargs): if cls is RawEnvironment: raise TypeError("Type RawEnvironment cannot be instantiated; It can be used only as a base class") if super().__new__ is object.__new__ and cls.__init__ is not object.__init__: obj = super().__new__(cls) else: obj = super().__new__(cls, *args, **kwargs) return obj def __init__(self, *args, **kwargs): super(RawEnvironment, self).__init__(*args, **kwargs) self.env = None self.observation_space: Optional[spaces.Space] = None self.action_space: Optional[Union[spaces.Discrete, spaces.Box]] = None self.max_episode_steps: Optional[int] = None
[docs] @abstractmethod def reset(self, **kwargs): """ Resets the environment with kwargs. Returns: observation (np.ndarray or list): The initial observations of the agent. info (dict): The information about the environment. """ raise NotImplementedError
[docs] @abstractmethod def step(self, action): """ Steps through the environment with action. Parameters: action (np.ndarray or list): The action to be executed. Return: observation (np.ndarray or list): The next step observation after executing action. reward (np.ndarray or list): The reward returned by the environment. terminated(np.ndarray or list): A bool value that indicates if the environment should be terminated. truncated(np.ndarray or list): A bool value that indicates if the environment should be truncated. info (dict): The information about the environment. """ raise NotImplementedError
[docs] @abstractmethod def render(self, *args, **kwargs): """ Renders the environment. Return: rgb_images (np.ndarray or list): The images used to visualize the environment. """ return NotImplementedError
[docs] @abstractmethod def close(self): """Closes the environment.""" return NotImplementedError
[docs] def avail_actions(self): """Returns a boolean mask indicating which actions are available for each agent.""" assert type(self.action_space) is Dict, "The action space should be discrete." return np.ones(self.action_space.n, np.bool_)
[docs] class RawMultiAgentEnv(ABC): """A base class for multi-agent environment. The following attributes are necessary when creating a new multi-agent environment in XuanCe: - self.env: the environment object; - self.observation_space: the observation space of the agent; - self.action_space: the action space of the agent; - self.agents: a list of all agents' ids; - self.num_agents: the number of total agents in the environment; - self.groups: a list of groups. Each group contains agents' ids with a same role; - self.num_groups: the number of groups of the environment, default is 1; - self.max_episode_steps: the maximum steps for one episode of the environment. """ def __new__(cls, *args, **kwargs): if cls is RawMultiAgentEnv: raise TypeError("Type RawMultiAgentEnv cannot be instantiated; It can be used only as a base class") if super().__new__ is object.__new__ and cls.__init__ is not object.__init__: obj = super().__new__(cls) else: obj = super().__new__(cls, *args, **kwargs) return obj def __init__(self, *args, **kwargs): super(RawMultiAgentEnv, self).__init__(*args, **kwargs) self.env = None self.agents: Optional[AgentKeys] = None # e.g., ['red_0', 'red_1', 'blue_0', 'blue_1']. self.state_space: Optional[spaces.Space] = None self.observation_space: Optional[Dict[spaces.Space]] = None self.action_space: Optional[Dict[spaces.Space]] = None self.num_agents: Optional[int] = None # Number of all agents, e.g., 4. self.agent_groups: Optional[List[AgentKeys]] = [] self.max_episode_steps: Optional[int] = None
[docs] def get_env_info(self) -> Dict[str, Any]: return {'state_space': self.state_space, 'observation_space': self.observation_space, 'action_space': self.action_space, 'agents': self.agents, 'num_agents': self.num_agents, 'max_episode_steps': self.max_episode_steps}
[docs] def get_groups_info(self) -> Dict[str, Any]: return {'num_groups': len(self.agent_groups), 'agent_groups': self.agent_groups, 'observation_space_groups': [{k: self.observation_space[k] for i, k in enumerate(group)} for group in self.agent_groups], 'action_space_groups': [{k: self.action_space[k] for i, k in enumerate(group)} for group in self.agent_groups], 'num_agents_groups': [len(group) for group in self.agent_groups]}
[docs] def agent_mask(self): """Returns boolean mask variables indicating which agents are currently alive.""" return {agent: True for agent in self.agents}
[docs] def avail_actions(self): """Returns a boolean mask indicating which actions are available for each agent.""" return {agent: np.ones(self.action_space[agent].n, np.bool_) for agent in self.agents}
[docs] def state(self): """Returns the global state of the environment.""" raise NotImplementedError
[docs] @abstractmethod def reset(self, **kwargs): """ Resets the environment with kwargs. Returns: observation (MultiAgentDict): The initial observations of the agent. info (MultiAgentDict): The information about the environment. """ raise NotImplementedError
[docs] @abstractmethod def step(self, action_dict: MultiAgentDict) -> Tuple[ MultiAgentDict, MultiAgentDict, MultiAgentDict, bool, MultiAgentDict]: """ Steps through the environment with action. Parameters: action_dict (MultiAgentDict): A dict that contains all agents' actions. Return: observation (MultiAgentDict): The next step observations after executing actions. reward (MultiAgentDict): The rewards returned by the environment. terminated(MultiAgentDict): A dict of bool values that indicates if the environment should be terminated. truncated(bool): A bool value that indicates if the environment should be truncated. info (MultiAgentDict): The information about the environment. """ raise NotImplementedError
[docs] @abstractmethod def render(self, *args, **kwargs): """ Renders the environment. Return: rgb_images (np.ndarray or list): The images used to visualize the environment. """ return NotImplementedError
[docs] @abstractmethod def close(self): """Closes the environment.""" return