|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
|
import numpy as np |
|
import torch |
|
from torch import nn |
|
|
|
|
|
def get_timestep_embedding( |
|
timesteps: torch.Tensor, |
|
embedding_dim: int, |
|
flip_sin_to_cos: bool = False, |
|
downscale_freq_shift: float = 1, |
|
scale: float = 1, |
|
max_period: int = 10000, |
|
): |
|
""" |
|
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. |
|
|
|
:param timesteps: a 1-D Tensor of N indices, one per batch element. |
|
These may be fractional. |
|
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the |
|
embeddings. :return: an [N x dim] Tensor of positional embeddings. |
|
""" |
|
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" |
|
|
|
half_dim = embedding_dim // 2 |
|
exponent = -math.log(max_period) * torch.arange( |
|
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device |
|
) |
|
exponent = exponent / (half_dim - downscale_freq_shift) |
|
|
|
emb = torch.exp(exponent) |
|
emb = timesteps[:, None].float() * emb[None, :] |
|
|
|
|
|
emb = scale * emb |
|
|
|
|
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) |
|
|
|
|
|
if flip_sin_to_cos: |
|
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) |
|
|
|
|
|
if embedding_dim % 2 == 1: |
|
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) |
|
return emb |
|
|
|
def zero_module(module): |
|
|
|
for p in module.parameters(): |
|
p.detach().zero_() |
|
return module |
|
|
|
class TimestepEmbedding(nn.Module): |
|
def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None, time_cond_proj_dim=None): |
|
super().__init__() |
|
|
|
self.linear_1 = nn.Linear(in_channels, time_embed_dim) |
|
self.act = None |
|
if act_fn == "silu": |
|
self.act = nn.SiLU() |
|
elif act_fn == "mish": |
|
self.act = nn.Mish() |
|
|
|
if time_cond_proj_dim is not None: |
|
self.cond_proj = zero_module(nn.Linear(time_cond_proj_dim, in_channels, bias=False)) |
|
else: |
|
self.cond_proj = None |
|
|
|
|
|
if out_dim is not None: |
|
time_embed_dim_out = out_dim |
|
else: |
|
time_embed_dim_out = time_embed_dim |
|
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) |
|
|
|
def forward(self, sample, condition=None): |
|
if condition is not None: |
|
sample = sample + self.cond_proj(condition) |
|
sample = self.linear_1(sample) |
|
|
|
if self.act is not None: |
|
sample = self.act(sample) |
|
|
|
sample = self.linear_2(sample) |
|
return sample |
|
|
|
|
|
class Timesteps(nn.Module): |
|
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): |
|
super().__init__() |
|
self.num_channels = num_channels |
|
self.flip_sin_to_cos = flip_sin_to_cos |
|
self.downscale_freq_shift = downscale_freq_shift |
|
|
|
def forward(self, timesteps): |
|
t_emb = get_timestep_embedding( |
|
timesteps, |
|
self.num_channels, |
|
flip_sin_to_cos=self.flip_sin_to_cos, |
|
downscale_freq_shift=self.downscale_freq_shift, |
|
) |
|
return t_emb |
|
|
|
|
|
class GaussianFourierProjection(nn.Module): |
|
"""Gaussian Fourier embeddings for noise levels.""" |
|
|
|
def __init__( |
|
self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False |
|
): |
|
super().__init__() |
|
self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) |
|
self.log = log |
|
self.flip_sin_to_cos = flip_sin_to_cos |
|
|
|
if set_W_to_weight: |
|
|
|
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) |
|
|
|
self.weight = self.W |
|
|
|
def forward(self, x): |
|
if self.log: |
|
x = torch.log(x) |
|
|
|
x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi |
|
|
|
if self.flip_sin_to_cos: |
|
out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) |
|
else: |
|
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) |
|
return out |
|
|
|
|
|
class ImagePositionalEmbeddings(nn.Module): |
|
""" |
|
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the |
|
height and width of the latent space. |
|
|
|
For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 |
|
|
|
For VQ-diffusion: |
|
|
|
Output vector embeddings are used as input for the transformer. |
|
|
|
Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. |
|
|
|
Args: |
|
num_embed (`int`): |
|
Number of embeddings for the latent pixels embeddings. |
|
height (`int`): |
|
Height of the latent image i.e. the number of height embeddings. |
|
width (`int`): |
|
Width of the latent image i.e. the number of width embeddings. |
|
embed_dim (`int`): |
|
Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
num_embed: int, |
|
height: int, |
|
width: int, |
|
embed_dim: int, |
|
): |
|
super().__init__() |
|
|
|
self.height = height |
|
self.width = width |
|
self.num_embed = num_embed |
|
self.embed_dim = embed_dim |
|
|
|
self.emb = nn.Embedding(self.num_embed, embed_dim) |
|
self.height_emb = nn.Embedding(self.height, embed_dim) |
|
self.width_emb = nn.Embedding(self.width, embed_dim) |
|
|
|
def forward(self, index): |
|
emb = self.emb(index) |
|
|
|
height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height)) |
|
|
|
|
|
height_emb = height_emb.unsqueeze(2) |
|
|
|
width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)) |
|
|
|
|
|
width_emb = width_emb.unsqueeze(1) |
|
|
|
pos_emb = height_emb + width_emb |
|
|
|
|
|
pos_emb = pos_emb.view(1, self.height * self.width, -1) |
|
|
|
emb = emb + pos_emb[:, : emb.shape[1], :] |
|
|
|
return emb |
|
|