Source code for xuance.mindspore.agents.policy_gradient.mpdqn_agent
from argparse import Namespace
from xuance.common import Optional, BaseCallback
from xuance.environment.single_agent_env import Gym_Env
from xuance.mindspore import Module
from xuance.mindspore.utils import NormalizeFunctions, ActivationFunctions, InitializeFunctions
from xuance.mindspore.policies import REGISTRY_Policy
from xuance.mindspore.agents.policy_gradient.pdqn_agent import PDQN_Agent
[docs]
class MPDQN_Agent(PDQN_Agent):
"""The implementation of MPDQN agent.
Args:
config: the Namespace variable that provides hyperparameters and other settings.
envs: the vectorized environments.
callback: A user-defined callback function object to inject custom logic during training.
"""
def __init__(self,
config: Namespace,
envs: Gym_Env,
callback: Optional[BaseCallback] = None):
super(MPDQN_Agent, self).__init__(config, envs, observation_space, action_space, callback)
def _build_policy(self) -> Module:
normalize_fn = NormalizeFunctions[self.config.normalize] if hasattr(self.config, "normalize") else None
initializer = InitializeFunctions[self.config.initialize] if hasattr(self.config, "initialize") else None
activation = ActivationFunctions[self.config.activation]
# build representation.
representation = self._build_representation(self.config.representation, self.observation_space, self.config)
# build policy.
if self.config.policy == "MPDQN_Policy":
policy = REGISTRY_Policy["MPDQN_Policy"](
observation_space=self.observation_space, action_space=self.action_space,
representation=representation,
conactor_hidden_size=self.config.conactor_hidden_size,
qnetwork_hidden_size=self.config.qnetwork_hidden_size,
normalize=normalize_fn, initialize=initializer, activation=activation,
activation_action=ActivationFunctions[self.config.activation_action])
else:
raise AttributeError(
f"{self.config.agent} currently does not support the policy named {self.config.policy}.")
return policy