"""
DQN with Double Q-learning (Double DQN)
Paper link: https://ojs.aaai.org/index.php/AAAI/article/view/10295
Implementation: TensorFlow2
"""
import numpy as np
from argparse import Namespace
from xuance.tensorflow import tf, tk, Module
from xuance.tensorflow.learners import Learner
[docs]
class DDQN_Learner(Learner):
def __init__(self,
config: Namespace,
policy: Module,
callback):
super(DDQN_Learner, self).__init__(config, policy, callback)
if ("macOS" in self.os_name) and ("arm" in self.os_name): # For macOS with Apple's M-series chips.
if self.distributed_training:
with self.policy.mirrored_strategy.scope():
self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate)
else:
self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate)
else:
if self.distributed_training:
with self.policy.mirrored_strategy.scope():
self.optimizer = tk.optimizers.Adam(config.learning_rate)
else:
self.optimizer = tk.optimizers.Adam(config.learning_rate)
self.gamma = config.gamma
self.sync_frequency = config.sync_frequency
self.mse_loss = tk.losses.MeanSquaredError()
@tf.function
def forward_fn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch):
with tf.GradientTape() as tape:
_, _, evalQ = self.policy(obs_batch)
_, targetA, _ = self.policy(next_batch)
_, _, targetQ = self.policy.target(next_batch)
targetA = tf.one_hot(targetA, targetQ.shape[1])
targetQ = tf.reduce_sum(targetQ * targetA, axis=-1)
targetQ = rew_batch + self.gamma * (1 - ter_batch) * targetQ
targetQ = tf.stop_gradient(targetQ)
predictQ = tf.math.reduce_sum(evalQ * tf.one_hot(act_batch, evalQ.shape[1]), axis=-1)
loss = self.mse_loss(targetQ, predictQ)
gradients = tape.gradient(loss, self.policy.trainable_variables)
if self.use_grad_clip:
self.optimizer.apply_gradients([
(tf.clip_by_norm(grad, self.grad_clip_norm), var)
for (grad, var) in zip(gradients, self.policy.trainable_variables)
if grad is not None
])
else:
self.optimizer.apply_gradients([
(grad, var)
for (grad, var) in zip(gradients, self.policy.trainable_variables)
if grad is not None
])
return predictQ, loss
@tf.function
def learn(self, *inputs):
if self.distributed_training:
predictQ, loss = self.policy.mirrored_strategy.run(self.forward_fn, args=inputs)
return (self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, predictQ, axis=None),
self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, loss, axis=None))
else:
return self.forward_fn(*inputs)
[docs]
def update(self, **samples):
self.iterations += 1
obs_batch = samples['obs']
act_batch = samples['actions'].astype(np.int32)
next_batch = samples['obs_next']
rew_batch = samples['rewards']
ter_batch = samples['terminals']
info = self.callback.on_update_start(self.iterations,
policy=self.policy, obs=obs_batch, act=act_batch,
next_obs=next_batch, rew=rew_batch, termination=ter_batch)
predictQ, loss = self.learn(obs_batch, act_batch, next_batch, rew_batch, ter_batch)
if self.iterations % self.sync_frequency == 0:
self.policy.copy_target()
info.update({
"Qloss": loss.numpy(),
"predictQ": tf.math.reduce_mean(predictQ).numpy(),
})
info.update(self.callback.on_update_end(self.iterations,
policy=self.policy, info=info,
predictQ=predictQ, loss=loss))
return info