Source code for xuance.torch.utils.harmonizer

import torch

[docs] class Harmonizer(torch.nn.Module): """Learnable parameter for loss_scale balancing Ref: https://github.com/thuml/HarmonyDream/blob/main/dreamerv3-jax/dreamerv3/nets.py """ def __init__(self, device): super().__init__() self.harmony_s = torch.nn.Parameter(torch.tensor(0.0, device=device))
[docs] def forward(self, loss: torch.Tensor, regularize=True): if regularize: return loss / (torch.exp(self.harmony_s)) + torch.log(torch.exp(self.harmony_s) + 1) else: return loss / (torch.exp(self.harmony_s))
[docs] def get_harmony(self): return self.harmony_s