camenduru commited on
Commit
cd273ef
·
1 Parent(s): 4a273c0

Delete unet_3d_condition.py

Browse files
Files changed (1) hide show
  1. unet_3d_condition.py +0 -500
unet_3d_condition.py DELETED
@@ -1,500 +0,0 @@
1
- # Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
2
- # Copyright 2023 The ModelScope Team.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- from dataclasses import dataclass
16
- from typing import Any, Dict, List, Optional, Tuple, Union
17
-
18
- import torch
19
- import torch.nn as nn
20
- import torch.utils.checkpoint
21
-
22
- from diffusers.configuration_utils import ConfigMixin, register_to_config
23
- from diffusers.utils import BaseOutput, logging
24
- from diffusers.models.embeddings import TimestepEmbedding, Timesteps
25
- from diffusers.models.modeling_utils import ModelMixin
26
- from diffusers.models.transformer_temporal import TransformerTemporalModel
27
- from .unet_3d_blocks import (
28
- CrossAttnDownBlock3D,
29
- CrossAttnUpBlock3D,
30
- DownBlock3D,
31
- UNetMidBlock3DCrossAttn,
32
- UpBlock3D,
33
- get_down_block,
34
- get_up_block,
35
- transformer_g_c
36
- )
37
-
38
-
39
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
-
41
-
42
- @dataclass
43
- class UNet3DConditionOutput(BaseOutput):
44
- """
45
- Args:
46
- sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
47
- Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
48
- """
49
-
50
- sample: torch.FloatTensor
51
-
52
-
53
- class UNet3DConditionModel(ModelMixin, ConfigMixin):
54
- r"""
55
- UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
56
- and returns sample shaped output.
57
-
58
- This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
59
- implements for all the models (such as downloading or saving, etc.)
60
-
61
- Parameters:
62
- sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
63
- Height and width of input/output sample.
64
- in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
65
- out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
66
- down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
67
- The tuple of downsample blocks to use.
68
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
69
- The tuple of upsample blocks to use.
70
- block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
71
- The tuple of output channels for each block.
72
- layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
73
- downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
74
- mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
75
- act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
76
- norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
77
- If `None`, it will skip the normalization and activation layers in post-processing
78
- norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
79
- cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
80
- attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
81
- """
82
-
83
- _supports_gradient_checkpointing = True
84
-
85
- @register_to_config
86
- def __init__(
87
- self,
88
- sample_size: Optional[int] = None,
89
- in_channels: int = 4,
90
- out_channels: int = 4,
91
- down_block_types: Tuple[str] = (
92
- "CrossAttnDownBlock3D",
93
- "CrossAttnDownBlock3D",
94
- "CrossAttnDownBlock3D",
95
- "DownBlock3D",
96
- ),
97
- up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
98
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
99
- layers_per_block: int = 2,
100
- downsample_padding: int = 1,
101
- mid_block_scale_factor: float = 1,
102
- act_fn: str = "silu",
103
- norm_num_groups: Optional[int] = 32,
104
- norm_eps: float = 1e-5,
105
- cross_attention_dim: int = 1024,
106
- attention_head_dim: Union[int, Tuple[int]] = 64,
107
- ):
108
- super().__init__()
109
-
110
- self.sample_size = sample_size
111
- self.gradient_checkpointing = False
112
- # Check inputs
113
- if len(down_block_types) != len(up_block_types):
114
- raise ValueError(
115
- f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
116
- )
117
-
118
- if len(block_out_channels) != len(down_block_types):
119
- raise ValueError(
120
- f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
121
- )
122
-
123
- if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
124
- raise ValueError(
125
- f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
126
- )
127
-
128
- # input
129
- conv_in_kernel = 3
130
- conv_out_kernel = 3
131
- conv_in_padding = (conv_in_kernel - 1) // 2
132
- self.conv_in = nn.Conv2d(
133
- in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
134
- )
135
-
136
- # time
137
- time_embed_dim = block_out_channels[0] * 4
138
- self.time_proj = Timesteps(block_out_channels[0], True, 0)
139
- timestep_input_dim = block_out_channels[0]
140
-
141
- self.time_embedding = TimestepEmbedding(
142
- timestep_input_dim,
143
- time_embed_dim,
144
- act_fn=act_fn,
145
- )
146
-
147
- self.transformer_in = TransformerTemporalModel(
148
- num_attention_heads=8,
149
- attention_head_dim=attention_head_dim,
150
- in_channels=block_out_channels[0],
151
- num_layers=1,
152
- )
153
-
154
- # class embedding
155
- self.down_blocks = nn.ModuleList([])
156
- self.up_blocks = nn.ModuleList([])
157
-
158
- if isinstance(attention_head_dim, int):
159
- attention_head_dim = (attention_head_dim,) * len(down_block_types)
160
-
161
- # down
162
- output_channel = block_out_channels[0]
163
- for i, down_block_type in enumerate(down_block_types):
164
- input_channel = output_channel
165
- output_channel = block_out_channels[i]
166
- is_final_block = i == len(block_out_channels) - 1
167
-
168
- down_block = get_down_block(
169
- down_block_type,
170
- num_layers=layers_per_block,
171
- in_channels=input_channel,
172
- out_channels=output_channel,
173
- temb_channels=time_embed_dim,
174
- add_downsample=not is_final_block,
175
- resnet_eps=norm_eps,
176
- resnet_act_fn=act_fn,
177
- resnet_groups=norm_num_groups,
178
- cross_attention_dim=cross_attention_dim,
179
- attn_num_head_channels=attention_head_dim[i],
180
- downsample_padding=downsample_padding,
181
- dual_cross_attention=False,
182
- )
183
- self.down_blocks.append(down_block)
184
-
185
- # mid
186
- self.mid_block = UNetMidBlock3DCrossAttn(
187
- in_channels=block_out_channels[-1],
188
- temb_channels=time_embed_dim,
189
- resnet_eps=norm_eps,
190
- resnet_act_fn=act_fn,
191
- output_scale_factor=mid_block_scale_factor,
192
- cross_attention_dim=cross_attention_dim,
193
- attn_num_head_channels=attention_head_dim[-1],
194
- resnet_groups=norm_num_groups,
195
- dual_cross_attention=False,
196
- )
197
-
198
- # count how many layers upsample the images
199
- self.num_upsamplers = 0
200
-
201
- # up
202
- reversed_block_out_channels = list(reversed(block_out_channels))
203
- reversed_attention_head_dim = list(reversed(attention_head_dim))
204
-
205
- output_channel = reversed_block_out_channels[0]
206
- for i, up_block_type in enumerate(up_block_types):
207
- is_final_block = i == len(block_out_channels) - 1
208
-
209
- prev_output_channel = output_channel
210
- output_channel = reversed_block_out_channels[i]
211
- input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
212
-
213
- # add upsample block for all BUT final layer
214
- if not is_final_block:
215
- add_upsample = True
216
- self.num_upsamplers += 1
217
- else:
218
- add_upsample = False
219
-
220
- up_block = get_up_block(
221
- up_block_type,
222
- num_layers=layers_per_block + 1,
223
- in_channels=input_channel,
224
- out_channels=output_channel,
225
- prev_output_channel=prev_output_channel,
226
- temb_channels=time_embed_dim,
227
- add_upsample=add_upsample,
228
- resnet_eps=norm_eps,
229
- resnet_act_fn=act_fn,
230
- resnet_groups=norm_num_groups,
231
- cross_attention_dim=cross_attention_dim,
232
- attn_num_head_channels=reversed_attention_head_dim[i],
233
- dual_cross_attention=False,
234
- )
235
- self.up_blocks.append(up_block)
236
- prev_output_channel = output_channel
237
-
238
- # out
239
- if norm_num_groups is not None:
240
- self.conv_norm_out = nn.GroupNorm(
241
- num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
242
- )
243
- self.conv_act = nn.SiLU()
244
- else:
245
- self.conv_norm_out = None
246
- self.conv_act = None
247
-
248
- conv_out_padding = (conv_out_kernel - 1) // 2
249
- self.conv_out = nn.Conv2d(
250
- block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
251
- )
252
-
253
- def set_attention_slice(self, slice_size):
254
- r"""
255
- Enable sliced attention computation.
256
-
257
- When this option is enabled, the attention module will split the input tensor in slices, to compute attention
258
- in several steps. This is useful to save some memory in exchange for a small speed decrease.
259
-
260
- Args:
261
- slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
262
- When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
263
- `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
264
- provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
265
- must be a multiple of `slice_size`.
266
- """
267
- sliceable_head_dims = []
268
-
269
- def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
270
- if hasattr(module, "set_attention_slice"):
271
- sliceable_head_dims.append(module.sliceable_head_dim)
272
-
273
- for child in module.children():
274
- fn_recursive_retrieve_slicable_dims(child)
275
-
276
- # retrieve number of attention layers
277
- for module in self.children():
278
- fn_recursive_retrieve_slicable_dims(module)
279
-
280
- num_slicable_layers = len(sliceable_head_dims)
281
-
282
- if slice_size == "auto":
283
- # half the attention head size is usually a good trade-off between
284
- # speed and memory
285
- slice_size = [dim // 2 for dim in sliceable_head_dims]
286
- elif slice_size == "max":
287
- # make smallest slice possible
288
- slice_size = num_slicable_layers * [1]
289
-
290
- slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
291
-
292
- if len(slice_size) != len(sliceable_head_dims):
293
- raise ValueError(
294
- f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
295
- f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
296
- )
297
-
298
- for i in range(len(slice_size)):
299
- size = slice_size[i]
300
- dim = sliceable_head_dims[i]
301
- if size is not None and size > dim:
302
- raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
303
-
304
- # Recursively walk through all the children.
305
- # Any children which exposes the set_attention_slice method
306
- # gets the message
307
- def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
308
- if hasattr(module, "set_attention_slice"):
309
- module.set_attention_slice(slice_size.pop())
310
-
311
- for child in module.children():
312
- fn_recursive_set_attention_slice(child, slice_size)
313
-
314
- reversed_slice_size = list(reversed(slice_size))
315
- for module in self.children():
316
- fn_recursive_set_attention_slice(module, reversed_slice_size)
317
-
318
- def _set_gradient_checkpointing(self, value=False):
319
- self.gradient_checkpointing = value
320
- self.mid_block.gradient_checkpointing = value
321
- for module in self.down_blocks + self.up_blocks:
322
- if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
323
- module.gradient_checkpointing = value
324
-
325
- def forward(
326
- self,
327
- sample: torch.FloatTensor,
328
- timestep: Union[torch.Tensor, float, int],
329
- encoder_hidden_states: torch.Tensor,
330
- class_labels: Optional[torch.Tensor] = None,
331
- timestep_cond: Optional[torch.Tensor] = None,
332
- attention_mask: Optional[torch.Tensor] = None,
333
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
334
- down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
335
- mid_block_additional_residual: Optional[torch.Tensor] = None,
336
- return_dict: bool = True,
337
- ) -> Union[UNet3DConditionOutput, Tuple]:
338
- r"""
339
- Args:
340
- sample (`torch.FloatTensor`): (batch, num_frames, channel, height, width) noisy inputs tensor
341
- timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
342
- encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
343
- return_dict (`bool`, *optional*, defaults to `True`):
344
- Whether or not to return a [`models.unet_2d_condition.UNet3DConditionOutput`] instead of a plain tuple.
345
- cross_attention_kwargs (`dict`, *optional*):
346
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
347
- `self.processor` in
348
- [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
349
-
350
- Returns:
351
- [`~models.unet_2d_condition.UNet3DConditionOutput`] or `tuple`:
352
- [`~models.unet_2d_condition.UNet3DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
353
- returning a tuple, the first element is the sample tensor.
354
- """
355
- # By default samples have to be AT least a multiple of the overall upsampling factor.
356
- # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
357
- # However, the upsampling interpolation output size can be forced to fit any upsampling size
358
- # on the fly if necessary.
359
- default_overall_up_factor = 2**self.num_upsamplers
360
-
361
- # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
362
- forward_upsample_size = False
363
- upsample_size = None
364
-
365
- if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
366
- logger.info("Forward upsample size to force interpolation output size.")
367
- forward_upsample_size = True
368
-
369
- # prepare attention_mask
370
- if attention_mask is not None:
371
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
372
- attention_mask = attention_mask.unsqueeze(1)
373
-
374
- # 1. time
375
- timesteps = timestep
376
- if not torch.is_tensor(timesteps):
377
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
378
- # This would be a good case for the `match` statement (Python 3.10+)
379
- is_mps = sample.device.type == "mps"
380
- if isinstance(timestep, float):
381
- dtype = torch.float32 if is_mps else torch.float64
382
- else:
383
- dtype = torch.int32 if is_mps else torch.int64
384
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
385
- elif len(timesteps.shape) == 0:
386
- timesteps = timesteps[None].to(sample.device)
387
-
388
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
389
- num_frames = sample.shape[2]
390
- timesteps = timesteps.expand(sample.shape[0])
391
-
392
- t_emb = self.time_proj(timesteps)
393
-
394
- # timesteps does not contain any weights and will always return f32 tensors
395
- # but time_embedding might actually be running in fp16. so we need to cast here.
396
- # there might be better ways to encapsulate this.
397
- t_emb = t_emb.to(dtype=self.dtype)
398
-
399
- emb = self.time_embedding(t_emb, timestep_cond)
400
- emb = emb.repeat_interleave(repeats=num_frames, dim=0)
401
- encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
402
-
403
- # 2. pre-process
404
- sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
405
- sample = self.conv_in(sample)
406
-
407
- if num_frames > 1:
408
- if self.gradient_checkpointing:
409
- sample = transformer_g_c(self.transformer_in, sample, num_frames)
410
- else:
411
- sample = self.transformer_in(sample, num_frames=num_frames).sample
412
-
413
- # 3. down
414
- down_block_res_samples = (sample,)
415
- for downsample_block in self.down_blocks:
416
- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
417
- sample, res_samples = downsample_block(
418
- hidden_states=sample,
419
- temb=emb,
420
- encoder_hidden_states=encoder_hidden_states,
421
- attention_mask=attention_mask,
422
- num_frames=num_frames,
423
- cross_attention_kwargs=cross_attention_kwargs,
424
- )
425
- else:
426
- sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames)
427
-
428
- down_block_res_samples += res_samples
429
-
430
- if down_block_additional_residuals is not None:
431
- new_down_block_res_samples = ()
432
-
433
- for down_block_res_sample, down_block_additional_residual in zip(
434
- down_block_res_samples, down_block_additional_residuals
435
- ):
436
- down_block_res_sample = down_block_res_sample + down_block_additional_residual
437
- new_down_block_res_samples += (down_block_res_sample,)
438
-
439
- down_block_res_samples = new_down_block_res_samples
440
-
441
- # 4. mid
442
- if self.mid_block is not None:
443
- sample = self.mid_block(
444
- sample,
445
- emb,
446
- encoder_hidden_states=encoder_hidden_states,
447
- attention_mask=attention_mask,
448
- num_frames=num_frames,
449
- cross_attention_kwargs=cross_attention_kwargs,
450
- )
451
-
452
- if mid_block_additional_residual is not None:
453
- sample = sample + mid_block_additional_residual
454
-
455
- # 5. up
456
- for i, upsample_block in enumerate(self.up_blocks):
457
- is_final_block = i == len(self.up_blocks) - 1
458
-
459
- res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
460
- down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
461
-
462
- # if we have not reached the final block and need to forward the
463
- # upsample size, we do it here
464
- if not is_final_block and forward_upsample_size:
465
- upsample_size = down_block_res_samples[-1].shape[2:]
466
-
467
- if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
468
- sample = upsample_block(
469
- hidden_states=sample,
470
- temb=emb,
471
- res_hidden_states_tuple=res_samples,
472
- encoder_hidden_states=encoder_hidden_states,
473
- upsample_size=upsample_size,
474
- attention_mask=attention_mask,
475
- num_frames=num_frames,
476
- cross_attention_kwargs=cross_attention_kwargs,
477
- )
478
- else:
479
- sample = upsample_block(
480
- hidden_states=sample,
481
- temb=emb,
482
- res_hidden_states_tuple=res_samples,
483
- upsample_size=upsample_size,
484
- num_frames=num_frames,
485
- )
486
-
487
- # 6. post-process
488
- if self.conv_norm_out:
489
- sample = self.conv_norm_out(sample)
490
- sample = self.conv_act(sample)
491
-
492
- sample = self.conv_out(sample)
493
-
494
- # reshape to (batch, channel, framerate, width, height)
495
- sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4)
496
-
497
- if not return_dict:
498
- return (sample,)
499
-
500
- return UNet3DConditionOutput(sample=sample)