Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
class ContextEmbedding(nn.Module): | |
def __init__(self, vocab_size, embedding_dim): | |
super(ContextEmbedding, self).__init__() | |
self.embedd_fc1 = nn.Embedding(vocab_size, embedding_dim) | |
self.fc1 = nn.Linear(embedding_dim, embedding_dim) | |
def forward(self, x): | |
x = self.embedd_fc1(x) | |
x = self.fc1(x) | |
return x | |
class TimeEbedding(nn.Module): | |
def __init__(self, time_dim, hidden_dim, out_dim): | |
super(TimeEbedding, self).__init__() | |
self.fc1 = nn.Linear(time_dim, hidden_dim) | |
self.gelu = nn.GELU() | |
self.fc2 = nn.Linear(hidden_dim, out_dim) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.gelu(x) | |
x = self.fc2(x) | |
return x | |
class ConvBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, vocab_size, embedding_dim , residual=False): | |
super(ConvBlock, self).__init__() | |
self.conv1 = nn.Conv2d(in_channels, out_channels, 2, padding="same") | |
self.conv2 = nn.Conv2d(out_channels, out_channels, 2, padding="same") | |
self.downsample = nn.Conv2d(in_channels, out_channels, 3, padding="same", bias=False) | |
self.relu = nn.ReLU() | |
self.bn = nn.BatchNorm2d(out_channels) | |
self.te = TimeEbedding(1, out_channels, out_channels) | |
self.context_embedding = ContextEmbedding(vocab_size, embedding_dim) | |
self.__residual = residual | |
def forward(self, x, t, context): | |
x1 = self.downsample(x) | |
x = self.conv1(x) | |
x = self.relu(x) | |
x = self.conv2(x) | |
x = self.relu(x) | |
time_embedding = self.te(t)[:, :, None, None] | |
context = self.context_embedding(context)[:, :, None, None] | |
if self.__residual: | |
return self.bn(x + x1) * context + time_embedding | |
return x * context + time_embedding | |
class UNET(nn.Module): | |
def __init__(self, in_channels, out_channels, residual=False): | |
super(UNET, self).__init__() | |
self.convBD1 = ConvBlock(in_channels, 64, 5, 64,residual) | |
self.convBD2 = ConvBlock(64, 128, 5, 128,residual) | |
self.convBD3 = ConvBlock(128, 256, 5, 256,residual) | |
# self.convBD4 = ConvBlock(256, 512, 5, 512,residual) | |
self.bottle_neck = ConvBlock(256, 512, 5, 512,residual) | |
# self.bottle_neck = ConvBlock(512, 1024, 5, 1024,residual) | |
# self.convBU1 = ConvBlock(1024, 512, 5, 512,residual) | |
self.convBU2 = ConvBlock(512, 256, 5, 256,residual) | |
self.convBU3 = ConvBlock(256, 128, 5, 128,residual) | |
self.convBU4 = ConvBlock(128, 64, 5, 64,residual) | |
self.convT1 = nn.ConvTranspose2d(1024, 512, 2, 2) | |
self.convT2 = nn.ConvTranspose2d(512, 256, 2, 2) | |
self.convT3 = nn.ConvTranspose2d(256, 128, 2, 2) | |
self.convT4 = nn.ConvTranspose2d(128, 64, 2, 2) | |
self.final = nn.Conv2d(64, out_channels, 1) | |
self.maxpool = nn.MaxPool2d(2) | |
def forward(self, x, t, context): | |
if x.ndim == 3: | |
x = x.unsqueeze(0) | |
x1 = self.convBD1(x, t, context) | |
x = self.maxpool(x1) | |
x2 = self.convBD2(x, t, context) | |
x = self.maxpool(x2) | |
x3 = self.convBD3(x, t, context) | |
x = self.maxpool(x3) | |
# x4 = self.convBD4(x, t, context) | |
# x = self.maxpool(x4) | |
x = self.bottle_neck(x, t, context) | |
# x = self.convT1(x) | |
# x = torch.cat([x4, x], dim=1) | |
# x = self.convBU1(x, t, context) | |
x = self.convT2(x) | |
x = torch.cat([x3, x], dim=1) | |
x = self.convBU2(x, t, context) | |
x = self.convT3(x) | |
x = torch.cat([x2, x], dim=1) | |
x = self.convBU3(x, t, context) | |
x = self.convT4(x) | |
x = torch.cat([x1, x], dim=1) | |
x = self.convBU4(x, t, context) | |
x = self.final(x) | |
return x | |
import torch | |
import torch.nn as nn | |
from model import UNET | |
from helper import DDPM, denoise_image | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
timesteps = 500 | |
beta1 = 1e-4 | |
beta2 = 0.02 | |
betas = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1) + beta1 | |
betas = betas.to(device) | |
alpha = 1.0 - betas | |
alpha_bar = torch.cumprod(alpha, dim=0).to(device) | |
model = UNET(3, 3,True) | |
model = model.to(device) | |
loss_fn = nn.MSELoss() | |
optim = torch.optim.Adam(model.parameters(), lr=1e-3) | |
sampler = DDPM(betas) |