MonsterForge-small / LightweightGAN.py
michaelriedl's picture
Moved GAN to single file
9a1cd80
import math
import torch
import torch.nn.functional as F
from math import log2
from torch import nn, einsum
from kornia.filters import filter2d
from einops import reduce, rearrange, repeat
def exists(val):
return val is not None
def is_power_of_two(val):
return log2(val).is_integer()
def default(val, d):
return val if exists(val) else d
def get_1d_dct(i, freq, L):
result = math.cos(math.pi * freq * (i + 0.5) / L) / math.sqrt(L)
return result * (1 if freq == 0 else math.sqrt(2))
def get_dct_weights(width, channel, fidx_u, fidx_v):
dct_weights = torch.zeros(1, channel, width, width)
c_part = channel // len(fidx_u)
for i, (u_x, v_y) in enumerate(zip(fidx_u, fidx_v)):
for x in range(width):
for y in range(width):
coor_value = get_1d_dct(x, u_x, width) * get_1d_dct(y, v_y, width)
dct_weights[:, i * c_part : (i + 1) * c_part, x, y] = coor_value
return dct_weights
class Blur(nn.Module):
def __init__(self):
super().__init__()
f = torch.Tensor([1, 2, 1])
self.register_buffer("f", f)
def forward(self, x):
f = self.f
f = f[None, None, :] * f[None, :, None]
return filter2d(x, f, normalized=True)
class ChanNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim=1, unbiased=False, keepdim=True)
mean = torch.mean(x, dim=1, keepdim=True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
def Conv2dSame(dim_in, dim_out, kernel_size, bias=True):
pad_left = kernel_size // 2
pad_right = (pad_left - 1) if (kernel_size % 2) == 0 else pad_left
return nn.Sequential(
nn.ZeroPad2d((pad_left, pad_right, pad_left, pad_right)),
nn.Conv2d(dim_in, dim_out, kernel_size, bias=bias),
)
class DepthWiseConv2d(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size, padding=0, stride=1, bias=True):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(
dim_in,
dim_in,
kernel_size=kernel_size,
padding=padding,
groups=dim_in,
stride=stride,
bias=bias,
),
nn.Conv2d(dim_in, dim_out, kernel_size=1, bias=bias),
)
def forward(self, x):
return self.net(x)
class FCANet(nn.Module):
def __init__(self, *, chan_in, chan_out, reduction=4, width):
super().__init__()
freq_w, freq_h = ([0] * 8), list(
range(8)
) # in paper, it seems 16 frequencies was ideal
dct_weights = get_dct_weights(
width, chan_in, [*freq_w, *freq_h], [*freq_h, *freq_w]
)
self.register_buffer("dct_weights", dct_weights)
chan_intermediate = max(3, chan_out // reduction)
self.net = nn.Sequential(
nn.Conv2d(chan_in, chan_intermediate, 1),
nn.LeakyReLU(0.1),
nn.Conv2d(chan_intermediate, chan_out, 1),
nn.Sigmoid(),
)
def forward(self, x):
x = reduce(
x * self.dct_weights, "b c (h h1) (w w1) -> b c h1 w1", "sum", h1=1, w1=1
)
return self.net(x)
class Generator(nn.Module):
def __init__(
self,
*,
image_size,
latent_dim=256,
fmap_max=512,
fmap_inverse_coef=12,
transparent=False,
greyscale=False,
attn_res_layers=[],
freq_chan_attn=False,
syncbatchnorm=False,
antialias=False,
):
super().__init__()
resolution = log2(image_size)
assert is_power_of_two(image_size), "image size must be a power of 2"
# Set the normalization and blur
norm_class = nn.SyncBatchNorm if syncbatchnorm else nn.BatchNorm2d
Blur = nn.Identity if not antialias else Blur
if transparent:
init_channel = 4
elif greyscale:
init_channel = 1
else:
init_channel = 3
self.latent_dim = latent_dim
fmap_max = default(fmap_max, latent_dim)
self.initial_conv = nn.Sequential(
nn.ConvTranspose2d(latent_dim, latent_dim * 2, 4),
norm_class(latent_dim * 2),
nn.GLU(dim=1),
)
num_layers = int(resolution) - 2
features = list(
map(lambda n: (n, 2 ** (fmap_inverse_coef - n)), range(2, num_layers + 2))
)
features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features))
features = list(map(lambda n: 3 if n[0] >= 8 else n[1], features))
features = [latent_dim, *features]
in_out_features = list(zip(features[:-1], features[1:]))
self.res_layers = range(2, num_layers + 2)
self.layers = nn.ModuleList([])
self.res_to_feature_map = dict(zip(self.res_layers, in_out_features))
self.sle_map = ((3, 7), (4, 8), (5, 9), (6, 10))
self.sle_map = list(
filter(lambda t: t[0] <= resolution and t[1] <= resolution, self.sle_map)
)
self.sle_map = dict(self.sle_map)
self.num_layers_spatial_res = 1
for res, (chan_in, chan_out) in zip(self.res_layers, in_out_features):
image_width = 2**res
attn = None
if image_width in attn_res_layers:
attn = PreNorm(chan_in, LinearAttention(chan_in))
sle = None
if res in self.sle_map:
residual_layer = self.sle_map[res]
sle_chan_out = self.res_to_feature_map[residual_layer - 1][-1]
if freq_chan_attn:
sle = FCANet(
chan_in=chan_out, chan_out=sle_chan_out, width=2 ** (res + 1)
)
else:
sle = GlobalContext(chan_in=chan_out, chan_out=sle_chan_out)
layer = nn.ModuleList(
[
nn.Sequential(
PixelShuffleUpsample(chan_in),
Blur(),
Conv2dSame(chan_in, chan_out * 2, 4),
Noise(),
norm_class(chan_out * 2),
nn.GLU(dim=1),
),
sle,
attn,
]
)
self.layers.append(layer)
self.out_conv = nn.Conv2d(features[-1], init_channel, 3, padding=1)
def forward(self, x):
x = rearrange(x, "b c -> b c () ()")
x = self.initial_conv(x)
x = F.normalize(x, dim=1)
residuals = dict()
for res, (up, sle, attn) in zip(self.res_layers, self.layers):
if exists(attn):
x = attn(x) + x
x = up(x)
if exists(sle):
out_res = self.sle_map[res]
residual = sle(x)
residuals[out_res] = residual
next_res = res + 1
if next_res in residuals:
x = x * residuals[next_res]
return self.out_conv(x)
class GlobalContext(nn.Module):
def __init__(self, *, chan_in, chan_out):
super().__init__()
self.to_k = nn.Conv2d(chan_in, 1, 1)
chan_intermediate = max(3, chan_out // 2)
self.net = nn.Sequential(
nn.Conv2d(chan_in, chan_intermediate, 1),
nn.LeakyReLU(0.1),
nn.Conv2d(chan_intermediate, chan_out, 1),
nn.Sigmoid(),
)
def forward(self, x):
context = self.to_k(x)
context = context.flatten(2).softmax(dim=-1)
out = einsum("b i n, b c n -> b c i", context, x.flatten(2))
out = out.unsqueeze(-1)
return self.net(out)
class LinearAttention(nn.Module):
def __init__(self, dim, dim_head=64, heads=8, kernel_size=3):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
self.dim_head = dim_head
inner_dim = dim_head * heads
self.kernel_size = kernel_size
self.nonlin = nn.GELU()
self.to_lin_q = nn.Conv2d(dim, inner_dim, 1, bias=False)
self.to_lin_kv = DepthWiseConv2d(dim, inner_dim * 2, 3, padding=1, bias=False)
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias=False)
self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias=False)
self.to_out = nn.Conv2d(inner_dim * 2, dim, 1)
def forward(self, fmap):
h, x, y = self.heads, *fmap.shape[-2:]
# linear attention
lin_q, lin_k, lin_v = (
self.to_lin_q(fmap),
*self.to_lin_kv(fmap).chunk(2, dim=1),
)
lin_q, lin_k, lin_v = map(
lambda t: rearrange(t, "b (h c) x y -> (b h) (x y) c", h=h),
(lin_q, lin_k, lin_v),
)
lin_q = lin_q.softmax(dim=-1)
lin_k = lin_k.softmax(dim=-2)
lin_q = lin_q * self.scale
context = einsum("b n d, b n e -> b d e", lin_k, lin_v)
lin_out = einsum("b n d, b d e -> b n e", lin_q, context)
lin_out = rearrange(lin_out, "(b h) (x y) d -> b (h d) x y", h=h, x=x, y=y)
# conv-like full attention
q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim=1))
q, k, v = map(
lambda t: rearrange(t, "b (h c) x y -> (b h) c x y", h=h), (q, k, v)
)
k = F.unfold(k, kernel_size=self.kernel_size, padding=self.kernel_size // 2)
v = F.unfold(v, kernel_size=self.kernel_size, padding=self.kernel_size // 2)
k, v = map(
lambda t: rearrange(t, "b (d j) n -> b n j d", d=self.dim_head), (k, v)
)
q = rearrange(q, "b c ... -> b (...) c") * self.scale
sim = einsum("b i d, b i j d -> b i j", q, k)
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
full_out = einsum("b i j, b i j d -> b i d", attn, v)
full_out = rearrange(full_out, "(b h) (x y) d -> b (h d) x y", h=h, x=x, y=y)
# add outputs of linear attention + conv like full attention
lin_out = self.nonlin(lin_out)
out = torch.cat((lin_out, full_out), dim=1)
return self.to_out(out)
class Noise(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.zeros(1))
def forward(self, x, noise=None):
b, _, h, w, device = *x.shape, x.device
if not exists(noise):
noise = torch.randn(b, 1, h, w, device=device)
return x + self.weight * noise
class PixelShuffleUpsample(nn.Module):
def __init__(self, dim, dim_out=None):
super().__init__()
dim_out = default(dim_out, dim)
conv = nn.Conv2d(dim, dim_out * 4, 1)
self.net = nn.Sequential(conv, nn.SiLU(), nn.PixelShuffle(2))
self.init_conv_(conv)
def init_conv_(self, conv):
o, i, h, w = conv.weight.shape
conv_weight = torch.empty(o // 4, i, h, w)
nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, "o ... -> (o 4) ...")
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def forward(self, x):
return self.net(x)
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = ChanNorm(dim)
def forward(self, x):
return self.fn(self.norm(x))