Source code for xuance.environment
from argparse import Namespace
from xuance.environment.utils import XuanCeEnvWrapper, XuanCeAtariEnvWrapper, XuanCeMultiAgentEnvWrapper
from xuance.environment.utils import RawEnvironment, RawMultiAgentEnv
from xuance.environment.utils import space2shape, combined_shape
from xuance.environment.vector_envs import DummyVecEnv, DummyVecEnv_Atari, DummyVecMultiAgentEnv
from xuance.environment.vector_envs import SubprocVecEnv, SubprocVecEnv_Atari, SubprocVecMultiAgentEnv
from xuance.environment.single_agent_env import REGISTRY_ENV
from xuance.environment.multi_agent_env import REGISTRY_MULTI_AGENT_ENV
from xuance.environment.vector_envs import REGISTRY_VEC_ENV
[docs]
def make_envs(config: Namespace):
"""
Creates and returns a set of environments based on the provided configuration.
This function supports single-agent, multi-agent, and vectorized environments and handles
the initialization of the environment(s) based on the configuration settings. The function
also manages distributed training setups and environment vectorization.
Parameters:
-----------
config : Namespace
A configuration object containing the necessary settings to initialize the environment.
The configuration should contain the following attributes:
- env_name (str): The name of the environment to create.
- env_seed (int): The seed value for environment initialization.
- distributed_training (bool): Whether to use distributed training.
- parallels (int): The number of parallel environments for vectorized setups.
- vectorize (str): The type of vectorization to apply (e.g., 'DummyVecEnv', 'SubprocVecEnv', etc.).
Returns:
List of environments based on the configuration settings.
"""
def _thunk(env_seed: int = None):
"""
Function that creates and returns an environment based on the config settings.
Parameters:
-----------
env_seed : int, optional
The seed to use for environment initialization. Defaults to `None`.
Returns:
--------
environment
The created environment based on the configuration settings (single-agent or multi-agent).
"""
config.env_seed = env_seed
if config.env_name in REGISTRY_ENV.keys():
if config.env_name == "Platform":
return REGISTRY_ENV[config.env_name](config)
elif config.env_name == "Atari":
return XuanCeAtariEnvWrapper(REGISTRY_ENV[config.env_name](config))
else:
return XuanCeEnvWrapper(REGISTRY_ENV[config.env_name](config))
elif config.env_name in REGISTRY_MULTI_AGENT_ENV.keys():
return XuanCeMultiAgentEnvWrapper(REGISTRY_MULTI_AGENT_ENV[config.env_name](config))
else:
raise AttributeError(f"The environment named {config.env_name} cannot be created.")
distributed_training = config.distributed_training if hasattr(config, "distributed_training") else False
if not hasattr(config, "render_mode"):
config.render_mode = "human"
if distributed_training:
# rank = int(os.environ['RANK']) # for torch.nn.parallel.DistributedDataParallel
rank = 1
config.env_seed += rank * config.parallels
if config.vectorize in REGISTRY_VEC_ENV.keys():
env_fn = [_thunk for _ in range(config.parallels)]
return REGISTRY_VEC_ENV[config.vectorize](env_fn, config.env_seed)
elif config.vectorize == "NOREQUIRED":
return _thunk()
else:
raise AttributeError(f"The vectorizer {config.vectorize} is not implemented.")