Source code for xuance.mindspore.agents.policy_gradient.td3_agent

from argparse import Namespace
from gymnasium.spaces import Space
from xuance.common import Optional, BaseCallback
from xuance.environment import DummyVecEnv, SubprocVecEnv
from xuance.mindspore import Module
from xuance.mindspore.utils import NormalizeFunctions, ActivationFunctions, InitializeFunctions
from xuance.mindspore.policies import REGISTRY_Policy
from xuance.mindspore.agents.policy_gradient.ddpg_agent import DDPG_Agent


[docs] class TD3_Agent(DDPG_Agent): """The implementation of TD3 agent. Args: config: the Namespace variable that provides hyperparameters and other settings. envs: the vectorized environments. callback: A user-defined callback function object to inject custom logic during training. """ def __init__( self, config: Namespace, envs: Optional[DummyVecEnv | SubprocVecEnv] = None, observation_space: Optional[Space] = None, action_space: Optional[Space] = None, callback: Optional[BaseCallback] = None ): super(TD3_Agent, self).__init__(config, envs, observation_space, action_space, callback) def _build_policy(self) -> Module: normalize_fn = NormalizeFunctions[self.config.normalize] if hasattr(self.config, "normalize") else None initializer = InitializeFunctions[self.config.initialize] if hasattr(self.config, "initialize") else None activation = ActivationFunctions[self.config.activation] # build representations. representation = self._build_representation(self.config.representation, self.observation_space, self.config) # build policy if self.config.policy == "TD3_Policy": policy = REGISTRY_Policy["TD3_Policy"]( action_space=self.action_space, representation=representation, actor_hidden_size=self.config.actor_hidden_size, critic_hidden_size=self.config.critic_hidden_size, normalize=normalize_fn, initialize=initializer, activation=activation, activation_action=ActivationFunctions[self.config.activation_action]) else: raise AttributeError(f"TD3 currently does not support the policy named {self.config.policy}.") return policy