"""
Twin Delayed Deep Deterministic Policy Gradient (TD3)
Paper link: http://proceedings.mlr.press/v80/fujimoto18a/fujimoto18a.pdf
Implementation: TensorFlow2
"""
from argparse import Namespace
from xuance.tensorflow import tf, tk, Module
from xuance.tensorflow.learners import Learner
[docs]
class TD3_Learner(Learner):
def __init__(self,
config: Namespace,
policy: Module,
callback):
super(TD3_Learner, self).__init__(config, policy, callback)
if ("macOS" in self.os_name) and ("arm" in self.os_name): # For macOS with Apple's M-series chips.
if self.distributed_training:
with self.policy.mirrored_strategy.scope():
self.optimizer = {'actor': tk.optimizers.legacy.Adam(config.learning_rate_actor),
'critic': tk.optimizers.legacy.Adam(config.learning_rate_critic)}
else:
self.optimizer = {'actor': tk.optimizers.legacy.Adam(config.learning_rate_actor),
'critic': tk.optimizers.legacy.Adam(config.learning_rate_critic)}
else:
if self.distributed_training:
with self.policy.mirrored_strategy.scope():
self.optimizer = {'actor': tk.optimizers.Adam(config.learning_rate_actor),
'critic': tk.optimizers.Adam(config.learning_rate_critic)}
else:
self.optimizer = {'actor': tk.optimizers.Adam(config.learning_rate_actor),
'critic': tk.optimizers.Adam(config.learning_rate_critic)}
self.tau = config.tau
self.gamma = config.gamma
self.actor_update_delay = config.actor_update_delay
self.mse_loss = tk.losses.MeanSquaredError()
@tf.function
def actor_forward_fn(self, obs_batch):
with tf.GradientTape() as tape:
policy_q = self.policy.Qpolicy(obs_batch)
p_loss = -tf.reduce_mean(policy_q)
gradients = tape.gradient(p_loss, self.policy.actor_trainable_variables)
if self.use_grad_clip:
self.optimizer['actor'].apply_gradients([
(tf.clip_by_norm(grad, self.grad_clip_norm), var)
for (grad, var) in zip(gradients, self.policy.actor_trainable_variables)
if grad is not None])
else:
self.optimizer['actor'].apply_gradients([
(grad, var)
for (grad, var) in zip(gradients, self.policy.actor_trainable_variables)
if grad is not None])
return p_loss
@tf.function
def critic_forward_fn(self, obs_batch, act_batch, rew_batch, next_batch, ter_batch):
with tf.GradientTape() as tape:
action_q_A, action_q_B = self.policy.Qaction(obs_batch, act_batch)
action_q_A = tf.reshape(action_q_A, [-1])
action_q_B = tf.reshape(action_q_B, [-1])
next_q = tf.reshape(self.policy.Qtarget(next_batch), [-1])
target_q = rew_batch + self.gamma * (1 - ter_batch) * next_q
target_q = tf.stop_gradient(tf.reshape(target_q, [-1]))
q_loss = self.mse_loss(target_q, action_q_A) + self.mse_loss(target_q, action_q_B)
gradients = tape.gradient(q_loss, self.policy.critic_trainable_variables)
if self.use_grad_clip:
self.optimizer['critic'].apply_gradients([
(tf.clip_by_norm(grad, self.grad_clip_norm), var)
for (grad, var) in zip(gradients, self.policy.critic_trainable_variables)
if grad is not None])
else:
self.optimizer['critic'].apply_gradients([
(grad, var)
for (grad, var) in zip(gradients, self.policy.critic_trainable_variables)
if grad is not None])
return q_loss, action_q_A, action_q_B
@tf.function
def learn_actor(self, *inputs):
if self.distributed_training:
p_loss = self.policy.mirrored_strategy.run(self.actor_forward_fn, args=inputs)
return self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, p_loss, axis=None)
else:
return self.actor_forward_fn(*inputs)
@tf.function
def learn_critic(self, *inputs):
if self.distributed_training:
q_loss, action_q_A, action_q_B = self.policy.mirrored_strategy.run(self.critic_forward_fn, args=inputs)
return (self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, q_loss, axis=None),
self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, action_q_A, axis=None),
self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, action_q_B, axis=None))
else:
return self.critic_forward_fn(*inputs)
[docs]
def update(self, **samples):
self.iterations += 1
obs_batch = samples['obs']
act_batch = samples['actions']
next_batch = samples['obs_next']
rew_batch = samples['rewards']
ter_batch = samples['terminals']
info = self.callback.on_update_start(self.iterations,
policy=self.policy, obs=obs_batch, act=act_batch,
next_obs=next_batch, rew=rew_batch, termination=ter_batch)
q_loss, action_q_A, action_q_B = self.learn_critic(obs_batch, act_batch, rew_batch, next_batch, ter_batch)
if self.iterations % self.actor_update_delay == 0:
p_loss = self.learn_actor(obs_batch)
self.policy.soft_update(self.tau)
info["Ploss"] = p_loss.numpy()
info.update({
"Qloss": q_loss.numpy(),
"QvalueA": tf.math.reduce_mean(action_q_A).numpy(),
"QvalueB": tf.math.reduce_mean(action_q_B).numpy(),
})
info.update(self.callback.on_update_end(self.iterations, policy=self.policy, info=info,
action_q_A=action_q_A, action_q_B=action_q_B, q_loss=q_loss))
return info