import numpy as np
from abc import ABC, abstractmethod
from mindspore.nn.probability.distribution import Categorical, Normal
from xuance.mindspore import ms, ops, Tensor
[docs]
def split_distributions(distribution):
return_list = []
if isinstance(distribution, CategoricalDistribution):
shape = distribution.logits.shape
logits = distribution.logits.view(-1, shape[-1])
for logit in logits:
dist = CategoricalDistribution(logit.shape[-1])
dist.set_param(logits=ops.stop_gradient(logit.unsqueeze(0)))
return_list.append(dist)
elif isinstance(distribution, DiagGaussianDistribution):
shape = distribution.mu.shape
means = distribution.mu.view(-1, shape[-1])
std = distribution.std
for mu in means:
dist = DiagGaussianDistribution(shape[-1])
dist.set_param(mu, std)
return_list.append(dist)
else:
raise NotImplementedError
return np.array(return_list).reshape(shape[:-1])
[docs]
def merge_distributions(distribution_list):
if isinstance(distribution_list[0], CategoricalDistribution):
logits = ops.concat([dist.logits for dist in distribution_list], 0)
action_dim = logits.shape[-1]
dist = CategoricalDistribution(action_dim)
dist.set_param(logits=ops.stop_gradient(logits))
return dist
elif isinstance(distribution_list[0], DiagGaussianDistribution):
shape = distribution_list.shape
distribution_list = distribution_list.reshape([-1])
mu = ops.cat([dist.mu for dist in distribution_list])
std = ops.cat([dist.std for dist in distribution_list])
action_dim = distribution_list[0].mu.shape[-1]
dist = DiagGaussianDistribution(action_dim)
mu = mu.view(shape + (action_dim,))
std = std.view(shape + (action_dim,))
dist.set_param(mu, std)
return dist
else:
raise NotImplementedError
[docs]
class Distribution(ABC):
def __init__(self):
super(Distribution, self).__init__()
self.distribution = None
[docs]
@abstractmethod
def set_param(self, *args):
raise NotImplementedError
[docs]
@abstractmethod
def get_param(self):
raise NotImplementedError
[docs]
@abstractmethod
def log_prob(self, x: ms.Tensor):
raise NotImplementedError
[docs]
@abstractmethod
def entropy(self):
raise NotImplementedError
[docs]
@abstractmethod
def stochastic_sample(self):
raise NotImplementedError
[docs]
@abstractmethod
def deterministic_sample(self):
raise NotImplementedError
[docs]
class CategoricalDistribution(Distribution):
def __init__(self, action_dim: int):
super(CategoricalDistribution, self).__init__()
self.action_dim = action_dim
self.distribution = Categorical()
self.probs, self.logits = None, None
[docs]
def set_param(self, probs=None, logits=None):
if probs is not None:
logits = ops.log(probs) - ops.log1p(-probs)
elif logits is not None:
probs = ops.softmax(logits, axis=-1)
else:
raise RuntimeError("Failed to setup distributions without given probs or logits.")
self.probs = probs
self.logits = logits
[docs]
def get_param(self):
return self.logits
[docs]
def log_prob(self, x):
return self.distribution.log_prob(value=Tensor(x), probs=self.probs)
[docs]
def entropy(self):
return self.distribution.entropy(probs=self.probs)
[docs]
def stochastic_sample(self):
return self.distribution.sample(probs=self.probs)
[docs]
def deterministic_sample(self):
return self.argmax(self.distribution.probs)
[docs]
def kl_divergence(self, other: Distribution):
assert isinstance(other,
CategoricalDistribution), "KL Divergence should be measured by two same distribution with the same type"
return self.distribution.kl_loss(self.distribution, other.distribution)
[docs]
class DiagGaussianDistribution(Distribution):
def __init__(self, action_dim: int):
super(DiagGaussianDistribution, self).__init__()
self.mu, self.std = None, None
self.action_dim = action_dim
[docs]
def set_param(self, mu, std):
self.mu = mu
self.std = std
self.distribution = Normal(mean=self.mu, sd=self.std, dtype=ms.float32)
[docs]
def get_param(self):
return self.mu, self.std
[docs]
def log_prob(self, x: ms.Tensor):
return self.distribution.log_prob(value=Tensor(x), mean=self.mu, sd=self.std).sum(-1)
[docs]
def entropy(self):
return self.distribution.entropy(mean=self.mu, sd=self.std).sum(-1)
[docs]
def stochastic_sample(self):
return self.distribution.sample()
[docs]
def deterministic_sample(self):
return self.mu
[docs]
def kl_divergence(self, other: Distribution):
assert isinstance(other,
DiagGaussianDistribution), "KL Divergence should be measured by two same distribution with the same type"
return ops.kl_div(self.distribution, other.distribution)
[docs]
class ActivatedDiagGaussianDistribution(DiagGaussianDistribution):
def __init__(self, action_dim: int, activation_action):
super(ActivatedDiagGaussianDistribution, self).__init__(action_dim)
self.activation_fn = activation_action()
[docs]
def activated_rsample(self):
return self.activation_fn(self.stochastic_sample())
[docs]
def activated_rsample_and_logprob(self):
act_pre_activated = self.stochastic_sample() # sample without being activated.
act_activated = self.activation_fn(act_pre_activated)
log_prob = self.distribution.log_prob(act_pre_activated)
correction = - 2. * (ops.log(Tensor([2.0])) - act_pre_activated - ops.softplus(-2. * act_pre_activated))
log_prob += correction
return act_activated, log_prob.sum(-1)