Text-to-Image
daoyuan98 commited on
Commit
ef7c6c6
1 Parent(s): 54c32af

Create layers.py

Browse files
Files changed (1) hide show
  1. layers.py +631 -0
layers.py ADDED
@@ -0,0 +1,631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+
5
+ import torch
6
+ from einops import rearrange, repeat
7
+ from torch import Tensor, nn
8
+
9
+ import torch.nn.functional as F
10
+
11
+ import torch
12
+ from einops import rearrange
13
+
14
+
15
+ def attention(q, k, v, pe):
16
+ q, k = apply_rope(q, k, pe)
17
+
18
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
19
+ x = rearrange(x, "B H L D -> B L (H D)")
20
+
21
+ return x
22
+
23
+
24
+ def rope(pos, dim: int, theta: int):
25
+ assert dim % 2 == 0
26
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
27
+ omega = 1.0 / (theta**scale)
28
+ out = torch.einsum("...n,d->...nd", pos, omega)
29
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
30
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
31
+ return out.float()
32
+
33
+
34
+ def apply_rope(xq, xk, freqs_cis):
35
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
36
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
37
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
38
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
39
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
40
+
41
+
42
+ class EmbedND(nn.Module):
43
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
44
+ super().__init__()
45
+ self.dim = dim
46
+ self.theta = theta
47
+ self.axes_dim = axes_dim
48
+
49
+ def forward(self, ids: Tensor):
50
+ n_axes = ids.shape[-1]
51
+ emb = torch.cat(
52
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
53
+ dim=-3,
54
+ )
55
+ return emb.unsqueeze(1)
56
+
57
+
58
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
59
+ """
60
+ Create sinusoidal timestep embeddings.
61
+ :param t: a 1-D Tensor of N indices, one per batch element.
62
+ These may be fractional.
63
+ :param dim: the dimension of the output.
64
+ :param max_period: controls the minimum frequency of the embeddings.
65
+ :return: an (N, D) Tensor of positional embeddings.
66
+ """
67
+ t = time_factor * t
68
+ half = dim // 2
69
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
70
+ t.device)
71
+
72
+ args = t[:, None].float() * freqs[None]
73
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
74
+ if dim % 2:
75
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
76
+ if torch.is_floating_point(t):
77
+ embedding = embedding.to(t)
78
+ return embedding
79
+
80
+
81
+ class MLPEmbedder(nn.Module):
82
+ def __init__(self, in_dim: int, hidden_dim: int):
83
+ super().__init__()
84
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
85
+ self.silu = nn.SiLU()
86
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
87
+
88
+ def forward(self, x: Tensor):
89
+ return self.out_layer(self.silu(self.in_layer(x)))
90
+
91
+
92
+ class RMSNorm(torch.nn.Module):
93
+ def __init__(self, dim: int):
94
+ super().__init__()
95
+ self.scale = nn.Parameter(torch.ones(dim))
96
+
97
+ def forward(self, x: Tensor):
98
+ x_dtype = x.dtype
99
+ x = x.float()
100
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
101
+ return (x * rrms).to(dtype=x_dtype) * self.scale
102
+
103
+
104
+ class QKNorm(torch.nn.Module):
105
+ def __init__(self, dim: int):
106
+ super().__init__()
107
+ self.query_norm = RMSNorm(dim)
108
+ self.key_norm = RMSNorm(dim)
109
+
110
+ def forward(self, q: Tensor, k: Tensor, v: Tensor):
111
+ q = self.query_norm(q)
112
+ k = self.key_norm(k)
113
+ return q.to(v), k.to(v)
114
+
115
+ class LoRALinearLayer(nn.Module):
116
+ def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
117
+ super().__init__()
118
+
119
+ self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
120
+ self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
121
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
122
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
123
+ self.network_alpha = network_alpha
124
+ self.rank = rank
125
+
126
+ nn.init.normal_(self.down.weight, std=1 / rank)
127
+ nn.init.zeros_(self.up.weight)
128
+
129
+ def forward(self, hidden_states):
130
+ orig_dtype = hidden_states.dtype
131
+ dtype = self.down.weight.dtype
132
+
133
+ down_hidden_states = self.down(hidden_states.to(dtype))
134
+ up_hidden_states = self.up(down_hidden_states)
135
+
136
+ if self.network_alpha is not None:
137
+ up_hidden_states *= self.network_alpha / self.rank
138
+
139
+ return up_hidden_states.to(orig_dtype)
140
+
141
+ class FLuxSelfAttnProcessor:
142
+ def __call__(self, attn, x, pe, **attention_kwargs):
143
+ print('2' * 30)
144
+
145
+ qkv = attn.qkv(x)
146
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
147
+ q, k = attn.norm(q, k, v)
148
+ x = attention(q, k, v, pe=pe)
149
+ x = attn.proj(x)
150
+ return x
151
+
152
+ class LoraFluxAttnProcessor(nn.Module):
153
+
154
+ def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
155
+ super().__init__()
156
+ self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
157
+ self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha)
158
+ self.lora_weight = lora_weight
159
+
160
+
161
+ def __call__(self, attn, x, pe, **attention_kwargs):
162
+ qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight
163
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
164
+ q, k = attn.norm(q, k, v)
165
+ x = attention(q, k, v, pe=pe)
166
+ x = attn.proj(x) + self.proj_lora(x) * self.lora_weight
167
+ print('1' * 30)
168
+ print(x.norm(), (self.proj_lora(x) * self.lora_weight).norm(), 'norm')
169
+ return x
170
+
171
+ class SelfAttention(nn.Module):
172
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
173
+ super().__init__()
174
+ self.num_heads = num_heads
175
+ head_dim = dim // num_heads
176
+
177
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
178
+ self.norm = QKNorm(head_dim)
179
+ self.proj = nn.Linear(dim, dim)
180
+ def forward():
181
+ pass
182
+
183
+
184
+ @dataclass
185
+ class ModulationOut:
186
+ shift: Tensor
187
+ scale: Tensor
188
+ gate: Tensor
189
+
190
+
191
+ class Modulation(nn.Module):
192
+ def __init__(self, dim: int, double: bool):
193
+ super().__init__()
194
+ self.is_double = double
195
+ self.multiplier = 6 if double else 3
196
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
197
+
198
+ def forward(self, vec: Tensor):
199
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
200
+
201
+ return (
202
+ ModulationOut(*out[:3]),
203
+ ModulationOut(*out[3:]) if self.is_double else None,
204
+ )
205
+
206
+ class DoubleStreamBlockLoraProcessor(nn.Module):
207
+ def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
208
+ super().__init__()
209
+ self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
210
+ self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha)
211
+ self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
212
+ self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha)
213
+ self.lora_weight = lora_weight
214
+
215
+ def __call__(self, attn, img, txt, vec, pe):
216
+
217
+ img_mod1, img_mod2 = attn.img_mod(vec)
218
+ txt_mod1, txt_mod2 = attn.txt_mod(vec)
219
+
220
+ # prepare image for attention
221
+ img_modulated = attn.img_norm1(img)
222
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
223
+ img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight
224
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
225
+ img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
226
+
227
+ # prepare txt for attention
228
+ txt_modulated = attn.txt_norm1(txt)
229
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
230
+ txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight
231
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
232
+ txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
233
+
234
+ # run actual attention
235
+ q = torch.cat((txt_q, img_q), dim=2)
236
+ k = torch.cat((txt_k, img_k), dim=2)
237
+ v = torch.cat((txt_v, img_v), dim=2)
238
+
239
+ attn1 = attention(q, k, v, pe=pe)
240
+ txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
241
+
242
+ # calculate the img bloks
243
+ img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img_mod1.gate * self.proj_lora1(img_attn) * self.lora_weight
244
+ img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
245
+
246
+ # calculate the txt bloks
247
+ txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt_mod1.gate * self.proj_lora2(txt_attn) * self.lora_weight
248
+ txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
249
+
250
+ return img, txt
251
+
252
+ class IPDoubleStreamBlockProcessor(nn.Module):
253
+ """Attention processor for handling IP-adapter with double stream block."""
254
+
255
+ def __init__(self, context_dim, hidden_dim):
256
+ super().__init__()
257
+ if not hasattr(F, "scaled_dot_product_attention"):
258
+ raise ImportError(
259
+ "IPDoubleStreamBlockProcessor requires PyTorch 2.0 or higher. Please upgrade PyTorch."
260
+ )
261
+
262
+ # Ensure context_dim matches the dimension of image_proj
263
+ self.context_dim = context_dim
264
+ self.hidden_dim = hidden_dim
265
+
266
+ # Initialize projections for IP-adapter
267
+ self.ip_adapter_double_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=True)
268
+ self.ip_adapter_double_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=True)
269
+
270
+ nn.init.zeros_(self.ip_adapter_double_stream_k_proj.weight)
271
+ nn.init.zeros_(self.ip_adapter_double_stream_k_proj.bias)
272
+
273
+ nn.init.zeros_(self.ip_adapter_double_stream_v_proj.weight)
274
+ nn.init.zeros_(self.ip_adapter_double_stream_v_proj.bias)
275
+
276
+ def __call__(self, attn, img, txt, vec, pe, image_proj, ip_scale=1.0, **attention_kwargs):
277
+
278
+ # Prepare image for attention
279
+ img_mod1, img_mod2 = attn.img_mod(vec)
280
+ txt_mod1, txt_mod2 = attn.txt_mod(vec)
281
+
282
+ img_modulated = attn.img_norm1(img)
283
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
284
+ img_qkv = attn.img_attn.qkv(img_modulated)
285
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
286
+ img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
287
+
288
+ txt_modulated = attn.txt_norm1(txt)
289
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
290
+ txt_qkv = attn.txt_attn.qkv(txt_modulated)
291
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
292
+ txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
293
+
294
+ q = torch.cat((txt_q, img_q), dim=2)
295
+ k = torch.cat((txt_k, img_k), dim=2)
296
+ v = torch.cat((txt_v, img_v), dim=2)
297
+
298
+ attn1 = attention(q, k, v, pe=pe)
299
+ txt_attn, img_attn = attn1[:, :txt.shape[1]], attn1[:, txt.shape[1]:]
300
+
301
+ img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
302
+ img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
303
+
304
+ txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
305
+ txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
306
+
307
+ # IP-adapter processing
308
+ ip_query = img_q # latent sample query
309
+ ip_key = self.ip_adapter_double_stream_k_proj(image_proj)
310
+ ip_value = self.ip_adapter_double_stream_v_proj(image_proj)
311
+
312
+ # Reshape projections for multi-head attention
313
+ ip_key = rearrange(ip_key, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim)
314
+ ip_value = rearrange(ip_value, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim)
315
+
316
+ # Compute attention between IP projections and the latent query
317
+ ip_attention = F.scaled_dot_product_attention(
318
+ ip_query,
319
+ ip_key,
320
+ ip_value,
321
+ dropout_p=0.0,
322
+ is_causal=False
323
+ )
324
+ ip_attention = rearrange(ip_attention, "B H L D -> B L (H D)", H=attn.num_heads, D=attn.head_dim)
325
+
326
+ img = img + ip_scale * ip_attention
327
+
328
+ return img, txt
329
+
330
+ class DoubleStreamBlockProcessor:
331
+ def __call__(self, attn, img, txt, vec, pe):
332
+
333
+ img_mod1, img_mod2 = attn.img_mod(vec)
334
+ txt_mod1, txt_mod2 = attn.txt_mod(vec)
335
+
336
+ # prepare image for attention
337
+ img_modulated = attn.img_norm1(img)
338
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
339
+ img_qkv = attn.img_attn.qkv(img_modulated)
340
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
341
+ img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
342
+
343
+ # prepare txt for attention
344
+ txt_modulated = attn.txt_norm1(txt)
345
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
346
+ txt_qkv = attn.txt_attn.qkv(txt_modulated)
347
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
348
+ txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
349
+
350
+ # run actual attention
351
+ q = torch.cat((txt_q, img_q), dim=2)
352
+ k = torch.cat((txt_k, img_k), dim=2)
353
+ v = torch.cat((txt_v, img_v), dim=2)
354
+
355
+ attn1 = attention(q, k, v, pe=pe)
356
+ txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
357
+
358
+ # calculate the img bloks
359
+ img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
360
+ img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
361
+
362
+ # calculate the txt bloks
363
+ txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
364
+ txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
365
+
366
+ return img, txt
367
+
368
+
369
+ class DoubleStreamBlock(nn.Module):
370
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
371
+ super().__init__()
372
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
373
+ self.num_heads = num_heads
374
+ self.hidden_size = hidden_size
375
+ self.head_dim = hidden_size // num_heads
376
+
377
+ self.img_mod = Modulation(hidden_size, double=True)
378
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
379
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
380
+
381
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
382
+ self.img_mlp = nn.Sequential(
383
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
384
+ nn.GELU(approximate="tanh"),
385
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
386
+ )
387
+
388
+ self.txt_mod = Modulation(hidden_size, double=True)
389
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
390
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
391
+
392
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
393
+ self.txt_mlp = nn.Sequential(
394
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
395
+ nn.GELU(approximate="tanh"),
396
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
397
+ )
398
+
399
+ processor = DoubleStreamBlockProcessor()
400
+ self.set_processor(processor)
401
+
402
+ def set_processor(self, processor):
403
+ self.processor = processor
404
+
405
+ def get_processor(self):
406
+ return self.processor
407
+
408
+ def forward(
409
+ self,
410
+ img: Tensor,
411
+ txt: Tensor,
412
+ vec: Tensor,
413
+ pe: Tensor,
414
+ image_proj: Tensor = None,
415
+ ip_scale: float =1.0,
416
+ ):
417
+ if image_proj is None:
418
+ return self.processor(self, img, txt, vec, pe)
419
+ else:
420
+ return self.processor(self, img, txt, vec, pe, image_proj, ip_scale)
421
+
422
+
423
+ class IPSingleStreamBlockProcessor(nn.Module):
424
+ """Attention processor for handling IP-adapter with single stream block."""
425
+ def __init__(self, context_dim, hidden_dim):
426
+ super().__init__()
427
+ if not hasattr(F, "scaled_dot_product_attention"):
428
+ raise ImportError(
429
+ "IPSingleStreamBlockProcessor requires PyTorch 2.0 or higher. Please upgrade PyTorch."
430
+ )
431
+
432
+ # Ensure context_dim matches the dimension of image_proj
433
+ self.context_dim = context_dim
434
+ self.hidden_dim = hidden_dim
435
+
436
+ # Initialize projections for IP-adapter
437
+ self.ip_adapter_single_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=False)
438
+ self.ip_adapter_single_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=False)
439
+
440
+ nn.init.zeros_(self.ip_adapter_single_stream_k_proj.weight)
441
+ nn.init.zeros_(self.ip_adapter_single_stream_v_proj.weight)
442
+
443
+ def __call__(
444
+ self,
445
+ attn: nn.Module,
446
+ x: Tensor,
447
+ vec: Tensor,
448
+ pe: Tensor,
449
+ image_proj: Tensor = None,
450
+ ip_scale: float = 1.0
451
+ ):
452
+
453
+ mod, _ = attn.modulation(vec)
454
+ x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
455
+ qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
456
+
457
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
458
+ q, k = attn.norm(q, k, v)
459
+
460
+ # compute attention
461
+ attn_1 = attention(q, k, v, pe=pe)
462
+
463
+ # IP-adapter processing
464
+ ip_query = q
465
+ ip_key = self.ip_adapter_single_stream_k_proj(image_proj)
466
+ ip_value = self.ip_adapter_single_stream_v_proj(image_proj)
467
+
468
+ # Reshape projections for multi-head attention
469
+ ip_key = rearrange(ip_key, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim)
470
+ ip_value = rearrange(ip_value, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim)
471
+
472
+
473
+ # Compute attention between IP projections and the latent query
474
+ ip_attention = F.scaled_dot_product_attention(
475
+ ip_query,
476
+ ip_key,
477
+ ip_value
478
+ )
479
+ ip_attention = rearrange(ip_attention, "B H L D -> B L (H D)")
480
+
481
+ attn_out = attn_1 + ip_scale * ip_attention
482
+
483
+ # compute activation in mlp stream, cat again and run second linear layer
484
+ output = attn.linear2(torch.cat((attn_out, attn.mlp_act(mlp)), 2))
485
+ out = x + mod.gate * output
486
+
487
+ return out
488
+
489
+
490
+ class SingleStreamBlockLoraProcessor(nn.Module):
491
+ def __init__(self, dim: int, rank: int = 4, network_alpha = None, lora_weight: float = 1):
492
+ super().__init__()
493
+ self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
494
+ self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha)
495
+ self.lora_weight = lora_weight
496
+
497
+ def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor):
498
+
499
+ mod, _ = attn.modulation(vec)
500
+ x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
501
+ qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
502
+ qkv = qkv + self.qkv_lora(x_mod) * self.lora_weight
503
+
504
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
505
+ q, k = attn.norm(q, k, v)
506
+
507
+ # compute attention
508
+ attn_1 = attention(q, k, v, pe=pe)
509
+
510
+ # compute activation in mlp stream, cat again and run second linear layer
511
+ output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
512
+ output = output + self.proj_lora(output) * self.lora_weight
513
+ output = x + mod.gate * output
514
+
515
+ return output
516
+
517
+
518
+ class SingleStreamBlockProcessor:
519
+ def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor):
520
+
521
+ mod, _ = attn.modulation(vec)
522
+ x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
523
+ qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
524
+
525
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
526
+ q, k = attn.norm(q, k, v)
527
+
528
+ # compute attention
529
+ attn_1 = attention(q, k, v, pe=pe)
530
+
531
+ # compute activation in mlp stream, cat again and run second linear layer
532
+ output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
533
+ output = x + mod.gate * output
534
+
535
+ return output
536
+
537
+
538
+ class SingleStreamBlock(nn.Module):
539
+ """
540
+ A DiT block with parallel linear layers as described in
541
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
542
+ """
543
+
544
+ def __init__(
545
+ self,
546
+ hidden_size: int,
547
+ num_heads: int,
548
+ mlp_ratio: float = 4.0,
549
+ qk_scale: float = None,
550
+ ):
551
+ super().__init__()
552
+ self.hidden_dim = hidden_size
553
+ self.num_heads = num_heads
554
+ self.head_dim = hidden_size // num_heads
555
+ self.scale = qk_scale or self.head_dim**-0.5
556
+
557
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
558
+ # qkv and mlp_in
559
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
560
+ # proj and mlp_out
561
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
562
+
563
+ self.norm = QKNorm(self.head_dim)
564
+
565
+ self.hidden_size = hidden_size
566
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
567
+
568
+ self.mlp_act = nn.GELU(approximate="tanh")
569
+ self.modulation = Modulation(hidden_size, double=False)
570
+
571
+ processor = SingleStreamBlockProcessor()
572
+ self.set_processor(processor)
573
+
574
+
575
+ def set_processor(self, processor):
576
+ self.processor = processor
577
+
578
+ def get_processor(self):
579
+ return self.processor
580
+
581
+ def forward(
582
+ self,
583
+ x: Tensor,
584
+ vec: Tensor,
585
+ pe: Tensor,
586
+ image_proj: Tensor = None,
587
+ ip_scale: float = 1.0
588
+ ):
589
+ if image_proj is None:
590
+ return self.processor(self, x, vec, pe)
591
+ else:
592
+ return self.processor(self, x, vec, pe, image_proj, ip_scale)
593
+
594
+
595
+ class LastLayer(nn.Module):
596
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
597
+ super().__init__()
598
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
599
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
600
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
601
+
602
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
603
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
604
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
605
+ x = self.linear(x)
606
+ return x
607
+
608
+
609
+ class ImageProjModel(torch.nn.Module):
610
+ """Projection Model
611
+ https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter.py#L28
612
+ """
613
+
614
+
615
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
616
+ super().__init__()
617
+
618
+ self.generator = None
619
+ self.cross_attention_dim = cross_attention_dim
620
+ self.clip_extra_context_tokens = clip_extra_context_tokens
621
+ self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
622
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
623
+
624
+ def forward(self, image_embeds):
625
+ embeds = image_embeds
626
+ clip_extra_context_tokens = self.proj(embeds).reshape(
627
+ -1, self.clip_extra_context_tokens, self.cross_attention_dim
628
+ )
629
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
630
+ return clip_extra_context_tokens
631
+