Source code for xuance.tensorflow.utils.distributions

import numpy as np
from abc import abstractmethod
from tensorflow.keras.activations import softplus
from xuance.tensorflow import tf, Tensor


[docs] def split_distributions(distribution): return_list = [] if isinstance(distribution, CategoricalDistribution): shape = distribution.logits.shape logits = tf.reshape(distribution.logits, [-1, shape[-1]]) for logit in logits: dist = CategoricalDistribution(logits.shape[-1]) dist.set_param(logits=tf.stop_gradient(tf.expand_dims(logit, 0))) return_list.append(dist) elif isinstance(distribution, DiagGaussianDistribution): shape = distribution.mu.shape means = tf.reshape(distribution.mu, [-1, shape[-1]]) std = distribution.std for mu in means: dist = DiagGaussianDistribution(shape[-1]) dist.set_param(mu, std) return_list.append(dist) else: raise NotImplementedError return np.array(return_list).reshape(shape[:-1])
[docs] def merge_distributions(distribution_list): if isinstance(distribution_list[0], CategoricalDistribution): logits = tf.concat([dist.logits for dist in distribution_list], axis=0) action_dim = logits.shape[-1] dist = CategoricalDistribution(action_dim) dist.set_param(logits=tf.stop_gradient(logits)) return dist elif isinstance(distribution_list[0], DiagGaussianDistribution): shape = distribution_list.shape distribution_list = distribution_list.reshape([-1]) mu = tf.concat([dist.mu for dist in distribution_list], axis=0) std = tf.concat([dist.std for dist in distribution_list], axis=0) action_dim = distribution_list[0].mu.shape[-1] dist = DiagGaussianDistribution(action_dim) mu = tf.reshape(mu, shape + (action_dim,)) std = tf.reshape(std, shape + (action_dim,)) dist.set_param(mu=mu, std=std) return dist elif isinstance(distribution_list[0, 0], CategoricalDistribution): shape = distribution_list.shape distribution_list = distribution_list.reshape([-1]) logits = tf.concat([dist.logits for dist in distribution_list], axis=0) action_dim = logits.shape[-1] dist = CategoricalDistribution(action_dim) logits = tf.reshape(logits, shape + (action_dim, )) dist.set_param(tf.stop_gradient(logits)) return dist else: pass
[docs] class Distribution: def __init__(self): pass
[docs] @abstractmethod def set_param(self, *args): raise NotImplementedError
[docs] @abstractmethod def get_param(self): raise NotImplementedError
[docs] @abstractmethod def log_prob(self, x: Tensor): raise NotImplementedError
[docs] @abstractmethod def entropy(self): raise NotImplementedError
[docs] @abstractmethod def stochastic_sample(self): raise NotImplementedError
[docs] @abstractmethod def deterministic_sample(self): raise NotImplementedError
[docs] class CategoricalDistribution(Distribution): def __init__(self, action_dim: int): super(CategoricalDistribution, self).__init__() self.probs, self.logits = None, None self.action_dim = action_dim
[docs] def set_param(self, probs=None, logits=None): if probs is not None: self.probs = probs / tf.reduce_sum(probs, axis=-1, keepdims=True) self.logits = tf.math.log(probs) - tf.math.log1p(-probs) elif logits is not None: self.logits = logits self.probs = tf.nn.softmax(logits, axis=-1) else: raise RuntimeError("Either probs or logits must be specified.")
[docs] def get_param(self): return self.probs or self.logits
[docs] def log_prob(self, x): x = tf.expand_dims(tf.cast(x, dtype=tf.int32), -1) log_probs = tf.nn.log_softmax(self.logits) y = tf.gather(log_probs, x, batch_dims=1) y = tf.squeeze(y, axis=-1) return y
[docs] def entropy(self): log_probs = tf.nn.log_softmax(self.logits) e = -tf.reduce_sum(self.probs * log_probs, axis=-1, keepdims=True) return e
[docs] def stochastic_sample(self): logits_detach = self.logits.numpy() sampled_actions = tf.random.categorical(logits_detach, num_samples=1) return tf.squeeze(sampled_actions, axis=-1)
[docs] def deterministic_sample(self): return tf.argmax(self.probs, dim=1)
[docs] def kl_divergence(self, other: Distribution): assert isinstance(other, CategoricalDistribution), "KL Divergence should be measured by two same distribution with the same type" log_p = tf.nn.log_softmax(self.logits, axis=-1) # log P(a) log_q = tf.nn.log_softmax(other.logits, axis=-1) # log Q(a) p = tf.math.exp(log_p) # P(a) kl = tf.reduce_sum(p * (log_p - log_q), axis=-1) return kl
[docs] class DiagGaussianDistribution(Distribution): def __init__(self, action_dim: int): super(DiagGaussianDistribution, self).__init__() self.mu, self.std = None, None self.action_dim = action_dim
[docs] def set_param(self, mu, std): self.mu = mu self.std = std
[docs] def get_param(self): return self.mu, self.std
[docs] def log_prob(self, x): log_std = tf.math.log(self.std + 1e-8) log_prob = -0.5 * (((x - self.mu) / (self.std + 1e-8)) ** 2 + 2.0 * log_std + tf.math.log(2.0 * np.pi)) log_prob = tf.reduce_sum(log_prob, axis=-1, keepdims=False) return log_prob
[docs] def entropy(self): log_std = tf.math.log(self.std + 1e-8) entropy = tf.reduce_sum(0.5 + 0.5 * tf.math.log(2.0 * np.pi) + log_std, axis=-1, keepdims=True) return entropy
[docs] def stochastic_sample(self): eps = tf.random.normal(shape=tf.shape(self.mu)) # 𝜖 ~ N(0, 1) action = self.mu + self.std * eps # Reparameterization trick return action
[docs] def deterministic_sample(self): return self.mu
[docs] def kl_divergence(self, other: Distribution): assert isinstance(other, DiagGaussianDistribution), "KL Divergence should be measured by two same distribution with the same type" var1 = tf.square(self.std) var2 = tf.square(other.std) kl = tf.math.log(other.std / self.std) + (var1 + tf.square(self.mu - other.mu)) / (2.0 * var2) - 0.5 return tf.reduce_sum(kl, axis=-1)
[docs] class ActivatedDiagGaussianDistribution(DiagGaussianDistribution): def __init__(self, action_dim: int, activation_action): super(ActivatedDiagGaussianDistribution, self).__init__(action_dim) self.activation_fn = activation_action
[docs] def activated_rsample(self): return self.activation_fn(self.stochastic_sample())
[docs] def activated_rsample_and_logprob(self): act_pre_activated = self.stochastic_sample() # sample without being activated. act_activated = self.activation_fn(act_pre_activated) log_prob = self.log_prob(act_pre_activated) correction = - 2. * (tf.math.log(2.0) - act_pre_activated - softplus(-2. * act_pre_activated)) log_prob += correction return act_activated, tf.math.reduce_sum(log_prob, axis=-1)