Esmail-AGumaan commited on
Commit
f8b823c
1 Parent(s): 898fdaa

Update diffusion.py

Browse files
Files changed (1) hide show
  1. diffusion.py +212 -212
diffusion.py CHANGED
@@ -1,213 +1,213 @@
1
- import torch
2
- from torch import nn
3
- from torch.nn import functional as F
4
- from nanograd.models.stable_diffusion.attention import SelfAttention, CrossAttention
5
-
6
- class TimeEmbedding(nn.Module):
7
- def __init__(self, n_embd):
8
- super().__init__()
9
- self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
10
- self.linear_2 = nn.Linear(4 * n_embd, 4 * n_embd)
11
-
12
- def forward(self, x):
13
- x = self.linear_1(x)
14
- x = F.silu(x)
15
- x = self.linear_2(x)
16
-
17
- return x
18
-
19
- class UNET_ResidualBlock(nn.Module):
20
- def __init__(self, in_channels, out_channels, n_time=1280):
21
- super().__init__()
22
- self.groupnorm_feature = nn.GroupNorm(32, in_channels)
23
- self.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
24
- self.linear_time = nn.Linear(n_time, out_channels)
25
-
26
- self.groupnorm_merged = nn.GroupNorm(32, out_channels)
27
- self.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
28
-
29
- if in_channels == out_channels:
30
- self.residual_layer = nn.Identity()
31
- else:
32
- self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
33
-
34
- def forward(self, feature, time):
35
- residue = feature
36
-
37
- feature = self.groupnorm_feature(feature)
38
- feature = F.silu(feature)
39
- feature = self.conv_feature(feature)
40
-
41
- time = F.silu(time)
42
-
43
- time = self.linear_time(time)
44
- merged = feature + time.unsqueeze(-1).unsqueeze(-1)
45
- merged = self.groupnorm_merged(merged)
46
- merged = F.silu(merged)
47
- merged = self.conv_merged(merged)
48
-
49
- return merged + self.residual_layer(residue)
50
-
51
- class UNET_AttentionBlock(nn.Module):
52
- def __init__(self, n_head: int, n_embd: int, d_context=768):
53
- super().__init__()
54
- channels = n_head * n_embd
55
-
56
- self.groupnorm = nn.GroupNorm(32, channels, eps=1e-6)
57
- self.conv_input = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
58
-
59
- self.layernorm_1 = nn.LayerNorm(channels)
60
- self.attention_1 = SelfAttention(n_head, channels, in_proj_bias=False)
61
- self.layernorm_2 = nn.LayerNorm(channels)
62
- self.attention_2 = CrossAttention(n_head, channels, d_context, in_proj_bias=False)
63
- self.layernorm_3 = nn.LayerNorm(channels)
64
- self.linear_geglu_1 = nn.Linear(channels, 4 * channels * 2)
65
- self.linear_geglu_2 = nn.Linear(4 * channels, channels)
66
-
67
- self.conv_output = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
68
-
69
- def forward(self, x, context):
70
- residue_long = x
71
-
72
- x = self.groupnorm(x)
73
- x = self.conv_input(x)
74
-
75
- n, c, h, w = x.shape
76
- x = x.view((n, c, h * w))
77
-
78
- x = x.transpose(-1, -2)
79
-
80
- residue_short = x
81
-
82
- x = self.layernorm_1(x)
83
- x = self.attention_1(x)
84
- x += residue_short
85
-
86
- residue_short = x
87
-
88
- x = self.layernorm_2(x)
89
- x = self.attention_2(x, context)
90
-
91
- x += residue_short
92
-
93
- residue_short = x
94
-
95
- x = self.layernorm_3(x)
96
-
97
- # GeGLU as implemented in the original code: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/attention.py#L37C10-L37C10
98
- x, gate = self.linear_geglu_1(x).chunk(2, dim=-1)
99
- x = x * F.gelu(gate)
100
- x = self.linear_geglu_2(x)
101
- x += residue_short
102
- x = x.transpose(-1, -2)
103
-
104
- x = x.view((n, c, h, w))
105
-
106
- return self.conv_output(x) + residue_long
107
-
108
- class Upsample(nn.Module):
109
- def __init__(self, channels):
110
- super().__init__()
111
- self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
112
-
113
- def forward(self, x):
114
- x = F.interpolate(x, scale_factor=2, mode='nearest')
115
- return self.conv(x)
116
-
117
- class SwitchSequential(nn.Sequential):
118
- def forward(self, x, context, time):
119
- for layer in self:
120
- if isinstance(layer, UNET_AttentionBlock):
121
- x = layer(x, context)
122
- elif isinstance(layer, UNET_ResidualBlock):
123
- x = layer(x, time)
124
- else:
125
- x = layer(x)
126
- return x
127
-
128
- class UNET(nn.Module):
129
- def __init__(self):
130
- super().__init__()
131
- self.encoders = nn.ModuleList([
132
- SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),
133
- SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
134
- SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
135
- SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)),
136
- SwitchSequential(UNET_ResidualBlock(320, 640), UNET_AttentionBlock(8, 80)),
137
- SwitchSequential(UNET_ResidualBlock(640, 640), UNET_AttentionBlock(8, 80)),
138
- SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)),
139
- SwitchSequential(UNET_ResidualBlock(640, 1280), UNET_AttentionBlock(8, 160)),
140
- SwitchSequential(UNET_ResidualBlock(1280, 1280), UNET_AttentionBlock(8, 160)),
141
- SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)),
142
- SwitchSequential(UNET_ResidualBlock(1280, 1280)),
143
- SwitchSequential(UNET_ResidualBlock(1280, 1280)),
144
- ])
145
-
146
- self.bottleneck = SwitchSequential(
147
- UNET_ResidualBlock(1280, 1280),
148
- UNET_AttentionBlock(8, 160),
149
- UNET_ResidualBlock(1280, 1280),
150
- )
151
-
152
- self.decoders = nn.ModuleList([
153
- SwitchSequential(UNET_ResidualBlock(2560, 1280)),
154
- SwitchSequential(UNET_ResidualBlock(2560, 1280)),
155
- SwitchSequential(UNET_ResidualBlock(2560, 1280), Upsample(1280)),
156
- SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),
157
-
158
- SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),
159
- SwitchSequential(UNET_ResidualBlock(1920, 1280), UNET_AttentionBlock(8, 160), Upsample(1280)),
160
-
161
- SwitchSequential(UNET_ResidualBlock(1920, 640), UNET_AttentionBlock(8, 80)),
162
- SwitchSequential(UNET_ResidualBlock(1280, 640), UNET_AttentionBlock(8, 80)),
163
-
164
- SwitchSequential(UNET_ResidualBlock(960, 640), UNET_AttentionBlock(8, 80), Upsample(640)),
165
- SwitchSequential(UNET_ResidualBlock(960, 320), UNET_AttentionBlock(8, 40)),
166
- SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),
167
-
168
- SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),
169
- ])
170
-
171
- def forward(self, x, context, time):
172
- skip_connections = []
173
- for layers in self.encoders:
174
- x = layers(x, context, time)
175
- skip_connections.append(x)
176
-
177
- x = self.bottleneck(x, context, time)
178
-
179
- for layers in self.decoders:
180
- x = torch.cat((x, skip_connections.pop()), dim=1)
181
- x = layers(x, context, time)
182
-
183
- return x
184
-
185
-
186
- class UNET_OutputLayer(nn.Module):
187
- def __init__(self, in_channels, out_channels):
188
- super().__init__()
189
- self.groupnorm = nn.GroupNorm(32, in_channels)
190
- self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
191
-
192
- def forward(self, x):
193
- x = self.groupnorm(x)
194
- x = F.silu(x)
195
- x = self.conv(x)
196
-
197
- return x
198
-
199
- class Diffusion(nn.Module):
200
- def __init__(self):
201
- super().__init__()
202
- self.time_embedding = TimeEmbedding(320)
203
- self.unet = UNET()
204
- self.final = UNET_OutputLayer(320, 4)
205
-
206
- def forward(self, latent, context, time):
207
- time = self.time_embedding(time)
208
-
209
- output = self.unet(latent, context, time)
210
-
211
- output = self.final(output)
212
-
213
  return output
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ from attention import SelfAttention, CrossAttention
5
+
6
+ class TimeEmbedding(nn.Module):
7
+ def __init__(self, n_embd):
8
+ super().__init__()
9
+ self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
10
+ self.linear_2 = nn.Linear(4 * n_embd, 4 * n_embd)
11
+
12
+ def forward(self, x):
13
+ x = self.linear_1(x)
14
+ x = F.silu(x)
15
+ x = self.linear_2(x)
16
+
17
+ return x
18
+
19
+ class UNET_ResidualBlock(nn.Module):
20
+ def __init__(self, in_channels, out_channels, n_time=1280):
21
+ super().__init__()
22
+ self.groupnorm_feature = nn.GroupNorm(32, in_channels)
23
+ self.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
24
+ self.linear_time = nn.Linear(n_time, out_channels)
25
+
26
+ self.groupnorm_merged = nn.GroupNorm(32, out_channels)
27
+ self.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
28
+
29
+ if in_channels == out_channels:
30
+ self.residual_layer = nn.Identity()
31
+ else:
32
+ self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
33
+
34
+ def forward(self, feature, time):
35
+ residue = feature
36
+
37
+ feature = self.groupnorm_feature(feature)
38
+ feature = F.silu(feature)
39
+ feature = self.conv_feature(feature)
40
+
41
+ time = F.silu(time)
42
+
43
+ time = self.linear_time(time)
44
+ merged = feature + time.unsqueeze(-1).unsqueeze(-1)
45
+ merged = self.groupnorm_merged(merged)
46
+ merged = F.silu(merged)
47
+ merged = self.conv_merged(merged)
48
+
49
+ return merged + self.residual_layer(residue)
50
+
51
+ class UNET_AttentionBlock(nn.Module):
52
+ def __init__(self, n_head: int, n_embd: int, d_context=768):
53
+ super().__init__()
54
+ channels = n_head * n_embd
55
+
56
+ self.groupnorm = nn.GroupNorm(32, channels, eps=1e-6)
57
+ self.conv_input = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
58
+
59
+ self.layernorm_1 = nn.LayerNorm(channels)
60
+ self.attention_1 = SelfAttention(n_head, channels, in_proj_bias=False)
61
+ self.layernorm_2 = nn.LayerNorm(channels)
62
+ self.attention_2 = CrossAttention(n_head, channels, d_context, in_proj_bias=False)
63
+ self.layernorm_3 = nn.LayerNorm(channels)
64
+ self.linear_geglu_1 = nn.Linear(channels, 4 * channels * 2)
65
+ self.linear_geglu_2 = nn.Linear(4 * channels, channels)
66
+
67
+ self.conv_output = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
68
+
69
+ def forward(self, x, context):
70
+ residue_long = x
71
+
72
+ x = self.groupnorm(x)
73
+ x = self.conv_input(x)
74
+
75
+ n, c, h, w = x.shape
76
+ x = x.view((n, c, h * w))
77
+
78
+ x = x.transpose(-1, -2)
79
+
80
+ residue_short = x
81
+
82
+ x = self.layernorm_1(x)
83
+ x = self.attention_1(x)
84
+ x += residue_short
85
+
86
+ residue_short = x
87
+
88
+ x = self.layernorm_2(x)
89
+ x = self.attention_2(x, context)
90
+
91
+ x += residue_short
92
+
93
+ residue_short = x
94
+
95
+ x = self.layernorm_3(x)
96
+
97
+ # GeGLU as implemented in the original code: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/attention.py#L37C10-L37C10
98
+ x, gate = self.linear_geglu_1(x).chunk(2, dim=-1)
99
+ x = x * F.gelu(gate)
100
+ x = self.linear_geglu_2(x)
101
+ x += residue_short
102
+ x = x.transpose(-1, -2)
103
+
104
+ x = x.view((n, c, h, w))
105
+
106
+ return self.conv_output(x) + residue_long
107
+
108
+ class Upsample(nn.Module):
109
+ def __init__(self, channels):
110
+ super().__init__()
111
+ self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
112
+
113
+ def forward(self, x):
114
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
115
+ return self.conv(x)
116
+
117
+ class SwitchSequential(nn.Sequential):
118
+ def forward(self, x, context, time):
119
+ for layer in self:
120
+ if isinstance(layer, UNET_AttentionBlock):
121
+ x = layer(x, context)
122
+ elif isinstance(layer, UNET_ResidualBlock):
123
+ x = layer(x, time)
124
+ else:
125
+ x = layer(x)
126
+ return x
127
+
128
+ class UNET(nn.Module):
129
+ def __init__(self):
130
+ super().__init__()
131
+ self.encoders = nn.ModuleList([
132
+ SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),
133
+ SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
134
+ SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
135
+ SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)),
136
+ SwitchSequential(UNET_ResidualBlock(320, 640), UNET_AttentionBlock(8, 80)),
137
+ SwitchSequential(UNET_ResidualBlock(640, 640), UNET_AttentionBlock(8, 80)),
138
+ SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)),
139
+ SwitchSequential(UNET_ResidualBlock(640, 1280), UNET_AttentionBlock(8, 160)),
140
+ SwitchSequential(UNET_ResidualBlock(1280, 1280), UNET_AttentionBlock(8, 160)),
141
+ SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)),
142
+ SwitchSequential(UNET_ResidualBlock(1280, 1280)),
143
+ SwitchSequential(UNET_ResidualBlock(1280, 1280)),
144
+ ])
145
+
146
+ self.bottleneck = SwitchSequential(
147
+ UNET_ResidualBlock(1280, 1280),
148
+ UNET_AttentionBlock(8, 160),
149
+ UNET_ResidualBlock(1280, 1280),
150
+ )
151
+
152
+ self.decoders = nn.ModuleList([
153
+ SwitchSequential(UNET_ResidualBlock(2560, 1280)),
154
+ SwitchSequential(UNET_ResidualBlock(2560, 1280)),
155
+ SwitchSequential(UNET_ResidualBlock(2560, 1280), Upsample(1280)),
156
+ SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),
157
+
158
+ SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),
159
+ SwitchSequential(UNET_ResidualBlock(1920, 1280), UNET_AttentionBlock(8, 160), Upsample(1280)),
160
+
161
+ SwitchSequential(UNET_ResidualBlock(1920, 640), UNET_AttentionBlock(8, 80)),
162
+ SwitchSequential(UNET_ResidualBlock(1280, 640), UNET_AttentionBlock(8, 80)),
163
+
164
+ SwitchSequential(UNET_ResidualBlock(960, 640), UNET_AttentionBlock(8, 80), Upsample(640)),
165
+ SwitchSequential(UNET_ResidualBlock(960, 320), UNET_AttentionBlock(8, 40)),
166
+ SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),
167
+
168
+ SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),
169
+ ])
170
+
171
+ def forward(self, x, context, time):
172
+ skip_connections = []
173
+ for layers in self.encoders:
174
+ x = layers(x, context, time)
175
+ skip_connections.append(x)
176
+
177
+ x = self.bottleneck(x, context, time)
178
+
179
+ for layers in self.decoders:
180
+ x = torch.cat((x, skip_connections.pop()), dim=1)
181
+ x = layers(x, context, time)
182
+
183
+ return x
184
+
185
+
186
+ class UNET_OutputLayer(nn.Module):
187
+ def __init__(self, in_channels, out_channels):
188
+ super().__init__()
189
+ self.groupnorm = nn.GroupNorm(32, in_channels)
190
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
191
+
192
+ def forward(self, x):
193
+ x = self.groupnorm(x)
194
+ x = F.silu(x)
195
+ x = self.conv(x)
196
+
197
+ return x
198
+
199
+ class Diffusion(nn.Module):
200
+ def __init__(self):
201
+ super().__init__()
202
+ self.time_embedding = TimeEmbedding(320)
203
+ self.unet = UNET()
204
+ self.final = UNET_OutputLayer(320, 4)
205
+
206
+ def forward(self, latent, context, time):
207
+ time = self.time_embedding(time)
208
+
209
+ output = self.unet(latent, context, time)
210
+
211
+ output = self.final(output)
212
+
213
  return output