sprite-generator / model.py
basil-ahmad's picture
Upload 3 files
53ef34c verified
raw
history blame
4.55 kB
import torch
import torch.nn as nn
class ContextEmbedding(nn.Module):
def __init__(self, vocab_size, embedding_dim):
super(ContextEmbedding, self).__init__()
self.embedd_fc1 = nn.Embedding(vocab_size, embedding_dim)
self.fc1 = nn.Linear(embedding_dim, embedding_dim)
def forward(self, x):
x = self.embedd_fc1(x)
x = self.fc1(x)
return x
class TimeEbedding(nn.Module):
def __init__(self, time_dim, hidden_dim, out_dim):
super(TimeEbedding, self).__init__()
self.fc1 = nn.Linear(time_dim, hidden_dim)
self.gelu = nn.GELU()
self.fc2 = nn.Linear(hidden_dim, out_dim)
def forward(self, x):
x = self.fc1(x)
x = self.gelu(x)
x = self.fc2(x)
return x
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, vocab_size, embedding_dim , residual=False):
super(ConvBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 2, padding="same")
self.conv2 = nn.Conv2d(out_channels, out_channels, 2, padding="same")
self.downsample = nn.Conv2d(in_channels, out_channels, 3, padding="same", bias=False)
self.relu = nn.ReLU()
self.bn = nn.BatchNorm2d(out_channels)
self.te = TimeEbedding(1, out_channels, out_channels)
self.context_embedding = ContextEmbedding(vocab_size, embedding_dim)
self.__residual = residual
def forward(self, x, t, context):
x1 = self.downsample(x)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.relu(x)
time_embedding = self.te(t)[:, :, None, None]
context = self.context_embedding(context)[:, :, None, None]
if self.__residual:
return self.bn(x + x1) * context + time_embedding
return x * context + time_embedding
class UNET(nn.Module):
def __init__(self, in_channels, out_channels, residual=False):
super(UNET, self).__init__()
self.convBD1 = ConvBlock(in_channels, 64, 5, 64,residual)
self.convBD2 = ConvBlock(64, 128, 5, 128,residual)
self.convBD3 = ConvBlock(128, 256, 5, 256,residual)
# self.convBD4 = ConvBlock(256, 512, 5, 512,residual)
self.bottle_neck = ConvBlock(256, 512, 5, 512,residual)
# self.bottle_neck = ConvBlock(512, 1024, 5, 1024,residual)
# self.convBU1 = ConvBlock(1024, 512, 5, 512,residual)
self.convBU2 = ConvBlock(512, 256, 5, 256,residual)
self.convBU3 = ConvBlock(256, 128, 5, 128,residual)
self.convBU4 = ConvBlock(128, 64, 5, 64,residual)
self.convT1 = nn.ConvTranspose2d(1024, 512, 2, 2)
self.convT2 = nn.ConvTranspose2d(512, 256, 2, 2)
self.convT3 = nn.ConvTranspose2d(256, 128, 2, 2)
self.convT4 = nn.ConvTranspose2d(128, 64, 2, 2)
self.final = nn.Conv2d(64, out_channels, 1)
self.maxpool = nn.MaxPool2d(2)
def forward(self, x, t, context):
if x.ndim == 3:
x = x.unsqueeze(0)
x1 = self.convBD1(x, t, context)
x = self.maxpool(x1)
x2 = self.convBD2(x, t, context)
x = self.maxpool(x2)
x3 = self.convBD3(x, t, context)
x = self.maxpool(x3)
# x4 = self.convBD4(x, t, context)
# x = self.maxpool(x4)
x = self.bottle_neck(x, t, context)
# x = self.convT1(x)
# x = torch.cat([x4, x], dim=1)
# x = self.convBU1(x, t, context)
x = self.convT2(x)
x = torch.cat([x3, x], dim=1)
x = self.convBU2(x, t, context)
x = self.convT3(x)
x = torch.cat([x2, x], dim=1)
x = self.convBU3(x, t, context)
x = self.convT4(x)
x = torch.cat([x1, x], dim=1)
x = self.convBU4(x, t, context)
x = self.final(x)
return x
import torch
import torch.nn as nn
from model import UNET
from helper import DDPM, denoise_image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
timesteps = 500
beta1 = 1e-4
beta2 = 0.02
betas = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1) + beta1
betas = betas.to(device)
alpha = 1.0 - betas
alpha_bar = torch.cumprod(alpha, dim=0).to(device)
model = UNET(3, 3,True)
model = model.to(device)
loss_fn = nn.MSELoss()
optim = torch.optim.Adam(model.parameters(), lr=1e-3)
sampler = DDPM(betas)