ashawkey commited on
Commit
8747d5d
·
1 Parent(s): fbd4c7f

reformat, further clean

Browse files
convert_mvdream_to_diffusers.py CHANGED
@@ -4,7 +4,7 @@ import argparse
4
  import torch
5
  import sys
6
 
7
- sys.path.insert(0, '.')
8
 
9
  from diffusers.models import (
10
  AutoencoderKL,
@@ -15,20 +15,29 @@ from diffusers.utils import logging
15
  from typing import Any
16
  from accelerate import init_empty_weights
17
  from accelerate.utils import set_module_tensor_to_device
18
- from mvdream.models import MultiViewUNetWrapperModel
19
  from mvdream.pipeline_mvdream import MVDreamStableDiffusionPipeline
20
  from transformers import CLIPTokenizer, CLIPTextModel
21
 
22
  logger = logging.get_logger(__name__)
23
 
24
 
25
- def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None):
 
 
 
 
 
 
 
26
  """
27
  This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
28
  attention layers, and takes into account additional replacements that may arise.
29
  Assigns the weights to the new checkpoint.
30
  """
31
- assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
 
 
32
 
33
  # Splits the attention layers into three variables.
34
  if attention_paths_to_split is not None:
@@ -41,7 +50,9 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
41
  assert config is not None
42
  num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
43
 
44
- old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
 
 
45
  query, key, value = old_tensor.split(channels // num_heads, dim=1)
46
 
47
  checkpoint[path_map["query"]] = query.reshape(target_shape)
@@ -52,7 +63,10 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
52
  new_path = path["new"]
53
 
54
  # These have already been assigned
55
- if attention_paths_to_split is not None and new_path in attention_paths_to_split:
 
 
 
56
  continue
57
 
58
  # Global renaming happens here
@@ -65,7 +79,9 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
65
  new_path = new_path.replace(replacement["old"], replacement["new"])
66
 
67
  # proj_attn.weight has to be converted from conv 1D to linear
68
- is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path)
 
 
69
  shape = old_checkpoint[path["old"]].shape
70
  if is_attn_weight and len(shape) == 3:
71
  checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
@@ -122,17 +138,29 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
122
 
123
  new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
124
  new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
125
- new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
 
 
126
  new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
127
- new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
128
- new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
 
 
 
 
129
 
130
  new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
131
  new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
132
- new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
 
 
133
  new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
134
- new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
135
- new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
 
 
 
 
136
 
137
  new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
138
  new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
@@ -140,23 +168,55 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
140
  new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
141
 
142
  # Retrieves the keys for the encoder down blocks only
143
- num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
144
- down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)}
 
 
 
 
 
 
 
 
 
145
 
146
  # Retrieves the keys for the decoder up blocks only
147
- num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
148
- up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
 
 
 
 
 
 
 
 
 
149
 
150
  for i in range(num_down_blocks):
151
- resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
 
 
 
 
152
 
153
  if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
154
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight")
155
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias")
 
 
 
 
156
 
157
  paths = renew_vae_resnet_paths(resnets)
158
  meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
159
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
 
 
 
 
 
 
160
 
161
  mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
162
  num_mid_res_blocks = 2
@@ -165,25 +225,51 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
165
 
166
  paths = renew_vae_resnet_paths(resnets)
167
  meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
168
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
 
 
 
 
 
 
169
 
170
  mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
171
  paths = renew_vae_attention_paths(mid_attentions)
172
  meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
173
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
 
 
 
 
 
 
174
  conv_attn_to_linear(new_checkpoint)
175
 
176
  for i in range(num_up_blocks):
177
  block_id = num_up_blocks - 1 - i
178
- resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key]
 
 
 
 
179
 
180
  if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
181
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"]
182
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"]
 
 
 
 
183
 
184
  paths = renew_vae_resnet_paths(resnets)
185
  meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
186
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
 
 
 
 
 
 
187
 
188
  mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
189
  num_mid_res_blocks = 2
@@ -192,12 +278,24 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
192
 
193
  paths = renew_vae_resnet_paths(resnets)
194
  meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
195
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
 
 
 
 
 
 
196
 
197
  mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
198
  paths = renew_vae_attention_paths(mid_attentions)
199
  meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
200
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
 
 
 
 
 
 
201
  conv_attn_to_linear(new_checkpoint)
202
  return new_checkpoint
203
 
@@ -211,7 +309,9 @@ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
211
  new_item = old_item
212
 
213
  new_item = new_item.replace("nin_shortcut", "conv_shortcut")
214
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
 
 
215
 
216
  mapping.append({"old": old_item, "new": new_item})
217
 
@@ -241,7 +341,9 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
241
  new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
242
  new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
243
 
244
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
 
 
245
 
246
  mapping.append({"old": old_item, "new": new_item})
247
 
@@ -259,8 +361,12 @@ def conv_attn_to_linear(checkpoint):
259
  if checkpoint[key].ndim > 2:
260
  checkpoint[key] = checkpoint[key][:, :, 0]
261
 
 
262
  def create_unet_config(original_config) -> Any:
263
- return OmegaConf.to_container(original_config.model.params.unet_config.params, resolve=True)
 
 
 
264
 
265
  def convert_from_original_mvdream_ckpt(checkpoint_path, original_config_file, device):
266
  checkpoint = torch.load(checkpoint_path, map_location=device)
@@ -271,7 +377,9 @@ def convert_from_original_mvdream_ckpt(checkpoint_path, original_config_file, de
271
  # print(f"Original Config: {original_config}")
272
  prediction_type = "epsilon"
273
  image_size = 256
274
- num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000
 
 
275
  beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02
276
  beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085
277
  scheduler = DDIMScheduler(
@@ -297,10 +405,16 @@ def convert_from_original_mvdream_ckpt(checkpoint_path, original_config_file, de
297
  # )
298
  # print(f"Unet Config: {original_config.model.params.unet_config.params}")
299
  unet_config = create_unet_config(original_config)
300
- unet: MultiViewUNetWrapperModel = MultiViewUNetWrapperModel(**unet_config)
301
  unet.register_to_config(**unet_config)
302
  # print(f"Unet State Dict: {unet.state_dict().keys()}")
303
- unet.load_state_dict({key.replace("model.diffusion_model.", "unet."): value for key, value in checkpoint.items() if key.replace("model.diffusion_model.", "unet.") in unet.state_dict()})
 
 
 
 
 
 
304
  for param_name, param in unet.state_dict().items():
305
  set_module_tensor_to_device(unet, param_name, device=device, value=param)
306
 
@@ -308,10 +422,14 @@ def convert_from_original_mvdream_ckpt(checkpoint_path, original_config_file, de
308
  vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
309
  converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
310
 
311
- if ("model" in original_config and "params" in original_config.model and "scale_factor" in original_config.model.params):
 
 
 
 
312
  vae_scaling_factor = original_config.model.params.scale_factor
313
  else:
314
- vae_scaling_factor = 0.18215 # default SD scaling factor
315
 
316
  vae_config["scaling_factor"] = vae_scaling_factor
317
 
@@ -322,13 +440,19 @@ def convert_from_original_mvdream_ckpt(checkpoint_path, original_config_file, de
322
  set_module_tensor_to_device(vae, param_name, device=device, value=param)
323
 
324
  if original_config.model.params.unet_config.params.context_dim == 768:
325
- tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
326
- text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device=device) # type: ignore
 
 
327
  elif original_config.model.params.unet_config.params.context_dim == 1024:
328
- tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="tokenizer")
329
- text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="text_encoder").to(device=device) # type: ignore
 
 
330
  else:
331
- raise ValueError(f"Unknown context_dim: {original_config.model.paams.unet_config.params.context_dim}")
 
 
332
 
333
  pipe = MVDreamStableDiffusionPipeline(
334
  vae=vae,
@@ -344,7 +468,13 @@ def convert_from_original_mvdream_ckpt(checkpoint_path, original_config_file, de
344
  if __name__ == "__main__":
345
  parser = argparse.ArgumentParser()
346
 
347
- parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert.")
 
 
 
 
 
 
348
  parser.add_argument(
349
  "--original_config_file",
350
  default=None,
@@ -356,13 +486,33 @@ if __name__ == "__main__":
356
  action="store_true",
357
  help="Whether to store pipeline in safetensors format or not.",
358
  )
359
- parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
360
- parser.add_argument("--test", action="store_true", help="Whether to test inference after convertion.")
361
- parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
362
- parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  args = parser.parse_args()
364
-
365
- args.device = torch.device(args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
366
 
367
  pipe = convert_from_original_mvdream_ckpt(
368
  checkpoint_path=args.checkpoint_path,
@@ -375,7 +525,7 @@ if __name__ == "__main__":
375
 
376
  print(f"Saving pipeline to {args.dump_path}...")
377
  pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
378
-
379
  if args.test:
380
  try:
381
  print(f"Testing each subcomponent of the pipeline...")
@@ -388,10 +538,10 @@ if __name__ == "__main__":
388
  device=args.device,
389
  )
390
  for i, image in enumerate(images):
391
- image.save(f"image_{i}.png") # type: ignore
392
 
393
  print(f"Testing entire pipeline...")
394
- loaded_pipe: MVDreamStableDiffusionPipeline = MVDreamStableDiffusionPipeline.from_pretrained(args.dump_path, safe_serialization=args.to_safetensors) # type: ignore
395
  images = loaded_pipe(
396
  prompt="Head of Hatsune Miku",
397
  negative_prompt="painting, bad quality, flat",
@@ -401,7 +551,7 @@ if __name__ == "__main__":
401
  device=args.device,
402
  )
403
  for i, image in enumerate(images):
404
- image.save(f"image_{i}.png") # type: ignore
405
  except Exception as e:
406
  print(f"Failed to test inference: {e}")
407
  raise e from e
 
4
  import torch
5
  import sys
6
 
7
+ sys.path.insert(0, ".")
8
 
9
  from diffusers.models import (
10
  AutoencoderKL,
 
15
  from typing import Any
16
  from accelerate import init_empty_weights
17
  from accelerate.utils import set_module_tensor_to_device
18
+ from mvdream.models import MultiViewUNetModel
19
  from mvdream.pipeline_mvdream import MVDreamStableDiffusionPipeline
20
  from transformers import CLIPTokenizer, CLIPTextModel
21
 
22
  logger = logging.get_logger(__name__)
23
 
24
 
25
+ def assign_to_checkpoint(
26
+ paths,
27
+ checkpoint,
28
+ old_checkpoint,
29
+ attention_paths_to_split=None,
30
+ additional_replacements=None,
31
+ config=None,
32
+ ):
33
  """
34
  This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
35
  attention layers, and takes into account additional replacements that may arise.
36
  Assigns the weights to the new checkpoint.
37
  """
38
+ assert isinstance(
39
+ paths, list
40
+ ), "Paths should be a list of dicts containing 'old' and 'new' keys."
41
 
42
  # Splits the attention layers into three variables.
43
  if attention_paths_to_split is not None:
 
50
  assert config is not None
51
  num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
52
 
53
+ old_tensor = old_tensor.reshape(
54
+ (num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]
55
+ )
56
  query, key, value = old_tensor.split(channels // num_heads, dim=1)
57
 
58
  checkpoint[path_map["query"]] = query.reshape(target_shape)
 
63
  new_path = path["new"]
64
 
65
  # These have already been assigned
66
+ if (
67
+ attention_paths_to_split is not None
68
+ and new_path in attention_paths_to_split
69
+ ):
70
  continue
71
 
72
  # Global renaming happens here
 
79
  new_path = new_path.replace(replacement["old"], replacement["new"])
80
 
81
  # proj_attn.weight has to be converted from conv 1D to linear
82
+ is_attn_weight = "proj_attn.weight" in new_path or (
83
+ "attentions" in new_path and "to_" in new_path
84
+ )
85
  shape = old_checkpoint[path["old"]].shape
86
  if is_attn_weight and len(shape) == 3:
87
  checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
 
138
 
139
  new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
140
  new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
141
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict[
142
+ "encoder.conv_out.weight"
143
+ ]
144
  new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
145
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict[
146
+ "encoder.norm_out.weight"
147
+ ]
148
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict[
149
+ "encoder.norm_out.bias"
150
+ ]
151
 
152
  new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
153
  new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
154
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict[
155
+ "decoder.conv_out.weight"
156
+ ]
157
  new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
158
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict[
159
+ "decoder.norm_out.weight"
160
+ ]
161
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict[
162
+ "decoder.norm_out.bias"
163
+ ]
164
 
165
  new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
166
  new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
 
168
  new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
169
 
170
  # Retrieves the keys for the encoder down blocks only
171
+ num_down_blocks = len(
172
+ {
173
+ ".".join(layer.split(".")[:3])
174
+ for layer in vae_state_dict
175
+ if "encoder.down" in layer
176
+ }
177
+ )
178
+ down_blocks = {
179
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key]
180
+ for layer_id in range(num_down_blocks)
181
+ }
182
 
183
  # Retrieves the keys for the decoder up blocks only
184
+ num_up_blocks = len(
185
+ {
186
+ ".".join(layer.split(".")[:3])
187
+ for layer in vae_state_dict
188
+ if "decoder.up" in layer
189
+ }
190
+ )
191
+ up_blocks = {
192
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key]
193
+ for layer_id in range(num_up_blocks)
194
+ }
195
 
196
  for i in range(num_down_blocks):
197
+ resnets = [
198
+ key
199
+ for key in down_blocks[i]
200
+ if f"down.{i}" in key and f"down.{i}.downsample" not in key
201
+ ]
202
 
203
  if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
204
+ new_checkpoint[
205
+ f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"
206
+ ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight")
207
+ new_checkpoint[
208
+ f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"
209
+ ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias")
210
 
211
  paths = renew_vae_resnet_paths(resnets)
212
  meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
213
+ assign_to_checkpoint(
214
+ paths,
215
+ new_checkpoint,
216
+ vae_state_dict,
217
+ additional_replacements=[meta_path],
218
+ config=config,
219
+ )
220
 
221
  mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
222
  num_mid_res_blocks = 2
 
225
 
226
  paths = renew_vae_resnet_paths(resnets)
227
  meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
228
+ assign_to_checkpoint(
229
+ paths,
230
+ new_checkpoint,
231
+ vae_state_dict,
232
+ additional_replacements=[meta_path],
233
+ config=config,
234
+ )
235
 
236
  mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
237
  paths = renew_vae_attention_paths(mid_attentions)
238
  meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
239
+ assign_to_checkpoint(
240
+ paths,
241
+ new_checkpoint,
242
+ vae_state_dict,
243
+ additional_replacements=[meta_path],
244
+ config=config,
245
+ )
246
  conv_attn_to_linear(new_checkpoint)
247
 
248
  for i in range(num_up_blocks):
249
  block_id = num_up_blocks - 1 - i
250
+ resnets = [
251
+ key
252
+ for key in up_blocks[block_id]
253
+ if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
254
+ ]
255
 
256
  if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
257
+ new_checkpoint[
258
+ f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"
259
+ ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"]
260
+ new_checkpoint[
261
+ f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"
262
+ ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"]
263
 
264
  paths = renew_vae_resnet_paths(resnets)
265
  meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
266
+ assign_to_checkpoint(
267
+ paths,
268
+ new_checkpoint,
269
+ vae_state_dict,
270
+ additional_replacements=[meta_path],
271
+ config=config,
272
+ )
273
 
274
  mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
275
  num_mid_res_blocks = 2
 
278
 
279
  paths = renew_vae_resnet_paths(resnets)
280
  meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
281
+ assign_to_checkpoint(
282
+ paths,
283
+ new_checkpoint,
284
+ vae_state_dict,
285
+ additional_replacements=[meta_path],
286
+ config=config,
287
+ )
288
 
289
  mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
290
  paths = renew_vae_attention_paths(mid_attentions)
291
  meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
292
+ assign_to_checkpoint(
293
+ paths,
294
+ new_checkpoint,
295
+ vae_state_dict,
296
+ additional_replacements=[meta_path],
297
+ config=config,
298
+ )
299
  conv_attn_to_linear(new_checkpoint)
300
  return new_checkpoint
301
 
 
309
  new_item = old_item
310
 
311
  new_item = new_item.replace("nin_shortcut", "conv_shortcut")
312
+ new_item = shave_segments(
313
+ new_item, n_shave_prefix_segments=n_shave_prefix_segments
314
+ )
315
 
316
  mapping.append({"old": old_item, "new": new_item})
317
 
 
341
  new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
342
  new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
343
 
344
+ new_item = shave_segments(
345
+ new_item, n_shave_prefix_segments=n_shave_prefix_segments
346
+ )
347
 
348
  mapping.append({"old": old_item, "new": new_item})
349
 
 
361
  if checkpoint[key].ndim > 2:
362
  checkpoint[key] = checkpoint[key][:, :, 0]
363
 
364
+
365
  def create_unet_config(original_config) -> Any:
366
+ return OmegaConf.to_container(
367
+ original_config.model.params.unet_config.params, resolve=True
368
+ )
369
+
370
 
371
  def convert_from_original_mvdream_ckpt(checkpoint_path, original_config_file, device):
372
  checkpoint = torch.load(checkpoint_path, map_location=device)
 
377
  # print(f"Original Config: {original_config}")
378
  prediction_type = "epsilon"
379
  image_size = 256
380
+ num_train_timesteps = (
381
+ getattr(original_config.model.params, "timesteps", None) or 1000
382
+ )
383
  beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02
384
  beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085
385
  scheduler = DDIMScheduler(
 
405
  # )
406
  # print(f"Unet Config: {original_config.model.params.unet_config.params}")
407
  unet_config = create_unet_config(original_config)
408
+ unet = MultiViewUNetModel(**unet_config)
409
  unet.register_to_config(**unet_config)
410
  # print(f"Unet State Dict: {unet.state_dict().keys()}")
411
+ unet.load_state_dict(
412
+ {
413
+ key.replace("model.diffusion_model.", ""): value
414
+ for key, value in checkpoint.items()
415
+ if key.replace("model.diffusion_model.", "") in unet.state_dict()
416
+ }
417
+ )
418
  for param_name, param in unet.state_dict().items():
419
  set_module_tensor_to_device(unet, param_name, device=device, value=param)
420
 
 
422
  vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
423
  converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
424
 
425
+ if (
426
+ "model" in original_config
427
+ and "params" in original_config.model
428
+ and "scale_factor" in original_config.model.params
429
+ ):
430
  vae_scaling_factor = original_config.model.params.scale_factor
431
  else:
432
+ vae_scaling_factor = 0.18215 # default SD scaling factor
433
 
434
  vae_config["scaling_factor"] = vae_scaling_factor
435
 
 
440
  set_module_tensor_to_device(vae, param_name, device=device, value=param)
441
 
442
  if original_config.model.params.unet_config.params.context_dim == 768:
443
+ tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
444
+ "openai/clip-vit-large-patch14"
445
+ )
446
+ text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device=device) # type: ignore
447
  elif original_config.model.params.unet_config.params.context_dim == 1024:
448
+ tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
449
+ "stabilityai/stable-diffusion-2-1", subfolder="tokenizer"
450
+ )
451
+ text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="text_encoder").to(device=device) # type: ignore
452
  else:
453
+ raise ValueError(
454
+ f"Unknown context_dim: {original_config.model.paams.unet_config.params.context_dim}"
455
+ )
456
 
457
  pipe = MVDreamStableDiffusionPipeline(
458
  vae=vae,
 
468
  if __name__ == "__main__":
469
  parser = argparse.ArgumentParser()
470
 
471
+ parser.add_argument(
472
+ "--checkpoint_path",
473
+ default=None,
474
+ type=str,
475
+ required=True,
476
+ help="Path to the checkpoint to convert.",
477
+ )
478
  parser.add_argument(
479
  "--original_config_file",
480
  default=None,
 
486
  action="store_true",
487
  help="Whether to store pipeline in safetensors format or not.",
488
  )
489
+ parser.add_argument(
490
+ "--half", action="store_true", help="Save weights in half precision."
491
+ )
492
+ parser.add_argument(
493
+ "--test",
494
+ action="store_true",
495
+ help="Whether to test inference after convertion.",
496
+ )
497
+ parser.add_argument(
498
+ "--dump_path",
499
+ default=None,
500
+ type=str,
501
+ required=True,
502
+ help="Path to the output model.",
503
+ )
504
+ parser.add_argument(
505
+ "--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)"
506
+ )
507
  args = parser.parse_args()
508
+
509
+ args.device = torch.device(
510
+ args.device
511
+ if args.device is not None
512
+ else "cuda"
513
+ if torch.cuda.is_available()
514
+ else "cpu"
515
+ )
516
 
517
  pipe = convert_from_original_mvdream_ckpt(
518
  checkpoint_path=args.checkpoint_path,
 
525
 
526
  print(f"Saving pipeline to {args.dump_path}...")
527
  pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
528
+
529
  if args.test:
530
  try:
531
  print(f"Testing each subcomponent of the pipeline...")
 
538
  device=args.device,
539
  )
540
  for i, image in enumerate(images):
541
+ image.save(f"image_{i}.png") # type: ignore
542
 
543
  print(f"Testing entire pipeline...")
544
+ loaded_pipe: MVDreamStableDiffusionPipeline = MVDreamStableDiffusionPipeline.from_pretrained(args.dump_path, safe_serialization=args.to_safetensors) # type: ignore
545
  images = loaded_pipe(
546
  prompt="Head of Hatsune Miku",
547
  negative_prompt="painting, bad quality, flat",
 
551
  device=args.device,
552
  )
553
  for i, image in enumerate(images):
554
+ image.save(f"image_{i}.png") # type: ignore
555
  except Exception as e:
556
  print(f"Failed to test inference: {e}")
557
  raise e from e
main.py CHANGED
@@ -4,18 +4,25 @@ import numpy as np
4
  import argparse
5
  from mvdream.pipeline_mvdream import MVDreamStableDiffusionPipeline
6
 
7
- pipe = MVDreamStableDiffusionPipeline.from_pretrained('./weights', torch_dtype=torch.float16)
 
 
 
 
8
  pipe = pipe.to("cuda")
9
 
10
 
11
- parser = argparse.ArgumentParser(description='MVDream')
12
- parser.add_argument('prompt', type=str, default="a cute owl 3d model")
13
  args = parser.parse_args()
14
 
15
  while True:
16
  image = pipe(args.prompt)
17
- grid = np.concatenate([
18
- np.concatenate([image[0], image[2]], axis=0),
19
- np.concatenate([image[1], image[3]], axis=0),
20
- ], axis=1)
21
- kiui.vis.plot_image(grid)
 
 
 
 
4
  import argparse
5
  from mvdream.pipeline_mvdream import MVDreamStableDiffusionPipeline
6
 
7
+ pipe = MVDreamStableDiffusionPipeline.from_pretrained(
8
+ # "./weights", # local weights
9
+ "ashawkey/mvdream-sd2.1-diffusers",
10
+ torch_dtype=torch.float16
11
+ )
12
  pipe = pipe.to("cuda")
13
 
14
 
15
+ parser = argparse.ArgumentParser(description="MVDream")
16
+ parser.add_argument("prompt", type=str, default="a cute owl 3d model")
17
  args = parser.parse_args()
18
 
19
  while True:
20
  image = pipe(args.prompt)
21
+ grid = np.concatenate(
22
+ [
23
+ np.concatenate([image[0], image[2]], axis=0),
24
+ np.concatenate([image[1], image[3]], axis=0),
25
+ ],
26
+ axis=1,
27
+ )
28
+ kiui.vis.plot_image(grid)
mvdream/attention.py CHANGED
@@ -12,8 +12,9 @@ from typing import Optional, Any
12
  from .util import checkpoint
13
 
14
  try:
15
- import xformers # type: ignore
16
- import xformers.ops # type: ignore
 
17
  XFORMERS_IS_AVAILBLE = True
18
  except:
19
  XFORMERS_IS_AVAILBLE = False
@@ -47,7 +48,6 @@ def init_(tensor):
47
 
48
  # feedforward
49
  class GEGLU(nn.Module):
50
-
51
  def __init__(self, dim_in, dim_out):
52
  super().__init__()
53
  self.proj = nn.Linear(dim_in, dim_out * 2)
@@ -58,14 +58,19 @@ class GEGLU(nn.Module):
58
 
59
 
60
  class FeedForward(nn.Module):
61
-
62
- def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
63
  super().__init__()
64
  inner_dim = int(dim * mult)
65
  dim_out = default(dim_out, dim)
66
- project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
 
 
 
 
67
 
68
- self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
 
 
69
 
70
  def forward(self, x):
71
  return self.net(x)
@@ -81,20 +86,29 @@ def zero_module(module):
81
 
82
 
83
  def Normalize(in_channels):
84
- return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
 
 
85
 
86
 
87
  class SpatialSelfAttention(nn.Module):
88
-
89
  def __init__(self, in_channels):
90
  super().__init__()
91
  self.in_channels = in_channels
92
 
93
  self.norm = Normalize(in_channels)
94
- self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
95
- self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
96
- self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
97
- self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
 
 
 
 
 
 
 
 
98
 
99
  def forward(self, x):
100
  h_ = x
@@ -105,26 +119,25 @@ class SpatialSelfAttention(nn.Module):
105
 
106
  # compute attention
107
  b, c, h, w = q.shape
108
- q = rearrange(q, 'b c h w -> b (h w) c')
109
- k = rearrange(k, 'b c h w -> b c (h w)')
110
- w_ = torch.einsum('bij,bjk->bik', q, k)
111
 
112
- w_ = w_ * (int(c)**(-0.5))
113
  w_ = torch.nn.functional.softmax(w_, dim=2)
114
 
115
  # attend to values
116
- v = rearrange(v, 'b c h w -> b c (h w)')
117
- w_ = rearrange(w_, 'b i j -> b j i')
118
- h_ = torch.einsum('bij,bjk->bik', v, w_)
119
- h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
120
  h_ = self.proj_out(h_)
121
 
122
  return x + h_
123
 
124
 
125
  class CrossAttention(nn.Module):
126
-
127
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
128
  super().__init__()
129
  inner_dim = dim_head * heads
130
  context_dim = default(context_dim, query_dim)
@@ -136,7 +149,9 @@ class CrossAttention(nn.Module):
136
  self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
137
  self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
138
 
139
- self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
 
 
140
 
141
  def forward(self, x, context=None, mask=None):
142
  h = self.heads
@@ -146,29 +161,29 @@ class CrossAttention(nn.Module):
146
  k = self.to_k(context)
147
  v = self.to_v(context)
148
 
149
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
150
 
151
  # force cast to fp32 to avoid overflowing
152
  if _ATTN_PRECISION == "fp32":
153
- with autocast(enabled=False, device_type='cuda'):
154
  q, k = q.float(), k.float()
155
- sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
156
  else:
157
- sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
158
 
159
  del q, k
160
 
161
  if mask is not None:
162
- mask = rearrange(mask, 'b ... -> b (...)')
163
  max_neg_value = -torch.finfo(sim.dtype).max
164
- mask = repeat(mask, 'b j -> (b h) () j', h=h)
165
  sim.masked_fill_(~mask, max_neg_value)
166
 
167
  # attention, what we cannot get enough of
168
  sim = sim.softmax(dim=-1)
169
 
170
- out = einsum('b i j, b j d -> b i d', sim, v)
171
- out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
172
  return self.to_out(out)
173
 
174
 
@@ -187,7 +202,9 @@ class MemoryEfficientCrossAttention(nn.Module):
187
  self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
188
  self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
189
 
190
- self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
 
 
191
  self.attention_op: Optional[Any] = None
192
 
193
  def forward(self, x, context=None, mask=None):
@@ -198,44 +215,84 @@ class MemoryEfficientCrossAttention(nn.Module):
198
 
199
  b, _, _ = q.shape
200
  q, k, v = map(
201
- lambda t: t.unsqueeze(3).reshape(b, t.shape[1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(b * self.heads, t.shape[1], self.dim_head).contiguous(),
 
 
 
 
202
  (q, k, v),
203
  )
204
 
205
  # actually compute the attention, what we cannot get enough of
206
- out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
 
 
207
 
208
  if mask is not None:
209
  raise NotImplementedError
210
- out = (out.unsqueeze(0).reshape(b, self.heads, out.shape[1], self.dim_head).permute(0, 2, 1, 3).reshape(b, out.shape[1], self.heads * self.dim_head))
 
 
 
 
 
211
  return self.to_out(out)
212
 
213
 
214
  class BasicTransformerBlock(nn.Module):
215
  ATTENTION_MODES = {
216
- "softmax": CrossAttention, # vanilla attention
217
- "softmax-xformers": MemoryEfficientCrossAttention
218
- }
219
-
220
- def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, disable_self_attn=False):
 
 
 
 
 
 
 
 
 
 
221
  super().__init__()
222
  attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
223
  assert attn_mode in self.ATTENTION_MODES
224
  attn_cls = self.ATTENTION_MODES[attn_mode]
225
  self.disable_self_attn = disable_self_attn
226
- self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
 
 
 
 
 
 
227
  self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
228
- self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
 
 
 
 
 
 
229
  self.norm1 = nn.LayerNorm(dim)
230
  self.norm2 = nn.LayerNorm(dim)
231
  self.norm3 = nn.LayerNorm(dim)
232
  self.checkpoint = checkpoint
233
 
234
  def forward(self, x, context=None):
235
- return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
 
 
236
 
237
  def _forward(self, x, context=None):
238
- x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
 
 
 
 
 
239
  x = self.attn2(self.norm2(x), context=context) + x
240
  x = self.ff(self.norm3(x)) + x
241
  return x
@@ -251,7 +308,18 @@ class SpatialTransformer(nn.Module):
251
  NEW: use_linear for more efficiency instead of the 1x1 convs
252
  """
253
 
254
- def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None, disable_self_attn=False, use_linear=False, use_checkpoint=True):
 
 
 
 
 
 
 
 
 
 
 
255
  super().__init__()
256
  assert context_dim is not None
257
  if not isinstance(context_dim, list):
@@ -260,13 +328,30 @@ class SpatialTransformer(nn.Module):
260
  inner_dim = n_heads * d_head
261
  self.norm = Normalize(in_channels)
262
  if not use_linear:
263
- self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
 
 
264
  else:
265
  self.proj_in = nn.Linear(in_channels, inner_dim)
266
 
267
- self.transformer_blocks = nn.ModuleList([BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) for d in range(depth)])
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  if not use_linear:
269
- self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
 
 
270
  else:
271
  self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
272
  self.use_linear = use_linear
@@ -280,27 +365,33 @@ class SpatialTransformer(nn.Module):
280
  x = self.norm(x)
281
  if not self.use_linear:
282
  x = self.proj_in(x)
283
- x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
284
  if self.use_linear:
285
  x = self.proj_in(x)
286
  for i, block in enumerate(self.transformer_blocks):
287
  x = block(x, context=context[i])
288
  if self.use_linear:
289
  x = self.proj_out(x)
290
- x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
291
  if not self.use_linear:
292
  x = self.proj_out(x)
293
  return x + x_in
294
 
295
 
296
  class BasicTransformerBlock3D(BasicTransformerBlock):
297
-
298
  def forward(self, x, context=None, num_frames=1):
299
- return checkpoint(self._forward, (x, context, num_frames), self.parameters(), self.checkpoint)
 
 
300
 
301
  def _forward(self, x, context=None, num_frames=1):
302
  x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
303
- x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
 
 
 
 
 
304
  x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
305
  x = self.attn2(self.norm2(x), context=context) + x
306
  x = self.ff(self.norm3(x)) + x
@@ -308,9 +399,20 @@ class BasicTransformerBlock3D(BasicTransformerBlock):
308
 
309
 
310
  class SpatialTransformer3D(nn.Module):
311
- ''' 3D self-attention '''
312
-
313
- def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None, disable_self_attn=False, use_linear=False, use_checkpoint=True):
 
 
 
 
 
 
 
 
 
 
 
314
  super().__init__()
315
  assert context_dim is not None
316
  if not isinstance(context_dim, list):
@@ -319,13 +421,30 @@ class SpatialTransformer3D(nn.Module):
319
  inner_dim = n_heads * d_head
320
  self.norm = Normalize(in_channels)
321
  if not use_linear:
322
- self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
 
 
323
  else:
324
  self.proj_in = nn.Linear(in_channels, inner_dim)
325
 
326
- self.transformer_blocks = nn.ModuleList([BasicTransformerBlock3D(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) for d in range(depth)])
 
 
 
 
 
 
 
 
 
 
 
 
 
327
  if not use_linear:
328
- self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
 
 
329
  else:
330
  self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
331
  self.use_linear = use_linear
@@ -339,14 +458,14 @@ class SpatialTransformer3D(nn.Module):
339
  x = self.norm(x)
340
  if not self.use_linear:
341
  x = self.proj_in(x)
342
- x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
343
  if self.use_linear:
344
  x = self.proj_in(x)
345
  for i, block in enumerate(self.transformer_blocks):
346
  x = block(x, context=context[i], num_frames=num_frames)
347
  if self.use_linear:
348
  x = self.proj_out(x)
349
- x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
350
  if not self.use_linear:
351
  x = self.proj_out(x)
352
  return x + x_in
 
12
  from .util import checkpoint
13
 
14
  try:
15
+ import xformers # type: ignore
16
+ import xformers.ops # type: ignore
17
+
18
  XFORMERS_IS_AVAILBLE = True
19
  except:
20
  XFORMERS_IS_AVAILBLE = False
 
48
 
49
  # feedforward
50
  class GEGLU(nn.Module):
 
51
  def __init__(self, dim_in, dim_out):
52
  super().__init__()
53
  self.proj = nn.Linear(dim_in, dim_out * 2)
 
58
 
59
 
60
  class FeedForward(nn.Module):
61
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
 
62
  super().__init__()
63
  inner_dim = int(dim * mult)
64
  dim_out = default(dim_out, dim)
65
+ project_in = (
66
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
67
+ if not glu
68
+ else GEGLU(dim, inner_dim)
69
+ )
70
 
71
+ self.net = nn.Sequential(
72
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
73
+ )
74
 
75
  def forward(self, x):
76
  return self.net(x)
 
86
 
87
 
88
  def Normalize(in_channels):
89
+ return torch.nn.GroupNorm(
90
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
91
+ )
92
 
93
 
94
  class SpatialSelfAttention(nn.Module):
 
95
  def __init__(self, in_channels):
96
  super().__init__()
97
  self.in_channels = in_channels
98
 
99
  self.norm = Normalize(in_channels)
100
+ self.q = torch.nn.Conv2d(
101
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
102
+ )
103
+ self.k = torch.nn.Conv2d(
104
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
105
+ )
106
+ self.v = torch.nn.Conv2d(
107
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
108
+ )
109
+ self.proj_out = torch.nn.Conv2d(
110
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
111
+ )
112
 
113
  def forward(self, x):
114
  h_ = x
 
119
 
120
  # compute attention
121
  b, c, h, w = q.shape
122
+ q = rearrange(q, "b c h w -> b (h w) c")
123
+ k = rearrange(k, "b c h w -> b c (h w)")
124
+ w_ = torch.einsum("bij,bjk->bik", q, k)
125
 
126
+ w_ = w_ * (int(c) ** (-0.5))
127
  w_ = torch.nn.functional.softmax(w_, dim=2)
128
 
129
  # attend to values
130
+ v = rearrange(v, "b c h w -> b c (h w)")
131
+ w_ = rearrange(w_, "b i j -> b j i")
132
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
133
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
134
  h_ = self.proj_out(h_)
135
 
136
  return x + h_
137
 
138
 
139
  class CrossAttention(nn.Module):
140
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
 
141
  super().__init__()
142
  inner_dim = dim_head * heads
143
  context_dim = default(context_dim, query_dim)
 
149
  self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
150
  self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
151
 
152
+ self.to_out = nn.Sequential(
153
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
154
+ )
155
 
156
  def forward(self, x, context=None, mask=None):
157
  h = self.heads
 
161
  k = self.to_k(context)
162
  v = self.to_v(context)
163
 
164
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
165
 
166
  # force cast to fp32 to avoid overflowing
167
  if _ATTN_PRECISION == "fp32":
168
+ with autocast(enabled=False, device_type="cuda"):
169
  q, k = q.float(), k.float()
170
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
171
  else:
172
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
173
 
174
  del q, k
175
 
176
  if mask is not None:
177
+ mask = rearrange(mask, "b ... -> b (...)")
178
  max_neg_value = -torch.finfo(sim.dtype).max
179
+ mask = repeat(mask, "b j -> (b h) () j", h=h)
180
  sim.masked_fill_(~mask, max_neg_value)
181
 
182
  # attention, what we cannot get enough of
183
  sim = sim.softmax(dim=-1)
184
 
185
+ out = einsum("b i j, b j d -> b i d", sim, v)
186
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
187
  return self.to_out(out)
188
 
189
 
 
202
  self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
203
  self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
204
 
205
+ self.to_out = nn.Sequential(
206
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
207
+ )
208
  self.attention_op: Optional[Any] = None
209
 
210
  def forward(self, x, context=None, mask=None):
 
215
 
216
  b, _, _ = q.shape
217
  q, k, v = map(
218
+ lambda t: t.unsqueeze(3)
219
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
220
+ .permute(0, 2, 1, 3)
221
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
222
+ .contiguous(),
223
  (q, k, v),
224
  )
225
 
226
  # actually compute the attention, what we cannot get enough of
227
+ out = xformers.ops.memory_efficient_attention(
228
+ q, k, v, attn_bias=None, op=self.attention_op
229
+ )
230
 
231
  if mask is not None:
232
  raise NotImplementedError
233
+ out = (
234
+ out.unsqueeze(0)
235
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
236
+ .permute(0, 2, 1, 3)
237
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
238
+ )
239
  return self.to_out(out)
240
 
241
 
242
  class BasicTransformerBlock(nn.Module):
243
  ATTENTION_MODES = {
244
+ "softmax": CrossAttention,
245
+ "softmax-xformers": MemoryEfficientCrossAttention,
246
+ } # vanilla attention
247
+
248
+ def __init__(
249
+ self,
250
+ dim,
251
+ n_heads,
252
+ d_head,
253
+ dropout=0.0,
254
+ context_dim=None,
255
+ gated_ff=True,
256
+ checkpoint=True,
257
+ disable_self_attn=False,
258
+ ):
259
  super().__init__()
260
  attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
261
  assert attn_mode in self.ATTENTION_MODES
262
  attn_cls = self.ATTENTION_MODES[attn_mode]
263
  self.disable_self_attn = disable_self_attn
264
+ self.attn1 = attn_cls(
265
+ query_dim=dim,
266
+ heads=n_heads,
267
+ dim_head=d_head,
268
+ dropout=dropout,
269
+ context_dim=context_dim if self.disable_self_attn else None,
270
+ ) # is a self-attention if not self.disable_self_attn
271
  self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
272
+ self.attn2 = attn_cls(
273
+ query_dim=dim,
274
+ context_dim=context_dim,
275
+ heads=n_heads,
276
+ dim_head=d_head,
277
+ dropout=dropout,
278
+ ) # is self-attn if context is none
279
  self.norm1 = nn.LayerNorm(dim)
280
  self.norm2 = nn.LayerNorm(dim)
281
  self.norm3 = nn.LayerNorm(dim)
282
  self.checkpoint = checkpoint
283
 
284
  def forward(self, x, context=None):
285
+ return checkpoint(
286
+ self._forward, (x, context), self.parameters(), self.checkpoint
287
+ )
288
 
289
  def _forward(self, x, context=None):
290
+ x = (
291
+ self.attn1(
292
+ self.norm1(x), context=context if self.disable_self_attn else None
293
+ )
294
+ + x
295
+ )
296
  x = self.attn2(self.norm2(x), context=context) + x
297
  x = self.ff(self.norm3(x)) + x
298
  return x
 
308
  NEW: use_linear for more efficiency instead of the 1x1 convs
309
  """
310
 
311
+ def __init__(
312
+ self,
313
+ in_channels,
314
+ n_heads,
315
+ d_head,
316
+ depth=1,
317
+ dropout=0.0,
318
+ context_dim=None,
319
+ disable_self_attn=False,
320
+ use_linear=False,
321
+ use_checkpoint=True,
322
+ ):
323
  super().__init__()
324
  assert context_dim is not None
325
  if not isinstance(context_dim, list):
 
328
  inner_dim = n_heads * d_head
329
  self.norm = Normalize(in_channels)
330
  if not use_linear:
331
+ self.proj_in = nn.Conv2d(
332
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
333
+ )
334
  else:
335
  self.proj_in = nn.Linear(in_channels, inner_dim)
336
 
337
+ self.transformer_blocks = nn.ModuleList(
338
+ [
339
+ BasicTransformerBlock(
340
+ inner_dim,
341
+ n_heads,
342
+ d_head,
343
+ dropout=dropout,
344
+ context_dim=context_dim[d],
345
+ disable_self_attn=disable_self_attn,
346
+ checkpoint=use_checkpoint,
347
+ )
348
+ for d in range(depth)
349
+ ]
350
+ )
351
  if not use_linear:
352
+ self.proj_out = zero_module(
353
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
354
+ )
355
  else:
356
  self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
357
  self.use_linear = use_linear
 
365
  x = self.norm(x)
366
  if not self.use_linear:
367
  x = self.proj_in(x)
368
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
369
  if self.use_linear:
370
  x = self.proj_in(x)
371
  for i, block in enumerate(self.transformer_blocks):
372
  x = block(x, context=context[i])
373
  if self.use_linear:
374
  x = self.proj_out(x)
375
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
376
  if not self.use_linear:
377
  x = self.proj_out(x)
378
  return x + x_in
379
 
380
 
381
  class BasicTransformerBlock3D(BasicTransformerBlock):
 
382
  def forward(self, x, context=None, num_frames=1):
383
+ return checkpoint(
384
+ self._forward, (x, context, num_frames), self.parameters(), self.checkpoint
385
+ )
386
 
387
  def _forward(self, x, context=None, num_frames=1):
388
  x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
389
+ x = (
390
+ self.attn1(
391
+ self.norm1(x), context=context if self.disable_self_attn else None
392
+ )
393
+ + x
394
+ )
395
  x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
396
  x = self.attn2(self.norm2(x), context=context) + x
397
  x = self.ff(self.norm3(x)) + x
 
399
 
400
 
401
  class SpatialTransformer3D(nn.Module):
402
+ """3D self-attention"""
403
+
404
+ def __init__(
405
+ self,
406
+ in_channels,
407
+ n_heads,
408
+ d_head,
409
+ depth=1,
410
+ dropout=0.0,
411
+ context_dim=None,
412
+ disable_self_attn=False,
413
+ use_linear=False,
414
+ use_checkpoint=True,
415
+ ):
416
  super().__init__()
417
  assert context_dim is not None
418
  if not isinstance(context_dim, list):
 
421
  inner_dim = n_heads * d_head
422
  self.norm = Normalize(in_channels)
423
  if not use_linear:
424
+ self.proj_in = nn.Conv2d(
425
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
426
+ )
427
  else:
428
  self.proj_in = nn.Linear(in_channels, inner_dim)
429
 
430
+ self.transformer_blocks = nn.ModuleList(
431
+ [
432
+ BasicTransformerBlock3D(
433
+ inner_dim,
434
+ n_heads,
435
+ d_head,
436
+ dropout=dropout,
437
+ context_dim=context_dim[d],
438
+ disable_self_attn=disable_self_attn,
439
+ checkpoint=use_checkpoint,
440
+ )
441
+ for d in range(depth)
442
+ ]
443
+ )
444
  if not use_linear:
445
+ self.proj_out = zero_module(
446
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
447
+ )
448
  else:
449
  self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
450
  self.use_linear = use_linear
 
458
  x = self.norm(x)
459
  if not self.use_linear:
460
  x = self.proj_in(x)
461
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
462
  if self.use_linear:
463
  x = self.proj_in(x)
464
  for i, block in enumerate(self.transformer_blocks):
465
  x = block(x, context=context[i], num_frames=num_frames)
466
  if self.use_linear:
467
  x = self.proj_out(x)
468
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
469
  if not self.use_linear:
470
  x = self.proj_out(x)
471
  return x + x_in
mvdream/models.py CHANGED
@@ -5,6 +5,10 @@ import numpy as np
5
  import torch as th
6
  import torch.nn as nn
7
  import torch.nn.functional as F
 
 
 
 
8
 
9
  from abc import abstractmethod
10
  from .util import (
@@ -15,80 +19,6 @@ from .util import (
15
  timestep_embedding,
16
  )
17
  from .attention import SpatialTransformer, SpatialTransformer3D
18
- from diffusers.configuration_utils import ConfigMixin
19
- from diffusers.models.modeling_utils import ModelMixin
20
- from typing import Any, List, Optional
21
- from torch import Tensor
22
-
23
-
24
- class MultiViewUNetWrapperModel(ModelMixin, ConfigMixin):
25
-
26
- def __init__(self,
27
- image_size,
28
- in_channels,
29
- model_channels,
30
- out_channels,
31
- num_res_blocks,
32
- attention_resolutions,
33
- dropout=0,
34
- channel_mult=(1, 2, 4, 8),
35
- conv_resample=True,
36
- dims=2,
37
- num_classes=None,
38
- use_checkpoint=False,
39
- num_heads=-1,
40
- num_head_channels=-1,
41
- num_heads_upsample=-1,
42
- use_scale_shift_norm=False,
43
- resblock_updown=False,
44
- use_new_attention_order=False,
45
- use_spatial_transformer=False, # custom transformer support
46
- transformer_depth=1, # custom transformer support
47
- context_dim=None, # custom transformer support
48
- n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
49
- legacy=True,
50
- disable_self_attentions=None,
51
- num_attention_blocks=None,
52
- disable_middle_self_attn=False,
53
- use_linear_in_transformer=False,
54
- adm_in_channels=None,
55
- camera_dim=None,):
56
- super().__init__()
57
- self.unet = MultiViewUNetModel(
58
- image_size=image_size,
59
- in_channels=in_channels,
60
- model_channels=model_channels,
61
- out_channels=out_channels,
62
- num_res_blocks=num_res_blocks,
63
- attention_resolutions=attention_resolutions,
64
- dropout=dropout,
65
- channel_mult=channel_mult,
66
- conv_resample=conv_resample,
67
- dims=dims,
68
- num_classes=num_classes,
69
- use_checkpoint=use_checkpoint,
70
- num_heads=num_heads,
71
- num_head_channels=num_head_channels,
72
- num_heads_upsample=num_heads_upsample,
73
- use_scale_shift_norm=use_scale_shift_norm,
74
- resblock_updown=resblock_updown,
75
- use_new_attention_order=use_new_attention_order,
76
- use_spatial_transformer=use_spatial_transformer,
77
- transformer_depth=transformer_depth,
78
- context_dim=context_dim,
79
- n_embed=n_embed,
80
- legacy=legacy,
81
- disable_self_attentions=disable_self_attentions,
82
- num_attention_blocks=num_attention_blocks,
83
- disable_middle_self_attn=disable_middle_self_attn,
84
- use_linear_in_transformer=use_linear_in_transformer,
85
- adm_in_channels=adm_in_channels,
86
- camera_dim=camera_dim,
87
- )
88
-
89
- def forward(self, *args, **kwargs):
90
- return self.unet(*args, **kwargs)
91
-
92
 
93
  class TimestepBlock(nn.Module):
94
  """
@@ -137,12 +67,16 @@ class Upsample(nn.Module):
137
  self.use_conv = use_conv
138
  self.dims = dims
139
  if use_conv:
140
- self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
 
 
141
 
142
  def forward(self, x):
143
  assert x.shape[1] == self.channels
144
  if self.dims == 3:
145
- x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
 
 
146
  else:
147
  x = F.interpolate(x, scale_factor=2, mode="nearest")
148
  if self.use_conv:
@@ -167,7 +101,14 @@ class Downsample(nn.Module):
167
  self.dims = dims
168
  stride = 2 if dims != 3 else (1, 2, 2)
169
  if use_conv:
170
- self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
 
 
 
 
 
 
 
171
  else:
172
  assert self.channels == self.out_channels
173
  self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
@@ -243,13 +184,17 @@ class ResBlock(TimestepBlock):
243
  nn.GroupNorm(32, self.out_channels),
244
  nn.SiLU(),
245
  nn.Dropout(p=dropout),
246
- zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
 
 
247
  )
248
 
249
  if self.out_channels == channels:
250
  self.skip_connection = nn.Identity()
251
  elif use_conv:
252
- self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
 
 
253
  else:
254
  self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
255
 
@@ -260,7 +205,9 @@ class ResBlock(TimestepBlock):
260
  :param emb: an [N x emb_channels] Tensor of timestep embeddings.
261
  :return: an [N x C x ...] Tensor of outputs.
262
  """
263
- return checkpoint(self._forward, (x, emb), self.parameters(), self.use_checkpoint)
 
 
264
 
265
  def _forward(self, x, emb):
266
  if self.updown:
@@ -305,7 +252,9 @@ class AttentionBlock(nn.Module):
305
  if num_head_channels == -1:
306
  self.num_heads = num_heads
307
  else:
308
- assert (channels % num_head_channels == 0), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
 
 
309
  self.num_heads = channels // num_head_channels
310
  self.use_checkpoint = use_checkpoint
311
  self.norm = nn.GroupNorm(32, channels)
@@ -320,8 +269,7 @@ class AttentionBlock(nn.Module):
320
  self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
321
 
322
  def forward(self, x):
323
- return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
324
- #return pt_checkpoint(self._forward, x) # pytorch
325
 
326
  def _forward(self, x):
327
  b, c, *spatial = x.shape
@@ -332,26 +280,6 @@ class AttentionBlock(nn.Module):
332
  return (x + h).reshape(b, c, *spatial)
333
 
334
 
335
- def count_flops_attn(model, _x, y):
336
- """
337
- A counter for the `thop` package to count the operations in an
338
- attention operation.
339
- Meant to be used like:
340
- macs, params = thop.profile(
341
- model,
342
- inputs=(inputs, timestamps),
343
- custom_ops={QKVAttention: QKVAttention.count_flops},
344
- )
345
- """
346
- b, c, *spatial = y[0].shape
347
- num_spatial = int(np.prod(spatial))
348
- # We perform two matmuls with the same number of ops.
349
- # The first computes the weight matrix, the second computes
350
- # the combination of the value vectors.
351
- matmul_ops = 2 * b * (num_spatial**2) * c
352
- model.total_ops += th.DoubleTensor([matmul_ops])
353
-
354
-
355
  class QKVAttentionLegacy(nn.Module):
356
  """
357
  A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
@@ -372,15 +300,13 @@ class QKVAttentionLegacy(nn.Module):
372
  ch = width // (3 * self.n_heads)
373
  q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
374
  scale = 1 / math.sqrt(math.sqrt(ch))
375
- weight = th.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
 
 
376
  weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
377
  a = th.einsum("bts,bcs->bct", weight, v)
378
  return a.reshape(bs, -1, length)
379
 
380
- @staticmethod
381
- def count_flops(model, _x, y):
382
- return count_flops_attn(model, _x, y)
383
-
384
 
385
  class QKVAttention(nn.Module):
386
  """
@@ -406,17 +332,13 @@ class QKVAttention(nn.Module):
406
  "bct,bcs->bts",
407
  (q * scale).view(bs * self.n_heads, ch, length),
408
  (k * scale).view(bs * self.n_heads, ch, length),
409
- ) # More stable with f16 than dividing afterwards
410
  weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
411
  a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
412
  return a.reshape(bs, -1, length)
413
 
414
- @staticmethod
415
- def count_flops(model, _x, y):
416
- return count_flops_attn(model, _x, y)
417
 
418
-
419
- class MultiViewUNetModel(nn.Module):
420
  """
421
  The full multi-view UNet model with attention, timestep embedding and camera embedding.
422
  :param in_channels: channels in the input Tensor.
@@ -448,44 +370,49 @@ class MultiViewUNetModel(nn.Module):
448
  """
449
 
450
  def __init__(
451
- self,
452
- image_size,
453
- in_channels,
454
- model_channels,
455
- out_channels,
456
- num_res_blocks,
457
- attention_resolutions,
458
- dropout=0,
459
- channel_mult=(1, 2, 4, 8),
460
- conv_resample=True,
461
- dims=2,
462
- num_classes=None,
463
- use_checkpoint=False,
464
- num_heads=-1,
465
- num_head_channels=-1,
466
- num_heads_upsample=-1,
467
- use_scale_shift_norm=False,
468
- resblock_updown=False,
469
- use_new_attention_order=False,
470
- use_spatial_transformer=False, # custom transformer support
471
- transformer_depth=1, # custom transformer support
472
- context_dim=None, # custom transformer support
473
- n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
474
- legacy=True,
475
- disable_self_attentions=None,
476
- num_attention_blocks=None,
477
- disable_middle_self_attn=False,
478
- use_linear_in_transformer=False,
479
- adm_in_channels=None,
480
- camera_dim=None,
481
  ):
482
  super().__init__()
483
  if use_spatial_transformer:
484
- assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
 
 
485
 
486
  if context_dim is not None:
487
- assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
 
 
488
  from omegaconf.listconfig import ListConfig
 
489
  if type(context_dim) == ListConfig:
490
  context_dim = list(context_dim)
491
 
@@ -493,10 +420,14 @@ class MultiViewUNetModel(nn.Module):
493
  num_heads_upsample = num_heads
494
 
495
  if num_heads == -1:
496
- assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
 
 
497
 
498
  if num_head_channels == -1:
499
- assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
 
 
500
 
501
  self.image_size = image_size
502
  self.in_channels = in_channels
@@ -506,19 +437,28 @@ class MultiViewUNetModel(nn.Module):
506
  self.num_res_blocks = len(channel_mult) * [num_res_blocks]
507
  else:
508
  if len(num_res_blocks) != len(channel_mult):
509
- raise ValueError("provide num_res_blocks either as an int (globally constant) or "
510
- "as a list/tuple (per-level) with the same length as channel_mult")
 
 
511
  self.num_res_blocks = num_res_blocks
512
  if disable_self_attentions is not None:
513
  # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
514
  assert len(disable_self_attentions) == len(channel_mult)
515
  if num_attention_blocks is not None:
516
  assert len(num_attention_blocks) == len(self.num_res_blocks)
517
- assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
518
- print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
519
- f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
520
- f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
521
- f"attention will still not be set.")
 
 
 
 
 
 
 
522
 
523
  self.attention_resolutions = attention_resolutions
524
  self.dropout = dropout
@@ -554,30 +494,40 @@ class MultiViewUNetModel(nn.Module):
554
  self.label_emb = nn.Linear(1, time_embed_dim)
555
  elif self.num_classes == "sequential":
556
  assert adm_in_channels is not None
557
- self.label_emb = nn.Sequential(nn.Sequential(
558
- nn.Linear(adm_in_channels, time_embed_dim),
559
- nn.SiLU(),
560
- nn.Linear(time_embed_dim, time_embed_dim),
561
- ))
 
 
562
  else:
563
  raise ValueError()
564
 
565
- self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))])
 
 
 
 
 
 
566
  self._feature_size = model_channels
567
  input_block_chans = [model_channels]
568
  ch = model_channels
569
  ds = 1
570
  for level, mult in enumerate(channel_mult):
571
  for nr in range(self.num_res_blocks[level]):
572
- layers: List[Any] = [ResBlock(
573
- ch,
574
- time_embed_dim,
575
- dropout,
576
- out_channels=mult * model_channels,
577
- dims=dims,
578
- use_checkpoint=use_checkpoint,
579
- use_scale_shift_norm=use_scale_shift_norm,
580
- )]
 
 
581
  ch = mult * model_channels
582
  if ds in attention_resolutions:
583
  if num_head_channels == -1:
@@ -586,36 +536,61 @@ class MultiViewUNetModel(nn.Module):
586
  num_heads = ch // num_head_channels
587
  dim_head = num_head_channels
588
  if legacy:
589
- #num_heads = 1
590
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
 
 
 
 
591
  if disable_self_attentions is not None:
592
  disabled_sa = disable_self_attentions[level]
593
  else:
594
  disabled_sa = False
595
 
596
  if num_attention_blocks is None or nr < num_attention_blocks[level]:
597
- layers.append(AttentionBlock(
598
- ch,
599
- use_checkpoint=use_checkpoint,
600
- num_heads=num_heads,
601
- num_head_channels=dim_head,
602
- use_new_attention_order=use_new_attention_order,
603
- ) if not use_spatial_transformer else SpatialTransformer3D(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, use_checkpoint=use_checkpoint))
 
 
 
 
 
 
 
 
 
 
 
 
 
604
  self.input_blocks.append(TimestepEmbedSequential(*layers))
605
  self._feature_size += ch
606
  input_block_chans.append(ch)
607
  if level != len(channel_mult) - 1:
608
  out_ch = ch
609
- self.input_blocks.append(TimestepEmbedSequential(ResBlock(
610
- ch,
611
- time_embed_dim,
612
- dropout,
613
- out_channels=out_ch,
614
- dims=dims,
615
- use_checkpoint=use_checkpoint,
616
- use_scale_shift_norm=use_scale_shift_norm,
617
- down=True,
618
- ) if resblock_updown else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)))
 
 
 
 
 
 
 
 
619
  ch = out_ch
620
  input_block_chans.append(ch)
621
  ds *= 2
@@ -627,7 +602,7 @@ class MultiViewUNetModel(nn.Module):
627
  num_heads = ch // num_head_channels
628
  dim_head = num_head_channels
629
  if legacy:
630
- #num_heads = 1
631
  dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
632
  self.middle_block = TimestepEmbedSequential(
633
  ResBlock(
@@ -644,8 +619,18 @@ class MultiViewUNetModel(nn.Module):
644
  num_heads=num_heads,
645
  num_head_channels=dim_head,
646
  use_new_attention_order=use_new_attention_order,
647
- ) if not use_spatial_transformer else SpatialTransformer3D( # always uses a self-attn
648
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, use_checkpoint=use_checkpoint),
 
 
 
 
 
 
 
 
 
 
649
  ResBlock(
650
  ch,
651
  time_embed_dim,
@@ -661,15 +646,17 @@ class MultiViewUNetModel(nn.Module):
661
  for level, mult in list(enumerate(channel_mult))[::-1]:
662
  for i in range(self.num_res_blocks[level] + 1):
663
  ich = input_block_chans.pop()
664
- layers = [ResBlock(
665
- ch + ich,
666
- time_embed_dim,
667
- dropout,
668
- out_channels=model_channels * mult,
669
- dims=dims,
670
- use_checkpoint=use_checkpoint,
671
- use_scale_shift_norm=use_scale_shift_norm,
672
- )]
 
 
673
  ch = model_channels * mult
674
  if ds in attention_resolutions:
675
  if num_head_channels == -1:
@@ -678,33 +665,54 @@ class MultiViewUNetModel(nn.Module):
678
  num_heads = ch // num_head_channels
679
  dim_head = num_head_channels
680
  if legacy:
681
- #num_heads = 1
682
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
 
 
 
 
683
  if disable_self_attentions is not None:
684
  disabled_sa = disable_self_attentions[level]
685
  else:
686
  disabled_sa = False
687
 
688
  if num_attention_blocks is None or i < num_attention_blocks[level]:
689
- layers.append(AttentionBlock(
690
- ch,
691
- use_checkpoint=use_checkpoint,
692
- num_heads=num_heads_upsample,
693
- num_head_channels=dim_head,
694
- use_new_attention_order=use_new_attention_order,
695
- ) if not use_spatial_transformer else SpatialTransformer3D(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, use_checkpoint=use_checkpoint))
 
 
 
 
 
 
 
 
 
 
 
 
 
696
  if level and i == self.num_res_blocks[level]:
697
  out_ch = ch
698
- layers.append(ResBlock(
699
- ch,
700
- time_embed_dim,
701
- dropout,
702
- out_channels=out_ch,
703
- dims=dims,
704
- use_checkpoint=use_checkpoint,
705
- use_scale_shift_norm=use_scale_shift_norm,
706
- up=True,
707
- ) if resblock_updown else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch))
 
 
 
 
708
  ds //= 2
709
  self.output_blocks.append(TimestepEmbedSequential(*layers))
710
  self._feature_size += ch
@@ -718,10 +726,19 @@ class MultiViewUNetModel(nn.Module):
718
  self.id_predictor = nn.Sequential(
719
  nn.GroupNorm(32, ch),
720
  conv_nd(dims, model_channels, n_embed, 1),
721
- #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
722
  )
723
 
724
- def forward(self, x, timesteps=None, context=None, y: Optional[Tensor] = None, camera=None, num_frames=1, **kwargs):
 
 
 
 
 
 
 
 
 
725
  """
726
  Apply the model to an input batch.
727
  :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
@@ -731,11 +748,17 @@ class MultiViewUNetModel(nn.Module):
731
  :param num_frames: a integer indicating number of frames for tensor reshaping.
732
  :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
733
  """
734
- assert x.shape[0] % num_frames == 0, "[UNet] input batch size must be dividable by num_frames!"
735
- assert (y is not None) == (self.num_classes is not None), "must specify y if and only if the model is class-conditional"
 
 
 
 
736
  hs = []
737
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
738
-
 
 
739
  emb = self.time_embed(t_emb)
740
 
741
  if self.num_classes is not None:
 
5
  import torch as th
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
+ from diffusers.configuration_utils import ConfigMixin
9
+ from diffusers.models.modeling_utils import ModelMixin
10
+ from typing import Any, List, Optional
11
+ from torch import Tensor
12
 
13
  from abc import abstractmethod
14
  from .util import (
 
19
  timestep_embedding,
20
  )
21
  from .attention import SpatialTransformer, SpatialTransformer3D
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  class TimestepBlock(nn.Module):
24
  """
 
67
  self.use_conv = use_conv
68
  self.dims = dims
69
  if use_conv:
70
+ self.conv = conv_nd(
71
+ dims, self.channels, self.out_channels, 3, padding=padding
72
+ )
73
 
74
  def forward(self, x):
75
  assert x.shape[1] == self.channels
76
  if self.dims == 3:
77
+ x = F.interpolate(
78
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
79
+ )
80
  else:
81
  x = F.interpolate(x, scale_factor=2, mode="nearest")
82
  if self.use_conv:
 
101
  self.dims = dims
102
  stride = 2 if dims != 3 else (1, 2, 2)
103
  if use_conv:
104
+ self.op = conv_nd(
105
+ dims,
106
+ self.channels,
107
+ self.out_channels,
108
+ 3,
109
+ stride=stride,
110
+ padding=padding,
111
+ )
112
  else:
113
  assert self.channels == self.out_channels
114
  self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
 
184
  nn.GroupNorm(32, self.out_channels),
185
  nn.SiLU(),
186
  nn.Dropout(p=dropout),
187
+ zero_module(
188
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
189
+ ),
190
  )
191
 
192
  if self.out_channels == channels:
193
  self.skip_connection = nn.Identity()
194
  elif use_conv:
195
+ self.skip_connection = conv_nd(
196
+ dims, channels, self.out_channels, 3, padding=1
197
+ )
198
  else:
199
  self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
200
 
 
205
  :param emb: an [N x emb_channels] Tensor of timestep embeddings.
206
  :return: an [N x C x ...] Tensor of outputs.
207
  """
208
+ return checkpoint(
209
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
210
+ )
211
 
212
  def _forward(self, x, emb):
213
  if self.updown:
 
252
  if num_head_channels == -1:
253
  self.num_heads = num_heads
254
  else:
255
+ assert (
256
+ channels % num_head_channels == 0
257
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
258
  self.num_heads = channels // num_head_channels
259
  self.use_checkpoint = use_checkpoint
260
  self.norm = nn.GroupNorm(32, channels)
 
269
  self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
270
 
271
  def forward(self, x):
272
+ return checkpoint(self._forward, (x,), self.parameters(), True)
 
273
 
274
  def _forward(self, x):
275
  b, c, *spatial = x.shape
 
280
  return (x + h).reshape(b, c, *spatial)
281
 
282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  class QKVAttentionLegacy(nn.Module):
284
  """
285
  A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
 
300
  ch = width // (3 * self.n_heads)
301
  q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
302
  scale = 1 / math.sqrt(math.sqrt(ch))
303
+ weight = th.einsum(
304
+ "bct,bcs->bts", q * scale, k * scale
305
+ ) # More stable with f16 than dividing afterwards
306
  weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
307
  a = th.einsum("bts,bcs->bct", weight, v)
308
  return a.reshape(bs, -1, length)
309
 
 
 
 
 
310
 
311
  class QKVAttention(nn.Module):
312
  """
 
332
  "bct,bcs->bts",
333
  (q * scale).view(bs * self.n_heads, ch, length),
334
  (k * scale).view(bs * self.n_heads, ch, length),
335
+ ) # More stable with f16 than dividing afterwards
336
  weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
337
  a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
338
  return a.reshape(bs, -1, length)
339
 
 
 
 
340
 
341
+ class MultiViewUNetModel(ModelMixin, ConfigMixin):
 
342
  """
343
  The full multi-view UNet model with attention, timestep embedding and camera embedding.
344
  :param in_channels: channels in the input Tensor.
 
370
  """
371
 
372
  def __init__(
373
+ self,
374
+ image_size,
375
+ in_channels,
376
+ model_channels,
377
+ out_channels,
378
+ num_res_blocks,
379
+ attention_resolutions,
380
+ dropout=0,
381
+ channel_mult=(1, 2, 4, 8),
382
+ conv_resample=True,
383
+ dims=2,
384
+ num_classes=None,
385
+ use_checkpoint=False,
386
+ num_heads=-1,
387
+ num_head_channels=-1,
388
+ num_heads_upsample=-1,
389
+ use_scale_shift_norm=False,
390
+ resblock_updown=False,
391
+ use_new_attention_order=False,
392
+ use_spatial_transformer=False, # custom transformer support
393
+ transformer_depth=1, # custom transformer support
394
+ context_dim=None, # custom transformer support
395
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
396
+ legacy=True,
397
+ disable_self_attentions=None,
398
+ num_attention_blocks=None,
399
+ disable_middle_self_attn=False,
400
+ use_linear_in_transformer=False,
401
+ adm_in_channels=None,
402
+ camera_dim=None,
403
  ):
404
  super().__init__()
405
  if use_spatial_transformer:
406
+ assert (
407
+ context_dim is not None
408
+ ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
409
 
410
  if context_dim is not None:
411
+ assert (
412
+ use_spatial_transformer
413
+ ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
414
  from omegaconf.listconfig import ListConfig
415
+
416
  if type(context_dim) == ListConfig:
417
  context_dim = list(context_dim)
418
 
 
420
  num_heads_upsample = num_heads
421
 
422
  if num_heads == -1:
423
+ assert (
424
+ num_head_channels != -1
425
+ ), "Either num_heads or num_head_channels has to be set"
426
 
427
  if num_head_channels == -1:
428
+ assert (
429
+ num_heads != -1
430
+ ), "Either num_heads or num_head_channels has to be set"
431
 
432
  self.image_size = image_size
433
  self.in_channels = in_channels
 
437
  self.num_res_blocks = len(channel_mult) * [num_res_blocks]
438
  else:
439
  if len(num_res_blocks) != len(channel_mult):
440
+ raise ValueError(
441
+ "provide num_res_blocks either as an int (globally constant) or "
442
+ "as a list/tuple (per-level) with the same length as channel_mult"
443
+ )
444
  self.num_res_blocks = num_res_blocks
445
  if disable_self_attentions is not None:
446
  # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
447
  assert len(disable_self_attentions) == len(channel_mult)
448
  if num_attention_blocks is not None:
449
  assert len(num_attention_blocks) == len(self.num_res_blocks)
450
+ assert all(
451
+ map(
452
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
453
+ range(len(num_attention_blocks)),
454
+ )
455
+ )
456
+ print(
457
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
458
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
459
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
460
+ f"attention will still not be set."
461
+ )
462
 
463
  self.attention_resolutions = attention_resolutions
464
  self.dropout = dropout
 
494
  self.label_emb = nn.Linear(1, time_embed_dim)
495
  elif self.num_classes == "sequential":
496
  assert adm_in_channels is not None
497
+ self.label_emb = nn.Sequential(
498
+ nn.Sequential(
499
+ nn.Linear(adm_in_channels, time_embed_dim),
500
+ nn.SiLU(),
501
+ nn.Linear(time_embed_dim, time_embed_dim),
502
+ )
503
+ )
504
  else:
505
  raise ValueError()
506
 
507
+ self.input_blocks = nn.ModuleList(
508
+ [
509
+ TimestepEmbedSequential(
510
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
511
+ )
512
+ ]
513
+ )
514
  self._feature_size = model_channels
515
  input_block_chans = [model_channels]
516
  ch = model_channels
517
  ds = 1
518
  for level, mult in enumerate(channel_mult):
519
  for nr in range(self.num_res_blocks[level]):
520
+ layers: List[Any] = [
521
+ ResBlock(
522
+ ch,
523
+ time_embed_dim,
524
+ dropout,
525
+ out_channels=mult * model_channels,
526
+ dims=dims,
527
+ use_checkpoint=use_checkpoint,
528
+ use_scale_shift_norm=use_scale_shift_norm,
529
+ )
530
+ ]
531
  ch = mult * model_channels
532
  if ds in attention_resolutions:
533
  if num_head_channels == -1:
 
536
  num_heads = ch // num_head_channels
537
  dim_head = num_head_channels
538
  if legacy:
539
+ # num_heads = 1
540
+ dim_head = (
541
+ ch // num_heads
542
+ if use_spatial_transformer
543
+ else num_head_channels
544
+ )
545
  if disable_self_attentions is not None:
546
  disabled_sa = disable_self_attentions[level]
547
  else:
548
  disabled_sa = False
549
 
550
  if num_attention_blocks is None or nr < num_attention_blocks[level]:
551
+ layers.append(
552
+ AttentionBlock(
553
+ ch,
554
+ use_checkpoint=use_checkpoint,
555
+ num_heads=num_heads,
556
+ num_head_channels=dim_head,
557
+ use_new_attention_order=use_new_attention_order,
558
+ )
559
+ if not use_spatial_transformer
560
+ else SpatialTransformer3D(
561
+ ch,
562
+ num_heads,
563
+ dim_head,
564
+ depth=transformer_depth,
565
+ context_dim=context_dim,
566
+ disable_self_attn=disabled_sa,
567
+ use_linear=use_linear_in_transformer,
568
+ use_checkpoint=use_checkpoint,
569
+ )
570
+ )
571
  self.input_blocks.append(TimestepEmbedSequential(*layers))
572
  self._feature_size += ch
573
  input_block_chans.append(ch)
574
  if level != len(channel_mult) - 1:
575
  out_ch = ch
576
+ self.input_blocks.append(
577
+ TimestepEmbedSequential(
578
+ ResBlock(
579
+ ch,
580
+ time_embed_dim,
581
+ dropout,
582
+ out_channels=out_ch,
583
+ dims=dims,
584
+ use_checkpoint=use_checkpoint,
585
+ use_scale_shift_norm=use_scale_shift_norm,
586
+ down=True,
587
+ )
588
+ if resblock_updown
589
+ else Downsample(
590
+ ch, conv_resample, dims=dims, out_channels=out_ch
591
+ )
592
+ )
593
+ )
594
  ch = out_ch
595
  input_block_chans.append(ch)
596
  ds *= 2
 
602
  num_heads = ch // num_head_channels
603
  dim_head = num_head_channels
604
  if legacy:
605
+ # num_heads = 1
606
  dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
607
  self.middle_block = TimestepEmbedSequential(
608
  ResBlock(
 
619
  num_heads=num_heads,
620
  num_head_channels=dim_head,
621
  use_new_attention_order=use_new_attention_order,
622
+ )
623
+ if not use_spatial_transformer
624
+ else SpatialTransformer3D(
625
+ ch,
626
+ num_heads,
627
+ dim_head,
628
+ depth=transformer_depth,
629
+ context_dim=context_dim,
630
+ disable_self_attn=disable_middle_self_attn,
631
+ use_linear=use_linear_in_transformer,
632
+ use_checkpoint=use_checkpoint,
633
+ ), # always uses a self-attn
634
  ResBlock(
635
  ch,
636
  time_embed_dim,
 
646
  for level, mult in list(enumerate(channel_mult))[::-1]:
647
  for i in range(self.num_res_blocks[level] + 1):
648
  ich = input_block_chans.pop()
649
+ layers = [
650
+ ResBlock(
651
+ ch + ich,
652
+ time_embed_dim,
653
+ dropout,
654
+ out_channels=model_channels * mult,
655
+ dims=dims,
656
+ use_checkpoint=use_checkpoint,
657
+ use_scale_shift_norm=use_scale_shift_norm,
658
+ )
659
+ ]
660
  ch = model_channels * mult
661
  if ds in attention_resolutions:
662
  if num_head_channels == -1:
 
665
  num_heads = ch // num_head_channels
666
  dim_head = num_head_channels
667
  if legacy:
668
+ # num_heads = 1
669
+ dim_head = (
670
+ ch // num_heads
671
+ if use_spatial_transformer
672
+ else num_head_channels
673
+ )
674
  if disable_self_attentions is not None:
675
  disabled_sa = disable_self_attentions[level]
676
  else:
677
  disabled_sa = False
678
 
679
  if num_attention_blocks is None or i < num_attention_blocks[level]:
680
+ layers.append(
681
+ AttentionBlock(
682
+ ch,
683
+ use_checkpoint=use_checkpoint,
684
+ num_heads=num_heads_upsample,
685
+ num_head_channels=dim_head,
686
+ use_new_attention_order=use_new_attention_order,
687
+ )
688
+ if not use_spatial_transformer
689
+ else SpatialTransformer3D(
690
+ ch,
691
+ num_heads,
692
+ dim_head,
693
+ depth=transformer_depth,
694
+ context_dim=context_dim,
695
+ disable_self_attn=disabled_sa,
696
+ use_linear=use_linear_in_transformer,
697
+ use_checkpoint=use_checkpoint,
698
+ )
699
+ )
700
  if level and i == self.num_res_blocks[level]:
701
  out_ch = ch
702
+ layers.append(
703
+ ResBlock(
704
+ ch,
705
+ time_embed_dim,
706
+ dropout,
707
+ out_channels=out_ch,
708
+ dims=dims,
709
+ use_checkpoint=use_checkpoint,
710
+ use_scale_shift_norm=use_scale_shift_norm,
711
+ up=True,
712
+ )
713
+ if resblock_updown
714
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
715
+ )
716
  ds //= 2
717
  self.output_blocks.append(TimestepEmbedSequential(*layers))
718
  self._feature_size += ch
 
726
  self.id_predictor = nn.Sequential(
727
  nn.GroupNorm(32, ch),
728
  conv_nd(dims, model_channels, n_embed, 1),
729
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
730
  )
731
 
732
+ def forward(
733
+ self,
734
+ x,
735
+ timesteps=None,
736
+ context=None,
737
+ y: Optional[Tensor] = None,
738
+ camera=None,
739
+ num_frames=1,
740
+ **kwargs,
741
+ ):
742
  """
743
  Apply the model to an input batch.
744
  :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
 
748
  :param num_frames: a integer indicating number of frames for tensor reshaping.
749
  :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
750
  """
751
+ assert (
752
+ x.shape[0] % num_frames == 0
753
+ ), "[UNet] input batch size must be dividable by num_frames!"
754
+ assert (y is not None) == (
755
+ self.num_classes is not None
756
+ ), "must specify y if and only if the model is class-conditional"
757
  hs = []
758
+ t_emb = timestep_embedding(
759
+ timesteps, self.model_channels, repeat_only=False
760
+ ).to(x.dtype)
761
+
762
  emb = self.time_embed(t_emb)
763
 
764
  if self.num_classes is not None:
mvdream/pipeline_mvdream.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
- import numpy as np
3
  import inspect
 
4
  from typing import Callable, List, Optional, Union
5
  from transformers import CLIPTextModel, CLIPTokenizer
6
  from diffusers import AutoencoderKL, DiffusionPipeline
@@ -12,15 +12,12 @@ from diffusers.utils import (
12
  )
13
  from diffusers.configuration_utils import FrozenDict
14
  from diffusers.schedulers import DDIMScheduler
15
- try:
16
- from diffusers import randn_tensor # old import # type: ignore
17
- except ImportError:
18
- from diffusers.utils.torch_utils import randn_tensor # new import # type: ignore
19
 
20
- from .models import MultiViewUNetWrapperModel
21
- from accelerate.utils import set_module_tensor_to_device
22
 
23
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
24
 
25
  def create_camera_to_world_matrix(elevation, azimuth):
26
  elevation = np.radians(elevation)
@@ -55,14 +52,18 @@ def convert_opengl_to_blender(camera_matrix):
55
  camera_matrix_blender = np.dot(flip_yz, camera_matrix)
56
  else:
57
  # Construct transformation matrix to convert from OpenGL space to Blender space
58
- flip_yz = torch.tensor([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
 
 
59
  if camera_matrix.ndim == 3:
60
  flip_yz = flip_yz.unsqueeze(0)
61
  camera_matrix_blender = torch.matmul(flip_yz.to(camera_matrix), camera_matrix)
62
  return camera_matrix_blender
63
 
64
 
65
- def get_camera(num_frames, elevation=15, azimuth_start=0, azimuth_span=360, blender_coord=True):
 
 
66
  angle_gap = azimuth_span / num_frames
67
  cameras = []
68
  for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
@@ -74,11 +75,10 @@ def get_camera(num_frames, elevation=15, azimuth_start=0, azimuth_span=360, blen
74
 
75
 
76
  class MVDreamStableDiffusionPipeline(DiffusionPipeline):
77
-
78
  def __init__(
79
  self,
80
  vae: AutoencoderKL,
81
- unet: MultiViewUNetWrapperModel,
82
  tokenizer: CLIPTokenizer,
83
  text_encoder: CLIPTextModel,
84
  scheduler: DDIMScheduler,
@@ -86,25 +86,33 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
86
  ):
87
  super().__init__()
88
 
89
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: # type: ignore
90
- deprecation_message = (f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
91
- f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " # type: ignore
92
- "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
93
- " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
94
- " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
95
- " file")
96
- deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
 
 
 
 
97
  new_config = dict(scheduler.config)
98
  new_config["steps_offset"] = 1
99
  scheduler._internal_dict = FrozenDict(new_config)
100
 
101
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: # type: ignore
102
- deprecation_message = (f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
103
- " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
104
- " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
105
- " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
106
- " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file")
107
- deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
 
 
 
 
108
  new_config = dict(scheduler.config)
109
  new_config["clip_sample"] = False
110
  scheduler._internal_dict = FrozenDict(new_config)
@@ -116,7 +124,7 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
116
  tokenizer=tokenizer,
117
  text_encoder=text_encoder,
118
  )
119
- self.vae_scale_factor = 2**(len(self.vae.config.block_out_channels) - 1)
120
  self.register_to_config(requires_safety_checker=requires_safety_checker)
121
 
122
  def enable_vae_slicing(self):
@@ -162,13 +170,15 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
162
  if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
163
  from accelerate import cpu_offload
164
  else:
165
- raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
 
 
166
 
167
  device = torch.device(f"cuda:{gpu_id}")
168
 
169
  if self.device.type != "cpu":
170
  self.to("cpu", silence_dtype_warnings=True)
171
- torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
172
 
173
  for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
174
  cpu_offload(cpu_offloaded_model, device)
@@ -183,17 +193,21 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
183
  if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
184
  from accelerate import cpu_offload_with_hook
185
  else:
186
- raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
 
 
187
 
188
  device = torch.device(f"cuda:{gpu_id}")
189
 
190
  if self.device.type != "cpu":
191
  self.to("cpu", silence_dtype_warnings=True)
192
- torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
193
 
194
  hook = None
195
  for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
196
- _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
 
 
197
 
198
  # We'll offload the last model manually.
199
  self.final_offload_hook = hook
@@ -208,7 +222,11 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
208
  if not hasattr(self.unet, "_hf_hook"):
209
  return self.device
210
  for module in self.unet.modules():
211
- if (hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None):
 
 
 
 
212
  return torch.device(module._hf_hook.execution_device)
213
  return self.device
214
 
@@ -249,7 +267,9 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
249
  elif prompt is not None and isinstance(prompt, list):
250
  batch_size = len(prompt)
251
  else:
252
- raise ValueError(f"`prompt` should be either a string or a list of strings, but got {type(prompt)}.")
 
 
253
 
254
  text_inputs = self.tokenizer(
255
  prompt,
@@ -259,14 +279,25 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
259
  return_tensors="pt",
260
  )
261
  text_input_ids = text_inputs.input_ids
262
- untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
263
-
264
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
265
- removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1:-1])
266
- logger.warning("The following part of your input was truncated because CLIP can only handle sequences up to"
267
- f" {self.tokenizer.model_max_length} tokens: {removed_text}")
 
 
 
 
 
 
 
 
268
 
269
- if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
 
 
 
270
  attention_mask = text_inputs.attention_mask.to(device)
271
  else:
272
  attention_mask = None
@@ -282,7 +313,9 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
282
  bs_embed, seq_len, _ = prompt_embeds.shape
283
  # duplicate text embeddings for each generation per prompt, using mps friendly method
284
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
285
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
 
 
286
 
287
  # get unconditional embeddings for classifier free guidance
288
  if do_classifier_free_guidance:
@@ -290,14 +323,18 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
290
  if negative_prompt is None:
291
  uncond_tokens = [""] * batch_size
292
  elif type(prompt) is not type(negative_prompt):
293
- raise TypeError(f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
294
- f" {type(prompt)}.")
 
 
295
  elif isinstance(negative_prompt, str):
296
  uncond_tokens = [negative_prompt]
297
  elif batch_size != len(negative_prompt):
298
- raise ValueError(f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
299
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
300
- " the batch size of `prompt`.")
 
 
301
  else:
302
  uncond_tokens = negative_prompt
303
 
@@ -310,7 +347,10 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
310
  return_tensors="pt",
311
  )
312
 
313
- if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
 
 
 
314
  attention_mask = uncond_input.attention_mask.to(device)
315
  else:
316
  attention_mask = None
@@ -324,10 +364,16 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
324
  # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
325
  seq_len = negative_prompt_embeds.shape[1]
326
 
327
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
 
 
328
 
329
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
330
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
 
 
 
 
331
 
332
  # For classifier free guidance, we need to do two forward passes.
333
  # Here we concatenate the unconditional and text embeddings into a single batch
@@ -350,25 +396,48 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
350
  # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
351
  # and should be between [0, 1]
352
 
353
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
 
 
354
  extra_step_kwargs = {}
355
  if accepts_eta:
356
  extra_step_kwargs["eta"] = eta
357
 
358
  # check if the scheduler accepts generator
359
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
 
 
360
  if accepts_generator:
361
  extra_step_kwargs["generator"] = generator
362
  return extra_step_kwargs
363
 
364
- def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
365
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  if isinstance(generator, list) and len(generator) != batch_size:
367
- raise ValueError(f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
368
- f" size of {batch_size}. Make sure the batch size matches the length of the generators.")
 
 
369
 
370
  if latents is None:
371
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
 
 
372
  else:
373
  latents = latents.to(device)
374
 
@@ -392,14 +461,13 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
392
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
393
  callback_steps: int = 1,
394
  batch_size: int = 4,
395
- device = torch.device("cuda:0"),
396
  ):
397
  self.unet = self.unet.to(device=device)
398
  self.vae = self.vae.to(device=device)
399
 
400
  self.text_encoder = self.text_encoder.to(device=device)
401
 
402
-
403
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
404
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
405
  # corresponds to doing no classifier free guidance.
@@ -415,7 +483,7 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
415
  num_images_per_prompt=num_images_per_prompt,
416
  do_classifier_free_guidance=do_classifier_free_guidance,
417
  negative_prompt=negative_prompt,
418
- ) # type: ignore
419
  prompt_embeds_neg, prompt_embeds_pos = _prompt_embeds.chunk(2)
420
 
421
  # Prepare latent variables
@@ -429,7 +497,7 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
429
  generator,
430
  None,
431
  )
432
-
433
  camera = get_camera(batch_size).to(dtype=latents.dtype, device=device)
434
 
435
  # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
@@ -442,13 +510,21 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
442
  # expand the latents if we are doing classifier free guidance
443
  multiplier = 2 if do_classifier_free_guidance else 1
444
  latent_model_input = torch.cat([latents] * multiplier)
445
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
 
 
446
 
447
  # predict the noise residual
448
  noise_pred = self.unet.forward(
449
  x=latent_model_input,
450
- timesteps=torch.tensor([t] * 4 * multiplier, dtype=latent_model_input.dtype, device=device),
451
- context=torch.cat([prompt_embeds_neg] * 4 + [prompt_embeds_pos] * 4),
 
 
 
 
 
 
452
  num_frames=4,
453
  camera=torch.cat([camera] * multiplier),
454
  )
@@ -456,17 +532,23 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
456
  # perform guidance
457
  if do_classifier_free_guidance:
458
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
459
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
 
 
460
 
461
  # compute the previous noisy sample x_t -> x_t-1
462
  # latents = self.scheduler.step(noise_pred.to(dtype=torch.float32), t, latents.to(dtype=torch.float32)).prev_sample.to(prompt_embeds.dtype)
463
- latents: torch.Tensor = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
 
 
464
 
465
  # call the callback, if provided
466
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
 
 
467
  progress_bar.update()
468
  if callback is not None and i % callback_steps == 0:
469
- callback(i, t, latents) # type: ignore
470
 
471
  # Post-processing
472
  if output_type == "latent":
 
1
  import torch
 
2
  import inspect
3
+ import numpy as np
4
  from typing import Callable, List, Optional, Union
5
  from transformers import CLIPTextModel, CLIPTokenizer
6
  from diffusers import AutoencoderKL, DiffusionPipeline
 
12
  )
13
  from diffusers.configuration_utils import FrozenDict
14
  from diffusers.schedulers import DDIMScheduler
15
+ from diffusers.utils.torch_utils import randn_tensor
16
+
17
+ from .models import MultiViewUNetModel
 
18
 
19
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
 
20
 
 
21
 
22
  def create_camera_to_world_matrix(elevation, azimuth):
23
  elevation = np.radians(elevation)
 
52
  camera_matrix_blender = np.dot(flip_yz, camera_matrix)
53
  else:
54
  # Construct transformation matrix to convert from OpenGL space to Blender space
55
+ flip_yz = torch.tensor(
56
+ [[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]
57
+ )
58
  if camera_matrix.ndim == 3:
59
  flip_yz = flip_yz.unsqueeze(0)
60
  camera_matrix_blender = torch.matmul(flip_yz.to(camera_matrix), camera_matrix)
61
  return camera_matrix_blender
62
 
63
 
64
+ def get_camera(
65
+ num_frames, elevation=15, azimuth_start=0, azimuth_span=360, blender_coord=True
66
+ ):
67
  angle_gap = azimuth_span / num_frames
68
  cameras = []
69
  for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
 
75
 
76
 
77
  class MVDreamStableDiffusionPipeline(DiffusionPipeline):
 
78
  def __init__(
79
  self,
80
  vae: AutoencoderKL,
81
+ unet: MultiViewUNetModel,
82
  tokenizer: CLIPTokenizer,
83
  text_encoder: CLIPTextModel,
84
  scheduler: DDIMScheduler,
 
86
  ):
87
  super().__init__()
88
 
89
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: # type: ignore
90
+ deprecation_message = (
91
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
92
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " # type: ignore
93
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
94
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
95
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
96
+ " file"
97
+ )
98
+ deprecate(
99
+ "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
100
+ )
101
  new_config = dict(scheduler.config)
102
  new_config["steps_offset"] = 1
103
  scheduler._internal_dict = FrozenDict(new_config)
104
 
105
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: # type: ignore
106
+ deprecation_message = (
107
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
108
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
109
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
110
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
111
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
112
+ )
113
+ deprecate(
114
+ "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
115
+ )
116
  new_config = dict(scheduler.config)
117
  new_config["clip_sample"] = False
118
  scheduler._internal_dict = FrozenDict(new_config)
 
124
  tokenizer=tokenizer,
125
  text_encoder=text_encoder,
126
  )
127
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
128
  self.register_to_config(requires_safety_checker=requires_safety_checker)
129
 
130
  def enable_vae_slicing(self):
 
170
  if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
171
  from accelerate import cpu_offload
172
  else:
173
+ raise ImportError(
174
+ "`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher"
175
+ )
176
 
177
  device = torch.device(f"cuda:{gpu_id}")
178
 
179
  if self.device.type != "cpu":
180
  self.to("cpu", silence_dtype_warnings=True)
181
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
182
 
183
  for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
184
  cpu_offload(cpu_offloaded_model, device)
 
193
  if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
194
  from accelerate import cpu_offload_with_hook
195
  else:
196
+ raise ImportError(
197
+ "`enable_model_offload` requires `accelerate v0.17.0` or higher."
198
+ )
199
 
200
  device = torch.device(f"cuda:{gpu_id}")
201
 
202
  if self.device.type != "cpu":
203
  self.to("cpu", silence_dtype_warnings=True)
204
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
205
 
206
  hook = None
207
  for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
208
+ _, hook = cpu_offload_with_hook(
209
+ cpu_offloaded_model, device, prev_module_hook=hook
210
+ )
211
 
212
  # We'll offload the last model manually.
213
  self.final_offload_hook = hook
 
222
  if not hasattr(self.unet, "_hf_hook"):
223
  return self.device
224
  for module in self.unet.modules():
225
+ if (
226
+ hasattr(module, "_hf_hook")
227
+ and hasattr(module._hf_hook, "execution_device")
228
+ and module._hf_hook.execution_device is not None
229
+ ):
230
  return torch.device(module._hf_hook.execution_device)
231
  return self.device
232
 
 
267
  elif prompt is not None and isinstance(prompt, list):
268
  batch_size = len(prompt)
269
  else:
270
+ raise ValueError(
271
+ f"`prompt` should be either a string or a list of strings, but got {type(prompt)}."
272
+ )
273
 
274
  text_inputs = self.tokenizer(
275
  prompt,
 
279
  return_tensors="pt",
280
  )
281
  text_input_ids = text_inputs.input_ids
282
+ untruncated_ids = self.tokenizer(
283
+ prompt, padding="longest", return_tensors="pt"
284
+ ).input_ids
285
+
286
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
287
+ text_input_ids, untruncated_ids
288
+ ):
289
+ removed_text = self.tokenizer.batch_decode(
290
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
291
+ )
292
+ logger.warning(
293
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
294
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
295
+ )
296
 
297
+ if (
298
+ hasattr(self.text_encoder.config, "use_attention_mask")
299
+ and self.text_encoder.config.use_attention_mask
300
+ ):
301
  attention_mask = text_inputs.attention_mask.to(device)
302
  else:
303
  attention_mask = None
 
313
  bs_embed, seq_len, _ = prompt_embeds.shape
314
  # duplicate text embeddings for each generation per prompt, using mps friendly method
315
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
316
+ prompt_embeds = prompt_embeds.view(
317
+ bs_embed * num_images_per_prompt, seq_len, -1
318
+ )
319
 
320
  # get unconditional embeddings for classifier free guidance
321
  if do_classifier_free_guidance:
 
323
  if negative_prompt is None:
324
  uncond_tokens = [""] * batch_size
325
  elif type(prompt) is not type(negative_prompt):
326
+ raise TypeError(
327
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
328
+ f" {type(prompt)}."
329
+ )
330
  elif isinstance(negative_prompt, str):
331
  uncond_tokens = [negative_prompt]
332
  elif batch_size != len(negative_prompt):
333
+ raise ValueError(
334
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
335
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
336
+ " the batch size of `prompt`."
337
+ )
338
  else:
339
  uncond_tokens = negative_prompt
340
 
 
347
  return_tensors="pt",
348
  )
349
 
350
+ if (
351
+ hasattr(self.text_encoder.config, "use_attention_mask")
352
+ and self.text_encoder.config.use_attention_mask
353
+ ):
354
  attention_mask = uncond_input.attention_mask.to(device)
355
  else:
356
  attention_mask = None
 
364
  # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
365
  seq_len = negative_prompt_embeds.shape[1]
366
 
367
+ negative_prompt_embeds = negative_prompt_embeds.to(
368
+ dtype=self.text_encoder.dtype, device=device
369
+ )
370
 
371
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
372
+ 1, num_images_per_prompt, 1
373
+ )
374
+ negative_prompt_embeds = negative_prompt_embeds.view(
375
+ batch_size * num_images_per_prompt, seq_len, -1
376
+ )
377
 
378
  # For classifier free guidance, we need to do two forward passes.
379
  # Here we concatenate the unconditional and text embeddings into a single batch
 
396
  # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
397
  # and should be between [0, 1]
398
 
399
+ accepts_eta = "eta" in set(
400
+ inspect.signature(self.scheduler.step).parameters.keys()
401
+ )
402
  extra_step_kwargs = {}
403
  if accepts_eta:
404
  extra_step_kwargs["eta"] = eta
405
 
406
  # check if the scheduler accepts generator
407
+ accepts_generator = "generator" in set(
408
+ inspect.signature(self.scheduler.step).parameters.keys()
409
+ )
410
  if accepts_generator:
411
  extra_step_kwargs["generator"] = generator
412
  return extra_step_kwargs
413
 
414
+ def prepare_latents(
415
+ self,
416
+ batch_size,
417
+ num_channels_latents,
418
+ height,
419
+ width,
420
+ dtype,
421
+ device,
422
+ generator,
423
+ latents=None,
424
+ ):
425
+ shape = (
426
+ batch_size,
427
+ num_channels_latents,
428
+ height // self.vae_scale_factor,
429
+ width // self.vae_scale_factor,
430
+ )
431
  if isinstance(generator, list) and len(generator) != batch_size:
432
+ raise ValueError(
433
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
434
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
435
+ )
436
 
437
  if latents is None:
438
+ latents = randn_tensor(
439
+ shape, generator=generator, device=device, dtype=dtype
440
+ )
441
  else:
442
  latents = latents.to(device)
443
 
 
461
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
462
  callback_steps: int = 1,
463
  batch_size: int = 4,
464
+ device=torch.device("cuda:0"),
465
  ):
466
  self.unet = self.unet.to(device=device)
467
  self.vae = self.vae.to(device=device)
468
 
469
  self.text_encoder = self.text_encoder.to(device=device)
470
 
 
471
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
472
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
473
  # corresponds to doing no classifier free guidance.
 
483
  num_images_per_prompt=num_images_per_prompt,
484
  do_classifier_free_guidance=do_classifier_free_guidance,
485
  negative_prompt=negative_prompt,
486
+ ) # type: ignore
487
  prompt_embeds_neg, prompt_embeds_pos = _prompt_embeds.chunk(2)
488
 
489
  # Prepare latent variables
 
497
  generator,
498
  None,
499
  )
500
+
501
  camera = get_camera(batch_size).to(dtype=latents.dtype, device=device)
502
 
503
  # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
 
510
  # expand the latents if we are doing classifier free guidance
511
  multiplier = 2 if do_classifier_free_guidance else 1
512
  latent_model_input = torch.cat([latents] * multiplier)
513
+ latent_model_input = self.scheduler.scale_model_input(
514
+ latent_model_input, t
515
+ )
516
 
517
  # predict the noise residual
518
  noise_pred = self.unet.forward(
519
  x=latent_model_input,
520
+ timesteps=torch.tensor(
521
+ [t] * 4 * multiplier,
522
+ dtype=latent_model_input.dtype,
523
+ device=device,
524
+ ),
525
+ context=torch.cat(
526
+ [prompt_embeds_neg] * 4 + [prompt_embeds_pos] * 4
527
+ ),
528
  num_frames=4,
529
  camera=torch.cat([camera] * multiplier),
530
  )
 
532
  # perform guidance
533
  if do_classifier_free_guidance:
534
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
535
+ noise_pred = noise_pred_uncond + guidance_scale * (
536
+ noise_pred_text - noise_pred_uncond
537
+ )
538
 
539
  # compute the previous noisy sample x_t -> x_t-1
540
  # latents = self.scheduler.step(noise_pred.to(dtype=torch.float32), t, latents.to(dtype=torch.float32)).prev_sample.to(prompt_embeds.dtype)
541
+ latents: torch.Tensor = self.scheduler.step(
542
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
543
+ )[0]
544
 
545
  # call the callback, if provided
546
+ if i == len(timesteps) - 1 or (
547
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
548
+ ):
549
  progress_bar.update()
550
  if callback is not None and i % callback_steps == 0:
551
+ callback(i, t, latents) # type: ignore
552
 
553
  # Post-processing
554
  if output_type == "latent":
mvdream/util.py CHANGED
@@ -12,6 +12,7 @@ import torch
12
  import torch.nn as nn
13
  from einops import repeat
14
 
 
15
  def checkpoint(func, inputs, params, flag):
16
  """
17
  Evaluate a function without caching intermediate activations, allowing for
@@ -30,7 +31,6 @@ def checkpoint(func, inputs, params, flag):
30
 
31
 
32
  class CheckpointFunction(torch.autograd.Function):
33
-
34
  @staticmethod
35
  def forward(ctx, run_function, length, *args):
36
  ctx.run_function = run_function
@@ -43,9 +43,7 @@ class CheckpointFunction(torch.autograd.Function):
43
 
44
  @staticmethod
45
  def backward(ctx, *output_grads):
46
- ctx.input_tensors = [
47
- x.detach().requires_grad_(True) for x in ctx.input_tensors
48
- ]
49
  with torch.enable_grad():
50
  # Fixes a bug where the first op in run_function modifies the
51
  # Tensor storage in place, which is not allowed for detach()'d
@@ -76,16 +74,18 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
76
  if not repeat_only:
77
  half = dim // 2
78
  freqs = torch.exp(
79
- -math.log(max_period) *
80
- torch.arange(start=0, end=half, dtype=torch.float32) /
81
- half).to(device=timesteps.device)
 
82
  args = timesteps[:, None] * freqs[None]
83
  embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
84
  if dim % 2:
85
  embedding = torch.cat(
86
- [embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
 
87
  else:
88
- embedding = repeat(timesteps, 'b -> b d', d=dim)
89
  # import pdb; pdb.set_trace()
90
  return embedding
91
 
@@ -98,6 +98,7 @@ def zero_module(module):
98
  p.detach().zero_()
99
  return module
100
 
 
101
  def conv_nd(dims, *args, **kwargs):
102
  """
103
  Create a 1D, 2D, or 3D convolution module.
 
12
  import torch.nn as nn
13
  from einops import repeat
14
 
15
+
16
  def checkpoint(func, inputs, params, flag):
17
  """
18
  Evaluate a function without caching intermediate activations, allowing for
 
31
 
32
 
33
  class CheckpointFunction(torch.autograd.Function):
 
34
  @staticmethod
35
  def forward(ctx, run_function, length, *args):
36
  ctx.run_function = run_function
 
43
 
44
  @staticmethod
45
  def backward(ctx, *output_grads):
46
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
 
 
47
  with torch.enable_grad():
48
  # Fixes a bug where the first op in run_function modifies the
49
  # Tensor storage in place, which is not allowed for detach()'d
 
74
  if not repeat_only:
75
  half = dim // 2
76
  freqs = torch.exp(
77
+ -math.log(max_period)
78
+ * torch.arange(start=0, end=half, dtype=torch.float32)
79
+ / half
80
+ ).to(device=timesteps.device)
81
  args = timesteps[:, None] * freqs[None]
82
  embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
83
  if dim % 2:
84
  embedding = torch.cat(
85
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
86
+ )
87
  else:
88
+ embedding = repeat(timesteps, "b -> b d", d=dim)
89
  # import pdb; pdb.set_trace()
90
  return embedding
91
 
 
98
  p.detach().zero_()
99
  return module
100
 
101
+
102
  def conv_nd(dims, *args, **kwargs):
103
  """
104
  Create a 1D, 2D, or 3D convolution module.