import torch
from torch.distributions import Independent, Normal, Bernoulli
from torch.distributions.utils import logits_to_probs
import numpy as np
from argparse import Namespace
from xuance.common import Any, Dict, Tuple, Sequence, List
from xuance.torch import Tensor, Module
from xuance.torch.utils.distributions import MSEDistribution, SymLogDistribution, TwoHotEncodingDistribution, BernoulliSafeMode
from xuance.torch.utils import dotdict, Moments, compute_lambda_values
from xuance.torch.utils.harmonizer import Harmonizer
[docs]
class DreamerV3Policy(Module): # checked
def __init__(self,
model: Module,
config: Namespace):
super(DreamerV3Policy, self).__init__()
# convert to dotdict
self.config = dotdict(vars(config))
self.stoch_size = self.config.world_model.stochastic_size
self.disc_size = self.config.world_model.discrete_size
self.stoch_state_size = self.stoch_size * self.disc_size # 1024 = 32 * 32
self.batch_size = self.config.batch_size
self.seq_len = self.config.seq_len
self.recurrent_state_size = self.config.world_model.recurrent_model.recurrent_state_size
self.device = self.config.device
self.is_continuous = self.config.is_continuous
self.actions_dim = np.sum(self.config.act_shape) # continuous: num of action props; discrete: num of actions
# nets
self.model: Module = model
self.world_model: Module = self.model.world_model
self.actor: Module = self.model.actor
self.critic: Module = self.model.critic
self.target_critic: Module = self.model.target_critic
# running mean
self.moments = Moments(
self.config.actor.moments.decay,
self.config.actor.moments.max,
self.config.actor.moments.percentile.low,
self.config.actor.moments.percentile.high,
)
self.harmonizer_s1 = Harmonizer(self.device)
self.harmonizer_s2 = Harmonizer(self.device)
self.harmonizer_s3 = Harmonizer(self.device)
[docs]
def model_forward(self,
obs: Tensor,
acts: Tensor,
is_first: Tensor) \
-> Tuple[SymLogDistribution, TwoHotEncodingDistribution, Independent, Tensor, Tensor,
Tensor, Tensor]:
recurrent_state = torch.zeros(1, self.batch_size, self.recurrent_state_size, device=self.device) # [1, 16, 512]
recurrent_states = torch.empty(self.seq_len, self.batch_size, self.recurrent_state_size,
device=self.device) # [64, 16, 512]
priors_logits = torch.empty(self.seq_len, self.batch_size, self.stoch_state_size, device=self.device) # [64, 16, 1024]
embedded_obs = self.world_model.encoder(obs) # [64, 16, 512]
# [1, 16, 32, 32], [64, 16, 32, 32], [64, 16, 1024]
posterior = torch.zeros(1, self.batch_size, self.stoch_size, self.disc_size, device=self.device)
posteriors = torch.empty(self.seq_len, self.batch_size, self.stoch_size, self.disc_size, device=self.device)
posteriors_logits = torch.empty(self.seq_len, self.batch_size, self.stoch_state_size, device=self.device)
for i in range(0, self.seq_len):
recurrent_state, posterior, _, posterior_logits, prior_logits = self.world_model.rssm.dynamic(
posterior, # z0 [1, 16, 32, 32]
recurrent_state, # h0 [1, 16, 512]
acts[i: i + 1], # a0 [1, 16, 2]
embedded_obs[i: i + 1], # x1 [1, 16, 512]
is_first[i: i + 1], # is_first1 [1, 16, 1]
) # h0, cat(z0, a0) -> h1; h1 + x1 -> z1; h1 -> z1_hat
recurrent_states[i] = recurrent_state
priors_logits[i] = prior_logits # z1_hat
posteriors[i] = posterior
posteriors_logits[i] = posterior_logits # z1
latent_states = torch.cat((posteriors.view(*posteriors.shape[:-2], -1), recurrent_states), -1)
"""model_states: [64, 16, 32 * 32 + 512 = 1536]"""
reconstructed_obs: Tensor = self.world_model.observation_model(latent_states)
"""po(obs, symlog_dist)"""
po = SymLogDistribution(reconstructed_obs, dims=len(reconstructed_obs.shape[2:]))
"""pr(rews, two_hot_dist)"""
pr = TwoHotEncodingDistribution(self.world_model.reward_model(latent_states), dims=1)
"""pc(cont, bernoulli_dist)"""
pc = Independent(BernoulliSafeMode(logits=self.world_model.continue_model(latent_states)), 1)
# -> [seq, batch, 32, 32]
priors_logits = priors_logits.view(*priors_logits.shape[:-1], self.stoch_size, self.disc_size)
posteriors_logits = posteriors_logits.view(*posteriors_logits.shape[:-1], self.stoch_size, self.disc_size)
return (po, pr, pc, priors_logits, posteriors_logits,
recurrent_states, posteriors)
[docs]
def actor_critic_forward(self,
posteriors: Tensor,
recurrent_states: Tensor,
terms: Tensor) \
-> Dict[str, List[Any]]:
imagined_prior = posteriors.detach().reshape(1, -1, self.stoch_state_size)
recurrent_state = recurrent_states.detach().reshape(1, -1, self.recurrent_state_size) # [1, 1024, 512]
imagined_latent_state = torch.cat((imagined_prior, recurrent_state), -1) # [1, 1024, 1536]
imagined_trajectories = torch.empty(
self.config.horizon + 1,
self.batch_size * self.seq_len,
self.stoch_state_size + self.recurrent_state_size,
device=self.device,
) # [16, 1024, 1536]
imagined_trajectories[0] = imagined_latent_state
imagined_actions = torch.empty(
self.config.horizon + 1,
self.batch_size * self.seq_len,
self.actions_dim,
device=self.device,
) # [16, 1024, 2]
actions = torch.cat(self.actor(imagined_latent_state.detach())[0], dim=-1) # z0 -> a0
imagined_actions[0] = actions
for i in range(1, self.config.horizon + 1):
imagined_prior, recurrent_state = self.world_model.rssm.imagination(imagined_prior, recurrent_state, actions)
imagined_prior = imagined_prior.view(1, -1, self.stoch_state_size) # [1, 1024, 1024]
imagined_latent_state = torch.cat((imagined_prior, recurrent_state), -1)
imagined_trajectories[i] = imagined_latent_state
actions = torch.cat(self.actor(imagined_latent_state.detach())[0], dim=-1)
imagined_actions[i] = actions
predicted_values = TwoHotEncodingDistribution(self.critic(imagined_trajectories), dims=1).mean
predicted_rewards = TwoHotEncodingDistribution(self.world_model.reward_model(imagined_trajectories), dims=1).mean
continues = Independent(BernoulliSafeMode(logits=self.world_model.continue_model(imagined_trajectories)), 1).mode
true_continue = (1 - terms).flatten().reshape(1, -1, 1) # continues: [16, 1024, 1]; true: [1, 1024, 1]
continues = torch.cat((true_continue, continues[1:]))
"""seq_shift[1:]"""
lambda_values = compute_lambda_values(
predicted_rewards[1:],
predicted_values[1:],
continues[1:] * self.config.gamma,
lmbda=self.config.lmbda,
)
with torch.no_grad():
discount = torch.cumprod(continues * self.config.gamma, dim=0) / self.config.gamma
policies: Sequence[torch.distributions.Distribution] = self.actor(imagined_trajectories.detach())[1]
baseline = predicted_values[:-1]
offset, invscale = self.moments(lambda_values)
normed_lambda_values = (lambda_values - offset) / invscale
normed_baseline = (baseline - offset) / invscale
advantage = normed_lambda_values - normed_baseline
if self.is_continuous:
objective = advantage
else:
objective = (
torch.stack(
[
p.log_prob(imgnd_act.detach()).unsqueeze(-1)[:-1]
for p, imgnd_act in zip(policies, torch.split(imagined_actions, [self.actions_dim], dim=-1))
],
dim=-1,
).sum(dim=-1)
* advantage.detach()
)
try:
entropy = self.config.actor.ent_coef * torch.stack([p.entropy() for p in policies], -1).sum(dim=-1)
except NotImplementedError:
entropy = torch.zeros_like(objective)
"""seq_shift"""
qv = TwoHotEncodingDistribution(self.critic(imagined_trajectories.detach()[:-1]), dims=1)
predicted_target_values = TwoHotEncodingDistribution(
self.target_critic(imagined_trajectories.detach()[:-1]), dims=1
).mean
return {
'for_actor': [objective, discount, entropy],
'for_critic': [qv, predicted_target_values, lambda_values]
}
[docs]
def soft_update(self, tau=0.02): # checked
for ep, tp in zip(self.critic.parameters(), self.target_critic.parameters()):
tp.data.mul_(1 - tau)
tp.data.add_(tau * ep.data)
[docs]
class DreamerV2Policy(Module): # checked
def __init__(self,
model: Module,
config: Namespace):
super(DreamerV2Policy, self).__init__()
# convert to dotdict
self.config = dotdict(vars(config))
self.stoch_size = self.config.world_model.stochastic_size
self.disc_size = self.config.world_model.discrete_size
self.stoch_state_size = self.stoch_size * self.disc_size # 1024 = 32 * 32
self.batch_size = self.config.batch_size
self.seq_len = self.config.seq_len
self.recurrent_state_size = self.config.world_model.recurrent_model.recurrent_state_size
self.device = self.config.device
self.is_continuous = self.config.is_continuous
self.actions_dim = np.sum(self.config.act_shape) # continuous: num of action props; discrete: num of actions
# nets
self.model: Module = model
self.world_model: Module = self.model.world_model
self.actor: Module = self.model.actor
self.critic: Module = self.model.critic
self.target_critic: Module = self.model.target_critic
self.harmonizer_s1 = Harmonizer(self.device)
self.harmonizer_s2 = Harmonizer(self.device)
self.harmonizer_s3 = Harmonizer(self.device)
[docs]
def model_forward(self,
obs: Tensor,
acts: Tensor,
is_first: Tensor) \
-> Tuple[Independent, Independent, Independent, Tensor, Tensor,
Tensor, Tensor]:
recurrent_state = torch.zeros(1, self.batch_size, self.recurrent_state_size, device=self.device)
recurrent_states = torch.zeros(self.seq_len, self.batch_size, self.recurrent_state_size,
device=self.device)
priors_logits = torch.empty(self.seq_len, self.batch_size, self.stoch_state_size, device=self.device)
embedded_obs = self.world_model.encoder(obs)
posterior = torch.zeros(1, self.batch_size, self.stoch_size, self.disc_size, device=self.device)
posteriors = torch.empty(self.seq_len, self.batch_size, self.stoch_size, self.disc_size, device=self.device)
posteriors_logits = torch.empty(self.seq_len, self.batch_size, self.stoch_state_size, device=self.device)
for i in range(0, self.seq_len):
recurrent_state, posterior, _, posterior_logits, prior_logits = self.world_model.rssm.dynamic(
posterior,
recurrent_state,
acts[i: i + 1],
embedded_obs[i: i + 1],
is_first[i: i + 1],
)
recurrent_states[i] = recurrent_state
priors_logits[i] = prior_logits
posteriors[i] = posterior
posteriors_logits[i] = posterior_logits
latent_states = torch.cat((posteriors.view(*posteriors.shape[:-2], -1), recurrent_states), -1)
reconstructed_obs: Tensor = self.world_model.observation_model(latent_states)
po = Independent(Normal(reconstructed_obs, 1), len(reconstructed_obs.shape[2:]))
pr = Independent(Normal(self.world_model.reward_model(latent_states), 1), 1)
# error due to not support of Boolean()
if self.config.world_model.use_continues:
pc = Independent(Bernoulli(logits=self.world_model.continue_model(latent_states)), 1)
else:
pc = None
# -> [seq, batch, 32, 32]
priors_logits = priors_logits.view(*priors_logits.shape[:-1], self.stoch_size, self.disc_size)
posteriors_logits = posteriors_logits.view(*posteriors_logits.shape[:-1], self.stoch_size, self.disc_size)
return (po, pr, pc, priors_logits, posteriors_logits,
recurrent_states, posteriors)
[docs]
def actor_critic_forward(self,
posteriors: Tensor,
recurrent_states: Tensor,
terms: Tensor) \
-> Dict[str, List[Any]]:
imagined_prior = posteriors.detach().reshape(1, -1, self.stoch_state_size)
recurrent_state = recurrent_states.detach().reshape(1, -1, self.recurrent_state_size)
imagined_latent_state = torch.cat((imagined_prior, recurrent_state), -1)
imagined_trajectories = torch.empty(
self.config.horizon + 1,
self.batch_size * self.seq_len,
self.stoch_state_size + self.recurrent_state_size,
device=self.device,
)
imagined_trajectories[0] = imagined_latent_state
imagined_actions = torch.empty(
self.config.horizon + 1,
self.batch_size * self.seq_len,
self.actions_dim,
device=self.device,
)
# diff to v3; here is, at imagined_trajectories[0] takes action imagined_actions[1]
imagined_actions[0] = torch.zeros(1, self.batch_size * self.seq_len, self.actions_dim)
for i in range(1, self.config.horizon + 1):
# (1, batch_size * seq_len, sum(actions_dim))
actions = torch.cat(self.actor(imagined_latent_state.detach())[0], dim=-1)
imagined_actions[i] = actions
imagined_prior, recurrent_state = self.world_model.rssm.imagination(imagined_prior, recurrent_state, actions)
imagined_prior = imagined_prior.view(1, -1, self.stoch_state_size)
imagined_latent_state = torch.cat((imagined_prior, recurrent_state), -1)
imagined_trajectories[i] = imagined_latent_state
predicted_target_values = self.target_critic(imagined_trajectories)
predicted_rewards = self.world_model.reward_model(imagined_trajectories)
if self.config.world_model.use_continues:
continues = logits_to_probs(self.world_model.continue_model(imagined_trajectories), is_binary=True) # diff to v3
# diff to v3(v3: no self.config.gamma here, but mult gamma before passing to 'compute_lambda_values')
true_continue = (1 - terms).reshape(1, -1, 1) * self.config.gamma
continues = torch.cat((true_continue, continues[1:]))
else:
continues = torch.ones_like(predicted_rewards.detach()) * self.config.gamma
# Compute the lambda_values, by passing as last value the value of the last imagined state
# (horizon, batch_size * seq_len, 1)
lambda_values = DreamerV2Policy.compute_lambda_values(
predicted_rewards[:-1],
predicted_target_values[:-1],
continues[:-1],
bootstrap=predicted_target_values[-1:],
horizon=self.config.horizon,
lmbda=self.config.lmbda,
)
with torch.no_grad():
discount = torch.cumprod(torch.cat((torch.ones_like(continues[:1]), continues[:-1]), 0), 0)
policies: Sequence[torch.distributions.Distribution] = self.actor(imagined_trajectories[:-2].detach())[1]
# Dynamics backpropagation
dynamics = lambda_values[1:]
# Reinforce
advantage = (lambda_values[1:] - predicted_target_values[:-2]).detach()
reinforce = (
torch.stack(
[
p.log_prob(imgnd_act[1:-1].detach()).unsqueeze(-1)
for p, imgnd_act in zip(policies, torch.split(imagined_actions, [self.actions_dim], -1))
],
-1,
).sum(-1)
* advantage
)
objective = self.config.actor.objective_mix * reinforce + (1 - self.config.actor.objective_mix) * dynamics
try:
entropy = self.config.actor.ent_coef * torch.stack([p.entropy() for p in policies], -1).sum(dim=-1)
except NotImplementedError:
entropy = torch.zeros_like(objective)
# policy_loss = -torch.mean(discount[:-2].detach() * (objective + entropy.unsqueeze(-1)))
# last imagined state (with position=horizon+1) in the trajectory only used for bootstrapping;
qv = Independent(Normal(self.critic(imagined_trajectories.detach()[:-1]), 1), 1)
return {
'for_actor': [objective, discount, entropy],
'for_critic': [qv, predicted_target_values, lambda_values]
}
[docs]
def hard_update(self): # checked
for ep, tp in zip(self.critic.parameters(), self.target_critic.parameters()):
tp.data.mul_(0)
tp.data.add_(1.0 * ep.data)
[docs]
@staticmethod
def compute_lambda_values(
rewards: Tensor,
values: Tensor,
continues: Tensor,
bootstrap: Tensor = None,
horizon: int = 15,
lmbda: float = 0.95,
) -> Tensor:
if bootstrap is None:
bootstrap = torch.zeros_like(values[-1:])
agg = bootstrap
next_val = torch.cat((values[1:], bootstrap), dim=0)
inputs = rewards + continues * next_val * (1 - lmbda)
lv = []
for i in reversed(range(horizon)):
agg = inputs[i] + continues[i] * lmbda * agg
lv.append(agg)
return torch.cat(list(reversed(lv)), dim=0)