Source code for xuance.common.callback

from abc import ABC


[docs] class BaseCallback(ABC): """Base class for callback hooks in reinforcement learning training and testing. Users can inherit this class to implement custom logic during different stages of training and evaluation. """ def __init__(self, *args, **kwargs): self.logger = kwargs.get('logger')
[docs] def on_update_start(self, iterations, **kwargs): """Called before the policy update begins. Args: iterations (int): Number of update iterations that have performed. **kwargs: Additional optional keyword arguments. """ return {}
[docs] def on_update_end(self, iterations, **kwargs): """Called after the policy update is completed. Args: iterations (int): Number of update iterations that have performed. **kwargs: Optional keyword arguments. """ return {}
[docs] def on_train_step(self, current_step, **kwargs): """Called after each training step (i.e., after collecting one transition). Args: current_step (int): The current global training step. **kwargs: Additional optional information. """ return
[docs] def on_train_epochs_end(self, current_step, **kwargs): """Called after each training epoch (i.e., after collecting one transition). Args: current_step (int): The current global training step. **kwargs: Additional optional information. """ return
[docs] def on_train_episode_info(self, **kwargs): """Called at the termination or truncation of one episode for an environment. """ return
[docs] def on_train_step_end(self, current_step, **kwargs): """Called after a training step is completed (includes update, logging, etc.). Args: current_step (int): The current global training step. envs_info: Environment information. train_info: Training information. """ return
[docs] def on_test_step(self, *args, **kwargs): """Called during each step in the testing loop. Args: *args: Optional positional arguments. **kwargs: Optional keyword arguments. """ return
[docs] def on_test_end(self, *args, **kwargs): """Called at the end of the testing loop. Args: *args: Optional positional arguments. **kwargs: Optional keyword arguments. """ return
[docs] class MultiAgentBaseCallback(BaseCallback): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def on_update_agent_wise(self, iterations, agent_key, **kwargs) -> dict: """Called when updating an agent's policy. Args: iterations (int): Number of update iterations that have performed. agent_key (str): The key of the agent to update. **kwargs: Optional keyword arguments. """ return {}