nicolas-dufour's picture
squash: merge all unpushed commits
c4c7cee
raw
history blame
7.09 kB
import torch.nn as nn
from models.positional_embeddings import FourierEmbedding, PositionalEmbedding
from models.networks.transformers import FusedMLP
import torch
import torch.nn.functional as F
import numpy as np
from einops import rearrange
class TimeEmbedder(nn.Module):
def __init__(
self,
noise_embedding_type: str,
dim: int,
time_scaling: float,
expansion: int = 4,
):
super().__init__()
self.encode_time = (
PositionalEmbedding(num_channels=dim, endpoint=True)
if noise_embedding_type == "positional"
else FourierEmbedding(num_channels=dim)
)
self.time_scaling = time_scaling
self.map_time = nn.Sequential(
nn.Linear(dim, dim * expansion),
nn.SiLU(),
nn.Linear(dim * expansion, dim * expansion),
)
def forward(self, t):
time = self.encode_time(t * self.time_scaling)
time_mean = time.mean(dim=-1, keepdim=True)
time_std = time.std(dim=-1, keepdim=True)
time = (time - time_mean) / time_std
return self.map_time(time)
def get_timestep_embedding(timesteps, embedding_dim, dtype=torch.float32):
assert len(timesteps.shape) == 1
timesteps = timesteps * 1000.0
half_dim = embedding_dim // 2
emb = np.log(10000) / (half_dim - 1)
emb = (torch.arange(half_dim, dtype=dtype, device=timesteps.device) * -emb).exp()
emb = timesteps.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
if embedding_dim % 2 == 1: # zero pad
emb = F.pad(emb, (0, 1))
assert emb.shape == (timesteps.shape[0], embedding_dim)
return emb
class AdaLNMLPBlock(nn.Module):
def __init__(self, dim, expansion):
super().__init__()
self.mlp = FusedMLP(
dim, dropout=0.0, hidden_layer_multiplier=expansion, activation=nn.GELU
)
self.ada_map = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 3))
self.ln = nn.LayerNorm(dim, elementwise_affine=False)
nn.init.zeros_(self.mlp[-1].weight)
nn.init.zeros_(self.mlp[-1].bias)
def forward(self, x, y):
gamma, mu, sigma = self.ada_map(y).chunk(3, dim=-1)
x_res = (1 + gamma) * self.ln(x) + mu
x = x + self.mlp(x_res) * sigma
return x
class GeoAdaLNMLP(nn.Module):
def __init__(self, input_dim, dim, depth, expansion, cond_dim):
super().__init__()
self.time_embedder = TimeEmbedder("positional", dim // 4, 1000, expansion=4)
self.cond_mapper = nn.Linear(cond_dim, dim)
self.initial_mapper = nn.Linear(input_dim, dim)
self.blocks = nn.ModuleList(
[AdaLNMLPBlock(dim, expansion) for _ in range(depth)]
)
self.final_adaln = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, dim * 2),
)
self.final_ln = nn.LayerNorm(dim, elementwise_affine=False)
self.final_linear = nn.Linear(dim, input_dim)
def forward(self, batch):
x = batch["y"]
x = self.initial_mapper(x)
gamma = batch["gamma"]
cond = batch["emb"]
t = self.time_embedder(gamma)
cond = self.cond_mapper(cond)
cond = cond + t
for block in self.blocks:
x = block(x, cond)
gamma_last, mu_last = self.final_adaln(cond).chunk(2, dim=-1)
x = (1 + gamma_last) * self.final_ln(x) + mu_last
x = self.final_linear(x)
return x
class GeoAdaLNMLPVonFisher(nn.Module):
def __init__(self, input_dim, dim, depth, expansion, cond_dim):
super().__init__()
self.cond_mapper = nn.Linear(cond_dim, dim)
self.blocks = nn.ModuleList(
[AdaLNMLPBlock(dim, expansion) for _ in range(depth)]
)
self.final_adaln = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, dim * 2),
)
self.final_ln = nn.LayerNorm(dim, elementwise_affine=False)
self.mu_predictor = nn.Sequential(
FusedMLP(dim, dropout=0.0, hidden_layer_multiplier=2, activation=nn.GELU),
nn.Linear(dim, input_dim),
)
self.kappa_predictor = nn.Sequential(
FusedMLP(dim, dropout=0.0, hidden_layer_multiplier=2, activation=nn.GELU),
nn.Linear(dim, 1),
torch.nn.Softplus(),
)
self.init_registers = torch.nn.Parameter(torch.randn(dim), requires_grad=True)
torch.nn.init.trunc_normal_(
self.init_registers, std=0.02, a=-2 * 0.02, b=2 * 0.02
)
def forward(self, batch):
cond = batch["emb"]
cond = self.cond_mapper(cond)
x = self.init_registers.unsqueeze(0).repeat(cond.shape[0], 1)
for block in self.blocks:
x = block(x, cond)
gamma_last, mu_last = self.final_adaln(cond).chunk(2, dim=-1)
x = (1 + gamma_last) * self.final_ln(x) + mu_last
mu = self.mu_predictor(x)
mu = mu / mu.norm(dim=-1, keepdim=True)
kappa = self.kappa_predictor(x)
return mu, kappa
class GeoAdaLNMLPVonFisherMixture(nn.Module):
def __init__(self, input_dim, dim, depth, expansion, cond_dim, num_mixtures=3):
super().__init__()
self.cond_mapper = nn.Linear(cond_dim, dim)
self.blocks = nn.ModuleList(
[AdaLNMLPBlock(dim, expansion) for _ in range(depth)]
)
self.final_adaln = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, dim * 2),
)
self.final_ln = nn.LayerNorm(dim, elementwise_affine=False)
self.mu_predictor = nn.Sequential(
FusedMLP(dim, dropout=0.0, hidden_layer_multiplier=2, activation=nn.GELU),
nn.Linear(dim, input_dim * num_mixtures),
)
self.kappa_predictor = nn.Sequential(
FusedMLP(dim, dropout=0.0, hidden_layer_multiplier=2, activation=nn.GELU),
nn.Linear(dim, num_mixtures),
torch.nn.Softplus(),
)
self.mixture_weights = nn.Sequential(
FusedMLP(dim, dropout=0.0, hidden_layer_multiplier=2, activation=nn.GELU),
nn.Linear(dim, num_mixtures),
torch.nn.Softmax(dim=-1),
)
self.num_mixtures = num_mixtures
self.init_registers = torch.nn.Parameter(torch.randn(dim), requires_grad=True)
torch.nn.init.trunc_normal_(
self.init_registers, std=0.02, a=-2 * 0.02, b=2 * 0.02
)
def forward(self, batch):
cond = batch["emb"]
cond = self.cond_mapper(cond)
x = self.init_registers.unsqueeze(0).repeat(cond.shape[0], 1)
for block in self.blocks:
x = block(x, cond)
gamma_last, mu_last = self.final_adaln(cond).chunk(2, dim=-1)
x = (1 + gamma_last) * self.final_ln(x) + mu_last
mu = self.mu_predictor(x)
mu = rearrange(mu, "b (n d) -> b n d", n=self.num_mixtures)
mu = mu / mu.norm(dim=-1, keepdim=True)
kappa = self.kappa_predictor(x)
weights = self.mixture_weights(x)
return mu, kappa, weights