import math
from numbers import Number
import torch
import numpy as np
from abc import ABC, abstractmethod
from torch.nn.functional import softplus, one_hot, softmax
from torch.distributions import Categorical, Bernoulli, Normal, constraints
from torch.distributions.utils import broadcast_all
from xuance.common import Callable
from xuance.torch import Tensor
from xuance.torch.utils.operations import sym_log, sym_exp
kl_div = torch.distributions.kl_divergence
[docs]
def split_distributions(distribution):
"""Splits a batch of distributions into individual instances.
This function separates a batch of distributions (either `CategoricalDistribution`
or `DiagGaussianDistribution`) into individual distribution objects.
Args:
distribution (CategoricalDistribution or DiagGaussianDistribution): The input
distribution batch to be split.
Returns:
np.ndarray: A reshaped array of individual distribution instances.
Raises:
NotImplementedError: If the distribution type is not supported.
"""
return_list = []
if isinstance(distribution, CategoricalDistribution):
shape = distribution.logits.shape
logits = distribution.logits.view(-1, shape[-1])
for logit in logits:
dist = CategoricalDistribution(logits.shape[-1])
dist.set_param(logits=logit.unsqueeze(0).detach())
return_list.append(dist)
elif isinstance(distribution, DiagGaussianDistribution):
shape = distribution.mu.shape
means = distribution.mu.view(-1, shape[-1])
std = distribution.std
for mu in means:
dist = DiagGaussianDistribution(shape[-1])
dist.set_param(mu.detach(), std.detach())
return_list.append(dist)
else:
raise NotImplementedError
return np.array(return_list).reshape(shape[:-1])
[docs]
def merge_distributions(distribution_list):
"""Merges a list of individual distributions back into a batch distribution.
This function reconstructs a batched distribution from a list (or array) of
individual distributions, supporting both categorical and diagonal Gaussian distributions.
Args:
distribution_list (list or np.ndarray): A list or array of individual distribution instances.
Returns:
CategoricalDistribution or DiagGaussianDistribution: A merged batch distribution.
Raises:
NotImplementedError: If the distribution type is not supported.
"""
if isinstance(distribution_list[0], CategoricalDistribution):
logits = torch.cat([dist.logits for dist in distribution_list], dim=0)
action_dim = logits.shape[-1]
dist = CategoricalDistribution(action_dim)
dist.set_param(logits=logits.detach())
return dist
elif isinstance(distribution_list[0], DiagGaussianDistribution):
shape = distribution_list.shape
distribution_list = distribution_list.reshape([-1])
mu = torch.cat([dist.mu for dist in distribution_list], dim=0)
std = torch.cat([dist.std for dist in distribution_list], dim=0)
action_dim = distribution_list[0].mu.shape[-1]
dist = DiagGaussianDistribution(action_dim)
mu = mu.view(shape + (action_dim,))
std = std.view(shape + (action_dim,))
dist.set_param(mu, std)
return dist
elif isinstance(distribution_list[0, 0], CategoricalDistribution):
shape = distribution_list.shape
distribution_list = distribution_list.reshape([-1])
logits = torch.cat([dist.logits for dist in distribution_list], dim=0)
action_dim = logits.shape[-1]
dist = CategoricalDistribution(action_dim)
logits = logits.view(shape + (action_dim,))
dist.set_param(logits=logits.detach())
return dist
else:
pass
[docs]
class Distribution(ABC):
def __init__(self):
super(Distribution, self).__init__()
self.distribution = None
[docs]
@abstractmethod
def set_param(self, *args):
raise NotImplementedError
[docs]
@abstractmethod
def get_param(self):
raise NotImplementedError
[docs]
@abstractmethod
def log_prob(self, x: torch.Tensor):
raise NotImplementedError
[docs]
@abstractmethod
def entropy(self):
raise NotImplementedError
[docs]
@abstractmethod
def stochastic_sample(self):
raise NotImplementedError
[docs]
@abstractmethod
def deterministic_sample(self):
raise NotImplementedError
[docs]
class CategoricalDistribution(Distribution):
def __init__(self, action_dim: int):
super(CategoricalDistribution, self).__init__()
self.action_dim = action_dim
self.probs, self.logits = None, None
[docs]
def set_param(self, probs=None, logits=None):
if probs is not None:
self.distribution = Categorical(probs=probs, logits=logits)
elif logits is not None:
self.distribution = Categorical(probs=probs, logits=logits)
else:
raise RuntimeError("Failed to setup distributions without given probs or logits.")
self.probs = self.distribution.probs
self.logits = self.distribution.logits
[docs]
def get_param(self):
return self.logits
[docs]
def log_prob(self, x):
return self.distribution.log_prob(x)
[docs]
def entropy(self):
return self.distribution.entropy()
[docs]
def stochastic_sample(self):
return self.distribution.sample()
[docs]
def deterministic_sample(self):
return torch.argmax(self.distribution.probs, dim=1)
[docs]
def kl_divergence(self, other: Distribution):
assert isinstance(other,
CategoricalDistribution), "KL Divergence should be measured by two same distribution with the same type"
return kl_div(self.distribution, other.distribution)
[docs]
class DiagGaussianDistribution(Distribution):
def __init__(self, action_dim: int):
super(DiagGaussianDistribution, self).__init__()
self.mu, self.std = None, None
self.action_dim = action_dim
[docs]
def set_param(self, mu, std):
self.mu = mu
self.std = std
self.distribution = Normal(mu, std)
[docs]
def get_param(self):
return self.mu, self.std
[docs]
def log_prob(self, x):
return self.distribution.log_prob(x).sum(-1)
[docs]
def entropy(self):
return self.distribution.entropy().sum(-1)
[docs]
def stochastic_sample(self):
return self.distribution.sample()
[docs]
def rsample(self):
return self.distribution.rsample()
[docs]
def deterministic_sample(self):
return self.mu
[docs]
def kl_divergence(self, other: Distribution):
assert isinstance(other,
DiagGaussianDistribution), "KL Divergence should be measured by two same distribution with the same type"
return kl_div(self.distribution, other.distribution)
[docs]
class ActivatedDiagGaussianDistribution(DiagGaussianDistribution):
def __init__(self, action_dim: int, activation_action, device):
super(ActivatedDiagGaussianDistribution, self).__init__(action_dim)
self.activation_fn = activation_action()
self.device = device
[docs]
def activated_rsample(self):
return self.activation_fn(self.rsample())
[docs]
def activated_rsample_and_logprob(self):
act_pre_activated = self.rsample() # sample without being activated.
act_activated = self.activation_fn(act_pre_activated)
log_prob = self.distribution.log_prob(act_pre_activated)
correction = - 2. * (torch.log(Tensor([2.0])).to(self.device) - act_pre_activated - softplus(-2. * act_pre_activated))
log_prob += correction
return act_activated, log_prob.sum(-1)
[docs]
class SymLogDistribution:
def __init__(self,
mode: Tensor,
dims: int,
dist: str = "mse",
agg: str = "sum",
tol: float = 1e-8):
self._mode = mode
self._dims = tuple([-x for x in range(1, dims + 1)])
self._dist = dist
self._agg = agg
self._tol = tol
self._batch_shape = mode.shape[: len(mode.shape) - dims]
self._event_shape = mode.shape[len(mode.shape) - dims:]
@property
def mode(self) -> Tensor:
return sym_exp(self._mode)
@property
def mean(self) -> Tensor:
return sym_exp(self._mode)
[docs]
def log_prob(self, value: Tensor) -> Tensor:
"""Computes the log probability of a value under this distribution.
Args:
value: The observed value (in original space) to evaluate.
Returns:
Log probability tensor, aggregated over event dimensions as specified.
Raises:
AssertionError: If value shape does not match the distribution's shape.
NotImplementedError: If invalid distance or aggregation methods are provided.
"""
assert self._mode.shape == value.shape, (self._mode.shape, value.shape)
if self._dist == "mse":
distance = (self._mode - sym_log(value)) ** 2
distance = torch.where(distance < self._tol, 0, distance)
elif self._dist == "abs":
distance = torch.abs(self._mode - sym_log(value))
distance = torch.where(distance < self._tol, 0, distance)
else:
raise NotImplementedError(self._dist)
if self._agg == "mean":
loss = distance.mean(self._dims)
elif self._agg == "sum":
loss = distance.sum(self._dims)
else:
raise NotImplementedError(self._agg)
return -loss
[docs]
class MSEDistribution:
def __init__(self, mode: Tensor, dims: int, agg: str = "sum"):
self._mode = mode
self._dims = tuple([-x for x in range(1, dims + 1)])
self._agg = agg
self._batch_shape = mode.shape[: len(mode.shape) - dims]
self._event_shape = mode.shape[len(mode.shape) - dims :]
@property
def mode(self) -> Tensor:
return self._mode
@property
def mean(self) -> Tensor:
return self._mode
[docs]
def log_prob(self, value: Tensor) -> Tensor:
assert self._mode.shape == value.shape, (self._mode.shape, value.shape)
distance = (self._mode - value) ** 2
if self._agg == "mean":
loss = distance.mean(self._dims)
elif self._agg == "sum":
loss = distance.sum(self._dims)
else:
raise NotImplementedError(self._agg)
return -loss
[docs]
class TwoHotEncodingDistribution:
def __init__(
self,
logits: Tensor,
dims: int = 0,
low: int = -20,
high: int = 20,
transfwd: Callable[[Tensor], Tensor] = sym_log,
transbwd: Callable[[Tensor], Tensor] = sym_exp,
):
self.logits = logits
self.probs = softmax(logits, dim=-1)
self.dims = tuple([-x for x in range(1, dims + 1)]) # logits.shape[-1] = 255 (len(self.bins))
self.bins = torch.linspace(low, high, logits.shape[-1], device=logits.device)
self.low = low
self.high = high
self.transfwd = transfwd
self.transbwd = transbwd
self._batch_shape = logits.shape[: len(logits.shape) - dims]
self._event_shape = logits.shape[len(logits.shape) - dims : -1] + (1,)
@property
def mean(self) -> Tensor:
return self.transbwd((self.probs * self.bins).sum(dim=self.dims, keepdim=True))
@property
def mode(self) -> Tensor:
return self.transbwd((self.probs * self.bins).sum(dim=self.dims, keepdim=True))
[docs]
def log_prob(self, x: Tensor) -> Tensor:
x = self.transfwd(x)
# below in [-1, len(self.bins) - 1]
below = (self.bins <= x).type(torch.int32).sum(dim=-1, keepdim=True) - 1
# above in [0, len(self.bins)]
above = below + 1 # shape: [1, ]
# above in [0, len(self.bins) - 1]
above = torch.minimum(above, torch.full_like(above, len(self.bins) - 1))
# below in [0, len(self.bins) - 1]
below = torch.maximum(below, torch.zeros_like(below))
equal = below == above
dist_to_below = torch.where(equal, 1, torch.abs(self.bins[below] - x))
dist_to_above = torch.where(equal, 1, torch.abs(self.bins[above] - x))
total = dist_to_below + dist_to_above
weight_below = dist_to_above / total
weight_above = dist_to_below / total
target = (one_hot(below, len(self.bins)) * weight_below[..., None]
+ one_hot(above, len(self.bins)) * weight_above[..., None]).squeeze(-2)
log_pred = self.logits - torch.logsumexp(self.logits, dim=-1, keepdims=True)
return (target * log_pred).sum(dim=self.dims)
[docs]
class BernoulliSafeMode(Bernoulli):
def __init__(self, probs=None, logits=None, validate_args=None):
super().__init__(probs, logits, validate_args)
@property
def mode(self):
mode = (self.probs > 0.5).to(self.probs)
return mode
CONST_SQRT_2 = math.sqrt(2)
CONST_INV_SQRT_2PI = 1 / math.sqrt(2 * math.pi)
CONST_INV_SQRT_2 = 1 / math.sqrt(2)
CONST_LOG_INV_SQRT_2PI = math.log(CONST_INV_SQRT_2PI)
CONST_LOG_SQRT_2PI_E = 0.5 * math.log(2 * math.pi * math.e)
[docs]
class TruncatedStandardNormal(torch.distributions.Distribution):
"""
Truncated Standard Normal distribution
https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
"""
arg_constraints = {
"a": constraints.real,
"b": constraints.real,
}
has_rsample = True
def __init__(self, a, b, validate_args=None):
self.a, self.b = broadcast_all(a, b)
if isinstance(a, Number) and isinstance(b, Number):
batch_shape = torch.Size()
else:
batch_shape = self.a.size()
super(TruncatedStandardNormal, self).__init__(batch_shape, validate_args=validate_args)
if self.a.dtype != self.b.dtype:
raise ValueError("Truncation bounds types are different")
if any((self.a >= self.b).view(-1).tolist()):
raise ValueError("Incorrect truncation range")
eps = torch.finfo(self.a.dtype).eps
self._dtype_min_gt_0 = eps
self._dtype_max_lt_1 = 1 - eps
self._little_phi_a = self._little_phi(self.a)
self._little_phi_b = self._little_phi(self.b)
self._big_phi_a = self._big_phi(self.a)
self._big_phi_b = self._big_phi(self.b)
self._Z = (self._big_phi_b - self._big_phi_a).clamp_min(eps)
self._log_Z = self._Z.log()
little_phi_coeff_a = torch.nan_to_num(self.a, nan=math.nan)
little_phi_coeff_b = torch.nan_to_num(self.b, nan=math.nan)
self._lpbb_m_lpaa_d_Z = (
self._little_phi_b * little_phi_coeff_b - self._little_phi_a * little_phi_coeff_a
) / self._Z
self._mean = -(self._little_phi_b - self._little_phi_a) / self._Z
self._variance = 1 - self._lpbb_m_lpaa_d_Z - ((self._little_phi_b - self._little_phi_a) / self._Z) ** 2
self._entropy = CONST_LOG_SQRT_2PI_E + self._log_Z - 0.5 * self._lpbb_m_lpaa_d_Z
@constraints.dependent_property
def support(self):
return constraints.interval(self.a, self.b)
@property
def mean(self):
return self._mean
@property
def variance(self):
return self._variance
@property
def auc(self):
return self._Z
@staticmethod
def _little_phi(x):
return (-(x**2) * 0.5).exp() * CONST_INV_SQRT_2PI
@staticmethod
def _big_phi(x):
return 0.5 * (1 + (x * CONST_INV_SQRT_2).erf())
@staticmethod
def _inv_big_phi(x):
return CONST_SQRT_2 * (2 * x - 1).erfinv()
[docs]
def cdf(self, value):
if self._validate_args:
self._validate_sample(value)
return ((self._big_phi(value) - self._big_phi_a) / self._Z).clamp(0, 1)
[docs]
def icdf(self, value):
return self._inv_big_phi(self._big_phi_a + value * self._Z)
[docs]
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
return CONST_LOG_INV_SQRT_2PI - self._log_Z - (value**2) * 0.5
[docs]
def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
p = torch.empty(shape, device=self.a.device).uniform_(self._dtype_min_gt_0, self._dtype_max_lt_1)
return self.icdf(p)
[docs]
def entropy(self):
return self._entropy
[docs]
class TruncatedNormal(TruncatedStandardNormal):
"""
Truncated Normal distribution
https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
"""
has_rsample = True
def __init__(self, loc, scale, a, b, validate_args=None):
self.loc, self.scale, a, b = broadcast_all(loc, scale, a, b)
a = (a - self.loc) / self.scale
b = (b - self.loc) / self.scale
super(TruncatedNormal, self).__init__(a, b, validate_args=validate_args)
self._log_scale = self.scale.log()
self._mean = self._mean * self.scale + self.loc
self._variance = self._variance * self.scale**2
self._entropy += self._log_scale
def _to_std_rv(self, value):
return (value - self.loc) / self.scale
def _from_std_rv(self, value):
return value * self.scale + self.loc
[docs]
def cdf(self, value):
return super(TruncatedNormal, self).cdf(self._to_std_rv(value))
[docs]
def icdf(self, value):
return self._from_std_rv(super(TruncatedNormal, self).icdf(value))
[docs]
def log_prob(self, value):
return super(TruncatedNormal, self).log_prob(self._to_std_rv(value)) - self._log_scale