|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
|
|
from safetensors.torch import load_model, save_model |
|
|
|
|
|
def normalize(x: torch.Tensor, dim=None, eps=1e-4) -> torch.Tensor: |
|
if dim is None: |
|
dim = list(range(1, x.ndim)) |
|
norm = torch.linalg.vector_norm( |
|
x, dim=dim, keepdim=True, dtype=torch.float32) |
|
norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel())) |
|
norm_detached = norm.detach().to(x.dtype) |
|
return x / norm_detached |
|
|
|
|
|
|
|
class FourierFeatureExtractor(nn.Module): |
|
def __init__(self, num_channels, bandwidth=1): |
|
super().__init__() |
|
self.register_buffer('freqs', 2 * torch.pi * |
|
torch.randn(num_channels) * bandwidth) |
|
self.register_buffer('phases', 2 * torch.pi * torch.rand(num_channels)) |
|
self.sqrt_two = torch.sqrt(torch.tensor(2)) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
y = x.to(torch.float32) |
|
y = y.ger(self.freqs.to(torch.float32)) |
|
y = y + self.phases.to(torch.float32) |
|
y = y.cos() * self.sqrt_two |
|
return y.to(x.dtype) |
|
|
|
|
|
class NormalizedLinearLayer(nn.Module): |
|
def __init__(self, in_channels, out_channels, kernel): |
|
super().__init__() |
|
self.out_channels = out_channels |
|
self.weight = nn.Parameter(torch.randn( |
|
out_channels, in_channels, *kernel)) |
|
|
|
def forward(self, x: torch.Tensor, gain=1) -> torch.Tensor: |
|
w = self.weight.to(torch.float32) |
|
if self.training: |
|
with torch.no_grad(): |
|
self.weight.copy_(normalize(w)) |
|
w = normalize(w) |
|
|
|
w = w * (gain / np.sqrt(w[0].numel())) |
|
w = w.to(x.dtype) |
|
if w.ndim == 2: |
|
return x @ w.t() |
|
assert w.ndim == 4 |
|
return nn.functional.conv2d(x, w, padding=(w.shape[-1]//2,)) |
|
|
|
|
|
class AdaptiveLossWeightMLP(nn.Module): |
|
def __init__( |
|
self, |
|
noise_scheduler, |
|
logvar_channels=128, |
|
device='cuda', |
|
**kwargs |
|
): |
|
super().__init__() |
|
self.device = device |
|
self.noise_scheduler = noise_scheduler |
|
self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(device) |
|
|
|
self.a_bar_mean = noise_scheduler.alphas_cumprod.mean().to(device) |
|
self.a_bar_std = noise_scheduler.alphas_cumprod.std().to(device) |
|
self.alphas_cumprod = noise_scheduler.alphas_cumprod.to(device) |
|
|
|
self.logvar_fourier = FourierFeatureExtractor(logvar_channels).to(device) |
|
|
|
self.logvar_linear = NormalizedLinearLayer( |
|
logvar_channels, 1, kernel=[]).to(device) |
|
|
|
def _forward(self, timesteps: torch.Tensor): |
|
return self.compute_variance(timesteps) |
|
|
|
def forward(self, loss: torch.Tensor, timesteps): |
|
adaptive_loss_weights = self.compute_variance(timesteps) |
|
|
|
loss_scaled = loss / torch.exp(adaptive_loss_weights) |
|
|
|
|
|
|
|
|
|
|
|
return loss_scaled |
|
|
|
def compute_variance(self, timesteps: torch.Tensor): |
|
return self._compute_ddpm_variance(timesteps) |
|
|
|
def _compute_ddpm_variance(self, timesteps: torch.Tensor): |
|
timesteps = timesteps.to(self.device) |
|
a_bar = self.noise_scheduler.alphas_cumprod[timesteps] |
|
c_noise = a_bar.sub(self.a_bar_mean).div_(self.a_bar_std) |
|
return self.logvar_linear(self.logvar_fourier(c_noise)).squeeze() |
|
|
|
|
|
class EDM2WeightingWrapper: |
|
def __init__(self, |
|
noise_scheduler, |
|
optimizer=torch.optim.AdamW, |
|
lr=5e-3, optimizer_args={'weight_decay': 0}, |
|
logvar_channels=128, |
|
device="cuda", |
|
): |
|
""" |
|
Initialize EDM2Loss with Fourier features for training with dynamic loss weighting. |
|
|
|
:param optimizer: Optimizer class to use. |
|
:param lr: Learning rate for the optimizer. |
|
:param optimizer_args: Additional arguments for the optimizer. |
|
:param device: Device to run computations on. |
|
:param logvar_channels: Fourier decomposition complexity. |
|
""" |
|
self.device = device |
|
noise_scheduler.alphas_cumprod = noise_scheduler.alphas_cumprod.to(device) |
|
self.model = AdaptiveLossWeightMLP( |
|
noise_scheduler=noise_scheduler, |
|
logvar_channels=logvar_channels, |
|
device=device |
|
).to(device) |
|
|
|
|
|
|
|
|
|
self.model.train(mode=True) |
|
self.optimizer = optimizer( |
|
self.model.parameters(), lr=lr, **optimizer_args) |
|
|
|
self.model.train(mode=True) |
|
|
|
|
|
def __call__(self, loss, timesteps): |
|
""" |
|
Compute the weighted loss and backpropagate it through the loss_module. |
|
|
|
:param timesteps: Tensor of timesteps (shape: [batch_size]). |
|
:param loss: Tensor of individual losses (shape: [batch_size]). |
|
:return: Scalar tensor representing the total weighted loss. |
|
""" |
|
timesteps = timesteps.to(self.device) |
|
loss = loss.to(self.device) |
|
|
|
|
|
weighted_losses = self.model(loss, timesteps) |
|
weighted_loss = weighted_losses.mean() |
|
|
|
|
|
|
|
weighted_loss.backward( |
|
retain_graph=True, inputs=list(self.model.parameters())) |
|
|
|
self.optimizer.step() |
|
self.optimizer.zero_grad() |
|
|
|
return weighted_losses |
|
|
|
def save_model(self, path): |
|
save_model(self.model, path) |
|
|
|
def load_model(self, path): |
|
load_model(self.model, path) |
|
|