michaelriedl commited on
Commit
002ca81
·
1 Parent(s): 1f1f217

Initial dump

Browse files
.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Custom for repository
2
+ dev/
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ *__pycache__/
6
+ *.py[cod]
7
+
8
+ # VS Code
9
+ .vscode/
LightweightGAN/Blur.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from kornia.filters import filter2d
4
+
5
+
6
+ class Blur(nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+ f = torch.Tensor([1, 2, 1])
10
+ self.register_buffer("f", f)
11
+
12
+ def forward(self, x):
13
+ f = self.f
14
+ f = f[None, None, :] * f[None, :, None]
15
+ return filter2d(x, f, normalized=True)
LightweightGAN/ChanNorm.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class ChanNorm(nn.Module):
6
+ def __init__(self, dim, eps=1e-5):
7
+ super().__init__()
8
+ self.eps = eps
9
+ self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
10
+ self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
11
+
12
+ def forward(self, x):
13
+ var = torch.var(x, dim=1, unbiased=False, keepdim=True)
14
+ mean = torch.mean(x, dim=1, keepdim=True)
15
+ return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
LightweightGAN/Conv2dSame.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+
4
+ def Conv2dSame(dim_in, dim_out, kernel_size, bias=True):
5
+ pad_left = kernel_size // 2
6
+ pad_right = (pad_left - 1) if (kernel_size % 2) == 0 else pad_left
7
+
8
+ return nn.Sequential(
9
+ nn.ZeroPad2d((pad_left, pad_right, pad_left, pad_right)),
10
+ nn.Conv2d(dim_in, dim_out, kernel_size, bias=bias),
11
+ )
LightweightGAN/DepthWiseConv2d.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+
4
+ class DepthWiseConv2d(nn.Module):
5
+ def __init__(self, dim_in, dim_out, kernel_size, padding=0, stride=1, bias=True):
6
+ super().__init__()
7
+ self.net = nn.Sequential(
8
+ nn.Conv2d(
9
+ dim_in,
10
+ dim_in,
11
+ kernel_size=kernel_size,
12
+ padding=padding,
13
+ groups=dim_in,
14
+ stride=stride,
15
+ bias=bias,
16
+ ),
17
+ nn.Conv2d(dim_in, dim_out, kernel_size=1, bias=bias),
18
+ )
19
+
20
+ def forward(self, x):
21
+ return self.net(x)
LightweightGAN/FCANet.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from einops import reduce
3
+ from .helper_funcs import get_dct_weights
4
+
5
+
6
+ class FCANet(nn.Module):
7
+ def __init__(self, *, chan_in, chan_out, reduction=4, width):
8
+ super().__init__()
9
+
10
+ freq_w, freq_h = ([0] * 8), list(
11
+ range(8)
12
+ ) # in paper, it seems 16 frequencies was ideal
13
+ dct_weights = get_dct_weights(
14
+ width, chan_in, [*freq_w, *freq_h], [*freq_h, *freq_w]
15
+ )
16
+ self.register_buffer("dct_weights", dct_weights)
17
+
18
+ chan_intermediate = max(3, chan_out // reduction)
19
+
20
+ self.net = nn.Sequential(
21
+ nn.Conv2d(chan_in, chan_intermediate, 1),
22
+ nn.LeakyReLU(0.1),
23
+ nn.Conv2d(chan_intermediate, chan_out, 1),
24
+ nn.Sigmoid(),
25
+ )
26
+
27
+ def forward(self, x):
28
+ x = reduce(
29
+ x * self.dct_weights, "b c (h h1) (w w1) -> b c h1 w1", "sum", h1=1, w1=1
30
+ )
31
+ return self.net(x)
LightweightGAN/Generator.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+ from torch import nn
3
+ from math import log2
4
+ from einops import rearrange
5
+
6
+ from .Blur import Blur
7
+ from .Noise import Noise
8
+ from .FCANet import FCANet
9
+ from .PreNorm import PreNorm
10
+ from .Conv2dSame import Conv2dSame
11
+ from .GlobalContext import GlobalContext
12
+ from .LinearAttention import LinearAttention
13
+ from .PixelShuffleUpsample import PixelShuffleUpsample
14
+ from .helper_funcs import exists, is_power_of_two, default
15
+
16
+
17
+ class Generator(nn.Module):
18
+ def __init__(
19
+ self,
20
+ *,
21
+ image_size,
22
+ latent_dim=256,
23
+ fmap_max=512,
24
+ fmap_inverse_coef=12,
25
+ transparent=False,
26
+ greyscale=False,
27
+ attn_res_layers=[],
28
+ freq_chan_attn=False,
29
+ syncbatchnorm=False,
30
+ antialias=False,
31
+ ):
32
+ super().__init__()
33
+ resolution = log2(image_size)
34
+ assert is_power_of_two(image_size), "image size must be a power of 2"
35
+
36
+ # Set the normalization and blur
37
+ norm_class = nn.SyncBatchNorm if syncbatchnorm else nn.BatchNorm2d
38
+ Blur = nn.Identity if not antialias else Blur
39
+
40
+ if transparent:
41
+ init_channel = 4
42
+ elif greyscale:
43
+ init_channel = 1
44
+ else:
45
+ init_channel = 3
46
+
47
+ self.latent_dim = latent_dim
48
+
49
+ fmap_max = default(fmap_max, latent_dim)
50
+
51
+ self.initial_conv = nn.Sequential(
52
+ nn.ConvTranspose2d(latent_dim, latent_dim * 2, 4),
53
+ norm_class(latent_dim * 2),
54
+ nn.GLU(dim=1),
55
+ )
56
+
57
+ num_layers = int(resolution) - 2
58
+ features = list(
59
+ map(lambda n: (n, 2 ** (fmap_inverse_coef - n)), range(2, num_layers + 2))
60
+ )
61
+ features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features))
62
+ features = list(map(lambda n: 3 if n[0] >= 8 else n[1], features))
63
+ features = [latent_dim, *features]
64
+
65
+ in_out_features = list(zip(features[:-1], features[1:]))
66
+
67
+ self.res_layers = range(2, num_layers + 2)
68
+ self.layers = nn.ModuleList([])
69
+ self.res_to_feature_map = dict(zip(self.res_layers, in_out_features))
70
+
71
+ self.sle_map = ((3, 7), (4, 8), (5, 9), (6, 10))
72
+ self.sle_map = list(
73
+ filter(lambda t: t[0] <= resolution and t[1] <= resolution, self.sle_map)
74
+ )
75
+ self.sle_map = dict(self.sle_map)
76
+
77
+ self.num_layers_spatial_res = 1
78
+
79
+ for res, (chan_in, chan_out) in zip(self.res_layers, in_out_features):
80
+ image_width = 2**res
81
+
82
+ attn = None
83
+ if image_width in attn_res_layers:
84
+ attn = PreNorm(chan_in, LinearAttention(chan_in))
85
+
86
+ sle = None
87
+ if res in self.sle_map:
88
+ residual_layer = self.sle_map[res]
89
+ sle_chan_out = self.res_to_feature_map[residual_layer - 1][-1]
90
+
91
+ if freq_chan_attn:
92
+ sle = FCANet(
93
+ chan_in=chan_out, chan_out=sle_chan_out, width=2 ** (res + 1)
94
+ )
95
+ else:
96
+ sle = GlobalContext(chan_in=chan_out, chan_out=sle_chan_out)
97
+
98
+ layer = nn.ModuleList(
99
+ [
100
+ nn.Sequential(
101
+ PixelShuffleUpsample(chan_in),
102
+ Blur(),
103
+ Conv2dSame(chan_in, chan_out * 2, 4),
104
+ Noise(),
105
+ norm_class(chan_out * 2),
106
+ nn.GLU(dim=1),
107
+ ),
108
+ sle,
109
+ attn,
110
+ ]
111
+ )
112
+ self.layers.append(layer)
113
+
114
+ self.out_conv = nn.Conv2d(features[-1], init_channel, 3, padding=1)
115
+
116
+ def forward(self, x):
117
+ x = rearrange(x, "b c -> b c () ()")
118
+ x = self.initial_conv(x)
119
+ x = F.normalize(x, dim=1)
120
+
121
+ residuals = dict()
122
+
123
+ for res, (up, sle, attn) in zip(self.res_layers, self.layers):
124
+ if exists(attn):
125
+ x = attn(x) + x
126
+
127
+ x = up(x)
128
+
129
+ if exists(sle):
130
+ out_res = self.sle_map[res]
131
+ residual = sle(x)
132
+ residuals[out_res] = residual
133
+
134
+ next_res = res + 1
135
+ if next_res in residuals:
136
+ x = x * residuals[next_res]
137
+
138
+ return self.out_conv(x)
LightweightGAN/GlobalContext.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn, einsum
2
+
3
+
4
+ class GlobalContext(nn.Module):
5
+ def __init__(self, *, chan_in, chan_out):
6
+ super().__init__()
7
+ self.to_k = nn.Conv2d(chan_in, 1, 1)
8
+ chan_intermediate = max(3, chan_out // 2)
9
+
10
+ self.net = nn.Sequential(
11
+ nn.Conv2d(chan_in, chan_intermediate, 1),
12
+ nn.LeakyReLU(0.1),
13
+ nn.Conv2d(chan_intermediate, chan_out, 1),
14
+ nn.Sigmoid(),
15
+ )
16
+
17
+ def forward(self, x):
18
+ context = self.to_k(x)
19
+ context = context.flatten(2).softmax(dim=-1)
20
+ out = einsum("b i n, b c n -> b c i", context, x.flatten(2))
21
+ out = out.unsqueeze(-1)
22
+ return self.net(out)
LightweightGAN/LinearAttention.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from torch import nn, einsum
5
+ from einops import rearrange
6
+
7
+ from .DepthWiseConv2d import DepthWiseConv2d
8
+
9
+
10
+ class LinearAttention(nn.Module):
11
+ def __init__(self, dim, dim_head=64, heads=8, kernel_size=3):
12
+ super().__init__()
13
+ self.scale = dim_head**-0.5
14
+ self.heads = heads
15
+ self.dim_head = dim_head
16
+ inner_dim = dim_head * heads
17
+
18
+ self.kernel_size = kernel_size
19
+ self.nonlin = nn.GELU()
20
+
21
+ self.to_lin_q = nn.Conv2d(dim, inner_dim, 1, bias=False)
22
+ self.to_lin_kv = DepthWiseConv2d(dim, inner_dim * 2, 3, padding=1, bias=False)
23
+
24
+ self.to_q = nn.Conv2d(dim, inner_dim, 1, bias=False)
25
+ self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias=False)
26
+
27
+ self.to_out = nn.Conv2d(inner_dim * 2, dim, 1)
28
+
29
+ def forward(self, fmap):
30
+ h, x, y = self.heads, *fmap.shape[-2:]
31
+
32
+ # linear attention
33
+
34
+ lin_q, lin_k, lin_v = (
35
+ self.to_lin_q(fmap),
36
+ *self.to_lin_kv(fmap).chunk(2, dim=1),
37
+ )
38
+ lin_q, lin_k, lin_v = map(
39
+ lambda t: rearrange(t, "b (h c) x y -> (b h) (x y) c", h=h),
40
+ (lin_q, lin_k, lin_v),
41
+ )
42
+
43
+ lin_q = lin_q.softmax(dim=-1)
44
+ lin_k = lin_k.softmax(dim=-2)
45
+
46
+ lin_q = lin_q * self.scale
47
+
48
+ context = einsum("b n d, b n e -> b d e", lin_k, lin_v)
49
+ lin_out = einsum("b n d, b d e -> b n e", lin_q, context)
50
+ lin_out = rearrange(lin_out, "(b h) (x y) d -> b (h d) x y", h=h, x=x, y=y)
51
+
52
+ # conv-like full attention
53
+
54
+ q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim=1))
55
+ q, k, v = map(
56
+ lambda t: rearrange(t, "b (h c) x y -> (b h) c x y", h=h), (q, k, v)
57
+ )
58
+
59
+ k = F.unfold(k, kernel_size=self.kernel_size, padding=self.kernel_size // 2)
60
+ v = F.unfold(v, kernel_size=self.kernel_size, padding=self.kernel_size // 2)
61
+
62
+ k, v = map(
63
+ lambda t: rearrange(t, "b (d j) n -> b n j d", d=self.dim_head), (k, v)
64
+ )
65
+
66
+ q = rearrange(q, "b c ... -> b (...) c") * self.scale
67
+
68
+ sim = einsum("b i d, b i j d -> b i j", q, k)
69
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
70
+
71
+ attn = sim.softmax(dim=-1)
72
+
73
+ full_out = einsum("b i j, b i j d -> b i d", attn, v)
74
+ full_out = rearrange(full_out, "(b h) (x y) d -> b (h d) x y", h=h, x=x, y=y)
75
+
76
+ # add outputs of linear attention + conv like full attention
77
+
78
+ lin_out = self.nonlin(lin_out)
79
+ out = torch.cat((lin_out, full_out), dim=1)
80
+ return self.to_out(out)
LightweightGAN/Noise.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from .helper_funcs import exists
4
+
5
+
6
+ class Noise(nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+ self.weight = nn.Parameter(torch.zeros(1))
10
+
11
+ def forward(self, x, noise=None):
12
+ b, _, h, w, device = *x.shape, x.device
13
+
14
+ if not exists(noise):
15
+ noise = torch.randn(b, 1, h, w, device=device)
16
+
17
+ return x + self.weight * noise
LightweightGAN/PixelShuffleUpsample.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from einops import repeat
4
+ from .helper_funcs import default
5
+
6
+
7
+ class PixelShuffleUpsample(nn.Module):
8
+ def __init__(self, dim, dim_out=None):
9
+ super().__init__()
10
+ dim_out = default(dim_out, dim)
11
+ conv = nn.Conv2d(dim, dim_out * 4, 1)
12
+
13
+ self.net = nn.Sequential(conv, nn.SiLU(), nn.PixelShuffle(2))
14
+
15
+ self.init_conv_(conv)
16
+
17
+ def init_conv_(self, conv):
18
+ o, i, h, w = conv.weight.shape
19
+ conv_weight = torch.empty(o // 4, i, h, w)
20
+ nn.init.kaiming_uniform_(conv_weight)
21
+ conv_weight = repeat(conv_weight, "o ... -> (o 4) ...")
22
+
23
+ conv.weight.data.copy_(conv_weight)
24
+ nn.init.zeros_(conv.bias.data)
25
+
26
+ def forward(self, x):
27
+ return self.net(x)
LightweightGAN/PreNorm.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from .ChanNorm import ChanNorm
3
+
4
+
5
+ class PreNorm(nn.Module):
6
+ def __init__(self, dim, fn):
7
+ super().__init__()
8
+ self.fn = fn
9
+ self.norm = ChanNorm(dim)
10
+
11
+ def forward(self, x):
12
+ return self.fn(self.norm(x))
LightweightGAN/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .Generator import Generator
LightweightGAN/helper_funcs.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from math import log2
4
+
5
+
6
+ def exists(val):
7
+ return val is not None
8
+
9
+
10
+ def is_power_of_two(val):
11
+ return log2(val).is_integer()
12
+
13
+
14
+ def default(val, d):
15
+ return val if exists(val) else d
16
+
17
+
18
+ def get_1d_dct(i, freq, L):
19
+ result = math.cos(math.pi * freq * (i + 0.5) / L) / math.sqrt(L)
20
+ return result * (1 if freq == 0 else math.sqrt(2))
21
+
22
+
23
+ def get_dct_weights(width, channel, fidx_u, fidx_v):
24
+ dct_weights = torch.zeros(1, channel, width, width)
25
+ c_part = channel // len(fidx_u)
26
+
27
+ for i, (u_x, v_y) in enumerate(zip(fidx_u, fidx_v)):
28
+ for x in range(width):
29
+ for y in range(width):
30
+ coor_value = get_1d_dct(x, u_x, width) * get_1d_dct(y, v_y, width)
31
+ dct_weights[:, i * c_part : (i + 1) * c_part, x, y] = coor_value
32
+
33
+ return dct_weights
MonsterForgeModel.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import PreTrainedModel
3
+ from MonsterForgeSmallConfig import MonsterForgeSmallConfig
4
+ from LightweightGAN import Generator
5
+
6
+
7
+ class MonsterForgeModel(PreTrainedModel):
8
+ config_class = MonsterForgeSmallConfig
9
+
10
+ def __init__(self, config):
11
+ super().__init__(config)
12
+ self.model = Generator(
13
+ image_size=config.image_size,
14
+ latent_dim=config.latent_dim,
15
+ fmap_max=config.fmap_max,
16
+ fmap_inverse_coef=config.fmap_inverse_coef,
17
+ transparent=config.transparent,
18
+ greyscale=config.greyscale,
19
+ attn_res_layers=config.attn_res_layers,
20
+ freq_chan_attn=config.freq_chan_attn,
21
+ syncbatchnorm=config.syncbatchnorm,
22
+ antialias=config.antialias,
23
+ )
24
+
25
+ def forward(self, tensor):
26
+ return self.model(tensor)
27
+
28
+ def load_params(self, pt_file):
29
+ self.model.load_state_dict(torch.load(pt_file))
MonsterForgeSmallConfig.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class MonsterForgeSmallConfig(PretrainedConfig):
5
+ model_type = "lightweight-gan"
6
+
7
+ def __init__(
8
+ self,
9
+ image_size=64,
10
+ latent_dim=256,
11
+ fmap_max=512,
12
+ fmap_inverse_coef=12,
13
+ transparent=False,
14
+ greyscale=False,
15
+ attn_res_layers=[32],
16
+ freq_chan_attn=False,
17
+ syncbatchnorm=False,
18
+ antialias=False,
19
+ **kwargs,
20
+ ):
21
+ self.image_size = image_size
22
+ self.latent_dim = latent_dim
23
+ self.fmap_max = fmap_max
24
+ self.fmap_inverse_coef = fmap_inverse_coef
25
+ self.transparent = transparent
26
+ self.greyscale = greyscale
27
+ self.attn_res_layers = attn_res_layers
28
+ self.freq_chan_attn = freq_chan_attn
29
+ self.syncbatchnorm = syncbatchnorm
30
+ self.antialias = antialias
31
+ super().__init__(**kwargs)
config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "antialias": false,
3
+ "architectures": [
4
+ "MonsterForgeModel"
5
+ ],
6
+ "attn_res_layers": [
7
+ 32
8
+ ],
9
+ "auto_map": {
10
+ "AutoConfig": "MonsterForgeSmallConfig.MonsterForgeSmallConfig",
11
+ "AutoModel": "MonsterForgeModel.MonsterForgeModel"
12
+ },
13
+ "fmap_inverse_coef": 12,
14
+ "fmap_max": 512,
15
+ "freq_chan_attn": false,
16
+ "greyscale": false,
17
+ "image_size": 64,
18
+ "latent_dim": 256,
19
+ "model_type": "lightweight-gan",
20
+ "syncbatchnorm": false,
21
+ "torch_dtype": "float32",
22
+ "transformers_version": "4.31.0",
23
+ "transparent": false
24
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4786f2b7af8dfbeb0f558aa39458d3dc170c761734c4eb5334fabc9acad39590
3
+ size 94506911