ashawkey commited on
Commit
b861db3
1 Parent(s): f0e1e27

fix mvdream

Browse files
.gitignore CHANGED
@@ -4,4 +4,5 @@
4
  *.pyc
5
 
6
  weights
 
7
  sd-v2*
 
4
  *.pyc
5
 
6
  weights
7
+ models
8
  sd-v2*
README.md CHANGED
@@ -4,6 +4,9 @@ modified from https://github.com/KokeCacao/mvdream-hf.
4
 
5
  ### convert weights
6
  ```bash
 
 
 
7
  # download original ckpt
8
  wget https://huggingface.co/MVDream/MVDream/resolve/main/sd-v2.1-base-4view.pt
9
  wget https://raw.githubusercontent.com/bytedance/MVDream/main/mvdream/configs/sd-v2-base.yaml
 
4
 
5
  ### convert weights
6
  ```bash
7
+ # dependency
8
+ pip install -U omegaconf diffusers safetensors huggingface_hub transformers accelerate
9
+
10
  # download original ckpt
11
  wget https://huggingface.co/MVDream/MVDream/resolve/main/sd-v2.1-base-4view.pt
12
  wget https://raw.githubusercontent.com/bytedance/MVDream/main/mvdream/configs/sd-v2-base.yaml
convert_mvdream_to_diffusers.py CHANGED
@@ -405,6 +405,12 @@ def convert_from_original_mvdream_ckpt(checkpoint_path, original_config_file, de
405
  # )
406
  # print(f"Unet Config: {original_config.model.params.unet_config.params}")
407
  unet_config = create_unet_config(original_config)
 
 
 
 
 
 
408
  unet = MultiViewUNetModel(**unet_config)
409
  unet.register_to_config(**unet_config)
410
  # print(f"Unet State Dict: {unet.state_dict().keys()}")
 
405
  # )
406
  # print(f"Unet Config: {original_config.model.params.unet_config.params}")
407
  unet_config = create_unet_config(original_config)
408
+
409
+ # remove unused configs
410
+ del unet_config['legacy']
411
+ del unet_config['use_linear_in_transformer']
412
+ del unet_config['use_spatial_transformer']
413
+
414
  unet = MultiViewUNetModel(**unet_config)
415
  unet.register_to_config(**unet_config)
416
  # print(f"Unet State Dict: {unet.state_dict().keys()}")
mvdream/attention.py CHANGED
@@ -1,6 +1,3 @@
1
- # obtained and modified from https://github.com/bytedance/MVDream
2
-
3
- import math
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
@@ -14,9 +11,9 @@ from .util import checkpoint, zero_module
14
  try:
15
  import xformers # type: ignore
16
  import xformers.ops # type: ignore
17
-
18
  XFORMERS_IS_AVAILBLE = True
19
  except:
 
20
  XFORMERS_IS_AVAILBLE = False
21
 
22
  # CrossAttn precision handling
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
 
11
  try:
12
  import xformers # type: ignore
13
  import xformers.ops # type: ignore
 
14
  XFORMERS_IS_AVAILBLE = True
15
  except:
16
+ print(f'[WARN] xformers is unavailable!')
17
  XFORMERS_IS_AVAILBLE = False
18
 
19
  # CrossAttn precision handling
mvdream/models.py CHANGED
@@ -1,6 +1,3 @@
1
- # obtained and modified from https://github.com/bytedance/MVDream
2
-
3
- import math
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
@@ -9,7 +6,6 @@ from diffusers.models.modeling_utils import ModelMixin
9
  from typing import Any, List, Optional
10
  from torch import Tensor
11
 
12
- from abc import abstractmethod
13
  from .util import (
14
  checkpoint,
15
  conv_nd,
@@ -19,19 +15,8 @@ from .util import (
19
  )
20
  from .attention import SpatialTransformer, SpatialTransformer3D
21
 
22
- class TimestepBlock(nn.Module):
23
- """
24
- Any module where forward() takes timestep embeddings as a second argument.
25
- """
26
 
27
- @abstractmethod
28
- def forward(self, x, emb):
29
- """
30
- Apply the module to `x` given `emb` timestep embeddings.
31
- """
32
-
33
-
34
- class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
35
  """
36
  A sequential module that passes timestep embeddings to the children that
37
  support it as an extra input.
@@ -39,7 +24,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
39
 
40
  def forward(self, x, emb, context=None, num_frames=1):
41
  for layer in self:
42
- if isinstance(layer, TimestepBlock):
43
  x = layer(x, emb)
44
  elif isinstance(layer, SpatialTransformer3D):
45
  x = layer(x, context, num_frames=num_frames)
@@ -117,7 +102,7 @@ class Downsample(nn.Module):
117
  return self.op(x)
118
 
119
 
120
- class ResBlock(TimestepBlock):
121
  """
122
  A residual block that can optionally change the number of channels.
123
  :param channels: the number of input channels.
@@ -289,6 +274,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
289
  disable_middle_self_attn=False,
290
  adm_in_channels=None,
291
  camera_dim=None,
 
292
  ):
293
  super().__init__()
294
  assert context_dim is not None
@@ -383,7 +369,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
383
 
384
  self.input_blocks = nn.ModuleList(
385
  [
386
- TimestepEmbedSequential(
387
  conv_nd(dims, in_channels, model_channels, 3, padding=1)
388
  )
389
  ]
@@ -430,13 +416,13 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
430
  use_checkpoint=use_checkpoint,
431
  )
432
  )
433
- self.input_blocks.append(TimestepEmbedSequential(*layers))
434
  self._feature_size += ch
435
  input_block_chans.append(ch)
436
  if level != len(channel_mult) - 1:
437
  out_ch = ch
438
  self.input_blocks.append(
439
- TimestepEmbedSequential(
440
  ResBlock(
441
  ch,
442
  time_embed_dim,
@@ -464,7 +450,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
464
  num_heads = ch // num_head_channels
465
  dim_head = num_head_channels
466
 
467
- self.middle_block = TimestepEmbedSequential(
468
  ResBlock(
469
  ch,
470
  time_embed_dim,
@@ -550,7 +536,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
550
  else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
551
  )
552
  ds //= 2
553
- self.output_blocks.append(TimestepEmbedSequential(*layers))
554
  self._feature_size += ch
555
 
556
  self.out = nn.Sequential(
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
 
6
  from typing import Any, List, Optional
7
  from torch import Tensor
8
 
 
9
  from .util import (
10
  checkpoint,
11
  conv_nd,
 
15
  )
16
  from .attention import SpatialTransformer, SpatialTransformer3D
17
 
 
 
 
 
18
 
19
+ class CondSequential(nn.Sequential):
 
 
 
 
 
 
 
20
  """
21
  A sequential module that passes timestep embeddings to the children that
22
  support it as an extra input.
 
24
 
25
  def forward(self, x, emb, context=None, num_frames=1):
26
  for layer in self:
27
+ if isinstance(layer, ResBlock):
28
  x = layer(x, emb)
29
  elif isinstance(layer, SpatialTransformer3D):
30
  x = layer(x, context, num_frames=num_frames)
 
102
  return self.op(x)
103
 
104
 
105
+ class ResBlock(nn.Module):
106
  """
107
  A residual block that can optionally change the number of channels.
108
  :param channels: the number of input channels.
 
274
  disable_middle_self_attn=False,
275
  adm_in_channels=None,
276
  camera_dim=None,
277
+ **kwargs,
278
  ):
279
  super().__init__()
280
  assert context_dim is not None
 
369
 
370
  self.input_blocks = nn.ModuleList(
371
  [
372
+ CondSequential(
373
  conv_nd(dims, in_channels, model_channels, 3, padding=1)
374
  )
375
  ]
 
416
  use_checkpoint=use_checkpoint,
417
  )
418
  )
419
+ self.input_blocks.append(CondSequential(*layers))
420
  self._feature_size += ch
421
  input_block_chans.append(ch)
422
  if level != len(channel_mult) - 1:
423
  out_ch = ch
424
  self.input_blocks.append(
425
+ CondSequential(
426
  ResBlock(
427
  ch,
428
  time_embed_dim,
 
450
  num_heads = ch // num_head_channels
451
  dim_head = num_head_channels
452
 
453
+ self.middle_block = CondSequential(
454
  ResBlock(
455
  ch,
456
  time_embed_dim,
 
536
  else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
537
  )
538
  ds //= 2
539
+ self.output_blocks.append(CondSequential(*layers))
540
  self._feature_size += ch
541
 
542
  self.out = nn.Sequential(
mvdream/util.py CHANGED
@@ -1,12 +1,3 @@
1
- # adopted from
2
- # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
- # and
4
- # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
- # and
6
- # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
- #
8
- # thanks!
9
-
10
  import math
11
  import torch
12
  import torch.nn as nn
 
 
 
 
 
 
 
 
 
 
1
  import math
2
  import torch
3
  import torch.nn as nn