frankleeeee commited on
Commit
9784497
1 Parent(s): 46186e9

Upload STDiT

Browse files
Files changed (6) hide show
  1. config.json +39 -0
  2. configuration_stdit.py +51 -0
  3. layers.py +466 -0
  4. model.safetensors +3 -0
  5. modeling_stdit.py +260 -0
  6. utils.py +80 -0
config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "STDiT"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_stdit.STDiTConfig",
7
+ "AutoModel": "modeling_stdit.STDiT"
8
+ },
9
+ "caption_channels": 4096,
10
+ "class_dropout_prob": 0.1,
11
+ "depth": 28,
12
+ "drop_path": 0.0,
13
+ "enable_flash_attn": false,
14
+ "enable_layernorm_kernel": false,
15
+ "enable_sequence_parallelism": false,
16
+ "freeze": null,
17
+ "hidden_size": 1152,
18
+ "in_channels": 4,
19
+ "input_size": [
20
+ 16,
21
+ 32,
22
+ 32
23
+ ],
24
+ "mlp_ratio": 4.0,
25
+ "model_max_length": 120,
26
+ "model_type": "resnet",
27
+ "no_temporal_pos_emb": false,
28
+ "num_heads": 16,
29
+ "patch_size": [
30
+ 1,
31
+ 2,
32
+ 2
33
+ ],
34
+ "pred_sigma": true,
35
+ "space_scale": 0.5,
36
+ "time_scale": 1.0,
37
+ "torch_dtype": "float32",
38
+ "transformers_version": "4.38.2"
39
+ }
configuration_stdit.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import PretrainedConfig
3
+ from typing import List
4
+
5
+
6
+ class STDiTConfig(PretrainedConfig):
7
+ model_type = "resnet"
8
+
9
+ def __init__(
10
+ self,
11
+ input_size=(1, 32, 32),
12
+ in_channels=4,
13
+ patch_size=(1, 2, 2),
14
+ hidden_size=1152,
15
+ depth=28,
16
+ num_heads=16,
17
+ mlp_ratio=4.0,
18
+ class_dropout_prob=0.1,
19
+ pred_sigma=True,
20
+ drop_path=0.0,
21
+ no_temporal_pos_emb=False,
22
+ caption_channels=4096,
23
+ model_max_length=120,
24
+ space_scale=1.0,
25
+ time_scale=1.0,
26
+ freeze=None,
27
+ enable_flash_attn=False,
28
+ enable_layernorm_kernel=False,
29
+ enable_sequence_parallelism=False,
30
+ **kwargs,
31
+ ):
32
+ self.input_size = input_size
33
+ self.in_channels = in_channels
34
+ self.patch_size = patch_size
35
+ self.hidden_size = hidden_size
36
+ self.depth = depth
37
+ self.num_heads = num_heads
38
+ self.mlp_ratio = mlp_ratio
39
+ self.class_dropout_prob = class_dropout_prob
40
+ self.pred_sigma = pred_sigma
41
+ self.drop_path = drop_path
42
+ self.no_temporal_pos_emb = no_temporal_pos_emb
43
+ self.caption_channels = caption_channels
44
+ self.model_max_length = model_max_length
45
+ self.space_scale = space_scale
46
+ self.time_scale = time_scale
47
+ self.freeze = freeze
48
+ self.enable_flash_attn = enable_flash_attn
49
+ self.enable_layernorm_kernel = enable_layernorm_kernel
50
+ self.enable_sequence_parallelism = enable_sequence_parallelism
51
+ super().__init__(**kwargs)
layers.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import collections.abc
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from einops import rearrange
9
+ from itertools import repeat
10
+ from functools import partial
11
+ from .utils import get_layernorm, t2i_modulate, approx_gelu
12
+
13
+
14
+ try:
15
+ import xformers
16
+ HAS_XFORMERS = True
17
+ except:
18
+ HAS_XFORMERS = False
19
+
20
+ # ============================
21
+ # Attention
22
+ # ============================
23
+ class Attention(nn.Module):
24
+ def __init__(
25
+ self,
26
+ dim: int,
27
+ num_heads: int = 8,
28
+ qkv_bias: bool = False,
29
+ qk_norm: bool = False,
30
+ attn_drop: float = 0.0,
31
+ proj_drop: float = 0.0,
32
+ norm_layer: nn.Module = nn.LayerNorm,
33
+ enable_flash_attn: bool = False,
34
+ ) -> None:
35
+ super().__init__()
36
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
37
+ self.dim = dim
38
+ self.num_heads = num_heads
39
+ self.head_dim = dim // num_heads
40
+ self.scale = self.head_dim**-0.5
41
+ self.enable_flash_attn = enable_flash_attn
42
+
43
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
44
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
45
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
46
+ self.attn_drop = nn.Dropout(attn_drop)
47
+ self.proj = nn.Linear(dim, dim)
48
+ self.proj_drop = nn.Dropout(proj_drop)
49
+
50
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
51
+ B, N, C = x.shape
52
+ qkv = self.qkv(x)
53
+ qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
54
+ if self.enable_flash_attn:
55
+ qkv_permute_shape = (2, 0, 1, 3, 4)
56
+ else:
57
+ qkv_permute_shape = (2, 0, 3, 1, 4)
58
+ qkv = qkv.view(qkv_shape).permute(qkv_permute_shape)
59
+ q, k, v = qkv.unbind(0)
60
+ q, k = self.q_norm(q), self.k_norm(k)
61
+ if self.enable_flash_attn:
62
+ from flash_attn import flash_attn_func
63
+
64
+ x = flash_attn_func(
65
+ q,
66
+ k,
67
+ v,
68
+ dropout_p=self.attn_drop.p if self.training else 0.0,
69
+ softmax_scale=self.scale,
70
+ )
71
+ else:
72
+ dtype = q.dtype
73
+ q = q * self.scale
74
+ attn = q @ k.transpose(-2, -1) # translate attn to float32
75
+ attn = attn.to(torch.float32)
76
+ attn = attn.softmax(dim=-1)
77
+ attn = attn.to(dtype) # cast back attn to original dtype
78
+ attn = self.attn_drop(attn)
79
+ x = attn @ v
80
+
81
+ x_output_shape = (B, N, C)
82
+ if not self.enable_flash_attn:
83
+ x = x.transpose(1, 2)
84
+ x = x.reshape(x_output_shape)
85
+ x = self.proj(x)
86
+ x = self.proj_drop(x)
87
+ return x
88
+
89
+
90
+ # ============================
91
+ # Caption Embedding
92
+ # ============================
93
+ class CaptionEmbedder(nn.Module):
94
+ """
95
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
96
+ """
97
+
98
+ def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate="tanh"), token_num=120):
99
+ super().__init__()
100
+ self.y_proj = Mlp(
101
+ in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0
102
+ )
103
+ self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels**0.5))
104
+ self.uncond_prob = uncond_prob
105
+
106
+ def token_drop(self, caption, force_drop_ids=None):
107
+ """
108
+ Drops labels to enable classifier-free guidance.
109
+ """
110
+ if force_drop_ids is None:
111
+ drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
112
+ else:
113
+ drop_ids = force_drop_ids == 1
114
+ caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
115
+ return caption
116
+
117
+ def forward(self, caption, train, force_drop_ids=None):
118
+ if train:
119
+ assert caption.shape[2:] == self.y_embedding.shape
120
+ use_dropout = self.uncond_prob > 0
121
+ if (train and use_dropout) or (force_drop_ids is not None):
122
+ caption = self.token_drop(caption, force_drop_ids)
123
+ caption = self.y_proj(caption)
124
+ return caption
125
+
126
+
127
+ # ============================
128
+ # Drop Path
129
+ # ============================
130
+ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
131
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
132
+
133
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
134
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
135
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
136
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
137
+ 'survival rate' as the argument.
138
+
139
+ """
140
+ if drop_prob == 0. or not training:
141
+ return x
142
+ keep_prob = 1 - drop_prob
143
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
144
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
145
+ if keep_prob > 0.0 and scale_by_keep:
146
+ random_tensor.div_(keep_prob)
147
+ return x * random_tensor
148
+
149
+ class DropPath(nn.Module):
150
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
151
+ """
152
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
153
+ super(DropPath, self).__init__()
154
+ self.drop_prob = drop_prob
155
+ self.scale_by_keep = scale_by_keep
156
+
157
+ def forward(self, x):
158
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
159
+
160
+ def extra_repr(self):
161
+ return f'drop_prob={round(self.drop_prob,3):0.3f}'
162
+
163
+
164
+ # ============================
165
+ # MHA
166
+ # ============================
167
+ class MultiHeadCrossAttention(nn.Module):
168
+ def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
169
+ super(MultiHeadCrossAttention, self).__init__()
170
+ assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
171
+
172
+ self.d_model = d_model
173
+ self.num_heads = num_heads
174
+ self.head_dim = d_model // num_heads
175
+
176
+ self.q_linear = nn.Linear(d_model, d_model)
177
+ self.kv_linear = nn.Linear(d_model, d_model * 2)
178
+ self.attn_drop = nn.Dropout(attn_drop)
179
+ self.proj = nn.Linear(d_model, d_model)
180
+ self.proj_drop = nn.Dropout(proj_drop)
181
+
182
+ def forward(self, x, cond, mask=None):
183
+ # query/value: img tokens; key: condition; mask: if padding tokens
184
+ B, N, C = x.shape
185
+
186
+ q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
187
+ kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
188
+ k, v = kv.unbind(2)
189
+
190
+ attn_bias = None
191
+ assert HAS_XFORMERS, "Please install xformers to use this module."
192
+ if mask is not None:
193
+ attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
194
+ x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
195
+
196
+ x = x.view(B, -1, C)
197
+ x = self.proj(x)
198
+ x = self.proj_drop(x)
199
+ return x
200
+
201
+
202
+ # ============================
203
+ # MLP
204
+ # ============================
205
+ def _ntuple(n):
206
+ def parse(x):
207
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
208
+ return tuple(x)
209
+ return tuple(repeat(x, n))
210
+ return parse
211
+
212
+ to_2tuple = _ntuple(2)
213
+
214
+ class Mlp(nn.Module):
215
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
216
+ """
217
+ def __init__(
218
+ self,
219
+ in_features,
220
+ hidden_features=None,
221
+ out_features=None,
222
+ act_layer=nn.GELU,
223
+ norm_layer=None,
224
+ bias=True,
225
+ drop=0.,
226
+ use_conv=False,
227
+ ):
228
+ super().__init__()
229
+ out_features = out_features or in_features
230
+ hidden_features = hidden_features or in_features
231
+ bias = to_2tuple(bias)
232
+ drop_probs = to_2tuple(drop)
233
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
234
+
235
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
236
+ self.act = act_layer()
237
+ self.drop1 = nn.Dropout(drop_probs[0])
238
+ self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
239
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
240
+ self.drop2 = nn.Dropout(drop_probs[1])
241
+
242
+ def forward(self, x):
243
+ x = self.fc1(x)
244
+ x = self.act(x)
245
+ x = self.drop1(x)
246
+ x = self.norm(x)
247
+ x = self.fc2(x)
248
+ x = self.drop2(x)
249
+ return x
250
+
251
+
252
+ # ============================
253
+ # Patch Embedding
254
+ # ============================
255
+ class PatchEmbed3D(nn.Module):
256
+ """Video to Patch Embedding.
257
+
258
+ Args:
259
+ patch_size (int): Patch token size. Default: (2,4,4).
260
+ in_chans (int): Number of input video channels. Default: 3.
261
+ embed_dim (int): Number of linear projection output channels. Default: 96.
262
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
263
+ """
264
+
265
+ def __init__(
266
+ self,
267
+ patch_size=(2, 4, 4),
268
+ in_chans=3,
269
+ embed_dim=96,
270
+ norm_layer=None,
271
+ flatten=True,
272
+ ):
273
+ super().__init__()
274
+ self.patch_size = patch_size
275
+ self.flatten = flatten
276
+
277
+ self.in_chans = in_chans
278
+ self.embed_dim = embed_dim
279
+
280
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
281
+ if norm_layer is not None:
282
+ self.norm = norm_layer(embed_dim)
283
+ else:
284
+ self.norm = None
285
+
286
+ def forward(self, x):
287
+ """Forward function."""
288
+ # padding
289
+ _, _, D, H, W = x.size()
290
+ if W % self.patch_size[2] != 0:
291
+ x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
292
+ if H % self.patch_size[1] != 0:
293
+ x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
294
+ if D % self.patch_size[0] != 0:
295
+ x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
296
+
297
+ x = self.proj(x) # (B C T H W)
298
+ if self.norm is not None:
299
+ D, Wh, Ww = x.size(2), x.size(3), x.size(4)
300
+ x = x.flatten(2).transpose(1, 2)
301
+ x = self.norm(x)
302
+ x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
303
+ if self.flatten:
304
+ x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
305
+ return x
306
+
307
+
308
+ # ============================
309
+ # T2I
310
+ # ============================
311
+ def t2i_modulate(x, shift, scale):
312
+ return x * (1 + scale) + shift
313
+
314
+ class T2IFinalLayer(nn.Module):
315
+ """
316
+ The final layer of PixArt.
317
+ """
318
+
319
+ def __init__(self, hidden_size, num_patch, out_channels):
320
+ super().__init__()
321
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
322
+ self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
323
+ self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5)
324
+ self.out_channels = out_channels
325
+
326
+ def forward(self, x, t):
327
+ shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
328
+ x = t2i_modulate(self.norm_final(x), shift, scale)
329
+ x = self.linear(x)
330
+ return x
331
+
332
+
333
+ # ============================
334
+ # Time Step Embedding
335
+ # ============================
336
+ class TimestepEmbedder(nn.Module):
337
+ """
338
+ Embeds scalar timesteps into vector representations.
339
+ """
340
+
341
+ def __init__(self, hidden_size, frequency_embedding_size=256):
342
+ super().__init__()
343
+ self.mlp = nn.Sequential(
344
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
345
+ nn.SiLU(),
346
+ nn.Linear(hidden_size, hidden_size, bias=True),
347
+ )
348
+ self.frequency_embedding_size = frequency_embedding_size
349
+
350
+ @staticmethod
351
+ def timestep_embedding(t, dim, max_period=10000):
352
+ """
353
+ Create sinusoidal timestep embeddings.
354
+ :param t: a 1-D Tensor of N indices, one per batch element.
355
+ These may be fractional.
356
+ :param dim: the dimension of the output.
357
+ :param max_period: controls the minimum frequency of the embeddings.
358
+ :return: an (N, D) Tensor of positional embeddings.
359
+ """
360
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
361
+ half = dim // 2
362
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)
363
+ freqs = freqs.to(device=t.device)
364
+ args = t[:, None].float() * freqs[None]
365
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
366
+ if dim % 2:
367
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
368
+ return embedding
369
+
370
+ def forward(self, t, dtype):
371
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
372
+ if t_freq.dtype != dtype:
373
+ t_freq = t_freq.to(dtype)
374
+ t_emb = self.mlp(t_freq)
375
+ return t_emb
376
+
377
+
378
+ # ============================
379
+ # STDiT Block
380
+ # ============================
381
+ class STDiTBlock(nn.Module):
382
+ def __init__(
383
+ self,
384
+ hidden_size,
385
+ num_heads,
386
+ d_s=None,
387
+ d_t=None,
388
+ mlp_ratio=4.0,
389
+ drop_path=0.0,
390
+ enable_flash_attn=False,
391
+ enable_layernorm_kernel=False,
392
+ enable_sequence_parallelism=False,
393
+ ):
394
+ super().__init__()
395
+ self.hidden_size = hidden_size
396
+ self.enable_flash_attn = enable_flash_attn
397
+ self._enable_sequence_parallelism = enable_sequence_parallelism
398
+
399
+ if enable_sequence_parallelism:
400
+ self.attn_cls = SeqParallelAttention
401
+ self.mha_cls = SeqParallelMultiHeadCrossAttention
402
+ else:
403
+ self.attn_cls = Attention
404
+ self.mha_cls = MultiHeadCrossAttention
405
+
406
+ self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
407
+ self.attn = self.attn_cls(
408
+ hidden_size,
409
+ num_heads=num_heads,
410
+ qkv_bias=True,
411
+ enable_flash_attn=enable_flash_attn,
412
+ )
413
+ self.cross_attn = self.mha_cls(hidden_size, num_heads)
414
+ self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
415
+ self.mlp = Mlp(
416
+ in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0
417
+ )
418
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
419
+ self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
420
+
421
+ # temporal attention
422
+ self.d_s = d_s
423
+ self.d_t = d_t
424
+
425
+ if self._enable_sequence_parallelism:
426
+ sp_size = dist.get_world_size(get_sequence_parallel_group())
427
+ # make sure d_t is divisible by sp_size
428
+ assert d_t % sp_size == 0
429
+ self.d_t = d_t // sp_size
430
+
431
+ self.attn_temp = self.attn_cls(
432
+ hidden_size,
433
+ num_heads=num_heads,
434
+ qkv_bias=True,
435
+ enable_flash_attn=self.enable_flash_attn,
436
+ )
437
+
438
+ def forward(self, x, y, t, mask=None, tpe=None):
439
+ B, N, C = x.shape
440
+
441
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
442
+ self.scale_shift_table[None] + t.reshape(B, 6, -1)
443
+ ).chunk(6, dim=1)
444
+ x_m = t2i_modulate(self.norm1(x), shift_msa, scale_msa)
445
+
446
+ # spatial branch
447
+ x_s = rearrange(x_m, "B (T S) C -> (B T) S C", T=self.d_t, S=self.d_s)
448
+ x_s = self.attn(x_s)
449
+ x_s = rearrange(x_s, "(B T) S C -> B (T S) C", T=self.d_t, S=self.d_s)
450
+ x = x + self.drop_path(gate_msa * x_s)
451
+
452
+ # temporal branch
453
+ x_t = rearrange(x, "B (T S) C -> (B S) T C", T=self.d_t, S=self.d_s)
454
+ if tpe is not None:
455
+ x_t = x_t + tpe
456
+ x_t = self.attn_temp(x_t)
457
+ x_t = rearrange(x_t, "(B S) T C -> B (T S) C", T=self.d_t, S=self.d_s)
458
+ x = x + self.drop_path(gate_msa * x_t)
459
+
460
+ # cross attn
461
+ x = x + self.cross_attn(x, y, mask)
462
+
463
+ # mlp
464
+ x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
465
+
466
+ return x
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:351efcf3c7656393606bffa1edf14d3b8508a9849157e806ab8582b002520a1f
3
+ size 3041758968
modeling_stdit.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.distributed as dist
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+ from .configuration_stdit import STDiTConfig
7
+ from .layers import (
8
+ STDiTBlock,
9
+ CaptionEmbedder,
10
+ PatchEmbed3D,
11
+ T2IFinalLayer,
12
+ TimestepEmbedder,
13
+ )
14
+ from .utils import (
15
+ approx_gelu,
16
+ get_1d_sincos_pos_embed,
17
+ get_2d_sincos_pos_embed,
18
+ )
19
+ from transformers import PreTrainedModel
20
+
21
+
22
+ class STDiT(PreTrainedModel):
23
+
24
+ config_class = STDiTConfig
25
+
26
+ def __init__(
27
+ self,
28
+ config
29
+ ):
30
+ super().__init__(config)
31
+ self.pred_sigma = config.pred_sigma
32
+ self.in_channels = config.in_channels
33
+ self.out_channels = config.in_channels * 2 if config.pred_sigma else config.in_channels
34
+ self.hidden_size = config.hidden_size
35
+ self.patch_size = config.patch_size
36
+ self.input_size = config.input_size
37
+ num_patches = np.prod([config.input_size[i] // config.patch_size[i] for i in range(3)])
38
+ self.num_patches = num_patches
39
+ self.num_temporal = config.input_size[0] // config.patch_size[0]
40
+ self.num_spatial = num_patches // self.num_temporal
41
+ self.num_heads = config.num_heads
42
+ self.no_temporal_pos_emb = config.no_temporal_pos_emb
43
+ self.depth = config.depth
44
+ self.mlp_ratio = config.mlp_ratio
45
+ self.enable_flash_attn = config.enable_flash_attn
46
+ self.enable_layernorm_kernel = config.enable_layernorm_kernel
47
+ self.space_scale = config.space_scale
48
+ self.time_scale = config.time_scale
49
+
50
+ self.register_buffer("pos_embed", self.get_spatial_pos_embed())
51
+ self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed())
52
+
53
+ self.x_embedder = PatchEmbed3D(config.patch_size, config.in_channels, config.hidden_size)
54
+ self.t_embedder = TimestepEmbedder(config.hidden_size)
55
+ self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(config.hidden_size, 6 * config.hidden_size, bias=True))
56
+ self.y_embedder = CaptionEmbedder(
57
+ in_channels=config.caption_channels,
58
+ hidden_size=config.hidden_size,
59
+ uncond_prob=config.class_dropout_prob,
60
+ act_layer=approx_gelu,
61
+ token_num=config.model_max_length,
62
+ )
63
+
64
+ drop_path = [x.item() for x in torch.linspace(0, config.drop_path, config.depth)]
65
+ self.blocks = nn.ModuleList(
66
+ [
67
+ STDiTBlock(
68
+ self.hidden_size,
69
+ self.num_heads,
70
+ mlp_ratio=self.mlp_ratio,
71
+ drop_path=drop_path[i],
72
+ enable_flash_attn=self.enable_flash_attn,
73
+ enable_layernorm_kernel=self.enable_layernorm_kernel,
74
+ enable_sequence_parallelism=config.enable_sequence_parallelism,
75
+ d_t=self.num_temporal,
76
+ d_s=self.num_spatial,
77
+ )
78
+ for i in range(self.depth)
79
+ ]
80
+ )
81
+ self.final_layer = T2IFinalLayer(config.hidden_size, np.prod(self.patch_size), self.out_channels)
82
+
83
+ # init model
84
+ self.initialize_weights()
85
+ self.initialize_temporal()
86
+ if config.freeze is not None:
87
+ assert config.freeze in ["not_temporal", "text"]
88
+ if config.freeze == "not_temporal":
89
+ self.freeze_not_temporal()
90
+ elif config.freeze == "text":
91
+ self.freeze_text()
92
+
93
+ # sequence parallel related configs
94
+ self.enable_sequence_parallelism = config.enable_sequence_parallelism
95
+ if config.enable_sequence_parallelism:
96
+ self.sp_rank = dist.get_rank(get_sequence_parallel_group())
97
+ else:
98
+ self.sp_rank = None
99
+
100
+ def forward(self, x, timestep, y, mask=None):
101
+ """
102
+ Forward pass of STDiT.
103
+ Args:
104
+ x (torch.Tensor): latent representation of video; of shape [B, C, T, H, W]
105
+ timestep (torch.Tensor): diffusion time steps; of shape [B]
106
+ y (torch.Tensor): representation of prompts; of shape [B, 1, N_token, C]
107
+ mask (torch.Tensor): mask for selecting prompt tokens; of shape [B, N_token]
108
+
109
+ Returns:
110
+ x (torch.Tensor): output latent representation; of shape [B, C, T, H, W]
111
+ """
112
+ # embedding
113
+ x = self.x_embedder(x) # [B, N, C]
114
+ x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
115
+ x = x + self.pos_embed
116
+ x = rearrange(x, "B T S C -> B (T S) C")
117
+
118
+ # shard over the sequence dim if sp is enabled
119
+ if self.enable_sequence_parallelism:
120
+ x = split_forward_gather_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="down")
121
+
122
+ t = self.t_embedder(timestep, dtype=x.dtype) # [B, C]
123
+ t0 = self.t_block(t) # [B, C]
124
+ y = self.y_embedder(y, self.training) # [B, 1, N_token, C]
125
+
126
+ if mask is not None:
127
+ if mask.shape[0] != y.shape[0]:
128
+ mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
129
+ mask = mask.squeeze(1).squeeze(1)
130
+ y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
131
+ y_lens = mask.sum(dim=1).tolist()
132
+ else:
133
+ y_lens = [y.shape[2]] * y.shape[0]
134
+ y = y.squeeze(1).view(1, -1, x.shape[-1])
135
+
136
+ # blocks
137
+ for i, block in enumerate(self.blocks):
138
+ if i == 0:
139
+ if self.enable_sequence_parallelism:
140
+ tpe = torch.chunk(
141
+ self.pos_embed_temporal, dist.get_world_size(get_sequence_parallel_group()), dim=1
142
+ )[self.sp_rank].contiguous()
143
+ else:
144
+ tpe = self.pos_embed_temporal
145
+ else:
146
+ tpe = None
147
+ x = auto_grad_checkpoint(block, x, y, t0, y_lens, tpe)
148
+
149
+ if self.enable_sequence_parallelism:
150
+ x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="up")
151
+ # x.shape: [B, N, C]
152
+
153
+ # final process
154
+ x = self.final_layer(x, t) # [B, N, C=T_p * H_p * W_p * C_out]
155
+ x = self.unpatchify(x) # [B, C_out, T, H, W]
156
+
157
+ # cast to float32 for better accuracy
158
+ x = x.to(torch.float32)
159
+ return x
160
+
161
+ def unpatchify(self, x):
162
+ """
163
+ Args:
164
+ x (torch.Tensor): of shape [B, N, C]
165
+
166
+ Return:
167
+ x (torch.Tensor): of shape [B, C_out, T, H, W]
168
+ """
169
+
170
+ N_t, N_h, N_w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
171
+ T_p, H_p, W_p = self.patch_size
172
+ x = rearrange(
173
+ x,
174
+ "B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
175
+ N_t=N_t,
176
+ N_h=N_h,
177
+ N_w=N_w,
178
+ T_p=T_p,
179
+ H_p=H_p,
180
+ W_p=W_p,
181
+ C_out=self.out_channels,
182
+ )
183
+ return x
184
+
185
+ def unpatchify_old(self, x):
186
+ c = self.out_channels
187
+ t, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
188
+ pt, ph, pw = self.patch_size
189
+
190
+ x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c))
191
+ x = rearrange(x, "n t h w r p q c -> n c t r h p w q")
192
+ imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
193
+ return imgs
194
+
195
+ def get_spatial_pos_embed(self, grid_size=None):
196
+ if grid_size is None:
197
+ grid_size = self.input_size[1:]
198
+ pos_embed = get_2d_sincos_pos_embed(
199
+ self.hidden_size,
200
+ (grid_size[0] // self.patch_size[1], grid_size[1] // self.patch_size[2]),
201
+ scale=self.space_scale,
202
+ )
203
+ pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
204
+ return pos_embed
205
+
206
+ def get_temporal_pos_embed(self):
207
+ pos_embed = get_1d_sincos_pos_embed(
208
+ self.hidden_size,
209
+ self.input_size[0] // self.patch_size[0],
210
+ scale=self.time_scale,
211
+ )
212
+ pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
213
+ return pos_embed
214
+
215
+ def freeze_not_temporal(self):
216
+ for n, p in self.named_parameters():
217
+ if "attn_temp" not in n:
218
+ p.requires_grad = False
219
+
220
+ def freeze_text(self):
221
+ for n, p in self.named_parameters():
222
+ if "cross_attn" in n:
223
+ p.requires_grad = False
224
+
225
+ def initialize_temporal(self):
226
+ for block in self.blocks:
227
+ nn.init.constant_(block.attn_temp.proj.weight, 0)
228
+ nn.init.constant_(block.attn_temp.proj.bias, 0)
229
+
230
+ def initialize_weights(self):
231
+ # Initialize transformer layers:
232
+ def _basic_init(module):
233
+ if isinstance(module, nn.Linear):
234
+ torch.nn.init.xavier_uniform_(module.weight)
235
+ if module.bias is not None:
236
+ nn.init.constant_(module.bias, 0)
237
+
238
+ self.apply(_basic_init)
239
+
240
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
241
+ w = self.x_embedder.proj.weight.data
242
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
243
+
244
+ # Initialize timestep embedding MLP:
245
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
246
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
247
+ nn.init.normal_(self.t_block[1].weight, std=0.02)
248
+
249
+ # Initialize caption embedding MLP:
250
+ nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)
251
+ nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
252
+
253
+ # Zero-out adaLN modulation layers in PixArt blocks:
254
+ for block in self.blocks:
255
+ nn.init.constant_(block.cross_attn.proj.weight, 0)
256
+ nn.init.constant_(block.cross_attn.proj.bias, 0)
257
+
258
+ # Zero-out output layers:
259
+ nn.init.constant_(self.final_layer.linear.weight, 0)
260
+ nn.init.constant_(self.final_layer.linear.bias, 0)
utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
6
+
7
+ def get_layernorm(hidden_size: torch.Tensor, eps: float, affine: bool, use_kernel: bool):
8
+ if use_kernel:
9
+ try:
10
+ from apex.normalization import FusedLayerNorm
11
+
12
+ return FusedLayerNorm(hidden_size, elementwise_affine=affine, eps=eps)
13
+ except ImportError:
14
+ raise RuntimeError("FusedLayerNorm not available. Please install apex.")
15
+ else:
16
+ return nn.LayerNorm(hidden_size, eps, elementwise_affine=affine)
17
+
18
+ def get_1d_sincos_pos_embed(embed_dim, length, scale=1.0):
19
+ pos = np.arange(0, length)[..., None] / scale
20
+ return get_1d_sincos_pos_embed_from_grid(embed_dim, pos)
21
+
22
+
23
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
24
+ """
25
+ embed_dim: output dimension for each position
26
+ pos: a list of positions to be encoded: size (M,)
27
+ out: (M, D)
28
+ """
29
+ assert embed_dim % 2 == 0
30
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
31
+ omega /= embed_dim / 2.0
32
+ omega = 1.0 / 10000**omega # (D/2,)
33
+
34
+ pos = pos.reshape(-1) # (M,)
35
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
36
+
37
+ emb_sin = np.sin(out) # (M, D/2)
38
+ emb_cos = np.cos(out) # (M, D/2)
39
+
40
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
41
+ return emb
42
+
43
+
44
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, scale=1.0, base_size=None):
45
+ """
46
+ grid_size: int of the grid height and width
47
+ return:
48
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
49
+ """
50
+ if not isinstance(grid_size, tuple):
51
+ grid_size = (grid_size, grid_size)
52
+
53
+ grid_h = np.arange(grid_size[0], dtype=np.float32) / scale
54
+ grid_w = np.arange(grid_size[1], dtype=np.float32) / scale
55
+ if base_size is not None:
56
+ grid_h *= base_size / grid_size[0]
57
+ grid_w *= base_size / grid_size[1]
58
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
59
+ grid = np.stack(grid, axis=0)
60
+
61
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
62
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
63
+ if cls_token and extra_tokens > 0:
64
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
65
+ return pos_embed
66
+
67
+
68
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
69
+ assert embed_dim % 2 == 0
70
+
71
+ # use half of dimensions to encode grid_h
72
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
73
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
74
+
75
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
76
+ return emb
77
+
78
+
79
+ def t2i_modulate(x, shift, scale):
80
+ return x * (1 + scale) + shift