lllyasviel commited on
Commit
7386161
·
0 Parent(s):
Files changed (50) hide show
  1. .gitignore +144 -0
  2. entry.py +0 -0
  3. sd_xl_base.yaml +98 -0
  4. sd_xl_refiner.yaml +91 -0
  5. sgm/__init__.py +4 -0
  6. sgm/data/__init__.py +1 -0
  7. sgm/data/cifar10.py +67 -0
  8. sgm/data/dataset.py +80 -0
  9. sgm/data/mnist.py +85 -0
  10. sgm/inference/api.py +388 -0
  11. sgm/inference/helpers.py +305 -0
  12. sgm/lr_scheduler.py +135 -0
  13. sgm/models/__init__.py +2 -0
  14. sgm/models/autoencoder.py +335 -0
  15. sgm/models/diffusion.py +320 -0
  16. sgm/modules/__init__.py +6 -0
  17. sgm/modules/attention.py +633 -0
  18. sgm/modules/autoencoding/__init__.py +0 -0
  19. sgm/modules/autoencoding/losses/__init__.py +246 -0
  20. sgm/modules/autoencoding/lpips/__init__.py +0 -0
  21. sgm/modules/autoencoding/lpips/loss/.gitignore +1 -0
  22. sgm/modules/autoencoding/lpips/loss/LICENSE +23 -0
  23. sgm/modules/autoencoding/lpips/loss/__init__.py +0 -0
  24. sgm/modules/autoencoding/lpips/loss/lpips.py +147 -0
  25. sgm/modules/autoencoding/lpips/model/LICENSE +58 -0
  26. sgm/modules/autoencoding/lpips/model/__init__.py +0 -0
  27. sgm/modules/autoencoding/lpips/model/model.py +88 -0
  28. sgm/modules/autoencoding/lpips/util.py +128 -0
  29. sgm/modules/autoencoding/lpips/vqperceptual.py +17 -0
  30. sgm/modules/autoencoding/regularizers/__init__.py +53 -0
  31. sgm/modules/diffusionmodules/__init__.py +7 -0
  32. sgm/modules/diffusionmodules/denoiser.py +63 -0
  33. sgm/modules/diffusionmodules/denoiser_scaling.py +31 -0
  34. sgm/modules/diffusionmodules/denoiser_weighting.py +24 -0
  35. sgm/modules/diffusionmodules/discretizer.py +69 -0
  36. sgm/modules/diffusionmodules/guiders.py +53 -0
  37. sgm/modules/diffusionmodules/loss.py +69 -0
  38. sgm/modules/diffusionmodules/model.py +743 -0
  39. sgm/modules/diffusionmodules/openaimodel.py +1262 -0
  40. sgm/modules/diffusionmodules/sampling.py +365 -0
  41. sgm/modules/diffusionmodules/sampling_utils.py +48 -0
  42. sgm/modules/diffusionmodules/sigma_sampling.py +31 -0
  43. sgm/modules/diffusionmodules/util.py +308 -0
  44. sgm/modules/diffusionmodules/wrappers.py +34 -0
  45. sgm/modules/distributions/__init__.py +0 -0
  46. sgm/modules/distributions/distributions.py +102 -0
  47. sgm/modules/ema.py +86 -0
  48. sgm/modules/encoders/__init__.py +0 -0
  49. sgm/modules/encoders/modules.py +960 -0
  50. sgm/util.py +248 -0
.gitignore ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .idea/
2
+
3
+ training/
4
+ lightning_logs/
5
+ image_log/
6
+ result/
7
+ results/
8
+
9
+ *.pth
10
+ *.pt
11
+ *.ckpt
12
+ *.safetensors
13
+ *.mp4
14
+ *.avi
15
+
16
+ # Byte-compiled / optimized / DLL files
17
+ __pycache__/
18
+ *.py[cod]
19
+ *$py.class
20
+
21
+ # C extensions
22
+ *.so
23
+
24
+ # Distribution / packaging
25
+ .Python
26
+ build/
27
+ develop-eggs/
28
+ dist/
29
+ downloads/
30
+ eggs/
31
+ .eggs/
32
+ lib/
33
+ lib64/
34
+ parts/
35
+ sdist/
36
+ var/
37
+ wheels/
38
+ pip-wheel-metadata/
39
+ share/python-wheels/
40
+ *.egg-info/
41
+ .installed.cfg
42
+ *.egg
43
+ MANIFEST
44
+
45
+ # PyInstaller
46
+ # Usually these files are written by a python script from a template
47
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
48
+ *.manifest
49
+ *.spec
50
+
51
+ # Installer logs
52
+ pip-log.txt
53
+ pip-delete-this-directory.txt
54
+
55
+ # Unit test / coverage reports
56
+ htmlcov/
57
+ .tox/
58
+ .nox/
59
+ .coverage
60
+ .coverage.*
61
+ .cache
62
+ nosetests.xml
63
+ coverage.xml
64
+ *.cover
65
+ *.py,cover
66
+ .hypothesis/
67
+ .pytest_cache/
68
+
69
+ # Translations
70
+ *.mo
71
+ *.pot
72
+
73
+ # Django stuff:
74
+ *.log
75
+ local_settings.py
76
+ db.sqlite3
77
+ db.sqlite3-journal
78
+
79
+ # Flask stuff:
80
+ instance/
81
+ .webassets-cache
82
+
83
+ # Scrapy stuff:
84
+ .scrapy
85
+
86
+ # Sphinx documentation
87
+ docs/_build/
88
+
89
+ # PyBuilder
90
+ target/
91
+
92
+ # Jupyter Notebook
93
+ .ipynb_checkpoints
94
+
95
+ # IPython
96
+ profile_default/
97
+ ipython_config.py
98
+
99
+ # pyenv
100
+ .python-version
101
+
102
+ # pipenv
103
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
104
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
105
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
106
+ # install all needed dependencies.
107
+ #Pipfile.lock
108
+
109
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
110
+ __pypackages__/
111
+
112
+ # Celery stuff
113
+ celerybeat-schedule
114
+ celerybeat.pid
115
+
116
+ # SageMath parsed files
117
+ *.sage.py
118
+
119
+ # Environments
120
+ .env
121
+ .venv
122
+ env/
123
+ venv/
124
+ ENV/
125
+ env.bak/
126
+ venv.bak/
127
+
128
+ # Spyder project settings
129
+ .spyderproject
130
+ .spyproject
131
+
132
+ # Rope project settings
133
+ .ropeproject
134
+
135
+ # mkdocs documentation
136
+ /site
137
+
138
+ # mypy
139
+ .mypy_cache/
140
+ .dmypy.json
141
+ dmypy.json
142
+
143
+ # Pyre type checker
144
+ .pyre/
entry.py ADDED
File without changes
sd_xl_base.yaml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ scale_factor: 0.13025
5
+ disable_first_stage_autocast: True
6
+
7
+ denoiser_config:
8
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
9
+ params:
10
+ num_idx: 1000
11
+
12
+ weighting_config:
13
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
14
+ scaling_config:
15
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
16
+ discretization_config:
17
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
18
+
19
+ network_config:
20
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
21
+ params:
22
+ adm_in_channels: 2816
23
+ num_classes: sequential
24
+ use_checkpoint: True
25
+ in_channels: 4
26
+ out_channels: 4
27
+ model_channels: 320
28
+ attention_resolutions: [4, 2]
29
+ num_res_blocks: 2
30
+ channel_mult: [1, 2, 4]
31
+ num_head_channels: 64
32
+ use_spatial_transformer: True
33
+ use_linear_in_transformer: True
34
+ transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
35
+ context_dim: 2048
36
+ spatial_transformer_attn_type: softmax-xformers
37
+ legacy: False
38
+
39
+ conditioner_config:
40
+ target: sgm.modules.GeneralConditioner
41
+ params:
42
+ emb_models:
43
+ # crossattn cond
44
+ - is_trainable: False
45
+ input_key: txt
46
+ target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
47
+ params:
48
+ layer: hidden
49
+ layer_idx: 11
50
+ # crossattn and vector cond
51
+ - is_trainable: False
52
+ input_key: txt
53
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
54
+ params:
55
+ arch: ViT-bigG-14
56
+ version: laion2b_s39b_b160k
57
+ freeze: True
58
+ layer: penultimate
59
+ always_return_pooled: True
60
+ legacy: False
61
+ # vector cond
62
+ - is_trainable: False
63
+ input_key: original_size_as_tuple
64
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
65
+ params:
66
+ outdim: 256 # multiplied by two
67
+ # vector cond
68
+ - is_trainable: False
69
+ input_key: crop_coords_top_left
70
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
71
+ params:
72
+ outdim: 256 # multiplied by two
73
+ # vector cond
74
+ - is_trainable: False
75
+ input_key: target_size_as_tuple
76
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
77
+ params:
78
+ outdim: 256 # multiplied by two
79
+
80
+ first_stage_config:
81
+ target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
82
+ params:
83
+ embed_dim: 4
84
+ monitor: val/rec_loss
85
+ ddconfig:
86
+ attn_type: vanilla-xformers
87
+ double_z: true
88
+ z_channels: 4
89
+ resolution: 256
90
+ in_channels: 3
91
+ out_ch: 3
92
+ ch: 128
93
+ ch_mult: [1, 2, 4, 4]
94
+ num_res_blocks: 2
95
+ attn_resolutions: []
96
+ dropout: 0.0
97
+ lossconfig:
98
+ target: torch.nn.Identity
sd_xl_refiner.yaml ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ scale_factor: 0.13025
5
+ disable_first_stage_autocast: True
6
+
7
+ denoiser_config:
8
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
9
+ params:
10
+ num_idx: 1000
11
+
12
+ weighting_config:
13
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
14
+ scaling_config:
15
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
16
+ discretization_config:
17
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
18
+
19
+ network_config:
20
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
21
+ params:
22
+ adm_in_channels: 2560
23
+ num_classes: sequential
24
+ use_checkpoint: True
25
+ in_channels: 4
26
+ out_channels: 4
27
+ model_channels: 384
28
+ attention_resolutions: [4, 2]
29
+ num_res_blocks: 2
30
+ channel_mult: [1, 2, 4, 4]
31
+ num_head_channels: 64
32
+ use_spatial_transformer: True
33
+ use_linear_in_transformer: True
34
+ transformer_depth: 4
35
+ context_dim: [1280, 1280, 1280, 1280] # 1280
36
+ spatial_transformer_attn_type: softmax-xformers
37
+ legacy: False
38
+
39
+ conditioner_config:
40
+ target: sgm.modules.GeneralConditioner
41
+ params:
42
+ emb_models:
43
+ # crossattn and vector cond
44
+ - is_trainable: False
45
+ input_key: txt
46
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
47
+ params:
48
+ arch: ViT-bigG-14
49
+ version: laion2b_s39b_b160k
50
+ legacy: False
51
+ freeze: True
52
+ layer: penultimate
53
+ always_return_pooled: True
54
+ # vector cond
55
+ - is_trainable: False
56
+ input_key: original_size_as_tuple
57
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
58
+ params:
59
+ outdim: 256 # multiplied by two
60
+ # vector cond
61
+ - is_trainable: False
62
+ input_key: crop_coords_top_left
63
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
64
+ params:
65
+ outdim: 256 # multiplied by two
66
+ # vector cond
67
+ - is_trainable: False
68
+ input_key: aesthetic_score
69
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
70
+ params:
71
+ outdim: 256 # multiplied by one
72
+
73
+ first_stage_config:
74
+ target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
75
+ params:
76
+ embed_dim: 4
77
+ monitor: val/rec_loss
78
+ ddconfig:
79
+ attn_type: vanilla-xformers
80
+ double_z: true
81
+ z_channels: 4
82
+ resolution: 256
83
+ in_channels: 3
84
+ out_ch: 3
85
+ ch: 128
86
+ ch_mult: [1, 2, 4, 4]
87
+ num_res_blocks: 2
88
+ attn_resolutions: []
89
+ dropout: 0.0
90
+ lossconfig:
91
+ target: torch.nn.Identity
sgm/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .models import AutoencodingEngine, DiffusionEngine
2
+ from .util import get_configs_path, instantiate_from_config
3
+
4
+ __version__ = "0.1.0"
sgm/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .dataset import StableDataModuleFromConfig
sgm/data/cifar10.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torchvision
3
+ from torch.utils.data import DataLoader, Dataset
4
+ from torchvision import transforms
5
+
6
+
7
+ class CIFAR10DataDictWrapper(Dataset):
8
+ def __init__(self, dset):
9
+ super().__init__()
10
+ self.dset = dset
11
+
12
+ def __getitem__(self, i):
13
+ x, y = self.dset[i]
14
+ return {"jpg": x, "cls": y}
15
+
16
+ def __len__(self):
17
+ return len(self.dset)
18
+
19
+
20
+ class CIFAR10Loader(pl.LightningDataModule):
21
+ def __init__(self, batch_size, num_workers=0, shuffle=True):
22
+ super().__init__()
23
+
24
+ transform = transforms.Compose(
25
+ [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
26
+ )
27
+
28
+ self.batch_size = batch_size
29
+ self.num_workers = num_workers
30
+ self.shuffle = shuffle
31
+ self.train_dataset = CIFAR10DataDictWrapper(
32
+ torchvision.datasets.CIFAR10(
33
+ root=".data/", train=True, download=True, transform=transform
34
+ )
35
+ )
36
+ self.test_dataset = CIFAR10DataDictWrapper(
37
+ torchvision.datasets.CIFAR10(
38
+ root=".data/", train=False, download=True, transform=transform
39
+ )
40
+ )
41
+
42
+ def prepare_data(self):
43
+ pass
44
+
45
+ def train_dataloader(self):
46
+ return DataLoader(
47
+ self.train_dataset,
48
+ batch_size=self.batch_size,
49
+ shuffle=self.shuffle,
50
+ num_workers=self.num_workers,
51
+ )
52
+
53
+ def test_dataloader(self):
54
+ return DataLoader(
55
+ self.test_dataset,
56
+ batch_size=self.batch_size,
57
+ shuffle=self.shuffle,
58
+ num_workers=self.num_workers,
59
+ )
60
+
61
+ def val_dataloader(self):
62
+ return DataLoader(
63
+ self.test_dataset,
64
+ batch_size=self.batch_size,
65
+ shuffle=self.shuffle,
66
+ num_workers=self.num_workers,
67
+ )
sgm/data/dataset.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torchdata.datapipes.iter
4
+ import webdataset as wds
5
+ from omegaconf import DictConfig
6
+ from pytorch_lightning import LightningDataModule
7
+
8
+ try:
9
+ from sdata import create_dataset, create_dummy_dataset, create_loader
10
+ except ImportError as e:
11
+ print("#" * 100)
12
+ print("Datasets not yet available")
13
+ print("to enable, we need to add stable-datasets as a submodule")
14
+ print("please use ``git submodule update --init --recursive``")
15
+ print("and do ``pip install -e stable-datasets/`` from the root of this repo")
16
+ print("#" * 100)
17
+ exit(1)
18
+
19
+
20
+ class StableDataModuleFromConfig(LightningDataModule):
21
+ def __init__(
22
+ self,
23
+ train: DictConfig,
24
+ validation: Optional[DictConfig] = None,
25
+ test: Optional[DictConfig] = None,
26
+ skip_val_loader: bool = False,
27
+ dummy: bool = False,
28
+ ):
29
+ super().__init__()
30
+ self.train_config = train
31
+ assert (
32
+ "datapipeline" in self.train_config and "loader" in self.train_config
33
+ ), "train config requires the fields `datapipeline` and `loader`"
34
+
35
+ self.val_config = validation
36
+ if not skip_val_loader:
37
+ if self.val_config is not None:
38
+ assert (
39
+ "datapipeline" in self.val_config and "loader" in self.val_config
40
+ ), "validation config requires the fields `datapipeline` and `loader`"
41
+ else:
42
+ print(
43
+ "Warning: No Validation datapipeline defined, using that one from training"
44
+ )
45
+ self.val_config = train
46
+
47
+ self.test_config = test
48
+ if self.test_config is not None:
49
+ assert (
50
+ "datapipeline" in self.test_config and "loader" in self.test_config
51
+ ), "test config requires the fields `datapipeline` and `loader`"
52
+
53
+ self.dummy = dummy
54
+ if self.dummy:
55
+ print("#" * 100)
56
+ print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
57
+ print("#" * 100)
58
+
59
+ def setup(self, stage: str) -> None:
60
+ print("Preparing datasets")
61
+ if self.dummy:
62
+ data_fn = create_dummy_dataset
63
+ else:
64
+ data_fn = create_dataset
65
+
66
+ self.train_datapipeline = data_fn(**self.train_config.datapipeline)
67
+ if self.val_config:
68
+ self.val_datapipeline = data_fn(**self.val_config.datapipeline)
69
+ if self.test_config:
70
+ self.test_datapipeline = data_fn(**self.test_config.datapipeline)
71
+
72
+ def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe:
73
+ loader = create_loader(self.train_datapipeline, **self.train_config.loader)
74
+ return loader
75
+
76
+ def val_dataloader(self) -> wds.DataPipeline:
77
+ return create_loader(self.val_datapipeline, **self.val_config.loader)
78
+
79
+ def test_dataloader(self) -> wds.DataPipeline:
80
+ return create_loader(self.test_datapipeline, **self.test_config.loader)
sgm/data/mnist.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torchvision
3
+ from torch.utils.data import DataLoader, Dataset
4
+ from torchvision import transforms
5
+
6
+
7
+ class MNISTDataDictWrapper(Dataset):
8
+ def __init__(self, dset):
9
+ super().__init__()
10
+ self.dset = dset
11
+
12
+ def __getitem__(self, i):
13
+ x, y = self.dset[i]
14
+ return {"jpg": x, "cls": y}
15
+
16
+ def __len__(self):
17
+ return len(self.dset)
18
+
19
+
20
+ class MNISTLoader(pl.LightningDataModule):
21
+ def __init__(self, batch_size, num_workers=0, prefetch_factor=2, shuffle=True):
22
+ super().__init__()
23
+
24
+ transform = transforms.Compose(
25
+ [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
26
+ )
27
+
28
+ self.batch_size = batch_size
29
+ self.num_workers = num_workers
30
+ self.prefetch_factor = prefetch_factor if num_workers > 0 else 0
31
+ self.shuffle = shuffle
32
+ self.train_dataset = MNISTDataDictWrapper(
33
+ torchvision.datasets.MNIST(
34
+ root=".data/", train=True, download=True, transform=transform
35
+ )
36
+ )
37
+ self.test_dataset = MNISTDataDictWrapper(
38
+ torchvision.datasets.MNIST(
39
+ root=".data/", train=False, download=True, transform=transform
40
+ )
41
+ )
42
+
43
+ def prepare_data(self):
44
+ pass
45
+
46
+ def train_dataloader(self):
47
+ return DataLoader(
48
+ self.train_dataset,
49
+ batch_size=self.batch_size,
50
+ shuffle=self.shuffle,
51
+ num_workers=self.num_workers,
52
+ prefetch_factor=self.prefetch_factor,
53
+ )
54
+
55
+ def test_dataloader(self):
56
+ return DataLoader(
57
+ self.test_dataset,
58
+ batch_size=self.batch_size,
59
+ shuffle=self.shuffle,
60
+ num_workers=self.num_workers,
61
+ prefetch_factor=self.prefetch_factor,
62
+ )
63
+
64
+ def val_dataloader(self):
65
+ return DataLoader(
66
+ self.test_dataset,
67
+ batch_size=self.batch_size,
68
+ shuffle=self.shuffle,
69
+ num_workers=self.num_workers,
70
+ prefetch_factor=self.prefetch_factor,
71
+ )
72
+
73
+
74
+ if __name__ == "__main__":
75
+ dset = MNISTDataDictWrapper(
76
+ torchvision.datasets.MNIST(
77
+ root=".data/",
78
+ train=False,
79
+ download=True,
80
+ transform=transforms.Compose(
81
+ [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
82
+ ),
83
+ )
84
+ )
85
+ ex = dset[0]
sgm/inference/api.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, asdict
2
+ from enum import Enum
3
+ from omegaconf import OmegaConf
4
+ import pathlib
5
+ from sgm.inference.helpers import (
6
+ do_sample,
7
+ do_img2img,
8
+ Img2ImgDiscretizationWrapper,
9
+ )
10
+ from sgm.modules.diffusionmodules.sampling import (
11
+ EulerEDMSampler,
12
+ HeunEDMSampler,
13
+ EulerAncestralSampler,
14
+ DPMPP2SAncestralSampler,
15
+ DPMPP2MSampler,
16
+ LinearMultistepSampler,
17
+ )
18
+ from sgm.util import load_model_from_config
19
+ from typing import Optional
20
+
21
+
22
+ class ModelArchitecture(str, Enum):
23
+ SD_2_1 = "stable-diffusion-v2-1"
24
+ SD_2_1_768 = "stable-diffusion-v2-1-768"
25
+ SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
26
+ SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
27
+ SDXL_V1_BASE = "stable-diffusion-xl-v1-base"
28
+ SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner"
29
+
30
+
31
+ class Sampler(str, Enum):
32
+ EULER_EDM = "EulerEDMSampler"
33
+ HEUN_EDM = "HeunEDMSampler"
34
+ EULER_ANCESTRAL = "EulerAncestralSampler"
35
+ DPMPP2S_ANCESTRAL = "DPMPP2SAncestralSampler"
36
+ DPMPP2M = "DPMPP2MSampler"
37
+ LINEAR_MULTISTEP = "LinearMultistepSampler"
38
+
39
+
40
+ class Discretization(str, Enum):
41
+ LEGACY_DDPM = "LegacyDDPMDiscretization"
42
+ EDM = "EDMDiscretization"
43
+
44
+
45
+ class Guider(str, Enum):
46
+ VANILLA = "VanillaCFG"
47
+ IDENTITY = "IdentityGuider"
48
+
49
+
50
+ class Thresholder(str, Enum):
51
+ NONE = "None"
52
+
53
+
54
+ @dataclass
55
+ class SamplingParams:
56
+ width: int = 1024
57
+ height: int = 1024
58
+ steps: int = 50
59
+ sampler: Sampler = Sampler.DPMPP2M
60
+ discretization: Discretization = Discretization.LEGACY_DDPM
61
+ guider: Guider = Guider.VANILLA
62
+ thresholder: Thresholder = Thresholder.NONE
63
+ scale: float = 6.0
64
+ aesthetic_score: float = 5.0
65
+ negative_aesthetic_score: float = 5.0
66
+ img2img_strength: float = 1.0
67
+ orig_width: int = 1024
68
+ orig_height: int = 1024
69
+ crop_coords_top: int = 0
70
+ crop_coords_left: int = 0
71
+ sigma_min: float = 0.0292
72
+ sigma_max: float = 14.6146
73
+ rho: float = 3.0
74
+ s_churn: float = 0.0
75
+ s_tmin: float = 0.0
76
+ s_tmax: float = 999.0
77
+ s_noise: float = 1.0
78
+ eta: float = 1.0
79
+ order: int = 4
80
+
81
+
82
+ @dataclass
83
+ class SamplingSpec:
84
+ width: int
85
+ height: int
86
+ channels: int
87
+ factor: int
88
+ is_legacy: bool
89
+ config: str
90
+ ckpt: str
91
+ is_guided: bool
92
+
93
+
94
+ model_specs = {
95
+ ModelArchitecture.SD_2_1: SamplingSpec(
96
+ height=512,
97
+ width=512,
98
+ channels=4,
99
+ factor=8,
100
+ is_legacy=True,
101
+ config="sd_2_1.yaml",
102
+ ckpt="v2-1_512-ema-pruned.safetensors",
103
+ is_guided=True,
104
+ ),
105
+ ModelArchitecture.SD_2_1_768: SamplingSpec(
106
+ height=768,
107
+ width=768,
108
+ channels=4,
109
+ factor=8,
110
+ is_legacy=True,
111
+ config="sd_2_1_768.yaml",
112
+ ckpt="v2-1_768-ema-pruned.safetensors",
113
+ is_guided=True,
114
+ ),
115
+ ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec(
116
+ height=1024,
117
+ width=1024,
118
+ channels=4,
119
+ factor=8,
120
+ is_legacy=False,
121
+ config="sd_xl_base.yaml",
122
+ ckpt="sd_xl_base_0.9.safetensors",
123
+ is_guided=True,
124
+ ),
125
+ ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec(
126
+ height=1024,
127
+ width=1024,
128
+ channels=4,
129
+ factor=8,
130
+ is_legacy=True,
131
+ config="sd_xl_refiner.yaml",
132
+ ckpt="sd_xl_refiner_0.9.safetensors",
133
+ is_guided=True,
134
+ ),
135
+ ModelArchitecture.SDXL_V1_BASE: SamplingSpec(
136
+ height=1024,
137
+ width=1024,
138
+ channels=4,
139
+ factor=8,
140
+ is_legacy=False,
141
+ config="sd_xl_base.yaml",
142
+ ckpt="sd_xl_base_1.0.safetensors",
143
+ is_guided=True,
144
+ ),
145
+ ModelArchitecture.SDXL_V1_REFINER: SamplingSpec(
146
+ height=1024,
147
+ width=1024,
148
+ channels=4,
149
+ factor=8,
150
+ is_legacy=True,
151
+ config="sd_xl_refiner.yaml",
152
+ ckpt="sd_xl_refiner_1.0.safetensors",
153
+ is_guided=True,
154
+ ),
155
+ }
156
+
157
+
158
+ class SamplingPipeline:
159
+ def __init__(
160
+ self,
161
+ model_id: ModelArchitecture,
162
+ model_path="checkpoints",
163
+ config_path="configs/inference",
164
+ device="cuda",
165
+ use_fp16=True,
166
+ ) -> None:
167
+ if model_id not in model_specs:
168
+ raise ValueError(f"Model {model_id} not supported")
169
+ self.model_id = model_id
170
+ self.specs = model_specs[self.model_id]
171
+ self.config = str(pathlib.Path(config_path, self.specs.config))
172
+ self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt))
173
+ self.device = device
174
+ self.model = self._load_model(device=device, use_fp16=use_fp16)
175
+
176
+ def _load_model(self, device="cuda", use_fp16=True):
177
+ config = OmegaConf.load(self.config)
178
+ model = load_model_from_config(config, self.ckpt)
179
+ if model is None:
180
+ raise ValueError(f"Model {self.model_id} could not be loaded")
181
+ model.to(device)
182
+ if use_fp16:
183
+ model.conditioner.half()
184
+ model.model.half()
185
+ return model
186
+
187
+ def text_to_image(
188
+ self,
189
+ params: SamplingParams,
190
+ prompt: str,
191
+ negative_prompt: str = "",
192
+ samples: int = 1,
193
+ return_latents: bool = False,
194
+ ):
195
+ sampler = get_sampler_config(params)
196
+ value_dict = asdict(params)
197
+ value_dict["prompt"] = prompt
198
+ value_dict["negative_prompt"] = negative_prompt
199
+ value_dict["target_width"] = params.width
200
+ value_dict["target_height"] = params.height
201
+ return do_sample(
202
+ self.model,
203
+ sampler,
204
+ value_dict,
205
+ samples,
206
+ params.height,
207
+ params.width,
208
+ self.specs.channels,
209
+ self.specs.factor,
210
+ force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
211
+ return_latents=return_latents,
212
+ filter=None,
213
+ )
214
+
215
+ def image_to_image(
216
+ self,
217
+ params: SamplingParams,
218
+ image,
219
+ prompt: str,
220
+ negative_prompt: str = "",
221
+ samples: int = 1,
222
+ return_latents: bool = False,
223
+ ):
224
+ sampler = get_sampler_config(params)
225
+
226
+ if params.img2img_strength < 1.0:
227
+ sampler.discretization = Img2ImgDiscretizationWrapper(
228
+ sampler.discretization,
229
+ strength=params.img2img_strength,
230
+ )
231
+ height, width = image.shape[2], image.shape[3]
232
+ value_dict = asdict(params)
233
+ value_dict["prompt"] = prompt
234
+ value_dict["negative_prompt"] = negative_prompt
235
+ value_dict["target_width"] = width
236
+ value_dict["target_height"] = height
237
+ return do_img2img(
238
+ image,
239
+ self.model,
240
+ sampler,
241
+ value_dict,
242
+ samples,
243
+ force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
244
+ return_latents=return_latents,
245
+ filter=None,
246
+ )
247
+
248
+ def refiner(
249
+ self,
250
+ params: SamplingParams,
251
+ image,
252
+ prompt: str,
253
+ negative_prompt: Optional[str] = None,
254
+ samples: int = 1,
255
+ return_latents: bool = False,
256
+ ):
257
+ sampler = get_sampler_config(params)
258
+ value_dict = {
259
+ "orig_width": image.shape[3] * 8,
260
+ "orig_height": image.shape[2] * 8,
261
+ "target_width": image.shape[3] * 8,
262
+ "target_height": image.shape[2] * 8,
263
+ "prompt": prompt,
264
+ "negative_prompt": negative_prompt,
265
+ "crop_coords_top": 0,
266
+ "crop_coords_left": 0,
267
+ "aesthetic_score": 6.0,
268
+ "negative_aesthetic_score": 2.5,
269
+ }
270
+
271
+ return do_img2img(
272
+ image,
273
+ self.model,
274
+ sampler,
275
+ value_dict,
276
+ samples,
277
+ skip_encode=True,
278
+ return_latents=return_latents,
279
+ filter=None,
280
+ )
281
+
282
+
283
+ def get_guider_config(params: SamplingParams):
284
+ if params.guider == Guider.IDENTITY:
285
+ guider_config = {
286
+ "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
287
+ }
288
+ elif params.guider == Guider.VANILLA:
289
+ scale = params.scale
290
+
291
+ thresholder = params.thresholder
292
+
293
+ if thresholder == Thresholder.NONE:
294
+ dyn_thresh_config = {
295
+ "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
296
+ }
297
+ else:
298
+ raise NotImplementedError
299
+
300
+ guider_config = {
301
+ "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
302
+ "params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
303
+ }
304
+ else:
305
+ raise NotImplementedError
306
+ return guider_config
307
+
308
+
309
+ def get_discretization_config(params: SamplingParams):
310
+ if params.discretization == Discretization.LEGACY_DDPM:
311
+ discretization_config = {
312
+ "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
313
+ }
314
+ elif params.discretization == Discretization.EDM:
315
+ discretization_config = {
316
+ "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
317
+ "params": {
318
+ "sigma_min": params.sigma_min,
319
+ "sigma_max": params.sigma_max,
320
+ "rho": params.rho,
321
+ },
322
+ }
323
+ else:
324
+ raise ValueError(f"unknown discretization {params.discretization}")
325
+ return discretization_config
326
+
327
+
328
+ def get_sampler_config(params: SamplingParams):
329
+ discretization_config = get_discretization_config(params)
330
+ guider_config = get_guider_config(params)
331
+ sampler = None
332
+ if params.sampler == Sampler.EULER_EDM:
333
+ return EulerEDMSampler(
334
+ num_steps=params.steps,
335
+ discretization_config=discretization_config,
336
+ guider_config=guider_config,
337
+ s_churn=params.s_churn,
338
+ s_tmin=params.s_tmin,
339
+ s_tmax=params.s_tmax,
340
+ s_noise=params.s_noise,
341
+ verbose=True,
342
+ )
343
+ if params.sampler == Sampler.HEUN_EDM:
344
+ return HeunEDMSampler(
345
+ num_steps=params.steps,
346
+ discretization_config=discretization_config,
347
+ guider_config=guider_config,
348
+ s_churn=params.s_churn,
349
+ s_tmin=params.s_tmin,
350
+ s_tmax=params.s_tmax,
351
+ s_noise=params.s_noise,
352
+ verbose=True,
353
+ )
354
+ if params.sampler == Sampler.EULER_ANCESTRAL:
355
+ return EulerAncestralSampler(
356
+ num_steps=params.steps,
357
+ discretization_config=discretization_config,
358
+ guider_config=guider_config,
359
+ eta=params.eta,
360
+ s_noise=params.s_noise,
361
+ verbose=True,
362
+ )
363
+ if params.sampler == Sampler.DPMPP2S_ANCESTRAL:
364
+ return DPMPP2SAncestralSampler(
365
+ num_steps=params.steps,
366
+ discretization_config=discretization_config,
367
+ guider_config=guider_config,
368
+ eta=params.eta,
369
+ s_noise=params.s_noise,
370
+ verbose=True,
371
+ )
372
+ if params.sampler == Sampler.DPMPP2M:
373
+ return DPMPP2MSampler(
374
+ num_steps=params.steps,
375
+ discretization_config=discretization_config,
376
+ guider_config=guider_config,
377
+ verbose=True,
378
+ )
379
+ if params.sampler == Sampler.LINEAR_MULTISTEP:
380
+ return LinearMultistepSampler(
381
+ num_steps=params.steps,
382
+ discretization_config=discretization_config,
383
+ guider_config=guider_config,
384
+ order=params.order,
385
+ verbose=True,
386
+ )
387
+
388
+ raise ValueError(f"unknown sampler {params.sampler}!")
sgm/inference/helpers.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Union, List, Optional
3
+
4
+ import math
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+ from einops import rearrange
9
+ from imwatermark import WatermarkEncoder
10
+ from omegaconf import ListConfig
11
+ from torch import autocast
12
+
13
+ from sgm.util import append_dims
14
+
15
+
16
+ class WatermarkEmbedder:
17
+ def __init__(self, watermark):
18
+ self.watermark = watermark
19
+ self.num_bits = len(WATERMARK_BITS)
20
+ self.encoder = WatermarkEncoder()
21
+ self.encoder.set_watermark("bits", self.watermark)
22
+
23
+ def __call__(self, image: torch.Tensor):
24
+ """
25
+ Adds a predefined watermark to the input image
26
+
27
+ Args:
28
+ image: ([N,] B, C, H, W) in range [0, 1]
29
+
30
+ Returns:
31
+ same as input but watermarked
32
+ """
33
+ # watermarking libary expects input as cv2 BGR format
34
+ squeeze = len(image.shape) == 4
35
+ if squeeze:
36
+ image = image[None, ...]
37
+ n = image.shape[0]
38
+ image_np = rearrange(
39
+ (255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
40
+ ).numpy()[:, :, :, ::-1]
41
+ # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
42
+ for k in range(image_np.shape[0]):
43
+ image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
44
+ image = torch.from_numpy(
45
+ rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
46
+ ).to(image.device)
47
+ image = torch.clamp(image / 255, min=0.0, max=1.0)
48
+ if squeeze:
49
+ image = image[0]
50
+ return image
51
+
52
+
53
+ # A fixed 48-bit message that was choosen at random
54
+ # WATERMARK_MESSAGE = 0xB3EC907BB19E
55
+ WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
56
+ # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
57
+ WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
58
+ embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
59
+
60
+
61
+ def get_unique_embedder_keys_from_conditioner(conditioner):
62
+ return list({x.input_key for x in conditioner.embedders})
63
+
64
+
65
+ def perform_save_locally(save_path, samples):
66
+ os.makedirs(os.path.join(save_path), exist_ok=True)
67
+ base_count = len(os.listdir(os.path.join(save_path)))
68
+ samples = embed_watermark(samples)
69
+ for sample in samples:
70
+ sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
71
+ Image.fromarray(sample.astype(np.uint8)).save(
72
+ os.path.join(save_path, f"{base_count:09}.png")
73
+ )
74
+ base_count += 1
75
+
76
+
77
+ class Img2ImgDiscretizationWrapper:
78
+ """
79
+ wraps a discretizer, and prunes the sigmas
80
+ params:
81
+ strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
82
+ """
83
+
84
+ def __init__(self, discretization, strength: float = 1.0):
85
+ self.discretization = discretization
86
+ self.strength = strength
87
+ assert 0.0 <= self.strength <= 1.0
88
+
89
+ def __call__(self, *args, **kwargs):
90
+ # sigmas start large first, and decrease then
91
+ sigmas = self.discretization(*args, **kwargs)
92
+ print(f"sigmas after discretization, before pruning img2img: ", sigmas)
93
+ sigmas = torch.flip(sigmas, (0,))
94
+ sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
95
+ print("prune index:", max(int(self.strength * len(sigmas)), 1))
96
+ sigmas = torch.flip(sigmas, (0,))
97
+ print(f"sigmas after pruning: ", sigmas)
98
+ return sigmas
99
+
100
+
101
+ def do_sample(
102
+ model,
103
+ sampler,
104
+ value_dict,
105
+ num_samples,
106
+ H,
107
+ W,
108
+ C,
109
+ F,
110
+ force_uc_zero_embeddings: Optional[List] = None,
111
+ batch2model_input: Optional[List] = None,
112
+ return_latents=False,
113
+ filter=None,
114
+ device="cuda",
115
+ ):
116
+ if force_uc_zero_embeddings is None:
117
+ force_uc_zero_embeddings = []
118
+ if batch2model_input is None:
119
+ batch2model_input = []
120
+
121
+ with torch.no_grad():
122
+ with autocast(device) as precision_scope:
123
+ with model.ema_scope():
124
+ num_samples = [num_samples]
125
+ batch, batch_uc = get_batch(
126
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
127
+ value_dict,
128
+ num_samples,
129
+ )
130
+ for key in batch:
131
+ if isinstance(batch[key], torch.Tensor):
132
+ print(key, batch[key].shape)
133
+ elif isinstance(batch[key], list):
134
+ print(key, [len(l) for l in batch[key]])
135
+ else:
136
+ print(key, batch[key])
137
+ c, uc = model.conditioner.get_unconditional_conditioning(
138
+ batch,
139
+ batch_uc=batch_uc,
140
+ force_uc_zero_embeddings=force_uc_zero_embeddings,
141
+ )
142
+
143
+ for k in c:
144
+ if not k == "crossattn":
145
+ c[k], uc[k] = map(
146
+ lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc)
147
+ )
148
+
149
+ additional_model_inputs = {}
150
+ for k in batch2model_input:
151
+ additional_model_inputs[k] = batch[k]
152
+
153
+ shape = (math.prod(num_samples), C, H // F, W // F)
154
+ randn = torch.randn(shape).to(device)
155
+
156
+ def denoiser(input, sigma, c):
157
+ return model.denoiser(
158
+ model.model, input, sigma, c, **additional_model_inputs
159
+ )
160
+
161
+ samples_z = sampler(denoiser, randn, cond=c, uc=uc)
162
+ samples_x = model.decode_first_stage(samples_z)
163
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
164
+
165
+ if filter is not None:
166
+ samples = filter(samples)
167
+
168
+ if return_latents:
169
+ return samples, samples_z
170
+ return samples
171
+
172
+
173
+ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
174
+ # Hardcoded demo setups; might undergo some changes in the future
175
+
176
+ batch = {}
177
+ batch_uc = {}
178
+
179
+ for key in keys:
180
+ if key == "txt":
181
+ batch["txt"] = (
182
+ np.repeat([value_dict["prompt"]], repeats=math.prod(N))
183
+ .reshape(N)
184
+ .tolist()
185
+ )
186
+ batch_uc["txt"] = (
187
+ np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
188
+ .reshape(N)
189
+ .tolist()
190
+ )
191
+ elif key == "original_size_as_tuple":
192
+ batch["original_size_as_tuple"] = (
193
+ torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
194
+ .to(device)
195
+ .repeat(*N, 1)
196
+ )
197
+ elif key == "crop_coords_top_left":
198
+ batch["crop_coords_top_left"] = (
199
+ torch.tensor(
200
+ [value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
201
+ )
202
+ .to(device)
203
+ .repeat(*N, 1)
204
+ )
205
+ elif key == "aesthetic_score":
206
+ batch["aesthetic_score"] = (
207
+ torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
208
+ )
209
+ batch_uc["aesthetic_score"] = (
210
+ torch.tensor([value_dict["negative_aesthetic_score"]])
211
+ .to(device)
212
+ .repeat(*N, 1)
213
+ )
214
+
215
+ elif key == "target_size_as_tuple":
216
+ batch["target_size_as_tuple"] = (
217
+ torch.tensor([value_dict["target_height"], value_dict["target_width"]])
218
+ .to(device)
219
+ .repeat(*N, 1)
220
+ )
221
+ else:
222
+ batch[key] = value_dict[key]
223
+
224
+ for key in batch.keys():
225
+ if key not in batch_uc and isinstance(batch[key], torch.Tensor):
226
+ batch_uc[key] = torch.clone(batch[key])
227
+ return batch, batch_uc
228
+
229
+
230
+ def get_input_image_tensor(image: Image.Image, device="cuda"):
231
+ w, h = image.size
232
+ print(f"loaded input image of size ({w}, {h})")
233
+ width, height = map(
234
+ lambda x: x - x % 64, (w, h)
235
+ ) # resize to integer multiple of 64
236
+ image = image.resize((width, height))
237
+ image_array = np.array(image.convert("RGB"))
238
+ image_array = image_array[None].transpose(0, 3, 1, 2)
239
+ image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0
240
+ return image_tensor.to(device)
241
+
242
+
243
+ def do_img2img(
244
+ img,
245
+ model,
246
+ sampler,
247
+ value_dict,
248
+ num_samples,
249
+ force_uc_zero_embeddings=[],
250
+ additional_kwargs={},
251
+ offset_noise_level: float = 0.0,
252
+ return_latents=False,
253
+ skip_encode=False,
254
+ filter=None,
255
+ device="cuda",
256
+ ):
257
+ with torch.no_grad():
258
+ with autocast(device) as precision_scope:
259
+ with model.ema_scope():
260
+ batch, batch_uc = get_batch(
261
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
262
+ value_dict,
263
+ [num_samples],
264
+ )
265
+ c, uc = model.conditioner.get_unconditional_conditioning(
266
+ batch,
267
+ batch_uc=batch_uc,
268
+ force_uc_zero_embeddings=force_uc_zero_embeddings,
269
+ )
270
+
271
+ for k in c:
272
+ c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc))
273
+
274
+ for k in additional_kwargs:
275
+ c[k] = uc[k] = additional_kwargs[k]
276
+ if skip_encode:
277
+ z = img
278
+ else:
279
+ z = model.encode_first_stage(img)
280
+ noise = torch.randn_like(z)
281
+ sigmas = sampler.discretization(sampler.num_steps)
282
+ sigma = sigmas[0].to(z.device)
283
+
284
+ if offset_noise_level > 0.0:
285
+ noise = noise + offset_noise_level * append_dims(
286
+ torch.randn(z.shape[0], device=z.device), z.ndim
287
+ )
288
+ noised_z = z + noise * append_dims(sigma, z.ndim)
289
+ noised_z = noised_z / torch.sqrt(
290
+ 1.0 + sigmas[0] ** 2.0
291
+ ) # Note: hardcoded to DDPM-like scaling. need to generalize later.
292
+
293
+ def denoiser(x, sigma, c):
294
+ return model.denoiser(model.model, x, sigma, c)
295
+
296
+ samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
297
+ samples_x = model.decode_first_stage(samples_z)
298
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
299
+
300
+ if filter is not None:
301
+ samples = filter(samples)
302
+
303
+ if return_latents:
304
+ return samples, samples_z
305
+ return samples
sgm/lr_scheduler.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class LambdaWarmUpCosineScheduler:
5
+ """
6
+ note: use with a base_lr of 1.0
7
+ """
8
+
9
+ def __init__(
10
+ self,
11
+ warm_up_steps,
12
+ lr_min,
13
+ lr_max,
14
+ lr_start,
15
+ max_decay_steps,
16
+ verbosity_interval=0,
17
+ ):
18
+ self.lr_warm_up_steps = warm_up_steps
19
+ self.lr_start = lr_start
20
+ self.lr_min = lr_min
21
+ self.lr_max = lr_max
22
+ self.lr_max_decay_steps = max_decay_steps
23
+ self.last_lr = 0.0
24
+ self.verbosity_interval = verbosity_interval
25
+
26
+ def schedule(self, n, **kwargs):
27
+ if self.verbosity_interval > 0:
28
+ if n % self.verbosity_interval == 0:
29
+ print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
30
+ if n < self.lr_warm_up_steps:
31
+ lr = (
32
+ self.lr_max - self.lr_start
33
+ ) / self.lr_warm_up_steps * n + self.lr_start
34
+ self.last_lr = lr
35
+ return lr
36
+ else:
37
+ t = (n - self.lr_warm_up_steps) / (
38
+ self.lr_max_decay_steps - self.lr_warm_up_steps
39
+ )
40
+ t = min(t, 1.0)
41
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
42
+ 1 + np.cos(t * np.pi)
43
+ )
44
+ self.last_lr = lr
45
+ return lr
46
+
47
+ def __call__(self, n, **kwargs):
48
+ return self.schedule(n, **kwargs)
49
+
50
+
51
+ class LambdaWarmUpCosineScheduler2:
52
+ """
53
+ supports repeated iterations, configurable via lists
54
+ note: use with a base_lr of 1.0.
55
+ """
56
+
57
+ def __init__(
58
+ self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
59
+ ):
60
+ assert (
61
+ len(warm_up_steps)
62
+ == len(f_min)
63
+ == len(f_max)
64
+ == len(f_start)
65
+ == len(cycle_lengths)
66
+ )
67
+ self.lr_warm_up_steps = warm_up_steps
68
+ self.f_start = f_start
69
+ self.f_min = f_min
70
+ self.f_max = f_max
71
+ self.cycle_lengths = cycle_lengths
72
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
73
+ self.last_f = 0.0
74
+ self.verbosity_interval = verbosity_interval
75
+
76
+ def find_in_interval(self, n):
77
+ interval = 0
78
+ for cl in self.cum_cycles[1:]:
79
+ if n <= cl:
80
+ return interval
81
+ interval += 1
82
+
83
+ def schedule(self, n, **kwargs):
84
+ cycle = self.find_in_interval(n)
85
+ n = n - self.cum_cycles[cycle]
86
+ if self.verbosity_interval > 0:
87
+ if n % self.verbosity_interval == 0:
88
+ print(
89
+ f"current step: {n}, recent lr-multiplier: {self.last_f}, "
90
+ f"current cycle {cycle}"
91
+ )
92
+ if n < self.lr_warm_up_steps[cycle]:
93
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
94
+ cycle
95
+ ] * n + self.f_start[cycle]
96
+ self.last_f = f
97
+ return f
98
+ else:
99
+ t = (n - self.lr_warm_up_steps[cycle]) / (
100
+ self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
101
+ )
102
+ t = min(t, 1.0)
103
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
104
+ 1 + np.cos(t * np.pi)
105
+ )
106
+ self.last_f = f
107
+ return f
108
+
109
+ def __call__(self, n, **kwargs):
110
+ return self.schedule(n, **kwargs)
111
+
112
+
113
+ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
114
+ def schedule(self, n, **kwargs):
115
+ cycle = self.find_in_interval(n)
116
+ n = n - self.cum_cycles[cycle]
117
+ if self.verbosity_interval > 0:
118
+ if n % self.verbosity_interval == 0:
119
+ print(
120
+ f"current step: {n}, recent lr-multiplier: {self.last_f}, "
121
+ f"current cycle {cycle}"
122
+ )
123
+
124
+ if n < self.lr_warm_up_steps[cycle]:
125
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
126
+ cycle
127
+ ] * n + self.f_start[cycle]
128
+ self.last_f = f
129
+ return f
130
+ else:
131
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
132
+ self.cycle_lengths[cycle] - n
133
+ ) / (self.cycle_lengths[cycle])
134
+ self.last_f = f
135
+ return f
sgm/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .autoencoder import AutoencodingEngine
2
+ from .diffusion import DiffusionEngine
sgm/models/autoencoder.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from abc import abstractmethod
3
+ from contextlib import contextmanager
4
+ from typing import Any, Dict, Tuple, Union
5
+
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ from omegaconf import ListConfig
9
+ from packaging import version
10
+ from safetensors.torch import load_file as load_safetensors
11
+
12
+ from ..modules.diffusionmodules.model import Decoder, Encoder
13
+ from ..modules.distributions.distributions import DiagonalGaussianDistribution
14
+ from ..modules.ema import LitEma
15
+ from ..util import default, get_obj_from_str, instantiate_from_config
16
+
17
+
18
+ class AbstractAutoencoder(pl.LightningModule):
19
+ """
20
+ This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
21
+ unCLIP models, etc. Hence, it is fairly general, and specific features
22
+ (e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ ema_decay: Union[None, float] = None,
28
+ monitor: Union[None, str] = None,
29
+ input_key: str = "jpg",
30
+ ckpt_path: Union[None, str] = None,
31
+ ignore_keys: Union[Tuple, list, ListConfig] = (),
32
+ ):
33
+ super().__init__()
34
+ self.input_key = input_key
35
+ self.use_ema = ema_decay is not None
36
+ if monitor is not None:
37
+ self.monitor = monitor
38
+
39
+ if self.use_ema:
40
+ self.model_ema = LitEma(self, decay=ema_decay)
41
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
42
+
43
+ if ckpt_path is not None:
44
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
45
+
46
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
47
+ self.automatic_optimization = False
48
+
49
+ def init_from_ckpt(
50
+ self, path: str, ignore_keys: Union[Tuple, list, ListConfig] = tuple()
51
+ ) -> None:
52
+ if path.endswith("ckpt"):
53
+ sd = torch.load(path, map_location="cpu")["state_dict"]
54
+ elif path.endswith("safetensors"):
55
+ sd = load_safetensors(path)
56
+ else:
57
+ raise NotImplementedError
58
+
59
+ keys = list(sd.keys())
60
+ for k in keys:
61
+ for ik in ignore_keys:
62
+ if re.match(ik, k):
63
+ print("Deleting key {} from state_dict.".format(k))
64
+ del sd[k]
65
+ missing, unexpected = self.load_state_dict(sd, strict=False)
66
+ print(
67
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
68
+ )
69
+ if len(missing) > 0:
70
+ print(f"Missing Keys: {missing}")
71
+ if len(unexpected) > 0:
72
+ print(f"Unexpected Keys: {unexpected}")
73
+
74
+ @abstractmethod
75
+ def get_input(self, batch) -> Any:
76
+ raise NotImplementedError()
77
+
78
+ def on_train_batch_end(self, *args, **kwargs):
79
+ # for EMA computation
80
+ if self.use_ema:
81
+ self.model_ema(self)
82
+
83
+ @contextmanager
84
+ def ema_scope(self, context=None):
85
+ if self.use_ema:
86
+ self.model_ema.store(self.parameters())
87
+ self.model_ema.copy_to(self)
88
+ if context is not None:
89
+ print(f"{context}: Switched to EMA weights")
90
+ try:
91
+ yield None
92
+ finally:
93
+ if self.use_ema:
94
+ self.model_ema.restore(self.parameters())
95
+ if context is not None:
96
+ print(f"{context}: Restored training weights")
97
+
98
+ @abstractmethod
99
+ def encode(self, *args, **kwargs) -> torch.Tensor:
100
+ raise NotImplementedError("encode()-method of abstract base class called")
101
+
102
+ @abstractmethod
103
+ def decode(self, *args, **kwargs) -> torch.Tensor:
104
+ raise NotImplementedError("decode()-method of abstract base class called")
105
+
106
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
107
+ print(f"loading >>> {cfg['target']} <<< optimizer from config")
108
+ return get_obj_from_str(cfg["target"])(
109
+ params, lr=lr, **cfg.get("params", dict())
110
+ )
111
+
112
+ def configure_optimizers(self) -> Any:
113
+ raise NotImplementedError()
114
+
115
+
116
+ class AutoencodingEngine(AbstractAutoencoder):
117
+ """
118
+ Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
119
+ (we also restore them explicitly as special cases for legacy reasons).
120
+ Regularizations such as KL or VQ are moved to the regularizer class.
121
+ """
122
+
123
+ def __init__(
124
+ self,
125
+ *args,
126
+ encoder_config: Dict,
127
+ decoder_config: Dict,
128
+ loss_config: Dict,
129
+ regularizer_config: Dict,
130
+ optimizer_config: Union[Dict, None] = None,
131
+ lr_g_factor: float = 1.0,
132
+ **kwargs,
133
+ ):
134
+ super().__init__(*args, **kwargs)
135
+ # todo: add options to freeze encoder/decoder
136
+ self.encoder = instantiate_from_config(encoder_config)
137
+ self.decoder = instantiate_from_config(decoder_config)
138
+ self.loss = instantiate_from_config(loss_config)
139
+ self.regularization = instantiate_from_config(regularizer_config)
140
+ self.optimizer_config = default(
141
+ optimizer_config, {"target": "torch.optim.Adam"}
142
+ )
143
+ self.lr_g_factor = lr_g_factor
144
+
145
+ def get_input(self, batch: Dict) -> torch.Tensor:
146
+ # assuming unified data format, dataloader returns a dict.
147
+ # image tensors should be scaled to -1 ... 1 and in channels-first format (e.g., bchw instead if bhwc)
148
+ return batch[self.input_key]
149
+
150
+ def get_autoencoder_params(self) -> list:
151
+ params = (
152
+ list(self.encoder.parameters())
153
+ + list(self.decoder.parameters())
154
+ + list(self.regularization.get_trainable_parameters())
155
+ + list(self.loss.get_trainable_autoencoder_parameters())
156
+ )
157
+ return params
158
+
159
+ def get_discriminator_params(self) -> list:
160
+ params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
161
+ return params
162
+
163
+ def get_last_layer(self):
164
+ return self.decoder.get_last_layer()
165
+
166
+ def encode(self, x: Any, return_reg_log: bool = False) -> Any:
167
+ z = self.encoder(x)
168
+ z, reg_log = self.regularization(z)
169
+ if return_reg_log:
170
+ return z, reg_log
171
+ return z
172
+
173
+ def decode(self, z: Any) -> torch.Tensor:
174
+ x = self.decoder(z)
175
+ return x
176
+
177
+ def forward(self, x: Any) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
178
+ z, reg_log = self.encode(x, return_reg_log=True)
179
+ dec = self.decode(z)
180
+ return z, dec, reg_log
181
+
182
+ def training_step(self, batch, batch_idx, optimizer_idx) -> Any:
183
+ x = self.get_input(batch)
184
+ z, xrec, regularization_log = self(x)
185
+
186
+ if optimizer_idx == 0:
187
+ # autoencode
188
+ aeloss, log_dict_ae = self.loss(
189
+ regularization_log,
190
+ x,
191
+ xrec,
192
+ optimizer_idx,
193
+ self.global_step,
194
+ last_layer=self.get_last_layer(),
195
+ split="train",
196
+ )
197
+
198
+ self.log_dict(
199
+ log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
200
+ )
201
+ return aeloss
202
+
203
+ if optimizer_idx == 1:
204
+ # discriminator
205
+ discloss, log_dict_disc = self.loss(
206
+ regularization_log,
207
+ x,
208
+ xrec,
209
+ optimizer_idx,
210
+ self.global_step,
211
+ last_layer=self.get_last_layer(),
212
+ split="train",
213
+ )
214
+ self.log_dict(
215
+ log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
216
+ )
217
+ return discloss
218
+
219
+ def validation_step(self, batch, batch_idx) -> Dict:
220
+ log_dict = self._validation_step(batch, batch_idx)
221
+ with self.ema_scope():
222
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
223
+ log_dict.update(log_dict_ema)
224
+ return log_dict
225
+
226
+ def _validation_step(self, batch, batch_idx, postfix="") -> Dict:
227
+ x = self.get_input(batch)
228
+
229
+ z, xrec, regularization_log = self(x)
230
+ aeloss, log_dict_ae = self.loss(
231
+ regularization_log,
232
+ x,
233
+ xrec,
234
+ 0,
235
+ self.global_step,
236
+ last_layer=self.get_last_layer(),
237
+ split="val" + postfix,
238
+ )
239
+
240
+ discloss, log_dict_disc = self.loss(
241
+ regularization_log,
242
+ x,
243
+ xrec,
244
+ 1,
245
+ self.global_step,
246
+ last_layer=self.get_last_layer(),
247
+ split="val" + postfix,
248
+ )
249
+ self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
250
+ log_dict_ae.update(log_dict_disc)
251
+ self.log_dict(log_dict_ae)
252
+ return log_dict_ae
253
+
254
+ def configure_optimizers(self) -> Any:
255
+ ae_params = self.get_autoencoder_params()
256
+ disc_params = self.get_discriminator_params()
257
+
258
+ opt_ae = self.instantiate_optimizer_from_config(
259
+ ae_params,
260
+ default(self.lr_g_factor, 1.0) * self.learning_rate,
261
+ self.optimizer_config,
262
+ )
263
+ opt_disc = self.instantiate_optimizer_from_config(
264
+ disc_params, self.learning_rate, self.optimizer_config
265
+ )
266
+
267
+ return [opt_ae, opt_disc], []
268
+
269
+ @torch.no_grad()
270
+ def log_images(self, batch: Dict, **kwargs) -> Dict:
271
+ log = dict()
272
+ x = self.get_input(batch)
273
+ _, xrec, _ = self(x)
274
+ log["inputs"] = x
275
+ log["reconstructions"] = xrec
276
+ with self.ema_scope():
277
+ _, xrec_ema, _ = self(x)
278
+ log["reconstructions_ema"] = xrec_ema
279
+ return log
280
+
281
+
282
+ class AutoencoderKL(AutoencodingEngine):
283
+ def __init__(self, embed_dim: int, **kwargs):
284
+ ddconfig = kwargs.pop("ddconfig")
285
+ ckpt_path = kwargs.pop("ckpt_path", None)
286
+ ignore_keys = kwargs.pop("ignore_keys", ())
287
+ super().__init__(
288
+ encoder_config={"target": "torch.nn.Identity"},
289
+ decoder_config={"target": "torch.nn.Identity"},
290
+ regularizer_config={"target": "torch.nn.Identity"},
291
+ loss_config=kwargs.pop("lossconfig"),
292
+ **kwargs,
293
+ )
294
+ assert ddconfig["double_z"]
295
+ self.encoder = Encoder(**ddconfig)
296
+ self.decoder = Decoder(**ddconfig)
297
+ self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
298
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
299
+ self.embed_dim = embed_dim
300
+
301
+ if ckpt_path is not None:
302
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
303
+
304
+ def encode(self, x):
305
+ assert (
306
+ not self.training
307
+ ), f"{self.__class__.__name__} only supports inference currently"
308
+ h = self.encoder(x)
309
+ moments = self.quant_conv(h)
310
+ posterior = DiagonalGaussianDistribution(moments)
311
+ return posterior
312
+
313
+ def decode(self, z, **decoder_kwargs):
314
+ z = self.post_quant_conv(z)
315
+ dec = self.decoder(z, **decoder_kwargs)
316
+ return dec
317
+
318
+
319
+ class AutoencoderKLInferenceWrapper(AutoencoderKL):
320
+ def encode(self, x):
321
+ return super().encode(x).sample()
322
+
323
+
324
+ class IdentityFirstStage(AbstractAutoencoder):
325
+ def __init__(self, *args, **kwargs):
326
+ super().__init__(*args, **kwargs)
327
+
328
+ def get_input(self, x: Any) -> Any:
329
+ return x
330
+
331
+ def encode(self, x: Any, *args, **kwargs) -> Any:
332
+ return x
333
+
334
+ def decode(self, x: Any, *args, **kwargs) -> Any:
335
+ return x
sgm/models/diffusion.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ from typing import Any, Dict, List, Tuple, Union
3
+
4
+ import pytorch_lightning as pl
5
+ import torch
6
+ from omegaconf import ListConfig, OmegaConf
7
+ from safetensors.torch import load_file as load_safetensors
8
+ from torch.optim.lr_scheduler import LambdaLR
9
+
10
+ from ..modules import UNCONDITIONAL_CONFIG
11
+ from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
12
+ from ..modules.ema import LitEma
13
+ from ..util import (
14
+ default,
15
+ disabled_train,
16
+ get_obj_from_str,
17
+ instantiate_from_config,
18
+ log_txt_as_img,
19
+ )
20
+
21
+
22
+ class DiffusionEngine(pl.LightningModule):
23
+ def __init__(
24
+ self,
25
+ network_config,
26
+ denoiser_config,
27
+ first_stage_config,
28
+ conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
29
+ sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
30
+ optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
31
+ scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
32
+ loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
33
+ network_wrapper: Union[None, str] = None,
34
+ ckpt_path: Union[None, str] = None,
35
+ use_ema: bool = False,
36
+ ema_decay_rate: float = 0.9999,
37
+ scale_factor: float = 1.0,
38
+ disable_first_stage_autocast=False,
39
+ input_key: str = "jpg",
40
+ log_keys: Union[List, None] = None,
41
+ no_cond_log: bool = False,
42
+ compile_model: bool = False,
43
+ ):
44
+ super().__init__()
45
+ self.log_keys = log_keys
46
+ self.input_key = input_key
47
+ self.optimizer_config = default(
48
+ optimizer_config, {"target": "torch.optim.AdamW"}
49
+ )
50
+ model = instantiate_from_config(network_config)
51
+ self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
52
+ model, compile_model=compile_model
53
+ )
54
+
55
+ self.denoiser = instantiate_from_config(denoiser_config)
56
+ self.sampler = (
57
+ instantiate_from_config(sampler_config)
58
+ if sampler_config is not None
59
+ else None
60
+ )
61
+ self.conditioner = instantiate_from_config(
62
+ default(conditioner_config, UNCONDITIONAL_CONFIG)
63
+ )
64
+ self.scheduler_config = scheduler_config
65
+ self._init_first_stage(first_stage_config)
66
+
67
+ self.loss_fn = (
68
+ instantiate_from_config(loss_fn_config)
69
+ if loss_fn_config is not None
70
+ else None
71
+ )
72
+
73
+ self.use_ema = use_ema
74
+ if self.use_ema:
75
+ self.model_ema = LitEma(self.model, decay=ema_decay_rate)
76
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
77
+
78
+ self.scale_factor = scale_factor
79
+ self.disable_first_stage_autocast = disable_first_stage_autocast
80
+ self.no_cond_log = no_cond_log
81
+
82
+ if ckpt_path is not None:
83
+ self.init_from_ckpt(ckpt_path)
84
+
85
+ def init_from_ckpt(
86
+ self,
87
+ path: str,
88
+ ) -> None:
89
+ if path.endswith("ckpt"):
90
+ sd = torch.load(path, map_location="cpu")["state_dict"]
91
+ elif path.endswith("safetensors"):
92
+ sd = load_safetensors(path)
93
+ else:
94
+ raise NotImplementedError
95
+
96
+ missing, unexpected = self.load_state_dict(sd, strict=False)
97
+ print(
98
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
99
+ )
100
+ if len(missing) > 0:
101
+ print(f"Missing Keys: {missing}")
102
+ if len(unexpected) > 0:
103
+ print(f"Unexpected Keys: {unexpected}")
104
+
105
+ def _init_first_stage(self, config):
106
+ model = instantiate_from_config(config).eval()
107
+ model.train = disabled_train
108
+ for param in model.parameters():
109
+ param.requires_grad = False
110
+ self.first_stage_model = model
111
+
112
+ def get_input(self, batch):
113
+ # assuming unified data format, dataloader returns a dict.
114
+ # image tensors should be scaled to -1 ... 1 and in bchw format
115
+ return batch[self.input_key]
116
+
117
+ @torch.no_grad()
118
+ def decode_first_stage(self, z):
119
+ z = 1.0 / self.scale_factor * z
120
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
121
+ out = self.first_stage_model.decode(z)
122
+ return out
123
+
124
+ @torch.no_grad()
125
+ def encode_first_stage(self, x):
126
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
127
+ z = self.first_stage_model.encode(x)
128
+ z = self.scale_factor * z
129
+ return z
130
+
131
+ def forward(self, x, batch):
132
+ loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)
133
+ loss_mean = loss.mean()
134
+ loss_dict = {"loss": loss_mean}
135
+ return loss_mean, loss_dict
136
+
137
+ def shared_step(self, batch: Dict) -> Any:
138
+ x = self.get_input(batch)
139
+ x = self.encode_first_stage(x)
140
+ batch["global_step"] = self.global_step
141
+ loss, loss_dict = self(x, batch)
142
+ return loss, loss_dict
143
+
144
+ def training_step(self, batch, batch_idx):
145
+ loss, loss_dict = self.shared_step(batch)
146
+
147
+ self.log_dict(
148
+ loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
149
+ )
150
+
151
+ self.log(
152
+ "global_step",
153
+ self.global_step,
154
+ prog_bar=True,
155
+ logger=True,
156
+ on_step=True,
157
+ on_epoch=False,
158
+ )
159
+
160
+ if self.scheduler_config is not None:
161
+ lr = self.optimizers().param_groups[0]["lr"]
162
+ self.log(
163
+ "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
164
+ )
165
+
166
+ return loss
167
+
168
+ def on_train_start(self, *args, **kwargs):
169
+ if self.sampler is None or self.loss_fn is None:
170
+ raise ValueError("Sampler and loss function need to be set for training.")
171
+
172
+ def on_train_batch_end(self, *args, **kwargs):
173
+ if self.use_ema:
174
+ self.model_ema(self.model)
175
+
176
+ @contextmanager
177
+ def ema_scope(self, context=None):
178
+ if self.use_ema:
179
+ self.model_ema.store(self.model.parameters())
180
+ self.model_ema.copy_to(self.model)
181
+ if context is not None:
182
+ print(f"{context}: Switched to EMA weights")
183
+ try:
184
+ yield None
185
+ finally:
186
+ if self.use_ema:
187
+ self.model_ema.restore(self.model.parameters())
188
+ if context is not None:
189
+ print(f"{context}: Restored training weights")
190
+
191
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
192
+ return get_obj_from_str(cfg["target"])(
193
+ params, lr=lr, **cfg.get("params", dict())
194
+ )
195
+
196
+ def configure_optimizers(self):
197
+ lr = self.learning_rate
198
+ params = list(self.model.parameters())
199
+ for embedder in self.conditioner.embedders:
200
+ if embedder.is_trainable:
201
+ params = params + list(embedder.parameters())
202
+ opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
203
+ if self.scheduler_config is not None:
204
+ scheduler = instantiate_from_config(self.scheduler_config)
205
+ print("Setting up LambdaLR scheduler...")
206
+ scheduler = [
207
+ {
208
+ "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
209
+ "interval": "step",
210
+ "frequency": 1,
211
+ }
212
+ ]
213
+ return [opt], scheduler
214
+ return opt
215
+
216
+ @torch.no_grad()
217
+ def sample(
218
+ self,
219
+ cond: Dict,
220
+ uc: Union[Dict, None] = None,
221
+ batch_size: int = 16,
222
+ shape: Union[None, Tuple, List] = None,
223
+ **kwargs,
224
+ ):
225
+ randn = torch.randn(batch_size, *shape).to(self.device)
226
+
227
+ denoiser = lambda input, sigma, c: self.denoiser(
228
+ self.model, input, sigma, c, **kwargs
229
+ )
230
+ samples = self.sampler(denoiser, randn, cond, uc=uc)
231
+ return samples
232
+
233
+ @torch.no_grad()
234
+ def log_conditionings(self, batch: Dict, n: int) -> Dict:
235
+ """
236
+ Defines heuristics to log different conditionings.
237
+ These can be lists of strings (text-to-image), tensors, ints, ...
238
+ """
239
+ image_h, image_w = batch[self.input_key].shape[2:]
240
+ log = dict()
241
+
242
+ for embedder in self.conditioner.embedders:
243
+ if (
244
+ (self.log_keys is None) or (embedder.input_key in self.log_keys)
245
+ ) and not self.no_cond_log:
246
+ x = batch[embedder.input_key][:n]
247
+ if isinstance(x, torch.Tensor):
248
+ if x.dim() == 1:
249
+ # class-conditional, convert integer to string
250
+ x = [str(x[i].item()) for i in range(x.shape[0])]
251
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
252
+ elif x.dim() == 2:
253
+ # size and crop cond and the like
254
+ x = [
255
+ "x".join([str(xx) for xx in x[i].tolist()])
256
+ for i in range(x.shape[0])
257
+ ]
258
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
259
+ else:
260
+ raise NotImplementedError()
261
+ elif isinstance(x, (List, ListConfig)):
262
+ if isinstance(x[0], str):
263
+ # strings
264
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
265
+ else:
266
+ raise NotImplementedError()
267
+ else:
268
+ raise NotImplementedError()
269
+ log[embedder.input_key] = xc
270
+ return log
271
+
272
+ @torch.no_grad()
273
+ def log_images(
274
+ self,
275
+ batch: Dict,
276
+ N: int = 8,
277
+ sample: bool = True,
278
+ ucg_keys: List[str] = None,
279
+ **kwargs,
280
+ ) -> Dict:
281
+ conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
282
+ if ucg_keys:
283
+ assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
284
+ "Each defined ucg key for sampling must be in the provided conditioner input keys,"
285
+ f"but we have {ucg_keys} vs. {conditioner_input_keys}"
286
+ )
287
+ else:
288
+ ucg_keys = conditioner_input_keys
289
+ log = dict()
290
+
291
+ x = self.get_input(batch)
292
+
293
+ c, uc = self.conditioner.get_unconditional_conditioning(
294
+ batch,
295
+ force_uc_zero_embeddings=ucg_keys
296
+ if len(self.conditioner.embedders) > 0
297
+ else [],
298
+ )
299
+
300
+ sampling_kwargs = {}
301
+
302
+ N = min(x.shape[0], N)
303
+ x = x.to(self.device)[:N]
304
+ log["inputs"] = x
305
+ z = self.encode_first_stage(x)
306
+ log["reconstructions"] = self.decode_first_stage(z)
307
+ log.update(self.log_conditionings(batch, N))
308
+
309
+ for k in c:
310
+ if isinstance(c[k], torch.Tensor):
311
+ c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
312
+
313
+ if sample:
314
+ with self.ema_scope("Plotting"):
315
+ samples = self.sample(
316
+ c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
317
+ )
318
+ samples = self.decode_first_stage(samples)
319
+ log["samples"] = samples
320
+ return log
sgm/modules/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .encoders.modules import GeneralConditioner
2
+
3
+ UNCONDITIONAL_CONFIG = {
4
+ "target": "sgm.modules.GeneralConditioner",
5
+ "params": {"emb_models": []},
6
+ }
sgm/modules/attention.py ADDED
@@ -0,0 +1,633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from inspect import isfunction
3
+ from typing import Any, Optional
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from einops import rearrange, repeat
8
+ from packaging import version
9
+ from torch import nn
10
+
11
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
12
+ SDP_IS_AVAILABLE = True
13
+ from torch.backends.cuda import SDPBackend, sdp_kernel
14
+
15
+ BACKEND_MAP = {
16
+ SDPBackend.MATH: {
17
+ "enable_math": True,
18
+ "enable_flash": False,
19
+ "enable_mem_efficient": False,
20
+ },
21
+ SDPBackend.FLASH_ATTENTION: {
22
+ "enable_math": False,
23
+ "enable_flash": True,
24
+ "enable_mem_efficient": False,
25
+ },
26
+ SDPBackend.EFFICIENT_ATTENTION: {
27
+ "enable_math": False,
28
+ "enable_flash": False,
29
+ "enable_mem_efficient": True,
30
+ },
31
+ None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
32
+ }
33
+ else:
34
+ from contextlib import nullcontext
35
+
36
+ SDP_IS_AVAILABLE = False
37
+ sdp_kernel = nullcontext
38
+ BACKEND_MAP = {}
39
+ print(
40
+ f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
41
+ f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
42
+ )
43
+
44
+ try:
45
+ import xformers
46
+ import xformers.ops
47
+
48
+ XFORMERS_IS_AVAILABLE = True
49
+ except:
50
+ XFORMERS_IS_AVAILABLE = False
51
+ print("no module 'xformers'. Processing without...")
52
+
53
+ from .diffusionmodules.util import checkpoint
54
+
55
+
56
+ def exists(val):
57
+ return val is not None
58
+
59
+
60
+ def uniq(arr):
61
+ return {el: True for el in arr}.keys()
62
+
63
+
64
+ def default(val, d):
65
+ if exists(val):
66
+ return val
67
+ return d() if isfunction(d) else d
68
+
69
+
70
+ def max_neg_value(t):
71
+ return -torch.finfo(t.dtype).max
72
+
73
+
74
+ def init_(tensor):
75
+ dim = tensor.shape[-1]
76
+ std = 1 / math.sqrt(dim)
77
+ tensor.uniform_(-std, std)
78
+ return tensor
79
+
80
+
81
+ # feedforward
82
+ class GEGLU(nn.Module):
83
+ def __init__(self, dim_in, dim_out):
84
+ super().__init__()
85
+ self.proj = nn.Linear(dim_in, dim_out * 2)
86
+
87
+ def forward(self, x):
88
+ x, gate = self.proj(x).chunk(2, dim=-1)
89
+ return x * F.gelu(gate)
90
+
91
+
92
+ class FeedForward(nn.Module):
93
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
94
+ super().__init__()
95
+ inner_dim = int(dim * mult)
96
+ dim_out = default(dim_out, dim)
97
+ project_in = (
98
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
99
+ if not glu
100
+ else GEGLU(dim, inner_dim)
101
+ )
102
+
103
+ self.net = nn.Sequential(
104
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
105
+ )
106
+
107
+ def forward(self, x):
108
+ return self.net(x)
109
+
110
+
111
+ def zero_module(module):
112
+ """
113
+ Zero out the parameters of a module and return it.
114
+ """
115
+ for p in module.parameters():
116
+ p.detach().zero_()
117
+ return module
118
+
119
+
120
+ def Normalize(in_channels):
121
+ return torch.nn.GroupNorm(
122
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
123
+ )
124
+
125
+
126
+ class LinearAttention(nn.Module):
127
+ def __init__(self, dim, heads=4, dim_head=32):
128
+ super().__init__()
129
+ self.heads = heads
130
+ hidden_dim = dim_head * heads
131
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
132
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
133
+
134
+ def forward(self, x):
135
+ b, c, h, w = x.shape
136
+ qkv = self.to_qkv(x)
137
+ q, k, v = rearrange(
138
+ qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
139
+ )
140
+ k = k.softmax(dim=-1)
141
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
142
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
143
+ out = rearrange(
144
+ out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
145
+ )
146
+ return self.to_out(out)
147
+
148
+
149
+ class SpatialSelfAttention(nn.Module):
150
+ def __init__(self, in_channels):
151
+ super().__init__()
152
+ self.in_channels = in_channels
153
+
154
+ self.norm = Normalize(in_channels)
155
+ self.q = torch.nn.Conv2d(
156
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
157
+ )
158
+ self.k = torch.nn.Conv2d(
159
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
160
+ )
161
+ self.v = torch.nn.Conv2d(
162
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
163
+ )
164
+ self.proj_out = torch.nn.Conv2d(
165
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
166
+ )
167
+
168
+ def forward(self, x):
169
+ h_ = x
170
+ h_ = self.norm(h_)
171
+ q = self.q(h_)
172
+ k = self.k(h_)
173
+ v = self.v(h_)
174
+
175
+ # compute attention
176
+ b, c, h, w = q.shape
177
+ q = rearrange(q, "b c h w -> b (h w) c")
178
+ k = rearrange(k, "b c h w -> b c (h w)")
179
+ w_ = torch.einsum("bij,bjk->bik", q, k)
180
+
181
+ w_ = w_ * (int(c) ** (-0.5))
182
+ w_ = torch.nn.functional.softmax(w_, dim=2)
183
+
184
+ # attend to values
185
+ v = rearrange(v, "b c h w -> b c (h w)")
186
+ w_ = rearrange(w_, "b i j -> b j i")
187
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
188
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
189
+ h_ = self.proj_out(h_)
190
+
191
+ return x + h_
192
+
193
+
194
+ class CrossAttention(nn.Module):
195
+ def __init__(
196
+ self,
197
+ query_dim,
198
+ context_dim=None,
199
+ heads=8,
200
+ dim_head=64,
201
+ dropout=0.0,
202
+ backend=None,
203
+ ):
204
+ super().__init__()
205
+ inner_dim = dim_head * heads
206
+ context_dim = default(context_dim, query_dim)
207
+
208
+ self.scale = dim_head**-0.5
209
+ self.heads = heads
210
+
211
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
212
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
213
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
214
+
215
+ self.to_out = nn.Sequential(
216
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
217
+ )
218
+ self.backend = backend
219
+
220
+ def forward(
221
+ self,
222
+ x,
223
+ context=None,
224
+ mask=None,
225
+ additional_tokens=None,
226
+ n_times_crossframe_attn_in_self=0,
227
+ ):
228
+ h = self.heads
229
+
230
+ if additional_tokens is not None:
231
+ # get the number of masked tokens at the beginning of the output sequence
232
+ n_tokens_to_mask = additional_tokens.shape[1]
233
+ # add additional token
234
+ x = torch.cat([additional_tokens, x], dim=1)
235
+
236
+ q = self.to_q(x)
237
+ context = default(context, x)
238
+ k = self.to_k(context)
239
+ v = self.to_v(context)
240
+
241
+ if n_times_crossframe_attn_in_self:
242
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
243
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
244
+ n_cp = x.shape[0] // n_times_crossframe_attn_in_self
245
+ k = repeat(
246
+ k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
247
+ )
248
+ v = repeat(
249
+ v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
250
+ )
251
+
252
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
253
+
254
+ ## old
255
+ """
256
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
257
+ del q, k
258
+
259
+ if exists(mask):
260
+ mask = rearrange(mask, 'b ... -> b (...)')
261
+ max_neg_value = -torch.finfo(sim.dtype).max
262
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
263
+ sim.masked_fill_(~mask, max_neg_value)
264
+
265
+ # attention, what we cannot get enough of
266
+ sim = sim.softmax(dim=-1)
267
+
268
+ out = einsum('b i j, b j d -> b i d', sim, v)
269
+ """
270
+ ## new
271
+ with sdp_kernel(**BACKEND_MAP[self.backend]):
272
+ # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
273
+ out = F.scaled_dot_product_attention(
274
+ q, k, v, attn_mask=mask
275
+ ) # scale is dim_head ** -0.5 per default
276
+
277
+ del q, k, v
278
+ out = rearrange(out, "b h n d -> b n (h d)", h=h)
279
+
280
+ if additional_tokens is not None:
281
+ # remove additional token
282
+ out = out[:, n_tokens_to_mask:]
283
+ return self.to_out(out)
284
+
285
+
286
+ class MemoryEfficientCrossAttention(nn.Module):
287
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
288
+ def __init__(
289
+ self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
290
+ ):
291
+ super().__init__()
292
+ print(
293
+ f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
294
+ f"{heads} heads with a dimension of {dim_head}."
295
+ )
296
+ inner_dim = dim_head * heads
297
+ context_dim = default(context_dim, query_dim)
298
+
299
+ self.heads = heads
300
+ self.dim_head = dim_head
301
+
302
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
303
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
304
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
305
+
306
+ self.to_out = nn.Sequential(
307
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
308
+ )
309
+ self.attention_op: Optional[Any] = None
310
+
311
+ def forward(
312
+ self,
313
+ x,
314
+ context=None,
315
+ mask=None,
316
+ additional_tokens=None,
317
+ n_times_crossframe_attn_in_self=0,
318
+ ):
319
+ if additional_tokens is not None:
320
+ # get the number of masked tokens at the beginning of the output sequence
321
+ n_tokens_to_mask = additional_tokens.shape[1]
322
+ # add additional token
323
+ x = torch.cat([additional_tokens, x], dim=1)
324
+ q = self.to_q(x)
325
+ context = default(context, x)
326
+ k = self.to_k(context)
327
+ v = self.to_v(context)
328
+
329
+ if n_times_crossframe_attn_in_self:
330
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
331
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
332
+ # n_cp = x.shape[0]//n_times_crossframe_attn_in_self
333
+ k = repeat(
334
+ k[::n_times_crossframe_attn_in_self],
335
+ "b ... -> (b n) ...",
336
+ n=n_times_crossframe_attn_in_self,
337
+ )
338
+ v = repeat(
339
+ v[::n_times_crossframe_attn_in_self],
340
+ "b ... -> (b n) ...",
341
+ n=n_times_crossframe_attn_in_self,
342
+ )
343
+
344
+ b, _, _ = q.shape
345
+ q, k, v = map(
346
+ lambda t: t.unsqueeze(3)
347
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
348
+ .permute(0, 2, 1, 3)
349
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
350
+ .contiguous(),
351
+ (q, k, v),
352
+ )
353
+
354
+ # actually compute the attention, what we cannot get enough of
355
+ out = xformers.ops.memory_efficient_attention(
356
+ q, k, v, attn_bias=None, op=self.attention_op
357
+ )
358
+
359
+ # TODO: Use this directly in the attention operation, as a bias
360
+ if exists(mask):
361
+ raise NotImplementedError
362
+ out = (
363
+ out.unsqueeze(0)
364
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
365
+ .permute(0, 2, 1, 3)
366
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
367
+ )
368
+ if additional_tokens is not None:
369
+ # remove additional token
370
+ out = out[:, n_tokens_to_mask:]
371
+ return self.to_out(out)
372
+
373
+
374
+ class BasicTransformerBlock(nn.Module):
375
+ ATTENTION_MODES = {
376
+ "softmax": CrossAttention, # vanilla attention
377
+ "softmax-xformers": MemoryEfficientCrossAttention, # ampere
378
+ }
379
+
380
+ def __init__(
381
+ self,
382
+ dim,
383
+ n_heads,
384
+ d_head,
385
+ dropout=0.0,
386
+ context_dim=None,
387
+ gated_ff=True,
388
+ checkpoint=True,
389
+ disable_self_attn=False,
390
+ attn_mode="softmax",
391
+ sdp_backend=None,
392
+ ):
393
+ super().__init__()
394
+ assert attn_mode in self.ATTENTION_MODES
395
+ if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
396
+ print(
397
+ f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
398
+ f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
399
+ )
400
+ attn_mode = "softmax"
401
+ elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
402
+ print(
403
+ "We do not support vanilla attention anymore, as it is too expensive. Sorry."
404
+ )
405
+ if not XFORMERS_IS_AVAILABLE:
406
+ assert (
407
+ False
408
+ ), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
409
+ else:
410
+ print("Falling back to xformers efficient attention.")
411
+ attn_mode = "softmax-xformers"
412
+ attn_cls = self.ATTENTION_MODES[attn_mode]
413
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
414
+ assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
415
+ else:
416
+ assert sdp_backend is None
417
+ self.disable_self_attn = disable_self_attn
418
+ self.attn1 = attn_cls(
419
+ query_dim=dim,
420
+ heads=n_heads,
421
+ dim_head=d_head,
422
+ dropout=dropout,
423
+ context_dim=context_dim if self.disable_self_attn else None,
424
+ backend=sdp_backend,
425
+ ) # is a self-attention if not self.disable_self_attn
426
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
427
+ self.attn2 = attn_cls(
428
+ query_dim=dim,
429
+ context_dim=context_dim,
430
+ heads=n_heads,
431
+ dim_head=d_head,
432
+ dropout=dropout,
433
+ backend=sdp_backend,
434
+ ) # is self-attn if context is none
435
+ self.norm1 = nn.LayerNorm(dim)
436
+ self.norm2 = nn.LayerNorm(dim)
437
+ self.norm3 = nn.LayerNorm(dim)
438
+ self.checkpoint = checkpoint
439
+ if self.checkpoint:
440
+ print(f"{self.__class__.__name__} is using checkpointing")
441
+
442
+ def forward(
443
+ self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
444
+ ):
445
+ kwargs = {"x": x}
446
+
447
+ if context is not None:
448
+ kwargs.update({"context": context})
449
+
450
+ if additional_tokens is not None:
451
+ kwargs.update({"additional_tokens": additional_tokens})
452
+
453
+ if n_times_crossframe_attn_in_self:
454
+ kwargs.update(
455
+ {"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
456
+ )
457
+
458
+ # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
459
+ return checkpoint(
460
+ self._forward, (x, context), self.parameters(), self.checkpoint
461
+ )
462
+
463
+ def _forward(
464
+ self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
465
+ ):
466
+ x = (
467
+ self.attn1(
468
+ self.norm1(x),
469
+ context=context if self.disable_self_attn else None,
470
+ additional_tokens=additional_tokens,
471
+ n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
472
+ if not self.disable_self_attn
473
+ else 0,
474
+ )
475
+ + x
476
+ )
477
+ x = (
478
+ self.attn2(
479
+ self.norm2(x), context=context, additional_tokens=additional_tokens
480
+ )
481
+ + x
482
+ )
483
+ x = self.ff(self.norm3(x)) + x
484
+ return x
485
+
486
+
487
+ class BasicTransformerSingleLayerBlock(nn.Module):
488
+ ATTENTION_MODES = {
489
+ "softmax": CrossAttention, # vanilla attention
490
+ "softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
491
+ # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
492
+ }
493
+
494
+ def __init__(
495
+ self,
496
+ dim,
497
+ n_heads,
498
+ d_head,
499
+ dropout=0.0,
500
+ context_dim=None,
501
+ gated_ff=True,
502
+ checkpoint=True,
503
+ attn_mode="softmax",
504
+ ):
505
+ super().__init__()
506
+ assert attn_mode in self.ATTENTION_MODES
507
+ attn_cls = self.ATTENTION_MODES[attn_mode]
508
+ self.attn1 = attn_cls(
509
+ query_dim=dim,
510
+ heads=n_heads,
511
+ dim_head=d_head,
512
+ dropout=dropout,
513
+ context_dim=context_dim,
514
+ )
515
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
516
+ self.norm1 = nn.LayerNorm(dim)
517
+ self.norm2 = nn.LayerNorm(dim)
518
+ self.checkpoint = checkpoint
519
+
520
+ def forward(self, x, context=None):
521
+ return checkpoint(
522
+ self._forward, (x, context), self.parameters(), self.checkpoint
523
+ )
524
+
525
+ def _forward(self, x, context=None):
526
+ x = self.attn1(self.norm1(x), context=context) + x
527
+ x = self.ff(self.norm2(x)) + x
528
+ return x
529
+
530
+
531
+ class SpatialTransformer(nn.Module):
532
+ """
533
+ Transformer block for image-like data.
534
+ First, project the input (aka embedding)
535
+ and reshape to b, t, d.
536
+ Then apply standard transformer action.
537
+ Finally, reshape to image
538
+ NEW: use_linear for more efficiency instead of the 1x1 convs
539
+ """
540
+
541
+ def __init__(
542
+ self,
543
+ in_channels,
544
+ n_heads,
545
+ d_head,
546
+ depth=1,
547
+ dropout=0.0,
548
+ context_dim=None,
549
+ disable_self_attn=False,
550
+ use_linear=False,
551
+ attn_type="softmax",
552
+ use_checkpoint=True,
553
+ # sdp_backend=SDPBackend.FLASH_ATTENTION
554
+ sdp_backend=None,
555
+ ):
556
+ super().__init__()
557
+ print(
558
+ f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
559
+ )
560
+ from omegaconf import ListConfig
561
+
562
+ if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
563
+ context_dim = [context_dim]
564
+ if exists(context_dim) and isinstance(context_dim, list):
565
+ if depth != len(context_dim):
566
+ print(
567
+ f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
568
+ f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
569
+ )
570
+ # depth does not match context dims.
571
+ assert all(
572
+ map(lambda x: x == context_dim[0], context_dim)
573
+ ), "need homogenous context_dim to match depth automatically"
574
+ context_dim = depth * [context_dim[0]]
575
+ elif context_dim is None:
576
+ context_dim = [None] * depth
577
+ self.in_channels = in_channels
578
+ inner_dim = n_heads * d_head
579
+ self.norm = Normalize(in_channels)
580
+ if not use_linear:
581
+ self.proj_in = nn.Conv2d(
582
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
583
+ )
584
+ else:
585
+ self.proj_in = nn.Linear(in_channels, inner_dim)
586
+
587
+ self.transformer_blocks = nn.ModuleList(
588
+ [
589
+ BasicTransformerBlock(
590
+ inner_dim,
591
+ n_heads,
592
+ d_head,
593
+ dropout=dropout,
594
+ context_dim=context_dim[d],
595
+ disable_self_attn=disable_self_attn,
596
+ attn_mode=attn_type,
597
+ checkpoint=use_checkpoint,
598
+ sdp_backend=sdp_backend,
599
+ )
600
+ for d in range(depth)
601
+ ]
602
+ )
603
+ if not use_linear:
604
+ self.proj_out = zero_module(
605
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
606
+ )
607
+ else:
608
+ # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
609
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
610
+ self.use_linear = use_linear
611
+
612
+ def forward(self, x, context=None):
613
+ # note: if no context is given, cross-attention defaults to self-attention
614
+ if not isinstance(context, list):
615
+ context = [context]
616
+ b, c, h, w = x.shape
617
+ x_in = x
618
+ x = self.norm(x)
619
+ if not self.use_linear:
620
+ x = self.proj_in(x)
621
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
622
+ if self.use_linear:
623
+ x = self.proj_in(x)
624
+ for i, block in enumerate(self.transformer_blocks):
625
+ if i > 0 and len(context) == 1:
626
+ i = 0 # use same context for each block
627
+ x = block(x, context=context[i])
628
+ if self.use_linear:
629
+ x = self.proj_out(x)
630
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
631
+ if not self.use_linear:
632
+ x = self.proj_out(x)
633
+ return x + x_in
sgm/modules/autoencoding/__init__.py ADDED
File without changes
sgm/modules/autoencoding/losses/__init__.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+
7
+ from ....util import default, instantiate_from_config
8
+ from ..lpips.loss.lpips import LPIPS
9
+ from ..lpips.model.model import NLayerDiscriminator, weights_init
10
+ from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss
11
+
12
+
13
+ def adopt_weight(weight, global_step, threshold=0, value=0.0):
14
+ if global_step < threshold:
15
+ weight = value
16
+ return weight
17
+
18
+
19
+ class LatentLPIPS(nn.Module):
20
+ def __init__(
21
+ self,
22
+ decoder_config,
23
+ perceptual_weight=1.0,
24
+ latent_weight=1.0,
25
+ scale_input_to_tgt_size=False,
26
+ scale_tgt_to_input_size=False,
27
+ perceptual_weight_on_inputs=0.0,
28
+ ):
29
+ super().__init__()
30
+ self.scale_input_to_tgt_size = scale_input_to_tgt_size
31
+ self.scale_tgt_to_input_size = scale_tgt_to_input_size
32
+ self.init_decoder(decoder_config)
33
+ self.perceptual_loss = LPIPS().eval()
34
+ self.perceptual_weight = perceptual_weight
35
+ self.latent_weight = latent_weight
36
+ self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
37
+
38
+ def init_decoder(self, config):
39
+ self.decoder = instantiate_from_config(config)
40
+ if hasattr(self.decoder, "encoder"):
41
+ del self.decoder.encoder
42
+
43
+ def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
44
+ log = dict()
45
+ loss = (latent_inputs - latent_predictions) ** 2
46
+ log[f"{split}/latent_l2_loss"] = loss.mean().detach()
47
+ image_reconstructions = None
48
+ if self.perceptual_weight > 0.0:
49
+ image_reconstructions = self.decoder.decode(latent_predictions)
50
+ image_targets = self.decoder.decode(latent_inputs)
51
+ perceptual_loss = self.perceptual_loss(
52
+ image_targets.contiguous(), image_reconstructions.contiguous()
53
+ )
54
+ loss = (
55
+ self.latent_weight * loss.mean()
56
+ + self.perceptual_weight * perceptual_loss.mean()
57
+ )
58
+ log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
59
+
60
+ if self.perceptual_weight_on_inputs > 0.0:
61
+ image_reconstructions = default(
62
+ image_reconstructions, self.decoder.decode(latent_predictions)
63
+ )
64
+ if self.scale_input_to_tgt_size:
65
+ image_inputs = torch.nn.functional.interpolate(
66
+ image_inputs,
67
+ image_reconstructions.shape[2:],
68
+ mode="bicubic",
69
+ antialias=True,
70
+ )
71
+ elif self.scale_tgt_to_input_size:
72
+ image_reconstructions = torch.nn.functional.interpolate(
73
+ image_reconstructions,
74
+ image_inputs.shape[2:],
75
+ mode="bicubic",
76
+ antialias=True,
77
+ )
78
+
79
+ perceptual_loss2 = self.perceptual_loss(
80
+ image_inputs.contiguous(), image_reconstructions.contiguous()
81
+ )
82
+ loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
83
+ log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
84
+ return loss, log
85
+
86
+
87
+ class GeneralLPIPSWithDiscriminator(nn.Module):
88
+ def __init__(
89
+ self,
90
+ disc_start: int,
91
+ logvar_init: float = 0.0,
92
+ pixelloss_weight=1.0,
93
+ disc_num_layers: int = 3,
94
+ disc_in_channels: int = 3,
95
+ disc_factor: float = 1.0,
96
+ disc_weight: float = 1.0,
97
+ perceptual_weight: float = 1.0,
98
+ disc_loss: str = "hinge",
99
+ scale_input_to_tgt_size: bool = False,
100
+ dims: int = 2,
101
+ learn_logvar: bool = False,
102
+ regularization_weights: Union[None, dict] = None,
103
+ ):
104
+ super().__init__()
105
+ self.dims = dims
106
+ if self.dims > 2:
107
+ print(
108
+ f"running with dims={dims}. This means that for perceptual loss calculation, "
109
+ f"the LPIPS loss will be applied to each frame independently. "
110
+ )
111
+ self.scale_input_to_tgt_size = scale_input_to_tgt_size
112
+ assert disc_loss in ["hinge", "vanilla"]
113
+ self.pixel_weight = pixelloss_weight
114
+ self.perceptual_loss = LPIPS().eval()
115
+ self.perceptual_weight = perceptual_weight
116
+ # output log variance
117
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
118
+ self.learn_logvar = learn_logvar
119
+
120
+ self.discriminator = NLayerDiscriminator(
121
+ input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=False
122
+ ).apply(weights_init)
123
+ self.discriminator_iter_start = disc_start
124
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
125
+ self.disc_factor = disc_factor
126
+ self.discriminator_weight = disc_weight
127
+ self.regularization_weights = default(regularization_weights, {})
128
+
129
+ def get_trainable_parameters(self) -> Any:
130
+ return self.discriminator.parameters()
131
+
132
+ def get_trainable_autoencoder_parameters(self) -> Any:
133
+ if self.learn_logvar:
134
+ yield self.logvar
135
+ yield from ()
136
+
137
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
138
+ if last_layer is not None:
139
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
140
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
141
+ else:
142
+ nll_grads = torch.autograd.grad(
143
+ nll_loss, self.last_layer[0], retain_graph=True
144
+ )[0]
145
+ g_grads = torch.autograd.grad(
146
+ g_loss, self.last_layer[0], retain_graph=True
147
+ )[0]
148
+
149
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
150
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
151
+ d_weight = d_weight * self.discriminator_weight
152
+ return d_weight
153
+
154
+ def forward(
155
+ self,
156
+ regularization_log,
157
+ inputs,
158
+ reconstructions,
159
+ optimizer_idx,
160
+ global_step,
161
+ last_layer=None,
162
+ split="train",
163
+ weights=None,
164
+ ):
165
+ if self.scale_input_to_tgt_size:
166
+ inputs = torch.nn.functional.interpolate(
167
+ inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
168
+ )
169
+
170
+ if self.dims > 2:
171
+ inputs, reconstructions = map(
172
+ lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
173
+ (inputs, reconstructions),
174
+ )
175
+
176
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
177
+ if self.perceptual_weight > 0:
178
+ p_loss = self.perceptual_loss(
179
+ inputs.contiguous(), reconstructions.contiguous()
180
+ )
181
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
182
+
183
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
184
+ weighted_nll_loss = nll_loss
185
+ if weights is not None:
186
+ weighted_nll_loss = weights * nll_loss
187
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
188
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
189
+
190
+ # now the GAN part
191
+ if optimizer_idx == 0:
192
+ # generator update
193
+ logits_fake = self.discriminator(reconstructions.contiguous())
194
+ g_loss = -torch.mean(logits_fake)
195
+
196
+ if self.disc_factor > 0.0:
197
+ try:
198
+ d_weight = self.calculate_adaptive_weight(
199
+ nll_loss, g_loss, last_layer=last_layer
200
+ )
201
+ except RuntimeError:
202
+ assert not self.training
203
+ d_weight = torch.tensor(0.0)
204
+ else:
205
+ d_weight = torch.tensor(0.0)
206
+
207
+ disc_factor = adopt_weight(
208
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
209
+ )
210
+ loss = weighted_nll_loss + d_weight * disc_factor * g_loss
211
+ log = dict()
212
+ for k in regularization_log:
213
+ if k in self.regularization_weights:
214
+ loss = loss + self.regularization_weights[k] * regularization_log[k]
215
+ log[f"{split}/{k}"] = regularization_log[k].detach().mean()
216
+
217
+ log.update(
218
+ {
219
+ "{}/total_loss".format(split): loss.clone().detach().mean(),
220
+ "{}/logvar".format(split): self.logvar.detach(),
221
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
222
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
223
+ "{}/d_weight".format(split): d_weight.detach(),
224
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
225
+ "{}/g_loss".format(split): g_loss.detach().mean(),
226
+ }
227
+ )
228
+
229
+ return loss, log
230
+
231
+ if optimizer_idx == 1:
232
+ # second pass for discriminator update
233
+ logits_real = self.discriminator(inputs.contiguous().detach())
234
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
235
+
236
+ disc_factor = adopt_weight(
237
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
238
+ )
239
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
240
+
241
+ log = {
242
+ "{}/disc_loss".format(split): d_loss.clone().detach().mean(),
243
+ "{}/logits_real".format(split): logits_real.detach().mean(),
244
+ "{}/logits_fake".format(split): logits_fake.detach().mean(),
245
+ }
246
+ return d_loss, log
sgm/modules/autoencoding/lpips/__init__.py ADDED
File without changes
sgm/modules/autoencoding/lpips/loss/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ vgg.pth
sgm/modules/autoencoding/lpips/loss/LICENSE ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
2
+ All rights reserved.
3
+
4
+ Redistribution and use in source and binary forms, with or without
5
+ modification, are permitted provided that the following conditions are met:
6
+
7
+ * Redistributions of source code must retain the above copyright notice, this
8
+ list of conditions and the following disclaimer.
9
+
10
+ * Redistributions in binary form must reproduce the above copyright notice,
11
+ this list of conditions and the following disclaimer in the documentation
12
+ and/or other materials provided with the distribution.
13
+
14
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
sgm/modules/autoencoding/lpips/loss/__init__.py ADDED
File without changes
sgm/modules/autoencoding/lpips/loss/lpips.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2
+
3
+ from collections import namedtuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision import models
8
+
9
+ from ..util import get_ckpt_path
10
+
11
+
12
+ class LPIPS(nn.Module):
13
+ # Learned perceptual metric
14
+ def __init__(self, use_dropout=True):
15
+ super().__init__()
16
+ self.scaling_layer = ScalingLayer()
17
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
18
+ self.net = vgg16(pretrained=True, requires_grad=False)
19
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
20
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
21
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
22
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
23
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
24
+ self.load_from_pretrained()
25
+ for param in self.parameters():
26
+ param.requires_grad = False
27
+
28
+ def load_from_pretrained(self, name="vgg_lpips"):
29
+ ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss")
30
+ self.load_state_dict(
31
+ torch.load(ckpt, map_location=torch.device("cpu")), strict=False
32
+ )
33
+ print("loaded pretrained LPIPS loss from {}".format(ckpt))
34
+
35
+ @classmethod
36
+ def from_pretrained(cls, name="vgg_lpips"):
37
+ if name != "vgg_lpips":
38
+ raise NotImplementedError
39
+ model = cls()
40
+ ckpt = get_ckpt_path(name)
41
+ model.load_state_dict(
42
+ torch.load(ckpt, map_location=torch.device("cpu")), strict=False
43
+ )
44
+ return model
45
+
46
+ def forward(self, input, target):
47
+ in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
48
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
49
+ feats0, feats1, diffs = {}, {}, {}
50
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
51
+ for kk in range(len(self.chns)):
52
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
53
+ outs1[kk]
54
+ )
55
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
56
+
57
+ res = [
58
+ spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
59
+ for kk in range(len(self.chns))
60
+ ]
61
+ val = res[0]
62
+ for l in range(1, len(self.chns)):
63
+ val += res[l]
64
+ return val
65
+
66
+
67
+ class ScalingLayer(nn.Module):
68
+ def __init__(self):
69
+ super(ScalingLayer, self).__init__()
70
+ self.register_buffer(
71
+ "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
72
+ )
73
+ self.register_buffer(
74
+ "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
75
+ )
76
+
77
+ def forward(self, inp):
78
+ return (inp - self.shift) / self.scale
79
+
80
+
81
+ class NetLinLayer(nn.Module):
82
+ """A single linear layer which does a 1x1 conv"""
83
+
84
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
85
+ super(NetLinLayer, self).__init__()
86
+ layers = (
87
+ [
88
+ nn.Dropout(),
89
+ ]
90
+ if (use_dropout)
91
+ else []
92
+ )
93
+ layers += [
94
+ nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
95
+ ]
96
+ self.model = nn.Sequential(*layers)
97
+
98
+
99
+ class vgg16(torch.nn.Module):
100
+ def __init__(self, requires_grad=False, pretrained=True):
101
+ super(vgg16, self).__init__()
102
+ vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
103
+ self.slice1 = torch.nn.Sequential()
104
+ self.slice2 = torch.nn.Sequential()
105
+ self.slice3 = torch.nn.Sequential()
106
+ self.slice4 = torch.nn.Sequential()
107
+ self.slice5 = torch.nn.Sequential()
108
+ self.N_slices = 5
109
+ for x in range(4):
110
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
111
+ for x in range(4, 9):
112
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
113
+ for x in range(9, 16):
114
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
115
+ for x in range(16, 23):
116
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
117
+ for x in range(23, 30):
118
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
119
+ if not requires_grad:
120
+ for param in self.parameters():
121
+ param.requires_grad = False
122
+
123
+ def forward(self, X):
124
+ h = self.slice1(X)
125
+ h_relu1_2 = h
126
+ h = self.slice2(h)
127
+ h_relu2_2 = h
128
+ h = self.slice3(h)
129
+ h_relu3_3 = h
130
+ h = self.slice4(h)
131
+ h_relu4_3 = h
132
+ h = self.slice5(h)
133
+ h_relu5_3 = h
134
+ vgg_outputs = namedtuple(
135
+ "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
136
+ )
137
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
138
+ return out
139
+
140
+
141
+ def normalize_tensor(x, eps=1e-10):
142
+ norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
143
+ return x / (norm_factor + eps)
144
+
145
+
146
+ def spatial_average(x, keepdim=True):
147
+ return x.mean([2, 3], keepdim=keepdim)
sgm/modules/autoencoding/lpips/model/LICENSE ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
2
+ All rights reserved.
3
+
4
+ Redistribution and use in source and binary forms, with or without
5
+ modification, are permitted provided that the following conditions are met:
6
+
7
+ * Redistributions of source code must retain the above copyright notice, this
8
+ list of conditions and the following disclaimer.
9
+
10
+ * Redistributions in binary form must reproduce the above copyright notice,
11
+ this list of conditions and the following disclaimer in the documentation
12
+ and/or other materials provided with the distribution.
13
+
14
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24
+
25
+
26
+ --------------------------- LICENSE FOR pix2pix --------------------------------
27
+ BSD License
28
+
29
+ For pix2pix software
30
+ Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
31
+ All rights reserved.
32
+
33
+ Redistribution and use in source and binary forms, with or without
34
+ modification, are permitted provided that the following conditions are met:
35
+
36
+ * Redistributions of source code must retain the above copyright notice, this
37
+ list of conditions and the following disclaimer.
38
+
39
+ * Redistributions in binary form must reproduce the above copyright notice,
40
+ this list of conditions and the following disclaimer in the documentation
41
+ and/or other materials provided with the distribution.
42
+
43
+ ----------------------------- LICENSE FOR DCGAN --------------------------------
44
+ BSD License
45
+
46
+ For dcgan.torch software
47
+
48
+ Copyright (c) 2015, Facebook, Inc. All rights reserved.
49
+
50
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
51
+
52
+ Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
53
+
54
+ Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
55
+
56
+ Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
57
+
58
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
sgm/modules/autoencoding/lpips/model/__init__.py ADDED
File without changes
sgm/modules/autoencoding/lpips/model/model.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import torch.nn as nn
4
+
5
+ from ..util import ActNorm
6
+
7
+
8
+ def weights_init(m):
9
+ classname = m.__class__.__name__
10
+ if classname.find("Conv") != -1:
11
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
12
+ elif classname.find("BatchNorm") != -1:
13
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
14
+ nn.init.constant_(m.bias.data, 0)
15
+
16
+
17
+ class NLayerDiscriminator(nn.Module):
18
+ """Defines a PatchGAN discriminator as in Pix2Pix
19
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
20
+ """
21
+
22
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
23
+ """Construct a PatchGAN discriminator
24
+ Parameters:
25
+ input_nc (int) -- the number of channels in input images
26
+ ndf (int) -- the number of filters in the last conv layer
27
+ n_layers (int) -- the number of conv layers in the discriminator
28
+ norm_layer -- normalization layer
29
+ """
30
+ super(NLayerDiscriminator, self).__init__()
31
+ if not use_actnorm:
32
+ norm_layer = nn.BatchNorm2d
33
+ else:
34
+ norm_layer = ActNorm
35
+ if (
36
+ type(norm_layer) == functools.partial
37
+ ): # no need to use bias as BatchNorm2d has affine parameters
38
+ use_bias = norm_layer.func != nn.BatchNorm2d
39
+ else:
40
+ use_bias = norm_layer != nn.BatchNorm2d
41
+
42
+ kw = 4
43
+ padw = 1
44
+ sequence = [
45
+ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
46
+ nn.LeakyReLU(0.2, True),
47
+ ]
48
+ nf_mult = 1
49
+ nf_mult_prev = 1
50
+ for n in range(1, n_layers): # gradually increase the number of filters
51
+ nf_mult_prev = nf_mult
52
+ nf_mult = min(2**n, 8)
53
+ sequence += [
54
+ nn.Conv2d(
55
+ ndf * nf_mult_prev,
56
+ ndf * nf_mult,
57
+ kernel_size=kw,
58
+ stride=2,
59
+ padding=padw,
60
+ bias=use_bias,
61
+ ),
62
+ norm_layer(ndf * nf_mult),
63
+ nn.LeakyReLU(0.2, True),
64
+ ]
65
+
66
+ nf_mult_prev = nf_mult
67
+ nf_mult = min(2**n_layers, 8)
68
+ sequence += [
69
+ nn.Conv2d(
70
+ ndf * nf_mult_prev,
71
+ ndf * nf_mult,
72
+ kernel_size=kw,
73
+ stride=1,
74
+ padding=padw,
75
+ bias=use_bias,
76
+ ),
77
+ norm_layer(ndf * nf_mult),
78
+ nn.LeakyReLU(0.2, True),
79
+ ]
80
+
81
+ sequence += [
82
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
83
+ ] # output 1 channel prediction map
84
+ self.main = nn.Sequential(*sequence)
85
+
86
+ def forward(self, input):
87
+ """Standard forward."""
88
+ return self.main(input)
sgm/modules/autoencoding/lpips/util.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+
4
+ import requests
5
+ import torch
6
+ import torch.nn as nn
7
+ from tqdm import tqdm
8
+
9
+ URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
10
+
11
+ CKPT_MAP = {"vgg_lpips": "vgg.pth"}
12
+
13
+ MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
14
+
15
+
16
+ def download(url, local_path, chunk_size=1024):
17
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
18
+ with requests.get(url, stream=True) as r:
19
+ total_size = int(r.headers.get("content-length", 0))
20
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
21
+ with open(local_path, "wb") as f:
22
+ for data in r.iter_content(chunk_size=chunk_size):
23
+ if data:
24
+ f.write(data)
25
+ pbar.update(chunk_size)
26
+
27
+
28
+ def md5_hash(path):
29
+ with open(path, "rb") as f:
30
+ content = f.read()
31
+ return hashlib.md5(content).hexdigest()
32
+
33
+
34
+ def get_ckpt_path(name, root, check=False):
35
+ assert name in URL_MAP
36
+ path = os.path.join(root, CKPT_MAP[name])
37
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
38
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
39
+ download(URL_MAP[name], path)
40
+ md5 = md5_hash(path)
41
+ assert md5 == MD5_MAP[name], md5
42
+ return path
43
+
44
+
45
+ class ActNorm(nn.Module):
46
+ def __init__(
47
+ self, num_features, logdet=False, affine=True, allow_reverse_init=False
48
+ ):
49
+ assert affine
50
+ super().__init__()
51
+ self.logdet = logdet
52
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
53
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
54
+ self.allow_reverse_init = allow_reverse_init
55
+
56
+ self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
57
+
58
+ def initialize(self, input):
59
+ with torch.no_grad():
60
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
61
+ mean = (
62
+ flatten.mean(1)
63
+ .unsqueeze(1)
64
+ .unsqueeze(2)
65
+ .unsqueeze(3)
66
+ .permute(1, 0, 2, 3)
67
+ )
68
+ std = (
69
+ flatten.std(1)
70
+ .unsqueeze(1)
71
+ .unsqueeze(2)
72
+ .unsqueeze(3)
73
+ .permute(1, 0, 2, 3)
74
+ )
75
+
76
+ self.loc.data.copy_(-mean)
77
+ self.scale.data.copy_(1 / (std + 1e-6))
78
+
79
+ def forward(self, input, reverse=False):
80
+ if reverse:
81
+ return self.reverse(input)
82
+ if len(input.shape) == 2:
83
+ input = input[:, :, None, None]
84
+ squeeze = True
85
+ else:
86
+ squeeze = False
87
+
88
+ _, _, height, width = input.shape
89
+
90
+ if self.training and self.initialized.item() == 0:
91
+ self.initialize(input)
92
+ self.initialized.fill_(1)
93
+
94
+ h = self.scale * (input + self.loc)
95
+
96
+ if squeeze:
97
+ h = h.squeeze(-1).squeeze(-1)
98
+
99
+ if self.logdet:
100
+ log_abs = torch.log(torch.abs(self.scale))
101
+ logdet = height * width * torch.sum(log_abs)
102
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
103
+ return h, logdet
104
+
105
+ return h
106
+
107
+ def reverse(self, output):
108
+ if self.training and self.initialized.item() == 0:
109
+ if not self.allow_reverse_init:
110
+ raise RuntimeError(
111
+ "Initializing ActNorm in reverse direction is "
112
+ "disabled by default. Use allow_reverse_init=True to enable."
113
+ )
114
+ else:
115
+ self.initialize(output)
116
+ self.initialized.fill_(1)
117
+
118
+ if len(output.shape) == 2:
119
+ output = output[:, :, None, None]
120
+ squeeze = True
121
+ else:
122
+ squeeze = False
123
+
124
+ h = output / self.scale - self.loc
125
+
126
+ if squeeze:
127
+ h = h.squeeze(-1).squeeze(-1)
128
+ return h
sgm/modules/autoencoding/lpips/vqperceptual.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def hinge_d_loss(logits_real, logits_fake):
6
+ loss_real = torch.mean(F.relu(1.0 - logits_real))
7
+ loss_fake = torch.mean(F.relu(1.0 + logits_fake))
8
+ d_loss = 0.5 * (loss_real + loss_fake)
9
+ return d_loss
10
+
11
+
12
+ def vanilla_d_loss(logits_real, logits_fake):
13
+ d_loss = 0.5 * (
14
+ torch.mean(torch.nn.functional.softplus(-logits_real))
15
+ + torch.mean(torch.nn.functional.softplus(logits_fake))
16
+ )
17
+ return d_loss
sgm/modules/autoencoding/regularizers/__init__.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from typing import Any, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from ....modules.distributions.distributions import DiagonalGaussianDistribution
9
+
10
+
11
+ class AbstractRegularizer(nn.Module):
12
+ def __init__(self):
13
+ super().__init__()
14
+
15
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
16
+ raise NotImplementedError()
17
+
18
+ @abstractmethod
19
+ def get_trainable_parameters(self) -> Any:
20
+ raise NotImplementedError()
21
+
22
+
23
+ class DiagonalGaussianRegularizer(AbstractRegularizer):
24
+ def __init__(self, sample: bool = True):
25
+ super().__init__()
26
+ self.sample = sample
27
+
28
+ def get_trainable_parameters(self) -> Any:
29
+ yield from ()
30
+
31
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
32
+ log = dict()
33
+ posterior = DiagonalGaussianDistribution(z)
34
+ if self.sample:
35
+ z = posterior.sample()
36
+ else:
37
+ z = posterior.mode()
38
+ kl_loss = posterior.kl()
39
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
40
+ log["kl_loss"] = kl_loss
41
+ return z, log
42
+
43
+
44
+ def measure_perplexity(predicted_indices, num_centroids):
45
+ # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
46
+ # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
47
+ encodings = (
48
+ F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
49
+ )
50
+ avg_probs = encodings.mean(0)
51
+ perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
52
+ cluster_use = torch.sum(avg_probs > 0)
53
+ return perplexity, cluster_use
sgm/modules/diffusionmodules/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .denoiser import Denoiser
2
+ from .discretizer import Discretization
3
+ from .loss import StandardDiffusionLoss
4
+ from .model import Decoder, Encoder, Model
5
+ from .openaimodel import UNetModel
6
+ from .sampling import BaseDiffusionSampler
7
+ from .wrappers import OpenAIWrapper
sgm/modules/diffusionmodules/denoiser.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from ...util import append_dims, instantiate_from_config
4
+
5
+
6
+ class Denoiser(nn.Module):
7
+ def __init__(self, weighting_config, scaling_config):
8
+ super().__init__()
9
+
10
+ self.weighting = instantiate_from_config(weighting_config)
11
+ self.scaling = instantiate_from_config(scaling_config)
12
+
13
+ def possibly_quantize_sigma(self, sigma):
14
+ return sigma
15
+
16
+ def possibly_quantize_c_noise(self, c_noise):
17
+ return c_noise
18
+
19
+ def w(self, sigma):
20
+ return self.weighting(sigma)
21
+
22
+ def __call__(self, network, input, sigma, cond):
23
+ sigma = self.possibly_quantize_sigma(sigma)
24
+ sigma_shape = sigma.shape
25
+ sigma = append_dims(sigma, input.ndim)
26
+ c_skip, c_out, c_in, c_noise = self.scaling(sigma)
27
+ c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
28
+ return network(input * c_in, c_noise, cond) * c_out + input * c_skip
29
+
30
+
31
+ class DiscreteDenoiser(Denoiser):
32
+ def __init__(
33
+ self,
34
+ weighting_config,
35
+ scaling_config,
36
+ num_idx,
37
+ discretization_config,
38
+ do_append_zero=False,
39
+ quantize_c_noise=True,
40
+ flip=True,
41
+ ):
42
+ super().__init__(weighting_config, scaling_config)
43
+ sigmas = instantiate_from_config(discretization_config)(
44
+ num_idx, do_append_zero=do_append_zero, flip=flip
45
+ )
46
+ self.register_buffer("sigmas", sigmas)
47
+ self.quantize_c_noise = quantize_c_noise
48
+
49
+ def sigma_to_idx(self, sigma):
50
+ dists = sigma - self.sigmas[:, None]
51
+ return dists.abs().argmin(dim=0).view(sigma.shape)
52
+
53
+ def idx_to_sigma(self, idx):
54
+ return self.sigmas[idx]
55
+
56
+ def possibly_quantize_sigma(self, sigma):
57
+ return self.idx_to_sigma(self.sigma_to_idx(sigma))
58
+
59
+ def possibly_quantize_c_noise(self, c_noise):
60
+ if self.quantize_c_noise:
61
+ return self.sigma_to_idx(c_noise)
62
+ else:
63
+ return c_noise
sgm/modules/diffusionmodules/denoiser_scaling.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class EDMScaling:
5
+ def __init__(self, sigma_data=0.5):
6
+ self.sigma_data = sigma_data
7
+
8
+ def __call__(self, sigma):
9
+ c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
10
+ c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
11
+ c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
12
+ c_noise = 0.25 * sigma.log()
13
+ return c_skip, c_out, c_in, c_noise
14
+
15
+
16
+ class EpsScaling:
17
+ def __call__(self, sigma):
18
+ c_skip = torch.ones_like(sigma, device=sigma.device)
19
+ c_out = -sigma
20
+ c_in = 1 / (sigma**2 + 1.0) ** 0.5
21
+ c_noise = sigma.clone()
22
+ return c_skip, c_out, c_in, c_noise
23
+
24
+
25
+ class VScaling:
26
+ def __call__(self, sigma):
27
+ c_skip = 1.0 / (sigma**2 + 1.0)
28
+ c_out = -sigma / (sigma**2 + 1.0) ** 0.5
29
+ c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
30
+ c_noise = sigma.clone()
31
+ return c_skip, c_out, c_in, c_noise
sgm/modules/diffusionmodules/denoiser_weighting.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class UnitWeighting:
5
+ def __call__(self, sigma):
6
+ return torch.ones_like(sigma, device=sigma.device)
7
+
8
+
9
+ class EDMWeighting:
10
+ def __init__(self, sigma_data=0.5):
11
+ self.sigma_data = sigma_data
12
+
13
+ def __call__(self, sigma):
14
+ return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
15
+
16
+
17
+ class VWeighting(EDMWeighting):
18
+ def __init__(self):
19
+ super().__init__(sigma_data=1.0)
20
+
21
+
22
+ class EpsWeighting:
23
+ def __call__(self, sigma):
24
+ return sigma**-2.0
sgm/modules/diffusionmodules/discretizer.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from functools import partial
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from ...modules.diffusionmodules.util import make_beta_schedule
8
+ from ...util import append_zero
9
+
10
+
11
+ def generate_roughly_equally_spaced_steps(
12
+ num_substeps: int, max_step: int
13
+ ) -> np.ndarray:
14
+ return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1]
15
+
16
+
17
+ class Discretization:
18
+ def __call__(self, n, do_append_zero=True, device="cpu", flip=False):
19
+ sigmas = self.get_sigmas(n, device=device)
20
+ sigmas = append_zero(sigmas) if do_append_zero else sigmas
21
+ return sigmas if not flip else torch.flip(sigmas, (0,))
22
+
23
+ @abstractmethod
24
+ def get_sigmas(self, n, device):
25
+ pass
26
+
27
+
28
+ class EDMDiscretization(Discretization):
29
+ def __init__(self, sigma_min=0.02, sigma_max=80.0, rho=7.0):
30
+ self.sigma_min = sigma_min
31
+ self.sigma_max = sigma_max
32
+ self.rho = rho
33
+
34
+ def get_sigmas(self, n, device="cpu"):
35
+ ramp = torch.linspace(0, 1, n, device=device)
36
+ min_inv_rho = self.sigma_min ** (1 / self.rho)
37
+ max_inv_rho = self.sigma_max ** (1 / self.rho)
38
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
39
+ return sigmas
40
+
41
+
42
+ class LegacyDDPMDiscretization(Discretization):
43
+ def __init__(
44
+ self,
45
+ linear_start=0.00085,
46
+ linear_end=0.0120,
47
+ num_timesteps=1000,
48
+ ):
49
+ super().__init__()
50
+ self.num_timesteps = num_timesteps
51
+ betas = make_beta_schedule(
52
+ "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
53
+ )
54
+ alphas = 1.0 - betas
55
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
56
+ self.to_torch = partial(torch.tensor, dtype=torch.float32)
57
+
58
+ def get_sigmas(self, n, device="cpu"):
59
+ if n < self.num_timesteps:
60
+ timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
61
+ alphas_cumprod = self.alphas_cumprod[timesteps]
62
+ elif n == self.num_timesteps:
63
+ alphas_cumprod = self.alphas_cumprod
64
+ else:
65
+ raise ValueError
66
+
67
+ to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
68
+ sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
69
+ return torch.flip(sigmas, (0,))
sgm/modules/diffusionmodules/guiders.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+
5
+ from ...util import default, instantiate_from_config
6
+
7
+
8
+ class VanillaCFG:
9
+ """
10
+ implements parallelized CFG
11
+ """
12
+
13
+ def __init__(self, scale, dyn_thresh_config=None):
14
+ scale_schedule = lambda scale, sigma: scale # independent of step
15
+ self.scale_schedule = partial(scale_schedule, scale)
16
+ self.dyn_thresh = instantiate_from_config(
17
+ default(
18
+ dyn_thresh_config,
19
+ {
20
+ "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
21
+ },
22
+ )
23
+ )
24
+
25
+ def __call__(self, x, sigma):
26
+ x_u, x_c = x.chunk(2)
27
+ scale_value = self.scale_schedule(sigma)
28
+ x_pred = self.dyn_thresh(x_u, x_c, scale_value)
29
+ return x_pred
30
+
31
+ def prepare_inputs(self, x, s, c, uc):
32
+ c_out = dict()
33
+
34
+ for k in c:
35
+ if k in ["vector", "crossattn", "concat"]:
36
+ c_out[k] = torch.cat((uc[k], c[k]), 0)
37
+ else:
38
+ assert c[k] == uc[k]
39
+ c_out[k] = c[k]
40
+ return torch.cat([x] * 2), torch.cat([s] * 2), c_out
41
+
42
+
43
+ class IdentityGuider:
44
+ def __call__(self, x, sigma):
45
+ return x
46
+
47
+ def prepare_inputs(self, x, s, c, uc):
48
+ c_out = dict()
49
+
50
+ for k in c:
51
+ c_out[k] = c[k]
52
+
53
+ return x, s, c_out
sgm/modules/diffusionmodules/loss.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from omegaconf import ListConfig
6
+
7
+ from ...util import append_dims, instantiate_from_config
8
+ from ...modules.autoencoding.lpips.loss.lpips import LPIPS
9
+
10
+
11
+ class StandardDiffusionLoss(nn.Module):
12
+ def __init__(
13
+ self,
14
+ sigma_sampler_config,
15
+ type="l2",
16
+ offset_noise_level=0.0,
17
+ batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None,
18
+ ):
19
+ super().__init__()
20
+
21
+ assert type in ["l2", "l1", "lpips"]
22
+
23
+ self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
24
+
25
+ self.type = type
26
+ self.offset_noise_level = offset_noise_level
27
+
28
+ if type == "lpips":
29
+ self.lpips = LPIPS().eval()
30
+
31
+ if not batch2model_keys:
32
+ batch2model_keys = []
33
+
34
+ if isinstance(batch2model_keys, str):
35
+ batch2model_keys = [batch2model_keys]
36
+
37
+ self.batch2model_keys = set(batch2model_keys)
38
+
39
+ def __call__(self, network, denoiser, conditioner, input, batch):
40
+ cond = conditioner(batch)
41
+ additional_model_inputs = {
42
+ key: batch[key] for key in self.batch2model_keys.intersection(batch)
43
+ }
44
+
45
+ sigmas = self.sigma_sampler(input.shape[0]).to(input.device)
46
+ noise = torch.randn_like(input)
47
+ if self.offset_noise_level > 0.0:
48
+ noise = noise + self.offset_noise_level * append_dims(
49
+ torch.randn(input.shape[0], device=input.device), input.ndim
50
+ )
51
+ noised_input = input + noise * append_dims(sigmas, input.ndim)
52
+ model_output = denoiser(
53
+ network, noised_input, sigmas, cond, **additional_model_inputs
54
+ )
55
+ w = append_dims(denoiser.w(sigmas), input.ndim)
56
+ return self.get_loss(model_output, input, w)
57
+
58
+ def get_loss(self, model_output, target, w):
59
+ if self.type == "l2":
60
+ return torch.mean(
61
+ (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1
62
+ )
63
+ elif self.type == "l1":
64
+ return torch.mean(
65
+ (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1
66
+ )
67
+ elif self.type == "lpips":
68
+ loss = self.lpips(model_output, target).reshape(-1)
69
+ return loss
sgm/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,743 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ from typing import Any, Callable, Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange
9
+ from packaging import version
10
+
11
+ try:
12
+ import xformers
13
+ import xformers.ops
14
+
15
+ XFORMERS_IS_AVAILABLE = True
16
+ except:
17
+ XFORMERS_IS_AVAILABLE = False
18
+ print("no module 'xformers'. Processing without...")
19
+
20
+ from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention
21
+
22
+
23
+ def get_timestep_embedding(timesteps, embedding_dim):
24
+ """
25
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
26
+ From Fairseq.
27
+ Build sinusoidal embeddings.
28
+ This matches the implementation in tensor2tensor, but differs slightly
29
+ from the description in Section 3.5 of "Attention Is All You Need".
30
+ """
31
+ assert len(timesteps.shape) == 1
32
+
33
+ half_dim = embedding_dim // 2
34
+ emb = math.log(10000) / (half_dim - 1)
35
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
36
+ emb = emb.to(device=timesteps.device)
37
+ emb = timesteps.float()[:, None] * emb[None, :]
38
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
39
+ if embedding_dim % 2 == 1: # zero pad
40
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
41
+ return emb
42
+
43
+
44
+ def nonlinearity(x):
45
+ # swish
46
+ return x * torch.sigmoid(x)
47
+
48
+
49
+ def Normalize(in_channels, num_groups=32):
50
+ return torch.nn.GroupNorm(
51
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
52
+ )
53
+
54
+
55
+ class Upsample(nn.Module):
56
+ def __init__(self, in_channels, with_conv):
57
+ super().__init__()
58
+ self.with_conv = with_conv
59
+ if self.with_conv:
60
+ self.conv = torch.nn.Conv2d(
61
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
62
+ )
63
+
64
+ def forward(self, x):
65
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
66
+ if self.with_conv:
67
+ x = self.conv(x)
68
+ return x
69
+
70
+
71
+ class Downsample(nn.Module):
72
+ def __init__(self, in_channels, with_conv):
73
+ super().__init__()
74
+ self.with_conv = with_conv
75
+ if self.with_conv:
76
+ # no asymmetric padding in torch conv, must do it ourselves
77
+ self.conv = torch.nn.Conv2d(
78
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
79
+ )
80
+
81
+ def forward(self, x):
82
+ if self.with_conv:
83
+ pad = (0, 1, 0, 1)
84
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
85
+ x = self.conv(x)
86
+ else:
87
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
88
+ return x
89
+
90
+
91
+ class ResnetBlock(nn.Module):
92
+ def __init__(
93
+ self,
94
+ *,
95
+ in_channels,
96
+ out_channels=None,
97
+ conv_shortcut=False,
98
+ dropout,
99
+ temb_channels=512,
100
+ ):
101
+ super().__init__()
102
+ self.in_channels = in_channels
103
+ out_channels = in_channels if out_channels is None else out_channels
104
+ self.out_channels = out_channels
105
+ self.use_conv_shortcut = conv_shortcut
106
+
107
+ self.norm1 = Normalize(in_channels)
108
+ self.conv1 = torch.nn.Conv2d(
109
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
110
+ )
111
+ if temb_channels > 0:
112
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
113
+ self.norm2 = Normalize(out_channels)
114
+ self.dropout = torch.nn.Dropout(dropout)
115
+ self.conv2 = torch.nn.Conv2d(
116
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
117
+ )
118
+ if self.in_channels != self.out_channels:
119
+ if self.use_conv_shortcut:
120
+ self.conv_shortcut = torch.nn.Conv2d(
121
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
122
+ )
123
+ else:
124
+ self.nin_shortcut = torch.nn.Conv2d(
125
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
126
+ )
127
+
128
+ def forward(self, x, temb):
129
+ h = x
130
+ h = self.norm1(h)
131
+ h = nonlinearity(h)
132
+ h = self.conv1(h)
133
+
134
+ if temb is not None:
135
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
136
+
137
+ h = self.norm2(h)
138
+ h = nonlinearity(h)
139
+ h = self.dropout(h)
140
+ h = self.conv2(h)
141
+
142
+ if self.in_channels != self.out_channels:
143
+ if self.use_conv_shortcut:
144
+ x = self.conv_shortcut(x)
145
+ else:
146
+ x = self.nin_shortcut(x)
147
+
148
+ return x + h
149
+
150
+
151
+ class LinAttnBlock(LinearAttention):
152
+ """to match AttnBlock usage"""
153
+
154
+ def __init__(self, in_channels):
155
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
156
+
157
+
158
+ class AttnBlock(nn.Module):
159
+ def __init__(self, in_channels):
160
+ super().__init__()
161
+ self.in_channels = in_channels
162
+
163
+ self.norm = Normalize(in_channels)
164
+ self.q = torch.nn.Conv2d(
165
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
166
+ )
167
+ self.k = torch.nn.Conv2d(
168
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
169
+ )
170
+ self.v = torch.nn.Conv2d(
171
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
172
+ )
173
+ self.proj_out = torch.nn.Conv2d(
174
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
175
+ )
176
+
177
+ def attention(self, h_: torch.Tensor) -> torch.Tensor:
178
+ h_ = self.norm(h_)
179
+ q = self.q(h_)
180
+ k = self.k(h_)
181
+ v = self.v(h_)
182
+
183
+ b, c, h, w = q.shape
184
+ q, k, v = map(
185
+ lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)
186
+ )
187
+ h_ = torch.nn.functional.scaled_dot_product_attention(
188
+ q, k, v
189
+ ) # scale is dim ** -0.5 per default
190
+ # compute attention
191
+
192
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
193
+
194
+ def forward(self, x, **kwargs):
195
+ h_ = x
196
+ h_ = self.attention(h_)
197
+ h_ = self.proj_out(h_)
198
+ return x + h_
199
+
200
+
201
+ class MemoryEfficientAttnBlock(nn.Module):
202
+ """
203
+ Uses xformers efficient implementation,
204
+ see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
205
+ Note: this is a single-head self-attention operation
206
+ """
207
+
208
+ #
209
+ def __init__(self, in_channels):
210
+ super().__init__()
211
+ self.in_channels = in_channels
212
+
213
+ self.norm = Normalize(in_channels)
214
+ self.q = torch.nn.Conv2d(
215
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
216
+ )
217
+ self.k = torch.nn.Conv2d(
218
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
219
+ )
220
+ self.v = torch.nn.Conv2d(
221
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
222
+ )
223
+ self.proj_out = torch.nn.Conv2d(
224
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
225
+ )
226
+ self.attention_op: Optional[Any] = None
227
+
228
+ def attention(self, h_: torch.Tensor) -> torch.Tensor:
229
+ h_ = self.norm(h_)
230
+ q = self.q(h_)
231
+ k = self.k(h_)
232
+ v = self.v(h_)
233
+
234
+ # compute attention
235
+ B, C, H, W = q.shape
236
+ q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
237
+
238
+ q, k, v = map(
239
+ lambda t: t.unsqueeze(3)
240
+ .reshape(B, t.shape[1], 1, C)
241
+ .permute(0, 2, 1, 3)
242
+ .reshape(B * 1, t.shape[1], C)
243
+ .contiguous(),
244
+ (q, k, v),
245
+ )
246
+ out = xformers.ops.memory_efficient_attention(
247
+ q, k, v, attn_bias=None, op=self.attention_op
248
+ )
249
+
250
+ out = (
251
+ out.unsqueeze(0)
252
+ .reshape(B, 1, out.shape[1], C)
253
+ .permute(0, 2, 1, 3)
254
+ .reshape(B, out.shape[1], C)
255
+ )
256
+ return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
257
+
258
+ def forward(self, x, **kwargs):
259
+ h_ = x
260
+ h_ = self.attention(h_)
261
+ h_ = self.proj_out(h_)
262
+ return x + h_
263
+
264
+
265
+ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
266
+ def forward(self, x, context=None, mask=None, **unused_kwargs):
267
+ b, c, h, w = x.shape
268
+ x = rearrange(x, "b c h w -> b (h w) c")
269
+ out = super().forward(x, context=context, mask=mask)
270
+ out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c)
271
+ return x + out
272
+
273
+
274
+ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
275
+ assert attn_type in [
276
+ "vanilla",
277
+ "vanilla-xformers",
278
+ "memory-efficient-cross-attn",
279
+ "linear",
280
+ "none",
281
+ ], f"attn_type {attn_type} unknown"
282
+ if (
283
+ version.parse(torch.__version__) < version.parse("2.0.0")
284
+ and attn_type != "none"
285
+ ):
286
+ assert XFORMERS_IS_AVAILABLE, (
287
+ f"We do not support vanilla attention in {torch.__version__} anymore, "
288
+ f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
289
+ )
290
+ attn_type = "vanilla-xformers"
291
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
292
+ if attn_type == "vanilla":
293
+ assert attn_kwargs is None
294
+ return AttnBlock(in_channels)
295
+ elif attn_type == "vanilla-xformers":
296
+ print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
297
+ return MemoryEfficientAttnBlock(in_channels)
298
+ elif type == "memory-efficient-cross-attn":
299
+ attn_kwargs["query_dim"] = in_channels
300
+ return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
301
+ elif attn_type == "none":
302
+ return nn.Identity(in_channels)
303
+ else:
304
+ return LinAttnBlock(in_channels)
305
+
306
+
307
+ class Model(nn.Module):
308
+ def __init__(
309
+ self,
310
+ *,
311
+ ch,
312
+ out_ch,
313
+ ch_mult=(1, 2, 4, 8),
314
+ num_res_blocks,
315
+ attn_resolutions,
316
+ dropout=0.0,
317
+ resamp_with_conv=True,
318
+ in_channels,
319
+ resolution,
320
+ use_timestep=True,
321
+ use_linear_attn=False,
322
+ attn_type="vanilla",
323
+ ):
324
+ super().__init__()
325
+ if use_linear_attn:
326
+ attn_type = "linear"
327
+ self.ch = ch
328
+ self.temb_ch = self.ch * 4
329
+ self.num_resolutions = len(ch_mult)
330
+ self.num_res_blocks = num_res_blocks
331
+ self.resolution = resolution
332
+ self.in_channels = in_channels
333
+
334
+ self.use_timestep = use_timestep
335
+ if self.use_timestep:
336
+ # timestep embedding
337
+ self.temb = nn.Module()
338
+ self.temb.dense = nn.ModuleList(
339
+ [
340
+ torch.nn.Linear(self.ch, self.temb_ch),
341
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
342
+ ]
343
+ )
344
+
345
+ # downsampling
346
+ self.conv_in = torch.nn.Conv2d(
347
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
348
+ )
349
+
350
+ curr_res = resolution
351
+ in_ch_mult = (1,) + tuple(ch_mult)
352
+ self.down = nn.ModuleList()
353
+ for i_level in range(self.num_resolutions):
354
+ block = nn.ModuleList()
355
+ attn = nn.ModuleList()
356
+ block_in = ch * in_ch_mult[i_level]
357
+ block_out = ch * ch_mult[i_level]
358
+ for i_block in range(self.num_res_blocks):
359
+ block.append(
360
+ ResnetBlock(
361
+ in_channels=block_in,
362
+ out_channels=block_out,
363
+ temb_channels=self.temb_ch,
364
+ dropout=dropout,
365
+ )
366
+ )
367
+ block_in = block_out
368
+ if curr_res in attn_resolutions:
369
+ attn.append(make_attn(block_in, attn_type=attn_type))
370
+ down = nn.Module()
371
+ down.block = block
372
+ down.attn = attn
373
+ if i_level != self.num_resolutions - 1:
374
+ down.downsample = Downsample(block_in, resamp_with_conv)
375
+ curr_res = curr_res // 2
376
+ self.down.append(down)
377
+
378
+ # middle
379
+ self.mid = nn.Module()
380
+ self.mid.block_1 = ResnetBlock(
381
+ in_channels=block_in,
382
+ out_channels=block_in,
383
+ temb_channels=self.temb_ch,
384
+ dropout=dropout,
385
+ )
386
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
387
+ self.mid.block_2 = ResnetBlock(
388
+ in_channels=block_in,
389
+ out_channels=block_in,
390
+ temb_channels=self.temb_ch,
391
+ dropout=dropout,
392
+ )
393
+
394
+ # upsampling
395
+ self.up = nn.ModuleList()
396
+ for i_level in reversed(range(self.num_resolutions)):
397
+ block = nn.ModuleList()
398
+ attn = nn.ModuleList()
399
+ block_out = ch * ch_mult[i_level]
400
+ skip_in = ch * ch_mult[i_level]
401
+ for i_block in range(self.num_res_blocks + 1):
402
+ if i_block == self.num_res_blocks:
403
+ skip_in = ch * in_ch_mult[i_level]
404
+ block.append(
405
+ ResnetBlock(
406
+ in_channels=block_in + skip_in,
407
+ out_channels=block_out,
408
+ temb_channels=self.temb_ch,
409
+ dropout=dropout,
410
+ )
411
+ )
412
+ block_in = block_out
413
+ if curr_res in attn_resolutions:
414
+ attn.append(make_attn(block_in, attn_type=attn_type))
415
+ up = nn.Module()
416
+ up.block = block
417
+ up.attn = attn
418
+ if i_level != 0:
419
+ up.upsample = Upsample(block_in, resamp_with_conv)
420
+ curr_res = curr_res * 2
421
+ self.up.insert(0, up) # prepend to get consistent order
422
+
423
+ # end
424
+ self.norm_out = Normalize(block_in)
425
+ self.conv_out = torch.nn.Conv2d(
426
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
427
+ )
428
+
429
+ def forward(self, x, t=None, context=None):
430
+ # assert x.shape[2] == x.shape[3] == self.resolution
431
+ if context is not None:
432
+ # assume aligned context, cat along channel axis
433
+ x = torch.cat((x, context), dim=1)
434
+ if self.use_timestep:
435
+ # timestep embedding
436
+ assert t is not None
437
+ temb = get_timestep_embedding(t, self.ch)
438
+ temb = self.temb.dense[0](temb)
439
+ temb = nonlinearity(temb)
440
+ temb = self.temb.dense[1](temb)
441
+ else:
442
+ temb = None
443
+
444
+ # downsampling
445
+ hs = [self.conv_in(x)]
446
+ for i_level in range(self.num_resolutions):
447
+ for i_block in range(self.num_res_blocks):
448
+ h = self.down[i_level].block[i_block](hs[-1], temb)
449
+ if len(self.down[i_level].attn) > 0:
450
+ h = self.down[i_level].attn[i_block](h)
451
+ hs.append(h)
452
+ if i_level != self.num_resolutions - 1:
453
+ hs.append(self.down[i_level].downsample(hs[-1]))
454
+
455
+ # middle
456
+ h = hs[-1]
457
+ h = self.mid.block_1(h, temb)
458
+ h = self.mid.attn_1(h)
459
+ h = self.mid.block_2(h, temb)
460
+
461
+ # upsampling
462
+ for i_level in reversed(range(self.num_resolutions)):
463
+ for i_block in range(self.num_res_blocks + 1):
464
+ h = self.up[i_level].block[i_block](
465
+ torch.cat([h, hs.pop()], dim=1), temb
466
+ )
467
+ if len(self.up[i_level].attn) > 0:
468
+ h = self.up[i_level].attn[i_block](h)
469
+ if i_level != 0:
470
+ h = self.up[i_level].upsample(h)
471
+
472
+ # end
473
+ h = self.norm_out(h)
474
+ h = nonlinearity(h)
475
+ h = self.conv_out(h)
476
+ return h
477
+
478
+ def get_last_layer(self):
479
+ return self.conv_out.weight
480
+
481
+
482
+ class Encoder(nn.Module):
483
+ def __init__(
484
+ self,
485
+ *,
486
+ ch,
487
+ out_ch,
488
+ ch_mult=(1, 2, 4, 8),
489
+ num_res_blocks,
490
+ attn_resolutions,
491
+ dropout=0.0,
492
+ resamp_with_conv=True,
493
+ in_channels,
494
+ resolution,
495
+ z_channels,
496
+ double_z=True,
497
+ use_linear_attn=False,
498
+ attn_type="vanilla",
499
+ **ignore_kwargs,
500
+ ):
501
+ super().__init__()
502
+ if use_linear_attn:
503
+ attn_type = "linear"
504
+ self.ch = ch
505
+ self.temb_ch = 0
506
+ self.num_resolutions = len(ch_mult)
507
+ self.num_res_blocks = num_res_blocks
508
+ self.resolution = resolution
509
+ self.in_channels = in_channels
510
+
511
+ # downsampling
512
+ self.conv_in = torch.nn.Conv2d(
513
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
514
+ )
515
+
516
+ curr_res = resolution
517
+ in_ch_mult = (1,) + tuple(ch_mult)
518
+ self.in_ch_mult = in_ch_mult
519
+ self.down = nn.ModuleList()
520
+ for i_level in range(self.num_resolutions):
521
+ block = nn.ModuleList()
522
+ attn = nn.ModuleList()
523
+ block_in = ch * in_ch_mult[i_level]
524
+ block_out = ch * ch_mult[i_level]
525
+ for i_block in range(self.num_res_blocks):
526
+ block.append(
527
+ ResnetBlock(
528
+ in_channels=block_in,
529
+ out_channels=block_out,
530
+ temb_channels=self.temb_ch,
531
+ dropout=dropout,
532
+ )
533
+ )
534
+ block_in = block_out
535
+ if curr_res in attn_resolutions:
536
+ attn.append(make_attn(block_in, attn_type=attn_type))
537
+ down = nn.Module()
538
+ down.block = block
539
+ down.attn = attn
540
+ if i_level != self.num_resolutions - 1:
541
+ down.downsample = Downsample(block_in, resamp_with_conv)
542
+ curr_res = curr_res // 2
543
+ self.down.append(down)
544
+
545
+ # middle
546
+ self.mid = nn.Module()
547
+ self.mid.block_1 = ResnetBlock(
548
+ in_channels=block_in,
549
+ out_channels=block_in,
550
+ temb_channels=self.temb_ch,
551
+ dropout=dropout,
552
+ )
553
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
554
+ self.mid.block_2 = ResnetBlock(
555
+ in_channels=block_in,
556
+ out_channels=block_in,
557
+ temb_channels=self.temb_ch,
558
+ dropout=dropout,
559
+ )
560
+
561
+ # end
562
+ self.norm_out = Normalize(block_in)
563
+ self.conv_out = torch.nn.Conv2d(
564
+ block_in,
565
+ 2 * z_channels if double_z else z_channels,
566
+ kernel_size=3,
567
+ stride=1,
568
+ padding=1,
569
+ )
570
+
571
+ def forward(self, x):
572
+ # timestep embedding
573
+ temb = None
574
+
575
+ # downsampling
576
+ hs = [self.conv_in(x)]
577
+ for i_level in range(self.num_resolutions):
578
+ for i_block in range(self.num_res_blocks):
579
+ h = self.down[i_level].block[i_block](hs[-1], temb)
580
+ if len(self.down[i_level].attn) > 0:
581
+ h = self.down[i_level].attn[i_block](h)
582
+ hs.append(h)
583
+ if i_level != self.num_resolutions - 1:
584
+ hs.append(self.down[i_level].downsample(hs[-1]))
585
+
586
+ # middle
587
+ h = hs[-1]
588
+ h = self.mid.block_1(h, temb)
589
+ h = self.mid.attn_1(h)
590
+ h = self.mid.block_2(h, temb)
591
+
592
+ # end
593
+ h = self.norm_out(h)
594
+ h = nonlinearity(h)
595
+ h = self.conv_out(h)
596
+ return h
597
+
598
+
599
+ class Decoder(nn.Module):
600
+ def __init__(
601
+ self,
602
+ *,
603
+ ch,
604
+ out_ch,
605
+ ch_mult=(1, 2, 4, 8),
606
+ num_res_blocks,
607
+ attn_resolutions,
608
+ dropout=0.0,
609
+ resamp_with_conv=True,
610
+ in_channels,
611
+ resolution,
612
+ z_channels,
613
+ give_pre_end=False,
614
+ tanh_out=False,
615
+ use_linear_attn=False,
616
+ attn_type="vanilla",
617
+ **ignorekwargs,
618
+ ):
619
+ super().__init__()
620
+ if use_linear_attn:
621
+ attn_type = "linear"
622
+ self.ch = ch
623
+ self.temb_ch = 0
624
+ self.num_resolutions = len(ch_mult)
625
+ self.num_res_blocks = num_res_blocks
626
+ self.resolution = resolution
627
+ self.in_channels = in_channels
628
+ self.give_pre_end = give_pre_end
629
+ self.tanh_out = tanh_out
630
+
631
+ # compute in_ch_mult, block_in and curr_res at lowest res
632
+ in_ch_mult = (1,) + tuple(ch_mult)
633
+ block_in = ch * ch_mult[self.num_resolutions - 1]
634
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
635
+ self.z_shape = (1, z_channels, curr_res, curr_res)
636
+ print(
637
+ "Working with z of shape {} = {} dimensions.".format(
638
+ self.z_shape, np.prod(self.z_shape)
639
+ )
640
+ )
641
+
642
+ make_attn_cls = self._make_attn()
643
+ make_resblock_cls = self._make_resblock()
644
+ make_conv_cls = self._make_conv()
645
+ # z to block_in
646
+ self.conv_in = torch.nn.Conv2d(
647
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
648
+ )
649
+
650
+ # middle
651
+ self.mid = nn.Module()
652
+ self.mid.block_1 = make_resblock_cls(
653
+ in_channels=block_in,
654
+ out_channels=block_in,
655
+ temb_channels=self.temb_ch,
656
+ dropout=dropout,
657
+ )
658
+ self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type)
659
+ self.mid.block_2 = make_resblock_cls(
660
+ in_channels=block_in,
661
+ out_channels=block_in,
662
+ temb_channels=self.temb_ch,
663
+ dropout=dropout,
664
+ )
665
+
666
+ # upsampling
667
+ self.up = nn.ModuleList()
668
+ for i_level in reversed(range(self.num_resolutions)):
669
+ block = nn.ModuleList()
670
+ attn = nn.ModuleList()
671
+ block_out = ch * ch_mult[i_level]
672
+ for i_block in range(self.num_res_blocks + 1):
673
+ block.append(
674
+ make_resblock_cls(
675
+ in_channels=block_in,
676
+ out_channels=block_out,
677
+ temb_channels=self.temb_ch,
678
+ dropout=dropout,
679
+ )
680
+ )
681
+ block_in = block_out
682
+ if curr_res in attn_resolutions:
683
+ attn.append(make_attn_cls(block_in, attn_type=attn_type))
684
+ up = nn.Module()
685
+ up.block = block
686
+ up.attn = attn
687
+ if i_level != 0:
688
+ up.upsample = Upsample(block_in, resamp_with_conv)
689
+ curr_res = curr_res * 2
690
+ self.up.insert(0, up) # prepend to get consistent order
691
+
692
+ # end
693
+ self.norm_out = Normalize(block_in)
694
+ self.conv_out = make_conv_cls(
695
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
696
+ )
697
+
698
+ def _make_attn(self) -> Callable:
699
+ return make_attn
700
+
701
+ def _make_resblock(self) -> Callable:
702
+ return ResnetBlock
703
+
704
+ def _make_conv(self) -> Callable:
705
+ return torch.nn.Conv2d
706
+
707
+ def get_last_layer(self, **kwargs):
708
+ return self.conv_out.weight
709
+
710
+ def forward(self, z, **kwargs):
711
+ # assert z.shape[1:] == self.z_shape[1:]
712
+ self.last_z_shape = z.shape
713
+
714
+ # timestep embedding
715
+ temb = None
716
+
717
+ # z to block_in
718
+ h = self.conv_in(z)
719
+
720
+ # middle
721
+ h = self.mid.block_1(h, temb, **kwargs)
722
+ h = self.mid.attn_1(h, **kwargs)
723
+ h = self.mid.block_2(h, temb, **kwargs)
724
+
725
+ # upsampling
726
+ for i_level in reversed(range(self.num_resolutions)):
727
+ for i_block in range(self.num_res_blocks + 1):
728
+ h = self.up[i_level].block[i_block](h, temb, **kwargs)
729
+ if len(self.up[i_level].attn) > 0:
730
+ h = self.up[i_level].attn[i_block](h, **kwargs)
731
+ if i_level != 0:
732
+ h = self.up[i_level].upsample(h)
733
+
734
+ # end
735
+ if self.give_pre_end:
736
+ return h
737
+
738
+ h = self.norm_out(h)
739
+ h = nonlinearity(h)
740
+ h = self.conv_out(h, **kwargs)
741
+ if self.tanh_out:
742
+ h = torch.tanh(h)
743
+ return h
sgm/modules/diffusionmodules/openaimodel.py ADDED
@@ -0,0 +1,1262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from abc import abstractmethod
3
+ from functools import partial
4
+ from typing import Iterable
5
+
6
+ import numpy as np
7
+ import torch as th
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from einops import rearrange
11
+
12
+ from ...modules.attention import SpatialTransformer
13
+ from ...modules.diffusionmodules.util import (
14
+ avg_pool_nd,
15
+ checkpoint,
16
+ conv_nd,
17
+ linear,
18
+ normalization,
19
+ timestep_embedding,
20
+ zero_module,
21
+ )
22
+ from ...util import default, exists
23
+
24
+
25
+ # dummy replace
26
+ def convert_module_to_f16(x):
27
+ pass
28
+
29
+
30
+ def convert_module_to_f32(x):
31
+ pass
32
+
33
+
34
+ ## go
35
+ class AttentionPool2d(nn.Module):
36
+ """
37
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ spacial_dim: int,
43
+ embed_dim: int,
44
+ num_heads_channels: int,
45
+ output_dim: int = None,
46
+ ):
47
+ super().__init__()
48
+ self.positional_embedding = nn.Parameter(
49
+ th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
50
+ )
51
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
52
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
53
+ self.num_heads = embed_dim // num_heads_channels
54
+ self.attention = QKVAttention(self.num_heads)
55
+
56
+ def forward(self, x):
57
+ b, c, *_spatial = x.shape
58
+ x = x.reshape(b, c, -1) # NC(HW)
59
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
60
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
61
+ x = self.qkv_proj(x)
62
+ x = self.attention(x)
63
+ x = self.c_proj(x)
64
+ return x[:, :, 0]
65
+
66
+
67
+ class TimestepBlock(nn.Module):
68
+ """
69
+ Any module where forward() takes timestep embeddings as a second argument.
70
+ """
71
+
72
+ @abstractmethod
73
+ def forward(self, x, emb):
74
+ """
75
+ Apply the module to `x` given `emb` timestep embeddings.
76
+ """
77
+
78
+
79
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
80
+ """
81
+ A sequential module that passes timestep embeddings to the children that
82
+ support it as an extra input.
83
+ """
84
+
85
+ def forward(
86
+ self,
87
+ x,
88
+ emb,
89
+ context=None,
90
+ skip_time_mix=False,
91
+ time_context=None,
92
+ num_video_frames=None,
93
+ time_context_cat=None,
94
+ use_crossframe_attention_in_spatial_layers=False,
95
+ ):
96
+ for layer in self:
97
+ if isinstance(layer, TimestepBlock):
98
+ x = layer(x, emb)
99
+ elif isinstance(layer, SpatialTransformer):
100
+ x = layer(x, context)
101
+ else:
102
+ x = layer(x)
103
+ return x
104
+
105
+
106
+ class Upsample(nn.Module):
107
+ """
108
+ An upsampling layer with an optional convolution.
109
+ :param channels: channels in the inputs and outputs.
110
+ :param use_conv: a bool determining if a convolution is applied.
111
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
112
+ upsampling occurs in the inner-two dimensions.
113
+ """
114
+
115
+ def __init__(
116
+ self, channels, use_conv, dims=2, out_channels=None, padding=1, third_up=False
117
+ ):
118
+ super().__init__()
119
+ self.channels = channels
120
+ self.out_channels = out_channels or channels
121
+ self.use_conv = use_conv
122
+ self.dims = dims
123
+ self.third_up = third_up
124
+ if use_conv:
125
+ self.conv = conv_nd(
126
+ dims, self.channels, self.out_channels, 3, padding=padding
127
+ )
128
+
129
+ def forward(self, x):
130
+ assert x.shape[1] == self.channels
131
+ if self.dims == 3:
132
+ t_factor = 1 if not self.third_up else 2
133
+ x = F.interpolate(
134
+ x,
135
+ (t_factor * x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
136
+ mode="nearest",
137
+ )
138
+ else:
139
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
140
+ if self.use_conv:
141
+ x = self.conv(x)
142
+ return x
143
+
144
+
145
+ class TransposedUpsample(nn.Module):
146
+ "Learned 2x upsampling without padding"
147
+
148
+ def __init__(self, channels, out_channels=None, ks=5):
149
+ super().__init__()
150
+ self.channels = channels
151
+ self.out_channels = out_channels or channels
152
+
153
+ self.up = nn.ConvTranspose2d(
154
+ self.channels, self.out_channels, kernel_size=ks, stride=2
155
+ )
156
+
157
+ def forward(self, x):
158
+ return self.up(x)
159
+
160
+
161
+ class Downsample(nn.Module):
162
+ """
163
+ A downsampling layer with an optional convolution.
164
+ :param channels: channels in the inputs and outputs.
165
+ :param use_conv: a bool determining if a convolution is applied.
166
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
167
+ downsampling occurs in the inner-two dimensions.
168
+ """
169
+
170
+ def __init__(
171
+ self, channels, use_conv, dims=2, out_channels=None, padding=1, third_down=False
172
+ ):
173
+ super().__init__()
174
+ self.channels = channels
175
+ self.out_channels = out_channels or channels
176
+ self.use_conv = use_conv
177
+ self.dims = dims
178
+ stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2))
179
+ if use_conv:
180
+ print(f"Building a Downsample layer with {dims} dims.")
181
+ print(
182
+ f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, "
183
+ f"kernel-size: 3, stride: {stride}, padding: {padding}"
184
+ )
185
+ if dims == 3:
186
+ print(f" --> Downsampling third axis (time): {third_down}")
187
+ self.op = conv_nd(
188
+ dims,
189
+ self.channels,
190
+ self.out_channels,
191
+ 3,
192
+ stride=stride,
193
+ padding=padding,
194
+ )
195
+ else:
196
+ assert self.channels == self.out_channels
197
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
198
+
199
+ def forward(self, x):
200
+ assert x.shape[1] == self.channels
201
+ return self.op(x)
202
+
203
+
204
+ class ResBlock(TimestepBlock):
205
+ """
206
+ A residual block that can optionally change the number of channels.
207
+ :param channels: the number of input channels.
208
+ :param emb_channels: the number of timestep embedding channels.
209
+ :param dropout: the rate of dropout.
210
+ :param out_channels: if specified, the number of out channels.
211
+ :param use_conv: if True and out_channels is specified, use a spatial
212
+ convolution instead of a smaller 1x1 convolution to change the
213
+ channels in the skip connection.
214
+ :param dims: determines if the signal is 1D, 2D, or 3D.
215
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
216
+ :param up: if True, use this block for upsampling.
217
+ :param down: if True, use this block for downsampling.
218
+ """
219
+
220
+ def __init__(
221
+ self,
222
+ channels,
223
+ emb_channels,
224
+ dropout,
225
+ out_channels=None,
226
+ use_conv=False,
227
+ use_scale_shift_norm=False,
228
+ dims=2,
229
+ use_checkpoint=False,
230
+ up=False,
231
+ down=False,
232
+ kernel_size=3,
233
+ exchange_temb_dims=False,
234
+ skip_t_emb=False,
235
+ ):
236
+ super().__init__()
237
+ self.channels = channels
238
+ self.emb_channels = emb_channels
239
+ self.dropout = dropout
240
+ self.out_channels = out_channels or channels
241
+ self.use_conv = use_conv
242
+ self.use_checkpoint = use_checkpoint
243
+ self.use_scale_shift_norm = use_scale_shift_norm
244
+ self.exchange_temb_dims = exchange_temb_dims
245
+
246
+ if isinstance(kernel_size, Iterable):
247
+ padding = [k // 2 for k in kernel_size]
248
+ else:
249
+ padding = kernel_size // 2
250
+
251
+ self.in_layers = nn.Sequential(
252
+ normalization(channels),
253
+ nn.SiLU(),
254
+ conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
255
+ )
256
+
257
+ self.updown = up or down
258
+
259
+ if up:
260
+ self.h_upd = Upsample(channels, False, dims)
261
+ self.x_upd = Upsample(channels, False, dims)
262
+ elif down:
263
+ self.h_upd = Downsample(channels, False, dims)
264
+ self.x_upd = Downsample(channels, False, dims)
265
+ else:
266
+ self.h_upd = self.x_upd = nn.Identity()
267
+
268
+ self.skip_t_emb = skip_t_emb
269
+ self.emb_out_channels = (
270
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels
271
+ )
272
+ if self.skip_t_emb:
273
+ print(f"Skipping timestep embedding in {self.__class__.__name__}")
274
+ assert not self.use_scale_shift_norm
275
+ self.emb_layers = None
276
+ self.exchange_temb_dims = False
277
+ else:
278
+ self.emb_layers = nn.Sequential(
279
+ nn.SiLU(),
280
+ linear(
281
+ emb_channels,
282
+ self.emb_out_channels,
283
+ ),
284
+ )
285
+
286
+ self.out_layers = nn.Sequential(
287
+ normalization(self.out_channels),
288
+ nn.SiLU(),
289
+ nn.Dropout(p=dropout),
290
+ zero_module(
291
+ conv_nd(
292
+ dims,
293
+ self.out_channels,
294
+ self.out_channels,
295
+ kernel_size,
296
+ padding=padding,
297
+ )
298
+ ),
299
+ )
300
+
301
+ if self.out_channels == channels:
302
+ self.skip_connection = nn.Identity()
303
+ elif use_conv:
304
+ self.skip_connection = conv_nd(
305
+ dims, channels, self.out_channels, kernel_size, padding=padding
306
+ )
307
+ else:
308
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
309
+
310
+ def forward(self, x, emb):
311
+ """
312
+ Apply the block to a Tensor, conditioned on a timestep embedding.
313
+ :param x: an [N x C x ...] Tensor of features.
314
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
315
+ :return: an [N x C x ...] Tensor of outputs.
316
+ """
317
+ return checkpoint(
318
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
319
+ )
320
+
321
+ def _forward(self, x, emb):
322
+ if self.updown:
323
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
324
+ h = in_rest(x)
325
+ h = self.h_upd(h)
326
+ x = self.x_upd(x)
327
+ h = in_conv(h)
328
+ else:
329
+ h = self.in_layers(x)
330
+
331
+ if self.skip_t_emb:
332
+ emb_out = th.zeros_like(h)
333
+ else:
334
+ emb_out = self.emb_layers(emb).type(h.dtype)
335
+ while len(emb_out.shape) < len(h.shape):
336
+ emb_out = emb_out[..., None]
337
+ if self.use_scale_shift_norm:
338
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
339
+ scale, shift = th.chunk(emb_out, 2, dim=1)
340
+ h = out_norm(h) * (1 + scale) + shift
341
+ h = out_rest(h)
342
+ else:
343
+ if self.exchange_temb_dims:
344
+ emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
345
+ h = h + emb_out
346
+ h = self.out_layers(h)
347
+ return self.skip_connection(x) + h
348
+
349
+
350
+ class AttentionBlock(nn.Module):
351
+ """
352
+ An attention block that allows spatial positions to attend to each other.
353
+ Originally ported from here, but adapted to the N-d case.
354
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
355
+ """
356
+
357
+ def __init__(
358
+ self,
359
+ channels,
360
+ num_heads=1,
361
+ num_head_channels=-1,
362
+ use_checkpoint=False,
363
+ use_new_attention_order=False,
364
+ ):
365
+ super().__init__()
366
+ self.channels = channels
367
+ if num_head_channels == -1:
368
+ self.num_heads = num_heads
369
+ else:
370
+ assert (
371
+ channels % num_head_channels == 0
372
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
373
+ self.num_heads = channels // num_head_channels
374
+ self.use_checkpoint = use_checkpoint
375
+ self.norm = normalization(channels)
376
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
377
+ if use_new_attention_order:
378
+ # split qkv before split heads
379
+ self.attention = QKVAttention(self.num_heads)
380
+ else:
381
+ # split heads before split qkv
382
+ self.attention = QKVAttentionLegacy(self.num_heads)
383
+
384
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
385
+
386
+ def forward(self, x, **kwargs):
387
+ # TODO add crossframe attention and use mixed checkpoint
388
+ return checkpoint(
389
+ self._forward, (x,), self.parameters(), True
390
+ ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
391
+ # return pt_checkpoint(self._forward, x) # pytorch
392
+
393
+ def _forward(self, x):
394
+ b, c, *spatial = x.shape
395
+ x = x.reshape(b, c, -1)
396
+ qkv = self.qkv(self.norm(x))
397
+ h = self.attention(qkv)
398
+ h = self.proj_out(h)
399
+ return (x + h).reshape(b, c, *spatial)
400
+
401
+
402
+ def count_flops_attn(model, _x, y):
403
+ """
404
+ A counter for the `thop` package to count the operations in an
405
+ attention operation.
406
+ Meant to be used like:
407
+ macs, params = thop.profile(
408
+ model,
409
+ inputs=(inputs, timestamps),
410
+ custom_ops={QKVAttention: QKVAttention.count_flops},
411
+ )
412
+ """
413
+ b, c, *spatial = y[0].shape
414
+ num_spatial = int(np.prod(spatial))
415
+ # We perform two matmuls with the same number of ops.
416
+ # The first computes the weight matrix, the second computes
417
+ # the combination of the value vectors.
418
+ matmul_ops = 2 * b * (num_spatial**2) * c
419
+ model.total_ops += th.DoubleTensor([matmul_ops])
420
+
421
+
422
+ class QKVAttentionLegacy(nn.Module):
423
+ """
424
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
425
+ """
426
+
427
+ def __init__(self, n_heads):
428
+ super().__init__()
429
+ self.n_heads = n_heads
430
+
431
+ def forward(self, qkv):
432
+ """
433
+ Apply QKV attention.
434
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
435
+ :return: an [N x (H * C) x T] tensor after attention.
436
+ """
437
+ bs, width, length = qkv.shape
438
+ assert width % (3 * self.n_heads) == 0
439
+ ch = width // (3 * self.n_heads)
440
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
441
+ scale = 1 / math.sqrt(math.sqrt(ch))
442
+ weight = th.einsum(
443
+ "bct,bcs->bts", q * scale, k * scale
444
+ ) # More stable with f16 than dividing afterwards
445
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
446
+ a = th.einsum("bts,bcs->bct", weight, v)
447
+ return a.reshape(bs, -1, length)
448
+
449
+ @staticmethod
450
+ def count_flops(model, _x, y):
451
+ return count_flops_attn(model, _x, y)
452
+
453
+
454
+ class QKVAttention(nn.Module):
455
+ """
456
+ A module which performs QKV attention and splits in a different order.
457
+ """
458
+
459
+ def __init__(self, n_heads):
460
+ super().__init__()
461
+ self.n_heads = n_heads
462
+
463
+ def forward(self, qkv):
464
+ """
465
+ Apply QKV attention.
466
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
467
+ :return: an [N x (H * C) x T] tensor after attention.
468
+ """
469
+ bs, width, length = qkv.shape
470
+ assert width % (3 * self.n_heads) == 0
471
+ ch = width // (3 * self.n_heads)
472
+ q, k, v = qkv.chunk(3, dim=1)
473
+ scale = 1 / math.sqrt(math.sqrt(ch))
474
+ weight = th.einsum(
475
+ "bct,bcs->bts",
476
+ (q * scale).view(bs * self.n_heads, ch, length),
477
+ (k * scale).view(bs * self.n_heads, ch, length),
478
+ ) # More stable with f16 than dividing afterwards
479
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
480
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
481
+ return a.reshape(bs, -1, length)
482
+
483
+ @staticmethod
484
+ def count_flops(model, _x, y):
485
+ return count_flops_attn(model, _x, y)
486
+
487
+
488
+ class Timestep(nn.Module):
489
+ def __init__(self, dim):
490
+ super().__init__()
491
+ self.dim = dim
492
+
493
+ def forward(self, t):
494
+ return timestep_embedding(t, self.dim)
495
+
496
+
497
+ class UNetModel(nn.Module):
498
+ """
499
+ The full UNet model with attention and timestep embedding.
500
+ :param in_channels: channels in the input Tensor.
501
+ :param model_channels: base channel count for the model.
502
+ :param out_channels: channels in the output Tensor.
503
+ :param num_res_blocks: number of residual blocks per downsample.
504
+ :param attention_resolutions: a collection of downsample rates at which
505
+ attention will take place. May be a set, list, or tuple.
506
+ For example, if this contains 4, then at 4x downsampling, attention
507
+ will be used.
508
+ :param dropout: the dropout probability.
509
+ :param channel_mult: channel multiplier for each level of the UNet.
510
+ :param conv_resample: if True, use learned convolutions for upsampling and
511
+ downsampling.
512
+ :param dims: determines if the signal is 1D, 2D, or 3D.
513
+ :param num_classes: if specified (as an int), then this model will be
514
+ class-conditional with `num_classes` classes.
515
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
516
+ :param num_heads: the number of attention heads in each attention layer.
517
+ :param num_heads_channels: if specified, ignore num_heads and instead use
518
+ a fixed channel width per attention head.
519
+ :param num_heads_upsample: works with num_heads to set a different number
520
+ of heads for upsampling. Deprecated.
521
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
522
+ :param resblock_updown: use residual blocks for up/downsampling.
523
+ :param use_new_attention_order: use a different attention pattern for potentially
524
+ increased efficiency.
525
+ """
526
+
527
+ def __init__(
528
+ self,
529
+ in_channels,
530
+ model_channels,
531
+ out_channels,
532
+ num_res_blocks,
533
+ attention_resolutions,
534
+ dropout=0,
535
+ channel_mult=(1, 2, 4, 8),
536
+ conv_resample=True,
537
+ dims=2,
538
+ num_classes=None,
539
+ use_checkpoint=False,
540
+ use_fp16=False,
541
+ num_heads=-1,
542
+ num_head_channels=-1,
543
+ num_heads_upsample=-1,
544
+ use_scale_shift_norm=False,
545
+ resblock_updown=False,
546
+ use_new_attention_order=False,
547
+ use_spatial_transformer=False, # custom transformer support
548
+ transformer_depth=1, # custom transformer support
549
+ context_dim=None, # custom transformer support
550
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
551
+ legacy=True,
552
+ disable_self_attentions=None,
553
+ num_attention_blocks=None,
554
+ disable_middle_self_attn=False,
555
+ use_linear_in_transformer=False,
556
+ spatial_transformer_attn_type="softmax",
557
+ adm_in_channels=None,
558
+ use_fairscale_checkpoint=False,
559
+ offload_to_cpu=False,
560
+ transformer_depth_middle=None,
561
+ ):
562
+ super().__init__()
563
+ from omegaconf.listconfig import ListConfig
564
+
565
+ if use_spatial_transformer:
566
+ assert (
567
+ context_dim is not None
568
+ ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
569
+
570
+ if context_dim is not None:
571
+ assert (
572
+ use_spatial_transformer
573
+ ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
574
+ if type(context_dim) == ListConfig:
575
+ context_dim = list(context_dim)
576
+
577
+ if num_heads_upsample == -1:
578
+ num_heads_upsample = num_heads
579
+
580
+ if num_heads == -1:
581
+ assert (
582
+ num_head_channels != -1
583
+ ), "Either num_heads or num_head_channels has to be set"
584
+
585
+ if num_head_channels == -1:
586
+ assert (
587
+ num_heads != -1
588
+ ), "Either num_heads or num_head_channels has to be set"
589
+
590
+ self.in_channels = in_channels
591
+ self.model_channels = model_channels
592
+ self.out_channels = out_channels
593
+ if isinstance(transformer_depth, int):
594
+ transformer_depth = len(channel_mult) * [transformer_depth]
595
+ elif isinstance(transformer_depth, ListConfig):
596
+ transformer_depth = list(transformer_depth)
597
+ transformer_depth_middle = default(
598
+ transformer_depth_middle, transformer_depth[-1]
599
+ )
600
+
601
+ if isinstance(num_res_blocks, int):
602
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
603
+ else:
604
+ if len(num_res_blocks) != len(channel_mult):
605
+ raise ValueError(
606
+ "provide num_res_blocks either as an int (globally constant) or "
607
+ "as a list/tuple (per-level) with the same length as channel_mult"
608
+ )
609
+ self.num_res_blocks = num_res_blocks
610
+ # self.num_res_blocks = num_res_blocks
611
+ if disable_self_attentions is not None:
612
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
613
+ assert len(disable_self_attentions) == len(channel_mult)
614
+ if num_attention_blocks is not None:
615
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
616
+ assert all(
617
+ map(
618
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
619
+ range(len(num_attention_blocks)),
620
+ )
621
+ )
622
+ print(
623
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
624
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
625
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
626
+ f"attention will still not be set."
627
+ ) # todo: convert to warning
628
+
629
+ self.attention_resolutions = attention_resolutions
630
+ self.dropout = dropout
631
+ self.channel_mult = channel_mult
632
+ self.conv_resample = conv_resample
633
+ self.num_classes = num_classes
634
+ self.use_checkpoint = use_checkpoint
635
+ if use_fp16:
636
+ print("WARNING: use_fp16 was dropped and has no effect anymore.")
637
+ # self.dtype = th.float16 if use_fp16 else th.float32
638
+ self.num_heads = num_heads
639
+ self.num_head_channels = num_head_channels
640
+ self.num_heads_upsample = num_heads_upsample
641
+ self.predict_codebook_ids = n_embed is not None
642
+
643
+ assert use_fairscale_checkpoint != use_checkpoint or not (
644
+ use_checkpoint or use_fairscale_checkpoint
645
+ )
646
+
647
+ self.use_fairscale_checkpoint = False
648
+ checkpoint_wrapper_fn = (
649
+ partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu)
650
+ if self.use_fairscale_checkpoint
651
+ else lambda x: x
652
+ )
653
+
654
+ time_embed_dim = model_channels * 4
655
+ self.time_embed = checkpoint_wrapper_fn(
656
+ nn.Sequential(
657
+ linear(model_channels, time_embed_dim),
658
+ nn.SiLU(),
659
+ linear(time_embed_dim, time_embed_dim),
660
+ )
661
+ )
662
+
663
+ if self.num_classes is not None:
664
+ if isinstance(self.num_classes, int):
665
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
666
+ elif self.num_classes == "continuous":
667
+ print("setting up linear c_adm embedding layer")
668
+ self.label_emb = nn.Linear(1, time_embed_dim)
669
+ elif self.num_classes == "timestep":
670
+ self.label_emb = checkpoint_wrapper_fn(
671
+ nn.Sequential(
672
+ Timestep(model_channels),
673
+ nn.Sequential(
674
+ linear(model_channels, time_embed_dim),
675
+ nn.SiLU(),
676
+ linear(time_embed_dim, time_embed_dim),
677
+ ),
678
+ )
679
+ )
680
+ elif self.num_classes == "sequential":
681
+ assert adm_in_channels is not None
682
+ self.label_emb = nn.Sequential(
683
+ nn.Sequential(
684
+ linear(adm_in_channels, time_embed_dim),
685
+ nn.SiLU(),
686
+ linear(time_embed_dim, time_embed_dim),
687
+ )
688
+ )
689
+ else:
690
+ raise ValueError()
691
+
692
+ self.input_blocks = nn.ModuleList(
693
+ [
694
+ TimestepEmbedSequential(
695
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
696
+ )
697
+ ]
698
+ )
699
+ self._feature_size = model_channels
700
+ input_block_chans = [model_channels]
701
+ ch = model_channels
702
+ ds = 1
703
+ for level, mult in enumerate(channel_mult):
704
+ for nr in range(self.num_res_blocks[level]):
705
+ layers = [
706
+ checkpoint_wrapper_fn(
707
+ ResBlock(
708
+ ch,
709
+ time_embed_dim,
710
+ dropout,
711
+ out_channels=mult * model_channels,
712
+ dims=dims,
713
+ use_checkpoint=use_checkpoint,
714
+ use_scale_shift_norm=use_scale_shift_norm,
715
+ )
716
+ )
717
+ ]
718
+ ch = mult * model_channels
719
+ if ds in attention_resolutions:
720
+ if num_head_channels == -1:
721
+ dim_head = ch // num_heads
722
+ else:
723
+ num_heads = ch // num_head_channels
724
+ dim_head = num_head_channels
725
+ if legacy:
726
+ # num_heads = 1
727
+ dim_head = (
728
+ ch // num_heads
729
+ if use_spatial_transformer
730
+ else num_head_channels
731
+ )
732
+ if exists(disable_self_attentions):
733
+ disabled_sa = disable_self_attentions[level]
734
+ else:
735
+ disabled_sa = False
736
+
737
+ if (
738
+ not exists(num_attention_blocks)
739
+ or nr < num_attention_blocks[level]
740
+ ):
741
+ layers.append(
742
+ checkpoint_wrapper_fn(
743
+ AttentionBlock(
744
+ ch,
745
+ use_checkpoint=use_checkpoint,
746
+ num_heads=num_heads,
747
+ num_head_channels=dim_head,
748
+ use_new_attention_order=use_new_attention_order,
749
+ )
750
+ )
751
+ if not use_spatial_transformer
752
+ else checkpoint_wrapper_fn(
753
+ SpatialTransformer(
754
+ ch,
755
+ num_heads,
756
+ dim_head,
757
+ depth=transformer_depth[level],
758
+ context_dim=context_dim,
759
+ disable_self_attn=disabled_sa,
760
+ use_linear=use_linear_in_transformer,
761
+ attn_type=spatial_transformer_attn_type,
762
+ use_checkpoint=use_checkpoint,
763
+ )
764
+ )
765
+ )
766
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
767
+ self._feature_size += ch
768
+ input_block_chans.append(ch)
769
+ if level != len(channel_mult) - 1:
770
+ out_ch = ch
771
+ self.input_blocks.append(
772
+ TimestepEmbedSequential(
773
+ checkpoint_wrapper_fn(
774
+ ResBlock(
775
+ ch,
776
+ time_embed_dim,
777
+ dropout,
778
+ out_channels=out_ch,
779
+ dims=dims,
780
+ use_checkpoint=use_checkpoint,
781
+ use_scale_shift_norm=use_scale_shift_norm,
782
+ down=True,
783
+ )
784
+ )
785
+ if resblock_updown
786
+ else Downsample(
787
+ ch, conv_resample, dims=dims, out_channels=out_ch
788
+ )
789
+ )
790
+ )
791
+ ch = out_ch
792
+ input_block_chans.append(ch)
793
+ ds *= 2
794
+ self._feature_size += ch
795
+
796
+ if num_head_channels == -1:
797
+ dim_head = ch // num_heads
798
+ else:
799
+ num_heads = ch // num_head_channels
800
+ dim_head = num_head_channels
801
+ if legacy:
802
+ # num_heads = 1
803
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
804
+ self.middle_block = TimestepEmbedSequential(
805
+ checkpoint_wrapper_fn(
806
+ ResBlock(
807
+ ch,
808
+ time_embed_dim,
809
+ dropout,
810
+ dims=dims,
811
+ use_checkpoint=use_checkpoint,
812
+ use_scale_shift_norm=use_scale_shift_norm,
813
+ )
814
+ ),
815
+ checkpoint_wrapper_fn(
816
+ AttentionBlock(
817
+ ch,
818
+ use_checkpoint=use_checkpoint,
819
+ num_heads=num_heads,
820
+ num_head_channels=dim_head,
821
+ use_new_attention_order=use_new_attention_order,
822
+ )
823
+ )
824
+ if not use_spatial_transformer
825
+ else checkpoint_wrapper_fn(
826
+ SpatialTransformer( # always uses a self-attn
827
+ ch,
828
+ num_heads,
829
+ dim_head,
830
+ depth=transformer_depth_middle,
831
+ context_dim=context_dim,
832
+ disable_self_attn=disable_middle_self_attn,
833
+ use_linear=use_linear_in_transformer,
834
+ attn_type=spatial_transformer_attn_type,
835
+ use_checkpoint=use_checkpoint,
836
+ )
837
+ ),
838
+ checkpoint_wrapper_fn(
839
+ ResBlock(
840
+ ch,
841
+ time_embed_dim,
842
+ dropout,
843
+ dims=dims,
844
+ use_checkpoint=use_checkpoint,
845
+ use_scale_shift_norm=use_scale_shift_norm,
846
+ )
847
+ ),
848
+ )
849
+ self._feature_size += ch
850
+
851
+ self.output_blocks = nn.ModuleList([])
852
+ for level, mult in list(enumerate(channel_mult))[::-1]:
853
+ for i in range(self.num_res_blocks[level] + 1):
854
+ ich = input_block_chans.pop()
855
+ layers = [
856
+ checkpoint_wrapper_fn(
857
+ ResBlock(
858
+ ch + ich,
859
+ time_embed_dim,
860
+ dropout,
861
+ out_channels=model_channels * mult,
862
+ dims=dims,
863
+ use_checkpoint=use_checkpoint,
864
+ use_scale_shift_norm=use_scale_shift_norm,
865
+ )
866
+ )
867
+ ]
868
+ ch = model_channels * mult
869
+ if ds in attention_resolutions:
870
+ if num_head_channels == -1:
871
+ dim_head = ch // num_heads
872
+ else:
873
+ num_heads = ch // num_head_channels
874
+ dim_head = num_head_channels
875
+ if legacy:
876
+ # num_heads = 1
877
+ dim_head = (
878
+ ch // num_heads
879
+ if use_spatial_transformer
880
+ else num_head_channels
881
+ )
882
+ if exists(disable_self_attentions):
883
+ disabled_sa = disable_self_attentions[level]
884
+ else:
885
+ disabled_sa = False
886
+
887
+ if (
888
+ not exists(num_attention_blocks)
889
+ or i < num_attention_blocks[level]
890
+ ):
891
+ layers.append(
892
+ checkpoint_wrapper_fn(
893
+ AttentionBlock(
894
+ ch,
895
+ use_checkpoint=use_checkpoint,
896
+ num_heads=num_heads_upsample,
897
+ num_head_channels=dim_head,
898
+ use_new_attention_order=use_new_attention_order,
899
+ )
900
+ )
901
+ if not use_spatial_transformer
902
+ else checkpoint_wrapper_fn(
903
+ SpatialTransformer(
904
+ ch,
905
+ num_heads,
906
+ dim_head,
907
+ depth=transformer_depth[level],
908
+ context_dim=context_dim,
909
+ disable_self_attn=disabled_sa,
910
+ use_linear=use_linear_in_transformer,
911
+ attn_type=spatial_transformer_attn_type,
912
+ use_checkpoint=use_checkpoint,
913
+ )
914
+ )
915
+ )
916
+ if level and i == self.num_res_blocks[level]:
917
+ out_ch = ch
918
+ layers.append(
919
+ checkpoint_wrapper_fn(
920
+ ResBlock(
921
+ ch,
922
+ time_embed_dim,
923
+ dropout,
924
+ out_channels=out_ch,
925
+ dims=dims,
926
+ use_checkpoint=use_checkpoint,
927
+ use_scale_shift_norm=use_scale_shift_norm,
928
+ up=True,
929
+ )
930
+ )
931
+ if resblock_updown
932
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
933
+ )
934
+ ds //= 2
935
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
936
+ self._feature_size += ch
937
+
938
+ self.out = checkpoint_wrapper_fn(
939
+ nn.Sequential(
940
+ normalization(ch),
941
+ nn.SiLU(),
942
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
943
+ )
944
+ )
945
+ if self.predict_codebook_ids:
946
+ self.id_predictor = checkpoint_wrapper_fn(
947
+ nn.Sequential(
948
+ normalization(ch),
949
+ conv_nd(dims, model_channels, n_embed, 1),
950
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
951
+ )
952
+ )
953
+
954
+ def convert_to_fp16(self):
955
+ """
956
+ Convert the torso of the model to float16.
957
+ """
958
+ self.input_blocks.apply(convert_module_to_f16)
959
+ self.middle_block.apply(convert_module_to_f16)
960
+ self.output_blocks.apply(convert_module_to_f16)
961
+
962
+ def convert_to_fp32(self):
963
+ """
964
+ Convert the torso of the model to float32.
965
+ """
966
+ self.input_blocks.apply(convert_module_to_f32)
967
+ self.middle_block.apply(convert_module_to_f32)
968
+ self.output_blocks.apply(convert_module_to_f32)
969
+
970
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
971
+ """
972
+ Apply the model to an input batch.
973
+ :param x: an [N x C x ...] Tensor of inputs.
974
+ :param timesteps: a 1-D batch of timesteps.
975
+ :param context: conditioning plugged in via crossattn
976
+ :param y: an [N] Tensor of labels, if class-conditional.
977
+ :return: an [N x C x ...] Tensor of outputs.
978
+ """
979
+ assert (y is not None) == (
980
+ self.num_classes is not None
981
+ ), "must specify y if and only if the model is class-conditional"
982
+ hs = []
983
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
984
+ emb = self.time_embed(t_emb)
985
+
986
+ if self.num_classes is not None:
987
+ assert y.shape[0] == x.shape[0]
988
+ emb = emb + self.label_emb(y)
989
+
990
+ # h = x.type(self.dtype)
991
+ h = x
992
+ for module in self.input_blocks:
993
+ h = module(h, emb, context)
994
+ hs.append(h)
995
+ h = self.middle_block(h, emb, context)
996
+ for module in self.output_blocks:
997
+ h = th.cat([h, hs.pop()], dim=1)
998
+ h = module(h, emb, context)
999
+ h = h.type(x.dtype)
1000
+ if self.predict_codebook_ids:
1001
+ assert False, "not supported anymore. what the f*** are you doing?"
1002
+ else:
1003
+ return self.out(h)
1004
+
1005
+
1006
+ class NoTimeUNetModel(UNetModel):
1007
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
1008
+ timesteps = th.zeros_like(timesteps)
1009
+ return super().forward(x, timesteps, context, y, **kwargs)
1010
+
1011
+
1012
+ class EncoderUNetModel(nn.Module):
1013
+ """
1014
+ The half UNet model with attention and timestep embedding.
1015
+ For usage, see UNet.
1016
+ """
1017
+
1018
+ def __init__(
1019
+ self,
1020
+ image_size,
1021
+ in_channels,
1022
+ model_channels,
1023
+ out_channels,
1024
+ num_res_blocks,
1025
+ attention_resolutions,
1026
+ dropout=0,
1027
+ channel_mult=(1, 2, 4, 8),
1028
+ conv_resample=True,
1029
+ dims=2,
1030
+ use_checkpoint=False,
1031
+ use_fp16=False,
1032
+ num_heads=1,
1033
+ num_head_channels=-1,
1034
+ num_heads_upsample=-1,
1035
+ use_scale_shift_norm=False,
1036
+ resblock_updown=False,
1037
+ use_new_attention_order=False,
1038
+ pool="adaptive",
1039
+ *args,
1040
+ **kwargs,
1041
+ ):
1042
+ super().__init__()
1043
+
1044
+ if num_heads_upsample == -1:
1045
+ num_heads_upsample = num_heads
1046
+
1047
+ self.in_channels = in_channels
1048
+ self.model_channels = model_channels
1049
+ self.out_channels = out_channels
1050
+ self.num_res_blocks = num_res_blocks
1051
+ self.attention_resolutions = attention_resolutions
1052
+ self.dropout = dropout
1053
+ self.channel_mult = channel_mult
1054
+ self.conv_resample = conv_resample
1055
+ self.use_checkpoint = use_checkpoint
1056
+ self.dtype = th.float16 if use_fp16 else th.float32
1057
+ self.num_heads = num_heads
1058
+ self.num_head_channels = num_head_channels
1059
+ self.num_heads_upsample = num_heads_upsample
1060
+
1061
+ time_embed_dim = model_channels * 4
1062
+ self.time_embed = nn.Sequential(
1063
+ linear(model_channels, time_embed_dim),
1064
+ nn.SiLU(),
1065
+ linear(time_embed_dim, time_embed_dim),
1066
+ )
1067
+
1068
+ self.input_blocks = nn.ModuleList(
1069
+ [
1070
+ TimestepEmbedSequential(
1071
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
1072
+ )
1073
+ ]
1074
+ )
1075
+ self._feature_size = model_channels
1076
+ input_block_chans = [model_channels]
1077
+ ch = model_channels
1078
+ ds = 1
1079
+ for level, mult in enumerate(channel_mult):
1080
+ for _ in range(num_res_blocks):
1081
+ layers = [
1082
+ ResBlock(
1083
+ ch,
1084
+ time_embed_dim,
1085
+ dropout,
1086
+ out_channels=mult * model_channels,
1087
+ dims=dims,
1088
+ use_checkpoint=use_checkpoint,
1089
+ use_scale_shift_norm=use_scale_shift_norm,
1090
+ )
1091
+ ]
1092
+ ch = mult * model_channels
1093
+ if ds in attention_resolutions:
1094
+ layers.append(
1095
+ AttentionBlock(
1096
+ ch,
1097
+ use_checkpoint=use_checkpoint,
1098
+ num_heads=num_heads,
1099
+ num_head_channels=num_head_channels,
1100
+ use_new_attention_order=use_new_attention_order,
1101
+ )
1102
+ )
1103
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
1104
+ self._feature_size += ch
1105
+ input_block_chans.append(ch)
1106
+ if level != len(channel_mult) - 1:
1107
+ out_ch = ch
1108
+ self.input_blocks.append(
1109
+ TimestepEmbedSequential(
1110
+ ResBlock(
1111
+ ch,
1112
+ time_embed_dim,
1113
+ dropout,
1114
+ out_channels=out_ch,
1115
+ dims=dims,
1116
+ use_checkpoint=use_checkpoint,
1117
+ use_scale_shift_norm=use_scale_shift_norm,
1118
+ down=True,
1119
+ )
1120
+ if resblock_updown
1121
+ else Downsample(
1122
+ ch, conv_resample, dims=dims, out_channels=out_ch
1123
+ )
1124
+ )
1125
+ )
1126
+ ch = out_ch
1127
+ input_block_chans.append(ch)
1128
+ ds *= 2
1129
+ self._feature_size += ch
1130
+
1131
+ self.middle_block = TimestepEmbedSequential(
1132
+ ResBlock(
1133
+ ch,
1134
+ time_embed_dim,
1135
+ dropout,
1136
+ dims=dims,
1137
+ use_checkpoint=use_checkpoint,
1138
+ use_scale_shift_norm=use_scale_shift_norm,
1139
+ ),
1140
+ AttentionBlock(
1141
+ ch,
1142
+ use_checkpoint=use_checkpoint,
1143
+ num_heads=num_heads,
1144
+ num_head_channels=num_head_channels,
1145
+ use_new_attention_order=use_new_attention_order,
1146
+ ),
1147
+ ResBlock(
1148
+ ch,
1149
+ time_embed_dim,
1150
+ dropout,
1151
+ dims=dims,
1152
+ use_checkpoint=use_checkpoint,
1153
+ use_scale_shift_norm=use_scale_shift_norm,
1154
+ ),
1155
+ )
1156
+ self._feature_size += ch
1157
+ self.pool = pool
1158
+ if pool == "adaptive":
1159
+ self.out = nn.Sequential(
1160
+ normalization(ch),
1161
+ nn.SiLU(),
1162
+ nn.AdaptiveAvgPool2d((1, 1)),
1163
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
1164
+ nn.Flatten(),
1165
+ )
1166
+ elif pool == "attention":
1167
+ assert num_head_channels != -1
1168
+ self.out = nn.Sequential(
1169
+ normalization(ch),
1170
+ nn.SiLU(),
1171
+ AttentionPool2d(
1172
+ (image_size // ds), ch, num_head_channels, out_channels
1173
+ ),
1174
+ )
1175
+ elif pool == "spatial":
1176
+ self.out = nn.Sequential(
1177
+ nn.Linear(self._feature_size, 2048),
1178
+ nn.ReLU(),
1179
+ nn.Linear(2048, self.out_channels),
1180
+ )
1181
+ elif pool == "spatial_v2":
1182
+ self.out = nn.Sequential(
1183
+ nn.Linear(self._feature_size, 2048),
1184
+ normalization(2048),
1185
+ nn.SiLU(),
1186
+ nn.Linear(2048, self.out_channels),
1187
+ )
1188
+ else:
1189
+ raise NotImplementedError(f"Unexpected {pool} pooling")
1190
+
1191
+ def convert_to_fp16(self):
1192
+ """
1193
+ Convert the torso of the model to float16.
1194
+ """
1195
+ self.input_blocks.apply(convert_module_to_f16)
1196
+ self.middle_block.apply(convert_module_to_f16)
1197
+
1198
+ def convert_to_fp32(self):
1199
+ """
1200
+ Convert the torso of the model to float32.
1201
+ """
1202
+ self.input_blocks.apply(convert_module_to_f32)
1203
+ self.middle_block.apply(convert_module_to_f32)
1204
+
1205
+ def forward(self, x, timesteps):
1206
+ """
1207
+ Apply the model to an input batch.
1208
+ :param x: an [N x C x ...] Tensor of inputs.
1209
+ :param timesteps: a 1-D batch of timesteps.
1210
+ :return: an [N x K] Tensor of outputs.
1211
+ """
1212
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
1213
+
1214
+ results = []
1215
+ # h = x.type(self.dtype)
1216
+ h = x
1217
+ for module in self.input_blocks:
1218
+ h = module(h, emb)
1219
+ if self.pool.startswith("spatial"):
1220
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
1221
+ h = self.middle_block(h, emb)
1222
+ if self.pool.startswith("spatial"):
1223
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
1224
+ h = th.cat(results, axis=-1)
1225
+ return self.out(h)
1226
+ else:
1227
+ h = h.type(x.dtype)
1228
+ return self.out(h)
1229
+
1230
+
1231
+ if __name__ == "__main__":
1232
+
1233
+ class Dummy(nn.Module):
1234
+ def __init__(self, in_channels=3, model_channels=64):
1235
+ super().__init__()
1236
+ self.input_blocks = nn.ModuleList(
1237
+ [
1238
+ TimestepEmbedSequential(
1239
+ conv_nd(2, in_channels, model_channels, 3, padding=1)
1240
+ )
1241
+ ]
1242
+ )
1243
+
1244
+ model = UNetModel(
1245
+ use_checkpoint=True,
1246
+ image_size=64,
1247
+ in_channels=4,
1248
+ out_channels=4,
1249
+ model_channels=128,
1250
+ attention_resolutions=[4, 2],
1251
+ num_res_blocks=2,
1252
+ channel_mult=[1, 2, 4],
1253
+ num_head_channels=64,
1254
+ use_spatial_transformer=False,
1255
+ use_linear_in_transformer=True,
1256
+ transformer_depth=1,
1257
+ legacy=False,
1258
+ ).cuda()
1259
+ x = th.randn(11, 4, 64, 64).cuda()
1260
+ t = th.randint(low=0, high=10, size=(11,), device="cuda")
1261
+ o = model(x, t)
1262
+ print("done.")
sgm/modules/diffusionmodules/sampling.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
3
+ """
4
+
5
+
6
+ from typing import Dict, Union
7
+
8
+ import torch
9
+ from omegaconf import ListConfig, OmegaConf
10
+ from tqdm import tqdm
11
+
12
+ from ...modules.diffusionmodules.sampling_utils import (
13
+ get_ancestral_step,
14
+ linear_multistep_coeff,
15
+ to_d,
16
+ to_neg_log_sigma,
17
+ to_sigma,
18
+ )
19
+ from ...util import append_dims, default, instantiate_from_config
20
+
21
+ DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}
22
+
23
+
24
+ class BaseDiffusionSampler:
25
+ def __init__(
26
+ self,
27
+ discretization_config: Union[Dict, ListConfig, OmegaConf],
28
+ num_steps: Union[int, None] = None,
29
+ guider_config: Union[Dict, ListConfig, OmegaConf, None] = None,
30
+ verbose: bool = False,
31
+ device: str = "cuda",
32
+ ):
33
+ self.num_steps = num_steps
34
+ self.discretization = instantiate_from_config(discretization_config)
35
+ self.guider = instantiate_from_config(
36
+ default(
37
+ guider_config,
38
+ DEFAULT_GUIDER,
39
+ )
40
+ )
41
+ self.verbose = verbose
42
+ self.device = device
43
+
44
+ def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
45
+ sigmas = self.discretization(
46
+ self.num_steps if num_steps is None else num_steps, device=self.device
47
+ )
48
+ uc = default(uc, cond)
49
+
50
+ x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
51
+ num_sigmas = len(sigmas)
52
+
53
+ s_in = x.new_ones([x.shape[0]])
54
+
55
+ return x, s_in, sigmas, num_sigmas, cond, uc
56
+
57
+ def denoise(self, x, denoiser, sigma, cond, uc):
58
+ denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc))
59
+ denoised = self.guider(denoised, sigma)
60
+ return denoised
61
+
62
+ def get_sigma_gen(self, num_sigmas):
63
+ sigma_generator = range(num_sigmas - 1)
64
+ if self.verbose:
65
+ print("#" * 30, " Sampling setting ", "#" * 30)
66
+ print(f"Sampler: {self.__class__.__name__}")
67
+ print(f"Discretization: {self.discretization.__class__.__name__}")
68
+ print(f"Guider: {self.guider.__class__.__name__}")
69
+ sigma_generator = tqdm(
70
+ sigma_generator,
71
+ total=num_sigmas,
72
+ desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps",
73
+ )
74
+ return sigma_generator
75
+
76
+
77
+ class SingleStepDiffusionSampler(BaseDiffusionSampler):
78
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs):
79
+ raise NotImplementedError
80
+
81
+ def euler_step(self, x, d, dt):
82
+ return x + dt * d
83
+
84
+
85
+ class EDMSampler(SingleStepDiffusionSampler):
86
+ def __init__(
87
+ self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs
88
+ ):
89
+ super().__init__(*args, **kwargs)
90
+
91
+ self.s_churn = s_churn
92
+ self.s_tmin = s_tmin
93
+ self.s_tmax = s_tmax
94
+ self.s_noise = s_noise
95
+
96
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0):
97
+ sigma_hat = sigma * (gamma + 1.0)
98
+ if gamma > 0:
99
+ eps = torch.randn_like(x) * self.s_noise
100
+ x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
101
+
102
+ denoised = self.denoise(x, denoiser, sigma_hat, cond, uc)
103
+ d = to_d(x, sigma_hat, denoised)
104
+ dt = append_dims(next_sigma - sigma_hat, x.ndim)
105
+
106
+ euler_step = self.euler_step(x, d, dt)
107
+ x = self.possible_correction_step(
108
+ euler_step, x, d, dt, next_sigma, denoiser, cond, uc
109
+ )
110
+ return x
111
+
112
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
113
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
114
+ x, cond, uc, num_steps
115
+ )
116
+
117
+ for i in self.get_sigma_gen(num_sigmas):
118
+ gamma = (
119
+ min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
120
+ if self.s_tmin <= sigmas[i] <= self.s_tmax
121
+ else 0.0
122
+ )
123
+ x = self.sampler_step(
124
+ s_in * sigmas[i],
125
+ s_in * sigmas[i + 1],
126
+ denoiser,
127
+ x,
128
+ cond,
129
+ uc,
130
+ gamma,
131
+ )
132
+
133
+ return x
134
+
135
+
136
+ class AncestralSampler(SingleStepDiffusionSampler):
137
+ def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs):
138
+ super().__init__(*args, **kwargs)
139
+
140
+ self.eta = eta
141
+ self.s_noise = s_noise
142
+ self.noise_sampler = lambda x: torch.randn_like(x)
143
+
144
+ def ancestral_euler_step(self, x, denoised, sigma, sigma_down):
145
+ d = to_d(x, sigma, denoised)
146
+ dt = append_dims(sigma_down - sigma, x.ndim)
147
+
148
+ return self.euler_step(x, d, dt)
149
+
150
+ def ancestral_step(self, x, sigma, next_sigma, sigma_up):
151
+ x = torch.where(
152
+ append_dims(next_sigma, x.ndim) > 0.0,
153
+ x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim),
154
+ x,
155
+ )
156
+ return x
157
+
158
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
159
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
160
+ x, cond, uc, num_steps
161
+ )
162
+
163
+ for i in self.get_sigma_gen(num_sigmas):
164
+ x = self.sampler_step(
165
+ s_in * sigmas[i],
166
+ s_in * sigmas[i + 1],
167
+ denoiser,
168
+ x,
169
+ cond,
170
+ uc,
171
+ )
172
+
173
+ return x
174
+
175
+
176
+ class LinearMultistepSampler(BaseDiffusionSampler):
177
+ def __init__(
178
+ self,
179
+ order=4,
180
+ *args,
181
+ **kwargs,
182
+ ):
183
+ super().__init__(*args, **kwargs)
184
+
185
+ self.order = order
186
+
187
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
188
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
189
+ x, cond, uc, num_steps
190
+ )
191
+
192
+ ds = []
193
+ sigmas_cpu = sigmas.detach().cpu().numpy()
194
+ for i in self.get_sigma_gen(num_sigmas):
195
+ sigma = s_in * sigmas[i]
196
+ denoised = denoiser(
197
+ *self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs
198
+ )
199
+ denoised = self.guider(denoised, sigma)
200
+ d = to_d(x, sigma, denoised)
201
+ ds.append(d)
202
+ if len(ds) > self.order:
203
+ ds.pop(0)
204
+ cur_order = min(i + 1, self.order)
205
+ coeffs = [
206
+ linear_multistep_coeff(cur_order, sigmas_cpu, i, j)
207
+ for j in range(cur_order)
208
+ ]
209
+ x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
210
+
211
+ return x
212
+
213
+
214
+ class EulerEDMSampler(EDMSampler):
215
+ def possible_correction_step(
216
+ self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
217
+ ):
218
+ return euler_step
219
+
220
+
221
+ class HeunEDMSampler(EDMSampler):
222
+ def possible_correction_step(
223
+ self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
224
+ ):
225
+ if torch.sum(next_sigma) < 1e-14:
226
+ # Save a network evaluation if all noise levels are 0
227
+ return euler_step
228
+ else:
229
+ denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc)
230
+ d_new = to_d(euler_step, next_sigma, denoised)
231
+ d_prime = (d + d_new) / 2.0
232
+
233
+ # apply correction if noise level is not 0
234
+ x = torch.where(
235
+ append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step
236
+ )
237
+ return x
238
+
239
+
240
+ class EulerAncestralSampler(AncestralSampler):
241
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc):
242
+ sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
243
+ denoised = self.denoise(x, denoiser, sigma, cond, uc)
244
+ x = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
245
+ x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
246
+
247
+ return x
248
+
249
+
250
+ class DPMPP2SAncestralSampler(AncestralSampler):
251
+ def get_variables(self, sigma, sigma_down):
252
+ t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)]
253
+ h = t_next - t
254
+ s = t + 0.5 * h
255
+ return h, s, t, t_next
256
+
257
+ def get_mult(self, h, s, t, t_next):
258
+ mult1 = to_sigma(s) / to_sigma(t)
259
+ mult2 = (-0.5 * h).expm1()
260
+ mult3 = to_sigma(t_next) / to_sigma(t)
261
+ mult4 = (-h).expm1()
262
+
263
+ return mult1, mult2, mult3, mult4
264
+
265
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs):
266
+ sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
267
+ denoised = self.denoise(x, denoiser, sigma, cond, uc)
268
+ x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
269
+
270
+ if torch.sum(sigma_down) < 1e-14:
271
+ # Save a network evaluation if all noise levels are 0
272
+ x = x_euler
273
+ else:
274
+ h, s, t, t_next = self.get_variables(sigma, sigma_down)
275
+ mult = [
276
+ append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)
277
+ ]
278
+
279
+ x2 = mult[0] * x - mult[1] * denoised
280
+ denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc)
281
+ x_dpmpp2s = mult[2] * x - mult[3] * denoised2
282
+
283
+ # apply correction if noise level is not 0
284
+ x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler)
285
+
286
+ x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
287
+ return x
288
+
289
+
290
+ class DPMPP2MSampler(BaseDiffusionSampler):
291
+ def get_variables(self, sigma, next_sigma, previous_sigma=None):
292
+ t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
293
+ h = t_next - t
294
+
295
+ if previous_sigma is not None:
296
+ h_last = t - to_neg_log_sigma(previous_sigma)
297
+ r = h_last / h
298
+ return h, r, t, t_next
299
+ else:
300
+ return h, None, t, t_next
301
+
302
+ def get_mult(self, h, r, t, t_next, previous_sigma):
303
+ mult1 = to_sigma(t_next) / to_sigma(t)
304
+ mult2 = (-h).expm1()
305
+
306
+ if previous_sigma is not None:
307
+ mult3 = 1 + 1 / (2 * r)
308
+ mult4 = 1 / (2 * r)
309
+ return mult1, mult2, mult3, mult4
310
+ else:
311
+ return mult1, mult2
312
+
313
+ def sampler_step(
314
+ self,
315
+ old_denoised,
316
+ previous_sigma,
317
+ sigma,
318
+ next_sigma,
319
+ denoiser,
320
+ x,
321
+ cond,
322
+ uc=None,
323
+ ):
324
+ denoised = self.denoise(x, denoiser, sigma, cond, uc)
325
+
326
+ h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
327
+ mult = [
328
+ append_dims(mult, x.ndim)
329
+ for mult in self.get_mult(h, r, t, t_next, previous_sigma)
330
+ ]
331
+
332
+ x_standard = mult[0] * x - mult[1] * denoised
333
+ if old_denoised is None or torch.sum(next_sigma) < 1e-14:
334
+ # Save a network evaluation if all noise levels are 0 or on the first step
335
+ return x_standard, denoised
336
+ else:
337
+ denoised_d = mult[2] * denoised - mult[3] * old_denoised
338
+ x_advanced = mult[0] * x - mult[1] * denoised_d
339
+
340
+ # apply correction if noise level is not 0 and not first step
341
+ x = torch.where(
342
+ append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard
343
+ )
344
+
345
+ return x, denoised
346
+
347
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
348
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
349
+ x, cond, uc, num_steps
350
+ )
351
+
352
+ old_denoised = None
353
+ for i in self.get_sigma_gen(num_sigmas):
354
+ x, old_denoised = self.sampler_step(
355
+ old_denoised,
356
+ None if i == 0 else s_in * sigmas[i - 1],
357
+ s_in * sigmas[i],
358
+ s_in * sigmas[i + 1],
359
+ denoiser,
360
+ x,
361
+ cond,
362
+ uc=uc,
363
+ )
364
+
365
+ return x
sgm/modules/diffusionmodules/sampling_utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from scipy import integrate
3
+
4
+ from ...util import append_dims
5
+
6
+
7
+ class NoDynamicThresholding:
8
+ def __call__(self, uncond, cond, scale):
9
+ return uncond + scale * (cond - uncond)
10
+
11
+
12
+ def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
13
+ if order - 1 > i:
14
+ raise ValueError(f"Order {order} too high for step {i}")
15
+
16
+ def fn(tau):
17
+ prod = 1.0
18
+ for k in range(order):
19
+ if j == k:
20
+ continue
21
+ prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
22
+ return prod
23
+
24
+ return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0]
25
+
26
+
27
+ def get_ancestral_step(sigma_from, sigma_to, eta=1.0):
28
+ if not eta:
29
+ return sigma_to, 0.0
30
+ sigma_up = torch.minimum(
31
+ sigma_to,
32
+ eta
33
+ * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5,
34
+ )
35
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
36
+ return sigma_down, sigma_up
37
+
38
+
39
+ def to_d(x, sigma, denoised):
40
+ return (x - denoised) / append_dims(sigma, x.ndim)
41
+
42
+
43
+ def to_neg_log_sigma(sigma):
44
+ return sigma.log().neg()
45
+
46
+
47
+ def to_sigma(neg_log_sigma):
48
+ return neg_log_sigma.neg().exp()
sgm/modules/diffusionmodules/sigma_sampling.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ...util import default, instantiate_from_config
4
+
5
+
6
+ class EDMSampling:
7
+ def __init__(self, p_mean=-1.2, p_std=1.2):
8
+ self.p_mean = p_mean
9
+ self.p_std = p_std
10
+
11
+ def __call__(self, n_samples, rand=None):
12
+ log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,)))
13
+ return log_sigma.exp()
14
+
15
+
16
+ class DiscreteSampling:
17
+ def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True):
18
+ self.num_idx = num_idx
19
+ self.sigmas = instantiate_from_config(discretization_config)(
20
+ num_idx, do_append_zero=do_append_zero, flip=flip
21
+ )
22
+
23
+ def idx_to_sigma(self, idx):
24
+ return self.sigmas[idx]
25
+
26
+ def __call__(self, n_samples, rand=None):
27
+ idx = default(
28
+ rand,
29
+ torch.randint(0, self.num_idx, (n_samples,)),
30
+ )
31
+ return self.idx_to_sigma(idx)
sgm/modules/diffusionmodules/util.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ adopted from
3
+ https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
4
+ and
5
+ https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
6
+ and
7
+ https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
8
+
9
+ thanks!
10
+ """
11
+
12
+ import math
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from einops import repeat
17
+
18
+
19
+ def make_beta_schedule(
20
+ schedule,
21
+ n_timestep,
22
+ linear_start=1e-4,
23
+ linear_end=2e-2,
24
+ ):
25
+ if schedule == "linear":
26
+ betas = (
27
+ torch.linspace(
28
+ linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
29
+ )
30
+ ** 2
31
+ )
32
+ return betas.numpy()
33
+
34
+
35
+ def extract_into_tensor(a, t, x_shape):
36
+ b, *_ = t.shape
37
+ out = a.gather(-1, t)
38
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
39
+
40
+
41
+ def mixed_checkpoint(func, inputs: dict, params, flag):
42
+ """
43
+ Evaluate a function without caching intermediate activations, allowing for
44
+ reduced memory at the expense of extra compute in the backward pass. This differs from the original checkpoint function
45
+ borrowed from https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py in that
46
+ it also works with non-tensor inputs
47
+ :param func: the function to evaluate.
48
+ :param inputs: the argument dictionary to pass to `func`.
49
+ :param params: a sequence of parameters `func` depends on but does not
50
+ explicitly take as arguments.
51
+ :param flag: if False, disable gradient checkpointing.
52
+ """
53
+ if flag:
54
+ tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)]
55
+ tensor_inputs = [
56
+ inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor)
57
+ ]
58
+ non_tensor_keys = [
59
+ key for key in inputs if not isinstance(inputs[key], torch.Tensor)
60
+ ]
61
+ non_tensor_inputs = [
62
+ inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor)
63
+ ]
64
+ args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params)
65
+ return MixedCheckpointFunction.apply(
66
+ func,
67
+ len(tensor_inputs),
68
+ len(non_tensor_inputs),
69
+ tensor_keys,
70
+ non_tensor_keys,
71
+ *args,
72
+ )
73
+ else:
74
+ return func(**inputs)
75
+
76
+
77
+ class MixedCheckpointFunction(torch.autograd.Function):
78
+ @staticmethod
79
+ def forward(
80
+ ctx,
81
+ run_function,
82
+ length_tensors,
83
+ length_non_tensors,
84
+ tensor_keys,
85
+ non_tensor_keys,
86
+ *args,
87
+ ):
88
+ ctx.end_tensors = length_tensors
89
+ ctx.end_non_tensors = length_tensors + length_non_tensors
90
+ ctx.gpu_autocast_kwargs = {
91
+ "enabled": torch.is_autocast_enabled(),
92
+ "dtype": torch.get_autocast_gpu_dtype(),
93
+ "cache_enabled": torch.is_autocast_cache_enabled(),
94
+ }
95
+ assert (
96
+ len(tensor_keys) == length_tensors
97
+ and len(non_tensor_keys) == length_non_tensors
98
+ )
99
+
100
+ ctx.input_tensors = {
101
+ key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors]))
102
+ }
103
+ ctx.input_non_tensors = {
104
+ key: val
105
+ for (key, val) in zip(
106
+ non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors])
107
+ )
108
+ }
109
+ ctx.run_function = run_function
110
+ ctx.input_params = list(args[ctx.end_non_tensors :])
111
+
112
+ with torch.no_grad():
113
+ output_tensors = ctx.run_function(
114
+ **ctx.input_tensors, **ctx.input_non_tensors
115
+ )
116
+ return output_tensors
117
+
118
+ @staticmethod
119
+ def backward(ctx, *output_grads):
120
+ # additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)}
121
+ ctx.input_tensors = {
122
+ key: ctx.input_tensors[key].detach().requires_grad_(True)
123
+ for key in ctx.input_tensors
124
+ }
125
+
126
+ with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
127
+ # Fixes a bug where the first op in run_function modifies the
128
+ # Tensor storage in place, which is not allowed for detach()'d
129
+ # Tensors.
130
+ shallow_copies = {
131
+ key: ctx.input_tensors[key].view_as(ctx.input_tensors[key])
132
+ for key in ctx.input_tensors
133
+ }
134
+ # shallow_copies.update(additional_args)
135
+ output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors)
136
+ input_grads = torch.autograd.grad(
137
+ output_tensors,
138
+ list(ctx.input_tensors.values()) + ctx.input_params,
139
+ output_grads,
140
+ allow_unused=True,
141
+ )
142
+ del ctx.input_tensors
143
+ del ctx.input_params
144
+ del output_tensors
145
+ return (
146
+ (None, None, None, None, None)
147
+ + input_grads[: ctx.end_tensors]
148
+ + (None,) * (ctx.end_non_tensors - ctx.end_tensors)
149
+ + input_grads[ctx.end_tensors :]
150
+ )
151
+
152
+
153
+ def checkpoint(func, inputs, params, flag):
154
+ """
155
+ Evaluate a function without caching intermediate activations, allowing for
156
+ reduced memory at the expense of extra compute in the backward pass.
157
+ :param func: the function to evaluate.
158
+ :param inputs: the argument sequence to pass to `func`.
159
+ :param params: a sequence of parameters `func` depends on but does not
160
+ explicitly take as arguments.
161
+ :param flag: if False, disable gradient checkpointing.
162
+ """
163
+ if flag:
164
+ args = tuple(inputs) + tuple(params)
165
+ return CheckpointFunction.apply(func, len(inputs), *args)
166
+ else:
167
+ return func(*inputs)
168
+
169
+
170
+ class CheckpointFunction(torch.autograd.Function):
171
+ @staticmethod
172
+ def forward(ctx, run_function, length, *args):
173
+ ctx.run_function = run_function
174
+ ctx.input_tensors = list(args[:length])
175
+ ctx.input_params = list(args[length:])
176
+ ctx.gpu_autocast_kwargs = {
177
+ "enabled": torch.is_autocast_enabled(),
178
+ "dtype": torch.get_autocast_gpu_dtype(),
179
+ "cache_enabled": torch.is_autocast_cache_enabled(),
180
+ }
181
+ with torch.no_grad():
182
+ output_tensors = ctx.run_function(*ctx.input_tensors)
183
+ return output_tensors
184
+
185
+ @staticmethod
186
+ def backward(ctx, *output_grads):
187
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
188
+ with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
189
+ # Fixes a bug where the first op in run_function modifies the
190
+ # Tensor storage in place, which is not allowed for detach()'d
191
+ # Tensors.
192
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
193
+ output_tensors = ctx.run_function(*shallow_copies)
194
+ input_grads = torch.autograd.grad(
195
+ output_tensors,
196
+ ctx.input_tensors + ctx.input_params,
197
+ output_grads,
198
+ allow_unused=True,
199
+ )
200
+ del ctx.input_tensors
201
+ del ctx.input_params
202
+ del output_tensors
203
+ return (None, None) + input_grads
204
+
205
+
206
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
207
+ """
208
+ Create sinusoidal timestep embeddings.
209
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
210
+ These may be fractional.
211
+ :param dim: the dimension of the output.
212
+ :param max_period: controls the minimum frequency of the embeddings.
213
+ :return: an [N x dim] Tensor of positional embeddings.
214
+ """
215
+ if not repeat_only:
216
+ half = dim // 2
217
+ freqs = torch.exp(
218
+ -math.log(max_period)
219
+ * torch.arange(start=0, end=half, dtype=torch.float32)
220
+ / half
221
+ ).to(device=timesteps.device)
222
+ args = timesteps[:, None].float() * freqs[None]
223
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
224
+ if dim % 2:
225
+ embedding = torch.cat(
226
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
227
+ )
228
+ else:
229
+ embedding = repeat(timesteps, "b -> b d", d=dim)
230
+ return embedding
231
+
232
+
233
+ def zero_module(module):
234
+ """
235
+ Zero out the parameters of a module and return it.
236
+ """
237
+ for p in module.parameters():
238
+ p.detach().zero_()
239
+ return module
240
+
241
+
242
+ def scale_module(module, scale):
243
+ """
244
+ Scale the parameters of a module and return it.
245
+ """
246
+ for p in module.parameters():
247
+ p.detach().mul_(scale)
248
+ return module
249
+
250
+
251
+ def mean_flat(tensor):
252
+ """
253
+ Take the mean over all non-batch dimensions.
254
+ """
255
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
256
+
257
+
258
+ def normalization(channels):
259
+ """
260
+ Make a standard normalization layer.
261
+ :param channels: number of input channels.
262
+ :return: an nn.Module for normalization.
263
+ """
264
+ return GroupNorm32(32, channels)
265
+
266
+
267
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
268
+ class SiLU(nn.Module):
269
+ def forward(self, x):
270
+ return x * torch.sigmoid(x)
271
+
272
+
273
+ class GroupNorm32(nn.GroupNorm):
274
+ def forward(self, x):
275
+ return super().forward(x.float()).type(x.dtype)
276
+
277
+
278
+ def conv_nd(dims, *args, **kwargs):
279
+ """
280
+ Create a 1D, 2D, or 3D convolution module.
281
+ """
282
+ if dims == 1:
283
+ return nn.Conv1d(*args, **kwargs)
284
+ elif dims == 2:
285
+ return nn.Conv2d(*args, **kwargs)
286
+ elif dims == 3:
287
+ return nn.Conv3d(*args, **kwargs)
288
+ raise ValueError(f"unsupported dimensions: {dims}")
289
+
290
+
291
+ def linear(*args, **kwargs):
292
+ """
293
+ Create a linear module.
294
+ """
295
+ return nn.Linear(*args, **kwargs)
296
+
297
+
298
+ def avg_pool_nd(dims, *args, **kwargs):
299
+ """
300
+ Create a 1D, 2D, or 3D average pooling module.
301
+ """
302
+ if dims == 1:
303
+ return nn.AvgPool1d(*args, **kwargs)
304
+ elif dims == 2:
305
+ return nn.AvgPool2d(*args, **kwargs)
306
+ elif dims == 3:
307
+ return nn.AvgPool3d(*args, **kwargs)
308
+ raise ValueError(f"unsupported dimensions: {dims}")
sgm/modules/diffusionmodules/wrappers.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from packaging import version
4
+
5
+ OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper"
6
+
7
+
8
+ class IdentityWrapper(nn.Module):
9
+ def __init__(self, diffusion_model, compile_model: bool = False):
10
+ super().__init__()
11
+ compile = (
12
+ torch.compile
13
+ if (version.parse(torch.__version__) >= version.parse("2.0.0"))
14
+ and compile_model
15
+ else lambda x: x
16
+ )
17
+ self.diffusion_model = compile(diffusion_model)
18
+
19
+ def forward(self, *args, **kwargs):
20
+ return self.diffusion_model(*args, **kwargs)
21
+
22
+
23
+ class OpenAIWrapper(IdentityWrapper):
24
+ def forward(
25
+ self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
26
+ ) -> torch.Tensor:
27
+ x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
28
+ return self.diffusion_model(
29
+ x,
30
+ timesteps=t,
31
+ context=c.get("crossattn", None),
32
+ y=c.get("vector", None),
33
+ **kwargs,
34
+ )
sgm/modules/distributions/__init__.py ADDED
File without changes
sgm/modules/distributions/distributions.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ class AbstractDistribution:
6
+ def sample(self):
7
+ raise NotImplementedError()
8
+
9
+ def mode(self):
10
+ raise NotImplementedError()
11
+
12
+
13
+ class DiracDistribution(AbstractDistribution):
14
+ def __init__(self, value):
15
+ self.value = value
16
+
17
+ def sample(self):
18
+ return self.value
19
+
20
+ def mode(self):
21
+ return self.value
22
+
23
+
24
+ class DiagonalGaussianDistribution(object):
25
+ def __init__(self, parameters, deterministic=False):
26
+ self.parameters = parameters
27
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
+ self.deterministic = deterministic
30
+ self.std = torch.exp(0.5 * self.logvar)
31
+ self.var = torch.exp(self.logvar)
32
+ if self.deterministic:
33
+ self.var = self.std = torch.zeros_like(self.mean).to(
34
+ device=self.parameters.device
35
+ )
36
+
37
+ def sample(self):
38
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(
39
+ device=self.parameters.device
40
+ )
41
+ return x
42
+
43
+ def kl(self, other=None):
44
+ if self.deterministic:
45
+ return torch.Tensor([0.0])
46
+ else:
47
+ if other is None:
48
+ return 0.5 * torch.sum(
49
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
50
+ dim=[1, 2, 3],
51
+ )
52
+ else:
53
+ return 0.5 * torch.sum(
54
+ torch.pow(self.mean - other.mean, 2) / other.var
55
+ + self.var / other.var
56
+ - 1.0
57
+ - self.logvar
58
+ + other.logvar,
59
+ dim=[1, 2, 3],
60
+ )
61
+
62
+ def nll(self, sample, dims=[1, 2, 3]):
63
+ if self.deterministic:
64
+ return torch.Tensor([0.0])
65
+ logtwopi = np.log(2.0 * np.pi)
66
+ return 0.5 * torch.sum(
67
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
68
+ dim=dims,
69
+ )
70
+
71
+ def mode(self):
72
+ return self.mean
73
+
74
+
75
+ def normal_kl(mean1, logvar1, mean2, logvar2):
76
+ """
77
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
78
+ Compute the KL divergence between two gaussians.
79
+ Shapes are automatically broadcasted, so batches can be compared to
80
+ scalars, among other use cases.
81
+ """
82
+ tensor = None
83
+ for obj in (mean1, logvar1, mean2, logvar2):
84
+ if isinstance(obj, torch.Tensor):
85
+ tensor = obj
86
+ break
87
+ assert tensor is not None, "at least one argument must be a Tensor"
88
+
89
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
90
+ # Tensors, but it does not work for torch.exp().
91
+ logvar1, logvar2 = [
92
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
93
+ for x in (logvar1, logvar2)
94
+ ]
95
+
96
+ return 0.5 * (
97
+ -1.0
98
+ + logvar2
99
+ - logvar1
100
+ + torch.exp(logvar1 - logvar2)
101
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
102
+ )
sgm/modules/ema.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class LitEma(nn.Module):
6
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
7
+ super().__init__()
8
+ if decay < 0.0 or decay > 1.0:
9
+ raise ValueError("Decay must be between 0 and 1")
10
+
11
+ self.m_name2s_name = {}
12
+ self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
13
+ self.register_buffer(
14
+ "num_updates",
15
+ torch.tensor(0, dtype=torch.int)
16
+ if use_num_upates
17
+ else torch.tensor(-1, dtype=torch.int),
18
+ )
19
+
20
+ for name, p in model.named_parameters():
21
+ if p.requires_grad:
22
+ # remove as '.'-character is not allowed in buffers
23
+ s_name = name.replace(".", "")
24
+ self.m_name2s_name.update({name: s_name})
25
+ self.register_buffer(s_name, p.clone().detach().data)
26
+
27
+ self.collected_params = []
28
+
29
+ def reset_num_updates(self):
30
+ del self.num_updates
31
+ self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))
32
+
33
+ def forward(self, model):
34
+ decay = self.decay
35
+
36
+ if self.num_updates >= 0:
37
+ self.num_updates += 1
38
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
39
+
40
+ one_minus_decay = 1.0 - decay
41
+
42
+ with torch.no_grad():
43
+ m_param = dict(model.named_parameters())
44
+ shadow_params = dict(self.named_buffers())
45
+
46
+ for key in m_param:
47
+ if m_param[key].requires_grad:
48
+ sname = self.m_name2s_name[key]
49
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
50
+ shadow_params[sname].sub_(
51
+ one_minus_decay * (shadow_params[sname] - m_param[key])
52
+ )
53
+ else:
54
+ assert not key in self.m_name2s_name
55
+
56
+ def copy_to(self, model):
57
+ m_param = dict(model.named_parameters())
58
+ shadow_params = dict(self.named_buffers())
59
+ for key in m_param:
60
+ if m_param[key].requires_grad:
61
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
62
+ else:
63
+ assert not key in self.m_name2s_name
64
+
65
+ def store(self, parameters):
66
+ """
67
+ Save the current parameters for restoring later.
68
+ Args:
69
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
70
+ temporarily stored.
71
+ """
72
+ self.collected_params = [param.clone() for param in parameters]
73
+
74
+ def restore(self, parameters):
75
+ """
76
+ Restore the parameters stored with the `store` method.
77
+ Useful to validate the model with EMA parameters without affecting the
78
+ original optimization process. Store the parameters before the
79
+ `copy_to` method. After validation (or model saving), use this to
80
+ restore the former parameters.
81
+ Args:
82
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
83
+ updated with the stored parameters.
84
+ """
85
+ for c_param, param in zip(self.collected_params, parameters):
86
+ param.data.copy_(c_param.data)
sgm/modules/encoders/__init__.py ADDED
File without changes
sgm/modules/encoders/modules.py ADDED
@@ -0,0 +1,960 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import nullcontext
2
+ from functools import partial
3
+ from typing import Dict, List, Optional, Tuple, Union
4
+
5
+ import kornia
6
+ import numpy as np
7
+ import open_clip
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops import rearrange, repeat
11
+ from omegaconf import ListConfig
12
+ from torch.utils.checkpoint import checkpoint
13
+ from transformers import (
14
+ ByT5Tokenizer,
15
+ CLIPTextModel,
16
+ CLIPTokenizer,
17
+ T5EncoderModel,
18
+ T5Tokenizer,
19
+ )
20
+
21
+ from ...modules.autoencoding.regularizers import DiagonalGaussianRegularizer
22
+ from ...modules.diffusionmodules.model import Encoder
23
+ from ...modules.diffusionmodules.openaimodel import Timestep
24
+ from ...modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
25
+ from ...modules.distributions.distributions import DiagonalGaussianDistribution
26
+ from ...util import (
27
+ autocast,
28
+ count_params,
29
+ default,
30
+ disabled_train,
31
+ expand_dims_like,
32
+ instantiate_from_config,
33
+ )
34
+
35
+
36
+ class AbstractEmbModel(nn.Module):
37
+ def __init__(self):
38
+ super().__init__()
39
+ self._is_trainable = None
40
+ self._ucg_rate = None
41
+ self._input_key = None
42
+
43
+ @property
44
+ def is_trainable(self) -> bool:
45
+ return self._is_trainable
46
+
47
+ @property
48
+ def ucg_rate(self) -> Union[float, torch.Tensor]:
49
+ return self._ucg_rate
50
+
51
+ @property
52
+ def input_key(self) -> str:
53
+ return self._input_key
54
+
55
+ @is_trainable.setter
56
+ def is_trainable(self, value: bool):
57
+ self._is_trainable = value
58
+
59
+ @ucg_rate.setter
60
+ def ucg_rate(self, value: Union[float, torch.Tensor]):
61
+ self._ucg_rate = value
62
+
63
+ @input_key.setter
64
+ def input_key(self, value: str):
65
+ self._input_key = value
66
+
67
+ @is_trainable.deleter
68
+ def is_trainable(self):
69
+ del self._is_trainable
70
+
71
+ @ucg_rate.deleter
72
+ def ucg_rate(self):
73
+ del self._ucg_rate
74
+
75
+ @input_key.deleter
76
+ def input_key(self):
77
+ del self._input_key
78
+
79
+
80
+ class GeneralConditioner(nn.Module):
81
+ OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"}
82
+ KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1}
83
+
84
+ def __init__(self, emb_models: Union[List, ListConfig]):
85
+ super().__init__()
86
+ embedders = []
87
+ for n, embconfig in enumerate(emb_models):
88
+ embedder = instantiate_from_config(embconfig)
89
+ assert isinstance(
90
+ embedder, AbstractEmbModel
91
+ ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel"
92
+ embedder.is_trainable = embconfig.get("is_trainable", False)
93
+ embedder.ucg_rate = embconfig.get("ucg_rate", 0.0)
94
+ if not embedder.is_trainable:
95
+ embedder.train = disabled_train
96
+ for param in embedder.parameters():
97
+ param.requires_grad = False
98
+ embedder.eval()
99
+ print(
100
+ f"Initialized embedder #{n}: {embedder.__class__.__name__} "
101
+ f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}"
102
+ )
103
+
104
+ if "input_key" in embconfig:
105
+ embedder.input_key = embconfig["input_key"]
106
+ elif "input_keys" in embconfig:
107
+ embedder.input_keys = embconfig["input_keys"]
108
+ else:
109
+ raise KeyError(
110
+ f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}"
111
+ )
112
+
113
+ embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None)
114
+ if embedder.legacy_ucg_val is not None:
115
+ embedder.ucg_prng = np.random.RandomState()
116
+
117
+ embedders.append(embedder)
118
+ self.embedders = nn.ModuleList(embedders)
119
+
120
+ def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict:
121
+ assert embedder.legacy_ucg_val is not None
122
+ p = embedder.ucg_rate
123
+ val = embedder.legacy_ucg_val
124
+ for i in range(len(batch[embedder.input_key])):
125
+ if embedder.ucg_prng.choice(2, p=[1 - p, p]):
126
+ batch[embedder.input_key][i] = val
127
+ return batch
128
+
129
+ def forward(
130
+ self, batch: Dict, force_zero_embeddings: Optional[List] = None
131
+ ) -> Dict:
132
+ output = dict()
133
+ if force_zero_embeddings is None:
134
+ force_zero_embeddings = []
135
+ for embedder in self.embedders:
136
+ embedding_context = nullcontext if embedder.is_trainable else torch.no_grad
137
+ with embedding_context():
138
+ if hasattr(embedder, "input_key") and (embedder.input_key is not None):
139
+ if embedder.legacy_ucg_val is not None:
140
+ batch = self.possibly_get_ucg_val(embedder, batch)
141
+ emb_out = embedder(batch[embedder.input_key])
142
+ elif hasattr(embedder, "input_keys"):
143
+ emb_out = embedder(*[batch[k] for k in embedder.input_keys])
144
+ assert isinstance(
145
+ emb_out, (torch.Tensor, list, tuple)
146
+ ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}"
147
+ if not isinstance(emb_out, (list, tuple)):
148
+ emb_out = [emb_out]
149
+ for emb in emb_out:
150
+ out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
151
+ if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
152
+ emb = (
153
+ expand_dims_like(
154
+ torch.bernoulli(
155
+ (1.0 - embedder.ucg_rate)
156
+ * torch.ones(emb.shape[0], device=emb.device)
157
+ ),
158
+ emb,
159
+ )
160
+ * emb
161
+ )
162
+ if (
163
+ hasattr(embedder, "input_key")
164
+ and embedder.input_key in force_zero_embeddings
165
+ ):
166
+ emb = torch.zeros_like(emb)
167
+ if out_key in output:
168
+ output[out_key] = torch.cat(
169
+ (output[out_key], emb), self.KEY2CATDIM[out_key]
170
+ )
171
+ else:
172
+ output[out_key] = emb
173
+ return output
174
+
175
+ def get_unconditional_conditioning(
176
+ self, batch_c, batch_uc=None, force_uc_zero_embeddings=None
177
+ ):
178
+ if force_uc_zero_embeddings is None:
179
+ force_uc_zero_embeddings = []
180
+ ucg_rates = list()
181
+ for embedder in self.embedders:
182
+ ucg_rates.append(embedder.ucg_rate)
183
+ embedder.ucg_rate = 0.0
184
+ c = self(batch_c)
185
+ uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings)
186
+
187
+ for embedder, rate in zip(self.embedders, ucg_rates):
188
+ embedder.ucg_rate = rate
189
+ return c, uc
190
+
191
+
192
+ class InceptionV3(nn.Module):
193
+ """Wrapper around the https://github.com/mseitzer/pytorch-fid inception
194
+ port with an additional squeeze at the end"""
195
+
196
+ def __init__(self, normalize_input=False, **kwargs):
197
+ super().__init__()
198
+ from pytorch_fid import inception
199
+
200
+ kwargs["resize_input"] = True
201
+ self.model = inception.InceptionV3(normalize_input=normalize_input, **kwargs)
202
+
203
+ def forward(self, inp):
204
+ # inp = kornia.geometry.resize(inp, (299, 299),
205
+ # interpolation='bicubic',
206
+ # align_corners=False,
207
+ # antialias=True)
208
+ # inp = inp.clamp(min=-1, max=1)
209
+
210
+ outp = self.model(inp)
211
+
212
+ if len(outp) == 1:
213
+ return outp[0].squeeze()
214
+
215
+ return outp
216
+
217
+
218
+ class IdentityEncoder(AbstractEmbModel):
219
+ def encode(self, x):
220
+ return x
221
+
222
+ def forward(self, x):
223
+ return x
224
+
225
+
226
+ class ClassEmbedder(AbstractEmbModel):
227
+ def __init__(self, embed_dim, n_classes=1000, add_sequence_dim=False):
228
+ super().__init__()
229
+ self.embedding = nn.Embedding(n_classes, embed_dim)
230
+ self.n_classes = n_classes
231
+ self.add_sequence_dim = add_sequence_dim
232
+
233
+ def forward(self, c):
234
+ c = self.embedding(c)
235
+ if self.add_sequence_dim:
236
+ c = c[:, None, :]
237
+ return c
238
+
239
+ def get_unconditional_conditioning(self, bs, device="cuda"):
240
+ uc_class = (
241
+ self.n_classes - 1
242
+ ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
243
+ uc = torch.ones((bs,), device=device) * uc_class
244
+ uc = {self.key: uc.long()}
245
+ return uc
246
+
247
+
248
+ class ClassEmbedderForMultiCond(ClassEmbedder):
249
+ def forward(self, batch, key=None, disable_dropout=False):
250
+ out = batch
251
+ key = default(key, self.key)
252
+ islist = isinstance(batch[key], list)
253
+ if islist:
254
+ batch[key] = batch[key][0]
255
+ c_out = super().forward(batch, key, disable_dropout)
256
+ out[key] = [c_out] if islist else c_out
257
+ return out
258
+
259
+
260
+ class FrozenT5Embedder(AbstractEmbModel):
261
+ """Uses the T5 transformer encoder for text"""
262
+
263
+ def __init__(
264
+ self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True
265
+ ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
266
+ super().__init__()
267
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
268
+ self.transformer = T5EncoderModel.from_pretrained(version)
269
+ self.device = device
270
+ self.max_length = max_length
271
+ if freeze:
272
+ self.freeze()
273
+
274
+ def freeze(self):
275
+ self.transformer = self.transformer.eval()
276
+
277
+ for param in self.parameters():
278
+ param.requires_grad = False
279
+
280
+ # @autocast
281
+ def forward(self, text):
282
+ batch_encoding = self.tokenizer(
283
+ text,
284
+ truncation=True,
285
+ max_length=self.max_length,
286
+ return_length=True,
287
+ return_overflowing_tokens=False,
288
+ padding="max_length",
289
+ return_tensors="pt",
290
+ )
291
+ tokens = batch_encoding["input_ids"].to(self.device)
292
+ with torch.autocast("cuda", enabled=False):
293
+ outputs = self.transformer(input_ids=tokens)
294
+ z = outputs.last_hidden_state
295
+ return z
296
+
297
+ def encode(self, text):
298
+ return self(text)
299
+
300
+
301
+ class FrozenByT5Embedder(AbstractEmbModel):
302
+ """
303
+ Uses the ByT5 transformer encoder for text. Is character-aware.
304
+ """
305
+
306
+ def __init__(
307
+ self, version="google/byt5-base", device="cuda", max_length=77, freeze=True
308
+ ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
309
+ super().__init__()
310
+ self.tokenizer = ByT5Tokenizer.from_pretrained(version)
311
+ self.transformer = T5EncoderModel.from_pretrained(version)
312
+ self.device = device
313
+ self.max_length = max_length
314
+ if freeze:
315
+ self.freeze()
316
+
317
+ def freeze(self):
318
+ self.transformer = self.transformer.eval()
319
+
320
+ for param in self.parameters():
321
+ param.requires_grad = False
322
+
323
+ def forward(self, text):
324
+ batch_encoding = self.tokenizer(
325
+ text,
326
+ truncation=True,
327
+ max_length=self.max_length,
328
+ return_length=True,
329
+ return_overflowing_tokens=False,
330
+ padding="max_length",
331
+ return_tensors="pt",
332
+ )
333
+ tokens = batch_encoding["input_ids"].to(self.device)
334
+ with torch.autocast("cuda", enabled=False):
335
+ outputs = self.transformer(input_ids=tokens)
336
+ z = outputs.last_hidden_state
337
+ return z
338
+
339
+ def encode(self, text):
340
+ return self(text)
341
+
342
+
343
+ class FrozenCLIPEmbedder(AbstractEmbModel):
344
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
345
+
346
+ LAYERS = ["last", "pooled", "hidden"]
347
+
348
+ def __init__(
349
+ self,
350
+ version="openai/clip-vit-large-patch14",
351
+ device="cuda",
352
+ max_length=77,
353
+ freeze=True,
354
+ layer="last",
355
+ layer_idx=None,
356
+ always_return_pooled=False,
357
+ ): # clip-vit-base-patch32
358
+ super().__init__()
359
+ assert layer in self.LAYERS
360
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
361
+ self.transformer = CLIPTextModel.from_pretrained(version)
362
+ self.device = device
363
+ self.max_length = max_length
364
+ if freeze:
365
+ self.freeze()
366
+ self.layer = layer
367
+ self.layer_idx = layer_idx
368
+ self.return_pooled = always_return_pooled
369
+ if layer == "hidden":
370
+ assert layer_idx is not None
371
+ assert 0 <= abs(layer_idx) <= 12
372
+
373
+ def freeze(self):
374
+ self.transformer = self.transformer.eval()
375
+
376
+ for param in self.parameters():
377
+ param.requires_grad = False
378
+
379
+ @autocast
380
+ def forward(self, text):
381
+ batch_encoding = self.tokenizer(
382
+ text,
383
+ truncation=True,
384
+ max_length=self.max_length,
385
+ return_length=True,
386
+ return_overflowing_tokens=False,
387
+ padding="max_length",
388
+ return_tensors="pt",
389
+ )
390
+ tokens = batch_encoding["input_ids"].to(self.device)
391
+ outputs = self.transformer(
392
+ input_ids=tokens, output_hidden_states=self.layer == "hidden"
393
+ )
394
+ if self.layer == "last":
395
+ z = outputs.last_hidden_state
396
+ elif self.layer == "pooled":
397
+ z = outputs.pooler_output[:, None, :]
398
+ else:
399
+ z = outputs.hidden_states[self.layer_idx]
400
+ if self.return_pooled:
401
+ return z, outputs.pooler_output
402
+ return z
403
+
404
+ def encode(self, text):
405
+ return self(text)
406
+
407
+
408
+ class FrozenOpenCLIPEmbedder2(AbstractEmbModel):
409
+ """
410
+ Uses the OpenCLIP transformer encoder for text
411
+ """
412
+
413
+ LAYERS = ["pooled", "last", "penultimate"]
414
+
415
+ def __init__(
416
+ self,
417
+ arch="ViT-H-14",
418
+ version="laion2b_s32b_b79k",
419
+ device="cuda",
420
+ max_length=77,
421
+ freeze=True,
422
+ layer="last",
423
+ always_return_pooled=False,
424
+ legacy=True,
425
+ ):
426
+ super().__init__()
427
+ assert layer in self.LAYERS
428
+ model, _, _ = open_clip.create_model_and_transforms(
429
+ arch,
430
+ device=torch.device("cpu"),
431
+ pretrained=version,
432
+ )
433
+ del model.visual
434
+ self.model = model
435
+
436
+ self.device = device
437
+ self.max_length = max_length
438
+ self.return_pooled = always_return_pooled
439
+ if freeze:
440
+ self.freeze()
441
+ self.layer = layer
442
+ if self.layer == "last":
443
+ self.layer_idx = 0
444
+ elif self.layer == "penultimate":
445
+ self.layer_idx = 1
446
+ else:
447
+ raise NotImplementedError()
448
+ self.legacy = legacy
449
+
450
+ def freeze(self):
451
+ self.model = self.model.eval()
452
+ for param in self.parameters():
453
+ param.requires_grad = False
454
+
455
+ @autocast
456
+ def forward(self, text):
457
+ tokens = open_clip.tokenize(text)
458
+ z = self.encode_with_transformer(tokens.to(self.device))
459
+ if not self.return_pooled and self.legacy:
460
+ return z
461
+ if self.return_pooled:
462
+ assert not self.legacy
463
+ return z[self.layer], z["pooled"]
464
+ return z[self.layer]
465
+
466
+ def encode_with_transformer(self, text):
467
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
468
+ x = x + self.model.positional_embedding
469
+ x = x.permute(1, 0, 2) # NLD -> LND
470
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
471
+ if self.legacy:
472
+ x = x[self.layer]
473
+ x = self.model.ln_final(x)
474
+ return x
475
+ else:
476
+ # x is a dict and will stay a dict
477
+ o = x["last"]
478
+ o = self.model.ln_final(o)
479
+ pooled = self.pool(o, text)
480
+ x["pooled"] = pooled
481
+ return x
482
+
483
+ def pool(self, x, text):
484
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
485
+ x = (
486
+ x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
487
+ @ self.model.text_projection
488
+ )
489
+ return x
490
+
491
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
492
+ outputs = {}
493
+ for i, r in enumerate(self.model.transformer.resblocks):
494
+ if i == len(self.model.transformer.resblocks) - 1:
495
+ outputs["penultimate"] = x.permute(1, 0, 2) # LND -> NLD
496
+ if (
497
+ self.model.transformer.grad_checkpointing
498
+ and not torch.jit.is_scripting()
499
+ ):
500
+ x = checkpoint(r, x, attn_mask)
501
+ else:
502
+ x = r(x, attn_mask=attn_mask)
503
+ outputs["last"] = x.permute(1, 0, 2) # LND -> NLD
504
+ return outputs
505
+
506
+ def encode(self, text):
507
+ return self(text)
508
+
509
+
510
+ class FrozenOpenCLIPEmbedder(AbstractEmbModel):
511
+ LAYERS = [
512
+ # "pooled",
513
+ "last",
514
+ "penultimate",
515
+ ]
516
+
517
+ def __init__(
518
+ self,
519
+ arch="ViT-H-14",
520
+ version="laion2b_s32b_b79k",
521
+ device="cuda",
522
+ max_length=77,
523
+ freeze=True,
524
+ layer="last",
525
+ ):
526
+ super().__init__()
527
+ assert layer in self.LAYERS
528
+ model, _, _ = open_clip.create_model_and_transforms(
529
+ arch, device=torch.device("cpu"), pretrained=version
530
+ )
531
+ del model.visual
532
+ self.model = model
533
+
534
+ self.device = device
535
+ self.max_length = max_length
536
+ if freeze:
537
+ self.freeze()
538
+ self.layer = layer
539
+ if self.layer == "last":
540
+ self.layer_idx = 0
541
+ elif self.layer == "penultimate":
542
+ self.layer_idx = 1
543
+ else:
544
+ raise NotImplementedError()
545
+
546
+ def freeze(self):
547
+ self.model = self.model.eval()
548
+ for param in self.parameters():
549
+ param.requires_grad = False
550
+
551
+ def forward(self, text):
552
+ tokens = open_clip.tokenize(text)
553
+ z = self.encode_with_transformer(tokens.to(self.device))
554
+ return z
555
+
556
+ def encode_with_transformer(self, text):
557
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
558
+ x = x + self.model.positional_embedding
559
+ x = x.permute(1, 0, 2) # NLD -> LND
560
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
561
+ x = x.permute(1, 0, 2) # LND -> NLD
562
+ x = self.model.ln_final(x)
563
+ return x
564
+
565
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
566
+ for i, r in enumerate(self.model.transformer.resblocks):
567
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
568
+ break
569
+ if (
570
+ self.model.transformer.grad_checkpointing
571
+ and not torch.jit.is_scripting()
572
+ ):
573
+ x = checkpoint(r, x, attn_mask)
574
+ else:
575
+ x = r(x, attn_mask=attn_mask)
576
+ return x
577
+
578
+ def encode(self, text):
579
+ return self(text)
580
+
581
+
582
+ class FrozenOpenCLIPImageEmbedder(AbstractEmbModel):
583
+ """
584
+ Uses the OpenCLIP vision transformer encoder for images
585
+ """
586
+
587
+ def __init__(
588
+ self,
589
+ arch="ViT-H-14",
590
+ version="laion2b_s32b_b79k",
591
+ device="cuda",
592
+ max_length=77,
593
+ freeze=True,
594
+ antialias=True,
595
+ ucg_rate=0.0,
596
+ unsqueeze_dim=False,
597
+ repeat_to_max_len=False,
598
+ num_image_crops=0,
599
+ output_tokens=False,
600
+ ):
601
+ super().__init__()
602
+ model, _, _ = open_clip.create_model_and_transforms(
603
+ arch,
604
+ device=torch.device("cpu"),
605
+ pretrained=version,
606
+ )
607
+ del model.transformer
608
+ self.model = model
609
+ self.max_crops = num_image_crops
610
+ self.pad_to_max_len = self.max_crops > 0
611
+ self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len)
612
+ self.device = device
613
+ self.max_length = max_length
614
+ if freeze:
615
+ self.freeze()
616
+
617
+ self.antialias = antialias
618
+
619
+ self.register_buffer(
620
+ "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
621
+ )
622
+ self.register_buffer(
623
+ "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
624
+ )
625
+ self.ucg_rate = ucg_rate
626
+ self.unsqueeze_dim = unsqueeze_dim
627
+ self.stored_batch = None
628
+ self.model.visual.output_tokens = output_tokens
629
+ self.output_tokens = output_tokens
630
+
631
+ def preprocess(self, x):
632
+ # normalize to [0,1]
633
+ x = kornia.geometry.resize(
634
+ x,
635
+ (224, 224),
636
+ interpolation="bicubic",
637
+ align_corners=True,
638
+ antialias=self.antialias,
639
+ )
640
+ x = (x + 1.0) / 2.0
641
+ # renormalize according to clip
642
+ x = kornia.enhance.normalize(x, self.mean, self.std)
643
+ return x
644
+
645
+ def freeze(self):
646
+ self.model = self.model.eval()
647
+ for param in self.parameters():
648
+ param.requires_grad = False
649
+
650
+ @autocast
651
+ def forward(self, image, no_dropout=False):
652
+ z = self.encode_with_vision_transformer(image)
653
+ tokens = None
654
+ if self.output_tokens:
655
+ z, tokens = z[0], z[1]
656
+ z = z.to(image.dtype)
657
+ if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0):
658
+ z = (
659
+ torch.bernoulli(
660
+ (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device)
661
+ )[:, None]
662
+ * z
663
+ )
664
+ if tokens is not None:
665
+ tokens = (
666
+ expand_dims_like(
667
+ torch.bernoulli(
668
+ (1.0 - self.ucg_rate)
669
+ * torch.ones(tokens.shape[0], device=tokens.device)
670
+ ),
671
+ tokens,
672
+ )
673
+ * tokens
674
+ )
675
+ if self.unsqueeze_dim:
676
+ z = z[:, None, :]
677
+ if self.output_tokens:
678
+ assert not self.repeat_to_max_len
679
+ assert not self.pad_to_max_len
680
+ return tokens, z
681
+ if self.repeat_to_max_len:
682
+ if z.dim() == 2:
683
+ z_ = z[:, None, :]
684
+ else:
685
+ z_ = z
686
+ return repeat(z_, "b 1 d -> b n d", n=self.max_length), z
687
+ elif self.pad_to_max_len:
688
+ assert z.dim() == 3
689
+ z_pad = torch.cat(
690
+ (
691
+ z,
692
+ torch.zeros(
693
+ z.shape[0],
694
+ self.max_length - z.shape[1],
695
+ z.shape[2],
696
+ device=z.device,
697
+ ),
698
+ ),
699
+ 1,
700
+ )
701
+ return z_pad, z_pad[:, 0, ...]
702
+ return z
703
+
704
+ def encode_with_vision_transformer(self, img):
705
+ # if self.max_crops > 0:
706
+ # img = self.preprocess_by_cropping(img)
707
+ if img.dim() == 5:
708
+ assert self.max_crops == img.shape[1]
709
+ img = rearrange(img, "b n c h w -> (b n) c h w")
710
+ img = self.preprocess(img)
711
+ if not self.output_tokens:
712
+ assert not self.model.visual.output_tokens
713
+ x = self.model.visual(img)
714
+ tokens = None
715
+ else:
716
+ assert self.model.visual.output_tokens
717
+ x, tokens = self.model.visual(img)
718
+ if self.max_crops > 0:
719
+ x = rearrange(x, "(b n) d -> b n d", n=self.max_crops)
720
+ # drop out between 0 and all along the sequence axis
721
+ x = (
722
+ torch.bernoulli(
723
+ (1.0 - self.ucg_rate)
724
+ * torch.ones(x.shape[0], x.shape[1], 1, device=x.device)
725
+ )
726
+ * x
727
+ )
728
+ if tokens is not None:
729
+ tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
730
+ print(
731
+ f"You are running very experimental token-concat in {self.__class__.__name__}. "
732
+ f"Check what you are doing, and then remove this message."
733
+ )
734
+ if self.output_tokens:
735
+ return x, tokens
736
+ return x
737
+
738
+ def encode(self, text):
739
+ return self(text)
740
+
741
+
742
+ class FrozenCLIPT5Encoder(AbstractEmbModel):
743
+ def __init__(
744
+ self,
745
+ clip_version="openai/clip-vit-large-patch14",
746
+ t5_version="google/t5-v1_1-xl",
747
+ device="cuda",
748
+ clip_max_length=77,
749
+ t5_max_length=77,
750
+ ):
751
+ super().__init__()
752
+ self.clip_encoder = FrozenCLIPEmbedder(
753
+ clip_version, device, max_length=clip_max_length
754
+ )
755
+ self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
756
+ print(
757
+ f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
758
+ f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params."
759
+ )
760
+
761
+ def encode(self, text):
762
+ return self(text)
763
+
764
+ def forward(self, text):
765
+ clip_z = self.clip_encoder.encode(text)
766
+ t5_z = self.t5_encoder.encode(text)
767
+ return [clip_z, t5_z]
768
+
769
+
770
+ class SpatialRescaler(nn.Module):
771
+ def __init__(
772
+ self,
773
+ n_stages=1,
774
+ method="bilinear",
775
+ multiplier=0.5,
776
+ in_channels=3,
777
+ out_channels=None,
778
+ bias=False,
779
+ wrap_video=False,
780
+ kernel_size=1,
781
+ remap_output=False,
782
+ ):
783
+ super().__init__()
784
+ self.n_stages = n_stages
785
+ assert self.n_stages >= 0
786
+ assert method in [
787
+ "nearest",
788
+ "linear",
789
+ "bilinear",
790
+ "trilinear",
791
+ "bicubic",
792
+ "area",
793
+ ]
794
+ self.multiplier = multiplier
795
+ self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
796
+ self.remap_output = out_channels is not None or remap_output
797
+ if self.remap_output:
798
+ print(
799
+ f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing."
800
+ )
801
+ self.channel_mapper = nn.Conv2d(
802
+ in_channels,
803
+ out_channels,
804
+ kernel_size=kernel_size,
805
+ bias=bias,
806
+ padding=kernel_size // 2,
807
+ )
808
+ self.wrap_video = wrap_video
809
+
810
+ def forward(self, x):
811
+ if self.wrap_video and x.ndim == 5:
812
+ B, C, T, H, W = x.shape
813
+ x = rearrange(x, "b c t h w -> b t c h w")
814
+ x = rearrange(x, "b t c h w -> (b t) c h w")
815
+
816
+ for stage in range(self.n_stages):
817
+ x = self.interpolator(x, scale_factor=self.multiplier)
818
+
819
+ if self.wrap_video:
820
+ x = rearrange(x, "(b t) c h w -> b t c h w", b=B, t=T, c=C)
821
+ x = rearrange(x, "b t c h w -> b c t h w")
822
+ if self.remap_output:
823
+ x = self.channel_mapper(x)
824
+ return x
825
+
826
+ def encode(self, x):
827
+ return self(x)
828
+
829
+
830
+ class LowScaleEncoder(nn.Module):
831
+ def __init__(
832
+ self,
833
+ model_config,
834
+ linear_start,
835
+ linear_end,
836
+ timesteps=1000,
837
+ max_noise_level=250,
838
+ output_size=64,
839
+ scale_factor=1.0,
840
+ ):
841
+ super().__init__()
842
+ self.max_noise_level = max_noise_level
843
+ self.model = instantiate_from_config(model_config)
844
+ self.augmentation_schedule = self.register_schedule(
845
+ timesteps=timesteps, linear_start=linear_start, linear_end=linear_end
846
+ )
847
+ self.out_size = output_size
848
+ self.scale_factor = scale_factor
849
+
850
+ def register_schedule(
851
+ self,
852
+ beta_schedule="linear",
853
+ timesteps=1000,
854
+ linear_start=1e-4,
855
+ linear_end=2e-2,
856
+ cosine_s=8e-3,
857
+ ):
858
+ betas = make_beta_schedule(
859
+ beta_schedule,
860
+ timesteps,
861
+ linear_start=linear_start,
862
+ linear_end=linear_end,
863
+ cosine_s=cosine_s,
864
+ )
865
+ alphas = 1.0 - betas
866
+ alphas_cumprod = np.cumprod(alphas, axis=0)
867
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
868
+
869
+ (timesteps,) = betas.shape
870
+ self.num_timesteps = int(timesteps)
871
+ self.linear_start = linear_start
872
+ self.linear_end = linear_end
873
+ assert (
874
+ alphas_cumprod.shape[0] == self.num_timesteps
875
+ ), "alphas have to be defined for each timestep"
876
+
877
+ to_torch = partial(torch.tensor, dtype=torch.float32)
878
+
879
+ self.register_buffer("betas", to_torch(betas))
880
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
881
+ self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
882
+
883
+ # calculations for diffusion q(x_t | x_{t-1}) and others
884
+ self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
885
+ self.register_buffer(
886
+ "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
887
+ )
888
+ self.register_buffer(
889
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
890
+ )
891
+ self.register_buffer(
892
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
893
+ )
894
+ self.register_buffer(
895
+ "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
896
+ )
897
+
898
+ def q_sample(self, x_start, t, noise=None):
899
+ noise = default(noise, lambda: torch.randn_like(x_start))
900
+ return (
901
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
902
+ + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
903
+ * noise
904
+ )
905
+
906
+ def forward(self, x):
907
+ z = self.model.encode(x)
908
+ if isinstance(z, DiagonalGaussianDistribution):
909
+ z = z.sample()
910
+ z = z * self.scale_factor
911
+ noise_level = torch.randint(
912
+ 0, self.max_noise_level, (x.shape[0],), device=x.device
913
+ ).long()
914
+ z = self.q_sample(z, noise_level)
915
+ if self.out_size is not None:
916
+ z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest")
917
+ # z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
918
+ return z, noise_level
919
+
920
+ def decode(self, z):
921
+ z = z / self.scale_factor
922
+ return self.model.decode(z)
923
+
924
+
925
+ class ConcatTimestepEmbedderND(AbstractEmbModel):
926
+ """embeds each dimension independently and concatenates them"""
927
+
928
+ def __init__(self, outdim):
929
+ super().__init__()
930
+ self.timestep = Timestep(outdim)
931
+ self.outdim = outdim
932
+
933
+ def forward(self, x):
934
+ if x.ndim == 1:
935
+ x = x[:, None]
936
+ assert len(x.shape) == 2
937
+ b, dims = x.shape[0], x.shape[1]
938
+ x = rearrange(x, "b d -> (b d)")
939
+ emb = self.timestep(x)
940
+ emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
941
+ return emb
942
+
943
+
944
+ class GaussianEncoder(Encoder, AbstractEmbModel):
945
+ def __init__(
946
+ self, weight: float = 1.0, flatten_output: bool = True, *args, **kwargs
947
+ ):
948
+ super().__init__(*args, **kwargs)
949
+ self.posterior = DiagonalGaussianRegularizer()
950
+ self.weight = weight
951
+ self.flatten_output = flatten_output
952
+
953
+ def forward(self, x) -> Tuple[Dict, torch.Tensor]:
954
+ z = super().forward(x)
955
+ z, log = self.posterior(z)
956
+ log["loss"] = log["kl_loss"]
957
+ log["weight"] = self.weight
958
+ if self.flatten_output:
959
+ z = rearrange(z, "b c h w -> b (h w ) c")
960
+ return log, z
sgm/util.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import importlib
3
+ import os
4
+ from functools import partial
5
+ from inspect import isfunction
6
+
7
+ import fsspec
8
+ import numpy as np
9
+ import torch
10
+ from PIL import Image, ImageDraw, ImageFont
11
+ from safetensors.torch import load_file as load_safetensors
12
+
13
+
14
+ def disabled_train(self, mode=True):
15
+ """Overwrite model.train with this function to make sure train/eval mode
16
+ does not change anymore."""
17
+ return self
18
+
19
+
20
+ def get_string_from_tuple(s):
21
+ try:
22
+ # Check if the string starts and ends with parentheses
23
+ if s[0] == "(" and s[-1] == ")":
24
+ # Convert the string to a tuple
25
+ t = eval(s)
26
+ # Check if the type of t is tuple
27
+ if type(t) == tuple:
28
+ return t[0]
29
+ else:
30
+ pass
31
+ except:
32
+ pass
33
+ return s
34
+
35
+
36
+ def is_power_of_two(n):
37
+ """
38
+ chat.openai.com/chat
39
+ Return True if n is a power of 2, otherwise return False.
40
+
41
+ The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False.
42
+ The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False.
43
+ If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise.
44
+ Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False.
45
+
46
+ """
47
+ if n <= 0:
48
+ return False
49
+ return (n & (n - 1)) == 0
50
+
51
+
52
+ def autocast(f, enabled=True):
53
+ def do_autocast(*args, **kwargs):
54
+ with torch.cuda.amp.autocast(
55
+ enabled=enabled,
56
+ dtype=torch.get_autocast_gpu_dtype(),
57
+ cache_enabled=torch.is_autocast_cache_enabled(),
58
+ ):
59
+ return f(*args, **kwargs)
60
+
61
+ return do_autocast
62
+
63
+
64
+ def load_partial_from_config(config):
65
+ return partial(get_obj_from_str(config["target"]), **config.get("params", dict()))
66
+
67
+
68
+ def log_txt_as_img(wh, xc, size=10):
69
+ # wh a tuple of (width, height)
70
+ # xc a list of captions to plot
71
+ b = len(xc)
72
+ txts = list()
73
+ for bi in range(b):
74
+ txt = Image.new("RGB", wh, color="white")
75
+ draw = ImageDraw.Draw(txt)
76
+ font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
77
+ nc = int(40 * (wh[0] / 256))
78
+ if isinstance(xc[bi], list):
79
+ text_seq = xc[bi][0]
80
+ else:
81
+ text_seq = xc[bi]
82
+ lines = "\n".join(
83
+ text_seq[start : start + nc] for start in range(0, len(text_seq), nc)
84
+ )
85
+
86
+ try:
87
+ draw.text((0, 0), lines, fill="black", font=font)
88
+ except UnicodeEncodeError:
89
+ print("Cant encode string for logging. Skipping.")
90
+
91
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
92
+ txts.append(txt)
93
+ txts = np.stack(txts)
94
+ txts = torch.tensor(txts)
95
+ return txts
96
+
97
+
98
+ def partialclass(cls, *args, **kwargs):
99
+ class NewCls(cls):
100
+ __init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
101
+
102
+ return NewCls
103
+
104
+
105
+ def make_path_absolute(path):
106
+ fs, p = fsspec.core.url_to_fs(path)
107
+ if fs.protocol == "file":
108
+ return os.path.abspath(p)
109
+ return path
110
+
111
+
112
+ def ismap(x):
113
+ if not isinstance(x, torch.Tensor):
114
+ return False
115
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
116
+
117
+
118
+ def isimage(x):
119
+ if not isinstance(x, torch.Tensor):
120
+ return False
121
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
122
+
123
+
124
+ def isheatmap(x):
125
+ if not isinstance(x, torch.Tensor):
126
+ return False
127
+
128
+ return x.ndim == 2
129
+
130
+
131
+ def isneighbors(x):
132
+ if not isinstance(x, torch.Tensor):
133
+ return False
134
+ return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1)
135
+
136
+
137
+ def exists(x):
138
+ return x is not None
139
+
140
+
141
+ def expand_dims_like(x, y):
142
+ while x.dim() != y.dim():
143
+ x = x.unsqueeze(-1)
144
+ return x
145
+
146
+
147
+ def default(val, d):
148
+ if exists(val):
149
+ return val
150
+ return d() if isfunction(d) else d
151
+
152
+
153
+ def mean_flat(tensor):
154
+ """
155
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
156
+ Take the mean over all non-batch dimensions.
157
+ """
158
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
159
+
160
+
161
+ def count_params(model, verbose=False):
162
+ total_params = sum(p.numel() for p in model.parameters())
163
+ if verbose:
164
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
165
+ return total_params
166
+
167
+
168
+ def instantiate_from_config(config):
169
+ if not "target" in config:
170
+ if config == "__is_first_stage__":
171
+ return None
172
+ elif config == "__is_unconditional__":
173
+ return None
174
+ raise KeyError("Expected key `target` to instantiate.")
175
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
176
+
177
+
178
+ def get_obj_from_str(string, reload=False, invalidate_cache=True):
179
+ module, cls = string.rsplit(".", 1)
180
+ if invalidate_cache:
181
+ importlib.invalidate_caches()
182
+ if reload:
183
+ module_imp = importlib.import_module(module)
184
+ importlib.reload(module_imp)
185
+ return getattr(importlib.import_module(module, package=None), cls)
186
+
187
+
188
+ def append_zero(x):
189
+ return torch.cat([x, x.new_zeros([1])])
190
+
191
+
192
+ def append_dims(x, target_dims):
193
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
194
+ dims_to_append = target_dims - x.ndim
195
+ if dims_to_append < 0:
196
+ raise ValueError(
197
+ f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
198
+ )
199
+ return x[(...,) + (None,) * dims_to_append]
200
+
201
+
202
+ def load_model_from_config(config, ckpt, verbose=True, freeze=True):
203
+ print(f"Loading model from {ckpt}")
204
+ if ckpt.endswith("ckpt"):
205
+ pl_sd = torch.load(ckpt, map_location="cpu")
206
+ if "global_step" in pl_sd:
207
+ print(f"Global Step: {pl_sd['global_step']}")
208
+ sd = pl_sd["state_dict"]
209
+ elif ckpt.endswith("safetensors"):
210
+ sd = load_safetensors(ckpt)
211
+ else:
212
+ raise NotImplementedError
213
+
214
+ model = instantiate_from_config(config.model)
215
+
216
+ m, u = model.load_state_dict(sd, strict=False)
217
+
218
+ if len(m) > 0 and verbose:
219
+ print("missing keys:")
220
+ print(m)
221
+ if len(u) > 0 and verbose:
222
+ print("unexpected keys:")
223
+ print(u)
224
+
225
+ if freeze:
226
+ for param in model.parameters():
227
+ param.requires_grad = False
228
+
229
+ model.eval()
230
+ return model
231
+
232
+
233
+ def get_configs_path() -> str:
234
+ """
235
+ Get the `configs` directory.
236
+ For a working copy, this is the one in the root of the repository,
237
+ but for an installed copy, it's in the `sgm` package (see pyproject.toml).
238
+ """
239
+ this_dir = os.path.dirname(__file__)
240
+ candidates = (
241
+ os.path.join(this_dir, "configs"),
242
+ os.path.join(this_dir, "..", "configs"),
243
+ )
244
+ for candidate in candidates:
245
+ candidate = os.path.abspath(candidate)
246
+ if os.path.isdir(candidate):
247
+ return candidate
248
+ raise FileNotFoundError(f"Could not find SGM configs in {candidates}")