Source code for xuance.common.tuning_tools.hyperparameters

from dataclasses import dataclass
from typing import Any, List, Optional, Union


[docs] @dataclass class Hyperparameter: """ Represents a hyperparameter for algorithm tuning. This dataclass defines the structure of a hyperparameter, including its name, type, distribution, whether it should be sampled on a logarithmic scale, and its default value. Attributes: name (str): The name of the hyperparameter. type (str): The type of the hyperparameter. Supported types include 'int', 'float', and 'categorical'. distribution (Union[List[Any], tuple]): The possible values or range for the hyperparameter. - For 'categorical' types, this should be a list of possible values. - For 'int' and 'float' types, this should be a tuple defining the range (min, max). log (bool, optional): Indicates whether the hyperparameter should be sampled on a logarithmic scale. This is typically used for hyperparameters like learning rates that span several orders of magnitude. Defaults to False. default (Optional[Any], optional): The default value of the hyperparameter if no tuning is performed. This provides a fallback value to ensure the algorithm can run with standard settings. Defaults to None. """ name: str # The name of the hyperparameter. type: str # 'int', 'float', 'categorical'. distribution: Union[List[Any], tuple] # Possible values or range. log: bool = False # A flag to sample the value from the log domain or not. default: Optional[Any] = None # Default value.
[docs] class AlgorithmHyperparametersRegistry: """ A registry for managing hyperparameters of different algorithms. This class allows for the registration of algorithms along with their corresponding hyperparameters. It provides methods to retrieve hyperparameters for a specific algorithm and to list all registered algorithms. Attributes: _registry (dict): A class-level dictionary mapping algorithm names to their list of hyperparameters. """ _registry = {}
[docs] @classmethod def register_algorithm(cls, algorithm_name: str, hyperparameters: List[Hyperparameter]): """ Register an algorithm along with its hyperparameters. This method adds an algorithm and its associated hyperparameters to the registry. If the algorithm already exists, its hyperparameters will be updated. Args: algorithm_name (str): The name of the algorithm to register. hyperparameters (List[Hyperparameter]): A list of Hyperparameter instances defining the algorithm's hyperparameters. Example: >>> hyperparams = [ ... Hyperparameter(name="learning_rate", type="float", distribution=(1e-5, 1e-2), log=True, default=1e-3), ... Hyperparameter(name="gamma", type="float", distribution=(0.85, 0.99), log=False, default=0.99), ... ] >>> AlgorithmHyperparametersRegistry.register_algorithm("DQN", hyperparams) """ cls._registry[algorithm_name] = hyperparameters
[docs] @classmethod def get_hyperparameters(cls, algorithm_name: str) -> List[Hyperparameter]: """ Retrieve the list of hyperparameters for a given algorithm. Args: algorithm_name (str): The name of the algorithm whose hyperparameters are to be retrieved. Returns: List[Hyperparameter]: A list of Hyperparameter instances associated with the specified algorithm. Returns an empty list if the algorithm is not registered. Example: >>> hyperparams = AlgorithmHyperparametersRegistry.get_hyperparameters("DQN") >>> for hp in hyperparams: ... print(hp.name, hp.type) learning_rate float gamma float """ return cls._registry.get(algorithm_name, [])
[docs] @classmethod def list_algorithms(cls) -> List[str]: """ List all registered algorithms. Returns: List[str]: A list of all algorithm names that have been registered in the registry. Example: >>> algorithms = AlgorithmHyperparametersRegistry.list_algorithms() >>> print(algorithms) ['DQN', 'A2C', 'SAC'] """ return list(cls._registry.keys())