Source code for xuance.torch.agents.policy_gradient.sac_agent

import torch
import numpy as np
from argparse import Namespace
from gymnasium.spaces import Space
from xuance.common import Optional, BaseCallback
from xuance.environment import DummyVecEnv, SubprocVecEnv
from xuance.torch import Module
from xuance.torch.utils import NormalizeFunctions, ActivationFunctions
from xuance.torch.policies import REGISTRY_Policy
from xuance.torch.agents import OffPolicyAgent


[docs] class SAC_Agent(OffPolicyAgent): """The implementation of SAC agent. Args: config: the Namespace variable that provides hyperparameters and other settings. envs: the vectorized environments. callback: A user-defined callback function object to inject custom logic during training. """ def __init__( self, config: Namespace, envs: Optional[DummyVecEnv | SubprocVecEnv] = None, observation_space: Optional[Space] = None, action_space: Optional[Space] = None, callback: Optional[BaseCallback] = None ): super(SAC_Agent, self).__init__(config, envs, observation_space, action_space, callback) self.policy = self._build_policy() # build policy self.memory = self._build_memory() # build memory self.learner = self._build_learner(self.config, self.policy, self.callback) def _build_policy(self) -> Module: normalize_fn = NormalizeFunctions[self.config.normalize] if hasattr(self.config, "normalize") else None initializer = torch.nn.init.orthogonal_ activation = ActivationFunctions[self.config.activation] device = self.device # build representations. representation = self._build_representation(self.config.representation, self.observation_space, self.config) # build policy if self.config.policy == "Gaussian_SAC": policy = REGISTRY_Policy["Gaussian_SAC"]( action_space=self.action_space, representation=representation, actor_hidden_size=self.config.actor_hidden_size, critic_hidden_size=self.config.critic_hidden_size, normalize=normalize_fn, initialize=initializer, device=device, use_distributed_training=self.distributed_training, activation=activation, activation_action=ActivationFunctions[self.config.activation_action]) elif self.config.policy == "Categorical_SAC": policy = REGISTRY_Policy["Categorical_SAC"]( action_space=self.action_space, representation=representation, actor_hidden_size=self.config.actor_hidden_size, critic_hidden_size=self.config.critic_hidden_size, normalize=normalize_fn, initialize=initializer, activation=activation, device=device, use_distributed_training=self.distributed_training) else: raise AttributeError(f"SAC currently does not support the policy named {self.config.policy}.") return policy
[docs] def get_actions(self, observations: np.ndarray, test_mode: Optional[bool] = False): """Returns actions and values. Parameters: observations (np.ndarray): The observation. test_mode (Optional[bool]): True for testing without noises. Returns: actions: The actions to be executed. values: The evaluated values. dists: The policy distributions. log_pi: Log of stochastic actions. """ _, actions_output = self.policy(observations) actions = actions_output.detach().cpu().numpy() return {"actions": actions}