movie-diffusion / model.py
Anton Forsman
put in everything
f04c9cc
raw
history blame
21.1 kB
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from collections import defaultdict
import torch as th
import numpy as np
import math
from tqdm import tqdm
from PIL import Image
class GaussianDiffusion:
def __init__(self, model, noise_steps, beta_0, beta_T, image_size, channels=3, schedule="linear"):
"""
suggested betas for:
* linear schedule: 1e-4, 0.02
model: the model to be trained (nn.Module)
noise_steps: the number of steps to apply noise (int)
beta_0: the initial value of beta (float)
beta_T: the final value of beta (float)
image_size: the size of the image (int, int)
"""
self.device = 'cpu'
self.channels = channels
self.model = model
self.noise_steps = noise_steps
self.beta_0 = beta_0
self.beta_T = beta_T
self.image_size = image_size
self.betas = self.beta_schedule(schedule=schedule)
self.alphas = 1.0 - self.betas
# cumulative product of alphas, so we can optimize forward process calculation
self.alpha_hat = torch.cumprod(self.alphas, dim=0)
def beta_schedule(self, schedule="cosine"):
if schedule == "linear":
return torch.linspace(self.beta_0, self.beta_T, self.noise_steps).to(self.device)
elif schedule == "cosine":
return self.betas_for_cosine(self.noise_steps)
elif schedule == "sigmoid":
return self.betas_for_sigmoid(self.noise_steps)
@staticmethod
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def betas_for_sigmoid(self, num_diffusion_timesteps, start=-3,end=3, tau=1.0, clip_min = 1e-9):
betas = []
v_start = self.sigmoid(start/tau)
v_end = self.sigmoid(end/tau)
for t in range(num_diffusion_timesteps):
t_float = float(t/num_diffusion_timesteps)
output0 = self.sigmoid((t_float* (end-start)+start)/tau)
output = (v_end-output0) / (v_end-v_start)
betas.append(np.clip(output*.2, clip_min,.2))
return torch.flip(torch.tensor(betas).to(self.device),dims=[0]).float()
def betas_for_cosine(self,num_steps,start=0,end=1,tau=1,clip_min=1e-9):
v_start = math.cos(start*math.pi / 2) ** (2 * tau)
betas = []
v_end = math.cos(end* math.pi/2) ** 2*tau
for t in range(num_steps):
t_float = float(t)/num_steps
output = math.cos((t_float* (end-start)+start)*math.pi/2)**(2*tau)
output = (v_end - output) / (v_end-v_start)
betas.append(np.clip(output*.2,clip_min,.2))
return torch.flip(torch.tensor(betas).to(self.device),dims=[0]).float()
def sample_time_steps(self, batch_size=1):
return torch.randint(0, self.noise_steps, (batch_size,)).to(self.device)
def to(self,device):
self.device = device
self.betas = self.betas.to(device)
self.alphas = self.alphas.to(device)
self.alpha_hat = self.alpha_hat.to(device)
def q(self, x, t):
"""
Forward process
"""
pass
def p(self, x, t):
"""
Backward process
"""
pass
def apply_noise(self, x, t):
# force x to be (batch_size, image_width, image_height, channels)
if len(x.shape) == 3:
x = x.unsqueeze(0)
if type(t) == int:
t = torch.tensor([t])
#print(f'Shape -> {x.shape}, len -> {len(x.shape)}')
sqrt_alpha_hat = torch.sqrt(torch.tensor([self.alpha_hat[t_] for t_ in t]).to(self.device))
sqrt_one_minus_alpha_hat = torch.sqrt(torch.tensor([1.0 - self.alpha_hat[t_] for t_ in t]).to(self.device))
# standard normal distribution
epsilon = torch.randn_like(x).to(self.device)
# Eq 2. in DDPM paper
#noisy_image = sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * epsilon
"""print(f'''
Shape of x {x.shape}
Shape of sqrt {sqrt_one_minus_alpha_hat.shape}''')"""
try:
#print(x.shape)
#noisy_image = torch.einsum("b,bwhc->bwhc", sqrt_alpha_hat, x.to(self.device)) + torch.einsum("b,bwhc->bwhc", sqrt_one_minus_alpha_hat, epsilon)
noisy_image = torch.einsum("b,bcwh->bcwh", sqrt_alpha_hat, x.to(self.device)) + torch.einsum("b,bcwh->bcwh", sqrt_one_minus_alpha_hat, epsilon)
except:
print(f'Failed image: shape {x.shape}')
#print(f'Noisy image -> {noisy_image.shape}')
# returning noisy iamge and the noise which was added to the image
#return noisy_image, epsilon
#return torch.clip(noisy_image, -1.0, 1.0), epsilon
return noisy_image, epsilon
@staticmethod
def normalize_image(x):
# normalize image to [-1, 1]
return x / 255.0 * 2.0 - 1.0
@staticmethod
def denormalize_image(x):
# denormalize image to [0, 255]
return (x + 1.0) / 2.0 * 255.0
def sample_step(self, x, t, cond):
batch_size = x.shape[0]
device = x.device
z = torch.randn_like(x) if t >= 1 else torch.zeros_like(x)
z = z.to(device)
alpha = self.alphas[t]
one_over_sqrt_alpha = 1.0 / torch.sqrt(alpha)
one_minus_alpha = 1.0 - alpha
sqrt_one_minus_alpha_hat = torch.sqrt(1.0 - self.alpha_hat[t])
beta_hat = (1 - self.alpha_hat[t-1]) / (1 - self.alpha_hat[t]) * self.betas[t]
beta = self.betas[t]
# should we reshape the params to (batch_size, 1, 1, 1) ?
# we can either use beta_hat or beta_t
# std = torch.sqrt(beta_hat)
std = torch.sqrt(beta)
# mean + variance * z
if cond is not None:
predicted_noise = self.model(x, torch.tensor([t]).repeat(batch_size).to(device), cond)
else:
predicted_noise = self.model(x, torch.tensor([t]).repeat(batch_size).to(device))
mean = one_over_sqrt_alpha * (x - one_minus_alpha / sqrt_one_minus_alpha_hat * predicted_noise)
x_t_minus_1 = mean + std * z
return x_t_minus_1
def sample(self, num_samples, show_progress=True):
"""
Sample from the model
"""
cond = None
if self.model.is_conditional:
# cond is arange()
assert num_samples <= self.model.num_classes, "num_samples must be less than or equal to the number of classes"
cond = torch.arange(self.model.num_classes)[:num_samples].to(self.device)
cond = rearrange(cond, 'i -> i ()')
self.model.eval()
image_versions = []
with torch.no_grad():
x = torch.randn(num_samples, self.channels, *self.image_size).to(self.device)
it = reversed(range(1, self.noise_steps))
if show_progress:
it = tqdm(it)
for t in it:
image_versions.append(self.denormalize_image(torch.clip(x, -1, 1)).clone().squeeze(0))
x = self.sample_step(x, t, cond)
self.model.train()
x = torch.clip(x, -1.0, 1.0)
return self.denormalize_image(x), image_versions
def validate(self, dataloader):
"""
Calculate the loss on the validation set
"""
self.model.eval()
acc_loss = 0
with torch.no_grad():
for (image, cond) in dataloader:
t = self.sample_time_steps(batch_size=image.shape[0])
noisy_image, added_noise = self.apply_noise(image, t)
noisy_image = noisy_image.to(self.device)
added_noise = added_noise.to(self.device)
cond = cond.to(self.device)
predicted_noise = self.model(noisy_image, t, cond)
loss = nn.MSELoss()(predicted_noise, added_noise)
acc_loss += loss.item()
self.model.train()
return acc_loss / len(dataloader)
class DiffusionImageAPI:
def __init__(self, diffusion_model):
self.diffusion_model = diffusion_model
def get_noisy_image(self, image, t):
x = torch.tensor(np.array(image))
x = self.diffusion_model.normalize_image(x)
y, _ = self.diffusion_model.apply_noise(x, t)
y = self.diffusion_model.denormalize_image(y)
#print(f"Shape of Image: {y.shape}")
return Image.fromarray(y.squeeze(0).numpy().astype(np.uint8))
def get_noisy_images(self, image, time_steps):
"""
image: the image to be processed PIL.Image
time_steps: the number of time steps to apply noise (int)
"""
return [self.get_noisy_image(image, int(t)) for t in time_steps]
def tensor_to_image(self, tensor):
return Image.fromarray(tensor.cpu().numpy().astype(np.uint8))
str_to_act = defaultdict(lambda: nn.SiLU())
str_to_act.update({
"relu": nn.ReLU(),
"silu": nn.SiLU(),
"gelu": nn.GELU(),
})
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
device = t.device
t = t.unsqueeze(-1)
inv_freq = 1.0 / (10000 ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim))
sin_enc = torch.sin(t.repeat(1, self.dim // 2) * inv_freq)
cos_enc = torch.cos(t.repeat(1, self.dim // 2) * inv_freq)
pos_enc = torch.cat([sin_enc, cos_enc], dim=-1)
return pos_enc
class TimeEmbedding(nn.Module):
def __init__(self, model_dim: int, emb_dim: int, act="silu"):
super().__init__()
self.lin = nn.Linear(model_dim, emb_dim)
self.act = str_to_act[act]
self.lin2 = nn.Linear(emb_dim, emb_dim)
def forward(self, x):
x = self.lin(x)
x = self.act(x)
x = self.lin2(x)
return x
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, act="silu", dropout=None, zero=False):
super().__init__()
self.norm = nn.GroupNorm(
num_groups=32,
num_channels=in_channels,
)
self.act = str_to_act[act]
if dropout is not None:
self.dropout = nn.Dropout(dropout)
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
)
if zero:
self.conv.weight.data.zero_()
def forward(self, x):
x = self.norm(x)
x = self.act(x)
if hasattr(self, "dropout"):
x = self.dropout(x)
x = self.conv(x)
return x
class EmbeddingBlock(nn.Module):
def __init__(self, channels: int, emb_dim: int, act="silu"):
super().__init__()
self.act = str_to_act[act]
self.lin = nn.Linear(emb_dim, channels)
def forward(self, x):
x = self.act(x)
x = self.lin(x)
return x
class ResBlock(nn.Module):
def __init__(self, channels: int, emb_dim: int, dropout: float = 0, out_channels=None):
"""A resblock with a time embedding and an optional change in channel count
"""
if out_channels is None:
out_channels = channels
super().__init__()
self.conv1 = ConvBlock(channels, out_channels)
self.emb = EmbeddingBlock(out_channels, emb_dim)
self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout, zero=True)
if channels != out_channels:
self.skip_connection = nn.Conv2d(channels, out_channels, kernel_size=1)
else:
self.skip_connection = nn.Identity()
def forward(self, x, t):
original = x
x = self.conv1(x)
t = self.emb(t)
# t: (batch_size, time_embedding_dim) = (batch_size, out_channels)
# x: (batch_size, out_channels, height, width)
# we repeat the time embedding to match the shape of x
t = t.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, x.shape[2], x.shape[3])
x = x + t
x = self.conv2(x)
x = x + self.skip_connection(original)
return x
class SelfAttentionBlock(nn.Module):
def __init__(self, channels, num_heads=1):
super().__init__()
self.channels = channels
self.num_heads = num_heads
self.norm = nn.GroupNorm(32, channels)
self.attention = nn.MultiheadAttention(
embed_dim=channels,
num_heads=num_heads,
dropout=0,
batch_first=True,
bias=True,
)
def forward(self, x):
h, w = x.shape[-2:]
original = x
x = self.norm(x)
x = rearrange(x, "b c h w -> b (h w) c")
x = self.attention(x, x, x)[0]
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
return x + original
class Downsample(nn.Module):
def __init__(self, channels):
super().__init__()
# ddpm uses maxpool
# self.down = nn.MaxPool2d
# iddpm uses strided conv
self.down = nn.Conv2d(
in_channels=channels,
out_channels=channels,
kernel_size=3,
stride=2,
padding=1,
)
def forward(self, x):
return self.down(x)
class DownBlock(nn.Module):
"""According to U-Net paper
'The contracting path follows the typical architecture of a convolutional network.
It consists of the repeated application of two 3x3 convolutions (unpadded convolutions),
each followed by a rectified linear unit (ReLU) and a 2x2 max pooling operation with stride 2
for downsampling. At each downsampling step we double the number of feature channels.'
"""
def __init__(self, in_channels, out_channels, time_embedding_dim, use_attn=False, dropout=0, downsample=True, width=1):
"""in_channels will typically be half of out_channels"""
super().__init__()
self.width = width
self.use_attn = use_attn
self.do_downsample = downsample
self.blocks = nn.ModuleList()
for _ in range(width):
self.blocks.append(ResBlock(
channels=in_channels,
out_channels=out_channels,
emb_dim=time_embedding_dim,
dropout=dropout,
))
if self.use_attn:
self.blocks.append(SelfAttentionBlock(
channels=out_channels,
))
in_channels = out_channels
if self.do_downsample:
self.downsample = Downsample(out_channels)
def forward(self, x, t):
for block in self.blocks:
if isinstance(block, ResBlock):
x = block(x, t)
elif isinstance(block, SelfAttentionBlock):
x = block(x)
residual = x
if self.do_downsample:
x = self.downsample(x)
return x, residual
class Upsample(nn.Module):
def __init__(self, channels):
super().__init__()
self.upsample = nn.Upsample(scale_factor=2)
self.conv = nn.Conv2d(
in_channels=channels,
out_channels=channels,
kernel_size=3,
padding=1,
)
def forward(self, x):
x = self.upsample(x)
x = self.conv(x)
return x
class UpBlock(nn.Module):
"""According to U-Net paper
Every step in the expansive path consists of an upsampling of the feature map followed by a 2x2
convolution (“up-convolution”) that halves the number of feature channels, a concatenation with
the correspondingly cropped feature map from the contracting path, and two 3x3 convolutions,
each followed by a ReLU.
"""
def __init__(self, in_channels, out_channels, time_embedding_dim, use_attn=False, dropout=0, upsample=True, width=1):
"""in_channels will typically be double of out_channels
"""
super().__init__()
self.use_attn = use_attn
self.do_upsample = upsample
self.blocks = nn.ModuleList()
for _ in range(width):
self.blocks.append(ResBlock(
channels=in_channels,
out_channels=out_channels,
emb_dim=time_embedding_dim,
dropout=dropout,
))
if self.use_attn:
self.blocks.append(SelfAttentionBlock(
channels=out_channels,
))
in_channels = out_channels
if self.do_upsample:
self.upsample = Upsample(out_channels)
def forward(self, x, t):
for block in self.blocks:
if isinstance(block, ResBlock):
x = block(x, t)
elif isinstance(block, SelfAttentionBlock):
x = block(x)
if self.do_upsample:
x = self.upsample(x)
return x
class Bottleneck(nn.Module):
def __init__(self, channels, dropout, time_embedding_dim):
super().__init__()
in_channels = channels
out_channels = channels
self.resblock_1 = ResBlock(
channels=in_channels,
out_channels=out_channels,
dropout=dropout,
emb_dim=time_embedding_dim
)
self.attention_block = SelfAttentionBlock(
channels=out_channels,
)
self.resblock_2 = ResBlock(
channels=out_channels,
out_channels=out_channels,
dropout=dropout,
emb_dim=time_embedding_dim
)
def forward(self, x, t):
x = self.resblock_1(x, t)
x = self.attention_block(x)
x = self.resblock_2(x, t)
return x
class Unet(nn.Module):
def __init__(
self,
image_channels=3,
res_block_width=2,
starting_channels=128,
dropout=0,
channel_mults=(1, 2, 2, 4, 4),
attention_layers=(False, False, False, True, False)
):
super().__init__()
self.is_conditional = False
self.image_channels = image_channels
self.starting_channels = starting_channels
time_embedding_dim = 4 * starting_channels
self.time_encoding = SinusoidalPositionalEncoding(dim=starting_channels)
self.time_embedding = TimeEmbedding(model_dim=starting_channels, emb_dim=time_embedding_dim)
self.input = nn.Conv2d(3, starting_channels, kernel_size=3, padding=1)
current_channel_count = starting_channels
input_channel_counts = []
self.contracting_path = nn.ModuleList([])
for i, channel_multiplier in enumerate(channel_mults):
is_last_layer = i == len(channel_mults) - 1
next_channel_count = channel_multiplier * starting_channels
self.contracting_path.append(DownBlock(
in_channels=current_channel_count,
out_channels=next_channel_count,
time_embedding_dim=time_embedding_dim,
use_attn=attention_layers[i],
dropout=dropout,
downsample=not is_last_layer,
width=res_block_width,
))
current_channel_count = next_channel_count
input_channel_counts.append(current_channel_count)
self.bottleneck = Bottleneck(channels=current_channel_count, time_embedding_dim=time_embedding_dim, dropout=dropout)
self.expansive_path = nn.ModuleList([])
for i, channel_multiplier in enumerate(reversed(channel_mults)):
next_channel_count = channel_multiplier * starting_channels
self.expansive_path.append(UpBlock(
in_channels=current_channel_count + input_channel_counts.pop(),
out_channels=next_channel_count,
time_embedding_dim=time_embedding_dim,
use_attn=list(reversed(attention_layers))[i],
dropout=dropout,
upsample=i != len(channel_mults) - 1,
width=res_block_width,
))
current_channel_count = next_channel_count
last_conv = nn.Conv2d(
in_channels=starting_channels,
out_channels=image_channels,
kernel_size=3,
padding=1,
)
last_conv.weight.data.zero_()
self.head = nn.Sequential(
nn.GroupNorm(32, starting_channels),
nn.SiLU(),
last_conv,
)
def forward(self, x, t):
t = self.time_encoding(t)
return self._forward(x, t)
def _forward(self, x, t):
t = self.time_embedding(t)
x = self.input(x)
residuals = []
for contracting_block in self.contracting_path:
x, residual = contracting_block(x, t)
residuals.append(residual)
x = self.bottleneck(x, t)
for expansive_block in self.expansive_path:
# Add the residual
residual = residuals.pop()
x = torch.cat([x, residual], dim=1)
x = expansive_block(x, t)
x = self.head(x)
return x
class ConditionalUnet(nn.Module):
def __init__(self, unet, num_classes):
super().__init__()
self.is_conditional = True
self.unet = unet
self.num_classes = num_classes
self.class_embedding = nn.Embedding(num_classes, unet.starting_channels)
def forward(self, x, t, cond=None):
# cond: (batch_size, n), where n is the number of classes that we are conditioning on
t = self.unet.time_encoding(t)
if cond is not None:
cond = self.class_embedding(cond)
# sum across the classes so we get a single vector representing the set of classes
cond = cond.sum(dim=1)
t += cond
return self.unet._forward(x, t)