Source code for xuance.mindspore.learners.policy_gradient.pdqn_learner
"""
Parameterised deep Q network (P-DQN)
Paper link: https://arxiv.org/pdf/1810.06394.pdf
Implementation: MindSpore
"""
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import Learner
from argparse import Namespace
from mindspore.ops import OneHot
[docs]
class PDQN_Learner(Learner):
[docs]
class ConActorNetWithLossCell(Module):
def __init__(self, backbone):
super(PDQN_Learner.ConActorNetWithLossCell, self).__init__(auto_prefix=False)
self._backbone = backbone
[docs]
def construct(self, x):
# optimize actor network
policy_q = self._backbone.Qpolicy(x)
p_loss = - policy_q.mean()
return p_loss
[docs]
class QNetWithLossCell(Module):
def __init__(self, backbone, loss_fn):
super(PDQN_Learner.QNetWithLossCell, self).__init__(auto_prefix=False)
self._backbone = backbone
self._loss_fn = loss_fn
[docs]
def construct(self, x, dis_a, con_a, label):
# optimize q-network
eval_qs = self._backbone.Qeval(x, con_a)
eval_q = eval_qs.gather(dis_a.astype(ms.int32).view(-1, 1), 1).squeeze()
q_loss = self._loss_fn(eval_q, label)
return q_loss
def __init__(self,
config: Namespace,
policy: Module,
callback):
self.gamma = config.gamma
self.tau = config.tau
super(PDQN_Learner, self).__init__(config, policy, callback)
# define loss function
loss_fn = nn.MSELoss()
# connect the feed forward network with loss function.
self.con_loss_net = self.ConActorNetWithLossCell(policy)
self.q_loss_net = self.QNetWithLossCell(policy, loss_fn)
# define the training network
self.con_policy_train = nn.TrainOneStepCell(self.con_loss_net, optimizer[0])
self.q_policy_train = nn.TrainOneStepCell(self.q_loss_net, optimizer[1])
# set the training network as train mode.
self.con_policy_train.set_train()
self.q_policy_train.set_train()
[docs]
def update(self, obs_batch, act_batch, rew_batch, next_batch, terminal_batch):
self.iterations += 1
obs_batch = Tensor(obs_batch)
hyact_batch = Tensor(act_batch)
disact_batch = hyact_batch[:, 0]#.long()
conact_batch = hyact_batch[:, 1:]
rew_batch = Tensor(rew_batch)
next_batch = Tensor(next_batch)
ter_batch = Tensor(terminal_batch)
target_conact = self.policy.Atarget(next_batch)
target_q = self.policy.Qtarget(next_batch, target_conact)
target_q = target_q.max(axis=-1)
target_q = rew_batch + (1 - ter_batch) * self.gamma * target_q
q_loss = self.q_policy_train(obs_batch, disact_batch, conact_batch, target_q)
p_loss = self.con_policy_train(obs_batch)
self.policy.soft_update(self.tau)
con_actor_lr = self.scheduler[0](self.iterations).asnumpy()
qnet_lr = self.scheduler[1](self.iterations).asnumpy()
info = {
"P_loss": p_loss.asnumpy(),
"Q_loss": q_loss.asnumpy(),
"con_actor_lr": con_actor_lr,
"qnet_lr": qnet_lr
}
return info