|
import torch |
|
from torch import nn |
|
from torchtools.nn import VectorQuantize |
|
from einops import rearrange |
|
import torch.nn.functional as F |
|
import math |
|
class ResBlock(nn.Module): |
|
def __init__(self, c, c_hidden): |
|
super().__init__() |
|
|
|
self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) |
|
self.depthwise = nn.Sequential( |
|
nn.ReplicationPad2d(1), |
|
nn.Conv2d(c, c, kernel_size=3, groups=c) |
|
) |
|
|
|
|
|
self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) |
|
self.channelwise = nn.Sequential( |
|
nn.Linear(c, c_hidden), |
|
nn.GELU(), |
|
nn.Linear(c_hidden, c), |
|
) |
|
|
|
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) |
|
|
|
|
|
def _basic_init(module): |
|
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): |
|
torch.nn.init.xavier_uniform_(module.weight) |
|
if module.bias is not None: |
|
nn.init.constant_(module.bias, 0) |
|
|
|
self.apply(_basic_init) |
|
|
|
def _norm(self, x, norm): |
|
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) |
|
|
|
def forward(self, x): |
|
|
|
mods = self.gammas |
|
|
|
x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1] |
|
|
|
|
|
x = x + self.depthwise(x_temp) * mods[2] |
|
|
|
x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4] |
|
x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] |
|
|
|
return x |
|
|
|
|
|
def extract_patches(tensor, patch_size, stride): |
|
b, c, H, W = tensor.shape |
|
pad_h = (patch_size - (H - patch_size) % stride) % stride |
|
pad_w = (patch_size - (W - patch_size) % stride) % stride |
|
tensor = F.pad(tensor, (0, pad_w, 0, pad_h), mode='reflect') |
|
|
|
|
|
patches = tensor.unfold(2, patch_size, stride).unfold(3, patch_size, stride) |
|
patches = patches.contiguous().view(b, c, -1, patch_size, patch_size) |
|
patches = patches.permute(0, 2, 1, 3, 4) |
|
return patches, (H, W) |
|
|
|
def fuse_patches(patches, patch_size, stride, H, W): |
|
|
|
b, num_patches, c, _, _ = patches.shape |
|
patches = patches.permute(0, 2, 1, 3, 4) |
|
|
|
|
|
|
|
pad_h = (patch_size - (H - patch_size) % stride) % stride |
|
pad_w = (patch_size - (W - patch_size) % stride) % stride |
|
out_h = H + pad_h |
|
out_w = W + pad_w |
|
patches = patches.contiguous().view(b, c , -1, patch_size*patch_size ).permute(0, 1, 3, 2) |
|
patches = patches.contiguous().view(b, c*patch_size*patch_size, -1) |
|
|
|
tensor = F.fold(patches, output_size=(out_h, out_w), kernel_size=patch_size, stride=stride) |
|
overlap_cnt = F.fold(torch.ones_like(patches), output_size=(out_h, out_w), kernel_size=patch_size, stride=stride) |
|
tensor = tensor / overlap_cnt |
|
print('end fuse patch', tensor.shape, (tensor.dtype)) |
|
return tensor[:, :, :H, :W] |
|
|
|
|
|
|
|
class StageA(nn.Module): |
|
def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192, |
|
scale_factor=0.43): |
|
super().__init__() |
|
self.c_latent = c_latent |
|
self.scale_factor = scale_factor |
|
c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))] |
|
|
|
|
|
self.in_block = nn.Sequential( |
|
nn.PixelUnshuffle(2), |
|
nn.Conv2d(3 * 4, c_levels[0], kernel_size=1) |
|
) |
|
down_blocks = [] |
|
for i in range(levels): |
|
if i > 0: |
|
down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) |
|
block = ResBlock(c_levels[i], c_levels[i] * 4) |
|
down_blocks.append(block) |
|
down_blocks.append(nn.Sequential( |
|
nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), |
|
nn.BatchNorm2d(c_latent), |
|
)) |
|
self.down_blocks = nn.Sequential(*down_blocks) |
|
self.down_blocks[0] |
|
|
|
self.codebook_size = codebook_size |
|
self.vquantizer = VectorQuantize(c_latent, k=codebook_size) |
|
|
|
|
|
up_blocks = [nn.Sequential( |
|
nn.Conv2d(c_latent, c_levels[-1], kernel_size=1) |
|
)] |
|
for i in range(levels): |
|
for j in range(bottleneck_blocks if i == 0 else 1): |
|
block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4) |
|
up_blocks.append(block) |
|
if i < levels - 1: |
|
up_blocks.append( |
|
nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, |
|
padding=1)) |
|
self.up_blocks = nn.Sequential(*up_blocks) |
|
self.out_block = nn.Sequential( |
|
nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1), |
|
nn.PixelShuffle(2), |
|
) |
|
|
|
def encode(self, x, quantize=False): |
|
x = self.in_block(x) |
|
x = self.down_blocks(x) |
|
if quantize: |
|
qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1) |
|
return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25 |
|
else: |
|
return x / self.scale_factor, None, None, None |
|
|
|
|
|
|
|
def decode(self, x, tiled_decoding=False): |
|
x = x * self.scale_factor |
|
x = self.up_blocks(x) |
|
x = self.out_block(x) |
|
return x |
|
|
|
def forward(self, x, quantize=False): |
|
qe, x, _, vq_loss = self.encode(x, quantize) |
|
x = self.decode(qe) |
|
return x, vq_loss |
|
|
|
|
|
class Discriminator(nn.Module): |
|
def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6): |
|
super().__init__() |
|
d = max(depth - 3, 3) |
|
layers = [ |
|
nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)), |
|
nn.LeakyReLU(0.2), |
|
] |
|
for i in range(depth - 1): |
|
c_in = c_hidden // (2 ** max((d - i), 0)) |
|
c_out = c_hidden // (2 ** max((d - 1 - i), 0)) |
|
layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) |
|
layers.append(nn.InstanceNorm2d(c_out)) |
|
layers.append(nn.LeakyReLU(0.2)) |
|
self.encoder = nn.Sequential(*layers) |
|
self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1) |
|
self.logits = nn.Sigmoid() |
|
|
|
def forward(self, x, cond=None): |
|
x = self.encoder(x) |
|
if cond is not None: |
|
cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1)) |
|
x = torch.cat([x, cond], dim=1) |
|
x = self.shuffle(x) |
|
x = self.logits(x) |
|
return x |
|
|