fix mvdream
Browse files- .gitignore +1 -0
- README.md +3 -0
- convert_mvdream_to_diffusers.py +6 -0
- mvdream/attention.py +1 -4
- mvdream/models.py +9 -23
- mvdream/util.py +0 -9
.gitignore
CHANGED
@@ -4,4 +4,5 @@
|
|
4 |
*.pyc
|
5 |
|
6 |
weights
|
|
|
7 |
sd-v2*
|
|
|
4 |
*.pyc
|
5 |
|
6 |
weights
|
7 |
+
models
|
8 |
sd-v2*
|
README.md
CHANGED
@@ -4,6 +4,9 @@ modified from https://github.com/KokeCacao/mvdream-hf.
|
|
4 |
|
5 |
### convert weights
|
6 |
```bash
|
|
|
|
|
|
|
7 |
# download original ckpt
|
8 |
wget https://huggingface.co/MVDream/MVDream/resolve/main/sd-v2.1-base-4view.pt
|
9 |
wget https://raw.githubusercontent.com/bytedance/MVDream/main/mvdream/configs/sd-v2-base.yaml
|
|
|
4 |
|
5 |
### convert weights
|
6 |
```bash
|
7 |
+
# dependency
|
8 |
+
pip install -U omegaconf diffusers safetensors huggingface_hub transformers accelerate
|
9 |
+
|
10 |
# download original ckpt
|
11 |
wget https://huggingface.co/MVDream/MVDream/resolve/main/sd-v2.1-base-4view.pt
|
12 |
wget https://raw.githubusercontent.com/bytedance/MVDream/main/mvdream/configs/sd-v2-base.yaml
|
convert_mvdream_to_diffusers.py
CHANGED
@@ -405,6 +405,12 @@ def convert_from_original_mvdream_ckpt(checkpoint_path, original_config_file, de
|
|
405 |
# )
|
406 |
# print(f"Unet Config: {original_config.model.params.unet_config.params}")
|
407 |
unet_config = create_unet_config(original_config)
|
|
|
|
|
|
|
|
|
|
|
|
|
408 |
unet = MultiViewUNetModel(**unet_config)
|
409 |
unet.register_to_config(**unet_config)
|
410 |
# print(f"Unet State Dict: {unet.state_dict().keys()}")
|
|
|
405 |
# )
|
406 |
# print(f"Unet Config: {original_config.model.params.unet_config.params}")
|
407 |
unet_config = create_unet_config(original_config)
|
408 |
+
|
409 |
+
# remove unused configs
|
410 |
+
del unet_config['legacy']
|
411 |
+
del unet_config['use_linear_in_transformer']
|
412 |
+
del unet_config['use_spatial_transformer']
|
413 |
+
|
414 |
unet = MultiViewUNetModel(**unet_config)
|
415 |
unet.register_to_config(**unet_config)
|
416 |
# print(f"Unet State Dict: {unet.state_dict().keys()}")
|
mvdream/attention.py
CHANGED
@@ -1,6 +1,3 @@
|
|
1 |
-
# obtained and modified from https://github.com/bytedance/MVDream
|
2 |
-
|
3 |
-
import math
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
import torch.nn.functional as F
|
@@ -14,9 +11,9 @@ from .util import checkpoint, zero_module
|
|
14 |
try:
|
15 |
import xformers # type: ignore
|
16 |
import xformers.ops # type: ignore
|
17 |
-
|
18 |
XFORMERS_IS_AVAILBLE = True
|
19 |
except:
|
|
|
20 |
XFORMERS_IS_AVAILBLE = False
|
21 |
|
22 |
# CrossAttn precision handling
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
|
|
11 |
try:
|
12 |
import xformers # type: ignore
|
13 |
import xformers.ops # type: ignore
|
|
|
14 |
XFORMERS_IS_AVAILBLE = True
|
15 |
except:
|
16 |
+
print(f'[WARN] xformers is unavailable!')
|
17 |
XFORMERS_IS_AVAILBLE = False
|
18 |
|
19 |
# CrossAttn precision handling
|
mvdream/models.py
CHANGED
@@ -1,6 +1,3 @@
|
|
1 |
-
# obtained and modified from https://github.com/bytedance/MVDream
|
2 |
-
|
3 |
-
import math
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
import torch.nn.functional as F
|
@@ -9,7 +6,6 @@ from diffusers.models.modeling_utils import ModelMixin
|
|
9 |
from typing import Any, List, Optional
|
10 |
from torch import Tensor
|
11 |
|
12 |
-
from abc import abstractmethod
|
13 |
from .util import (
|
14 |
checkpoint,
|
15 |
conv_nd,
|
@@ -19,19 +15,8 @@ from .util import (
|
|
19 |
)
|
20 |
from .attention import SpatialTransformer, SpatialTransformer3D
|
21 |
|
22 |
-
class TimestepBlock(nn.Module):
|
23 |
-
"""
|
24 |
-
Any module where forward() takes timestep embeddings as a second argument.
|
25 |
-
"""
|
26 |
|
27 |
-
|
28 |
-
def forward(self, x, emb):
|
29 |
-
"""
|
30 |
-
Apply the module to `x` given `emb` timestep embeddings.
|
31 |
-
"""
|
32 |
-
|
33 |
-
|
34 |
-
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
35 |
"""
|
36 |
A sequential module that passes timestep embeddings to the children that
|
37 |
support it as an extra input.
|
@@ -39,7 +24,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
|
39 |
|
40 |
def forward(self, x, emb, context=None, num_frames=1):
|
41 |
for layer in self:
|
42 |
-
if isinstance(layer,
|
43 |
x = layer(x, emb)
|
44 |
elif isinstance(layer, SpatialTransformer3D):
|
45 |
x = layer(x, context, num_frames=num_frames)
|
@@ -117,7 +102,7 @@ class Downsample(nn.Module):
|
|
117 |
return self.op(x)
|
118 |
|
119 |
|
120 |
-
class ResBlock(
|
121 |
"""
|
122 |
A residual block that can optionally change the number of channels.
|
123 |
:param channels: the number of input channels.
|
@@ -289,6 +274,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
289 |
disable_middle_self_attn=False,
|
290 |
adm_in_channels=None,
|
291 |
camera_dim=None,
|
|
|
292 |
):
|
293 |
super().__init__()
|
294 |
assert context_dim is not None
|
@@ -383,7 +369,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
383 |
|
384 |
self.input_blocks = nn.ModuleList(
|
385 |
[
|
386 |
-
|
387 |
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
388 |
)
|
389 |
]
|
@@ -430,13 +416,13 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
430 |
use_checkpoint=use_checkpoint,
|
431 |
)
|
432 |
)
|
433 |
-
self.input_blocks.append(
|
434 |
self._feature_size += ch
|
435 |
input_block_chans.append(ch)
|
436 |
if level != len(channel_mult) - 1:
|
437 |
out_ch = ch
|
438 |
self.input_blocks.append(
|
439 |
-
|
440 |
ResBlock(
|
441 |
ch,
|
442 |
time_embed_dim,
|
@@ -464,7 +450,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
464 |
num_heads = ch // num_head_channels
|
465 |
dim_head = num_head_channels
|
466 |
|
467 |
-
self.middle_block =
|
468 |
ResBlock(
|
469 |
ch,
|
470 |
time_embed_dim,
|
@@ -550,7 +536,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
550 |
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
551 |
)
|
552 |
ds //= 2
|
553 |
-
self.output_blocks.append(
|
554 |
self._feature_size += ch
|
555 |
|
556 |
self.out = nn.Sequential(
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
|
|
6 |
from typing import Any, List, Optional
|
7 |
from torch import Tensor
|
8 |
|
|
|
9 |
from .util import (
|
10 |
checkpoint,
|
11 |
conv_nd,
|
|
|
15 |
)
|
16 |
from .attention import SpatialTransformer, SpatialTransformer3D
|
17 |
|
|
|
|
|
|
|
|
|
18 |
|
19 |
+
class CondSequential(nn.Sequential):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
"""
|
21 |
A sequential module that passes timestep embeddings to the children that
|
22 |
support it as an extra input.
|
|
|
24 |
|
25 |
def forward(self, x, emb, context=None, num_frames=1):
|
26 |
for layer in self:
|
27 |
+
if isinstance(layer, ResBlock):
|
28 |
x = layer(x, emb)
|
29 |
elif isinstance(layer, SpatialTransformer3D):
|
30 |
x = layer(x, context, num_frames=num_frames)
|
|
|
102 |
return self.op(x)
|
103 |
|
104 |
|
105 |
+
class ResBlock(nn.Module):
|
106 |
"""
|
107 |
A residual block that can optionally change the number of channels.
|
108 |
:param channels: the number of input channels.
|
|
|
274 |
disable_middle_self_attn=False,
|
275 |
adm_in_channels=None,
|
276 |
camera_dim=None,
|
277 |
+
**kwargs,
|
278 |
):
|
279 |
super().__init__()
|
280 |
assert context_dim is not None
|
|
|
369 |
|
370 |
self.input_blocks = nn.ModuleList(
|
371 |
[
|
372 |
+
CondSequential(
|
373 |
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
374 |
)
|
375 |
]
|
|
|
416 |
use_checkpoint=use_checkpoint,
|
417 |
)
|
418 |
)
|
419 |
+
self.input_blocks.append(CondSequential(*layers))
|
420 |
self._feature_size += ch
|
421 |
input_block_chans.append(ch)
|
422 |
if level != len(channel_mult) - 1:
|
423 |
out_ch = ch
|
424 |
self.input_blocks.append(
|
425 |
+
CondSequential(
|
426 |
ResBlock(
|
427 |
ch,
|
428 |
time_embed_dim,
|
|
|
450 |
num_heads = ch // num_head_channels
|
451 |
dim_head = num_head_channels
|
452 |
|
453 |
+
self.middle_block = CondSequential(
|
454 |
ResBlock(
|
455 |
ch,
|
456 |
time_embed_dim,
|
|
|
536 |
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
537 |
)
|
538 |
ds //= 2
|
539 |
+
self.output_blocks.append(CondSequential(*layers))
|
540 |
self._feature_size += ch
|
541 |
|
542 |
self.out = nn.Sequential(
|
mvdream/util.py
CHANGED
@@ -1,12 +1,3 @@
|
|
1 |
-
# adopted from
|
2 |
-
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
3 |
-
# and
|
4 |
-
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
5 |
-
# and
|
6 |
-
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
|
7 |
-
#
|
8 |
-
# thanks!
|
9 |
-
|
10 |
import math
|
11 |
import torch
|
12 |
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import math
|
2 |
import torch
|
3 |
import torch.nn as nn
|