kawaimasa's picture
Upload 2 files
5f19913 verified
raw
history blame
6.29 kB
import torch
import torch.nn as nn
import numpy as np
from safetensors.torch import load_model, save_model
def normalize(x: torch.Tensor, dim=None, eps=1e-4) -> torch.Tensor:
if dim is None:
dim = list(range(1, x.ndim))
norm = torch.linalg.vector_norm(
x, dim=dim, keepdim=True, dtype=torch.float32) # type: torch.Tensor
norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
norm_detached = norm.detach().to(x.dtype) # Detach and cast to x's dtype
return x / norm_detached
# return x / norm.to(x.dtype)
class FourierFeatureExtractor(nn.Module):
def __init__(self, num_channels, bandwidth=1):
super().__init__()
self.register_buffer('freqs', 2 * torch.pi *
torch.randn(num_channels) * bandwidth)
self.register_buffer('phases', 2 * torch.pi * torch.rand(num_channels))
self.sqrt_two = torch.sqrt(torch.tensor(2))
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = x.to(torch.float32)
y = y.ger(self.freqs.to(torch.float32))
y = y + self.phases.to(torch.float32) # type: torch.Tensor
y = y.cos() * self.sqrt_two
return y.to(x.dtype)
class NormalizedLinearLayer(nn.Module):
def __init__(self, in_channels, out_channels, kernel):
super().__init__()
self.out_channels = out_channels
self.weight = nn.Parameter(torch.randn(
out_channels, in_channels, *kernel))
def forward(self, x: torch.Tensor, gain=1) -> torch.Tensor:
w = self.weight.to(torch.float32)
if self.training:
with torch.no_grad():
self.weight.copy_(normalize(w)) # forced weight normalization
w = normalize(w) # traditional weight normalization
# type: torch.Tensor # magnitude-preserving scaling
w = w * (gain / np.sqrt(w[0].numel()))
w = w.to(x.dtype)
if w.ndim == 2:
return x @ w.t()
assert w.ndim == 4
return nn.functional.conv2d(x, w, padding=(w.shape[-1]//2,))
class AdaptiveLossWeightMLP(nn.Module):
def __init__(
self,
noise_scheduler,
logvar_channels=128,
device='cuda',
**kwargs
):
super().__init__()
self.device = device
self.noise_scheduler = noise_scheduler
self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(device)
self.a_bar_mean = noise_scheduler.alphas_cumprod.mean().to(device)
self.a_bar_std = noise_scheduler.alphas_cumprod.std().to(device)
self.alphas_cumprod = noise_scheduler.alphas_cumprod.to(device)
self.logvar_fourier = FourierFeatureExtractor(logvar_channels).to(device)
# kernel = []? (not in code given, added matching edm2)
self.logvar_linear = NormalizedLinearLayer(
logvar_channels, 1, kernel=[]).to(device)
def _forward(self, timesteps: torch.Tensor):
return self.compute_variance(timesteps)
def forward(self, loss: torch.Tensor, timesteps):
adaptive_loss_weights = self.compute_variance(timesteps)
# type: torch.Tensor
loss_scaled = loss / torch.exp(adaptive_loss_weights)
# loss = loss_scaled + adaptive_loss_weights # type: torch.Tensor
# stdev, mean = torch.std_mean(loss)
# print(f"{mean=:.4f} {stdev=:.4f}")
return loss_scaled
def compute_variance(self, timesteps: torch.Tensor):
return self._compute_ddpm_variance(timesteps)
def _compute_ddpm_variance(self, timesteps: torch.Tensor):
timesteps = timesteps.to(self.device)
a_bar = self.noise_scheduler.alphas_cumprod[timesteps]
c_noise = a_bar.sub(self.a_bar_mean).div_(self.a_bar_std)
return self.logvar_linear(self.logvar_fourier(c_noise)).squeeze()
class EDM2WeightingWrapper:
def __init__(self,
noise_scheduler,
optimizer=torch.optim.AdamW,
lr=5e-3, optimizer_args={'weight_decay': 0},
logvar_channels=128,
device="cuda",
):
"""
Initialize EDM2Loss with Fourier features for training with dynamic loss weighting.
:param optimizer: Optimizer class to use.
:param lr: Learning rate for the optimizer.
:param optimizer_args: Additional arguments for the optimizer.
:param device: Device to run computations on.
:param logvar_channels: Fourier decomposition complexity.
"""
self.device = device
noise_scheduler.alphas_cumprod = noise_scheduler.alphas_cumprod.to(device)
self.model = AdaptiveLossWeightMLP(
noise_scheduler=noise_scheduler,
logvar_channels=logvar_channels,
device=device
).to(device)
# # モデルのすべてのパラメータを指定されたデバイスに移動
# for param in self.model.parameters():
# param.data = param.data.to(device)
self.model.train(mode=True)
self.optimizer = optimizer(
self.model.parameters(), lr=lr, **optimizer_args)
self.model.train(mode=True) # Ensure the model is in training mode
def __call__(self, loss, timesteps):
"""
Compute the weighted loss and backpropagate it through the loss_module.
:param timesteps: Tensor of timesteps (shape: [batch_size]).
:param loss: Tensor of individual losses (shape: [batch_size]).
:return: Scalar tensor representing the total weighted loss.
"""
timesteps = timesteps.to(self.device)
loss = loss.to(self.device)
# Forward pass through the loss_module
weighted_losses = self.model(loss, timesteps)
weighted_loss = weighted_losses.mean()
# Backward pass for loss_module
# Only compute gradients for self.model, don't touch anything else
weighted_loss.backward(
retain_graph=True, inputs=list(self.model.parameters()))
self.optimizer.step()
self.optimizer.zero_grad()
return weighted_losses
def save_model(self, path):
save_model(self.model, path)
def load_model(self, path):
load_model(self.model, path)