Text-to-3D
image-to-3d
Chao Xu commited on
Commit
169a228
1 Parent(s): 53ab577

add taming

Browse files
Files changed (45) hide show
  1. .gitignore +2 -1
  2. taming-transformers/.gitignore +2 -0
  3. taming-transformers/License.txt +19 -0
  4. taming-transformers/README.md +410 -0
  5. taming-transformers/configs/coco_cond_stage.yaml +49 -0
  6. taming-transformers/configs/coco_scene_images_transformer.yaml +80 -0
  7. taming-transformers/configs/custom_vqgan.yaml +43 -0
  8. taming-transformers/configs/drin_transformer.yaml +77 -0
  9. taming-transformers/configs/faceshq_transformer.yaml +61 -0
  10. taming-transformers/configs/faceshq_vqgan.yaml +42 -0
  11. taming-transformers/configs/imagenet_vqgan.yaml +42 -0
  12. taming-transformers/configs/imagenetdepth_vqgan.yaml +41 -0
  13. taming-transformers/configs/open_images_scene_images_transformer.yaml +86 -0
  14. taming-transformers/configs/sflckr_cond_stage.yaml +43 -0
  15. taming-transformers/environment.yaml +25 -0
  16. taming-transformers/main.py +585 -0
  17. taming-transformers/scripts/extract_depth.py +112 -0
  18. taming-transformers/scripts/extract_segmentation.py +130 -0
  19. taming-transformers/scripts/extract_submodel.py +17 -0
  20. taming-transformers/scripts/make_samples.py +292 -0
  21. taming-transformers/scripts/make_scene_samples.py +198 -0
  22. taming-transformers/scripts/sample_conditional.py +355 -0
  23. taming-transformers/scripts/sample_fast.py +260 -0
  24. taming-transformers/setup.py +13 -0
  25. taming-transformers/taming/lr_scheduler.py +34 -0
  26. taming-transformers/taming/models/cond_transformer.py +352 -0
  27. taming-transformers/taming/models/dummy_cond_stage.py +22 -0
  28. taming-transformers/taming/models/vqgan.py +404 -0
  29. taming-transformers/taming/modules/diffusionmodules/model.py +776 -0
  30. taming-transformers/taming/modules/discriminator/model.py +67 -0
  31. taming-transformers/taming/modules/losses/__init__.py +2 -0
  32. taming-transformers/taming/modules/losses/lpips.py +123 -0
  33. taming-transformers/taming/modules/losses/segmentation.py +22 -0
  34. taming-transformers/taming/modules/losses/vqperceptual.py +136 -0
  35. taming-transformers/taming/modules/misc/coord.py +31 -0
  36. taming-transformers/taming/modules/transformer/mingpt.py +415 -0
  37. taming-transformers/taming/modules/transformer/permuter.py +248 -0
  38. taming-transformers/taming/modules/util.py +130 -0
  39. taming-transformers/taming/modules/vqvae/quantize.py +445 -0
  40. taming-transformers/taming/util.py +157 -0
  41. taming-transformers/taming_transformers.egg-info/PKG-INFO +10 -0
  42. taming-transformers/taming_transformers.egg-info/SOURCES.txt +7 -0
  43. taming-transformers/taming_transformers.egg-info/dependency_links.txt +1 -0
  44. taming-transformers/taming_transformers.egg-info/requires.txt +3 -0
  45. taming-transformers/taming_transformers.egg-info/top_level.txt +1 -0
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  __pycache__/
2
- *.DS_Store
 
 
1
  __pycache__/
2
+ *.DS_Store
3
+ *.ipynb
taming-transformers/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ assets/
2
+ data/
taming-transformers/License.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ of this software and associated documentation files (the "Software"), to deal
5
+ in the Software without restriction, including without limitation the rights
6
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ copies of the Software, and to permit persons to whom the Software is
8
+ furnished to do so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in all
11
+ copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
14
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
15
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
16
+ IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
17
+ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
18
+ OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
19
+ OR OTHER DEALINGS IN THE SOFTWARE./
taming-transformers/README.md ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taming Transformers for High-Resolution Image Synthesis
2
+ ##### CVPR 2021 (Oral)
3
+ ![teaser](assets/mountain.jpeg)
4
+
5
+ [**Taming Transformers for High-Resolution Image Synthesis**](https://compvis.github.io/taming-transformers/)<br/>
6
+ [Patrick Esser](https://github.com/pesser)\*,
7
+ [Robin Rombach](https://github.com/rromb)\*,
8
+ [Björn Ommer](https://hci.iwr.uni-heidelberg.de/Staff/bommer)<br/>
9
+ \* equal contribution
10
+
11
+ **tl;dr** We combine the efficiancy of convolutional approaches with the expressivity of transformers by introducing a convolutional VQGAN, which learns a codebook of context-rich visual parts, whose composition is modeled with an autoregressive transformer.
12
+
13
+ ![teaser](assets/teaser.png)
14
+ [arXiv](https://arxiv.org/abs/2012.09841) | [BibTeX](#bibtex) | [Project Page](https://compvis.github.io/taming-transformers/)
15
+
16
+
17
+ ### News
18
+ #### 2022
19
+ - More pretrained VQGANs (e.g. a f8-model with only 256 codebook entries) are available in our new work on [Latent Diffusion Models](https://github.com/CompVis/latent-diffusion).
20
+ - Added scene synthesis models as proposed in the paper [High-Resolution Complex Scene Synthesis with Transformers](https://arxiv.org/abs/2105.06458), see [this section](#scene-image-synthesis).
21
+ #### 2021
22
+ - Thanks to [rom1504](https://github.com/rom1504) it is now easy to [train a VQGAN on your own datasets](#training-on-custom-data).
23
+ - Included a bugfix for the quantizer. For backward compatibility it is
24
+ disabled by default (which corresponds to always training with `beta=1.0`).
25
+ Use `legacy=False` in the quantizer config to enable it.
26
+ Thanks [richcmwang](https://github.com/richcmwang) and [wcshin-git](https://github.com/wcshin-git)!
27
+ - Our paper received an update: See https://arxiv.org/abs/2012.09841v3 and the corresponding changelog.
28
+ - Added a pretrained, [1.4B transformer model](https://k00.fr/s511rwcv) trained for class-conditional ImageNet synthesis, which obtains state-of-the-art FID scores among autoregressive approaches and outperforms BigGAN.
29
+ - Added pretrained, unconditional models on [FFHQ](https://k00.fr/yndvfu95) and [CelebA-HQ](https://k00.fr/2xkmielf).
30
+ - Added accelerated sampling via caching of keys/values in the self-attention operation, used in `scripts/sample_fast.py`.
31
+ - Added a checkpoint of a [VQGAN](https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/) trained with f8 compression and Gumbel-Quantization.
32
+ See also our updated [reconstruction notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb).
33
+ - We added a [colab notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb) which compares two VQGANs and OpenAI's [DALL-E](https://github.com/openai/DALL-E). See also [this section](#more-resources).
34
+ - We now include an overview of pretrained models in [Tab.1](#overview-of-pretrained-models). We added models for [COCO](#coco) and [ADE20k](#ade20k).
35
+ - The streamlit demo now supports image completions.
36
+ - We now include a couple of examples from the D-RIN dataset so you can run the
37
+ [D-RIN demo](#d-rin) without preparing the dataset first.
38
+ - You can now jump right into sampling with our [Colab quickstart notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/taming-transformers.ipynb).
39
+
40
+ ## Requirements
41
+ A suitable [conda](https://conda.io/) environment named `taming` can be created
42
+ and activated with:
43
+
44
+ ```
45
+ conda env create -f environment.yaml
46
+ conda activate taming
47
+ ```
48
+ ## Overview of pretrained models
49
+ The following table provides an overview of all models that are currently available.
50
+ FID scores were evaluated using [torch-fidelity](https://github.com/toshas/torch-fidelity).
51
+ For reference, we also include a link to the recently released autoencoder of the [DALL-E](https://github.com/openai/DALL-E) model.
52
+ See the corresponding [colab
53
+ notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb)
54
+ for a comparison and discussion of reconstruction capabilities.
55
+
56
+ | Dataset | FID vs train | FID vs val | Link | Samples (256x256) | Comments
57
+ | ------------- | ------------- | ------------- |------------- | ------------- |------------- |
58
+ | FFHQ (f=16) | 9.6 | -- | [ffhq_transformer](https://k00.fr/yndvfu95) | [ffhq_samples](https://k00.fr/j626x093) |
59
+ | CelebA-HQ (f=16) | 10.2 | -- | [celebahq_transformer](https://k00.fr/2xkmielf) | [celebahq_samples](https://k00.fr/j626x093) |
60
+ | ADE20K (f=16) | -- | 35.5 | [ade20k_transformer](https://k00.fr/ot46cksa) | [ade20k_samples.zip](https://heibox.uni-heidelberg.de/f/70bb78cbaf844501b8fb/) [2k] | evaluated on val split (2k images)
61
+ | COCO-Stuff (f=16) | -- | 20.4 | [coco_transformer](https://k00.fr/2zz6i2ce) | [coco_samples.zip](https://heibox.uni-heidelberg.de/f/a395a9be612f4a7a8054/) [5k] | evaluated on val split (5k images)
62
+ | ImageNet (cIN) (f=16) | 15.98/15.78/6.59/5.88/5.20 | -- | [cin_transformer](https://k00.fr/s511rwcv) | [cin_samples](https://k00.fr/j626x093) | different decoding hyperparameters |
63
+ | | | | || |
64
+ | FacesHQ (f=16) | -- | -- | [faceshq_transformer](https://k00.fr/qqfl2do8)
65
+ | S-FLCKR (f=16) | -- | -- | [sflckr](https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/)
66
+ | D-RIN (f=16) | -- | -- | [drin_transformer](https://k00.fr/39jcugc5)
67
+ | | | | | || |
68
+ | VQGAN ImageNet (f=16), 1024 | 10.54 | 7.94 | [vqgan_imagenet_f16_1024](https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/) | [reconstructions](https://k00.fr/j626x093) | Reconstruction-FIDs.
69
+ | VQGAN ImageNet (f=16), 16384 | 7.41 | 4.98 |[vqgan_imagenet_f16_16384](https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/) | [reconstructions](https://k00.fr/j626x093) | Reconstruction-FIDs.
70
+ | VQGAN OpenImages (f=8), 256 | -- | 1.49 |https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip | --- | Reconstruction-FIDs. Available via [latent diffusion](https://github.com/CompVis/latent-diffusion).
71
+ | VQGAN OpenImages (f=8), 16384 | -- | 1.14 |https://ommer-lab.com/files/latent-diffusion/vq-f8.zip | --- | Reconstruction-FIDs. Available via [latent diffusion](https://github.com/CompVis/latent-diffusion)
72
+ | VQGAN OpenImages (f=8), 8192, GumbelQuantization | 3.24 | 1.49 |[vqgan_gumbel_f8](https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/) | --- | Reconstruction-FIDs.
73
+ | | | | | || |
74
+ | DALL-E dVAE (f=8), 8192, GumbelQuantization | 33.88 | 32.01 | https://github.com/openai/DALL-E | [reconstructions](https://k00.fr/j626x093) | Reconstruction-FIDs.
75
+
76
+
77
+ ## Running pretrained models
78
+
79
+ The commands below will start a streamlit demo which supports sampling at
80
+ different resolutions and image completions. To run a non-interactive version
81
+ of the sampling process, replace `streamlit run scripts/sample_conditional.py --`
82
+ by `python scripts/make_samples.py --outdir <path_to_write_samples_to>` and
83
+ keep the remaining command line arguments.
84
+
85
+ To sample from unconditional or class-conditional models,
86
+ run `python scripts/sample_fast.py -r <path/to/config_and_checkpoint>`.
87
+ We describe below how to use this script to sample from the ImageNet, FFHQ, and CelebA-HQ models,
88
+ respectively.
89
+
90
+ ### S-FLCKR
91
+ ![teaser](assets/sunset_and_ocean.jpg)
92
+
93
+ You can also [run this model in a Colab
94
+ notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/taming-transformers.ipynb),
95
+ which includes all necessary steps to start sampling.
96
+
97
+ Download the
98
+ [2020-11-09T13-31-51_sflckr](https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/)
99
+ folder and place it into `logs`. Then, run
100
+ ```
101
+ streamlit run scripts/sample_conditional.py -- -r logs/2020-11-09T13-31-51_sflckr/
102
+ ```
103
+
104
+ ### ImageNet
105
+ ![teaser](assets/imagenet.png)
106
+
107
+ Download the [2021-04-03T19-39-50_cin_transformer](https://k00.fr/s511rwcv)
108
+ folder and place it into logs. Sampling from the class-conditional ImageNet
109
+ model does not require any data preparation. To produce 50 samples for each of
110
+ the 1000 classes of ImageNet, with k=600 for top-k sampling, p=0.92 for nucleus
111
+ sampling and temperature t=1.0, run
112
+
113
+ ```
114
+ python scripts/sample_fast.py -r logs/2021-04-03T19-39-50_cin_transformer/ -n 50 -k 600 -t 1.0 -p 0.92 --batch_size 25
115
+ ```
116
+
117
+ To restrict the model to certain classes, provide them via the `--classes` argument, separated by
118
+ commas. For example, to sample 50 *ostriches*, *border collies* and *whiskey jugs*, run
119
+
120
+ ```
121
+ python scripts/sample_fast.py -r logs/2021-04-03T19-39-50_cin_transformer/ -n 50 -k 600 -t 1.0 -p 0.92 --batch_size 25 --classes 9,232,901
122
+ ```
123
+ We recommended to experiment with the autoregressive decoding parameters (top-k, top-p and temperature) for best results.
124
+
125
+ ### FFHQ/CelebA-HQ
126
+
127
+ Download the [2021-04-23T18-19-01_ffhq_transformer](https://k00.fr/yndvfu95) and
128
+ [2021-04-23T18-11-19_celebahq_transformer](https://k00.fr/2xkmielf)
129
+ folders and place them into logs.
130
+ Again, sampling from these unconditional models does not require any data preparation.
131
+ To produce 50000 samples, with k=250 for top-k sampling,
132
+ p=1.0 for nucleus sampling and temperature t=1.0, run
133
+
134
+ ```
135
+ python scripts/sample_fast.py -r logs/2021-04-23T18-19-01_ffhq_transformer/
136
+ ```
137
+ for FFHQ and
138
+
139
+ ```
140
+ python scripts/sample_fast.py -r logs/2021-04-23T18-11-19_celebahq_transformer/
141
+ ```
142
+ to sample from the CelebA-HQ model.
143
+ For both models it can be advantageous to vary the top-k/top-p parameters for sampling.
144
+
145
+ ### FacesHQ
146
+ ![teaser](assets/faceshq.jpg)
147
+
148
+ Download [2020-11-13T21-41-45_faceshq_transformer](https://k00.fr/qqfl2do8) and
149
+ place it into `logs`. Follow the data preparation steps for
150
+ [CelebA-HQ](#celeba-hq) and [FFHQ](#ffhq). Run
151
+ ```
152
+ streamlit run scripts/sample_conditional.py -- -r logs/2020-11-13T21-41-45_faceshq_transformer/
153
+ ```
154
+
155
+ ### D-RIN
156
+ ![teaser](assets/drin.jpg)
157
+
158
+ Download [2020-11-20T12-54-32_drin_transformer](https://k00.fr/39jcugc5) and
159
+ place it into `logs`. To run the demo on a couple of example depth maps
160
+ included in the repository, run
161
+
162
+ ```
163
+ streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T12-54-32_drin_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.imagenet.DRINExamples}}}"
164
+ ```
165
+
166
+ To run the demo on the complete validation set, first follow the data preparation steps for
167
+ [ImageNet](#imagenet) and then run
168
+ ```
169
+ streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T12-54-32_drin_transformer/
170
+ ```
171
+
172
+ ### COCO
173
+ Download [2021-01-20T16-04-20_coco_transformer](https://k00.fr/2zz6i2ce) and
174
+ place it into `logs`. To run the demo on a couple of example segmentation maps
175
+ included in the repository, run
176
+
177
+ ```
178
+ streamlit run scripts/sample_conditional.py -- -r logs/2021-01-20T16-04-20_coco_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.coco.Examples}}}"
179
+ ```
180
+
181
+ ### ADE20k
182
+ Download [2020-11-20T21-45-44_ade20k_transformer](https://k00.fr/ot46cksa) and
183
+ place it into `logs`. To run the demo on a couple of example segmentation maps
184
+ included in the repository, run
185
+
186
+ ```
187
+ streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T21-45-44_ade20k_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.ade20k.Examples}}}"
188
+ ```
189
+
190
+ ## Scene Image Synthesis
191
+ ![teaser](assets/scene_images_samples.svg)
192
+ Scene image generation based on bounding box conditionals as done in our CVPR2021 AI4CC workshop paper [High-Resolution Complex Scene Synthesis with Transformers](https://arxiv.org/abs/2105.06458) (see talk on [workshop page](https://visual.cs.brown.edu/workshops/aicc2021/#awards)). Supporting the datasets COCO and Open Images.
193
+
194
+ ### Training
195
+ Download first-stage models [COCO-8k-VQGAN](https://heibox.uni-heidelberg.de/f/78dea9589974474c97c1/) for COCO or [COCO/Open-Images-8k-VQGAN](https://heibox.uni-heidelberg.de/f/461d9a9f4fcf48ab84f4/) for Open Images.
196
+ Change `ckpt_path` in `data/coco_scene_images_transformer.yaml` and `data/open_images_scene_images_transformer.yaml` to point to the downloaded first-stage models.
197
+ Download the full COCO/OI datasets and adapt `data_path` in the same files, unless working with the 100 files provided for training and validation suits your needs already.
198
+
199
+ Code can be run with
200
+ `python main.py --base configs/coco_scene_images_transformer.yaml -t True --gpus 0,`
201
+ or
202
+ `python main.py --base configs/open_images_scene_images_transformer.yaml -t True --gpus 0,`
203
+
204
+ ### Sampling
205
+ Train a model as described above or download a pre-trained model:
206
+ - [Open Images 1 billion parameter model](https://drive.google.com/file/d/1FEK-Z7hyWJBvFWQF50pzSK9y1W_CJEig/view?usp=sharing) available that trained 100 epochs. On 256x256 pixels, FID 41.48±0.21, SceneFID 14.60±0.15, Inception Score 18.47±0.27. The model was trained with 2d crops of images and is thus well-prepared for the task of generating high-resolution images, e.g. 512x512.
207
+ - [Open Images distilled version of the above model with 125 million parameters](https://drive.google.com/file/d/1xf89g0mc78J3d8Bx5YhbK4tNRNlOoYaO) allows for sampling on smaller GPUs (4 GB is enough for sampling 256x256 px images). Model was trained for 60 epochs with 10% soft loss, 90% hard loss. On 256x256 pixels, FID 43.07±0.40, SceneFID 15.93±0.19, Inception Score 17.23±0.11.
208
+ - [COCO 30 epochs](https://heibox.uni-heidelberg.de/f/0d0b2594e9074c7e9a33/)
209
+ - [COCO 60 epochs](https://drive.google.com/file/d/1bInd49g2YulTJBjU32Awyt5qnzxxG5U9/) (find model statistics for both COCO versions in `assets/coco_scene_images_training.svg`)
210
+
211
+ When downloading a pre-trained model, remember to change `ckpt_path` in `configs/*project.yaml` to point to your downloaded first-stage model (see ->Training).
212
+
213
+ Scene image generation can be run with
214
+ `python scripts/make_scene_samples.py --outdir=/some/outdir -r /path/to/pretrained/model --resolution=512,512`
215
+
216
+
217
+ ## Training on custom data
218
+
219
+ Training on your own dataset can be beneficial to get better tokens and hence better images for your domain.
220
+ Those are the steps to follow to make this work:
221
+ 1. install the repo with `conda env create -f environment.yaml`, `conda activate taming` and `pip install -e .`
222
+ 1. put your .jpg files in a folder `your_folder`
223
+ 2. create 2 text files a `xx_train.txt` and `xx_test.txt` that point to the files in your training and test set respectively (for example `find $(pwd)/your_folder -name "*.jpg" > train.txt`)
224
+ 3. adapt `configs/custom_vqgan.yaml` to point to these 2 files
225
+ 4. run `python main.py --base configs/custom_vqgan.yaml -t True --gpus 0,1` to
226
+ train on two GPUs. Use `--gpus 0,` (with a trailing comma) to train on a single GPU.
227
+
228
+ ## Data Preparation
229
+
230
+ ### ImageNet
231
+ The code will try to download (through [Academic
232
+ Torrents](http://academictorrents.com/)) and prepare ImageNet the first time it
233
+ is used. However, since ImageNet is quite large, this requires a lot of disk
234
+ space and time. If you already have ImageNet on your disk, you can speed things
235
+ up by putting the data into
236
+ `${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` (which defaults to
237
+ `~/.cache/autoencoders/data/ILSVRC2012_{split}/data/`), where `{split}` is one
238
+ of `train`/`validation`. It should have the following structure:
239
+
240
+ ```
241
+ ${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/
242
+ ├── n01440764
243
+ │ ├── n01440764_10026.JPEG
244
+ │ ├── n01440764_10027.JPEG
245
+ │ ├── ...
246
+ ├── n01443537
247
+ │ ├── n01443537_10007.JPEG
248
+ │ ├── n01443537_10014.JPEG
249
+ │ ├── ...
250
+ ├── ...
251
+ ```
252
+
253
+ If you haven't extracted the data, you can also place
254
+ `ILSVRC2012_img_train.tar`/`ILSVRC2012_img_val.tar` (or symlinks to them) into
255
+ `${XDG_CACHE}/autoencoders/data/ILSVRC2012_train/` /
256
+ `${XDG_CACHE}/autoencoders/data/ILSVRC2012_validation/`, which will then be
257
+ extracted into above structure without downloading it again. Note that this
258
+ will only happen if neither a folder
259
+ `${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` nor a file
260
+ `${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/.ready` exist. Remove them
261
+ if you want to force running the dataset preparation again.
262
+
263
+ You will then need to prepare the depth data using
264
+ [MiDaS](https://github.com/intel-isl/MiDaS). Create a symlink
265
+ `data/imagenet_depth` pointing to a folder with two subfolders `train` and
266
+ `val`, each mirroring the structure of the corresponding ImageNet folder
267
+ described above and containing a `png` file for each of ImageNet's `JPEG`
268
+ files. The `png` encodes `float32` depth values obtained from MiDaS as RGBA
269
+ images. We provide the script `scripts/extract_depth.py` to generate this data.
270
+ **Please note** that this script uses [MiDaS via PyTorch
271
+ Hub](https://pytorch.org/hub/intelisl_midas_v2/). When we prepared the data,
272
+ the hub provided the [MiDaS
273
+ v2.0](https://github.com/intel-isl/MiDaS/releases/tag/v2) version, but now it
274
+ provides a v2.1 version. We haven't tested our models with depth maps obtained
275
+ via v2.1 and if you want to make sure that things work as expected, you must
276
+ adjust the script to make sure it explicitly uses
277
+ [v2.0](https://github.com/intel-isl/MiDaS/releases/tag/v2)!
278
+
279
+ ### CelebA-HQ
280
+ Create a symlink `data/celebahq` pointing to a folder containing the `.npy`
281
+ files of CelebA-HQ (instructions to obtain them can be found in the [PGGAN
282
+ repository](https://github.com/tkarras/progressive_growing_of_gans)).
283
+
284
+ ### FFHQ
285
+ Create a symlink `data/ffhq` pointing to the `images1024x1024` folder obtained
286
+ from the [FFHQ repository](https://github.com/NVlabs/ffhq-dataset).
287
+
288
+ ### S-FLCKR
289
+ Unfortunately, we are not allowed to distribute the images we collected for the
290
+ S-FLCKR dataset and can therefore only give a description how it was produced.
291
+ There are many resources on [collecting images from the
292
+ web](https://github.com/adrianmrit/flickrdatasets) to get started.
293
+ We collected sufficiently large images from [flickr](https://www.flickr.com)
294
+ (see `data/flickr_tags.txt` for a full list of tags used to find images)
295
+ and various [subreddits](https://www.reddit.com/r/sfwpornnetwork/wiki/network)
296
+ (see `data/subreddits.txt` for all subreddits that were used).
297
+ Overall, we collected 107625 images, and split them randomly into 96861
298
+ training images and 10764 validation images. We then obtained segmentation
299
+ masks for each image using [DeepLab v2](https://arxiv.org/abs/1606.00915)
300
+ trained on [COCO-Stuff](https://arxiv.org/abs/1612.03716). We used a [PyTorch
301
+ reimplementation](https://github.com/kazuto1011/deeplab-pytorch) and include an
302
+ example script for this process in `scripts/extract_segmentation.py`.
303
+
304
+ ### COCO
305
+ Create a symlink `data/coco` containing the images from the 2017 split in
306
+ `train2017` and `val2017`, and their annotations in `annotations`. Files can be
307
+ obtained from the [COCO webpage](https://cocodataset.org/). In addition, we use
308
+ the [Stuff+thing PNG-style annotations on COCO 2017
309
+ trainval](http://calvin.inf.ed.ac.uk/wp-content/uploads/data/cocostuffdataset/stuffthingmaps_trainval2017.zip)
310
+ annotations from [COCO-Stuff](https://github.com/nightrome/cocostuff), which
311
+ should be placed under `data/cocostuffthings`.
312
+
313
+ ### ADE20k
314
+ Create a symlink `data/ade20k_root` containing the contents of
315
+ [ADEChallengeData2016.zip](http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip)
316
+ from the [MIT Scene Parsing Benchmark](http://sceneparsing.csail.mit.edu/).
317
+
318
+ ## Training models
319
+
320
+ ### FacesHQ
321
+
322
+ Train a VQGAN with
323
+ ```
324
+ python main.py --base configs/faceshq_vqgan.yaml -t True --gpus 0,
325
+ ```
326
+
327
+ Then, adjust the checkpoint path of the config key
328
+ `model.params.first_stage_config.params.ckpt_path` in
329
+ `configs/faceshq_transformer.yaml` (or download
330
+ [2020-11-09T13-33-36_faceshq_vqgan](https://k00.fr/uxy5usa9) and place into `logs`, which
331
+ corresponds to the preconfigured checkpoint path), then run
332
+ ```
333
+ python main.py --base configs/faceshq_transformer.yaml -t True --gpus 0,
334
+ ```
335
+
336
+ ### D-RIN
337
+
338
+ Train a VQGAN on ImageNet with
339
+ ```
340
+ python main.py --base configs/imagenet_vqgan.yaml -t True --gpus 0,
341
+ ```
342
+
343
+ or download a pretrained one from [2020-09-23T17-56-33_imagenet_vqgan](https://k00.fr/u0j2dtac)
344
+ and place under `logs`. If you trained your own, adjust the path in the config
345
+ key `model.params.first_stage_config.params.ckpt_path` of
346
+ `configs/drin_transformer.yaml`.
347
+
348
+ Train a VQGAN on Depth Maps of ImageNet with
349
+ ```
350
+ python main.py --base configs/imagenetdepth_vqgan.yaml -t True --gpus 0,
351
+ ```
352
+
353
+ or download a pretrained one from [2020-11-03T15-34-24_imagenetdepth_vqgan](https://k00.fr/55rlxs6i)
354
+ and place under `logs`. If you trained your own, adjust the path in the config
355
+ key `model.params.cond_stage_config.params.ckpt_path` of
356
+ `configs/drin_transformer.yaml`.
357
+
358
+ To train the transformer, run
359
+ ```
360
+ python main.py --base configs/drin_transformer.yaml -t True --gpus 0,
361
+ ```
362
+
363
+ ## More Resources
364
+ ### Comparing Different First Stage Models
365
+ The reconstruction and compression capabilities of different fist stage models can be analyzed in this [colab notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb).
366
+ In particular, the notebook compares two VQGANs with a downsampling factor of f=16 for each and codebook dimensionality of 1024 and 16384,
367
+ a VQGAN with f=8 and 8192 codebook entries and the discrete autoencoder of OpenAI's [DALL-E](https://github.com/openai/DALL-E) (which has f=8 and 8192
368
+ codebook entries).
369
+ ![firststages1](assets/first_stage_squirrels.png)
370
+ ![firststages2](assets/first_stage_mushrooms.png)
371
+
372
+ ### Other
373
+ - A [video summary](https://www.youtube.com/watch?v=o7dqGcLDf0A&feature=emb_imp_woyt) by [Two Minute Papers](https://www.youtube.com/channel/UCbfYPyITQ-7l4upoX8nvctg).
374
+ - A [video summary](https://www.youtube.com/watch?v=-wDSDtIAyWQ) by [Gradient Dude](https://www.youtube.com/c/GradientDude/about).
375
+ - A [weights and biases report summarizing the paper](https://wandb.ai/ayush-thakur/taming-transformer/reports/-Overview-Taming-Transformers-for-High-Resolution-Image-Synthesis---Vmlldzo0NjEyMTY)
376
+ by [ayulockin](https://github.com/ayulockin).
377
+ - A [video summary](https://www.youtube.com/watch?v=JfUTd8fjtX8&feature=emb_imp_woyt) by [What's AI](https://www.youtube.com/channel/UCUzGQrN-lyyc0BWTYoJM_Sg).
378
+ - Take a look at [ak9250's notebook](https://github.com/ak9250/taming-transformers/blob/master/tamingtransformerscolab.ipynb) if you want to run the streamlit demos on Colab.
379
+
380
+ ### Text-to-Image Optimization via CLIP
381
+ VQGAN has been successfully used as an image generator guided by the [CLIP](https://github.com/openai/CLIP) model, both for pure image generation
382
+ from scratch and image-to-image translation. We recommend the following notebooks/videos/resources:
383
+
384
+ - [Advadnouns](https://twitter.com/advadnoun/status/1389316507134357506) Patreon and corresponding LatentVision notebooks: https://www.patreon.com/patronizeme
385
+ - The [notebook]( https://colab.research.google.com/drive/1L8oL-vLJXVcRzCFbPwOoMkPKJ8-aYdPN) of [Rivers Have Wings](https://twitter.com/RiversHaveWings).
386
+ - A [video](https://www.youtube.com/watch?v=90QDe6DQXF4&t=12s) explanation by [Dot CSV](https://www.youtube.com/channel/UCy5znSnfMsDwaLlROnZ7Qbg) (in Spanish, but English subtitles are available)
387
+
388
+ ![txt2img](assets/birddrawnbyachild.png)
389
+
390
+ Text prompt: *'A bird drawn by a child'*
391
+
392
+ ## Shout-outs
393
+ Thanks to everyone who makes their code and models available. In particular,
394
+
395
+ - The architecture of our VQGAN is inspired by [Denoising Diffusion Probabilistic Models](https://github.com/hojonathanho/diffusion)
396
+ - The very hackable transformer implementation [minGPT](https://github.com/karpathy/minGPT)
397
+ - The good ol' [PatchGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) and [Learned Perceptual Similarity (LPIPS)](https://github.com/richzhang/PerceptualSimilarity)
398
+
399
+ ## BibTeX
400
+
401
+ ```
402
+ @misc{esser2020taming,
403
+ title={Taming Transformers for High-Resolution Image Synthesis},
404
+ author={Patrick Esser and Robin Rombach and Björn Ommer},
405
+ year={2020},
406
+ eprint={2012.09841},
407
+ archivePrefix={arXiv},
408
+ primaryClass={cs.CV}
409
+ }
410
+ ```
taming-transformers/configs/coco_cond_stage.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.5e-06
3
+ target: taming.models.vqgan.VQSegmentationModel
4
+ params:
5
+ embed_dim: 256
6
+ n_embed: 1024
7
+ image_key: "segmentation"
8
+ n_labels: 183
9
+ ddconfig:
10
+ double_z: false
11
+ z_channels: 256
12
+ resolution: 256
13
+ in_channels: 183
14
+ out_ch: 183
15
+ ch: 128
16
+ ch_mult:
17
+ - 1
18
+ - 1
19
+ - 2
20
+ - 2
21
+ - 4
22
+ num_res_blocks: 2
23
+ attn_resolutions:
24
+ - 16
25
+ dropout: 0.0
26
+
27
+ lossconfig:
28
+ target: taming.modules.losses.segmentation.BCELossWithQuant
29
+ params:
30
+ codebook_weight: 1.0
31
+
32
+ data:
33
+ target: main.DataModuleFromConfig
34
+ params:
35
+ batch_size: 12
36
+ train:
37
+ target: taming.data.coco.CocoImagesAndCaptionsTrain
38
+ params:
39
+ size: 296
40
+ crop_size: 256
41
+ onehot_segmentation: true
42
+ use_stuffthing: true
43
+ validation:
44
+ target: taming.data.coco.CocoImagesAndCaptionsValidation
45
+ params:
46
+ size: 256
47
+ crop_size: 256
48
+ onehot_segmentation: true
49
+ use_stuffthing: true
taming-transformers/configs/coco_scene_images_transformer.yaml ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.5e-06
3
+ target: taming.models.cond_transformer.Net2NetTransformer
4
+ params:
5
+ cond_stage_key: objects_bbox
6
+ transformer_config:
7
+ target: taming.modules.transformer.mingpt.GPT
8
+ params:
9
+ vocab_size: 8192
10
+ block_size: 348 # = 256 + 92 = dim(vqgan_latent_space,16x16) + dim(conditional_builder.embedding_dim)
11
+ n_layer: 40
12
+ n_head: 16
13
+ n_embd: 1408
14
+ embd_pdrop: 0.1
15
+ resid_pdrop: 0.1
16
+ attn_pdrop: 0.1
17
+ first_stage_config:
18
+ target: taming.models.vqgan.VQModel
19
+ params:
20
+ ckpt_path: /path/to/coco_epoch117.ckpt # https://heibox.uni-heidelberg.de/f/78dea9589974474c97c1/
21
+ embed_dim: 256
22
+ n_embed: 8192
23
+ ddconfig:
24
+ double_z: false
25
+ z_channels: 256
26
+ resolution: 256
27
+ in_channels: 3
28
+ out_ch: 3
29
+ ch: 128
30
+ ch_mult:
31
+ - 1
32
+ - 1
33
+ - 2
34
+ - 2
35
+ - 4
36
+ num_res_blocks: 2
37
+ attn_resolutions:
38
+ - 16
39
+ dropout: 0.0
40
+ lossconfig:
41
+ target: taming.modules.losses.DummyLoss
42
+ cond_stage_config:
43
+ target: taming.models.dummy_cond_stage.DummyCondStage
44
+ params:
45
+ conditional_key: objects_bbox
46
+
47
+ data:
48
+ target: main.DataModuleFromConfig
49
+ params:
50
+ batch_size: 6
51
+ train:
52
+ target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco
53
+ params:
54
+ data_path: data/coco_annotations_100 # substitute with path to full dataset
55
+ split: train
56
+ keys: [image, objects_bbox, file_name, annotations]
57
+ no_tokens: 8192
58
+ target_image_size: 256
59
+ min_object_area: 0.00001
60
+ min_objects_per_image: 2
61
+ max_objects_per_image: 30
62
+ crop_method: random-1d
63
+ random_flip: true
64
+ use_group_parameter: true
65
+ encode_crop: true
66
+ validation:
67
+ target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco
68
+ params:
69
+ data_path: data/coco_annotations_100 # substitute with path to full dataset
70
+ split: validation
71
+ keys: [image, objects_bbox, file_name, annotations]
72
+ no_tokens: 8192
73
+ target_image_size: 256
74
+ min_object_area: 0.00001
75
+ min_objects_per_image: 2
76
+ max_objects_per_image: 30
77
+ crop_method: center
78
+ random_flip: false
79
+ use_group_parameter: true
80
+ encode_crop: true
taming-transformers/configs/custom_vqgan.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.5e-6
3
+ target: taming.models.vqgan.VQModel
4
+ params:
5
+ embed_dim: 256
6
+ n_embed: 1024
7
+ ddconfig:
8
+ double_z: False
9
+ z_channels: 256
10
+ resolution: 256
11
+ in_channels: 3
12
+ out_ch: 3
13
+ ch: 128
14
+ ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
15
+ num_res_blocks: 2
16
+ attn_resolutions: [16]
17
+ dropout: 0.0
18
+
19
+ lossconfig:
20
+ target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
21
+ params:
22
+ disc_conditional: False
23
+ disc_in_channels: 3
24
+ disc_start: 10000
25
+ disc_weight: 0.8
26
+ codebook_weight: 1.0
27
+
28
+ data:
29
+ target: main.DataModuleFromConfig
30
+ params:
31
+ batch_size: 5
32
+ num_workers: 8
33
+ train:
34
+ target: taming.data.custom.CustomTrain
35
+ params:
36
+ training_images_list_file: some/training.txt
37
+ size: 256
38
+ validation:
39
+ target: taming.data.custom.CustomTest
40
+ params:
41
+ test_images_list_file: some/test.txt
42
+ size: 256
43
+
taming-transformers/configs/drin_transformer.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.5e-06
3
+ target: taming.models.cond_transformer.Net2NetTransformer
4
+ params:
5
+ cond_stage_key: depth
6
+ transformer_config:
7
+ target: taming.modules.transformer.mingpt.GPT
8
+ params:
9
+ vocab_size: 1024
10
+ block_size: 512
11
+ n_layer: 24
12
+ n_head: 16
13
+ n_embd: 1024
14
+ first_stage_config:
15
+ target: taming.models.vqgan.VQModel
16
+ params:
17
+ ckpt_path: logs/2020-09-23T17-56-33_imagenet_vqgan/checkpoints/last.ckpt
18
+ embed_dim: 256
19
+ n_embed: 1024
20
+ ddconfig:
21
+ double_z: false
22
+ z_channels: 256
23
+ resolution: 256
24
+ in_channels: 3
25
+ out_ch: 3
26
+ ch: 128
27
+ ch_mult:
28
+ - 1
29
+ - 1
30
+ - 2
31
+ - 2
32
+ - 4
33
+ num_res_blocks: 2
34
+ attn_resolutions:
35
+ - 16
36
+ dropout: 0.0
37
+ lossconfig:
38
+ target: taming.modules.losses.DummyLoss
39
+ cond_stage_config:
40
+ target: taming.models.vqgan.VQModel
41
+ params:
42
+ ckpt_path: logs/2020-11-03T15-34-24_imagenetdepth_vqgan/checkpoints/last.ckpt
43
+ embed_dim: 256
44
+ n_embed: 1024
45
+ ddconfig:
46
+ double_z: false
47
+ z_channels: 256
48
+ resolution: 256
49
+ in_channels: 1
50
+ out_ch: 1
51
+ ch: 128
52
+ ch_mult:
53
+ - 1
54
+ - 1
55
+ - 2
56
+ - 2
57
+ - 4
58
+ num_res_blocks: 2
59
+ attn_resolutions:
60
+ - 16
61
+ dropout: 0.0
62
+ lossconfig:
63
+ target: taming.modules.losses.DummyLoss
64
+
65
+ data:
66
+ target: main.DataModuleFromConfig
67
+ params:
68
+ batch_size: 2
69
+ num_workers: 8
70
+ train:
71
+ target: taming.data.imagenet.RINTrainWithDepth
72
+ params:
73
+ size: 256
74
+ validation:
75
+ target: taming.data.imagenet.RINValidationWithDepth
76
+ params:
77
+ size: 256
taming-transformers/configs/faceshq_transformer.yaml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.5e-06
3
+ target: taming.models.cond_transformer.Net2NetTransformer
4
+ params:
5
+ cond_stage_key: coord
6
+ transformer_config:
7
+ target: taming.modules.transformer.mingpt.GPT
8
+ params:
9
+ vocab_size: 1024
10
+ block_size: 512
11
+ n_layer: 24
12
+ n_head: 16
13
+ n_embd: 1024
14
+ first_stage_config:
15
+ target: taming.models.vqgan.VQModel
16
+ params:
17
+ ckpt_path: logs/2020-11-09T13-33-36_faceshq_vqgan/checkpoints/last.ckpt
18
+ embed_dim: 256
19
+ n_embed: 1024
20
+ ddconfig:
21
+ double_z: false
22
+ z_channels: 256
23
+ resolution: 256
24
+ in_channels: 3
25
+ out_ch: 3
26
+ ch: 128
27
+ ch_mult:
28
+ - 1
29
+ - 1
30
+ - 2
31
+ - 2
32
+ - 4
33
+ num_res_blocks: 2
34
+ attn_resolutions:
35
+ - 16
36
+ dropout: 0.0
37
+ lossconfig:
38
+ target: taming.modules.losses.DummyLoss
39
+ cond_stage_config:
40
+ target: taming.modules.misc.coord.CoordStage
41
+ params:
42
+ n_embed: 1024
43
+ down_factor: 16
44
+
45
+ data:
46
+ target: main.DataModuleFromConfig
47
+ params:
48
+ batch_size: 2
49
+ num_workers: 8
50
+ train:
51
+ target: taming.data.faceshq.FacesHQTrain
52
+ params:
53
+ size: 256
54
+ crop_size: 256
55
+ coord: True
56
+ validation:
57
+ target: taming.data.faceshq.FacesHQValidation
58
+ params:
59
+ size: 256
60
+ crop_size: 256
61
+ coord: True
taming-transformers/configs/faceshq_vqgan.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.5e-6
3
+ target: taming.models.vqgan.VQModel
4
+ params:
5
+ embed_dim: 256
6
+ n_embed: 1024
7
+ ddconfig:
8
+ double_z: False
9
+ z_channels: 256
10
+ resolution: 256
11
+ in_channels: 3
12
+ out_ch: 3
13
+ ch: 128
14
+ ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
15
+ num_res_blocks: 2
16
+ attn_resolutions: [16]
17
+ dropout: 0.0
18
+
19
+ lossconfig:
20
+ target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
21
+ params:
22
+ disc_conditional: False
23
+ disc_in_channels: 3
24
+ disc_start: 30001
25
+ disc_weight: 0.8
26
+ codebook_weight: 1.0
27
+
28
+ data:
29
+ target: main.DataModuleFromConfig
30
+ params:
31
+ batch_size: 3
32
+ num_workers: 8
33
+ train:
34
+ target: taming.data.faceshq.FacesHQTrain
35
+ params:
36
+ size: 256
37
+ crop_size: 256
38
+ validation:
39
+ target: taming.data.faceshq.FacesHQValidation
40
+ params:
41
+ size: 256
42
+ crop_size: 256
taming-transformers/configs/imagenet_vqgan.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.5e-6
3
+ target: taming.models.vqgan.VQModel
4
+ params:
5
+ embed_dim: 256
6
+ n_embed: 1024
7
+ ddconfig:
8
+ double_z: False
9
+ z_channels: 256
10
+ resolution: 256
11
+ in_channels: 3
12
+ out_ch: 3
13
+ ch: 128
14
+ ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
15
+ num_res_blocks: 2
16
+ attn_resolutions: [16]
17
+ dropout: 0.0
18
+
19
+ lossconfig:
20
+ target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
21
+ params:
22
+ disc_conditional: False
23
+ disc_in_channels: 3
24
+ disc_start: 250001
25
+ disc_weight: 0.8
26
+ codebook_weight: 1.0
27
+
28
+ data:
29
+ target: main.DataModuleFromConfig
30
+ params:
31
+ batch_size: 12
32
+ num_workers: 24
33
+ train:
34
+ target: taming.data.imagenet.ImageNetTrain
35
+ params:
36
+ config:
37
+ size: 256
38
+ validation:
39
+ target: taming.data.imagenet.ImageNetValidation
40
+ params:
41
+ config:
42
+ size: 256
taming-transformers/configs/imagenetdepth_vqgan.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.5e-6
3
+ target: taming.models.vqgan.VQModel
4
+ params:
5
+ embed_dim: 256
6
+ n_embed: 1024
7
+ image_key: depth
8
+ ddconfig:
9
+ double_z: False
10
+ z_channels: 256
11
+ resolution: 256
12
+ in_channels: 1
13
+ out_ch: 1
14
+ ch: 128
15
+ ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
16
+ num_res_blocks: 2
17
+ attn_resolutions: [16]
18
+ dropout: 0.0
19
+
20
+ lossconfig:
21
+ target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
22
+ params:
23
+ disc_conditional: False
24
+ disc_in_channels: 1
25
+ disc_start: 50001
26
+ disc_weight: 0.75
27
+ codebook_weight: 1.0
28
+
29
+ data:
30
+ target: main.DataModuleFromConfig
31
+ params:
32
+ batch_size: 3
33
+ num_workers: 8
34
+ train:
35
+ target: taming.data.imagenet.ImageNetTrainWithDepth
36
+ params:
37
+ size: 256
38
+ validation:
39
+ target: taming.data.imagenet.ImageNetValidationWithDepth
40
+ params:
41
+ size: 256
taming-transformers/configs/open_images_scene_images_transformer.yaml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.5e-06
3
+ target: taming.models.cond_transformer.Net2NetTransformer
4
+ params:
5
+ cond_stage_key: objects_bbox
6
+ transformer_config:
7
+ target: taming.modules.transformer.mingpt.GPT
8
+ params:
9
+ vocab_size: 8192
10
+ block_size: 348 # = 256 + 92 = dim(vqgan_latent_space,16x16) + dim(conditional_builder.embedding_dim)
11
+ n_layer: 36
12
+ n_head: 16
13
+ n_embd: 1536
14
+ embd_pdrop: 0.1
15
+ resid_pdrop: 0.1
16
+ attn_pdrop: 0.1
17
+ first_stage_config:
18
+ target: taming.models.vqgan.VQModel
19
+ params:
20
+ ckpt_path: /path/to/coco_oi_epoch12.ckpt # https://heibox.uni-heidelberg.de/f/461d9a9f4fcf48ab84f4/
21
+ embed_dim: 256
22
+ n_embed: 8192
23
+ ddconfig:
24
+ double_z: false
25
+ z_channels: 256
26
+ resolution: 256
27
+ in_channels: 3
28
+ out_ch: 3
29
+ ch: 128
30
+ ch_mult:
31
+ - 1
32
+ - 1
33
+ - 2
34
+ - 2
35
+ - 4
36
+ num_res_blocks: 2
37
+ attn_resolutions:
38
+ - 16
39
+ dropout: 0.0
40
+ lossconfig:
41
+ target: taming.modules.losses.DummyLoss
42
+ cond_stage_config:
43
+ target: taming.models.dummy_cond_stage.DummyCondStage
44
+ params:
45
+ conditional_key: objects_bbox
46
+
47
+ data:
48
+ target: main.DataModuleFromConfig
49
+ params:
50
+ batch_size: 6
51
+ train:
52
+ target: taming.data.annotated_objects_open_images.AnnotatedObjectsOpenImages
53
+ params:
54
+ data_path: data/open_images_annotations_100 # substitute with path to full dataset
55
+ split: train
56
+ keys: [image, objects_bbox, file_name, annotations]
57
+ no_tokens: 8192
58
+ target_image_size: 256
59
+ category_allow_list_target: taming.data.open_images_helper.top_300_classes_plus_coco_compatibility
60
+ category_mapping_target: taming.data.open_images_helper.open_images_unify_categories_for_coco
61
+ min_object_area: 0.0001
62
+ min_objects_per_image: 2
63
+ max_objects_per_image: 30
64
+ crop_method: random-2d
65
+ random_flip: true
66
+ use_group_parameter: true
67
+ use_additional_parameters: true
68
+ encode_crop: true
69
+ validation:
70
+ target: taming.data.annotated_objects_open_images.AnnotatedObjectsOpenImages
71
+ params:
72
+ data_path: data/open_images_annotations_100 # substitute with path to full dataset
73
+ split: validation
74
+ keys: [image, objects_bbox, file_name, annotations]
75
+ no_tokens: 8192
76
+ target_image_size: 256
77
+ category_allow_list_target: taming.data.open_images_helper.top_300_classes_plus_coco_compatibility
78
+ category_mapping_target: taming.data.open_images_helper.open_images_unify_categories_for_coco
79
+ min_object_area: 0.0001
80
+ min_objects_per_image: 2
81
+ max_objects_per_image: 30
82
+ crop_method: center
83
+ random_flip: false
84
+ use_group_parameter: true
85
+ use_additional_parameters: true
86
+ encode_crop: true
taming-transformers/configs/sflckr_cond_stage.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.5e-06
3
+ target: taming.models.vqgan.VQSegmentationModel
4
+ params:
5
+ embed_dim: 256
6
+ n_embed: 1024
7
+ image_key: "segmentation"
8
+ n_labels: 182
9
+ ddconfig:
10
+ double_z: false
11
+ z_channels: 256
12
+ resolution: 256
13
+ in_channels: 182
14
+ out_ch: 182
15
+ ch: 128
16
+ ch_mult:
17
+ - 1
18
+ - 1
19
+ - 2
20
+ - 2
21
+ - 4
22
+ num_res_blocks: 2
23
+ attn_resolutions:
24
+ - 16
25
+ dropout: 0.0
26
+
27
+ lossconfig:
28
+ target: taming.modules.losses.segmentation.BCELossWithQuant
29
+ params:
30
+ codebook_weight: 1.0
31
+
32
+ data:
33
+ target: cutlit.DataModuleFromConfig
34
+ params:
35
+ batch_size: 12
36
+ train:
37
+ target: taming.data.sflckr.Examples # adjust
38
+ params:
39
+ size: 256
40
+ validation:
41
+ target: taming.data.sflckr.Examples # adjust
42
+ params:
43
+ size: 256
taming-transformers/environment.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: taming
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ dependencies:
6
+ - python=3.8.5
7
+ - pip=20.3
8
+ - cudatoolkit=10.2
9
+ - pytorch=1.7.0
10
+ - torchvision=0.8.1
11
+ - numpy=1.19.2
12
+ - pip:
13
+ - albumentations==0.4.3
14
+ - opencv-python==4.1.2.30
15
+ - pudb==2019.2
16
+ - imageio==2.9.0
17
+ - imageio-ffmpeg==0.4.2
18
+ - pytorch-lightning==1.0.8
19
+ - omegaconf==2.0.0
20
+ - test-tube>=0.7.5
21
+ - streamlit>=0.73.1
22
+ - einops==0.3.0
23
+ - more-itertools>=8.0.0
24
+ - transformers==4.3.1
25
+ - -e .
taming-transformers/main.py ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, os, sys, datetime, glob, importlib
2
+ from omegaconf import OmegaConf
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch
6
+ import torchvision
7
+ from torch.utils.data import random_split, DataLoader, Dataset
8
+ import pytorch_lightning as pl
9
+ from pytorch_lightning import seed_everything
10
+ from pytorch_lightning.trainer import Trainer
11
+ from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
12
+ from pytorch_lightning.utilities import rank_zero_only
13
+
14
+ from taming.data.utils import custom_collate
15
+
16
+
17
+ def get_obj_from_str(string, reload=False):
18
+ module, cls = string.rsplit(".", 1)
19
+ if reload:
20
+ module_imp = importlib.import_module(module)
21
+ importlib.reload(module_imp)
22
+ return getattr(importlib.import_module(module, package=None), cls)
23
+
24
+
25
+ def get_parser(**parser_kwargs):
26
+ def str2bool(v):
27
+ if isinstance(v, bool):
28
+ return v
29
+ if v.lower() in ("yes", "true", "t", "y", "1"):
30
+ return True
31
+ elif v.lower() in ("no", "false", "f", "n", "0"):
32
+ return False
33
+ else:
34
+ raise argparse.ArgumentTypeError("Boolean value expected.")
35
+
36
+ parser = argparse.ArgumentParser(**parser_kwargs)
37
+ parser.add_argument(
38
+ "-n",
39
+ "--name",
40
+ type=str,
41
+ const=True,
42
+ default="",
43
+ nargs="?",
44
+ help="postfix for logdir",
45
+ )
46
+ parser.add_argument(
47
+ "-r",
48
+ "--resume",
49
+ type=str,
50
+ const=True,
51
+ default="",
52
+ nargs="?",
53
+ help="resume from logdir or checkpoint in logdir",
54
+ )
55
+ parser.add_argument(
56
+ "-b",
57
+ "--base",
58
+ nargs="*",
59
+ metavar="base_config.yaml",
60
+ help="paths to base configs. Loaded from left-to-right. "
61
+ "Parameters can be overwritten or added with command-line options of the form `--key value`.",
62
+ default=list(),
63
+ )
64
+ parser.add_argument(
65
+ "-t",
66
+ "--train",
67
+ type=str2bool,
68
+ const=True,
69
+ default=False,
70
+ nargs="?",
71
+ help="train",
72
+ )
73
+ parser.add_argument(
74
+ "--no-test",
75
+ type=str2bool,
76
+ const=True,
77
+ default=False,
78
+ nargs="?",
79
+ help="disable test",
80
+ )
81
+ parser.add_argument("-p", "--project", help="name of new or path to existing project")
82
+ parser.add_argument(
83
+ "-d",
84
+ "--debug",
85
+ type=str2bool,
86
+ nargs="?",
87
+ const=True,
88
+ default=False,
89
+ help="enable post-mortem debugging",
90
+ )
91
+ parser.add_argument(
92
+ "-s",
93
+ "--seed",
94
+ type=int,
95
+ default=23,
96
+ help="seed for seed_everything",
97
+ )
98
+ parser.add_argument(
99
+ "-f",
100
+ "--postfix",
101
+ type=str,
102
+ default="",
103
+ help="post-postfix for default name",
104
+ )
105
+
106
+ return parser
107
+
108
+
109
+ def nondefault_trainer_args(opt):
110
+ parser = argparse.ArgumentParser()
111
+ parser = Trainer.add_argparse_args(parser)
112
+ args = parser.parse_args([])
113
+ return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
114
+
115
+
116
+ def instantiate_from_config(config):
117
+ if not "target" in config:
118
+ raise KeyError("Expected key `target` to instantiate.")
119
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
120
+
121
+
122
+ class WrappedDataset(Dataset):
123
+ """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
124
+ def __init__(self, dataset):
125
+ self.data = dataset
126
+
127
+ def __len__(self):
128
+ return len(self.data)
129
+
130
+ def __getitem__(self, idx):
131
+ return self.data[idx]
132
+
133
+
134
+ class DataModuleFromConfig(pl.LightningDataModule):
135
+ def __init__(self, batch_size, train=None, validation=None, test=None,
136
+ wrap=False, num_workers=None):
137
+ super().__init__()
138
+ self.batch_size = batch_size
139
+ self.dataset_configs = dict()
140
+ self.num_workers = num_workers if num_workers is not None else batch_size*2
141
+ if train is not None:
142
+ self.dataset_configs["train"] = train
143
+ self.train_dataloader = self._train_dataloader
144
+ if validation is not None:
145
+ self.dataset_configs["validation"] = validation
146
+ self.val_dataloader = self._val_dataloader
147
+ if test is not None:
148
+ self.dataset_configs["test"] = test
149
+ self.test_dataloader = self._test_dataloader
150
+ self.wrap = wrap
151
+
152
+ def prepare_data(self):
153
+ for data_cfg in self.dataset_configs.values():
154
+ instantiate_from_config(data_cfg)
155
+
156
+ def setup(self, stage=None):
157
+ self.datasets = dict(
158
+ (k, instantiate_from_config(self.dataset_configs[k]))
159
+ for k in self.dataset_configs)
160
+ if self.wrap:
161
+ for k in self.datasets:
162
+ self.datasets[k] = WrappedDataset(self.datasets[k])
163
+
164
+ def _train_dataloader(self):
165
+ return DataLoader(self.datasets["train"], batch_size=self.batch_size,
166
+ num_workers=self.num_workers, shuffle=True, collate_fn=custom_collate)
167
+
168
+ def _val_dataloader(self):
169
+ return DataLoader(self.datasets["validation"],
170
+ batch_size=self.batch_size,
171
+ num_workers=self.num_workers, collate_fn=custom_collate)
172
+
173
+ def _test_dataloader(self):
174
+ return DataLoader(self.datasets["test"], batch_size=self.batch_size,
175
+ num_workers=self.num_workers, collate_fn=custom_collate)
176
+
177
+
178
+ class SetupCallback(Callback):
179
+ def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
180
+ super().__init__()
181
+ self.resume = resume
182
+ self.now = now
183
+ self.logdir = logdir
184
+ self.ckptdir = ckptdir
185
+ self.cfgdir = cfgdir
186
+ self.config = config
187
+ self.lightning_config = lightning_config
188
+
189
+ def on_pretrain_routine_start(self, trainer, pl_module):
190
+ if trainer.global_rank == 0:
191
+ # Create logdirs and save configs
192
+ os.makedirs(self.logdir, exist_ok=True)
193
+ os.makedirs(self.ckptdir, exist_ok=True)
194
+ os.makedirs(self.cfgdir, exist_ok=True)
195
+
196
+ print("Project config")
197
+ print(self.config.pretty())
198
+ OmegaConf.save(self.config,
199
+ os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
200
+
201
+ print("Lightning config")
202
+ print(self.lightning_config.pretty())
203
+ OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
204
+ os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
205
+
206
+ else:
207
+ # ModelCheckpoint callback created log directory --- remove it
208
+ if not self.resume and os.path.exists(self.logdir):
209
+ dst, name = os.path.split(self.logdir)
210
+ dst = os.path.join(dst, "child_runs", name)
211
+ os.makedirs(os.path.split(dst)[0], exist_ok=True)
212
+ try:
213
+ os.rename(self.logdir, dst)
214
+ except FileNotFoundError:
215
+ pass
216
+
217
+
218
+ class ImageLogger(Callback):
219
+ def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True):
220
+ super().__init__()
221
+ self.batch_freq = batch_frequency
222
+ self.max_images = max_images
223
+ self.logger_log_images = {
224
+ pl.loggers.WandbLogger: self._wandb,
225
+ pl.loggers.TestTubeLogger: self._testtube,
226
+ }
227
+ self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
228
+ if not increase_log_steps:
229
+ self.log_steps = [self.batch_freq]
230
+ self.clamp = clamp
231
+
232
+ @rank_zero_only
233
+ def _wandb(self, pl_module, images, batch_idx, split):
234
+ raise ValueError("No way wandb")
235
+ grids = dict()
236
+ for k in images:
237
+ grid = torchvision.utils.make_grid(images[k])
238
+ grids[f"{split}/{k}"] = wandb.Image(grid)
239
+ pl_module.logger.experiment.log(grids)
240
+
241
+ @rank_zero_only
242
+ def _testtube(self, pl_module, images, batch_idx, split):
243
+ for k in images:
244
+ grid = torchvision.utils.make_grid(images[k])
245
+ grid = (grid+1.0)/2.0 # -1,1 -> 0,1; c,h,w
246
+
247
+ tag = f"{split}/{k}"
248
+ pl_module.logger.experiment.add_image(
249
+ tag, grid,
250
+ global_step=pl_module.global_step)
251
+
252
+ @rank_zero_only
253
+ def log_local(self, save_dir, split, images,
254
+ global_step, current_epoch, batch_idx):
255
+ root = os.path.join(save_dir, "images", split)
256
+ for k in images:
257
+ grid = torchvision.utils.make_grid(images[k], nrow=4)
258
+
259
+ grid = (grid+1.0)/2.0 # -1,1 -> 0,1; c,h,w
260
+ grid = grid.transpose(0,1).transpose(1,2).squeeze(-1)
261
+ grid = grid.numpy()
262
+ grid = (grid*255).astype(np.uint8)
263
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
264
+ k,
265
+ global_step,
266
+ current_epoch,
267
+ batch_idx)
268
+ path = os.path.join(root, filename)
269
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
270
+ Image.fromarray(grid).save(path)
271
+
272
+ def log_img(self, pl_module, batch, batch_idx, split="train"):
273
+ if (self.check_frequency(batch_idx) and # batch_idx % self.batch_freq == 0
274
+ hasattr(pl_module, "log_images") and
275
+ callable(pl_module.log_images) and
276
+ self.max_images > 0):
277
+ logger = type(pl_module.logger)
278
+
279
+ is_train = pl_module.training
280
+ if is_train:
281
+ pl_module.eval()
282
+
283
+ with torch.no_grad():
284
+ images = pl_module.log_images(batch, split=split, pl_module=pl_module)
285
+
286
+ for k in images:
287
+ N = min(images[k].shape[0], self.max_images)
288
+ images[k] = images[k][:N]
289
+ if isinstance(images[k], torch.Tensor):
290
+ images[k] = images[k].detach().cpu()
291
+ if self.clamp:
292
+ images[k] = torch.clamp(images[k], -1., 1.)
293
+
294
+ self.log_local(pl_module.logger.save_dir, split, images,
295
+ pl_module.global_step, pl_module.current_epoch, batch_idx)
296
+
297
+ logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
298
+ logger_log_images(pl_module, images, pl_module.global_step, split)
299
+
300
+ if is_train:
301
+ pl_module.train()
302
+
303
+ def check_frequency(self, batch_idx):
304
+ if (batch_idx % self.batch_freq) == 0 or (batch_idx in self.log_steps):
305
+ try:
306
+ self.log_steps.pop(0)
307
+ except IndexError:
308
+ pass
309
+ return True
310
+ return False
311
+
312
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
313
+ self.log_img(pl_module, batch, batch_idx, split="train")
314
+
315
+ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
316
+ self.log_img(pl_module, batch, batch_idx, split="val")
317
+
318
+
319
+
320
+ if __name__ == "__main__":
321
+ # custom parser to specify config files, train, test and debug mode,
322
+ # postfix, resume.
323
+ # `--key value` arguments are interpreted as arguments to the trainer.
324
+ # `nested.key=value` arguments are interpreted as config parameters.
325
+ # configs are merged from left-to-right followed by command line parameters.
326
+
327
+ # model:
328
+ # base_learning_rate: float
329
+ # target: path to lightning module
330
+ # params:
331
+ # key: value
332
+ # data:
333
+ # target: main.DataModuleFromConfig
334
+ # params:
335
+ # batch_size: int
336
+ # wrap: bool
337
+ # train:
338
+ # target: path to train dataset
339
+ # params:
340
+ # key: value
341
+ # validation:
342
+ # target: path to validation dataset
343
+ # params:
344
+ # key: value
345
+ # test:
346
+ # target: path to test dataset
347
+ # params:
348
+ # key: value
349
+ # lightning: (optional, has sane defaults and can be specified on cmdline)
350
+ # trainer:
351
+ # additional arguments to trainer
352
+ # logger:
353
+ # logger to instantiate
354
+ # modelcheckpoint:
355
+ # modelcheckpoint to instantiate
356
+ # callbacks:
357
+ # callback1:
358
+ # target: importpath
359
+ # params:
360
+ # key: value
361
+
362
+ now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
363
+
364
+ # add cwd for convenience and to make classes in this file available when
365
+ # running as `python main.py`
366
+ # (in particular `main.DataModuleFromConfig`)
367
+ sys.path.append(os.getcwd())
368
+
369
+ parser = get_parser()
370
+ parser = Trainer.add_argparse_args(parser)
371
+
372
+ opt, unknown = parser.parse_known_args()
373
+ if opt.name and opt.resume:
374
+ raise ValueError(
375
+ "-n/--name and -r/--resume cannot be specified both."
376
+ "If you want to resume training in a new log folder, "
377
+ "use -n/--name in combination with --resume_from_checkpoint"
378
+ )
379
+ if opt.resume:
380
+ if not os.path.exists(opt.resume):
381
+ raise ValueError("Cannot find {}".format(opt.resume))
382
+ if os.path.isfile(opt.resume):
383
+ paths = opt.resume.split("/")
384
+ idx = len(paths)-paths[::-1].index("logs")+1
385
+ logdir = "/".join(paths[:idx])
386
+ ckpt = opt.resume
387
+ else:
388
+ assert os.path.isdir(opt.resume), opt.resume
389
+ logdir = opt.resume.rstrip("/")
390
+ ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
391
+
392
+ opt.resume_from_checkpoint = ckpt
393
+ base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
394
+ opt.base = base_configs+opt.base
395
+ _tmp = logdir.split("/")
396
+ nowname = _tmp[_tmp.index("logs")+1]
397
+ else:
398
+ if opt.name:
399
+ name = "_"+opt.name
400
+ elif opt.base:
401
+ cfg_fname = os.path.split(opt.base[0])[-1]
402
+ cfg_name = os.path.splitext(cfg_fname)[0]
403
+ name = "_"+cfg_name
404
+ else:
405
+ name = ""
406
+ nowname = now+name+opt.postfix
407
+ logdir = os.path.join("logs", nowname)
408
+
409
+ ckptdir = os.path.join(logdir, "checkpoints")
410
+ cfgdir = os.path.join(logdir, "configs")
411
+ seed_everything(opt.seed)
412
+
413
+ try:
414
+ # init and save configs
415
+ configs = [OmegaConf.load(cfg) for cfg in opt.base]
416
+ cli = OmegaConf.from_dotlist(unknown)
417
+ config = OmegaConf.merge(*configs, cli)
418
+ lightning_config = config.pop("lightning", OmegaConf.create())
419
+ # merge trainer cli with config
420
+ trainer_config = lightning_config.get("trainer", OmegaConf.create())
421
+ # default to ddp
422
+ trainer_config["distributed_backend"] = "ddp"
423
+ for k in nondefault_trainer_args(opt):
424
+ trainer_config[k] = getattr(opt, k)
425
+ if not "gpus" in trainer_config:
426
+ del trainer_config["distributed_backend"]
427
+ cpu = True
428
+ else:
429
+ gpuinfo = trainer_config["gpus"]
430
+ print(f"Running on GPUs {gpuinfo}")
431
+ cpu = False
432
+ trainer_opt = argparse.Namespace(**trainer_config)
433
+ lightning_config.trainer = trainer_config
434
+
435
+ # model
436
+ model = instantiate_from_config(config.model)
437
+
438
+ # trainer and callbacks
439
+ trainer_kwargs = dict()
440
+
441
+ # default logger configs
442
+ # NOTE wandb < 0.10.0 interferes with shutdown
443
+ # wandb >= 0.10.0 seems to fix it but still interferes with pudb
444
+ # debugging (wrongly sized pudb ui)
445
+ # thus prefer testtube for now
446
+ default_logger_cfgs = {
447
+ "wandb": {
448
+ "target": "pytorch_lightning.loggers.WandbLogger",
449
+ "params": {
450
+ "name": nowname,
451
+ "save_dir": logdir,
452
+ "offline": opt.debug,
453
+ "id": nowname,
454
+ }
455
+ },
456
+ "testtube": {
457
+ "target": "pytorch_lightning.loggers.TestTubeLogger",
458
+ "params": {
459
+ "name": "testtube",
460
+ "save_dir": logdir,
461
+ }
462
+ },
463
+ }
464
+ default_logger_cfg = default_logger_cfgs["testtube"]
465
+ logger_cfg = lightning_config.logger or OmegaConf.create()
466
+ logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
467
+ trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
468
+
469
+ # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
470
+ # specify which metric is used to determine best models
471
+ default_modelckpt_cfg = {
472
+ "target": "pytorch_lightning.callbacks.ModelCheckpoint",
473
+ "params": {
474
+ "dirpath": ckptdir,
475
+ "filename": "{epoch:06}",
476
+ "verbose": True,
477
+ "save_last": True,
478
+ }
479
+ }
480
+ if hasattr(model, "monitor"):
481
+ print(f"Monitoring {model.monitor} as checkpoint metric.")
482
+ default_modelckpt_cfg["params"]["monitor"] = model.monitor
483
+ default_modelckpt_cfg["params"]["save_top_k"] = 3
484
+
485
+ modelckpt_cfg = lightning_config.modelcheckpoint or OmegaConf.create()
486
+ modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
487
+ trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
488
+
489
+ # add callback which sets up log directory
490
+ default_callbacks_cfg = {
491
+ "setup_callback": {
492
+ "target": "main.SetupCallback",
493
+ "params": {
494
+ "resume": opt.resume,
495
+ "now": now,
496
+ "logdir": logdir,
497
+ "ckptdir": ckptdir,
498
+ "cfgdir": cfgdir,
499
+ "config": config,
500
+ "lightning_config": lightning_config,
501
+ }
502
+ },
503
+ "image_logger": {
504
+ "target": "main.ImageLogger",
505
+ "params": {
506
+ "batch_frequency": 750,
507
+ "max_images": 4,
508
+ "clamp": True
509
+ }
510
+ },
511
+ "learning_rate_logger": {
512
+ "target": "main.LearningRateMonitor",
513
+ "params": {
514
+ "logging_interval": "step",
515
+ #"log_momentum": True
516
+ }
517
+ },
518
+ }
519
+ callbacks_cfg = lightning_config.callbacks or OmegaConf.create()
520
+ callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
521
+ trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
522
+
523
+ trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
524
+
525
+ # data
526
+ data = instantiate_from_config(config.data)
527
+ # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
528
+ # calling these ourselves should not be necessary but it is.
529
+ # lightning still takes care of proper multiprocessing though
530
+ data.prepare_data()
531
+ data.setup()
532
+
533
+ # configure learning rate
534
+ bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
535
+ if not cpu:
536
+ ngpu = len(lightning_config.trainer.gpus.strip(",").split(','))
537
+ else:
538
+ ngpu = 1
539
+ accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches or 1
540
+ print(f"accumulate_grad_batches = {accumulate_grad_batches}")
541
+ lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
542
+ model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
543
+ print("Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
544
+ model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
545
+
546
+ # allow checkpointing via USR1
547
+ def melk(*args, **kwargs):
548
+ # run all checkpoint hooks
549
+ if trainer.global_rank == 0:
550
+ print("Summoning checkpoint.")
551
+ ckpt_path = os.path.join(ckptdir, "last.ckpt")
552
+ trainer.save_checkpoint(ckpt_path)
553
+
554
+ def divein(*args, **kwargs):
555
+ if trainer.global_rank == 0:
556
+ import pudb; pudb.set_trace()
557
+
558
+ import signal
559
+ signal.signal(signal.SIGUSR1, melk)
560
+ signal.signal(signal.SIGUSR2, divein)
561
+
562
+ # run
563
+ if opt.train:
564
+ try:
565
+ trainer.fit(model, data)
566
+ except Exception:
567
+ melk()
568
+ raise
569
+ if not opt.no_test and not trainer.interrupted:
570
+ trainer.test(model, data)
571
+ except Exception:
572
+ if opt.debug and trainer.global_rank==0:
573
+ try:
574
+ import pudb as debugger
575
+ except ImportError:
576
+ import pdb as debugger
577
+ debugger.post_mortem()
578
+ raise
579
+ finally:
580
+ # move newly created debug project to debug_runs
581
+ if opt.debug and not opt.resume and trainer.global_rank==0:
582
+ dst, name = os.path.split(logdir)
583
+ dst = os.path.join(dst, "debug_runs", name)
584
+ os.makedirs(os.path.split(dst)[0], exist_ok=True)
585
+ os.rename(logdir, dst)
taming-transformers/scripts/extract_depth.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from tqdm import trange
5
+ from PIL import Image
6
+
7
+
8
+ def get_state(gpu):
9
+ import torch
10
+ midas = torch.hub.load("intel-isl/MiDaS", "MiDaS")
11
+ if gpu:
12
+ midas.cuda()
13
+ midas.eval()
14
+
15
+ midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
16
+ transform = midas_transforms.default_transform
17
+
18
+ state = {"model": midas,
19
+ "transform": transform}
20
+ return state
21
+
22
+
23
+ def depth_to_rgba(x):
24
+ assert x.dtype == np.float32
25
+ assert len(x.shape) == 2
26
+ y = x.copy()
27
+ y.dtype = np.uint8
28
+ y = y.reshape(x.shape+(4,))
29
+ return np.ascontiguousarray(y)
30
+
31
+
32
+ def rgba_to_depth(x):
33
+ assert x.dtype == np.uint8
34
+ assert len(x.shape) == 3 and x.shape[2] == 4
35
+ y = x.copy()
36
+ y.dtype = np.float32
37
+ y = y.reshape(x.shape[:2])
38
+ return np.ascontiguousarray(y)
39
+
40
+
41
+ def run(x, state):
42
+ model = state["model"]
43
+ transform = state["transform"]
44
+ hw = x.shape[:2]
45
+ with torch.no_grad():
46
+ prediction = model(transform((x + 1.0) * 127.5).cuda())
47
+ prediction = torch.nn.functional.interpolate(
48
+ prediction.unsqueeze(1),
49
+ size=hw,
50
+ mode="bicubic",
51
+ align_corners=False,
52
+ ).squeeze()
53
+ output = prediction.cpu().numpy()
54
+ return output
55
+
56
+
57
+ def get_filename(relpath, level=-2):
58
+ # save class folder structure and filename:
59
+ fn = relpath.split(os.sep)[level:]
60
+ folder = fn[-2]
61
+ file = fn[-1].split('.')[0]
62
+ return folder, file
63
+
64
+
65
+ def save_depth(dataset, path, debug=False):
66
+ os.makedirs(path)
67
+ N = len(dset)
68
+ if debug:
69
+ N = 10
70
+ state = get_state(gpu=True)
71
+ for idx in trange(N, desc="Data"):
72
+ ex = dataset[idx]
73
+ image, relpath = ex["image"], ex["relpath"]
74
+ folder, filename = get_filename(relpath)
75
+ # prepare
76
+ folderabspath = os.path.join(path, folder)
77
+ os.makedirs(folderabspath, exist_ok=True)
78
+ savepath = os.path.join(folderabspath, filename)
79
+ # run model
80
+ xout = run(image, state)
81
+ I = depth_to_rgba(xout)
82
+ Image.fromarray(I).save("{}.png".format(savepath))
83
+
84
+
85
+ if __name__ == "__main__":
86
+ from taming.data.imagenet import ImageNetTrain, ImageNetValidation
87
+ out = "data/imagenet_depth"
88
+ if not os.path.exists(out):
89
+ print("Please create a folder or symlink '{}' to extract depth data ".format(out) +
90
+ "(be prepared that the output size will be larger than ImageNet itself).")
91
+ exit(1)
92
+
93
+ # go
94
+ dset = ImageNetValidation()
95
+ abspath = os.path.join(out, "val")
96
+ if os.path.exists(abspath):
97
+ print("{} exists - not doing anything.".format(abspath))
98
+ else:
99
+ print("preparing {}".format(abspath))
100
+ save_depth(dset, abspath)
101
+ print("done with validation split")
102
+
103
+ dset = ImageNetTrain()
104
+ abspath = os.path.join(out, "train")
105
+ if os.path.exists(abspath):
106
+ print("{} exists - not doing anything.".format(abspath))
107
+ else:
108
+ print("preparing {}".format(abspath))
109
+ save_depth(dset, abspath)
110
+ print("done with train split")
111
+
112
+ print("done done.")
taming-transformers/scripts/extract_segmentation.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+ import numpy as np
3
+ import scipy
4
+ import torch
5
+ import torch.nn as nn
6
+ from scipy import ndimage
7
+ from tqdm import tqdm, trange
8
+ from PIL import Image
9
+ import torch.hub
10
+ import torchvision
11
+ import torch.nn.functional as F
12
+
13
+ # download deeplabv2_resnet101_msc-cocostuff164k-100000.pth from
14
+ # https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/deeplabv2_resnet101_msc-cocostuff164k-100000.pth
15
+ # and put the path here
16
+ CKPT_PATH = "TODO"
17
+
18
+ rescale = lambda x: (x + 1.) / 2.
19
+
20
+ def rescale_bgr(x):
21
+ x = (x+1)*127.5
22
+ x = torch.flip(x, dims=[0])
23
+ return x
24
+
25
+
26
+ class COCOStuffSegmenter(nn.Module):
27
+ def __init__(self, config):
28
+ super().__init__()
29
+ self.config = config
30
+ self.n_labels = 182
31
+ model = torch.hub.load("kazuto1011/deeplab-pytorch", "deeplabv2_resnet101", n_classes=self.n_labels)
32
+ ckpt_path = CKPT_PATH
33
+ model.load_state_dict(torch.load(ckpt_path))
34
+ self.model = model
35
+
36
+ normalize = torchvision.transforms.Normalize(mean=self.mean, std=self.std)
37
+ self.image_transform = torchvision.transforms.Compose([
38
+ torchvision.transforms.Lambda(lambda image: torch.stack(
39
+ [normalize(rescale_bgr(x)) for x in image]))
40
+ ])
41
+
42
+ def forward(self, x, upsample=None):
43
+ x = self._pre_process(x)
44
+ x = self.model(x)
45
+ if upsample is not None:
46
+ x = torch.nn.functional.upsample_bilinear(x, size=upsample)
47
+ return x
48
+
49
+ def _pre_process(self, x):
50
+ x = self.image_transform(x)
51
+ return x
52
+
53
+ @property
54
+ def mean(self):
55
+ # bgr
56
+ return [104.008, 116.669, 122.675]
57
+
58
+ @property
59
+ def std(self):
60
+ return [1.0, 1.0, 1.0]
61
+
62
+ @property
63
+ def input_size(self):
64
+ return [3, 224, 224]
65
+
66
+
67
+ def run_model(img, model):
68
+ model = model.eval()
69
+ with torch.no_grad():
70
+ segmentation = model(img, upsample=(img.shape[2], img.shape[3]))
71
+ segmentation = torch.argmax(segmentation, dim=1, keepdim=True)
72
+ return segmentation.detach().cpu()
73
+
74
+
75
+ def get_input(batch, k):
76
+ x = batch[k]
77
+ if len(x.shape) == 3:
78
+ x = x[..., None]
79
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
80
+ return x.float()
81
+
82
+
83
+ def save_segmentation(segmentation, path):
84
+ # --> class label to uint8, save as png
85
+ os.makedirs(os.path.dirname(path), exist_ok=True)
86
+ assert len(segmentation.shape)==4
87
+ assert segmentation.shape[0]==1
88
+ for seg in segmentation:
89
+ seg = seg.permute(1,2,0).numpy().squeeze().astype(np.uint8)
90
+ seg = Image.fromarray(seg)
91
+ seg.save(path)
92
+
93
+
94
+ def iterate_dataset(dataloader, destpath, model):
95
+ os.makedirs(destpath, exist_ok=True)
96
+ num_processed = 0
97
+ for i, batch in tqdm(enumerate(dataloader), desc="Data"):
98
+ try:
99
+ img = get_input(batch, "image")
100
+ img = img.cuda()
101
+ seg = run_model(img, model)
102
+
103
+ path = batch["relative_file_path_"][0]
104
+ path = os.path.splitext(path)[0]
105
+
106
+ path = os.path.join(destpath, path + ".png")
107
+ save_segmentation(seg, path)
108
+ num_processed += 1
109
+ except Exception as e:
110
+ print(e)
111
+ print("but anyhow..")
112
+
113
+ print("Processed {} files. Bye.".format(num_processed))
114
+
115
+
116
+ from taming.data.sflckr import Examples
117
+ from torch.utils.data import DataLoader
118
+
119
+ if __name__ == "__main__":
120
+ dest = sys.argv[1]
121
+ batchsize = 1
122
+ print("Running with batch-size {}, saving to {}...".format(batchsize, dest))
123
+
124
+ model = COCOStuffSegmenter({}).cuda()
125
+ print("Instantiated model.")
126
+
127
+ dataset = Examples()
128
+ dloader = DataLoader(dataset, batch_size=batchsize)
129
+ iterate_dataset(dataloader=dloader, destpath=dest, model=model)
130
+ print("done.")
taming-transformers/scripts/extract_submodel.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+
4
+ if __name__ == "__main__":
5
+ inpath = sys.argv[1]
6
+ outpath = sys.argv[2]
7
+ submodel = "cond_stage_model"
8
+ if len(sys.argv) > 3:
9
+ submodel = sys.argv[3]
10
+
11
+ print("Extracting {} from {} to {}.".format(submodel, inpath, outpath))
12
+
13
+ sd = torch.load(inpath, map_location="cpu")
14
+ new_sd = {"state_dict": dict((k.split(".", 1)[-1],v)
15
+ for k,v in sd["state_dict"].items()
16
+ if k.startswith("cond_stage_model"))}
17
+ torch.save(new_sd, outpath)
taming-transformers/scripts/make_samples.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, os, sys, glob, math, time
2
+ import torch
3
+ import numpy as np
4
+ from omegaconf import OmegaConf
5
+ from PIL import Image
6
+ from main import instantiate_from_config, DataModuleFromConfig
7
+ from torch.utils.data import DataLoader
8
+ from torch.utils.data.dataloader import default_collate
9
+ from tqdm import trange
10
+
11
+
12
+ def save_image(x, path):
13
+ c,h,w = x.shape
14
+ assert c==3
15
+ x = ((x.detach().cpu().numpy().transpose(1,2,0)+1.0)*127.5).clip(0,255).astype(np.uint8)
16
+ Image.fromarray(x).save(path)
17
+
18
+
19
+ @torch.no_grad()
20
+ def run_conditional(model, dsets, outdir, top_k, temperature, batch_size=1):
21
+ if len(dsets.datasets) > 1:
22
+ split = sorted(dsets.datasets.keys())[0]
23
+ dset = dsets.datasets[split]
24
+ else:
25
+ dset = next(iter(dsets.datasets.values()))
26
+ print("Dataset: ", dset.__class__.__name__)
27
+ for start_idx in trange(0,len(dset)-batch_size+1,batch_size):
28
+ indices = list(range(start_idx, start_idx+batch_size))
29
+ example = default_collate([dset[i] for i in indices])
30
+
31
+ x = model.get_input("image", example).to(model.device)
32
+ for i in range(x.shape[0]):
33
+ save_image(x[i], os.path.join(outdir, "originals",
34
+ "{:06}.png".format(indices[i])))
35
+
36
+ cond_key = model.cond_stage_key
37
+ c = model.get_input(cond_key, example).to(model.device)
38
+
39
+ scale_factor = 1.0
40
+ quant_z, z_indices = model.encode_to_z(x)
41
+ quant_c, c_indices = model.encode_to_c(c)
42
+
43
+ cshape = quant_z.shape
44
+
45
+ xrec = model.first_stage_model.decode(quant_z)
46
+ for i in range(xrec.shape[0]):
47
+ save_image(xrec[i], os.path.join(outdir, "reconstructions",
48
+ "{:06}.png".format(indices[i])))
49
+
50
+ if cond_key == "segmentation":
51
+ # get image from segmentation mask
52
+ num_classes = c.shape[1]
53
+ c = torch.argmax(c, dim=1, keepdim=True)
54
+ c = torch.nn.functional.one_hot(c, num_classes=num_classes)
55
+ c = c.squeeze(1).permute(0, 3, 1, 2).float()
56
+ c = model.cond_stage_model.to_rgb(c)
57
+
58
+ idx = z_indices
59
+
60
+ half_sample = False
61
+ if half_sample:
62
+ start = idx.shape[1]//2
63
+ else:
64
+ start = 0
65
+
66
+ idx[:,start:] = 0
67
+ idx = idx.reshape(cshape[0],cshape[2],cshape[3])
68
+ start_i = start//cshape[3]
69
+ start_j = start %cshape[3]
70
+
71
+ cidx = c_indices
72
+ cidx = cidx.reshape(quant_c.shape[0],quant_c.shape[2],quant_c.shape[3])
73
+
74
+ sample = True
75
+
76
+ for i in range(start_i,cshape[2]-0):
77
+ if i <= 8:
78
+ local_i = i
79
+ elif cshape[2]-i < 8:
80
+ local_i = 16-(cshape[2]-i)
81
+ else:
82
+ local_i = 8
83
+ for j in range(start_j,cshape[3]-0):
84
+ if j <= 8:
85
+ local_j = j
86
+ elif cshape[3]-j < 8:
87
+ local_j = 16-(cshape[3]-j)
88
+ else:
89
+ local_j = 8
90
+
91
+ i_start = i-local_i
92
+ i_end = i_start+16
93
+ j_start = j-local_j
94
+ j_end = j_start+16
95
+ patch = idx[:,i_start:i_end,j_start:j_end]
96
+ patch = patch.reshape(patch.shape[0],-1)
97
+ cpatch = cidx[:, i_start:i_end, j_start:j_end]
98
+ cpatch = cpatch.reshape(cpatch.shape[0], -1)
99
+ patch = torch.cat((cpatch, patch), dim=1)
100
+ logits,_ = model.transformer(patch[:,:-1])
101
+ logits = logits[:, -256:, :]
102
+ logits = logits.reshape(cshape[0],16,16,-1)
103
+ logits = logits[:,local_i,local_j,:]
104
+
105
+ logits = logits/temperature
106
+
107
+ if top_k is not None:
108
+ logits = model.top_k_logits(logits, top_k)
109
+ # apply softmax to convert to probabilities
110
+ probs = torch.nn.functional.softmax(logits, dim=-1)
111
+ # sample from the distribution or take the most likely
112
+ if sample:
113
+ ix = torch.multinomial(probs, num_samples=1)
114
+ else:
115
+ _, ix = torch.topk(probs, k=1, dim=-1)
116
+ idx[:,i,j] = ix
117
+
118
+ xsample = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
119
+ for i in range(xsample.shape[0]):
120
+ save_image(xsample[i], os.path.join(outdir, "samples",
121
+ "{:06}.png".format(indices[i])))
122
+
123
+
124
+ def get_parser():
125
+ parser = argparse.ArgumentParser()
126
+ parser.add_argument(
127
+ "-r",
128
+ "--resume",
129
+ type=str,
130
+ nargs="?",
131
+ help="load from logdir or checkpoint in logdir",
132
+ )
133
+ parser.add_argument(
134
+ "-b",
135
+ "--base",
136
+ nargs="*",
137
+ metavar="base_config.yaml",
138
+ help="paths to base configs. Loaded from left-to-right. "
139
+ "Parameters can be overwritten or added with command-line options of the form `--key value`.",
140
+ default=list(),
141
+ )
142
+ parser.add_argument(
143
+ "-c",
144
+ "--config",
145
+ nargs="?",
146
+ metavar="single_config.yaml",
147
+ help="path to single config. If specified, base configs will be ignored "
148
+ "(except for the last one if left unspecified).",
149
+ const=True,
150
+ default="",
151
+ )
152
+ parser.add_argument(
153
+ "--ignore_base_data",
154
+ action="store_true",
155
+ help="Ignore data specification from base configs. Useful if you want "
156
+ "to specify a custom datasets on the command line.",
157
+ )
158
+ parser.add_argument(
159
+ "--outdir",
160
+ required=True,
161
+ type=str,
162
+ help="Where to write outputs to.",
163
+ )
164
+ parser.add_argument(
165
+ "--top_k",
166
+ type=int,
167
+ default=100,
168
+ help="Sample from among top-k predictions.",
169
+ )
170
+ parser.add_argument(
171
+ "--temperature",
172
+ type=float,
173
+ default=1.0,
174
+ help="Sampling temperature.",
175
+ )
176
+ return parser
177
+
178
+
179
+ def load_model_from_config(config, sd, gpu=True, eval_mode=True):
180
+ if "ckpt_path" in config.params:
181
+ print("Deleting the restore-ckpt path from the config...")
182
+ config.params.ckpt_path = None
183
+ if "downsample_cond_size" in config.params:
184
+ print("Deleting downsample-cond-size from the config and setting factor=0.5 instead...")
185
+ config.params.downsample_cond_size = -1
186
+ config.params["downsample_cond_factor"] = 0.5
187
+ try:
188
+ if "ckpt_path" in config.params.first_stage_config.params:
189
+ config.params.first_stage_config.params.ckpt_path = None
190
+ print("Deleting the first-stage restore-ckpt path from the config...")
191
+ if "ckpt_path" in config.params.cond_stage_config.params:
192
+ config.params.cond_stage_config.params.ckpt_path = None
193
+ print("Deleting the cond-stage restore-ckpt path from the config...")
194
+ except:
195
+ pass
196
+
197
+ model = instantiate_from_config(config)
198
+ if sd is not None:
199
+ missing, unexpected = model.load_state_dict(sd, strict=False)
200
+ print(f"Missing Keys in State Dict: {missing}")
201
+ print(f"Unexpected Keys in State Dict: {unexpected}")
202
+ if gpu:
203
+ model.cuda()
204
+ if eval_mode:
205
+ model.eval()
206
+ return {"model": model}
207
+
208
+
209
+ def get_data(config):
210
+ # get data
211
+ data = instantiate_from_config(config.data)
212
+ data.prepare_data()
213
+ data.setup()
214
+ return data
215
+
216
+
217
+ def load_model_and_dset(config, ckpt, gpu, eval_mode):
218
+ # get data
219
+ dsets = get_data(config) # calls data.config ...
220
+
221
+ # now load the specified checkpoint
222
+ if ckpt:
223
+ pl_sd = torch.load(ckpt, map_location="cpu")
224
+ global_step = pl_sd["global_step"]
225
+ else:
226
+ pl_sd = {"state_dict": None}
227
+ global_step = None
228
+ model = load_model_from_config(config.model,
229
+ pl_sd["state_dict"],
230
+ gpu=gpu,
231
+ eval_mode=eval_mode)["model"]
232
+ return dsets, model, global_step
233
+
234
+
235
+ if __name__ == "__main__":
236
+ sys.path.append(os.getcwd())
237
+
238
+ parser = get_parser()
239
+
240
+ opt, unknown = parser.parse_known_args()
241
+
242
+ ckpt = None
243
+ if opt.resume:
244
+ if not os.path.exists(opt.resume):
245
+ raise ValueError("Cannot find {}".format(opt.resume))
246
+ if os.path.isfile(opt.resume):
247
+ paths = opt.resume.split("/")
248
+ try:
249
+ idx = len(paths)-paths[::-1].index("logs")+1
250
+ except ValueError:
251
+ idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
252
+ logdir = "/".join(paths[:idx])
253
+ ckpt = opt.resume
254
+ else:
255
+ assert os.path.isdir(opt.resume), opt.resume
256
+ logdir = opt.resume.rstrip("/")
257
+ ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
258
+ print(f"logdir:{logdir}")
259
+ base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
260
+ opt.base = base_configs+opt.base
261
+
262
+ if opt.config:
263
+ if type(opt.config) == str:
264
+ opt.base = [opt.config]
265
+ else:
266
+ opt.base = [opt.base[-1]]
267
+
268
+ configs = [OmegaConf.load(cfg) for cfg in opt.base]
269
+ cli = OmegaConf.from_dotlist(unknown)
270
+ if opt.ignore_base_data:
271
+ for config in configs:
272
+ if hasattr(config, "data"): del config["data"]
273
+ config = OmegaConf.merge(*configs, cli)
274
+
275
+ print(ckpt)
276
+ gpu = True
277
+ eval_mode = True
278
+ show_config = False
279
+ if show_config:
280
+ print(OmegaConf.to_container(config))
281
+
282
+ dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
283
+ print(f"Global step: {global_step}")
284
+
285
+ outdir = os.path.join(opt.outdir, "{:06}_{}_{}".format(global_step,
286
+ opt.top_k,
287
+ opt.temperature))
288
+ os.makedirs(outdir, exist_ok=True)
289
+ print("Writing samples to ", outdir)
290
+ for k in ["originals", "reconstructions", "samples"]:
291
+ os.makedirs(os.path.join(outdir, k), exist_ok=True)
292
+ run_conditional(model, dsets, outdir, opt.top_k, opt.temperature)
taming-transformers/scripts/make_scene_samples.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import sys
4
+ from itertools import product
5
+ from pathlib import Path
6
+ from typing import Literal, List, Optional, Tuple
7
+
8
+ import numpy as np
9
+ import torch
10
+ from omegaconf import OmegaConf
11
+ from pytorch_lightning import seed_everything
12
+ from torch import Tensor
13
+ from torchvision.utils import save_image
14
+ from tqdm import tqdm
15
+
16
+ from scripts.make_samples import get_parser, load_model_and_dset
17
+ from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
18
+ from taming.data.helper_types import BoundingBox, Annotation
19
+ from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
20
+ from taming.models.cond_transformer import Net2NetTransformer
21
+
22
+ seed_everything(42424242)
23
+ device: Literal['cuda', 'cpu'] = 'cuda'
24
+ first_stage_factor = 16
25
+ trained_on_res = 256
26
+
27
+
28
+ def _helper(coord: int, coord_max: int, coord_window: int) -> (int, int):
29
+ assert 0 <= coord < coord_max
30
+ coord_desired_center = (coord_window - 1) // 2
31
+ return np.clip(coord - coord_desired_center, 0, coord_max - coord_window)
32
+
33
+
34
+ def get_crop_coordinates(x: int, y: int) -> BoundingBox:
35
+ WIDTH, HEIGHT = desired_z_shape[1], desired_z_shape[0]
36
+ x0 = _helper(x, WIDTH, first_stage_factor) / WIDTH
37
+ y0 = _helper(y, HEIGHT, first_stage_factor) / HEIGHT
38
+ w = first_stage_factor / WIDTH
39
+ h = first_stage_factor / HEIGHT
40
+ return x0, y0, w, h
41
+
42
+
43
+ def get_z_indices_crop_out(z_indices: Tensor, predict_x: int, predict_y: int) -> Tensor:
44
+ WIDTH, HEIGHT = desired_z_shape[1], desired_z_shape[0]
45
+ x0 = _helper(predict_x, WIDTH, first_stage_factor)
46
+ y0 = _helper(predict_y, HEIGHT, first_stage_factor)
47
+ no_images = z_indices.shape[0]
48
+ cut_out_1 = z_indices[:, y0:predict_y, x0:x0+first_stage_factor].reshape((no_images, -1))
49
+ cut_out_2 = z_indices[:, predict_y, x0:predict_x]
50
+ return torch.cat((cut_out_1, cut_out_2), dim=1)
51
+
52
+
53
+ @torch.no_grad()
54
+ def sample(model: Net2NetTransformer, annotations: List[Annotation], dataset: AnnotatedObjectsDataset,
55
+ conditional_builder: ObjectsCenterPointsConditionalBuilder, no_samples: int,
56
+ temperature: float, top_k: int) -> Tensor:
57
+ x_max, y_max = desired_z_shape[1], desired_z_shape[0]
58
+
59
+ annotations = [a._replace(category_no=dataset.get_category_number(a.category_id)) for a in annotations]
60
+
61
+ recompute_conditional = any((desired_resolution[0] > trained_on_res, desired_resolution[1] > trained_on_res))
62
+ if not recompute_conditional:
63
+ crop_coordinates = get_crop_coordinates(0, 0)
64
+ conditional_indices = conditional_builder.build(annotations, crop_coordinates)
65
+ c_indices = conditional_indices.to(device).repeat(no_samples, 1)
66
+ z_indices = torch.zeros((no_samples, 0), device=device).long()
67
+ output_indices = model.sample(z_indices, c_indices, steps=x_max*y_max, temperature=temperature,
68
+ sample=True, top_k=top_k)
69
+ else:
70
+ output_indices = torch.zeros((no_samples, y_max, x_max), device=device).long()
71
+ for predict_y, predict_x in tqdm(product(range(y_max), range(x_max)), desc='sampling_image', total=x_max*y_max):
72
+ crop_coordinates = get_crop_coordinates(predict_x, predict_y)
73
+ z_indices = get_z_indices_crop_out(output_indices, predict_x, predict_y)
74
+ conditional_indices = conditional_builder.build(annotations, crop_coordinates)
75
+ c_indices = conditional_indices.to(device).repeat(no_samples, 1)
76
+ new_index = model.sample(z_indices, c_indices, steps=1, temperature=temperature, sample=True, top_k=top_k)
77
+ output_indices[:, predict_y, predict_x] = new_index[:, -1]
78
+ z_shape = (
79
+ no_samples,
80
+ model.first_stage_model.quantize.e_dim, # codebook embed_dim
81
+ desired_z_shape[0], # z_height
82
+ desired_z_shape[1] # z_width
83
+ )
84
+ x_sample = model.decode_to_img(output_indices, z_shape) * 0.5 + 0.5
85
+ x_sample = x_sample.to('cpu')
86
+
87
+ plotter = conditional_builder.plot
88
+ figure_size = (x_sample.shape[2], x_sample.shape[3])
89
+ scene_graph = conditional_builder.build(annotations, (0., 0., 1., 1.))
90
+ plot = plotter(scene_graph, dataset.get_textual_label_for_category_no, figure_size)
91
+ return torch.cat((x_sample, plot.unsqueeze(0)))
92
+
93
+
94
+ def get_resolution(resolution_str: str) -> (Tuple[int, int], Tuple[int, int]):
95
+ if not resolution_str.count(',') == 1:
96
+ raise ValueError("Give resolution as in 'height,width'")
97
+ res_h, res_w = resolution_str.split(',')
98
+ res_h = max(int(res_h), trained_on_res)
99
+ res_w = max(int(res_w), trained_on_res)
100
+ z_h = int(round(res_h/first_stage_factor))
101
+ z_w = int(round(res_w/first_stage_factor))
102
+ return (z_h, z_w), (z_h*first_stage_factor, z_w*first_stage_factor)
103
+
104
+
105
+ def add_arg_to_parser(parser):
106
+ parser.add_argument(
107
+ "-R",
108
+ "--resolution",
109
+ type=str,
110
+ default='256,256',
111
+ help=f"give resolution in multiples of {first_stage_factor}, default is '256,256'",
112
+ )
113
+ parser.add_argument(
114
+ "-C",
115
+ "--conditional",
116
+ type=str,
117
+ default='objects_bbox',
118
+ help=f"objects_bbox or objects_center_points",
119
+ )
120
+ parser.add_argument(
121
+ "-N",
122
+ "--n_samples_per_layout",
123
+ type=int,
124
+ default=4,
125
+ help=f"how many samples to generate per layout",
126
+ )
127
+ return parser
128
+
129
+
130
+ if __name__ == "__main__":
131
+ sys.path.append(os.getcwd())
132
+
133
+ parser = get_parser()
134
+ parser = add_arg_to_parser(parser)
135
+
136
+ opt, unknown = parser.parse_known_args()
137
+
138
+ ckpt = None
139
+ if opt.resume:
140
+ if not os.path.exists(opt.resume):
141
+ raise ValueError("Cannot find {}".format(opt.resume))
142
+ if os.path.isfile(opt.resume):
143
+ paths = opt.resume.split("/")
144
+ try:
145
+ idx = len(paths)-paths[::-1].index("logs")+1
146
+ except ValueError:
147
+ idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
148
+ logdir = "/".join(paths[:idx])
149
+ ckpt = opt.resume
150
+ else:
151
+ assert os.path.isdir(opt.resume), opt.resume
152
+ logdir = opt.resume.rstrip("/")
153
+ ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
154
+ print(f"logdir:{logdir}")
155
+ base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
156
+ opt.base = base_configs+opt.base
157
+
158
+ if opt.config:
159
+ if type(opt.config) == str:
160
+ opt.base = [opt.config]
161
+ else:
162
+ opt.base = [opt.base[-1]]
163
+
164
+ configs = [OmegaConf.load(cfg) for cfg in opt.base]
165
+ cli = OmegaConf.from_dotlist(unknown)
166
+ if opt.ignore_base_data:
167
+ for config in configs:
168
+ if hasattr(config, "data"):
169
+ del config["data"]
170
+ config = OmegaConf.merge(*configs, cli)
171
+ desired_z_shape, desired_resolution = get_resolution(opt.resolution)
172
+ conditional = opt.conditional
173
+
174
+ print(ckpt)
175
+ gpu = True
176
+ eval_mode = True
177
+ show_config = False
178
+ if show_config:
179
+ print(OmegaConf.to_container(config))
180
+
181
+ dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
182
+ print(f"Global step: {global_step}")
183
+
184
+ data_loader = dsets.val_dataloader()
185
+ print(dsets.datasets["validation"].conditional_builders)
186
+ conditional_builder = dsets.datasets["validation"].conditional_builders[conditional]
187
+
188
+ outdir = Path(opt.outdir).joinpath(f"{global_step:06}_{opt.top_k}_{opt.temperature}")
189
+ outdir.mkdir(exist_ok=True, parents=True)
190
+ print("Writing samples to ", outdir)
191
+
192
+ p_bar_1 = tqdm(enumerate(iter(data_loader)), desc='batch', total=len(data_loader))
193
+ for batch_no, batch in p_bar_1:
194
+ save_img: Optional[Tensor] = None
195
+ for i, annotations in tqdm(enumerate(batch['annotations']), desc='within_batch', total=data_loader.batch_size):
196
+ imgs = sample(model, annotations, dsets.datasets["validation"], conditional_builder,
197
+ opt.n_samples_per_layout, opt.temperature, opt.top_k)
198
+ save_image(imgs, outdir.joinpath(f'{batch_no:04}_{i:02}.png'), n_row=opt.n_samples_per_layout+1)
taming-transformers/scripts/sample_conditional.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, os, sys, glob, math, time
2
+ import torch
3
+ import numpy as np
4
+ from omegaconf import OmegaConf
5
+ import streamlit as st
6
+ from streamlit import caching
7
+ from PIL import Image
8
+ from main import instantiate_from_config, DataModuleFromConfig
9
+ from torch.utils.data import DataLoader
10
+ from torch.utils.data.dataloader import default_collate
11
+
12
+
13
+ rescale = lambda x: (x + 1.) / 2.
14
+
15
+
16
+ def bchw_to_st(x):
17
+ return rescale(x.detach().cpu().numpy().transpose(0,2,3,1))
18
+
19
+ def save_img(xstart, fname):
20
+ I = (xstart.clip(0,1)[0]*255).astype(np.uint8)
21
+ Image.fromarray(I).save(fname)
22
+
23
+
24
+
25
+ def get_interactive_image(resize=False):
26
+ image = st.file_uploader("Input", type=["jpg", "JPEG", "png"])
27
+ if image is not None:
28
+ image = Image.open(image)
29
+ if not image.mode == "RGB":
30
+ image = image.convert("RGB")
31
+ image = np.array(image).astype(np.uint8)
32
+ print("upload image shape: {}".format(image.shape))
33
+ img = Image.fromarray(image)
34
+ if resize:
35
+ img = img.resize((256, 256))
36
+ image = np.array(img)
37
+ return image
38
+
39
+
40
+ def single_image_to_torch(x, permute=True):
41
+ assert x is not None, "Please provide an image through the upload function"
42
+ x = np.array(x)
43
+ x = torch.FloatTensor(x/255.*2. - 1.)[None,...]
44
+ if permute:
45
+ x = x.permute(0, 3, 1, 2)
46
+ return x
47
+
48
+
49
+ def pad_to_M(x, M):
50
+ hp = math.ceil(x.shape[2]/M)*M-x.shape[2]
51
+ wp = math.ceil(x.shape[3]/M)*M-x.shape[3]
52
+ x = torch.nn.functional.pad(x, (0,wp,0,hp,0,0,0,0))
53
+ return x
54
+
55
+ @torch.no_grad()
56
+ def run_conditional(model, dsets):
57
+ if len(dsets.datasets) > 1:
58
+ split = st.sidebar.radio("Split", sorted(dsets.datasets.keys()))
59
+ dset = dsets.datasets[split]
60
+ else:
61
+ dset = next(iter(dsets.datasets.values()))
62
+ batch_size = 1
63
+ start_index = st.sidebar.number_input("Example Index (Size: {})".format(len(dset)), value=0,
64
+ min_value=0,
65
+ max_value=len(dset)-batch_size)
66
+ indices = list(range(start_index, start_index+batch_size))
67
+
68
+ example = default_collate([dset[i] for i in indices])
69
+
70
+ x = model.get_input("image", example).to(model.device)
71
+
72
+ cond_key = model.cond_stage_key
73
+ c = model.get_input(cond_key, example).to(model.device)
74
+
75
+ scale_factor = st.sidebar.slider("Scale Factor", min_value=0.5, max_value=4.0, step=0.25, value=1.00)
76
+ if scale_factor != 1.0:
77
+ x = torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="bicubic")
78
+ c = torch.nn.functional.interpolate(c, scale_factor=scale_factor, mode="bicubic")
79
+
80
+ quant_z, z_indices = model.encode_to_z(x)
81
+ quant_c, c_indices = model.encode_to_c(c)
82
+
83
+ cshape = quant_z.shape
84
+
85
+ xrec = model.first_stage_model.decode(quant_z)
86
+ st.write("image: {}".format(x.shape))
87
+ st.image(bchw_to_st(x), clamp=True, output_format="PNG")
88
+ st.write("image reconstruction: {}".format(xrec.shape))
89
+ st.image(bchw_to_st(xrec), clamp=True, output_format="PNG")
90
+
91
+ if cond_key == "segmentation":
92
+ # get image from segmentation mask
93
+ num_classes = c.shape[1]
94
+ c = torch.argmax(c, dim=1, keepdim=True)
95
+ c = torch.nn.functional.one_hot(c, num_classes=num_classes)
96
+ c = c.squeeze(1).permute(0, 3, 1, 2).float()
97
+ c = model.cond_stage_model.to_rgb(c)
98
+
99
+ st.write(f"{cond_key}: {tuple(c.shape)}")
100
+ st.image(bchw_to_st(c), clamp=True, output_format="PNG")
101
+
102
+ idx = z_indices
103
+
104
+ half_sample = st.sidebar.checkbox("Image Completion", value=False)
105
+ if half_sample:
106
+ start = idx.shape[1]//2
107
+ else:
108
+ start = 0
109
+
110
+ idx[:,start:] = 0
111
+ idx = idx.reshape(cshape[0],cshape[2],cshape[3])
112
+ start_i = start//cshape[3]
113
+ start_j = start %cshape[3]
114
+
115
+ if not half_sample and quant_z.shape == quant_c.shape:
116
+ st.info("Setting idx to c_indices")
117
+ idx = c_indices.clone().reshape(cshape[0],cshape[2],cshape[3])
118
+
119
+ cidx = c_indices
120
+ cidx = cidx.reshape(quant_c.shape[0],quant_c.shape[2],quant_c.shape[3])
121
+
122
+ xstart = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
123
+ st.image(bchw_to_st(xstart), clamp=True, output_format="PNG")
124
+
125
+ temperature = st.number_input("Temperature", value=1.0)
126
+ top_k = st.number_input("Top k", value=100)
127
+ sample = st.checkbox("Sample", value=True)
128
+ update_every = st.number_input("Update every", value=75)
129
+
130
+ st.text(f"Sampling shape ({cshape[2]},{cshape[3]})")
131
+
132
+ animate = st.checkbox("animate")
133
+ if animate:
134
+ import imageio
135
+ outvid = "sampling.mp4"
136
+ writer = imageio.get_writer(outvid, fps=25)
137
+ elapsed_t = st.empty()
138
+ info = st.empty()
139
+ st.text("Sampled")
140
+ if st.button("Sample"):
141
+ output = st.empty()
142
+ start_t = time.time()
143
+ for i in range(start_i,cshape[2]-0):
144
+ if i <= 8:
145
+ local_i = i
146
+ elif cshape[2]-i < 8:
147
+ local_i = 16-(cshape[2]-i)
148
+ else:
149
+ local_i = 8
150
+ for j in range(start_j,cshape[3]-0):
151
+ if j <= 8:
152
+ local_j = j
153
+ elif cshape[3]-j < 8:
154
+ local_j = 16-(cshape[3]-j)
155
+ else:
156
+ local_j = 8
157
+
158
+ i_start = i-local_i
159
+ i_end = i_start+16
160
+ j_start = j-local_j
161
+ j_end = j_start+16
162
+ elapsed_t.text(f"Time: {time.time() - start_t} seconds")
163
+ info.text(f"Step: ({i},{j}) | Local: ({local_i},{local_j}) | Crop: ({i_start}:{i_end},{j_start}:{j_end})")
164
+ patch = idx[:,i_start:i_end,j_start:j_end]
165
+ patch = patch.reshape(patch.shape[0],-1)
166
+ cpatch = cidx[:, i_start:i_end, j_start:j_end]
167
+ cpatch = cpatch.reshape(cpatch.shape[0], -1)
168
+ patch = torch.cat((cpatch, patch), dim=1)
169
+ logits,_ = model.transformer(patch[:,:-1])
170
+ logits = logits[:, -256:, :]
171
+ logits = logits.reshape(cshape[0],16,16,-1)
172
+ logits = logits[:,local_i,local_j,:]
173
+
174
+ logits = logits/temperature
175
+
176
+ if top_k is not None:
177
+ logits = model.top_k_logits(logits, top_k)
178
+ # apply softmax to convert to probabilities
179
+ probs = torch.nn.functional.softmax(logits, dim=-1)
180
+ # sample from the distribution or take the most likely
181
+ if sample:
182
+ ix = torch.multinomial(probs, num_samples=1)
183
+ else:
184
+ _, ix = torch.topk(probs, k=1, dim=-1)
185
+ idx[:,i,j] = ix
186
+
187
+ if (i*cshape[3]+j)%update_every==0:
188
+ xstart = model.decode_to_img(idx[:, :cshape[2], :cshape[3]], cshape,)
189
+
190
+ xstart = bchw_to_st(xstart)
191
+ output.image(xstart, clamp=True, output_format="PNG")
192
+
193
+ if animate:
194
+ writer.append_data((xstart[0]*255).clip(0, 255).astype(np.uint8))
195
+
196
+ xstart = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
197
+ xstart = bchw_to_st(xstart)
198
+ output.image(xstart, clamp=True, output_format="PNG")
199
+ #save_img(xstart, "full_res_sample.png")
200
+ if animate:
201
+ writer.close()
202
+ st.video(outvid)
203
+
204
+
205
+ def get_parser():
206
+ parser = argparse.ArgumentParser()
207
+ parser.add_argument(
208
+ "-r",
209
+ "--resume",
210
+ type=str,
211
+ nargs="?",
212
+ help="load from logdir or checkpoint in logdir",
213
+ )
214
+ parser.add_argument(
215
+ "-b",
216
+ "--base",
217
+ nargs="*",
218
+ metavar="base_config.yaml",
219
+ help="paths to base configs. Loaded from left-to-right. "
220
+ "Parameters can be overwritten or added with command-line options of the form `--key value`.",
221
+ default=list(),
222
+ )
223
+ parser.add_argument(
224
+ "-c",
225
+ "--config",
226
+ nargs="?",
227
+ metavar="single_config.yaml",
228
+ help="path to single config. If specified, base configs will be ignored "
229
+ "(except for the last one if left unspecified).",
230
+ const=True,
231
+ default="",
232
+ )
233
+ parser.add_argument(
234
+ "--ignore_base_data",
235
+ action="store_true",
236
+ help="Ignore data specification from base configs. Useful if you want "
237
+ "to specify a custom datasets on the command line.",
238
+ )
239
+ return parser
240
+
241
+
242
+ def load_model_from_config(config, sd, gpu=True, eval_mode=True):
243
+ if "ckpt_path" in config.params:
244
+ st.warning("Deleting the restore-ckpt path from the config...")
245
+ config.params.ckpt_path = None
246
+ if "downsample_cond_size" in config.params:
247
+ st.warning("Deleting downsample-cond-size from the config and setting factor=0.5 instead...")
248
+ config.params.downsample_cond_size = -1
249
+ config.params["downsample_cond_factor"] = 0.5
250
+ try:
251
+ if "ckpt_path" in config.params.first_stage_config.params:
252
+ config.params.first_stage_config.params.ckpt_path = None
253
+ st.warning("Deleting the first-stage restore-ckpt path from the config...")
254
+ if "ckpt_path" in config.params.cond_stage_config.params:
255
+ config.params.cond_stage_config.params.ckpt_path = None
256
+ st.warning("Deleting the cond-stage restore-ckpt path from the config...")
257
+ except:
258
+ pass
259
+
260
+ model = instantiate_from_config(config)
261
+ if sd is not None:
262
+ missing, unexpected = model.load_state_dict(sd, strict=False)
263
+ st.info(f"Missing Keys in State Dict: {missing}")
264
+ st.info(f"Unexpected Keys in State Dict: {unexpected}")
265
+ if gpu:
266
+ model.cuda()
267
+ if eval_mode:
268
+ model.eval()
269
+ return {"model": model}
270
+
271
+
272
+ def get_data(config):
273
+ # get data
274
+ data = instantiate_from_config(config.data)
275
+ data.prepare_data()
276
+ data.setup()
277
+ return data
278
+
279
+
280
+ @st.cache(allow_output_mutation=True, suppress_st_warning=True)
281
+ def load_model_and_dset(config, ckpt, gpu, eval_mode):
282
+ # get data
283
+ dsets = get_data(config) # calls data.config ...
284
+
285
+ # now load the specified checkpoint
286
+ if ckpt:
287
+ pl_sd = torch.load(ckpt, map_location="cpu")
288
+ global_step = pl_sd["global_step"]
289
+ else:
290
+ pl_sd = {"state_dict": None}
291
+ global_step = None
292
+ model = load_model_from_config(config.model,
293
+ pl_sd["state_dict"],
294
+ gpu=gpu,
295
+ eval_mode=eval_mode)["model"]
296
+ return dsets, model, global_step
297
+
298
+
299
+ if __name__ == "__main__":
300
+ sys.path.append(os.getcwd())
301
+
302
+ parser = get_parser()
303
+
304
+ opt, unknown = parser.parse_known_args()
305
+
306
+ ckpt = None
307
+ if opt.resume:
308
+ if not os.path.exists(opt.resume):
309
+ raise ValueError("Cannot find {}".format(opt.resume))
310
+ if os.path.isfile(opt.resume):
311
+ paths = opt.resume.split("/")
312
+ try:
313
+ idx = len(paths)-paths[::-1].index("logs")+1
314
+ except ValueError:
315
+ idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
316
+ logdir = "/".join(paths[:idx])
317
+ ckpt = opt.resume
318
+ else:
319
+ assert os.path.isdir(opt.resume), opt.resume
320
+ logdir = opt.resume.rstrip("/")
321
+ ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
322
+ print(f"logdir:{logdir}")
323
+ base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
324
+ opt.base = base_configs+opt.base
325
+
326
+ if opt.config:
327
+ if type(opt.config) == str:
328
+ opt.base = [opt.config]
329
+ else:
330
+ opt.base = [opt.base[-1]]
331
+
332
+ configs = [OmegaConf.load(cfg) for cfg in opt.base]
333
+ cli = OmegaConf.from_dotlist(unknown)
334
+ if opt.ignore_base_data:
335
+ for config in configs:
336
+ if hasattr(config, "data"): del config["data"]
337
+ config = OmegaConf.merge(*configs, cli)
338
+
339
+ st.sidebar.text(ckpt)
340
+ gs = st.sidebar.empty()
341
+ gs.text(f"Global step: ?")
342
+ st.sidebar.text("Options")
343
+ #gpu = st.sidebar.checkbox("GPU", value=True)
344
+ gpu = True
345
+ #eval_mode = st.sidebar.checkbox("Eval Mode", value=True)
346
+ eval_mode = True
347
+ #show_config = st.sidebar.checkbox("Show Config", value=False)
348
+ show_config = False
349
+ if show_config:
350
+ st.info("Checkpoint: {}".format(ckpt))
351
+ st.json(OmegaConf.to_container(config))
352
+
353
+ dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
354
+ gs.text(f"Global step: {global_step}")
355
+ run_conditional(model, dsets)
taming-transformers/scripts/sample_fast.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, os, sys, glob
2
+ import torch
3
+ import time
4
+ import numpy as np
5
+ from omegaconf import OmegaConf
6
+ from PIL import Image
7
+ from tqdm import tqdm, trange
8
+ from einops import repeat
9
+
10
+ from main import instantiate_from_config
11
+ from taming.modules.transformer.mingpt import sample_with_past
12
+
13
+
14
+ rescale = lambda x: (x + 1.) / 2.
15
+
16
+
17
+ def chw_to_pillow(x):
18
+ return Image.fromarray((255*rescale(x.detach().cpu().numpy().transpose(1,2,0))).clip(0,255).astype(np.uint8))
19
+
20
+
21
+ @torch.no_grad()
22
+ def sample_classconditional(model, batch_size, class_label, steps=256, temperature=None, top_k=None, callback=None,
23
+ dim_z=256, h=16, w=16, verbose_time=False, top_p=None):
24
+ log = dict()
25
+ assert type(class_label) == int, f'expecting type int but type is {type(class_label)}'
26
+ qzshape = [batch_size, dim_z, h, w]
27
+ assert not model.be_unconditional, 'Expecting a class-conditional Net2NetTransformer.'
28
+ c_indices = repeat(torch.tensor([class_label]), '1 -> b 1', b=batch_size).to(model.device) # class token
29
+ t1 = time.time()
30
+ index_sample = sample_with_past(c_indices, model.transformer, steps=steps,
31
+ sample_logits=True, top_k=top_k, callback=callback,
32
+ temperature=temperature, top_p=top_p)
33
+ if verbose_time:
34
+ sampling_time = time.time() - t1
35
+ print(f"Full sampling takes about {sampling_time:.2f} seconds.")
36
+ x_sample = model.decode_to_img(index_sample, qzshape)
37
+ log["samples"] = x_sample
38
+ log["class_label"] = c_indices
39
+ return log
40
+
41
+
42
+ @torch.no_grad()
43
+ def sample_unconditional(model, batch_size, steps=256, temperature=None, top_k=None, top_p=None, callback=None,
44
+ dim_z=256, h=16, w=16, verbose_time=False):
45
+ log = dict()
46
+ qzshape = [batch_size, dim_z, h, w]
47
+ assert model.be_unconditional, 'Expecting an unconditional model.'
48
+ c_indices = repeat(torch.tensor([model.sos_token]), '1 -> b 1', b=batch_size).to(model.device) # sos token
49
+ t1 = time.time()
50
+ index_sample = sample_with_past(c_indices, model.transformer, steps=steps,
51
+ sample_logits=True, top_k=top_k, callback=callback,
52
+ temperature=temperature, top_p=top_p)
53
+ if verbose_time:
54
+ sampling_time = time.time() - t1
55
+ print(f"Full sampling takes about {sampling_time:.2f} seconds.")
56
+ x_sample = model.decode_to_img(index_sample, qzshape)
57
+ log["samples"] = x_sample
58
+ return log
59
+
60
+
61
+ @torch.no_grad()
62
+ def run(logdir, model, batch_size, temperature, top_k, unconditional=True, num_samples=50000,
63
+ given_classes=None, top_p=None):
64
+ batches = [batch_size for _ in range(num_samples//batch_size)] + [num_samples % batch_size]
65
+ if not unconditional:
66
+ assert given_classes is not None
67
+ print("Running in pure class-conditional sampling mode. I will produce "
68
+ f"{num_samples} samples for each of the {len(given_classes)} classes, "
69
+ f"i.e. {num_samples*len(given_classes)} in total.")
70
+ for class_label in tqdm(given_classes, desc="Classes"):
71
+ for n, bs in tqdm(enumerate(batches), desc="Sampling Class"):
72
+ if bs == 0: break
73
+ logs = sample_classconditional(model, batch_size=bs, class_label=class_label,
74
+ temperature=temperature, top_k=top_k, top_p=top_p)
75
+ save_from_logs(logs, logdir, base_count=n * batch_size, cond_key=logs["class_label"])
76
+ else:
77
+ print(f"Running in unconditional sampling mode, producing {num_samples} samples.")
78
+ for n, bs in tqdm(enumerate(batches), desc="Sampling"):
79
+ if bs == 0: break
80
+ logs = sample_unconditional(model, batch_size=bs, temperature=temperature, top_k=top_k, top_p=top_p)
81
+ save_from_logs(logs, logdir, base_count=n * batch_size)
82
+
83
+
84
+ def save_from_logs(logs, logdir, base_count, key="samples", cond_key=None):
85
+ xx = logs[key]
86
+ for i, x in enumerate(xx):
87
+ x = chw_to_pillow(x)
88
+ count = base_count + i
89
+ if cond_key is None:
90
+ x.save(os.path.join(logdir, f"{count:06}.png"))
91
+ else:
92
+ condlabel = cond_key[i]
93
+ if type(condlabel) == torch.Tensor: condlabel = condlabel.item()
94
+ os.makedirs(os.path.join(logdir, str(condlabel)), exist_ok=True)
95
+ x.save(os.path.join(logdir, str(condlabel), f"{count:06}.png"))
96
+
97
+
98
+ def get_parser():
99
+ def str2bool(v):
100
+ if isinstance(v, bool):
101
+ return v
102
+ if v.lower() in ("yes", "true", "t", "y", "1"):
103
+ return True
104
+ elif v.lower() in ("no", "false", "f", "n", "0"):
105
+ return False
106
+ else:
107
+ raise argparse.ArgumentTypeError("Boolean value expected.")
108
+
109
+ parser = argparse.ArgumentParser()
110
+ parser.add_argument(
111
+ "-r",
112
+ "--resume",
113
+ type=str,
114
+ nargs="?",
115
+ help="load from logdir or checkpoint in logdir",
116
+ )
117
+ parser.add_argument(
118
+ "-o",
119
+ "--outdir",
120
+ type=str,
121
+ nargs="?",
122
+ help="path where the samples will be logged to.",
123
+ default=""
124
+ )
125
+ parser.add_argument(
126
+ "-b",
127
+ "--base",
128
+ nargs="*",
129
+ metavar="base_config.yaml",
130
+ help="paths to base configs. Loaded from left-to-right. "
131
+ "Parameters can be overwritten or added with command-line options of the form `--key value`.",
132
+ default=list(),
133
+ )
134
+ parser.add_argument(
135
+ "-n",
136
+ "--num_samples",
137
+ type=int,
138
+ nargs="?",
139
+ help="num_samples to draw",
140
+ default=50000
141
+ )
142
+ parser.add_argument(
143
+ "--batch_size",
144
+ type=int,
145
+ nargs="?",
146
+ help="the batch size",
147
+ default=25
148
+ )
149
+ parser.add_argument(
150
+ "-k",
151
+ "--top_k",
152
+ type=int,
153
+ nargs="?",
154
+ help="top-k value to sample with",
155
+ default=250,
156
+ )
157
+ parser.add_argument(
158
+ "-t",
159
+ "--temperature",
160
+ type=float,
161
+ nargs="?",
162
+ help="temperature value to sample with",
163
+ default=1.0
164
+ )
165
+ parser.add_argument(
166
+ "-p",
167
+ "--top_p",
168
+ type=float,
169
+ nargs="?",
170
+ help="top-p value to sample with",
171
+ default=1.0
172
+ )
173
+ parser.add_argument(
174
+ "--classes",
175
+ type=str,
176
+ nargs="?",
177
+ help="specify comma-separated classes to sample from. Uses 1000 classes per default.",
178
+ default="imagenet"
179
+ )
180
+ return parser
181
+
182
+
183
+ def load_model_from_config(config, sd, gpu=True, eval_mode=True):
184
+ model = instantiate_from_config(config)
185
+ if sd is not None:
186
+ model.load_state_dict(sd)
187
+ if gpu:
188
+ model.cuda()
189
+ if eval_mode:
190
+ model.eval()
191
+ return {"model": model}
192
+
193
+
194
+ def load_model(config, ckpt, gpu, eval_mode):
195
+ # load the specified checkpoint
196
+ if ckpt:
197
+ pl_sd = torch.load(ckpt, map_location="cpu")
198
+ global_step = pl_sd["global_step"]
199
+ print(f"loaded model from global step {global_step}.")
200
+ else:
201
+ pl_sd = {"state_dict": None}
202
+ global_step = None
203
+ model = load_model_from_config(config.model, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode)["model"]
204
+ return model, global_step
205
+
206
+
207
+ if __name__ == "__main__":
208
+ sys.path.append(os.getcwd())
209
+ parser = get_parser()
210
+
211
+ opt, unknown = parser.parse_known_args()
212
+ assert opt.resume
213
+
214
+ ckpt = None
215
+
216
+ if not os.path.exists(opt.resume):
217
+ raise ValueError("Cannot find {}".format(opt.resume))
218
+ if os.path.isfile(opt.resume):
219
+ paths = opt.resume.split("/")
220
+ try:
221
+ idx = len(paths)-paths[::-1].index("logs")+1
222
+ except ValueError:
223
+ idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
224
+ logdir = "/".join(paths[:idx])
225
+ ckpt = opt.resume
226
+ else:
227
+ assert os.path.isdir(opt.resume), opt.resume
228
+ logdir = opt.resume.rstrip("/")
229
+ ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
230
+
231
+ base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
232
+ opt.base = base_configs+opt.base
233
+
234
+ configs = [OmegaConf.load(cfg) for cfg in opt.base]
235
+ cli = OmegaConf.from_dotlist(unknown)
236
+ config = OmegaConf.merge(*configs, cli)
237
+
238
+ model, global_step = load_model(config, ckpt, gpu=True, eval_mode=True)
239
+
240
+ if opt.outdir:
241
+ print(f"Switching logdir from '{logdir}' to '{opt.outdir}'")
242
+ logdir = opt.outdir
243
+
244
+ if opt.classes == "imagenet":
245
+ given_classes = [i for i in range(1000)]
246
+ else:
247
+ cls_str = opt.classes
248
+ assert not cls_str.endswith(","), 'class string should not end with a ","'
249
+ given_classes = [int(c) for c in cls_str.split(",")]
250
+
251
+ logdir = os.path.join(logdir, "samples", f"top_k_{opt.top_k}_temp_{opt.temperature:.2f}_top_p_{opt.top_p}",
252
+ f"{global_step}")
253
+
254
+ print(f"Logging to {logdir}")
255
+ os.makedirs(logdir, exist_ok=True)
256
+
257
+ run(logdir, model, opt.batch_size, opt.temperature, opt.top_k, unconditional=model.be_unconditional,
258
+ given_classes=given_classes, num_samples=opt.num_samples, top_p=opt.top_p)
259
+
260
+ print("done.")
taming-transformers/setup.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name='taming-transformers',
5
+ version='0.0.1',
6
+ description='Taming Transformers for High-Resolution Image Synthesis',
7
+ packages=find_packages(),
8
+ install_requires=[
9
+ 'torch',
10
+ 'numpy',
11
+ 'tqdm',
12
+ ],
13
+ )
taming-transformers/taming/lr_scheduler.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class LambdaWarmUpCosineScheduler:
5
+ """
6
+ note: use with a base_lr of 1.0
7
+ """
8
+ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9
+ self.lr_warm_up_steps = warm_up_steps
10
+ self.lr_start = lr_start
11
+ self.lr_min = lr_min
12
+ self.lr_max = lr_max
13
+ self.lr_max_decay_steps = max_decay_steps
14
+ self.last_lr = 0.
15
+ self.verbosity_interval = verbosity_interval
16
+
17
+ def schedule(self, n):
18
+ if self.verbosity_interval > 0:
19
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20
+ if n < self.lr_warm_up_steps:
21
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22
+ self.last_lr = lr
23
+ return lr
24
+ else:
25
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26
+ t = min(t, 1.0)
27
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28
+ 1 + np.cos(t * np.pi))
29
+ self.last_lr = lr
30
+ return lr
31
+
32
+ def __call__(self, n):
33
+ return self.schedule(n)
34
+
taming-transformers/taming/models/cond_transformer.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, math
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import pytorch_lightning as pl
5
+
6
+ from main import instantiate_from_config
7
+ from taming.modules.util import SOSProvider
8
+
9
+
10
+ def disabled_train(self, mode=True):
11
+ """Overwrite model.train with this function to make sure train/eval mode
12
+ does not change anymore."""
13
+ return self
14
+
15
+
16
+ class Net2NetTransformer(pl.LightningModule):
17
+ def __init__(self,
18
+ transformer_config,
19
+ first_stage_config,
20
+ cond_stage_config,
21
+ permuter_config=None,
22
+ ckpt_path=None,
23
+ ignore_keys=[],
24
+ first_stage_key="image",
25
+ cond_stage_key="depth",
26
+ downsample_cond_size=-1,
27
+ pkeep=1.0,
28
+ sos_token=0,
29
+ unconditional=False,
30
+ ):
31
+ super().__init__()
32
+ self.be_unconditional = unconditional
33
+ self.sos_token = sos_token
34
+ self.first_stage_key = first_stage_key
35
+ self.cond_stage_key = cond_stage_key
36
+ self.init_first_stage_from_ckpt(first_stage_config)
37
+ self.init_cond_stage_from_ckpt(cond_stage_config)
38
+ if permuter_config is None:
39
+ permuter_config = {"target": "taming.modules.transformer.permuter.Identity"}
40
+ self.permuter = instantiate_from_config(config=permuter_config)
41
+ self.transformer = instantiate_from_config(config=transformer_config)
42
+
43
+ if ckpt_path is not None:
44
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
45
+ self.downsample_cond_size = downsample_cond_size
46
+ self.pkeep = pkeep
47
+
48
+ def init_from_ckpt(self, path, ignore_keys=list()):
49
+ sd = torch.load(path, map_location="cpu")["state_dict"]
50
+ for k in sd.keys():
51
+ for ik in ignore_keys:
52
+ if k.startswith(ik):
53
+ self.print("Deleting key {} from state_dict.".format(k))
54
+ del sd[k]
55
+ self.load_state_dict(sd, strict=False)
56
+ print(f"Restored from {path}")
57
+
58
+ def init_first_stage_from_ckpt(self, config):
59
+ model = instantiate_from_config(config)
60
+ model = model.eval()
61
+ model.train = disabled_train
62
+ self.first_stage_model = model
63
+
64
+ def init_cond_stage_from_ckpt(self, config):
65
+ if config == "__is_first_stage__":
66
+ print("Using first stage also as cond stage.")
67
+ self.cond_stage_model = self.first_stage_model
68
+ elif config == "__is_unconditional__" or self.be_unconditional:
69
+ print(f"Using no cond stage. Assuming the training is intended to be unconditional. "
70
+ f"Prepending {self.sos_token} as a sos token.")
71
+ self.be_unconditional = True
72
+ self.cond_stage_key = self.first_stage_key
73
+ self.cond_stage_model = SOSProvider(self.sos_token)
74
+ else:
75
+ model = instantiate_from_config(config)
76
+ model = model.eval()
77
+ model.train = disabled_train
78
+ self.cond_stage_model = model
79
+
80
+ def forward(self, x, c):
81
+ # one step to produce the logits
82
+ _, z_indices = self.encode_to_z(x)
83
+ _, c_indices = self.encode_to_c(c)
84
+
85
+ if self.training and self.pkeep < 1.0:
86
+ mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape,
87
+ device=z_indices.device))
88
+ mask = mask.round().to(dtype=torch.int64)
89
+ r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
90
+ a_indices = mask*z_indices+(1-mask)*r_indices
91
+ else:
92
+ a_indices = z_indices
93
+
94
+ cz_indices = torch.cat((c_indices, a_indices), dim=1)
95
+
96
+ # target includes all sequence elements (no need to handle first one
97
+ # differently because we are conditioning)
98
+ target = z_indices
99
+ # make the prediction
100
+ logits, _ = self.transformer(cz_indices[:, :-1])
101
+ # cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
102
+ logits = logits[:, c_indices.shape[1]-1:]
103
+
104
+ return logits, target
105
+
106
+ def top_k_logits(self, logits, k):
107
+ v, ix = torch.topk(logits, k)
108
+ out = logits.clone()
109
+ out[out < v[..., [-1]]] = -float('Inf')
110
+ return out
111
+
112
+ @torch.no_grad()
113
+ def sample(self, x, c, steps, temperature=1.0, sample=False, top_k=None,
114
+ callback=lambda k: None):
115
+ x = torch.cat((c,x),dim=1)
116
+ block_size = self.transformer.get_block_size()
117
+ assert not self.transformer.training
118
+ if self.pkeep <= 0.0:
119
+ # one pass suffices since input is pure noise anyway
120
+ assert len(x.shape)==2
121
+ noise_shape = (x.shape[0], steps-1)
122
+ #noise = torch.randint(self.transformer.config.vocab_size, noise_shape).to(x)
123
+ noise = c.clone()[:,x.shape[1]-c.shape[1]:-1]
124
+ x = torch.cat((x,noise),dim=1)
125
+ logits, _ = self.transformer(x)
126
+ # take all logits for now and scale by temp
127
+ logits = logits / temperature
128
+ # optionally crop probabilities to only the top k options
129
+ if top_k is not None:
130
+ logits = self.top_k_logits(logits, top_k)
131
+ # apply softmax to convert to probabilities
132
+ probs = F.softmax(logits, dim=-1)
133
+ # sample from the distribution or take the most likely
134
+ if sample:
135
+ shape = probs.shape
136
+ probs = probs.reshape(shape[0]*shape[1],shape[2])
137
+ ix = torch.multinomial(probs, num_samples=1)
138
+ probs = probs.reshape(shape[0],shape[1],shape[2])
139
+ ix = ix.reshape(shape[0],shape[1])
140
+ else:
141
+ _, ix = torch.topk(probs, k=1, dim=-1)
142
+ # cut off conditioning
143
+ x = ix[:, c.shape[1]-1:]
144
+ else:
145
+ for k in range(steps):
146
+ callback(k)
147
+ assert x.size(1) <= block_size # make sure model can see conditioning
148
+ x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
149
+ logits, _ = self.transformer(x_cond)
150
+ # pluck the logits at the final step and scale by temperature
151
+ logits = logits[:, -1, :] / temperature
152
+ # optionally crop probabilities to only the top k options
153
+ if top_k is not None:
154
+ logits = self.top_k_logits(logits, top_k)
155
+ # apply softmax to convert to probabilities
156
+ probs = F.softmax(logits, dim=-1)
157
+ # sample from the distribution or take the most likely
158
+ if sample:
159
+ ix = torch.multinomial(probs, num_samples=1)
160
+ else:
161
+ _, ix = torch.topk(probs, k=1, dim=-1)
162
+ # append to the sequence and continue
163
+ x = torch.cat((x, ix), dim=1)
164
+ # cut off conditioning
165
+ x = x[:, c.shape[1]:]
166
+ return x
167
+
168
+ @torch.no_grad()
169
+ def encode_to_z(self, x):
170
+ quant_z, _, info = self.first_stage_model.encode(x)
171
+ indices = info[2].view(quant_z.shape[0], -1)
172
+ indices = self.permuter(indices)
173
+ return quant_z, indices
174
+
175
+ @torch.no_grad()
176
+ def encode_to_c(self, c):
177
+ if self.downsample_cond_size > -1:
178
+ c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size))
179
+ quant_c, _, [_,_,indices] = self.cond_stage_model.encode(c)
180
+ if len(indices.shape) > 2:
181
+ indices = indices.view(c.shape[0], -1)
182
+ return quant_c, indices
183
+
184
+ @torch.no_grad()
185
+ def decode_to_img(self, index, zshape):
186
+ index = self.permuter(index, reverse=True)
187
+ bhwc = (zshape[0],zshape[2],zshape[3],zshape[1])
188
+ quant_z = self.first_stage_model.quantize.get_codebook_entry(
189
+ index.reshape(-1), shape=bhwc)
190
+ x = self.first_stage_model.decode(quant_z)
191
+ return x
192
+
193
+ @torch.no_grad()
194
+ def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs):
195
+ log = dict()
196
+
197
+ N = 4
198
+ if lr_interface:
199
+ x, c = self.get_xc(batch, N, diffuse=False, upsample_factor=8)
200
+ else:
201
+ x, c = self.get_xc(batch, N)
202
+ x = x.to(device=self.device)
203
+ c = c.to(device=self.device)
204
+
205
+ quant_z, z_indices = self.encode_to_z(x)
206
+ quant_c, c_indices = self.encode_to_c(c)
207
+
208
+ # create a "half"" sample
209
+ z_start_indices = z_indices[:,:z_indices.shape[1]//2]
210
+ index_sample = self.sample(z_start_indices, c_indices,
211
+ steps=z_indices.shape[1]-z_start_indices.shape[1],
212
+ temperature=temperature if temperature is not None else 1.0,
213
+ sample=True,
214
+ top_k=top_k if top_k is not None else 100,
215
+ callback=callback if callback is not None else lambda k: None)
216
+ x_sample = self.decode_to_img(index_sample, quant_z.shape)
217
+
218
+ # sample
219
+ z_start_indices = z_indices[:, :0]
220
+ index_sample = self.sample(z_start_indices, c_indices,
221
+ steps=z_indices.shape[1],
222
+ temperature=temperature if temperature is not None else 1.0,
223
+ sample=True,
224
+ top_k=top_k if top_k is not None else 100,
225
+ callback=callback if callback is not None else lambda k: None)
226
+ x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape)
227
+
228
+ # det sample
229
+ z_start_indices = z_indices[:, :0]
230
+ index_sample = self.sample(z_start_indices, c_indices,
231
+ steps=z_indices.shape[1],
232
+ sample=False,
233
+ callback=callback if callback is not None else lambda k: None)
234
+ x_sample_det = self.decode_to_img(index_sample, quant_z.shape)
235
+
236
+ # reconstruction
237
+ x_rec = self.decode_to_img(z_indices, quant_z.shape)
238
+
239
+ log["inputs"] = x
240
+ log["reconstructions"] = x_rec
241
+
242
+ if self.cond_stage_key in ["objects_bbox", "objects_center_points"]:
243
+ figure_size = (x_rec.shape[2], x_rec.shape[3])
244
+ dataset = kwargs["pl_module"].trainer.datamodule.datasets["validation"]
245
+ label_for_category_no = dataset.get_textual_label_for_category_no
246
+ plotter = dataset.conditional_builders[self.cond_stage_key].plot
247
+ log["conditioning"] = torch.zeros_like(log["reconstructions"])
248
+ for i in range(quant_c.shape[0]):
249
+ log["conditioning"][i] = plotter(quant_c[i], label_for_category_no, figure_size)
250
+ log["conditioning_rec"] = log["conditioning"]
251
+ elif self.cond_stage_key != "image":
252
+ cond_rec = self.cond_stage_model.decode(quant_c)
253
+ if self.cond_stage_key == "segmentation":
254
+ # get image from segmentation mask
255
+ num_classes = cond_rec.shape[1]
256
+
257
+ c = torch.argmax(c, dim=1, keepdim=True)
258
+ c = F.one_hot(c, num_classes=num_classes)
259
+ c = c.squeeze(1).permute(0, 3, 1, 2).float()
260
+ c = self.cond_stage_model.to_rgb(c)
261
+
262
+ cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True)
263
+ cond_rec = F.one_hot(cond_rec, num_classes=num_classes)
264
+ cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float()
265
+ cond_rec = self.cond_stage_model.to_rgb(cond_rec)
266
+ log["conditioning_rec"] = cond_rec
267
+ log["conditioning"] = c
268
+
269
+ log["samples_half"] = x_sample
270
+ log["samples_nopix"] = x_sample_nopix
271
+ log["samples_det"] = x_sample_det
272
+ return log
273
+
274
+ def get_input(self, key, batch):
275
+ x = batch[key]
276
+ if len(x.shape) == 3:
277
+ x = x[..., None]
278
+ if len(x.shape) == 4:
279
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
280
+ if x.dtype == torch.double:
281
+ x = x.float()
282
+ return x
283
+
284
+ def get_xc(self, batch, N=None):
285
+ x = self.get_input(self.first_stage_key, batch)
286
+ c = self.get_input(self.cond_stage_key, batch)
287
+ if N is not None:
288
+ x = x[:N]
289
+ c = c[:N]
290
+ return x, c
291
+
292
+ def shared_step(self, batch, batch_idx):
293
+ x, c = self.get_xc(batch)
294
+ logits, target = self(x, c)
295
+ loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
296
+ return loss
297
+
298
+ def training_step(self, batch, batch_idx):
299
+ loss = self.shared_step(batch, batch_idx)
300
+ self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
301
+ return loss
302
+
303
+ def validation_step(self, batch, batch_idx):
304
+ loss = self.shared_step(batch, batch_idx)
305
+ self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
306
+ return loss
307
+
308
+ def configure_optimizers(self):
309
+ """
310
+ Following minGPT:
311
+ This long function is unfortunately doing something very simple and is being very defensive:
312
+ We are separating out all parameters of the model into two buckets: those that will experience
313
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
314
+ We are then returning the PyTorch optimizer object.
315
+ """
316
+ # separate out all parameters to those that will and won't experience regularizing weight decay
317
+ decay = set()
318
+ no_decay = set()
319
+ whitelist_weight_modules = (torch.nn.Linear, )
320
+ blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
321
+ for mn, m in self.transformer.named_modules():
322
+ for pn, p in m.named_parameters():
323
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
324
+
325
+ if pn.endswith('bias'):
326
+ # all biases will not be decayed
327
+ no_decay.add(fpn)
328
+ elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
329
+ # weights of whitelist modules will be weight decayed
330
+ decay.add(fpn)
331
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
332
+ # weights of blacklist modules will NOT be weight decayed
333
+ no_decay.add(fpn)
334
+
335
+ # special case the position embedding parameter in the root GPT module as not decayed
336
+ no_decay.add('pos_emb')
337
+
338
+ # validate that we considered every parameter
339
+ param_dict = {pn: p for pn, p in self.transformer.named_parameters()}
340
+ inter_params = decay & no_decay
341
+ union_params = decay | no_decay
342
+ assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
343
+ assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
344
+ % (str(param_dict.keys() - union_params), )
345
+
346
+ # create the pytorch optimizer object
347
+ optim_groups = [
348
+ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01},
349
+ {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
350
+ ]
351
+ optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95))
352
+ return optimizer
taming-transformers/taming/models/dummy_cond_stage.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import Tensor
2
+
3
+
4
+ class DummyCondStage:
5
+ def __init__(self, conditional_key):
6
+ self.conditional_key = conditional_key
7
+ self.train = None
8
+
9
+ def eval(self):
10
+ return self
11
+
12
+ @staticmethod
13
+ def encode(c: Tensor):
14
+ return c, None, (None, None, c)
15
+
16
+ @staticmethod
17
+ def decode(c: Tensor):
18
+ return c
19
+
20
+ @staticmethod
21
+ def to_rgb(c: Tensor):
22
+ return c
taming-transformers/taming/models/vqgan.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import pytorch_lightning as pl
4
+
5
+ from main import instantiate_from_config
6
+
7
+ from taming.modules.diffusionmodules.model import Encoder, Decoder
8
+ from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
9
+ from taming.modules.vqvae.quantize import GumbelQuantize
10
+ from taming.modules.vqvae.quantize import EMAVectorQuantizer
11
+
12
+ class VQModel(pl.LightningModule):
13
+ def __init__(self,
14
+ ddconfig,
15
+ lossconfig,
16
+ n_embed,
17
+ embed_dim,
18
+ ckpt_path=None,
19
+ ignore_keys=[],
20
+ image_key="image",
21
+ colorize_nlabels=None,
22
+ monitor=None,
23
+ remap=None,
24
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
25
+ ):
26
+ super().__init__()
27
+ self.image_key = image_key
28
+ self.encoder = Encoder(**ddconfig)
29
+ self.decoder = Decoder(**ddconfig)
30
+ self.loss = instantiate_from_config(lossconfig)
31
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
32
+ remap=remap, sane_index_shape=sane_index_shape)
33
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
34
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
35
+ if ckpt_path is not None:
36
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
37
+ self.image_key = image_key
38
+ if colorize_nlabels is not None:
39
+ assert type(colorize_nlabels)==int
40
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
41
+ if monitor is not None:
42
+ self.monitor = monitor
43
+
44
+ def init_from_ckpt(self, path, ignore_keys=list()):
45
+ sd = torch.load(path, map_location="cpu")["state_dict"]
46
+ keys = list(sd.keys())
47
+ for k in keys:
48
+ for ik in ignore_keys:
49
+ if k.startswith(ik):
50
+ print("Deleting key {} from state_dict.".format(k))
51
+ del sd[k]
52
+ self.load_state_dict(sd, strict=False)
53
+ print(f"Restored from {path}")
54
+
55
+ def encode(self, x):
56
+ h = self.encoder(x)
57
+ h = self.quant_conv(h)
58
+ quant, emb_loss, info = self.quantize(h)
59
+ return quant, emb_loss, info
60
+
61
+ def decode(self, quant):
62
+ quant = self.post_quant_conv(quant)
63
+ dec = self.decoder(quant)
64
+ return dec
65
+
66
+ def decode_code(self, code_b):
67
+ quant_b = self.quantize.embed_code(code_b)
68
+ dec = self.decode(quant_b)
69
+ return dec
70
+
71
+ def forward(self, input):
72
+ quant, diff, _ = self.encode(input)
73
+ dec = self.decode(quant)
74
+ return dec, diff
75
+
76
+ def get_input(self, batch, k):
77
+ x = batch[k]
78
+ if len(x.shape) == 3:
79
+ x = x[..., None]
80
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
81
+ return x.float()
82
+
83
+ def training_step(self, batch, batch_idx, optimizer_idx):
84
+ x = self.get_input(batch, self.image_key)
85
+ xrec, qloss = self(x)
86
+
87
+ if optimizer_idx == 0:
88
+ # autoencode
89
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
90
+ last_layer=self.get_last_layer(), split="train")
91
+
92
+ self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
93
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
94
+ return aeloss
95
+
96
+ if optimizer_idx == 1:
97
+ # discriminator
98
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
99
+ last_layer=self.get_last_layer(), split="train")
100
+ self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
101
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
102
+ return discloss
103
+
104
+ def validation_step(self, batch, batch_idx):
105
+ x = self.get_input(batch, self.image_key)
106
+ xrec, qloss = self(x)
107
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
108
+ last_layer=self.get_last_layer(), split="val")
109
+
110
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
111
+ last_layer=self.get_last_layer(), split="val")
112
+ rec_loss = log_dict_ae["val/rec_loss"]
113
+ self.log("val/rec_loss", rec_loss,
114
+ prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
115
+ self.log("val/aeloss", aeloss,
116
+ prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
117
+ self.log_dict(log_dict_ae)
118
+ self.log_dict(log_dict_disc)
119
+ return self.log_dict
120
+
121
+ def configure_optimizers(self):
122
+ lr = self.learning_rate
123
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
124
+ list(self.decoder.parameters())+
125
+ list(self.quantize.parameters())+
126
+ list(self.quant_conv.parameters())+
127
+ list(self.post_quant_conv.parameters()),
128
+ lr=lr, betas=(0.5, 0.9))
129
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
130
+ lr=lr, betas=(0.5, 0.9))
131
+ return [opt_ae, opt_disc], []
132
+
133
+ def get_last_layer(self):
134
+ return self.decoder.conv_out.weight
135
+
136
+ def log_images(self, batch, **kwargs):
137
+ log = dict()
138
+ x = self.get_input(batch, self.image_key)
139
+ x = x.to(self.device)
140
+ xrec, _ = self(x)
141
+ if x.shape[1] > 3:
142
+ # colorize with random projection
143
+ assert xrec.shape[1] > 3
144
+ x = self.to_rgb(x)
145
+ xrec = self.to_rgb(xrec)
146
+ log["inputs"] = x
147
+ log["reconstructions"] = xrec
148
+ return log
149
+
150
+ def to_rgb(self, x):
151
+ assert self.image_key == "segmentation"
152
+ if not hasattr(self, "colorize"):
153
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
154
+ x = F.conv2d(x, weight=self.colorize)
155
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
156
+ return x
157
+
158
+
159
+ class VQSegmentationModel(VQModel):
160
+ def __init__(self, n_labels, *args, **kwargs):
161
+ super().__init__(*args, **kwargs)
162
+ self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1))
163
+
164
+ def configure_optimizers(self):
165
+ lr = self.learning_rate
166
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
167
+ list(self.decoder.parameters())+
168
+ list(self.quantize.parameters())+
169
+ list(self.quant_conv.parameters())+
170
+ list(self.post_quant_conv.parameters()),
171
+ lr=lr, betas=(0.5, 0.9))
172
+ return opt_ae
173
+
174
+ def training_step(self, batch, batch_idx):
175
+ x = self.get_input(batch, self.image_key)
176
+ xrec, qloss = self(x)
177
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
178
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
179
+ return aeloss
180
+
181
+ def validation_step(self, batch, batch_idx):
182
+ x = self.get_input(batch, self.image_key)
183
+ xrec, qloss = self(x)
184
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val")
185
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
186
+ total_loss = log_dict_ae["val/total_loss"]
187
+ self.log("val/total_loss", total_loss,
188
+ prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
189
+ return aeloss
190
+
191
+ @torch.no_grad()
192
+ def log_images(self, batch, **kwargs):
193
+ log = dict()
194
+ x = self.get_input(batch, self.image_key)
195
+ x = x.to(self.device)
196
+ xrec, _ = self(x)
197
+ if x.shape[1] > 3:
198
+ # colorize with random projection
199
+ assert xrec.shape[1] > 3
200
+ # convert logits to indices
201
+ xrec = torch.argmax(xrec, dim=1, keepdim=True)
202
+ xrec = F.one_hot(xrec, num_classes=x.shape[1])
203
+ xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
204
+ x = self.to_rgb(x)
205
+ xrec = self.to_rgb(xrec)
206
+ log["inputs"] = x
207
+ log["reconstructions"] = xrec
208
+ return log
209
+
210
+
211
+ class VQNoDiscModel(VQModel):
212
+ def __init__(self,
213
+ ddconfig,
214
+ lossconfig,
215
+ n_embed,
216
+ embed_dim,
217
+ ckpt_path=None,
218
+ ignore_keys=[],
219
+ image_key="image",
220
+ colorize_nlabels=None
221
+ ):
222
+ super().__init__(ddconfig=ddconfig, lossconfig=lossconfig, n_embed=n_embed, embed_dim=embed_dim,
223
+ ckpt_path=ckpt_path, ignore_keys=ignore_keys, image_key=image_key,
224
+ colorize_nlabels=colorize_nlabels)
225
+
226
+ def training_step(self, batch, batch_idx):
227
+ x = self.get_input(batch, self.image_key)
228
+ xrec, qloss = self(x)
229
+ # autoencode
230
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train")
231
+ output = pl.TrainResult(minimize=aeloss)
232
+ output.log("train/aeloss", aeloss,
233
+ prog_bar=True, logger=True, on_step=True, on_epoch=True)
234
+ output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
235
+ return output
236
+
237
+ def validation_step(self, batch, batch_idx):
238
+ x = self.get_input(batch, self.image_key)
239
+ xrec, qloss = self(x)
240
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val")
241
+ rec_loss = log_dict_ae["val/rec_loss"]
242
+ output = pl.EvalResult(checkpoint_on=rec_loss)
243
+ output.log("val/rec_loss", rec_loss,
244
+ prog_bar=True, logger=True, on_step=True, on_epoch=True)
245
+ output.log("val/aeloss", aeloss,
246
+ prog_bar=True, logger=True, on_step=True, on_epoch=True)
247
+ output.log_dict(log_dict_ae)
248
+
249
+ return output
250
+
251
+ def configure_optimizers(self):
252
+ optimizer = torch.optim.Adam(list(self.encoder.parameters())+
253
+ list(self.decoder.parameters())+
254
+ list(self.quantize.parameters())+
255
+ list(self.quant_conv.parameters())+
256
+ list(self.post_quant_conv.parameters()),
257
+ lr=self.learning_rate, betas=(0.5, 0.9))
258
+ return optimizer
259
+
260
+
261
+ class GumbelVQ(VQModel):
262
+ def __init__(self,
263
+ ddconfig,
264
+ lossconfig,
265
+ n_embed,
266
+ embed_dim,
267
+ temperature_scheduler_config,
268
+ ckpt_path=None,
269
+ ignore_keys=[],
270
+ image_key="image",
271
+ colorize_nlabels=None,
272
+ monitor=None,
273
+ kl_weight=1e-8,
274
+ remap=None,
275
+ ):
276
+
277
+ z_channels = ddconfig["z_channels"]
278
+ super().__init__(ddconfig,
279
+ lossconfig,
280
+ n_embed,
281
+ embed_dim,
282
+ ckpt_path=None,
283
+ ignore_keys=ignore_keys,
284
+ image_key=image_key,
285
+ colorize_nlabels=colorize_nlabels,
286
+ monitor=monitor,
287
+ )
288
+
289
+ self.loss.n_classes = n_embed
290
+ self.vocab_size = n_embed
291
+
292
+ self.quantize = GumbelQuantize(z_channels, embed_dim,
293
+ n_embed=n_embed,
294
+ kl_weight=kl_weight, temp_init=1.0,
295
+ remap=remap)
296
+
297
+ self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp
298
+
299
+ if ckpt_path is not None:
300
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
301
+
302
+ def temperature_scheduling(self):
303
+ self.quantize.temperature = self.temperature_scheduler(self.global_step)
304
+
305
+ def encode_to_prequant(self, x):
306
+ h = self.encoder(x)
307
+ h = self.quant_conv(h)
308
+ return h
309
+
310
+ def decode_code(self, code_b):
311
+ raise NotImplementedError
312
+
313
+ def training_step(self, batch, batch_idx, optimizer_idx):
314
+ self.temperature_scheduling()
315
+ x = self.get_input(batch, self.image_key)
316
+ xrec, qloss = self(x)
317
+
318
+ if optimizer_idx == 0:
319
+ # autoencode
320
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
321
+ last_layer=self.get_last_layer(), split="train")
322
+
323
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
324
+ self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True)
325
+ return aeloss
326
+
327
+ if optimizer_idx == 1:
328
+ # discriminator
329
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
330
+ last_layer=self.get_last_layer(), split="train")
331
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
332
+ return discloss
333
+
334
+ def validation_step(self, batch, batch_idx):
335
+ x = self.get_input(batch, self.image_key)
336
+ xrec, qloss = self(x, return_pred_indices=True)
337
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
338
+ last_layer=self.get_last_layer(), split="val")
339
+
340
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
341
+ last_layer=self.get_last_layer(), split="val")
342
+ rec_loss = log_dict_ae["val/rec_loss"]
343
+ self.log("val/rec_loss", rec_loss,
344
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
345
+ self.log("val/aeloss", aeloss,
346
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
347
+ self.log_dict(log_dict_ae)
348
+ self.log_dict(log_dict_disc)
349
+ return self.log_dict
350
+
351
+ def log_images(self, batch, **kwargs):
352
+ log = dict()
353
+ x = self.get_input(batch, self.image_key)
354
+ x = x.to(self.device)
355
+ # encode
356
+ h = self.encoder(x)
357
+ h = self.quant_conv(h)
358
+ quant, _, _ = self.quantize(h)
359
+ # decode
360
+ x_rec = self.decode(quant)
361
+ log["inputs"] = x
362
+ log["reconstructions"] = x_rec
363
+ return log
364
+
365
+
366
+ class EMAVQ(VQModel):
367
+ def __init__(self,
368
+ ddconfig,
369
+ lossconfig,
370
+ n_embed,
371
+ embed_dim,
372
+ ckpt_path=None,
373
+ ignore_keys=[],
374
+ image_key="image",
375
+ colorize_nlabels=None,
376
+ monitor=None,
377
+ remap=None,
378
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
379
+ ):
380
+ super().__init__(ddconfig,
381
+ lossconfig,
382
+ n_embed,
383
+ embed_dim,
384
+ ckpt_path=None,
385
+ ignore_keys=ignore_keys,
386
+ image_key=image_key,
387
+ colorize_nlabels=colorize_nlabels,
388
+ monitor=monitor,
389
+ )
390
+ self.quantize = EMAVectorQuantizer(n_embed=n_embed,
391
+ embedding_dim=embed_dim,
392
+ beta=0.25,
393
+ remap=remap)
394
+ def configure_optimizers(self):
395
+ lr = self.learning_rate
396
+ #Remove self.quantize from parameter list since it is updated via EMA
397
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
398
+ list(self.decoder.parameters())+
399
+ list(self.quant_conv.parameters())+
400
+ list(self.post_quant_conv.parameters()),
401
+ lr=lr, betas=(0.5, 0.9))
402
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
403
+ lr=lr, betas=(0.5, 0.9))
404
+ return [opt_ae, opt_disc], []
taming-transformers/taming/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+
7
+
8
+ def get_timestep_embedding(timesteps, embedding_dim):
9
+ """
10
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
11
+ From Fairseq.
12
+ Build sinusoidal embeddings.
13
+ This matches the implementation in tensor2tensor, but differs slightly
14
+ from the description in Section 3.5 of "Attention Is All You Need".
15
+ """
16
+ assert len(timesteps.shape) == 1
17
+
18
+ half_dim = embedding_dim // 2
19
+ emb = math.log(10000) / (half_dim - 1)
20
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
21
+ emb = emb.to(device=timesteps.device)
22
+ emb = timesteps.float()[:, None] * emb[None, :]
23
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
24
+ if embedding_dim % 2 == 1: # zero pad
25
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
26
+ return emb
27
+
28
+
29
+ def nonlinearity(x):
30
+ # swish
31
+ return x*torch.sigmoid(x)
32
+
33
+
34
+ def Normalize(in_channels):
35
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
36
+
37
+
38
+ class Upsample(nn.Module):
39
+ def __init__(self, in_channels, with_conv):
40
+ super().__init__()
41
+ self.with_conv = with_conv
42
+ if self.with_conv:
43
+ self.conv = torch.nn.Conv2d(in_channels,
44
+ in_channels,
45
+ kernel_size=3,
46
+ stride=1,
47
+ padding=1)
48
+
49
+ def forward(self, x):
50
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
51
+ if self.with_conv:
52
+ x = self.conv(x)
53
+ return x
54
+
55
+
56
+ class Downsample(nn.Module):
57
+ def __init__(self, in_channels, with_conv):
58
+ super().__init__()
59
+ self.with_conv = with_conv
60
+ if self.with_conv:
61
+ # no asymmetric padding in torch conv, must do it ourselves
62
+ self.conv = torch.nn.Conv2d(in_channels,
63
+ in_channels,
64
+ kernel_size=3,
65
+ stride=2,
66
+ padding=0)
67
+
68
+ def forward(self, x):
69
+ if self.with_conv:
70
+ pad = (0,1,0,1)
71
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
72
+ x = self.conv(x)
73
+ else:
74
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
75
+ return x
76
+
77
+
78
+ class ResnetBlock(nn.Module):
79
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
80
+ dropout, temb_channels=512):
81
+ super().__init__()
82
+ self.in_channels = in_channels
83
+ out_channels = in_channels if out_channels is None else out_channels
84
+ self.out_channels = out_channels
85
+ self.use_conv_shortcut = conv_shortcut
86
+
87
+ self.norm1 = Normalize(in_channels)
88
+ self.conv1 = torch.nn.Conv2d(in_channels,
89
+ out_channels,
90
+ kernel_size=3,
91
+ stride=1,
92
+ padding=1)
93
+ if temb_channels > 0:
94
+ self.temb_proj = torch.nn.Linear(temb_channels,
95
+ out_channels)
96
+ self.norm2 = Normalize(out_channels)
97
+ self.dropout = torch.nn.Dropout(dropout)
98
+ self.conv2 = torch.nn.Conv2d(out_channels,
99
+ out_channels,
100
+ kernel_size=3,
101
+ stride=1,
102
+ padding=1)
103
+ if self.in_channels != self.out_channels:
104
+ if self.use_conv_shortcut:
105
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
106
+ out_channels,
107
+ kernel_size=3,
108
+ stride=1,
109
+ padding=1)
110
+ else:
111
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
112
+ out_channels,
113
+ kernel_size=1,
114
+ stride=1,
115
+ padding=0)
116
+
117
+ def forward(self, x, temb):
118
+ h = x
119
+ h = self.norm1(h)
120
+ h = nonlinearity(h)
121
+ h = self.conv1(h)
122
+
123
+ if temb is not None:
124
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
125
+
126
+ h = self.norm2(h)
127
+ h = nonlinearity(h)
128
+ h = self.dropout(h)
129
+ h = self.conv2(h)
130
+
131
+ if self.in_channels != self.out_channels:
132
+ if self.use_conv_shortcut:
133
+ x = self.conv_shortcut(x)
134
+ else:
135
+ x = self.nin_shortcut(x)
136
+
137
+ return x+h
138
+
139
+
140
+ class AttnBlock(nn.Module):
141
+ def __init__(self, in_channels):
142
+ super().__init__()
143
+ self.in_channels = in_channels
144
+
145
+ self.norm = Normalize(in_channels)
146
+ self.q = torch.nn.Conv2d(in_channels,
147
+ in_channels,
148
+ kernel_size=1,
149
+ stride=1,
150
+ padding=0)
151
+ self.k = torch.nn.Conv2d(in_channels,
152
+ in_channels,
153
+ kernel_size=1,
154
+ stride=1,
155
+ padding=0)
156
+ self.v = torch.nn.Conv2d(in_channels,
157
+ in_channels,
158
+ kernel_size=1,
159
+ stride=1,
160
+ padding=0)
161
+ self.proj_out = torch.nn.Conv2d(in_channels,
162
+ in_channels,
163
+ kernel_size=1,
164
+ stride=1,
165
+ padding=0)
166
+
167
+
168
+ def forward(self, x):
169
+ h_ = x
170
+ h_ = self.norm(h_)
171
+ q = self.q(h_)
172
+ k = self.k(h_)
173
+ v = self.v(h_)
174
+
175
+ # compute attention
176
+ b,c,h,w = q.shape
177
+ q = q.reshape(b,c,h*w)
178
+ q = q.permute(0,2,1) # b,hw,c
179
+ k = k.reshape(b,c,h*w) # b,c,hw
180
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
181
+ w_ = w_ * (int(c)**(-0.5))
182
+ w_ = torch.nn.functional.softmax(w_, dim=2)
183
+
184
+ # attend to values
185
+ v = v.reshape(b,c,h*w)
186
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
187
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
188
+ h_ = h_.reshape(b,c,h,w)
189
+
190
+ h_ = self.proj_out(h_)
191
+
192
+ return x+h_
193
+
194
+
195
+ class Model(nn.Module):
196
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
197
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
198
+ resolution, use_timestep=True):
199
+ super().__init__()
200
+ self.ch = ch
201
+ self.temb_ch = self.ch*4
202
+ self.num_resolutions = len(ch_mult)
203
+ self.num_res_blocks = num_res_blocks
204
+ self.resolution = resolution
205
+ self.in_channels = in_channels
206
+
207
+ self.use_timestep = use_timestep
208
+ if self.use_timestep:
209
+ # timestep embedding
210
+ self.temb = nn.Module()
211
+ self.temb.dense = nn.ModuleList([
212
+ torch.nn.Linear(self.ch,
213
+ self.temb_ch),
214
+ torch.nn.Linear(self.temb_ch,
215
+ self.temb_ch),
216
+ ])
217
+
218
+ # downsampling
219
+ self.conv_in = torch.nn.Conv2d(in_channels,
220
+ self.ch,
221
+ kernel_size=3,
222
+ stride=1,
223
+ padding=1)
224
+
225
+ curr_res = resolution
226
+ in_ch_mult = (1,)+tuple(ch_mult)
227
+ self.down = nn.ModuleList()
228
+ for i_level in range(self.num_resolutions):
229
+ block = nn.ModuleList()
230
+ attn = nn.ModuleList()
231
+ block_in = ch*in_ch_mult[i_level]
232
+ block_out = ch*ch_mult[i_level]
233
+ for i_block in range(self.num_res_blocks):
234
+ block.append(ResnetBlock(in_channels=block_in,
235
+ out_channels=block_out,
236
+ temb_channels=self.temb_ch,
237
+ dropout=dropout))
238
+ block_in = block_out
239
+ if curr_res in attn_resolutions:
240
+ attn.append(AttnBlock(block_in))
241
+ down = nn.Module()
242
+ down.block = block
243
+ down.attn = attn
244
+ if i_level != self.num_resolutions-1:
245
+ down.downsample = Downsample(block_in, resamp_with_conv)
246
+ curr_res = curr_res // 2
247
+ self.down.append(down)
248
+
249
+ # middle
250
+ self.mid = nn.Module()
251
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
252
+ out_channels=block_in,
253
+ temb_channels=self.temb_ch,
254
+ dropout=dropout)
255
+ self.mid.attn_1 = AttnBlock(block_in)
256
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
257
+ out_channels=block_in,
258
+ temb_channels=self.temb_ch,
259
+ dropout=dropout)
260
+
261
+ # upsampling
262
+ self.up = nn.ModuleList()
263
+ for i_level in reversed(range(self.num_resolutions)):
264
+ block = nn.ModuleList()
265
+ attn = nn.ModuleList()
266
+ block_out = ch*ch_mult[i_level]
267
+ skip_in = ch*ch_mult[i_level]
268
+ for i_block in range(self.num_res_blocks+1):
269
+ if i_block == self.num_res_blocks:
270
+ skip_in = ch*in_ch_mult[i_level]
271
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
272
+ out_channels=block_out,
273
+ temb_channels=self.temb_ch,
274
+ dropout=dropout))
275
+ block_in = block_out
276
+ if curr_res in attn_resolutions:
277
+ attn.append(AttnBlock(block_in))
278
+ up = nn.Module()
279
+ up.block = block
280
+ up.attn = attn
281
+ if i_level != 0:
282
+ up.upsample = Upsample(block_in, resamp_with_conv)
283
+ curr_res = curr_res * 2
284
+ self.up.insert(0, up) # prepend to get consistent order
285
+
286
+ # end
287
+ self.norm_out = Normalize(block_in)
288
+ self.conv_out = torch.nn.Conv2d(block_in,
289
+ out_ch,
290
+ kernel_size=3,
291
+ stride=1,
292
+ padding=1)
293
+
294
+
295
+ def forward(self, x, t=None):
296
+ #assert x.shape[2] == x.shape[3] == self.resolution
297
+
298
+ if self.use_timestep:
299
+ # timestep embedding
300
+ assert t is not None
301
+ temb = get_timestep_embedding(t, self.ch)
302
+ temb = self.temb.dense[0](temb)
303
+ temb = nonlinearity(temb)
304
+ temb = self.temb.dense[1](temb)
305
+ else:
306
+ temb = None
307
+
308
+ # downsampling
309
+ hs = [self.conv_in(x)]
310
+ for i_level in range(self.num_resolutions):
311
+ for i_block in range(self.num_res_blocks):
312
+ h = self.down[i_level].block[i_block](hs[-1], temb)
313
+ if len(self.down[i_level].attn) > 0:
314
+ h = self.down[i_level].attn[i_block](h)
315
+ hs.append(h)
316
+ if i_level != self.num_resolutions-1:
317
+ hs.append(self.down[i_level].downsample(hs[-1]))
318
+
319
+ # middle
320
+ h = hs[-1]
321
+ h = self.mid.block_1(h, temb)
322
+ h = self.mid.attn_1(h)
323
+ h = self.mid.block_2(h, temb)
324
+
325
+ # upsampling
326
+ for i_level in reversed(range(self.num_resolutions)):
327
+ for i_block in range(self.num_res_blocks+1):
328
+ h = self.up[i_level].block[i_block](
329
+ torch.cat([h, hs.pop()], dim=1), temb)
330
+ if len(self.up[i_level].attn) > 0:
331
+ h = self.up[i_level].attn[i_block](h)
332
+ if i_level != 0:
333
+ h = self.up[i_level].upsample(h)
334
+
335
+ # end
336
+ h = self.norm_out(h)
337
+ h = nonlinearity(h)
338
+ h = self.conv_out(h)
339
+ return h
340
+
341
+
342
+ class Encoder(nn.Module):
343
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
344
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
345
+ resolution, z_channels, double_z=True, **ignore_kwargs):
346
+ super().__init__()
347
+ self.ch = ch
348
+ self.temb_ch = 0
349
+ self.num_resolutions = len(ch_mult)
350
+ self.num_res_blocks = num_res_blocks
351
+ self.resolution = resolution
352
+ self.in_channels = in_channels
353
+
354
+ # downsampling
355
+ self.conv_in = torch.nn.Conv2d(in_channels,
356
+ self.ch,
357
+ kernel_size=3,
358
+ stride=1,
359
+ padding=1)
360
+
361
+ curr_res = resolution
362
+ in_ch_mult = (1,)+tuple(ch_mult)
363
+ self.down = nn.ModuleList()
364
+ for i_level in range(self.num_resolutions):
365
+ block = nn.ModuleList()
366
+ attn = nn.ModuleList()
367
+ block_in = ch*in_ch_mult[i_level]
368
+ block_out = ch*ch_mult[i_level]
369
+ for i_block in range(self.num_res_blocks):
370
+ block.append(ResnetBlock(in_channels=block_in,
371
+ out_channels=block_out,
372
+ temb_channels=self.temb_ch,
373
+ dropout=dropout))
374
+ block_in = block_out
375
+ if curr_res in attn_resolutions:
376
+ attn.append(AttnBlock(block_in))
377
+ down = nn.Module()
378
+ down.block = block
379
+ down.attn = attn
380
+ if i_level != self.num_resolutions-1:
381
+ down.downsample = Downsample(block_in, resamp_with_conv)
382
+ curr_res = curr_res // 2
383
+ self.down.append(down)
384
+
385
+ # middle
386
+ self.mid = nn.Module()
387
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
388
+ out_channels=block_in,
389
+ temb_channels=self.temb_ch,
390
+ dropout=dropout)
391
+ self.mid.attn_1 = AttnBlock(block_in)
392
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
393
+ out_channels=block_in,
394
+ temb_channels=self.temb_ch,
395
+ dropout=dropout)
396
+
397
+ # end
398
+ self.norm_out = Normalize(block_in)
399
+ self.conv_out = torch.nn.Conv2d(block_in,
400
+ 2*z_channels if double_z else z_channels,
401
+ kernel_size=3,
402
+ stride=1,
403
+ padding=1)
404
+
405
+
406
+ def forward(self, x):
407
+ #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
408
+
409
+ # timestep embedding
410
+ temb = None
411
+
412
+ # downsampling
413
+ hs = [self.conv_in(x)]
414
+ for i_level in range(self.num_resolutions):
415
+ for i_block in range(self.num_res_blocks):
416
+ h = self.down[i_level].block[i_block](hs[-1], temb)
417
+ if len(self.down[i_level].attn) > 0:
418
+ h = self.down[i_level].attn[i_block](h)
419
+ hs.append(h)
420
+ if i_level != self.num_resolutions-1:
421
+ hs.append(self.down[i_level].downsample(hs[-1]))
422
+
423
+ # middle
424
+ h = hs[-1]
425
+ h = self.mid.block_1(h, temb)
426
+ h = self.mid.attn_1(h)
427
+ h = self.mid.block_2(h, temb)
428
+
429
+ # end
430
+ h = self.norm_out(h)
431
+ h = nonlinearity(h)
432
+ h = self.conv_out(h)
433
+ return h
434
+
435
+
436
+ class Decoder(nn.Module):
437
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
438
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
439
+ resolution, z_channels, give_pre_end=False, **ignorekwargs):
440
+ super().__init__()
441
+ self.ch = ch
442
+ self.temb_ch = 0
443
+ self.num_resolutions = len(ch_mult)
444
+ self.num_res_blocks = num_res_blocks
445
+ self.resolution = resolution
446
+ self.in_channels = in_channels
447
+ self.give_pre_end = give_pre_end
448
+
449
+ # compute in_ch_mult, block_in and curr_res at lowest res
450
+ in_ch_mult = (1,)+tuple(ch_mult)
451
+ block_in = ch*ch_mult[self.num_resolutions-1]
452
+ curr_res = resolution // 2**(self.num_resolutions-1)
453
+ self.z_shape = (1,z_channels,curr_res,curr_res)
454
+ print("Working with z of shape {} = {} dimensions.".format(
455
+ self.z_shape, np.prod(self.z_shape)))
456
+
457
+ # z to block_in
458
+ self.conv_in = torch.nn.Conv2d(z_channels,
459
+ block_in,
460
+ kernel_size=3,
461
+ stride=1,
462
+ padding=1)
463
+
464
+ # middle
465
+ self.mid = nn.Module()
466
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
467
+ out_channels=block_in,
468
+ temb_channels=self.temb_ch,
469
+ dropout=dropout)
470
+ self.mid.attn_1 = AttnBlock(block_in)
471
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
472
+ out_channels=block_in,
473
+ temb_channels=self.temb_ch,
474
+ dropout=dropout)
475
+
476
+ # upsampling
477
+ self.up = nn.ModuleList()
478
+ for i_level in reversed(range(self.num_resolutions)):
479
+ block = nn.ModuleList()
480
+ attn = nn.ModuleList()
481
+ block_out = ch*ch_mult[i_level]
482
+ for i_block in range(self.num_res_blocks+1):
483
+ block.append(ResnetBlock(in_channels=block_in,
484
+ out_channels=block_out,
485
+ temb_channels=self.temb_ch,
486
+ dropout=dropout))
487
+ block_in = block_out
488
+ if curr_res in attn_resolutions:
489
+ attn.append(AttnBlock(block_in))
490
+ up = nn.Module()
491
+ up.block = block
492
+ up.attn = attn
493
+ if i_level != 0:
494
+ up.upsample = Upsample(block_in, resamp_with_conv)
495
+ curr_res = curr_res * 2
496
+ self.up.insert(0, up) # prepend to get consistent order
497
+
498
+ # end
499
+ self.norm_out = Normalize(block_in)
500
+ self.conv_out = torch.nn.Conv2d(block_in,
501
+ out_ch,
502
+ kernel_size=3,
503
+ stride=1,
504
+ padding=1)
505
+
506
+ def forward(self, z):
507
+ #assert z.shape[1:] == self.z_shape[1:]
508
+ self.last_z_shape = z.shape
509
+
510
+ # timestep embedding
511
+ temb = None
512
+
513
+ # z to block_in
514
+ h = self.conv_in(z)
515
+
516
+ # middle
517
+ h = self.mid.block_1(h, temb)
518
+ h = self.mid.attn_1(h)
519
+ h = self.mid.block_2(h, temb)
520
+
521
+ # upsampling
522
+ for i_level in reversed(range(self.num_resolutions)):
523
+ for i_block in range(self.num_res_blocks+1):
524
+ h = self.up[i_level].block[i_block](h, temb)
525
+ if len(self.up[i_level].attn) > 0:
526
+ h = self.up[i_level].attn[i_block](h)
527
+ if i_level != 0:
528
+ h = self.up[i_level].upsample(h)
529
+
530
+ # end
531
+ if self.give_pre_end:
532
+ return h
533
+
534
+ h = self.norm_out(h)
535
+ h = nonlinearity(h)
536
+ h = self.conv_out(h)
537
+ return h
538
+
539
+
540
+ class VUNet(nn.Module):
541
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
542
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
543
+ in_channels, c_channels,
544
+ resolution, z_channels, use_timestep=False, **ignore_kwargs):
545
+ super().__init__()
546
+ self.ch = ch
547
+ self.temb_ch = self.ch*4
548
+ self.num_resolutions = len(ch_mult)
549
+ self.num_res_blocks = num_res_blocks
550
+ self.resolution = resolution
551
+
552
+ self.use_timestep = use_timestep
553
+ if self.use_timestep:
554
+ # timestep embedding
555
+ self.temb = nn.Module()
556
+ self.temb.dense = nn.ModuleList([
557
+ torch.nn.Linear(self.ch,
558
+ self.temb_ch),
559
+ torch.nn.Linear(self.temb_ch,
560
+ self.temb_ch),
561
+ ])
562
+
563
+ # downsampling
564
+ self.conv_in = torch.nn.Conv2d(c_channels,
565
+ self.ch,
566
+ kernel_size=3,
567
+ stride=1,
568
+ padding=1)
569
+
570
+ curr_res = resolution
571
+ in_ch_mult = (1,)+tuple(ch_mult)
572
+ self.down = nn.ModuleList()
573
+ for i_level in range(self.num_resolutions):
574
+ block = nn.ModuleList()
575
+ attn = nn.ModuleList()
576
+ block_in = ch*in_ch_mult[i_level]
577
+ block_out = ch*ch_mult[i_level]
578
+ for i_block in range(self.num_res_blocks):
579
+ block.append(ResnetBlock(in_channels=block_in,
580
+ out_channels=block_out,
581
+ temb_channels=self.temb_ch,
582
+ dropout=dropout))
583
+ block_in = block_out
584
+ if curr_res in attn_resolutions:
585
+ attn.append(AttnBlock(block_in))
586
+ down = nn.Module()
587
+ down.block = block
588
+ down.attn = attn
589
+ if i_level != self.num_resolutions-1:
590
+ down.downsample = Downsample(block_in, resamp_with_conv)
591
+ curr_res = curr_res // 2
592
+ self.down.append(down)
593
+
594
+ self.z_in = torch.nn.Conv2d(z_channels,
595
+ block_in,
596
+ kernel_size=1,
597
+ stride=1,
598
+ padding=0)
599
+ # middle
600
+ self.mid = nn.Module()
601
+ self.mid.block_1 = ResnetBlock(in_channels=2*block_in,
602
+ out_channels=block_in,
603
+ temb_channels=self.temb_ch,
604
+ dropout=dropout)
605
+ self.mid.attn_1 = AttnBlock(block_in)
606
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
607
+ out_channels=block_in,
608
+ temb_channels=self.temb_ch,
609
+ dropout=dropout)
610
+
611
+ # upsampling
612
+ self.up = nn.ModuleList()
613
+ for i_level in reversed(range(self.num_resolutions)):
614
+ block = nn.ModuleList()
615
+ attn = nn.ModuleList()
616
+ block_out = ch*ch_mult[i_level]
617
+ skip_in = ch*ch_mult[i_level]
618
+ for i_block in range(self.num_res_blocks+1):
619
+ if i_block == self.num_res_blocks:
620
+ skip_in = ch*in_ch_mult[i_level]
621
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
622
+ out_channels=block_out,
623
+ temb_channels=self.temb_ch,
624
+ dropout=dropout))
625
+ block_in = block_out
626
+ if curr_res in attn_resolutions:
627
+ attn.append(AttnBlock(block_in))
628
+ up = nn.Module()
629
+ up.block = block
630
+ up.attn = attn
631
+ if i_level != 0:
632
+ up.upsample = Upsample(block_in, resamp_with_conv)
633
+ curr_res = curr_res * 2
634
+ self.up.insert(0, up) # prepend to get consistent order
635
+
636
+ # end
637
+ self.norm_out = Normalize(block_in)
638
+ self.conv_out = torch.nn.Conv2d(block_in,
639
+ out_ch,
640
+ kernel_size=3,
641
+ stride=1,
642
+ padding=1)
643
+
644
+
645
+ def forward(self, x, z):
646
+ #assert x.shape[2] == x.shape[3] == self.resolution
647
+
648
+ if self.use_timestep:
649
+ # timestep embedding
650
+ assert t is not None
651
+ temb = get_timestep_embedding(t, self.ch)
652
+ temb = self.temb.dense[0](temb)
653
+ temb = nonlinearity(temb)
654
+ temb = self.temb.dense[1](temb)
655
+ else:
656
+ temb = None
657
+
658
+ # downsampling
659
+ hs = [self.conv_in(x)]
660
+ for i_level in range(self.num_resolutions):
661
+ for i_block in range(self.num_res_blocks):
662
+ h = self.down[i_level].block[i_block](hs[-1], temb)
663
+ if len(self.down[i_level].attn) > 0:
664
+ h = self.down[i_level].attn[i_block](h)
665
+ hs.append(h)
666
+ if i_level != self.num_resolutions-1:
667
+ hs.append(self.down[i_level].downsample(hs[-1]))
668
+
669
+ # middle
670
+ h = hs[-1]
671
+ z = self.z_in(z)
672
+ h = torch.cat((h,z),dim=1)
673
+ h = self.mid.block_1(h, temb)
674
+ h = self.mid.attn_1(h)
675
+ h = self.mid.block_2(h, temb)
676
+
677
+ # upsampling
678
+ for i_level in reversed(range(self.num_resolutions)):
679
+ for i_block in range(self.num_res_blocks+1):
680
+ h = self.up[i_level].block[i_block](
681
+ torch.cat([h, hs.pop()], dim=1), temb)
682
+ if len(self.up[i_level].attn) > 0:
683
+ h = self.up[i_level].attn[i_block](h)
684
+ if i_level != 0:
685
+ h = self.up[i_level].upsample(h)
686
+
687
+ # end
688
+ h = self.norm_out(h)
689
+ h = nonlinearity(h)
690
+ h = self.conv_out(h)
691
+ return h
692
+
693
+
694
+ class SimpleDecoder(nn.Module):
695
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
696
+ super().__init__()
697
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
698
+ ResnetBlock(in_channels=in_channels,
699
+ out_channels=2 * in_channels,
700
+ temb_channels=0, dropout=0.0),
701
+ ResnetBlock(in_channels=2 * in_channels,
702
+ out_channels=4 * in_channels,
703
+ temb_channels=0, dropout=0.0),
704
+ ResnetBlock(in_channels=4 * in_channels,
705
+ out_channels=2 * in_channels,
706
+ temb_channels=0, dropout=0.0),
707
+ nn.Conv2d(2*in_channels, in_channels, 1),
708
+ Upsample(in_channels, with_conv=True)])
709
+ # end
710
+ self.norm_out = Normalize(in_channels)
711
+ self.conv_out = torch.nn.Conv2d(in_channels,
712
+ out_channels,
713
+ kernel_size=3,
714
+ stride=1,
715
+ padding=1)
716
+
717
+ def forward(self, x):
718
+ for i, layer in enumerate(self.model):
719
+ if i in [1,2,3]:
720
+ x = layer(x, None)
721
+ else:
722
+ x = layer(x)
723
+
724
+ h = self.norm_out(x)
725
+ h = nonlinearity(h)
726
+ x = self.conv_out(h)
727
+ return x
728
+
729
+
730
+ class UpsampleDecoder(nn.Module):
731
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
732
+ ch_mult=(2,2), dropout=0.0):
733
+ super().__init__()
734
+ # upsampling
735
+ self.temb_ch = 0
736
+ self.num_resolutions = len(ch_mult)
737
+ self.num_res_blocks = num_res_blocks
738
+ block_in = in_channels
739
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
740
+ self.res_blocks = nn.ModuleList()
741
+ self.upsample_blocks = nn.ModuleList()
742
+ for i_level in range(self.num_resolutions):
743
+ res_block = []
744
+ block_out = ch * ch_mult[i_level]
745
+ for i_block in range(self.num_res_blocks + 1):
746
+ res_block.append(ResnetBlock(in_channels=block_in,
747
+ out_channels=block_out,
748
+ temb_channels=self.temb_ch,
749
+ dropout=dropout))
750
+ block_in = block_out
751
+ self.res_blocks.append(nn.ModuleList(res_block))
752
+ if i_level != self.num_resolutions - 1:
753
+ self.upsample_blocks.append(Upsample(block_in, True))
754
+ curr_res = curr_res * 2
755
+
756
+ # end
757
+ self.norm_out = Normalize(block_in)
758
+ self.conv_out = torch.nn.Conv2d(block_in,
759
+ out_channels,
760
+ kernel_size=3,
761
+ stride=1,
762
+ padding=1)
763
+
764
+ def forward(self, x):
765
+ # upsampling
766
+ h = x
767
+ for k, i_level in enumerate(range(self.num_resolutions)):
768
+ for i_block in range(self.num_res_blocks + 1):
769
+ h = self.res_blocks[i_level][i_block](h, None)
770
+ if i_level != self.num_resolutions - 1:
771
+ h = self.upsample_blocks[k](h)
772
+ h = self.norm_out(h)
773
+ h = nonlinearity(h)
774
+ h = self.conv_out(h)
775
+ return h
776
+
taming-transformers/taming/modules/discriminator/model.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import torch.nn as nn
3
+
4
+
5
+ from taming.modules.util import ActNorm
6
+
7
+
8
+ def weights_init(m):
9
+ classname = m.__class__.__name__
10
+ if classname.find('Conv') != -1:
11
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
12
+ elif classname.find('BatchNorm') != -1:
13
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
14
+ nn.init.constant_(m.bias.data, 0)
15
+
16
+
17
+ class NLayerDiscriminator(nn.Module):
18
+ """Defines a PatchGAN discriminator as in Pix2Pix
19
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
20
+ """
21
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
22
+ """Construct a PatchGAN discriminator
23
+ Parameters:
24
+ input_nc (int) -- the number of channels in input images
25
+ ndf (int) -- the number of filters in the last conv layer
26
+ n_layers (int) -- the number of conv layers in the discriminator
27
+ norm_layer -- normalization layer
28
+ """
29
+ super(NLayerDiscriminator, self).__init__()
30
+ if not use_actnorm:
31
+ norm_layer = nn.BatchNorm2d
32
+ else:
33
+ norm_layer = ActNorm
34
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
35
+ use_bias = norm_layer.func != nn.BatchNorm2d
36
+ else:
37
+ use_bias = norm_layer != nn.BatchNorm2d
38
+
39
+ kw = 4
40
+ padw = 1
41
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
42
+ nf_mult = 1
43
+ nf_mult_prev = 1
44
+ for n in range(1, n_layers): # gradually increase the number of filters
45
+ nf_mult_prev = nf_mult
46
+ nf_mult = min(2 ** n, 8)
47
+ sequence += [
48
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
49
+ norm_layer(ndf * nf_mult),
50
+ nn.LeakyReLU(0.2, True)
51
+ ]
52
+
53
+ nf_mult_prev = nf_mult
54
+ nf_mult = min(2 ** n_layers, 8)
55
+ sequence += [
56
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
57
+ norm_layer(ndf * nf_mult),
58
+ nn.LeakyReLU(0.2, True)
59
+ ]
60
+
61
+ sequence += [
62
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
63
+ self.main = nn.Sequential(*sequence)
64
+
65
+ def forward(self, input):
66
+ """Standard forward."""
67
+ return self.main(input)
taming-transformers/taming/modules/losses/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from taming.modules.losses.vqperceptual import DummyLoss
2
+
taming-transformers/taming/modules/losses/lpips.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torchvision import models
6
+ from collections import namedtuple
7
+
8
+ from taming.util import get_ckpt_path
9
+
10
+
11
+ class LPIPS(nn.Module):
12
+ # Learned perceptual metric
13
+ def __init__(self, use_dropout=True):
14
+ super().__init__()
15
+ self.scaling_layer = ScalingLayer()
16
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
17
+ self.net = vgg16(pretrained=True, requires_grad=False)
18
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
19
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
20
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
21
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
22
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
23
+ self.load_from_pretrained()
24
+ for param in self.parameters():
25
+ param.requires_grad = False
26
+
27
+ def load_from_pretrained(self, name="vgg_lpips"):
28
+ ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
29
+ self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
30
+ print("loaded pretrained LPIPS loss from {}".format(ckpt))
31
+
32
+ @classmethod
33
+ def from_pretrained(cls, name="vgg_lpips"):
34
+ if name != "vgg_lpips":
35
+ raise NotImplementedError
36
+ model = cls()
37
+ ckpt = get_ckpt_path(name)
38
+ model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
39
+ return model
40
+
41
+ def forward(self, input, target):
42
+ in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
43
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
44
+ feats0, feats1, diffs = {}, {}, {}
45
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
46
+ for kk in range(len(self.chns)):
47
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
48
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
49
+
50
+ res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
51
+ val = res[0]
52
+ for l in range(1, len(self.chns)):
53
+ val += res[l]
54
+ return val
55
+
56
+
57
+ class ScalingLayer(nn.Module):
58
+ def __init__(self):
59
+ super(ScalingLayer, self).__init__()
60
+ self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
61
+ self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
62
+
63
+ def forward(self, inp):
64
+ return (inp - self.shift) / self.scale
65
+
66
+
67
+ class NetLinLayer(nn.Module):
68
+ """ A single linear layer which does a 1x1 conv """
69
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
70
+ super(NetLinLayer, self).__init__()
71
+ layers = [nn.Dropout(), ] if (use_dropout) else []
72
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
73
+ self.model = nn.Sequential(*layers)
74
+
75
+
76
+ class vgg16(torch.nn.Module):
77
+ def __init__(self, requires_grad=False, pretrained=True):
78
+ super(vgg16, self).__init__()
79
+ vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
80
+ self.slice1 = torch.nn.Sequential()
81
+ self.slice2 = torch.nn.Sequential()
82
+ self.slice3 = torch.nn.Sequential()
83
+ self.slice4 = torch.nn.Sequential()
84
+ self.slice5 = torch.nn.Sequential()
85
+ self.N_slices = 5
86
+ for x in range(4):
87
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
88
+ for x in range(4, 9):
89
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
90
+ for x in range(9, 16):
91
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
92
+ for x in range(16, 23):
93
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
94
+ for x in range(23, 30):
95
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
96
+ if not requires_grad:
97
+ for param in self.parameters():
98
+ param.requires_grad = False
99
+
100
+ def forward(self, X):
101
+ h = self.slice1(X)
102
+ h_relu1_2 = h
103
+ h = self.slice2(h)
104
+ h_relu2_2 = h
105
+ h = self.slice3(h)
106
+ h_relu3_3 = h
107
+ h = self.slice4(h)
108
+ h_relu4_3 = h
109
+ h = self.slice5(h)
110
+ h_relu5_3 = h
111
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
112
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
113
+ return out
114
+
115
+
116
+ def normalize_tensor(x,eps=1e-10):
117
+ norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
118
+ return x/(norm_factor+eps)
119
+
120
+
121
+ def spatial_average(x, keepdim=True):
122
+ return x.mean([2,3],keepdim=keepdim)
123
+
taming-transformers/taming/modules/losses/segmentation.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+
5
+ class BCELoss(nn.Module):
6
+ def forward(self, prediction, target):
7
+ loss = F.binary_cross_entropy_with_logits(prediction,target)
8
+ return loss, {}
9
+
10
+
11
+ class BCELossWithQuant(nn.Module):
12
+ def __init__(self, codebook_weight=1.):
13
+ super().__init__()
14
+ self.codebook_weight = codebook_weight
15
+
16
+ def forward(self, qloss, target, prediction, split):
17
+ bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
18
+ loss = bce_loss + self.codebook_weight*qloss
19
+ return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
20
+ "{}/bce_loss".format(split): bce_loss.detach().mean(),
21
+ "{}/quant_loss".format(split): qloss.detach().mean()
22
+ }
taming-transformers/taming/modules/losses/vqperceptual.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from taming.modules.losses.lpips import LPIPS
6
+ from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7
+
8
+
9
+ class DummyLoss(nn.Module):
10
+ def __init__(self):
11
+ super().__init__()
12
+
13
+
14
+ def adopt_weight(weight, global_step, threshold=0, value=0.):
15
+ if global_step < threshold:
16
+ weight = value
17
+ return weight
18
+
19
+
20
+ def hinge_d_loss(logits_real, logits_fake):
21
+ loss_real = torch.mean(F.relu(1. - logits_real))
22
+ loss_fake = torch.mean(F.relu(1. + logits_fake))
23
+ d_loss = 0.5 * (loss_real + loss_fake)
24
+ return d_loss
25
+
26
+
27
+ def vanilla_d_loss(logits_real, logits_fake):
28
+ d_loss = 0.5 * (
29
+ torch.mean(torch.nn.functional.softplus(-logits_real)) +
30
+ torch.mean(torch.nn.functional.softplus(logits_fake)))
31
+ return d_loss
32
+
33
+
34
+ class VQLPIPSWithDiscriminator(nn.Module):
35
+ def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
36
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
37
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
38
+ disc_ndf=64, disc_loss="hinge"):
39
+ super().__init__()
40
+ assert disc_loss in ["hinge", "vanilla"]
41
+ self.codebook_weight = codebook_weight
42
+ self.pixel_weight = pixelloss_weight
43
+ self.perceptual_loss = LPIPS().eval()
44
+ self.perceptual_weight = perceptual_weight
45
+
46
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
47
+ n_layers=disc_num_layers,
48
+ use_actnorm=use_actnorm,
49
+ ndf=disc_ndf
50
+ ).apply(weights_init)
51
+ self.discriminator_iter_start = disc_start
52
+ if disc_loss == "hinge":
53
+ self.disc_loss = hinge_d_loss
54
+ elif disc_loss == "vanilla":
55
+ self.disc_loss = vanilla_d_loss
56
+ else:
57
+ raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
58
+ print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
59
+ self.disc_factor = disc_factor
60
+ self.discriminator_weight = disc_weight
61
+ self.disc_conditional = disc_conditional
62
+
63
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
64
+ if last_layer is not None:
65
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
66
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
67
+ else:
68
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
69
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
70
+
71
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
72
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
73
+ d_weight = d_weight * self.discriminator_weight
74
+ return d_weight
75
+
76
+ def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
77
+ global_step, last_layer=None, cond=None, split="train"):
78
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
79
+ if self.perceptual_weight > 0:
80
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
81
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
82
+ else:
83
+ p_loss = torch.tensor([0.0])
84
+
85
+ nll_loss = rec_loss
86
+ #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
87
+ nll_loss = torch.mean(nll_loss)
88
+
89
+ # now the GAN part
90
+ if optimizer_idx == 0:
91
+ # generator update
92
+ if cond is None:
93
+ assert not self.disc_conditional
94
+ logits_fake = self.discriminator(reconstructions.contiguous())
95
+ else:
96
+ assert self.disc_conditional
97
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
98
+ g_loss = -torch.mean(logits_fake)
99
+
100
+ try:
101
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
102
+ except RuntimeError:
103
+ assert not self.training
104
+ d_weight = torch.tensor(0.0)
105
+
106
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
107
+ loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
108
+
109
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
110
+ "{}/quant_loss".format(split): codebook_loss.detach().mean(),
111
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
112
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
113
+ "{}/p_loss".format(split): p_loss.detach().mean(),
114
+ "{}/d_weight".format(split): d_weight.detach(),
115
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
116
+ "{}/g_loss".format(split): g_loss.detach().mean(),
117
+ }
118
+ return loss, log
119
+
120
+ if optimizer_idx == 1:
121
+ # second pass for discriminator update
122
+ if cond is None:
123
+ logits_real = self.discriminator(inputs.contiguous().detach())
124
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
125
+ else:
126
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
127
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
128
+
129
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
130
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
131
+
132
+ log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
133
+ "{}/logits_real".format(split): logits_real.detach().mean(),
134
+ "{}/logits_fake".format(split): logits_fake.detach().mean()
135
+ }
136
+ return d_loss, log
taming-transformers/taming/modules/misc/coord.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class CoordStage(object):
4
+ def __init__(self, n_embed, down_factor):
5
+ self.n_embed = n_embed
6
+ self.down_factor = down_factor
7
+
8
+ def eval(self):
9
+ return self
10
+
11
+ def encode(self, c):
12
+ """fake vqmodel interface"""
13
+ assert 0.0 <= c.min() and c.max() <= 1.0
14
+ b,ch,h,w = c.shape
15
+ assert ch == 1
16
+
17
+ c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor,
18
+ mode="area")
19
+ c = c.clamp(0.0, 1.0)
20
+ c = self.n_embed*c
21
+ c_quant = c.round()
22
+ c_ind = c_quant.to(dtype=torch.long)
23
+
24
+ info = None, None, c_ind
25
+ return c_quant, None, info
26
+
27
+ def decode(self, c):
28
+ c = c/self.n_embed
29
+ c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor,
30
+ mode="nearest")
31
+ return c
taming-transformers/taming/modules/transformer/mingpt.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ taken from: https://github.com/karpathy/minGPT/
3
+ GPT model:
4
+ - the initial stem consists of a combination of token encoding and a positional encoding
5
+ - the meat of it is a uniform sequence of Transformer blocks
6
+ - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
7
+ - all blocks feed into a central residual pathway similar to resnets
8
+ - the final decoder is a linear projection into a vanilla Softmax classifier
9
+ """
10
+
11
+ import math
12
+ import logging
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn import functional as F
17
+ from transformers import top_k_top_p_filtering
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class GPTConfig:
23
+ """ base GPT config, params common to all GPT versions """
24
+ embd_pdrop = 0.1
25
+ resid_pdrop = 0.1
26
+ attn_pdrop = 0.1
27
+
28
+ def __init__(self, vocab_size, block_size, **kwargs):
29
+ self.vocab_size = vocab_size
30
+ self.block_size = block_size
31
+ for k,v in kwargs.items():
32
+ setattr(self, k, v)
33
+
34
+
35
+ class GPT1Config(GPTConfig):
36
+ """ GPT-1 like network roughly 125M params """
37
+ n_layer = 12
38
+ n_head = 12
39
+ n_embd = 768
40
+
41
+
42
+ class CausalSelfAttention(nn.Module):
43
+ """
44
+ A vanilla multi-head masked self-attention layer with a projection at the end.
45
+ It is possible to use torch.nn.MultiheadAttention here but I am including an
46
+ explicit implementation here to show that there is nothing too scary here.
47
+ """
48
+
49
+ def __init__(self, config):
50
+ super().__init__()
51
+ assert config.n_embd % config.n_head == 0
52
+ # key, query, value projections for all heads
53
+ self.key = nn.Linear(config.n_embd, config.n_embd)
54
+ self.query = nn.Linear(config.n_embd, config.n_embd)
55
+ self.value = nn.Linear(config.n_embd, config.n_embd)
56
+ # regularization
57
+ self.attn_drop = nn.Dropout(config.attn_pdrop)
58
+ self.resid_drop = nn.Dropout(config.resid_pdrop)
59
+ # output projection
60
+ self.proj = nn.Linear(config.n_embd, config.n_embd)
61
+ # causal mask to ensure that attention is only applied to the left in the input sequence
62
+ mask = torch.tril(torch.ones(config.block_size,
63
+ config.block_size))
64
+ if hasattr(config, "n_unmasked"):
65
+ mask[:config.n_unmasked, :config.n_unmasked] = 1
66
+ self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size))
67
+ self.n_head = config.n_head
68
+
69
+ def forward(self, x, layer_past=None):
70
+ B, T, C = x.size()
71
+
72
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
73
+ k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
74
+ q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
75
+ v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
76
+
77
+ present = torch.stack((k, v))
78
+ if layer_past is not None:
79
+ past_key, past_value = layer_past
80
+ k = torch.cat((past_key, k), dim=-2)
81
+ v = torch.cat((past_value, v), dim=-2)
82
+
83
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
84
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
85
+ if layer_past is None:
86
+ att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
87
+
88
+ att = F.softmax(att, dim=-1)
89
+ att = self.attn_drop(att)
90
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
91
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
92
+
93
+ # output projection
94
+ y = self.resid_drop(self.proj(y))
95
+ return y, present # TODO: check that this does not break anything
96
+
97
+
98
+ class Block(nn.Module):
99
+ """ an unassuming Transformer block """
100
+ def __init__(self, config):
101
+ super().__init__()
102
+ self.ln1 = nn.LayerNorm(config.n_embd)
103
+ self.ln2 = nn.LayerNorm(config.n_embd)
104
+ self.attn = CausalSelfAttention(config)
105
+ self.mlp = nn.Sequential(
106
+ nn.Linear(config.n_embd, 4 * config.n_embd),
107
+ nn.GELU(), # nice
108
+ nn.Linear(4 * config.n_embd, config.n_embd),
109
+ nn.Dropout(config.resid_pdrop),
110
+ )
111
+
112
+ def forward(self, x, layer_past=None, return_present=False):
113
+ # TODO: check that training still works
114
+ if return_present: assert not self.training
115
+ # layer past: tuple of length two with B, nh, T, hs
116
+ attn, present = self.attn(self.ln1(x), layer_past=layer_past)
117
+
118
+ x = x + attn
119
+ x = x + self.mlp(self.ln2(x))
120
+ if layer_past is not None or return_present:
121
+ return x, present
122
+ return x
123
+
124
+
125
+ class GPT(nn.Module):
126
+ """ the full GPT language model, with a context size of block_size """
127
+ def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256,
128
+ embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
129
+ super().__init__()
130
+ config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
131
+ embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
132
+ n_layer=n_layer, n_head=n_head, n_embd=n_embd,
133
+ n_unmasked=n_unmasked)
134
+ # input embedding stem
135
+ self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
136
+ self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
137
+ self.drop = nn.Dropout(config.embd_pdrop)
138
+ # transformer
139
+ self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
140
+ # decoder head
141
+ self.ln_f = nn.LayerNorm(config.n_embd)
142
+ self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
143
+ self.block_size = config.block_size
144
+ self.apply(self._init_weights)
145
+ self.config = config
146
+ logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
147
+
148
+ def get_block_size(self):
149
+ return self.block_size
150
+
151
+ def _init_weights(self, module):
152
+ if isinstance(module, (nn.Linear, nn.Embedding)):
153
+ module.weight.data.normal_(mean=0.0, std=0.02)
154
+ if isinstance(module, nn.Linear) and module.bias is not None:
155
+ module.bias.data.zero_()
156
+ elif isinstance(module, nn.LayerNorm):
157
+ module.bias.data.zero_()
158
+ module.weight.data.fill_(1.0)
159
+
160
+ def forward(self, idx, embeddings=None, targets=None):
161
+ # forward the GPT model
162
+ token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
163
+
164
+ if embeddings is not None: # prepend explicit embeddings
165
+ token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
166
+
167
+ t = token_embeddings.shape[1]
168
+ assert t <= self.block_size, "Cannot forward, model block size is exhausted."
169
+ position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
170
+ x = self.drop(token_embeddings + position_embeddings)
171
+ x = self.blocks(x)
172
+ x = self.ln_f(x)
173
+ logits = self.head(x)
174
+
175
+ # if we are given some desired targets also calculate the loss
176
+ loss = None
177
+ if targets is not None:
178
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
179
+
180
+ return logits, loss
181
+
182
+ def forward_with_past(self, idx, embeddings=None, targets=None, past=None, past_length=None):
183
+ # inference only
184
+ assert not self.training
185
+ token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
186
+ if embeddings is not None: # prepend explicit embeddings
187
+ token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
188
+
189
+ if past is not None:
190
+ assert past_length is not None
191
+ past = torch.cat(past, dim=-2) # n_layer, 2, b, nh, len_past, dim_head
192
+ past_shape = list(past.shape)
193
+ expected_shape = [self.config.n_layer, 2, idx.shape[0], self.config.n_head, past_length, self.config.n_embd//self.config.n_head]
194
+ assert past_shape == expected_shape, f"{past_shape} =/= {expected_shape}"
195
+ position_embeddings = self.pos_emb[:, past_length, :] # each position maps to a (learnable) vector
196
+ else:
197
+ position_embeddings = self.pos_emb[:, :token_embeddings.shape[1], :]
198
+
199
+ x = self.drop(token_embeddings + position_embeddings)
200
+ presents = [] # accumulate over layers
201
+ for i, block in enumerate(self.blocks):
202
+ x, present = block(x, layer_past=past[i, ...] if past is not None else None, return_present=True)
203
+ presents.append(present)
204
+
205
+ x = self.ln_f(x)
206
+ logits = self.head(x)
207
+ # if we are given some desired targets also calculate the loss
208
+ loss = None
209
+ if targets is not None:
210
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
211
+
212
+ return logits, loss, torch.stack(presents) # _, _, n_layer, 2, b, nh, 1, dim_head
213
+
214
+
215
+ class DummyGPT(nn.Module):
216
+ # for debugging
217
+ def __init__(self, add_value=1):
218
+ super().__init__()
219
+ self.add_value = add_value
220
+
221
+ def forward(self, idx):
222
+ return idx + self.add_value, None
223
+
224
+
225
+ class CodeGPT(nn.Module):
226
+ """Takes in semi-embeddings"""
227
+ def __init__(self, vocab_size, block_size, in_channels, n_layer=12, n_head=8, n_embd=256,
228
+ embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
229
+ super().__init__()
230
+ config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
231
+ embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
232
+ n_layer=n_layer, n_head=n_head, n_embd=n_embd,
233
+ n_unmasked=n_unmasked)
234
+ # input embedding stem
235
+ self.tok_emb = nn.Linear(in_channels, config.n_embd)
236
+ self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
237
+ self.drop = nn.Dropout(config.embd_pdrop)
238
+ # transformer
239
+ self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
240
+ # decoder head
241
+ self.ln_f = nn.LayerNorm(config.n_embd)
242
+ self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
243
+ self.block_size = config.block_size
244
+ self.apply(self._init_weights)
245
+ self.config = config
246
+ logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
247
+
248
+ def get_block_size(self):
249
+ return self.block_size
250
+
251
+ def _init_weights(self, module):
252
+ if isinstance(module, (nn.Linear, nn.Embedding)):
253
+ module.weight.data.normal_(mean=0.0, std=0.02)
254
+ if isinstance(module, nn.Linear) and module.bias is not None:
255
+ module.bias.data.zero_()
256
+ elif isinstance(module, nn.LayerNorm):
257
+ module.bias.data.zero_()
258
+ module.weight.data.fill_(1.0)
259
+
260
+ def forward(self, idx, embeddings=None, targets=None):
261
+ # forward the GPT model
262
+ token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
263
+
264
+ if embeddings is not None: # prepend explicit embeddings
265
+ token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
266
+
267
+ t = token_embeddings.shape[1]
268
+ assert t <= self.block_size, "Cannot forward, model block size is exhausted."
269
+ position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
270
+ x = self.drop(token_embeddings + position_embeddings)
271
+ x = self.blocks(x)
272
+ x = self.taming_cinln_f(x)
273
+ logits = self.head(x)
274
+
275
+ # if we are given some desired targets also calculate the loss
276
+ loss = None
277
+ if targets is not None:
278
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
279
+
280
+ return logits, loss
281
+
282
+
283
+
284
+ #### sampling utils
285
+
286
+ def top_k_logits(logits, k):
287
+ v, ix = torch.topk(logits, k)
288
+ out = logits.clone()
289
+ out[out < v[:, [-1]]] = -float('Inf')
290
+ return out
291
+
292
+ @torch.no_grad()
293
+ def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
294
+ """
295
+ take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
296
+ the sequence, feeding the predictions back into the model each time. Clearly the sampling
297
+ has quadratic complexity unlike an RNN that is only linear, and has a finite context window
298
+ of block_size, unlike an RNN that has an infinite context window.
299
+ """
300
+ block_size = model.get_block_size()
301
+ model.eval()
302
+ for k in range(steps):
303
+ x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
304
+ logits, _ = model(x_cond)
305
+ # pluck the logits at the final step and scale by temperature
306
+ logits = logits[:, -1, :] / temperature
307
+ # optionally crop probabilities to only the top k options
308
+ if top_k is not None:
309
+ logits = top_k_logits(logits, top_k)
310
+ # apply softmax to convert to probabilities
311
+ probs = F.softmax(logits, dim=-1)
312
+ # sample from the distribution or take the most likely
313
+ if sample:
314
+ ix = torch.multinomial(probs, num_samples=1)
315
+ else:
316
+ _, ix = torch.topk(probs, k=1, dim=-1)
317
+ # append to the sequence and continue
318
+ x = torch.cat((x, ix), dim=1)
319
+
320
+ return x
321
+
322
+
323
+ @torch.no_grad()
324
+ def sample_with_past(x, model, steps, temperature=1., sample_logits=True,
325
+ top_k=None, top_p=None, callback=None):
326
+ # x is conditioning
327
+ sample = x
328
+ cond_len = x.shape[1]
329
+ past = None
330
+ for n in range(steps):
331
+ if callback is not None:
332
+ callback(n)
333
+ logits, _, present = model.forward_with_past(x, past=past, past_length=(n+cond_len-1))
334
+ if past is None:
335
+ past = [present]
336
+ else:
337
+ past.append(present)
338
+ logits = logits[:, -1, :] / temperature
339
+ if top_k is not None:
340
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
341
+
342
+ probs = F.softmax(logits, dim=-1)
343
+ if not sample_logits:
344
+ _, x = torch.topk(probs, k=1, dim=-1)
345
+ else:
346
+ x = torch.multinomial(probs, num_samples=1)
347
+ # append to the sequence and continue
348
+ sample = torch.cat((sample, x), dim=1)
349
+ del past
350
+ sample = sample[:, cond_len:] # cut conditioning off
351
+ return sample
352
+
353
+
354
+ #### clustering utils
355
+
356
+ class KMeans(nn.Module):
357
+ def __init__(self, ncluster=512, nc=3, niter=10):
358
+ super().__init__()
359
+ self.ncluster = ncluster
360
+ self.nc = nc
361
+ self.niter = niter
362
+ self.shape = (3,32,32)
363
+ self.register_buffer("C", torch.zeros(self.ncluster,nc))
364
+ self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
365
+
366
+ def is_initialized(self):
367
+ return self.initialized.item() == 1
368
+
369
+ @torch.no_grad()
370
+ def initialize(self, x):
371
+ N, D = x.shape
372
+ assert D == self.nc, D
373
+ c = x[torch.randperm(N)[:self.ncluster]] # init clusters at random
374
+ for i in range(self.niter):
375
+ # assign all pixels to the closest codebook element
376
+ a = ((x[:, None, :] - c[None, :, :])**2).sum(-1).argmin(1)
377
+ # move each codebook element to be the mean of the pixels that assigned to it
378
+ c = torch.stack([x[a==k].mean(0) for k in range(self.ncluster)])
379
+ # re-assign any poorly positioned codebook elements
380
+ nanix = torch.any(torch.isnan(c), dim=1)
381
+ ndead = nanix.sum().item()
382
+ print('done step %d/%d, re-initialized %d dead clusters' % (i+1, self.niter, ndead))
383
+ c[nanix] = x[torch.randperm(N)[:ndead]] # re-init dead clusters
384
+
385
+ self.C.copy_(c)
386
+ self.initialized.fill_(1)
387
+
388
+
389
+ def forward(self, x, reverse=False, shape=None):
390
+ if not reverse:
391
+ # flatten
392
+ bs,c,h,w = x.shape
393
+ assert c == self.nc
394
+ x = x.reshape(bs,c,h*w,1)
395
+ C = self.C.permute(1,0)
396
+ C = C.reshape(1,c,1,self.ncluster)
397
+ a = ((x-C)**2).sum(1).argmin(-1) # bs, h*w indices
398
+ return a
399
+ else:
400
+ # flatten
401
+ bs, HW = x.shape
402
+ """
403
+ c = self.C.reshape( 1, self.nc, 1, self.ncluster)
404
+ c = c[bs*[0],:,:,:]
405
+ c = c[:,:,HW*[0],:]
406
+ x = x.reshape(bs, 1, HW, 1)
407
+ x = x[:,3*[0],:,:]
408
+ x = torch.gather(c, dim=3, index=x)
409
+ """
410
+ x = self.C[x]
411
+ x = x.permute(0,2,1)
412
+ shape = shape if shape is not None else self.shape
413
+ x = x.reshape(bs, *shape)
414
+
415
+ return x
taming-transformers/taming/modules/transformer/permuter.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+
6
+ class AbstractPermuter(nn.Module):
7
+ def __init__(self, *args, **kwargs):
8
+ super().__init__()
9
+ def forward(self, x, reverse=False):
10
+ raise NotImplementedError
11
+
12
+
13
+ class Identity(AbstractPermuter):
14
+ def __init__(self):
15
+ super().__init__()
16
+
17
+ def forward(self, x, reverse=False):
18
+ return x
19
+
20
+
21
+ class Subsample(AbstractPermuter):
22
+ def __init__(self, H, W):
23
+ super().__init__()
24
+ C = 1
25
+ indices = np.arange(H*W).reshape(C,H,W)
26
+ while min(H, W) > 1:
27
+ indices = indices.reshape(C,H//2,2,W//2,2)
28
+ indices = indices.transpose(0,2,4,1,3)
29
+ indices = indices.reshape(C*4,H//2, W//2)
30
+ H = H//2
31
+ W = W//2
32
+ C = C*4
33
+ assert H == W == 1
34
+ idx = torch.tensor(indices.ravel())
35
+ self.register_buffer('forward_shuffle_idx',
36
+ nn.Parameter(idx, requires_grad=False))
37
+ self.register_buffer('backward_shuffle_idx',
38
+ nn.Parameter(torch.argsort(idx), requires_grad=False))
39
+
40
+ def forward(self, x, reverse=False):
41
+ if not reverse:
42
+ return x[:, self.forward_shuffle_idx]
43
+ else:
44
+ return x[:, self.backward_shuffle_idx]
45
+
46
+
47
+ def mortonify(i, j):
48
+ """(i,j) index to linear morton code"""
49
+ i = np.uint64(i)
50
+ j = np.uint64(j)
51
+
52
+ z = np.uint(0)
53
+
54
+ for pos in range(32):
55
+ z = (z |
56
+ ((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) |
57
+ ((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1))
58
+ )
59
+ return z
60
+
61
+
62
+ class ZCurve(AbstractPermuter):
63
+ def __init__(self, H, W):
64
+ super().__init__()
65
+ reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)]
66
+ idx = np.argsort(reverseidx)
67
+ idx = torch.tensor(idx)
68
+ reverseidx = torch.tensor(reverseidx)
69
+ self.register_buffer('forward_shuffle_idx',
70
+ idx)
71
+ self.register_buffer('backward_shuffle_idx',
72
+ reverseidx)
73
+
74
+ def forward(self, x, reverse=False):
75
+ if not reverse:
76
+ return x[:, self.forward_shuffle_idx]
77
+ else:
78
+ return x[:, self.backward_shuffle_idx]
79
+
80
+
81
+ class SpiralOut(AbstractPermuter):
82
+ def __init__(self, H, W):
83
+ super().__init__()
84
+ assert H == W
85
+ size = W
86
+ indices = np.arange(size*size).reshape(size,size)
87
+
88
+ i0 = size//2
89
+ j0 = size//2-1
90
+
91
+ i = i0
92
+ j = j0
93
+
94
+ idx = [indices[i0, j0]]
95
+ step_mult = 0
96
+ for c in range(1, size//2+1):
97
+ step_mult += 1
98
+ # steps left
99
+ for k in range(step_mult):
100
+ i = i - 1
101
+ j = j
102
+ idx.append(indices[i, j])
103
+
104
+ # step down
105
+ for k in range(step_mult):
106
+ i = i
107
+ j = j + 1
108
+ idx.append(indices[i, j])
109
+
110
+ step_mult += 1
111
+ if c < size//2:
112
+ # step right
113
+ for k in range(step_mult):
114
+ i = i + 1
115
+ j = j
116
+ idx.append(indices[i, j])
117
+
118
+ # step up
119
+ for k in range(step_mult):
120
+ i = i
121
+ j = j - 1
122
+ idx.append(indices[i, j])
123
+ else:
124
+ # end reached
125
+ for k in range(step_mult-1):
126
+ i = i + 1
127
+ idx.append(indices[i, j])
128
+
129
+ assert len(idx) == size*size
130
+ idx = torch.tensor(idx)
131
+ self.register_buffer('forward_shuffle_idx', idx)
132
+ self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
133
+
134
+ def forward(self, x, reverse=False):
135
+ if not reverse:
136
+ return x[:, self.forward_shuffle_idx]
137
+ else:
138
+ return x[:, self.backward_shuffle_idx]
139
+
140
+
141
+ class SpiralIn(AbstractPermuter):
142
+ def __init__(self, H, W):
143
+ super().__init__()
144
+ assert H == W
145
+ size = W
146
+ indices = np.arange(size*size).reshape(size,size)
147
+
148
+ i0 = size//2
149
+ j0 = size//2-1
150
+
151
+ i = i0
152
+ j = j0
153
+
154
+ idx = [indices[i0, j0]]
155
+ step_mult = 0
156
+ for c in range(1, size//2+1):
157
+ step_mult += 1
158
+ # steps left
159
+ for k in range(step_mult):
160
+ i = i - 1
161
+ j = j
162
+ idx.append(indices[i, j])
163
+
164
+ # step down
165
+ for k in range(step_mult):
166
+ i = i
167
+ j = j + 1
168
+ idx.append(indices[i, j])
169
+
170
+ step_mult += 1
171
+ if c < size//2:
172
+ # step right
173
+ for k in range(step_mult):
174
+ i = i + 1
175
+ j = j
176
+ idx.append(indices[i, j])
177
+
178
+ # step up
179
+ for k in range(step_mult):
180
+ i = i
181
+ j = j - 1
182
+ idx.append(indices[i, j])
183
+ else:
184
+ # end reached
185
+ for k in range(step_mult-1):
186
+ i = i + 1
187
+ idx.append(indices[i, j])
188
+
189
+ assert len(idx) == size*size
190
+ idx = idx[::-1]
191
+ idx = torch.tensor(idx)
192
+ self.register_buffer('forward_shuffle_idx', idx)
193
+ self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
194
+
195
+ def forward(self, x, reverse=False):
196
+ if not reverse:
197
+ return x[:, self.forward_shuffle_idx]
198
+ else:
199
+ return x[:, self.backward_shuffle_idx]
200
+
201
+
202
+ class Random(nn.Module):
203
+ def __init__(self, H, W):
204
+ super().__init__()
205
+ indices = np.random.RandomState(1).permutation(H*W)
206
+ idx = torch.tensor(indices.ravel())
207
+ self.register_buffer('forward_shuffle_idx', idx)
208
+ self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
209
+
210
+ def forward(self, x, reverse=False):
211
+ if not reverse:
212
+ return x[:, self.forward_shuffle_idx]
213
+ else:
214
+ return x[:, self.backward_shuffle_idx]
215
+
216
+
217
+ class AlternateParsing(AbstractPermuter):
218
+ def __init__(self, H, W):
219
+ super().__init__()
220
+ indices = np.arange(W*H).reshape(H,W)
221
+ for i in range(1, H, 2):
222
+ indices[i, :] = indices[i, ::-1]
223
+ idx = indices.flatten()
224
+ assert len(idx) == H*W
225
+ idx = torch.tensor(idx)
226
+ self.register_buffer('forward_shuffle_idx', idx)
227
+ self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
228
+
229
+ def forward(self, x, reverse=False):
230
+ if not reverse:
231
+ return x[:, self.forward_shuffle_idx]
232
+ else:
233
+ return x[:, self.backward_shuffle_idx]
234
+
235
+
236
+ if __name__ == "__main__":
237
+ p0 = AlternateParsing(16, 16)
238
+ print(p0.forward_shuffle_idx)
239
+ print(p0.backward_shuffle_idx)
240
+
241
+ x = torch.randint(0, 768, size=(11, 256))
242
+ y = p0(x)
243
+ xre = p0(y, reverse=True)
244
+ assert torch.equal(x, xre)
245
+
246
+ p1 = SpiralOut(2, 2)
247
+ print(p1.forward_shuffle_idx)
248
+ print(p1.backward_shuffle_idx)
taming-transformers/taming/modules/util.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ def count_params(model):
6
+ total_params = sum(p.numel() for p in model.parameters())
7
+ return total_params
8
+
9
+
10
+ class ActNorm(nn.Module):
11
+ def __init__(self, num_features, logdet=False, affine=True,
12
+ allow_reverse_init=False):
13
+ assert affine
14
+ super().__init__()
15
+ self.logdet = logdet
16
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
17
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
18
+ self.allow_reverse_init = allow_reverse_init
19
+
20
+ self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
21
+
22
+ def initialize(self, input):
23
+ with torch.no_grad():
24
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
25
+ mean = (
26
+ flatten.mean(1)
27
+ .unsqueeze(1)
28
+ .unsqueeze(2)
29
+ .unsqueeze(3)
30
+ .permute(1, 0, 2, 3)
31
+ )
32
+ std = (
33
+ flatten.std(1)
34
+ .unsqueeze(1)
35
+ .unsqueeze(2)
36
+ .unsqueeze(3)
37
+ .permute(1, 0, 2, 3)
38
+ )
39
+
40
+ self.loc.data.copy_(-mean)
41
+ self.scale.data.copy_(1 / (std + 1e-6))
42
+
43
+ def forward(self, input, reverse=False):
44
+ if reverse:
45
+ return self.reverse(input)
46
+ if len(input.shape) == 2:
47
+ input = input[:,:,None,None]
48
+ squeeze = True
49
+ else:
50
+ squeeze = False
51
+
52
+ _, _, height, width = input.shape
53
+
54
+ if self.training and self.initialized.item() == 0:
55
+ self.initialize(input)
56
+ self.initialized.fill_(1)
57
+
58
+ h = self.scale * (input + self.loc)
59
+
60
+ if squeeze:
61
+ h = h.squeeze(-1).squeeze(-1)
62
+
63
+ if self.logdet:
64
+ log_abs = torch.log(torch.abs(self.scale))
65
+ logdet = height*width*torch.sum(log_abs)
66
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
67
+ return h, logdet
68
+
69
+ return h
70
+
71
+ def reverse(self, output):
72
+ if self.training and self.initialized.item() == 0:
73
+ if not self.allow_reverse_init:
74
+ raise RuntimeError(
75
+ "Initializing ActNorm in reverse direction is "
76
+ "disabled by default. Use allow_reverse_init=True to enable."
77
+ )
78
+ else:
79
+ self.initialize(output)
80
+ self.initialized.fill_(1)
81
+
82
+ if len(output.shape) == 2:
83
+ output = output[:,:,None,None]
84
+ squeeze = True
85
+ else:
86
+ squeeze = False
87
+
88
+ h = output / self.scale - self.loc
89
+
90
+ if squeeze:
91
+ h = h.squeeze(-1).squeeze(-1)
92
+ return h
93
+
94
+
95
+ class AbstractEncoder(nn.Module):
96
+ def __init__(self):
97
+ super().__init__()
98
+
99
+ def encode(self, *args, **kwargs):
100
+ raise NotImplementedError
101
+
102
+
103
+ class Labelator(AbstractEncoder):
104
+ """Net2Net Interface for Class-Conditional Model"""
105
+ def __init__(self, n_classes, quantize_interface=True):
106
+ super().__init__()
107
+ self.n_classes = n_classes
108
+ self.quantize_interface = quantize_interface
109
+
110
+ def encode(self, c):
111
+ c = c[:,None]
112
+ if self.quantize_interface:
113
+ return c, None, [None, None, c.long()]
114
+ return c
115
+
116
+
117
+ class SOSProvider(AbstractEncoder):
118
+ # for unconditional training
119
+ def __init__(self, sos_token, quantize_interface=True):
120
+ super().__init__()
121
+ self.sos_token = sos_token
122
+ self.quantize_interface = quantize_interface
123
+
124
+ def encode(self, x):
125
+ # get batch size from data and replicate sos_token
126
+ c = torch.ones(x.shape[0], 1)*self.sos_token
127
+ c = c.long().to(x.device)
128
+ if self.quantize_interface:
129
+ return c, None, [None, None, c]
130
+ return c
taming-transformers/taming/modules/vqvae/quantize.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from torch import einsum
6
+ from einops import rearrange
7
+
8
+
9
+ class VectorQuantizer(nn.Module):
10
+ """
11
+ see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
12
+ ____________________________________________
13
+ Discretization bottleneck part of the VQ-VAE.
14
+ Inputs:
15
+ - n_e : number of embeddings
16
+ - e_dim : dimension of embedding
17
+ - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
18
+ _____________________________________________
19
+ """
20
+
21
+ # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for
22
+ # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be
23
+ # used wherever VectorQuantizer has been used before and is additionally
24
+ # more efficient.
25
+ def __init__(self, n_e, e_dim, beta):
26
+ super(VectorQuantizer, self).__init__()
27
+ self.n_e = n_e
28
+ self.e_dim = e_dim
29
+ self.beta = beta
30
+
31
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
32
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
33
+
34
+ def forward(self, z):
35
+ """
36
+ Inputs the output of the encoder network z and maps it to a discrete
37
+ one-hot vector that is the index of the closest embedding vector e_j
38
+ z (continuous) -> z_q (discrete)
39
+ z.shape = (batch, channel, height, width)
40
+ quantization pipeline:
41
+ 1. get encoder input (B,C,H,W)
42
+ 2. flatten input to (B*H*W,C)
43
+ """
44
+ # reshape z -> (batch, height, width, channel) and flatten
45
+ z = z.permute(0, 2, 3, 1).contiguous()
46
+ z_flattened = z.view(-1, self.e_dim)
47
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
48
+
49
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
50
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
51
+ torch.matmul(z_flattened, self.embedding.weight.t())
52
+
53
+ ## could possible replace this here
54
+ # #\start...
55
+ # find closest encodings
56
+ min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
57
+
58
+ min_encodings = torch.zeros(
59
+ min_encoding_indices.shape[0], self.n_e).to(z)
60
+ min_encodings.scatter_(1, min_encoding_indices, 1)
61
+
62
+ # dtype min encodings: torch.float32
63
+ # min_encodings shape: torch.Size([2048, 512])
64
+ # min_encoding_indices.shape: torch.Size([2048, 1])
65
+
66
+ # get quantized latent vectors
67
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
68
+ #.........\end
69
+
70
+ # with:
71
+ # .........\start
72
+ #min_encoding_indices = torch.argmin(d, dim=1)
73
+ #z_q = self.embedding(min_encoding_indices)
74
+ # ......\end......... (TODO)
75
+
76
+ # compute loss for embedding
77
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
78
+ torch.mean((z_q - z.detach()) ** 2)
79
+
80
+ # preserve gradients
81
+ z_q = z + (z_q - z).detach()
82
+
83
+ # perplexity
84
+ e_mean = torch.mean(min_encodings, dim=0)
85
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
86
+
87
+ # reshape back to match original input shape
88
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
89
+
90
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
91
+
92
+ def get_codebook_entry(self, indices, shape):
93
+ # shape specifying (batch, height, width, channel)
94
+ # TODO: check for more easy handling with nn.Embedding
95
+ min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
96
+ min_encodings.scatter_(1, indices[:,None], 1)
97
+
98
+ # get quantized latent vectors
99
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
100
+
101
+ if shape is not None:
102
+ z_q = z_q.view(shape)
103
+
104
+ # reshape back to match original input shape
105
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
106
+
107
+ return z_q
108
+
109
+
110
+ class GumbelQuantize(nn.Module):
111
+ """
112
+ credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
113
+ Gumbel Softmax trick quantizer
114
+ Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
115
+ https://arxiv.org/abs/1611.01144
116
+ """
117
+ def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True,
118
+ kl_weight=5e-4, temp_init=1.0, use_vqinterface=True,
119
+ remap=None, unknown_index="random"):
120
+ super().__init__()
121
+
122
+ self.embedding_dim = embedding_dim
123
+ self.n_embed = n_embed
124
+
125
+ self.straight_through = straight_through
126
+ self.temperature = temp_init
127
+ self.kl_weight = kl_weight
128
+
129
+ self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
130
+ self.embed = nn.Embedding(n_embed, embedding_dim)
131
+
132
+ self.use_vqinterface = use_vqinterface
133
+
134
+ self.remap = remap
135
+ if self.remap is not None:
136
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
137
+ self.re_embed = self.used.shape[0]
138
+ self.unknown_index = unknown_index # "random" or "extra" or integer
139
+ if self.unknown_index == "extra":
140
+ self.unknown_index = self.re_embed
141
+ self.re_embed = self.re_embed+1
142
+ print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
143
+ f"Using {self.unknown_index} for unknown indices.")
144
+ else:
145
+ self.re_embed = n_embed
146
+
147
+ def remap_to_used(self, inds):
148
+ ishape = inds.shape
149
+ assert len(ishape)>1
150
+ inds = inds.reshape(ishape[0],-1)
151
+ used = self.used.to(inds)
152
+ match = (inds[:,:,None]==used[None,None,...]).long()
153
+ new = match.argmax(-1)
154
+ unknown = match.sum(2)<1
155
+ if self.unknown_index == "random":
156
+ new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
157
+ else:
158
+ new[unknown] = self.unknown_index
159
+ return new.reshape(ishape)
160
+
161
+ def unmap_to_all(self, inds):
162
+ ishape = inds.shape
163
+ assert len(ishape)>1
164
+ inds = inds.reshape(ishape[0],-1)
165
+ used = self.used.to(inds)
166
+ if self.re_embed > self.used.shape[0]: # extra token
167
+ inds[inds>=self.used.shape[0]] = 0 # simply set to zero
168
+ back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
169
+ return back.reshape(ishape)
170
+
171
+ def forward(self, z, temp=None, return_logits=False):
172
+ # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
173
+ hard = self.straight_through if self.training else True
174
+ temp = self.temperature if temp is None else temp
175
+
176
+ logits = self.proj(z)
177
+ if self.remap is not None:
178
+ # continue only with used logits
179
+ full_zeros = torch.zeros_like(logits)
180
+ logits = logits[:,self.used,...]
181
+
182
+ soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
183
+ if self.remap is not None:
184
+ # go back to all entries but unused set to zero
185
+ full_zeros[:,self.used,...] = soft_one_hot
186
+ soft_one_hot = full_zeros
187
+ z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
188
+
189
+ # + kl divergence to the prior loss
190
+ qy = F.softmax(logits, dim=1)
191
+ diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
192
+
193
+ ind = soft_one_hot.argmax(dim=1)
194
+ if self.remap is not None:
195
+ ind = self.remap_to_used(ind)
196
+ if self.use_vqinterface:
197
+ if return_logits:
198
+ return z_q, diff, (None, None, ind), logits
199
+ return z_q, diff, (None, None, ind)
200
+ return z_q, diff, ind
201
+
202
+ def get_codebook_entry(self, indices, shape):
203
+ b, h, w, c = shape
204
+ assert b*h*w == indices.shape[0]
205
+ indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w)
206
+ if self.remap is not None:
207
+ indices = self.unmap_to_all(indices)
208
+ one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
209
+ z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight)
210
+ return z_q
211
+
212
+
213
+ class VectorQuantizer2(nn.Module):
214
+ """
215
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
216
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
217
+ """
218
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
219
+ # backwards compatibility we use the buggy version by default, but you can
220
+ # specify legacy=False to fix it.
221
+ def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
222
+ sane_index_shape=False, legacy=True):
223
+ super().__init__()
224
+ self.n_e = n_e
225
+ self.e_dim = e_dim
226
+ self.beta = beta
227
+ self.legacy = legacy
228
+
229
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
230
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
231
+
232
+ self.remap = remap
233
+ if self.remap is not None:
234
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
235
+ self.re_embed = self.used.shape[0]
236
+ self.unknown_index = unknown_index # "random" or "extra" or integer
237
+ if self.unknown_index == "extra":
238
+ self.unknown_index = self.re_embed
239
+ self.re_embed = self.re_embed+1
240
+ print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
241
+ f"Using {self.unknown_index} for unknown indices.")
242
+ else:
243
+ self.re_embed = n_e
244
+
245
+ self.sane_index_shape = sane_index_shape
246
+
247
+ def remap_to_used(self, inds):
248
+ ishape = inds.shape
249
+ assert len(ishape)>1
250
+ inds = inds.reshape(ishape[0],-1)
251
+ used = self.used.to(inds)
252
+ match = (inds[:,:,None]==used[None,None,...]).long()
253
+ new = match.argmax(-1)
254
+ unknown = match.sum(2)<1
255
+ if self.unknown_index == "random":
256
+ new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
257
+ else:
258
+ new[unknown] = self.unknown_index
259
+ return new.reshape(ishape)
260
+
261
+ def unmap_to_all(self, inds):
262
+ ishape = inds.shape
263
+ assert len(ishape)>1
264
+ inds = inds.reshape(ishape[0],-1)
265
+ used = self.used.to(inds)
266
+ if self.re_embed > self.used.shape[0]: # extra token
267
+ inds[inds>=self.used.shape[0]] = 0 # simply set to zero
268
+ back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
269
+ return back.reshape(ishape)
270
+
271
+ def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
272
+ assert temp is None or temp==1.0, "Only for interface compatible with Gumbel"
273
+ assert rescale_logits==False, "Only for interface compatible with Gumbel"
274
+ assert return_logits==False, "Only for interface compatible with Gumbel"
275
+ # reshape z -> (batch, height, width, channel) and flatten
276
+ z = rearrange(z, 'b c h w -> b h w c').contiguous()
277
+ z_flattened = z.view(-1, self.e_dim)
278
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
279
+
280
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
281
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
282
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
283
+
284
+ min_encoding_indices = torch.argmin(d, dim=1)
285
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
286
+ perplexity = None
287
+ min_encodings = None
288
+
289
+ # compute loss for embedding
290
+ if not self.legacy:
291
+ loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
292
+ torch.mean((z_q - z.detach()) ** 2)
293
+ else:
294
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
295
+ torch.mean((z_q - z.detach()) ** 2)
296
+
297
+ # preserve gradients
298
+ z_q = z + (z_q - z).detach()
299
+
300
+ # reshape back to match original input shape
301
+ z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
302
+
303
+ if self.remap is not None:
304
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
305
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
306
+ min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten
307
+
308
+ if self.sane_index_shape:
309
+ min_encoding_indices = min_encoding_indices.reshape(
310
+ z_q.shape[0], z_q.shape[2], z_q.shape[3])
311
+
312
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
313
+
314
+ def get_codebook_entry(self, indices, shape):
315
+ # shape specifying (batch, height, width, channel)
316
+ if self.remap is not None:
317
+ indices = indices.reshape(shape[0],-1) # add batch axis
318
+ indices = self.unmap_to_all(indices)
319
+ indices = indices.reshape(-1) # flatten again
320
+
321
+ # get quantized latent vectors
322
+ z_q = self.embedding(indices)
323
+
324
+ if shape is not None:
325
+ z_q = z_q.view(shape)
326
+ # reshape back to match original input shape
327
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
328
+
329
+ return z_q
330
+
331
+ class EmbeddingEMA(nn.Module):
332
+ def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
333
+ super().__init__()
334
+ self.decay = decay
335
+ self.eps = eps
336
+ weight = torch.randn(num_tokens, codebook_dim)
337
+ self.weight = nn.Parameter(weight, requires_grad = False)
338
+ self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False)
339
+ self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False)
340
+ self.update = True
341
+
342
+ def forward(self, embed_id):
343
+ return F.embedding(embed_id, self.weight)
344
+
345
+ def cluster_size_ema_update(self, new_cluster_size):
346
+ self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
347
+
348
+ def embed_avg_ema_update(self, new_embed_avg):
349
+ self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
350
+
351
+ def weight_update(self, num_tokens):
352
+ n = self.cluster_size.sum()
353
+ smoothed_cluster_size = (
354
+ (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
355
+ )
356
+ #normalize embedding average with smoothed cluster size
357
+ embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
358
+ self.weight.data.copy_(embed_normalized)
359
+
360
+
361
+ class EMAVectorQuantizer(nn.Module):
362
+ def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
363
+ remap=None, unknown_index="random"):
364
+ super().__init__()
365
+ self.codebook_dim = codebook_dim
366
+ self.num_tokens = num_tokens
367
+ self.beta = beta
368
+ self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
369
+
370
+ self.remap = remap
371
+ if self.remap is not None:
372
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
373
+ self.re_embed = self.used.shape[0]
374
+ self.unknown_index = unknown_index # "random" or "extra" or integer
375
+ if self.unknown_index == "extra":
376
+ self.unknown_index = self.re_embed
377
+ self.re_embed = self.re_embed+1
378
+ print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
379
+ f"Using {self.unknown_index} for unknown indices.")
380
+ else:
381
+ self.re_embed = n_embed
382
+
383
+ def remap_to_used(self, inds):
384
+ ishape = inds.shape
385
+ assert len(ishape)>1
386
+ inds = inds.reshape(ishape[0],-1)
387
+ used = self.used.to(inds)
388
+ match = (inds[:,:,None]==used[None,None,...]).long()
389
+ new = match.argmax(-1)
390
+ unknown = match.sum(2)<1
391
+ if self.unknown_index == "random":
392
+ new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
393
+ else:
394
+ new[unknown] = self.unknown_index
395
+ return new.reshape(ishape)
396
+
397
+ def unmap_to_all(self, inds):
398
+ ishape = inds.shape
399
+ assert len(ishape)>1
400
+ inds = inds.reshape(ishape[0],-1)
401
+ used = self.used.to(inds)
402
+ if self.re_embed > self.used.shape[0]: # extra token
403
+ inds[inds>=self.used.shape[0]] = 0 # simply set to zero
404
+ back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
405
+ return back.reshape(ishape)
406
+
407
+ def forward(self, z):
408
+ # reshape z -> (batch, height, width, channel) and flatten
409
+ #z, 'b c h w -> b h w c'
410
+ z = rearrange(z, 'b c h w -> b h w c')
411
+ z_flattened = z.reshape(-1, self.codebook_dim)
412
+
413
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
414
+ d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
415
+ self.embedding.weight.pow(2).sum(dim=1) - 2 * \
416
+ torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
417
+
418
+
419
+ encoding_indices = torch.argmin(d, dim=1)
420
+
421
+ z_q = self.embedding(encoding_indices).view(z.shape)
422
+ encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
423
+ avg_probs = torch.mean(encodings, dim=0)
424
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
425
+
426
+ if self.training and self.embedding.update:
427
+ #EMA cluster size
428
+ encodings_sum = encodings.sum(0)
429
+ self.embedding.cluster_size_ema_update(encodings_sum)
430
+ #EMA embedding average
431
+ embed_sum = encodings.transpose(0,1) @ z_flattened
432
+ self.embedding.embed_avg_ema_update(embed_sum)
433
+ #normalize embed_avg and update weight
434
+ self.embedding.weight_update(self.num_tokens)
435
+
436
+ # compute loss for embedding
437
+ loss = self.beta * F.mse_loss(z_q.detach(), z)
438
+
439
+ # preserve gradients
440
+ z_q = z + (z_q - z).detach()
441
+
442
+ # reshape back to match original input shape
443
+ #z_q, 'b h w c -> b c h w'
444
+ z_q = rearrange(z_q, 'b h w c -> b c h w')
445
+ return z_q, loss, (perplexity, encodings, encoding_indices)
taming-transformers/taming/util.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, hashlib
2
+ import requests
3
+ from tqdm import tqdm
4
+
5
+ URL_MAP = {
6
+ "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
7
+ }
8
+
9
+ CKPT_MAP = {
10
+ "vgg_lpips": "vgg.pth"
11
+ }
12
+
13
+ MD5_MAP = {
14
+ "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
15
+ }
16
+
17
+
18
+ def download(url, local_path, chunk_size=1024):
19
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
20
+ with requests.get(url, stream=True) as r:
21
+ total_size = int(r.headers.get("content-length", 0))
22
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
23
+ with open(local_path, "wb") as f:
24
+ for data in r.iter_content(chunk_size=chunk_size):
25
+ if data:
26
+ f.write(data)
27
+ pbar.update(chunk_size)
28
+
29
+
30
+ def md5_hash(path):
31
+ with open(path, "rb") as f:
32
+ content = f.read()
33
+ return hashlib.md5(content).hexdigest()
34
+
35
+
36
+ def get_ckpt_path(name, root, check=False):
37
+ assert name in URL_MAP
38
+ path = os.path.join(root, CKPT_MAP[name])
39
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
40
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
41
+ download(URL_MAP[name], path)
42
+ md5 = md5_hash(path)
43
+ assert md5 == MD5_MAP[name], md5
44
+ return path
45
+
46
+
47
+ class KeyNotFoundError(Exception):
48
+ def __init__(self, cause, keys=None, visited=None):
49
+ self.cause = cause
50
+ self.keys = keys
51
+ self.visited = visited
52
+ messages = list()
53
+ if keys is not None:
54
+ messages.append("Key not found: {}".format(keys))
55
+ if visited is not None:
56
+ messages.append("Visited: {}".format(visited))
57
+ messages.append("Cause:\n{}".format(cause))
58
+ message = "\n".join(messages)
59
+ super().__init__(message)
60
+
61
+
62
+ def retrieve(
63
+ list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
64
+ ):
65
+ """Given a nested list or dict return the desired value at key expanding
66
+ callable nodes if necessary and :attr:`expand` is ``True``. The expansion
67
+ is done in-place.
68
+
69
+ Parameters
70
+ ----------
71
+ list_or_dict : list or dict
72
+ Possibly nested list or dictionary.
73
+ key : str
74
+ key/to/value, path like string describing all keys necessary to
75
+ consider to get to the desired value. List indices can also be
76
+ passed here.
77
+ splitval : str
78
+ String that defines the delimiter between keys of the
79
+ different depth levels in `key`.
80
+ default : obj
81
+ Value returned if :attr:`key` is not found.
82
+ expand : bool
83
+ Whether to expand callable nodes on the path or not.
84
+
85
+ Returns
86
+ -------
87
+ The desired value or if :attr:`default` is not ``None`` and the
88
+ :attr:`key` is not found returns ``default``.
89
+
90
+ Raises
91
+ ------
92
+ Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
93
+ ``None``.
94
+ """
95
+
96
+ keys = key.split(splitval)
97
+
98
+ success = True
99
+ try:
100
+ visited = []
101
+ parent = None
102
+ last_key = None
103
+ for key in keys:
104
+ if callable(list_or_dict):
105
+ if not expand:
106
+ raise KeyNotFoundError(
107
+ ValueError(
108
+ "Trying to get past callable node with expand=False."
109
+ ),
110
+ keys=keys,
111
+ visited=visited,
112
+ )
113
+ list_or_dict = list_or_dict()
114
+ parent[last_key] = list_or_dict
115
+
116
+ last_key = key
117
+ parent = list_or_dict
118
+
119
+ try:
120
+ if isinstance(list_or_dict, dict):
121
+ list_or_dict = list_or_dict[key]
122
+ else:
123
+ list_or_dict = list_or_dict[int(key)]
124
+ except (KeyError, IndexError, ValueError) as e:
125
+ raise KeyNotFoundError(e, keys=keys, visited=visited)
126
+
127
+ visited += [key]
128
+ # final expansion of retrieved value
129
+ if expand and callable(list_or_dict):
130
+ list_or_dict = list_or_dict()
131
+ parent[last_key] = list_or_dict
132
+ except KeyNotFoundError as e:
133
+ if default is None:
134
+ raise e
135
+ else:
136
+ list_or_dict = default
137
+ success = False
138
+
139
+ if not pass_success:
140
+ return list_or_dict
141
+ else:
142
+ return list_or_dict, success
143
+
144
+
145
+ if __name__ == "__main__":
146
+ config = {"keya": "a",
147
+ "keyb": "b",
148
+ "keyc":
149
+ {"cc1": 1,
150
+ "cc2": 2,
151
+ }
152
+ }
153
+ from omegaconf import OmegaConf
154
+ config = OmegaConf.create(config)
155
+ print(config)
156
+ retrieve(config, "keya")
157
+
taming-transformers/taming_transformers.egg-info/PKG-INFO ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: taming-transformers
3
+ Version: 0.0.1
4
+ Summary: Taming Transformers for High-Resolution Image Synthesis
5
+ Home-page: UNKNOWN
6
+ License: UNKNOWN
7
+ Platform: UNKNOWN
8
+
9
+ UNKNOWN
10
+
taming-transformers/taming_transformers.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ README.md
2
+ setup.py
3
+ taming_transformers.egg-info/PKG-INFO
4
+ taming_transformers.egg-info/SOURCES.txt
5
+ taming_transformers.egg-info/dependency_links.txt
6
+ taming_transformers.egg-info/requires.txt
7
+ taming_transformers.egg-info/top_level.txt
taming-transformers/taming_transformers.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
taming-transformers/taming_transformers.egg-info/requires.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ numpy
3
+ tqdm
taming-transformers/taming_transformers.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+