JMalott commited on
Commit
81c65a4
1 Parent(s): 90bed62

Upload vqgan_detokenizer.py

Browse files
Files changed (1) hide show
  1. min_dalle/models/vqgan_detokenizer.py +197 -0
min_dalle/models/vqgan_detokenizer.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch import FloatTensor, LongTensor
4
+ from math import sqrt
5
+
6
+
7
+ class ResnetBlock(nn.Module):
8
+ def __init__(self, log2_count_in: int, log2_count_out: int):
9
+ super().__init__()
10
+ m, n = 2 ** log2_count_in, 2 ** log2_count_out
11
+ self.is_middle = m == n
12
+ self.norm1 = nn.GroupNorm(2 ** 5, m)
13
+ self.conv1 = nn.Conv2d(m, n, 3, padding=1)
14
+ self.norm2 = nn.GroupNorm(2 ** 5, n)
15
+ self.conv2 = nn.Conv2d(n, n, 3, padding=1)
16
+ if not self.is_middle:
17
+ self.nin_shortcut = nn.Conv2d(m, n, 1)
18
+
19
+ def forward(self, x: FloatTensor) -> FloatTensor:
20
+ h = x
21
+ h = self.norm1.forward(h)
22
+ h *= torch.sigmoid(h)
23
+ h = self.conv1.forward(h)
24
+ h = self.norm2.forward(h)
25
+ h *= torch.sigmoid(h)
26
+ h = self.conv2(h)
27
+ if not self.is_middle:
28
+ x = self.nin_shortcut.forward(x)
29
+ return x + h
30
+
31
+
32
+ class AttentionBlock(nn.Module):
33
+ def __init__(self):
34
+ super().__init__()
35
+ n = 2 ** 9
36
+ self.norm = nn.GroupNorm(2 ** 5, n)
37
+ self.q = nn.Conv2d(n, n, 1)
38
+ self.k = nn.Conv2d(n, n, 1)
39
+ self.v = nn.Conv2d(n, n, 1)
40
+ self.proj_out = nn.Conv2d(n, n, 1)
41
+
42
+ def forward(self, x: FloatTensor) -> FloatTensor:
43
+ n, m = 2 ** 9, x.shape[0]
44
+ h = x
45
+ h = self.norm(h)
46
+ k = self.k.forward(h)
47
+ v = self.v.forward(h)
48
+ q = self.q.forward(h)
49
+ k = k.reshape(m, n, -1)
50
+ v = v.reshape(m, n, -1)
51
+ q = q.reshape(m, n, -1)
52
+ q = q.permute(0, 2, 1)
53
+ w = torch.bmm(q, k)
54
+ w /= n ** 0.5
55
+ w = torch.softmax(w, dim=2)
56
+ w = w.permute(0, 2, 1)
57
+ h = torch.bmm(v, w)
58
+ token_count = int(sqrt(h.shape[-1]))
59
+ h = h.reshape(m, n, token_count, token_count)
60
+ h = self.proj_out.forward(h)
61
+ return x + h
62
+
63
+
64
+ class MiddleLayer(nn.Module):
65
+ def __init__(self):
66
+ super().__init__()
67
+ self.block_1 = ResnetBlock(9, 9)
68
+ self.attn_1 = AttentionBlock()
69
+ self.block_2 = ResnetBlock(9, 9)
70
+
71
+ def forward(self, h: FloatTensor) -> FloatTensor:
72
+ h = self.block_1.forward(h)
73
+ h = self.attn_1.forward(h)
74
+ h = self.block_2.forward(h)
75
+ return h
76
+
77
+
78
+ class Upsample(nn.Module):
79
+ def __init__(self, log2_count):
80
+ super().__init__()
81
+ n = 2 ** log2_count
82
+ self.upsample = torch.nn.UpsamplingNearest2d(scale_factor=2)
83
+ self.conv = nn.Conv2d(n, n, 3, padding=1)
84
+
85
+ def forward(self, x: FloatTensor) -> FloatTensor:
86
+ x = self.upsample.forward(x.to(torch.float32))
87
+ x = self.conv.forward(x)
88
+ return x
89
+
90
+
91
+ class UpsampleBlock(nn.Module):
92
+ def __init__(
93
+ self,
94
+ log2_count_in: int,
95
+ log2_count_out: int,
96
+ has_attention: bool,
97
+ has_upsample: bool
98
+ ):
99
+ super().__init__()
100
+ self.has_attention = has_attention
101
+ self.has_upsample = has_upsample
102
+
103
+ self.block = nn.ModuleList([
104
+ ResnetBlock(log2_count_in, log2_count_out),
105
+ ResnetBlock(log2_count_out, log2_count_out),
106
+ ResnetBlock(log2_count_out, log2_count_out)
107
+ ])
108
+
109
+ if has_attention:
110
+ self.attn = nn.ModuleList([
111
+ AttentionBlock(),
112
+ AttentionBlock(),
113
+ AttentionBlock()
114
+ ])
115
+
116
+ if has_upsample:
117
+ self.upsample = Upsample(log2_count_out)
118
+
119
+
120
+ def forward(self, h: FloatTensor) -> FloatTensor:
121
+ for j in range(3):
122
+ h = self.block[j].forward(h)
123
+ if self.has_attention:
124
+ h = self.attn[j].forward(h)
125
+ if self.has_upsample:
126
+ h = self.upsample.forward(h)
127
+ return h
128
+
129
+
130
+ class Decoder(nn.Module):
131
+ def __init__(self):
132
+ super().__init__()
133
+
134
+ self.conv_in = nn.Conv2d(2 ** 8, 2 ** 9, 3, padding=1)
135
+ self.mid = MiddleLayer()
136
+
137
+ self.up = nn.ModuleList([
138
+ UpsampleBlock(7, 7, False, False),
139
+ UpsampleBlock(8, 7, False, True),
140
+ UpsampleBlock(8, 8, False, True),
141
+ UpsampleBlock(9, 8, False, True),
142
+ UpsampleBlock(9, 9, True, True)
143
+ ])
144
+
145
+ self.norm_out = nn.GroupNorm(2 ** 5, 2 ** 7)
146
+ self.conv_out = nn.Conv2d(2 ** 7, 3, 3, padding=1)
147
+
148
+ def forward(self, z: FloatTensor) -> FloatTensor:
149
+ z = self.conv_in.forward(z)
150
+ z = self.mid.forward(z)
151
+
152
+ for i in reversed(range(5)):
153
+ z = self.up[i].forward(z)
154
+
155
+ z = self.norm_out.forward(z)
156
+ z *= torch.sigmoid(z)
157
+ z = self.conv_out.forward(z)
158
+ return z
159
+
160
+
161
+ class VQGanDetokenizer(nn.Module):
162
+ def __init__(self):
163
+ super().__init__()
164
+ vocab_count, embed_count = 2 ** 14, 2 ** 8
165
+ self.vocab_count = vocab_count
166
+ self.embedding = nn.Embedding(vocab_count, embed_count)
167
+ self.post_quant_conv = nn.Conv2d(embed_count, embed_count, 1)
168
+ self.decoder = Decoder()
169
+
170
+ def forward(self, is_seamless: bool, z: LongTensor) -> FloatTensor:
171
+ z.clamp_(0, self.vocab_count - 1)
172
+ grid_size = int(sqrt(z.shape[0]))
173
+ token_count = grid_size * 2 ** 4
174
+
175
+ if is_seamless:
176
+ z = z.view([grid_size, grid_size, 2 ** 4, 2 ** 4])
177
+ z = z.flatten(1, 2).transpose(1, 0).flatten(1, 2)
178
+ z = z.flatten().unsqueeze(1)
179
+ z = self.embedding.forward(z)
180
+ z = z.view((1, token_count, token_count, 2 ** 8))
181
+ else:
182
+ z = self.embedding.forward(z)
183
+ z = z.view((z.shape[0], 2 ** 4, 2 ** 4, 2 ** 8))
184
+
185
+ z = z.permute(0, 3, 1, 2).contiguous()
186
+ z = self.post_quant_conv.forward(z)
187
+ z = self.decoder.forward(z)
188
+ z = z.permute(0, 2, 3, 1)
189
+ z = z.clip(0.0, 1.0) * 255
190
+
191
+ if is_seamless:
192
+ z = z[0]
193
+ else:
194
+ z = z.view([grid_size, grid_size, 2 ** 8, 2 ** 8, 3])
195
+ z = z.flatten(1, 2).transpose(1, 0).flatten(1, 2)
196
+
197
+ return z