Source code for xuance.torch.agents.core.offline

import numpy as np
from tqdm import tqdm
from argparse import Namespace
from xuance.common import Optional, DummyOffPolicyBuffer, OfflineBuffer_D4RL, BaseCallback
from xuance.torch import Module
from xuance.torch.agents.base import Agent


[docs] class OfflineAgent(Agent): """The core class for offline reinforcement learning. Args: config: the Namespace variable that provides hyperparameters and other settings. envs: the vectorized environments. callback: A user-defined callback function object to inject custom logic during training. It can be used for logging, early stopping, model saving, or visualization. If not provided, a default no-op callback is used. """ def __init__(self, config: Namespace, envs, callback: Optional[BaseCallback] = None): super(OfflineAgent, self).__init__(config, envs, callback) self.auxiliary_info_shape = None self.buffer_size = self.config.buffer_size self.batch_size = self.config.batch_size self.memory: Optional[OfflineBuffer_D4RL] = self._build_memory() def _build_memory(self, auxiliary_info_shape=None): self.d4rl = True if self.config.env_name != "atari" else False Buffer = OfflineBuffer_D4RL if self.d4rl else DummyOffPolicyBuffer input_buffer = dict(observation_space=self.observation_space, action_space=self.action_space, auxiliary_shape=auxiliary_info_shape, n_envs=self.n_envs, buffer_size=self.buffer_size, batch_size=self.batch_size) return Buffer(**input_buffer) def _build_policy(self) -> Module: raise NotImplementedError
[docs] def train_epochs(self, n_epochs=1): train_info = {} for _ in range(n_epochs): samples = self.memory.sample() train_info = self.learner.update(**samples) return train_info
[docs] def train(self, train_steps): train_info = {} for _ in tqdm(range(train_steps)): if self.current_step > self.start_training and self.current_step % self.training_frequency == 0: update_info = self.train_epochs(n_epochs=self.n_epochs) # self.n_epochs = 16 self.log_infos(update_info, self.current_step) train_info.update(update_info) self.callback.on_train_epochs_end(self.current_step, policy=self.policy, memory=self.memory, train_info=train_info, train_steps=train_steps) self.current_step += 1 return train_info
[docs] def get_actions(self, observations: np.ndarray): raise NotImplementedError
[docs] def test(self, env_fn, steps): raise NotImplementedError