Source code for xuance.mindspore.utils.distributions

import numpy as np
from abc import ABC, abstractmethod
from mindspore.nn.probability.distribution import Categorical, Normal
from xuance.mindspore import ms, ops, Tensor


[docs] def split_distributions(distribution): return_list = [] if isinstance(distribution, CategoricalDistribution): shape = distribution.logits.shape logits = distribution.logits.view(-1, shape[-1]) for logit in logits: dist = CategoricalDistribution(logit.shape[-1]) dist.set_param(logits=ops.stop_gradient(logit.unsqueeze(0))) return_list.append(dist) elif isinstance(distribution, DiagGaussianDistribution): shape = distribution.mu.shape means = distribution.mu.view(-1, shape[-1]) std = distribution.std for mu in means: dist = DiagGaussianDistribution(shape[-1]) dist.set_param(mu, std) return_list.append(dist) else: raise NotImplementedError return np.array(return_list).reshape(shape[:-1])
[docs] def merge_distributions(distribution_list): if isinstance(distribution_list[0], CategoricalDistribution): logits = ops.concat([dist.logits for dist in distribution_list], 0) action_dim = logits.shape[-1] dist = CategoricalDistribution(action_dim) dist.set_param(logits=ops.stop_gradient(logits)) return dist elif isinstance(distribution_list[0], DiagGaussianDistribution): shape = distribution_list.shape distribution_list = distribution_list.reshape([-1]) mu = ops.cat([dist.mu for dist in distribution_list]) std = ops.cat([dist.std for dist in distribution_list]) action_dim = distribution_list[0].mu.shape[-1] dist = DiagGaussianDistribution(action_dim) mu = mu.view(shape + (action_dim,)) std = std.view(shape + (action_dim,)) dist.set_param(mu, std) return dist else: raise NotImplementedError
[docs] class Distribution(ABC): def __init__(self): super(Distribution, self).__init__() self.distribution = None
[docs] @abstractmethod def set_param(self, *args): raise NotImplementedError
[docs] @abstractmethod def get_param(self): raise NotImplementedError
[docs] @abstractmethod def log_prob(self, x: ms.Tensor): raise NotImplementedError
[docs] @abstractmethod def entropy(self): raise NotImplementedError
[docs] @abstractmethod def stochastic_sample(self): raise NotImplementedError
[docs] @abstractmethod def deterministic_sample(self): raise NotImplementedError
[docs] class CategoricalDistribution(Distribution): def __init__(self, action_dim: int): super(CategoricalDistribution, self).__init__() self.action_dim = action_dim self.distribution = Categorical() self.probs, self.logits = None, None
[docs] def set_param(self, probs=None, logits=None): if probs is not None: logits = ops.log(probs) - ops.log1p(-probs) elif logits is not None: probs = ops.softmax(logits, axis=-1) else: raise RuntimeError("Failed to setup distributions without given probs or logits.") self.probs = probs self.logits = logits
[docs] def get_param(self): return self.logits
[docs] def log_prob(self, x): return self.distribution.log_prob(value=Tensor(x), probs=self.probs)
[docs] def entropy(self): return self.distribution.entropy(probs=self.probs)
[docs] def stochastic_sample(self): return self.distribution.sample(probs=self.probs)
[docs] def deterministic_sample(self): return self.argmax(self.distribution.probs)
[docs] def kl_divergence(self, other: Distribution): assert isinstance(other, CategoricalDistribution), "KL Divergence should be measured by two same distribution with the same type" return self.distribution.kl_loss(self.distribution, other.distribution)
[docs] class DiagGaussianDistribution(Distribution): def __init__(self, action_dim: int): super(DiagGaussianDistribution, self).__init__() self.mu, self.std = None, None self.action_dim = action_dim
[docs] def set_param(self, mu, std): self.mu = mu self.std = std self.distribution = Normal(mean=self.mu, sd=self.std, dtype=ms.float32)
[docs] def get_param(self): return self.mu, self.std
[docs] def log_prob(self, x: ms.Tensor): return self.distribution.log_prob(value=Tensor(x), mean=self.mu, sd=self.std).sum(-1)
[docs] def entropy(self): return self.distribution.entropy(mean=self.mu, sd=self.std).sum(-1)
[docs] def stochastic_sample(self): return self.distribution.sample()
[docs] def deterministic_sample(self): return self.mu
[docs] def kl_divergence(self, other: Distribution): assert isinstance(other, DiagGaussianDistribution), "KL Divergence should be measured by two same distribution with the same type" return ops.kl_div(self.distribution, other.distribution)
[docs] class ActivatedDiagGaussianDistribution(DiagGaussianDistribution): def __init__(self, action_dim: int, activation_action): super(ActivatedDiagGaussianDistribution, self).__init__(action_dim) self.activation_fn = activation_action()
[docs] def activated_rsample(self): return self.activation_fn(self.stochastic_sample())
[docs] def activated_rsample_and_logprob(self): act_pre_activated = self.stochastic_sample() # sample without being activated. act_activated = self.activation_fn(act_pre_activated) log_prob = self.distribution.log_prob(act_pre_activated) correction = - 2. * (ops.log(Tensor([2.0])) - act_pre_activated - ops.softplus(-2. * act_pre_activated)) log_prob += correction return act_activated, log_prob.sum(-1)