Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import einops | |
from einops.layers.torch import Rearrange | |
def normalize(in_channels): | |
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) | |
def swish(x): | |
return x*torch.sigmoid(x) | |
class ResBlock(nn.Module): | |
def __init__(self, in_channels, out_channels=None, activation_fn="relu"): | |
super(ResBlock, self).__init__() | |
self.in_channels = in_channels | |
self.out_channels = in_channels if out_channels is None else out_channels | |
self.norm1 = normalize(in_channels) | |
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) | |
self.norm2 = normalize(out_channels) | |
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) | |
if self.in_channels != self.out_channels: | |
self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False) | |
self.activation_fn = activation_fn | |
if activation_fn=="relu": | |
self.actn = nn.ReLU() | |
def forward(self, x_in): | |
x = x_in | |
x = self.norm1(x) | |
if self.activation_fn=="relu": | |
x = self.actn(x) | |
elif self.activation_fn=="swish": | |
x = swish(x) | |
x = self.conv1(x) | |
x = self.norm2(x) | |
if self.activation_fn=="relu": | |
x = self.actn(x) | |
elif self.activation_fn=="swish": | |
x = swish(x) | |
x = self.conv2(x) | |
if self.in_channels != self.out_channels: | |
x_in = self.conv_out(x_in) | |
return x + x_in | |
class Encoder(nn.Module): | |
def __init__(self, ): | |
super().__init__() | |
self.filters = 128 | |
self.num_res_blocks = 2 | |
self.ch_mult = [1,1,2,2,4] | |
self.in_ch_mult = (1,)+tuple(self.ch_mult) | |
self.embedding_dim = 32 | |
self.conv_downsample = False | |
self.conv1 = nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1, bias=False) | |
blocks = [] | |
for i in range(len(self.ch_mult)): | |
block_in_ch = self.filters * self.in_ch_mult[i] | |
block_out_ch = self.filters * self.ch_mult[i] | |
for _ in range(self.num_res_blocks): | |
blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish")) | |
block_in_ch = block_out_ch | |
for _ in range(self.num_res_blocks): | |
blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish")) | |
self.norm1 = normalize(block_in_ch) | |
self.conv2 = nn.Conv2d(block_in_ch, self.embedding_dim, kernel_size=1, stride=1, padding=0) | |
self.blocks = nn.ModuleList(blocks) | |
def forward(self, x): | |
x = self.conv1(x) | |
for i in range(len(self.ch_mult)): | |
for j in range(self.num_res_blocks): | |
x = self.blocks[i*2+j](x) | |
if i < len(self.ch_mult) -1: | |
x = torch.nn.functional.avg_pool2d(x, (2,2),(2,2)) | |
x = self.blocks[-2](x) | |
x = self.blocks[-1](x) | |
x = self.norm1(x) | |
x = swish(x) | |
x = self.conv2(x) | |
return x | |
class VectorQuantizer(nn.Module): | |
def __init__(self, codebook_size=8192, emb_dim=32, beta=None): | |
super(VectorQuantizer, self).__init__() | |
self.codebook_size = codebook_size # number of embeddings | |
self.emb_dim = emb_dim # dimension of embedding | |
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim) | |
self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size) | |
self.beta=0.0 | |
self.z_dim = emb_dim | |
def forward(self, z): | |
# preprocess | |
b, c, h, w = z.size() | |
flatten = z.permute(0, 2, 3, 1).reshape(-1, c) | |
codebook = self.embedding.weight | |
with torch.no_grad(): | |
tokens = torch.cdist(flatten, codebook).argmin(dim=1) | |
quantized = F.embedding(tokens, | |
codebook).view(b, h, w, c).permute(0, 3, 1, 2) | |
# compute loss | |
codebook_loss = F.mse_loss(quantized, z.detach()) | |
commitment_loss = F.mse_loss(quantized.detach(), z) | |
loss = codebook_loss + self.beta * commitment_loss | |
# perplexity | |
counts = F.one_hot(tokens, self.codebook_size).sum(dim=0).to(z.dtype) | |
# dist.all_reduce(counts) | |
p = counts / counts.sum() | |
perplexity = torch.exp(-torch.sum(p * torch.log(p + 1e-10))) | |
# postprocess | |
tokens = tokens.view(b, h, w) | |
quantized = z + (quantized - z).detach() | |
# quantized_2 = self.get_codebook_feat(tokens, (b, h, w, c)) | |
return quantized, tokens, loss, perplexity | |
def get_codebook_feat(self, indices, shape=None): | |
# input indices: batch*token_num -> (batch*token_num)*1 | |
# shape: batch, height, width, channel | |
indices = indices.view(-1,1) | |
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices) | |
min_encodings.scatter_(1, indices, 1) | |
# get quantized latent vectors | |
z_q = torch.matmul(min_encodings.float(), self.embedding.weight) | |
if shape is not None: # reshape back to match original input shape | |
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous() | |
return z_q | |
class Decoder(nn.Module): | |
def __init__(self,): | |
super().__init__() | |
self.filters = 128 | |
self.num_res_blocks = 2 | |
self.ch_mult = [1,1,2,2,4] | |
self.in_ch_mult = (1,)+tuple(self.ch_mult) | |
self.embedding_dim =32 | |
self.out_channels = 3 | |
self.in_channels = self.embedding_dim | |
self.conv_downsample = False | |
self.conv1 = nn.Conv2d(32, 512, kernel_size=3, stride=1, padding=1) | |
blocks = [] | |
block_in_ch = self.filters * self.ch_mult[-1] | |
block_out_ch = self.filters * self.ch_mult[-1] | |
#blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)) | |
for _ in range(self.num_res_blocks): | |
blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish")) | |
upsample_conv_layers = [] | |
for i in reversed(range(len(self.ch_mult))): | |
block_out_ch = self.filters * self.ch_mult[i] | |
for _ in range(self.num_res_blocks): | |
blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish")) | |
block_in_ch = block_out_ch | |
if i > 0: | |
upsample_conv_layers.append(nn.Conv2d(block_in_ch, block_out_ch*4, kernel_size=3, stride=1, padding=1)) | |
self.upsample = Rearrange("b h w (h2 w2 c) -> b (h h2) (w w2) c", h2=2, w2=2) | |
self.norm1 = normalize(block_in_ch) | |
# self.act_fn | |
self.conv6 = nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1) | |
self.blocks = nn.ModuleList(blocks) | |
self.up_convs = nn.ModuleList(upsample_conv_layers) | |
def forward(self, x): | |
x = self.conv1(x) | |
x = self.blocks[0](x) | |
x = self.blocks[1](x) | |
for i in range(len(self.ch_mult)): | |
for j in range(self.num_res_blocks): | |
x = self.blocks[2+i*2+j](x) | |
if i < len(self.ch_mult)-1: | |
x = self.up_convs[i](x) | |
#print("pre: x.size()",x.size()) | |
x = x.permute(0,2,3,1) | |
x = self.upsample(x) | |
x = x.permute(0,3,1,2) | |
#print("post: x.size()", x.size()) | |
x = self.norm1(x) | |
x = swish(x) | |
x = self.conv6(x) | |
return x | |
class VQVAE(nn.Module): | |
def __init__(self, ): | |
super().__init__() | |
self.encoder = Encoder() | |
self.quantizer = VectorQuantizer() | |
self.decoder = Decoder() | |
def forward(self, x): | |
x = self.encoder(x) | |
quant,tokens, loss, perplexity = self.quantizer(x) | |
x = self.decoder(quant) | |
return x | |
def tokenize(self, x): | |
batch_shape = x.shape[:-3] | |
x = x.reshape(-1, *x.shape[-3:]) | |
x = self.encoder(x) | |
quant,tokens, loss, perplexity = self.quantizer(x) | |
return tokens.reshape(*batch_shape, *tokens.shape[1:]) | |
def decode(self, tokens): | |
tokens = einops.rearrange(tokens, 'b ... -> b (...)') | |
b = tokens.shape[0] | |
if tokens.shape[-1] == 256: | |
hw = 16 | |
elif tokens.shape[-1] == 224: | |
hw = 14 | |
else: | |
raise ValueError("Invalid tokens shape") | |
quant = self.quantizer.get_codebook_feat(tokens, (b, hw, hw, 32)) | |
x = self.decoder(quant) | |
return x | |
class VAEDecoder(nn.Module): | |
def __init__(self, ): | |
super().__init__() | |
self.quantizer = VectorQuantizer() | |
self.decoder = Decoder() | |
def forward(self, x): | |
quant = self.quantizer.get_codebook_feat(x,(1,14,14,32)) | |
x = self.decoder(quant) | |
return x | |
def get_tokenizer(): | |
checkpoint_path = os.path.join( | |
os.path.dirname(os.path.realpath(__file__)), "xh_ckpt.pth" | |
) | |
torch_state_dict = torch.load(checkpoint_path) | |
net = VQVAE() | |
net.load_state_dict(torch_state_dict) | |
return net | |