jadechoghari commited on
Commit
4de0c6c
1 Parent(s): b1aa766

Create attention.py

Browse files
Files changed (1) hide show
  1. unet/attention.py +331 -0
unet/attention.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn, einsum
6
+ from einops import rearrange, repeat
7
+
8
+ def checkpoint(func, inputs, params, flag):
9
+ """
10
+ Evaluate a function without caching intermediate activations, allowing for
11
+ reduced memory at the expense of extra compute in the backward pass.
12
+ :param func: the function to evaluate.
13
+ :param inputs: the argument sequence to pass to `func`.
14
+ :param params: a sequence of parameters `func` depends on but does not
15
+ explicitly take as arguments.
16
+ :param flag: if False, disable gradient checkpointing.
17
+ """
18
+ if False: # disabled checkpointing to allow requires_grad = False for main model
19
+ args = tuple(inputs) + tuple(params)
20
+ return CheckpointFunction.apply(func, len(inputs), *args)
21
+ else:
22
+ return func(*inputs)
23
+
24
+ try:
25
+ import xformers
26
+ import xformers.ops
27
+ XFORMERS_IS_AVAILBLE = True
28
+ except:
29
+ XFORMERS_IS_AVAILBLE = False
30
+
31
+
32
+ def exists(val):
33
+ return val is not None
34
+
35
+
36
+ def uniq(arr):
37
+ return{el: True for el in arr}.keys()
38
+
39
+
40
+ def default(val, d):
41
+ if exists(val):
42
+ return val
43
+ return d() if isfunction(d) else d
44
+
45
+
46
+ def max_neg_value(t):
47
+ return -torch.finfo(t.dtype).max
48
+
49
+
50
+ def init_(tensor):
51
+ dim = tensor.shape[-1]
52
+ std = 1 / math.sqrt(dim)
53
+ tensor.uniform_(-std, std)
54
+ return tensor
55
+
56
+
57
+ # feedforward
58
+ class GEGLU(nn.Module):
59
+ def __init__(self, dim_in, dim_out):
60
+ super().__init__()
61
+ self.proj = nn.Linear(dim_in, dim_out * 2)
62
+
63
+ def forward(self, x):
64
+ x, gate = self.proj(x).chunk(2, dim=-1)
65
+ return x * F.gelu(gate)
66
+
67
+
68
+ class FeedForward(nn.Module):
69
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
70
+ super().__init__()
71
+ inner_dim = int(dim * mult)
72
+ dim_out = default(dim_out, dim)
73
+ project_in = nn.Sequential(
74
+ nn.Linear(dim, inner_dim),
75
+ nn.GELU()
76
+ ) if not glu else GEGLU(dim, inner_dim)
77
+
78
+ self.net = nn.Sequential(
79
+ project_in,
80
+ nn.Dropout(dropout),
81
+ nn.Linear(inner_dim, dim_out)
82
+ )
83
+
84
+ def forward(self, x):
85
+ return self.net(x)
86
+
87
+
88
+ def zero_module(module):
89
+ """
90
+ Zero out the parameters of a module and return it.
91
+ """
92
+ for p in module.parameters():
93
+ p.detach().zero_()
94
+ return module
95
+
96
+
97
+ def Normalize(in_channels):
98
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
99
+
100
+
101
+ class LinearAttention(nn.Module):
102
+ def __init__(self, dim, heads=4, dim_head=32):
103
+ super().__init__()
104
+ self.heads = heads
105
+ hidden_dim = dim_head * heads
106
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
107
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
108
+
109
+ def forward(self, x):
110
+ b, c, h, w = x.shape
111
+ qkv = self.to_qkv(x)
112
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
113
+ k = k.softmax(dim=-1)
114
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
115
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
116
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
117
+ return self.to_out(out)
118
+
119
+
120
+ class SpatialSelfAttention(nn.Module):
121
+ def __init__(self, in_channels):
122
+ super().__init__()
123
+ self.in_channels = in_channels
124
+
125
+ self.norm = Normalize(in_channels)
126
+ self.q = torch.nn.Conv2d(in_channels,
127
+ in_channels,
128
+ kernel_size=1,
129
+ stride=1,
130
+ padding=0)
131
+ self.k = torch.nn.Conv2d(in_channels,
132
+ in_channels,
133
+ kernel_size=1,
134
+ stride=1,
135
+ padding=0)
136
+ self.v = torch.nn.Conv2d(in_channels,
137
+ in_channels,
138
+ kernel_size=1,
139
+ stride=1,
140
+ padding=0)
141
+ self.proj_out = torch.nn.Conv2d(in_channels,
142
+ in_channels,
143
+ kernel_size=1,
144
+ stride=1,
145
+ padding=0)
146
+
147
+ def forward(self, x):
148
+ h_ = x
149
+ h_ = self.norm(h_)
150
+ q = self.q(h_)
151
+ k = self.k(h_)
152
+ v = self.v(h_)
153
+
154
+ # compute attention
155
+ b,c,h,w = q.shape
156
+ q = rearrange(q, 'b c h w -> b (h w) c')
157
+ k = rearrange(k, 'b c h w -> b c (h w)')
158
+ w_ = torch.einsum('bij,bjk->bik', q, k)
159
+
160
+ w_ = w_ * (int(c)**(-0.5))
161
+ w_ = torch.nn.functional.softmax(w_, dim=2)
162
+
163
+ # attend to values
164
+ v = rearrange(v, 'b c h w -> b c (h w)')
165
+ w_ = rearrange(w_, 'b i j -> b j i')
166
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
167
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
168
+ h_ = self.proj_out(h_)
169
+
170
+ return x+h_
171
+
172
+
173
+ class CrossAttention(nn.Module):
174
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
175
+ super().__init__()
176
+ inner_dim = dim_head * heads
177
+ context_dim = default(context_dim, query_dim)
178
+
179
+ self.scale = dim_head ** -0.5
180
+ self.heads = heads
181
+
182
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
183
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
184
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
185
+
186
+ self.to_out = nn.Sequential(
187
+ nn.Linear(inner_dim, query_dim),
188
+ nn.Dropout(dropout)
189
+ )
190
+
191
+ def forward(self, x, context=None, mask=None):
192
+ h = self.heads
193
+
194
+ q = self.to_q(x)
195
+ context = default(context, x)
196
+ k = self.to_k(context)
197
+ v = self.to_v(context)
198
+
199
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
200
+
201
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
202
+
203
+ if exists(mask):
204
+ mask = rearrange(mask, 'b ... -> b (...)')
205
+ max_neg_value = -torch.finfo(sim.dtype).max
206
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
207
+ sim.masked_fill_(~mask, max_neg_value)
208
+
209
+ # attention, what we cannot get enough of
210
+ attn = sim.softmax(dim=-1)
211
+
212
+ out = einsum('b i j, b j d -> b i d', attn, v)
213
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
214
+ return self.to_out(out)
215
+
216
+
217
+ class BasicTransformerBlock(nn.Module):
218
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
219
+ super().__init__()
220
+ self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
221
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
222
+ self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
223
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
224
+ self.norm1 = nn.LayerNorm(dim)
225
+ self.norm2 = nn.LayerNorm(dim)
226
+ self.norm3 = nn.LayerNorm(dim)
227
+ self.checkpoint = checkpoint
228
+
229
+ def forward(self, x, context=None):
230
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
231
+
232
+ def _forward(self, x, context=None):
233
+ x = self.attn1(self.norm1(x)) + x
234
+ x = self.attn2(self.norm2(x), context=context) + x
235
+ x = self.ff(self.norm3(x)) + x
236
+ return x
237
+
238
+
239
+ class SpatialTransformer(nn.Module):
240
+ """
241
+ Transformer block for image-like data.
242
+ First, project the input (aka embedding)
243
+ and reshape to b, t, d.
244
+ Then apply standard transformer action.
245
+ Finally, reshape to image
246
+ """
247
+ def __init__(self, in_channels, n_heads, d_head,
248
+ depth=1, dropout=0., context_dim=None):
249
+ super().__init__()
250
+ self.in_channels = in_channels
251
+ inner_dim = n_heads * d_head
252
+ self.norm = Normalize(in_channels)
253
+
254
+ self.proj_in = nn.Conv2d(in_channels,
255
+ inner_dim,
256
+ kernel_size=1,
257
+ stride=1,
258
+ padding=0)
259
+
260
+ self.transformer_blocks = nn.ModuleList(
261
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
262
+ for d in range(depth)]
263
+ )
264
+
265
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
266
+ in_channels,
267
+ kernel_size=1,
268
+ stride=1,
269
+ padding=0))
270
+
271
+ def forward(self, x, context=None):
272
+ # note: if no context is given, cross-attention defaults to self-attention
273
+ b, c, h, w = x.shape
274
+ x_in = x
275
+ x = self.norm(x)
276
+ x = self.proj_in(x)
277
+ x = rearrange(x, 'b c h w -> b (h w) c')
278
+ for block in self.transformer_blocks:
279
+ x = block(x, context=context)
280
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
281
+ x = self.proj_out(x)
282
+ return x + x_in
283
+
284
+
285
+ class MemoryEfficientCrossAttention(nn.Module):
286
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
287
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
288
+ super().__init__()
289
+ print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
290
+ f"{heads} heads.")
291
+ inner_dim = dim_head * heads
292
+ context_dim = default(context_dim, query_dim)
293
+
294
+ self.heads = heads
295
+ self.dim_head = dim_head
296
+
297
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
298
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
299
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
300
+
301
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
302
+ self.attention_op: Optional[Any] = None
303
+
304
+ def forward(self, x, context=None, mask=None):
305
+ q = self.to_q(x)
306
+ context = default(context, x)
307
+ k = self.to_k(context)
308
+ v = self.to_v(context)
309
+
310
+ b, _, _ = q.shape
311
+ q, k, v = map(
312
+ lambda t: t.unsqueeze(3)
313
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
314
+ .permute(0, 2, 1, 3)
315
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
316
+ .contiguous(),
317
+ (q, k, v),
318
+ )
319
+
320
+ # actually compute the attention, what we cannot get enough of
321
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
322
+
323
+ if exists(mask):
324
+ raise NotImplementedError
325
+ out = (
326
+ out.unsqueeze(0)
327
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
328
+ .permute(0, 2, 1, 3)
329
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
330
+ )
331
+ return self.to_out(out)