Source code for xuance.common.common_tools

import os
import yaml
import time
import scipy.signal
from copy import deepcopy
from types import SimpleNamespace as SN

EPS = 1e-8


[docs] def recursive_dict_update(basic_dict, target_dict): """Update the dict values. Args: basic_dict: the original dict variable that to be updated. target_dict: the target dict variable with new values. Returns: A dict mapping keys of basic_dict to the values of the same keys in target_dict. For example: basic_dict = {'a': 1, 'b': 2} target_dict = {'a': 3, 'c': 4} out_dict = recursive_dict_update(basic_dict, target_dict) output_dict = {'a': 3, 'b': 2, 'c': 4} """ out_dict = deepcopy(basic_dict) for key, value in target_dict.items(): if isinstance(value, dict): out_dict[key] = recursive_dict_update(out_dict.get(key, {}), value) else: out_dict[key] = value return out_dict
[docs] def load_yaml(file_dir) -> dict: """Get dict variable from a YAML file. Args: file_dir: the directory of the YAML file. Returns: config_dict: the keys and corresponding values in the YAML file. """ with open(file_dir, "r") as f: try: config_dict = yaml.load(f, Loader=yaml.FullLoader) except yaml.YAMLError as exc: assert False, file_dir + " error: {}".format(exc) return config_dict
[docs] def get_arguments(algo, env, env_id, config_path=None, parser_args=None): """Get arguments from ``.yaml`` files Args: algo: the algorithm name that will be implemented, env: The name of the environment, env_id: The name of the scenario in the environment. config_path: default is None, if None, the default configs (``xuance/configs/.../*.yaml``) will be loaded. parser_args: arguments that specified by parser tools. Returns: args: the SimpleNamespace variables that contains attributes for DRL implementations. """ main_path = os.getcwd() main_path_package = os.path.dirname(os.path.dirname(__file__)) config_path_default = os.path.join(main_path_package, "configs") ''' get the arguments from xuance/config/basic.yaml ''' config_basic = load_yaml(os.path.join(config_path_default, "basic.yaml")) ''' get the arguments from, e.g., xuance/config/dqn/box2d/CarRacing-v3.yaml ''' if type(algo) == list: # for different groups of MARL algorithms. if config_path is None: config_path = [] file_name_env_id = env + "/" + env_id + ".yaml" file_name_env = env + "/" + env_id + ".yaml" config_path_env_id = [os.path.join(config_path_default, agent, file_name_env_id) for agent in algo] config_path_env = [os.path.join(config_path_default, agent, file_name_env) for agent in algo] for i_agent, agent in enumerate(algo): if os.path.exists(config_path_env_id[i_agent]): config_path.append(config_path_env_id[i_agent]) elif os.path.exists(config_path_env[i_agent]): config_path.append(config_path_env[i_agent]) else: raise AttributeError( f"Cannot find file named '{config_path_env_id[i_agent]}' or '{config_path_env[i_agent]}'." f"You can also customize the configuration file by specifying the `config_path` parameter " f"in the `get_runner()` function.") else: config_path = [os.path.join(main_path, _path) for _path in config_path] config_algo_default = [load_yaml(_path) for _path in config_path] configs = [recursive_dict_update(config_basic, config_i) for config_i in config_algo_default] # load parser_args and rewrite the parameters if their names are same. if parser_args is not None: configs = [recursive_dict_update(config_i, parser_args.__dict__) for config_i in configs] args = [SN(**config_i) for config_i in configs] for arg in args: if arg.dl_toolbox == "torch": from xuance.torch import set_device elif arg.dl_toolbox == "tensorflow": from xuance.tensorflow import set_device elif arg.dl_toolbox == "mindspore": from xuance.mindspore import set_device else: raise f"Unsupported dl_toolbox: {arg}" arg.device = set_device(arg.device) elif type(algo) == str: if config_path is None: file_name_env_id = env + "/" + env_id + ".yaml" file_name_env = env + ".yaml" config_path_env_id = os.path.join(config_path_default, algo, file_name_env_id) config_path_env = os.path.join(config_path_default, algo, file_name_env) if os.path.exists(config_path_env_id): config_path = config_path_env_id elif os.path.exists(config_path_env): config_path = config_path_env else: error_path_env_id = os.path.join('./xuance/configs', algo, file_name_env_id) error_path_env = os.path.join('./xuance/configs', algo, file_name_env) raise AttributeError( f"The file '{error_path_env_id}' or '{error_path_env}' does not exist in this library. " f"You can also customize the configuration file by specifying the `config_path` parameter " f"in the `get_runner()` function.") else: config_path = os.path.join(main_path, config_path) config_algo_default = load_yaml(config_path) configs = recursive_dict_update(config_basic, config_algo_default) # load parser_args and rewrite the parameters if their names are same. if parser_args is not None: configs = recursive_dict_update(configs, parser_args.__dict__) if not ('env_id' in configs.keys()): configs['env_id'] = env_id args = SN(**configs) if args.dl_toolbox == "torch": from xuance.torch import set_device elif args.dl_toolbox == "tensorflow": from xuance.tensorflow import set_device elif args.dl_toolbox == "mindspore": from xuance.mindspore import set_device else: raise f"Unsupported dl_toolbox: {args}" args.device = set_device(args.device) else: raise AttributeError("Unsupported agent_name or env_name!") return args
[docs] def create_directory(path): """Create an empty directory. Args: path: the path of the directory """ """Create a directory (recursively) if it does not exist.""" if not path: return os.makedirs(path, exist_ok=True)
[docs] def discount_cumsum(x, discount=0.99): """Get a discounted cumulated summation. Args: x: The original sequence. In DRL, x can be reward sequence. discount: the discount factor (gamma), default is 0.99. Returns: The discounted cumulative returns for each step. Examples: >>> x = [0, 1, 2, 2] >>> y = discount_cumsum(x, discount=0.99) [4.890798, 4.9402, 3.98, 2.0] """ return scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1]
[docs] def get_time_string(): t_now = time.localtime(time.time()) t_year = str(t_now.tm_year).zfill(4) t_month = str(t_now.tm_mon).zfill(2) t_day = str(t_now.tm_mday).zfill(2) t_hour = str(t_now.tm_hour).zfill(2) t_min = str(t_now.tm_min).zfill(2) t_sec = str(t_now.tm_sec).zfill(2) time_string = f"{t_year}_{t_month}{t_day}_{t_hour}{t_min}{t_sec}" return time_string