Spaces:
Paused
Paused
from typing import Any | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from einops import rearrange | |
from collections import defaultdict | |
import torch as th | |
import numpy as np | |
import math | |
from tqdm import tqdm | |
from PIL import Image | |
class GaussianDiffusion: | |
def __init__(self, model, noise_steps, beta_0, beta_T, image_size, channels=3, schedule="linear"): | |
""" | |
suggested betas for: | |
* linear schedule: 1e-4, 0.02 | |
model: the model to be trained (nn.Module) | |
noise_steps: the number of steps to apply noise (int) | |
beta_0: the initial value of beta (float) | |
beta_T: the final value of beta (float) | |
image_size: the size of the image (int, int) | |
""" | |
self.device = 'cpu' | |
self.channels = channels | |
self.model = model | |
self.noise_steps = noise_steps | |
self.beta_0 = beta_0 | |
self.beta_T = beta_T | |
self.image_size = image_size | |
self.betas = self.beta_schedule(schedule=schedule) | |
self.alphas = 1.0 - self.betas | |
# cumulative product of alphas, so we can optimize forward process calculation | |
self.alpha_hat = torch.cumprod(self.alphas, dim=0) | |
def beta_schedule(self, schedule="cosine"): | |
if schedule == "linear": | |
return torch.linspace(self.beta_0, self.beta_T, self.noise_steps).to(self.device) | |
elif schedule == "cosine": | |
return self.betas_for_cosine(self.noise_steps) | |
elif schedule == "sigmoid": | |
return self.betas_for_sigmoid(self.noise_steps) | |
def sigmoid(x): | |
return 1 / (1 + np.exp(-x)) | |
def betas_for_sigmoid(self, num_diffusion_timesteps, start=-3,end=3, tau=1.0, clip_min = 1e-9): | |
betas = [] | |
v_start = self.sigmoid(start/tau) | |
v_end = self.sigmoid(end/tau) | |
for t in range(num_diffusion_timesteps): | |
t_float = float(t/num_diffusion_timesteps) | |
output0 = self.sigmoid((t_float* (end-start)+start)/tau) | |
output = (v_end-output0) / (v_end-v_start) | |
betas.append(np.clip(output*.2, clip_min,.2)) | |
return torch.flip(torch.tensor(betas).to(self.device),dims=[0]).float() | |
def betas_for_cosine(self,num_steps,start=0,end=1,tau=1,clip_min=1e-9): | |
v_start = math.cos(start*math.pi / 2) ** (2 * tau) | |
betas = [] | |
v_end = math.cos(end* math.pi/2) ** 2*tau | |
for t in range(num_steps): | |
t_float = float(t)/num_steps | |
output = math.cos((t_float* (end-start)+start)*math.pi/2)**(2*tau) | |
output = (v_end - output) / (v_end-v_start) | |
betas.append(np.clip(output*.2,clip_min,.2)) | |
return torch.flip(torch.tensor(betas).to(self.device),dims=[0]).float() | |
def sample_time_steps(self, batch_size=1): | |
return torch.randint(0, self.noise_steps, (batch_size,)).to(self.device) | |
def to(self,device): | |
self.device = device | |
self.betas = self.betas.to(device) | |
self.alphas = self.alphas.to(device) | |
self.alpha_hat = self.alpha_hat.to(device) | |
def q(self, x, t): | |
""" | |
Forward process | |
""" | |
pass | |
def p(self, x, t): | |
""" | |
Backward process | |
""" | |
pass | |
def apply_noise(self, x, t): | |
# force x to be (batch_size, image_width, image_height, channels) | |
if len(x.shape) == 3: | |
x = x.unsqueeze(0) | |
if type(t) == int: | |
t = torch.tensor([t]) | |
#print(f'Shape -> {x.shape}, len -> {len(x.shape)}') | |
sqrt_alpha_hat = torch.sqrt(torch.tensor([self.alpha_hat[t_] for t_ in t]).to(self.device)) | |
sqrt_one_minus_alpha_hat = torch.sqrt(torch.tensor([1.0 - self.alpha_hat[t_] for t_ in t]).to(self.device)) | |
# standard normal distribution | |
epsilon = torch.randn_like(x).to(self.device) | |
# Eq 2. in DDPM paper | |
#noisy_image = sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * epsilon | |
"""print(f''' | |
Shape of x {x.shape} | |
Shape of sqrt {sqrt_one_minus_alpha_hat.shape}''')""" | |
try: | |
#print(x.shape) | |
#noisy_image = torch.einsum("b,bwhc->bwhc", sqrt_alpha_hat, x.to(self.device)) + torch.einsum("b,bwhc->bwhc", sqrt_one_minus_alpha_hat, epsilon) | |
noisy_image = torch.einsum("b,bcwh->bcwh", sqrt_alpha_hat, x.to(self.device)) + torch.einsum("b,bcwh->bcwh", sqrt_one_minus_alpha_hat, epsilon) | |
except: | |
print(f'Failed image: shape {x.shape}') | |
#print(f'Noisy image -> {noisy_image.shape}') | |
# returning noisy iamge and the noise which was added to the image | |
#return noisy_image, epsilon | |
#return torch.clip(noisy_image, -1.0, 1.0), epsilon | |
return noisy_image, epsilon | |
def normalize_image(x): | |
# normalize image to [-1, 1] | |
return x / 255.0 * 2.0 - 1.0 | |
def denormalize_image(x): | |
# denormalize image to [0, 255] | |
return (x + 1.0) / 2.0 * 255.0 | |
def sample_step(self, x, t, cond): | |
batch_size = x.shape[0] | |
device = x.device | |
z = torch.randn_like(x) if t >= 1 else torch.zeros_like(x) | |
z = z.to(device) | |
alpha = self.alphas[t] | |
one_over_sqrt_alpha = 1.0 / torch.sqrt(alpha) | |
one_minus_alpha = 1.0 - alpha | |
sqrt_one_minus_alpha_hat = torch.sqrt(1.0 - self.alpha_hat[t]) | |
beta_hat = (1 - self.alpha_hat[t-1]) / (1 - self.alpha_hat[t]) * self.betas[t] | |
beta = self.betas[t] | |
# should we reshape the params to (batch_size, 1, 1, 1) ? | |
# we can either use beta_hat or beta_t | |
# std = torch.sqrt(beta_hat) | |
std = torch.sqrt(beta) | |
# mean + variance * z | |
if cond is not None: | |
predicted_noise = self.model(x, torch.tensor([t]).repeat(batch_size).to(device), cond) | |
else: | |
predicted_noise = self.model(x, torch.tensor([t]).repeat(batch_size).to(device)) | |
mean = one_over_sqrt_alpha * (x - one_minus_alpha / sqrt_one_minus_alpha_hat * predicted_noise) | |
x_t_minus_1 = mean + std * z | |
return x_t_minus_1 | |
def sample(self, num_samples, show_progress=True): | |
""" | |
Sample from the model | |
""" | |
cond = None | |
if self.model.is_conditional: | |
# cond is arange() | |
assert num_samples <= self.model.num_classes, "num_samples must be less than or equal to the number of classes" | |
cond = torch.arange(self.model.num_classes)[:num_samples].to(self.device) | |
cond = rearrange(cond, 'i -> i ()') | |
self.model.eval() | |
image_versions = [] | |
with torch.no_grad(): | |
x = torch.randn(num_samples, self.channels, *self.image_size).to(self.device) | |
it = reversed(range(1, self.noise_steps)) | |
if show_progress: | |
it = tqdm(it) | |
for t in it: | |
image_versions.append(self.denormalize_image(torch.clip(x, -1, 1)).clone().squeeze(0)) | |
x = self.sample_step(x, t, cond) | |
self.model.train() | |
x = torch.clip(x, -1.0, 1.0) | |
return self.denormalize_image(x), image_versions | |
def validate(self, dataloader): | |
""" | |
Calculate the loss on the validation set | |
""" | |
self.model.eval() | |
acc_loss = 0 | |
with torch.no_grad(): | |
for (image, cond) in dataloader: | |
t = self.sample_time_steps(batch_size=image.shape[0]) | |
noisy_image, added_noise = self.apply_noise(image, t) | |
noisy_image = noisy_image.to(self.device) | |
added_noise = added_noise.to(self.device) | |
cond = cond.to(self.device) | |
predicted_noise = self.model(noisy_image, t, cond) | |
loss = nn.MSELoss()(predicted_noise, added_noise) | |
acc_loss += loss.item() | |
self.model.train() | |
return acc_loss / len(dataloader) | |
class DiffusionImageAPI: | |
def __init__(self, diffusion_model): | |
self.diffusion_model = diffusion_model | |
def get_noisy_image(self, image, t): | |
x = torch.tensor(np.array(image)) | |
x = self.diffusion_model.normalize_image(x) | |
y, _ = self.diffusion_model.apply_noise(x, t) | |
y = self.diffusion_model.denormalize_image(y) | |
#print(f"Shape of Image: {y.shape}") | |
return Image.fromarray(y.squeeze(0).numpy().astype(np.uint8)) | |
def get_noisy_images(self, image, time_steps): | |
""" | |
image: the image to be processed PIL.Image | |
time_steps: the number of time steps to apply noise (int) | |
""" | |
return [self.get_noisy_image(image, int(t)) for t in time_steps] | |
def tensor_to_image(self, tensor): | |
return Image.fromarray(tensor.cpu().numpy().astype(np.uint8)) | |
str_to_act = defaultdict(lambda: nn.SiLU()) | |
str_to_act.update({ | |
"relu": nn.ReLU(), | |
"silu": nn.SiLU(), | |
"gelu": nn.GELU(), | |
}) | |
class SinusoidalPositionalEncoding(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.dim = dim | |
def forward(self, t): | |
device = t.device | |
t = t.unsqueeze(-1) | |
inv_freq = 1.0 / (10000 ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim)) | |
sin_enc = torch.sin(t.repeat(1, self.dim // 2) * inv_freq) | |
cos_enc = torch.cos(t.repeat(1, self.dim // 2) * inv_freq) | |
pos_enc = torch.cat([sin_enc, cos_enc], dim=-1) | |
return pos_enc | |
class TimeEmbedding(nn.Module): | |
def __init__(self, model_dim: int, emb_dim: int, act="silu"): | |
super().__init__() | |
self.lin = nn.Linear(model_dim, emb_dim) | |
self.act = str_to_act[act] | |
self.lin2 = nn.Linear(emb_dim, emb_dim) | |
def forward(self, x): | |
x = self.lin(x) | |
x = self.act(x) | |
x = self.lin2(x) | |
return x | |
class ConvBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, act="silu", dropout=None, zero=False): | |
super().__init__() | |
self.norm = nn.GroupNorm( | |
num_groups=32, | |
num_channels=in_channels, | |
) | |
self.act = str_to_act[act] | |
if dropout is not None: | |
self.dropout = nn.Dropout(dropout) | |
self.conv = nn.Conv2d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=3, | |
padding=1, | |
) | |
if zero: | |
self.conv.weight.data.zero_() | |
def forward(self, x): | |
x = self.norm(x) | |
x = self.act(x) | |
if hasattr(self, "dropout"): | |
x = self.dropout(x) | |
x = self.conv(x) | |
return x | |
class EmbeddingBlock(nn.Module): | |
def __init__(self, channels: int, emb_dim: int, act="silu"): | |
super().__init__() | |
self.act = str_to_act[act] | |
self.lin = nn.Linear(emb_dim, channels) | |
def forward(self, x): | |
x = self.act(x) | |
x = self.lin(x) | |
return x | |
class ResBlock(nn.Module): | |
def __init__(self, channels: int, emb_dim: int, dropout: float = 0, out_channels=None): | |
"""A resblock with a time embedding and an optional change in channel count | |
""" | |
if out_channels is None: | |
out_channels = channels | |
super().__init__() | |
self.conv1 = ConvBlock(channels, out_channels) | |
self.emb = EmbeddingBlock(out_channels, emb_dim) | |
self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout, zero=True) | |
if channels != out_channels: | |
self.skip_connection = nn.Conv2d(channels, out_channels, kernel_size=1) | |
else: | |
self.skip_connection = nn.Identity() | |
def forward(self, x, t): | |
original = x | |
x = self.conv1(x) | |
t = self.emb(t) | |
# t: (batch_size, time_embedding_dim) = (batch_size, out_channels) | |
# x: (batch_size, out_channels, height, width) | |
# we repeat the time embedding to match the shape of x | |
t = t.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, x.shape[2], x.shape[3]) | |
x = x + t | |
x = self.conv2(x) | |
x = x + self.skip_connection(original) | |
return x | |
class SelfAttentionBlock(nn.Module): | |
def __init__(self, channels, num_heads=1): | |
super().__init__() | |
self.channels = channels | |
self.num_heads = num_heads | |
self.norm = nn.GroupNorm(32, channels) | |
self.attention = nn.MultiheadAttention( | |
embed_dim=channels, | |
num_heads=num_heads, | |
dropout=0, | |
batch_first=True, | |
bias=True, | |
) | |
def forward(self, x): | |
h, w = x.shape[-2:] | |
original = x | |
x = self.norm(x) | |
x = rearrange(x, "b c h w -> b (h w) c") | |
x = self.attention(x, x, x)[0] | |
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) | |
return x + original | |
class Downsample(nn.Module): | |
def __init__(self, channels): | |
super().__init__() | |
# ddpm uses maxpool | |
# self.down = nn.MaxPool2d | |
# iddpm uses strided conv | |
self.down = nn.Conv2d( | |
in_channels=channels, | |
out_channels=channels, | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
) | |
def forward(self, x): | |
return self.down(x) | |
class DownBlock(nn.Module): | |
"""According to U-Net paper | |
'The contracting path follows the typical architecture of a convolutional network. | |
It consists of the repeated application of two 3x3 convolutions (unpadded convolutions), | |
each followed by a rectified linear unit (ReLU) and a 2x2 max pooling operation with stride 2 | |
for downsampling. At each downsampling step we double the number of feature channels.' | |
""" | |
def __init__(self, in_channels, out_channels, time_embedding_dim, use_attn=False, dropout=0, downsample=True, width=1): | |
"""in_channels will typically be half of out_channels""" | |
super().__init__() | |
self.width = width | |
self.use_attn = use_attn | |
self.do_downsample = downsample | |
self.blocks = nn.ModuleList() | |
for _ in range(width): | |
self.blocks.append(ResBlock( | |
channels=in_channels, | |
out_channels=out_channels, | |
emb_dim=time_embedding_dim, | |
dropout=dropout, | |
)) | |
if self.use_attn: | |
self.blocks.append(SelfAttentionBlock( | |
channels=out_channels, | |
)) | |
in_channels = out_channels | |
if self.do_downsample: | |
self.downsample = Downsample(out_channels) | |
def forward(self, x, t): | |
for block in self.blocks: | |
if isinstance(block, ResBlock): | |
x = block(x, t) | |
elif isinstance(block, SelfAttentionBlock): | |
x = block(x) | |
residual = x | |
if self.do_downsample: | |
x = self.downsample(x) | |
return x, residual | |
class Upsample(nn.Module): | |
def __init__(self, channels): | |
super().__init__() | |
self.upsample = nn.Upsample(scale_factor=2) | |
self.conv = nn.Conv2d( | |
in_channels=channels, | |
out_channels=channels, | |
kernel_size=3, | |
padding=1, | |
) | |
def forward(self, x): | |
x = self.upsample(x) | |
x = self.conv(x) | |
return x | |
class UpBlock(nn.Module): | |
"""According to U-Net paper | |
Every step in the expansive path consists of an upsampling of the feature map followed by a 2x2 | |
convolution (“up-convolution”) that halves the number of feature channels, a concatenation with | |
the correspondingly cropped feature map from the contracting path, and two 3x3 convolutions, | |
each followed by a ReLU. | |
""" | |
def __init__(self, in_channels, out_channels, time_embedding_dim, use_attn=False, dropout=0, upsample=True, width=1): | |
"""in_channels will typically be double of out_channels | |
""" | |
super().__init__() | |
self.use_attn = use_attn | |
self.do_upsample = upsample | |
self.blocks = nn.ModuleList() | |
for _ in range(width): | |
self.blocks.append(ResBlock( | |
channels=in_channels, | |
out_channels=out_channels, | |
emb_dim=time_embedding_dim, | |
dropout=dropout, | |
)) | |
if self.use_attn: | |
self.blocks.append(SelfAttentionBlock( | |
channels=out_channels, | |
)) | |
in_channels = out_channels | |
if self.do_upsample: | |
self.upsample = Upsample(out_channels) | |
def forward(self, x, t): | |
for block in self.blocks: | |
if isinstance(block, ResBlock): | |
x = block(x, t) | |
elif isinstance(block, SelfAttentionBlock): | |
x = block(x) | |
if self.do_upsample: | |
x = self.upsample(x) | |
return x | |
class Bottleneck(nn.Module): | |
def __init__(self, channels, dropout, time_embedding_dim): | |
super().__init__() | |
in_channels = channels | |
out_channels = channels | |
self.resblock_1 = ResBlock( | |
channels=in_channels, | |
out_channels=out_channels, | |
dropout=dropout, | |
emb_dim=time_embedding_dim | |
) | |
self.attention_block = SelfAttentionBlock( | |
channels=out_channels, | |
) | |
self.resblock_2 = ResBlock( | |
channels=out_channels, | |
out_channels=out_channels, | |
dropout=dropout, | |
emb_dim=time_embedding_dim | |
) | |
def forward(self, x, t): | |
x = self.resblock_1(x, t) | |
x = self.attention_block(x) | |
x = self.resblock_2(x, t) | |
return x | |
class Unet(nn.Module): | |
def __init__( | |
self, | |
image_channels=3, | |
res_block_width=2, | |
starting_channels=128, | |
dropout=0, | |
channel_mults=(1, 2, 2, 4, 4), | |
attention_layers=(False, False, False, True, False) | |
): | |
super().__init__() | |
self.is_conditional = False | |
self.image_channels = image_channels | |
self.starting_channels = starting_channels | |
time_embedding_dim = 4 * starting_channels | |
self.time_encoding = SinusoidalPositionalEncoding(dim=starting_channels) | |
self.time_embedding = TimeEmbedding(model_dim=starting_channels, emb_dim=time_embedding_dim) | |
self.input = nn.Conv2d(3, starting_channels, kernel_size=3, padding=1) | |
current_channel_count = starting_channels | |
input_channel_counts = [] | |
self.contracting_path = nn.ModuleList([]) | |
for i, channel_multiplier in enumerate(channel_mults): | |
is_last_layer = i == len(channel_mults) - 1 | |
next_channel_count = channel_multiplier * starting_channels | |
self.contracting_path.append(DownBlock( | |
in_channels=current_channel_count, | |
out_channels=next_channel_count, | |
time_embedding_dim=time_embedding_dim, | |
use_attn=attention_layers[i], | |
dropout=dropout, | |
downsample=not is_last_layer, | |
width=res_block_width, | |
)) | |
current_channel_count = next_channel_count | |
input_channel_counts.append(current_channel_count) | |
self.bottleneck = Bottleneck(channels=current_channel_count, time_embedding_dim=time_embedding_dim, dropout=dropout) | |
self.expansive_path = nn.ModuleList([]) | |
for i, channel_multiplier in enumerate(reversed(channel_mults)): | |
next_channel_count = channel_multiplier * starting_channels | |
self.expansive_path.append(UpBlock( | |
in_channels=current_channel_count + input_channel_counts.pop(), | |
out_channels=next_channel_count, | |
time_embedding_dim=time_embedding_dim, | |
use_attn=list(reversed(attention_layers))[i], | |
dropout=dropout, | |
upsample=i != len(channel_mults) - 1, | |
width=res_block_width, | |
)) | |
current_channel_count = next_channel_count | |
last_conv = nn.Conv2d( | |
in_channels=starting_channels, | |
out_channels=image_channels, | |
kernel_size=3, | |
padding=1, | |
) | |
last_conv.weight.data.zero_() | |
self.head = nn.Sequential( | |
nn.GroupNorm(32, starting_channels), | |
nn.SiLU(), | |
last_conv, | |
) | |
def forward(self, x, t): | |
t = self.time_encoding(t) | |
return self._forward(x, t) | |
def _forward(self, x, t): | |
t = self.time_embedding(t) | |
x = self.input(x) | |
residuals = [] | |
for contracting_block in self.contracting_path: | |
x, residual = contracting_block(x, t) | |
residuals.append(residual) | |
x = self.bottleneck(x, t) | |
for expansive_block in self.expansive_path: | |
# Add the residual | |
residual = residuals.pop() | |
x = torch.cat([x, residual], dim=1) | |
x = expansive_block(x, t) | |
x = self.head(x) | |
return x | |
class ConditionalUnet(nn.Module): | |
def __init__(self, unet, num_classes): | |
super().__init__() | |
self.is_conditional = True | |
self.unet = unet | |
self.num_classes = num_classes | |
self.class_embedding = nn.Embedding(num_classes, unet.starting_channels) | |
def forward(self, x, t, cond=None): | |
# cond: (batch_size, n), where n is the number of classes that we are conditioning on | |
t = self.unet.time_encoding(t) | |
if cond is not None: | |
cond = self.class_embedding(cond) | |
# sum across the classes so we get a single vector representing the set of classes | |
cond = cond.sum(dim=1) | |
t += cond | |
return self.unet._forward(x, t) |