import torch import torch.nn as nn class ContextEmbedding(nn.Module): def __init__(self, vocab_size, embedding_dim): super(ContextEmbedding, self).__init__() self.embedd_fc1 = nn.Embedding(vocab_size, embedding_dim) self.fc1 = nn.Linear(embedding_dim, embedding_dim) def forward(self, x): x = self.embedd_fc1(x) x = self.fc1(x) return x class TimeEbedding(nn.Module): def __init__(self, time_dim, hidden_dim, out_dim): super(TimeEbedding, self).__init__() self.fc1 = nn.Linear(time_dim, hidden_dim) self.gelu = nn.GELU() self.fc2 = nn.Linear(hidden_dim, out_dim) def forward(self, x): x = self.fc1(x) x = self.gelu(x) x = self.fc2(x) return x class ConvBlock(nn.Module): def __init__(self, in_channels, out_channels, vocab_size, embedding_dim , residual=False): super(ConvBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, 2, padding="same") self.conv2 = nn.Conv2d(out_channels, out_channels, 2, padding="same") self.downsample = nn.Conv2d(in_channels, out_channels, 3, padding="same", bias=False) self.relu = nn.ReLU() self.bn = nn.BatchNorm2d(out_channels) self.te = TimeEbedding(1, out_channels, out_channels) self.context_embedding = ContextEmbedding(vocab_size, embedding_dim) self.__residual = residual def forward(self, x, t, context): x1 = self.downsample(x) x = self.conv1(x) x = self.relu(x) x = self.conv2(x) x = self.relu(x) time_embedding = self.te(t)[:, :, None, None] context = self.context_embedding(context)[:, :, None, None] if self.__residual: return self.bn(x + x1) * context + time_embedding return x * context + time_embedding class UNET(nn.Module): def __init__(self, in_channels, out_channels, residual=False): super(UNET, self).__init__() self.convBD1 = ConvBlock(in_channels, 64, 5, 64,residual) self.convBD2 = ConvBlock(64, 128, 5, 128,residual) self.convBD3 = ConvBlock(128, 256, 5, 256,residual) # self.convBD4 = ConvBlock(256, 512, 5, 512,residual) self.bottle_neck = ConvBlock(256, 512, 5, 512,residual) # self.bottle_neck = ConvBlock(512, 1024, 5, 1024,residual) # self.convBU1 = ConvBlock(1024, 512, 5, 512,residual) self.convBU2 = ConvBlock(512, 256, 5, 256,residual) self.convBU3 = ConvBlock(256, 128, 5, 128,residual) self.convBU4 = ConvBlock(128, 64, 5, 64,residual) self.convT1 = nn.ConvTranspose2d(1024, 512, 2, 2) self.convT2 = nn.ConvTranspose2d(512, 256, 2, 2) self.convT3 = nn.ConvTranspose2d(256, 128, 2, 2) self.convT4 = nn.ConvTranspose2d(128, 64, 2, 2) self.final = nn.Conv2d(64, out_channels, 1) self.maxpool = nn.MaxPool2d(2) def forward(self, x, t, context): if x.ndim == 3: x = x.unsqueeze(0) x1 = self.convBD1(x, t, context) x = self.maxpool(x1) x2 = self.convBD2(x, t, context) x = self.maxpool(x2) x3 = self.convBD3(x, t, context) x = self.maxpool(x3) # x4 = self.convBD4(x, t, context) # x = self.maxpool(x4) x = self.bottle_neck(x, t, context) # x = self.convT1(x) # x = torch.cat([x4, x], dim=1) # x = self.convBU1(x, t, context) x = self.convT2(x) x = torch.cat([x3, x], dim=1) x = self.convBU2(x, t, context) x = self.convT3(x) x = torch.cat([x2, x], dim=1) x = self.convBU3(x, t, context) x = self.convT4(x) x = torch.cat([x1, x], dim=1) x = self.convBU4(x, t, context) x = self.final(x) return x import torch import torch.nn as nn from model import UNET from helper import DDPM, denoise_image device = torch.device("cuda" if torch.cuda.is_available() else "cpu") timesteps = 500 beta1 = 1e-4 beta2 = 0.02 betas = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1) + beta1 betas = betas.to(device) alpha = 1.0 - betas alpha_bar = torch.cumprod(alpha, dim=0).to(device) model = UNET(3, 3,True) model = model.to(device) loss_fn = nn.MSELoss() optim = torch.optim.Adam(model.parameters(), lr=1e-3) sampler = DDPM(betas)