Chao Xu
commited on
Commit
•
169a228
1
Parent(s):
53ab577
add taming
Browse files- .gitignore +2 -1
- taming-transformers/.gitignore +2 -0
- taming-transformers/License.txt +19 -0
- taming-transformers/README.md +410 -0
- taming-transformers/configs/coco_cond_stage.yaml +49 -0
- taming-transformers/configs/coco_scene_images_transformer.yaml +80 -0
- taming-transformers/configs/custom_vqgan.yaml +43 -0
- taming-transformers/configs/drin_transformer.yaml +77 -0
- taming-transformers/configs/faceshq_transformer.yaml +61 -0
- taming-transformers/configs/faceshq_vqgan.yaml +42 -0
- taming-transformers/configs/imagenet_vqgan.yaml +42 -0
- taming-transformers/configs/imagenetdepth_vqgan.yaml +41 -0
- taming-transformers/configs/open_images_scene_images_transformer.yaml +86 -0
- taming-transformers/configs/sflckr_cond_stage.yaml +43 -0
- taming-transformers/environment.yaml +25 -0
- taming-transformers/main.py +585 -0
- taming-transformers/scripts/extract_depth.py +112 -0
- taming-transformers/scripts/extract_segmentation.py +130 -0
- taming-transformers/scripts/extract_submodel.py +17 -0
- taming-transformers/scripts/make_samples.py +292 -0
- taming-transformers/scripts/make_scene_samples.py +198 -0
- taming-transformers/scripts/sample_conditional.py +355 -0
- taming-transformers/scripts/sample_fast.py +260 -0
- taming-transformers/setup.py +13 -0
- taming-transformers/taming/lr_scheduler.py +34 -0
- taming-transformers/taming/models/cond_transformer.py +352 -0
- taming-transformers/taming/models/dummy_cond_stage.py +22 -0
- taming-transformers/taming/models/vqgan.py +404 -0
- taming-transformers/taming/modules/diffusionmodules/model.py +776 -0
- taming-transformers/taming/modules/discriminator/model.py +67 -0
- taming-transformers/taming/modules/losses/__init__.py +2 -0
- taming-transformers/taming/modules/losses/lpips.py +123 -0
- taming-transformers/taming/modules/losses/segmentation.py +22 -0
- taming-transformers/taming/modules/losses/vqperceptual.py +136 -0
- taming-transformers/taming/modules/misc/coord.py +31 -0
- taming-transformers/taming/modules/transformer/mingpt.py +415 -0
- taming-transformers/taming/modules/transformer/permuter.py +248 -0
- taming-transformers/taming/modules/util.py +130 -0
- taming-transformers/taming/modules/vqvae/quantize.py +445 -0
- taming-transformers/taming/util.py +157 -0
- taming-transformers/taming_transformers.egg-info/PKG-INFO +10 -0
- taming-transformers/taming_transformers.egg-info/SOURCES.txt +7 -0
- taming-transformers/taming_transformers.egg-info/dependency_links.txt +1 -0
- taming-transformers/taming_transformers.egg-info/requires.txt +3 -0
- taming-transformers/taming_transformers.egg-info/top_level.txt +1 -0
.gitignore
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
__pycache__/
|
2 |
-
*.DS_Store
|
|
|
|
1 |
__pycache__/
|
2 |
+
*.DS_Store
|
3 |
+
*.ipynb
|
taming-transformers/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
assets/
|
2 |
+
data/
|
taming-transformers/License.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
|
2 |
+
|
3 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
4 |
+
of this software and associated documentation files (the "Software"), to deal
|
5 |
+
in the Software without restriction, including without limitation the rights
|
6 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
7 |
+
copies of the Software, and to permit persons to whom the Software is
|
8 |
+
furnished to do so, subject to the following conditions:
|
9 |
+
|
10 |
+
The above copyright notice and this permission notice shall be included in all
|
11 |
+
copies or substantial portions of the Software.
|
12 |
+
|
13 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
14 |
+
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
15 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
16 |
+
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
17 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
|
18 |
+
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
|
19 |
+
OR OTHER DEALINGS IN THE SOFTWARE./
|
taming-transformers/README.md
ADDED
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Taming Transformers for High-Resolution Image Synthesis
|
2 |
+
##### CVPR 2021 (Oral)
|
3 |
+
![teaser](assets/mountain.jpeg)
|
4 |
+
|
5 |
+
[**Taming Transformers for High-Resolution Image Synthesis**](https://compvis.github.io/taming-transformers/)<br/>
|
6 |
+
[Patrick Esser](https://github.com/pesser)\*,
|
7 |
+
[Robin Rombach](https://github.com/rromb)\*,
|
8 |
+
[Björn Ommer](https://hci.iwr.uni-heidelberg.de/Staff/bommer)<br/>
|
9 |
+
\* equal contribution
|
10 |
+
|
11 |
+
**tl;dr** We combine the efficiancy of convolutional approaches with the expressivity of transformers by introducing a convolutional VQGAN, which learns a codebook of context-rich visual parts, whose composition is modeled with an autoregressive transformer.
|
12 |
+
|
13 |
+
![teaser](assets/teaser.png)
|
14 |
+
[arXiv](https://arxiv.org/abs/2012.09841) | [BibTeX](#bibtex) | [Project Page](https://compvis.github.io/taming-transformers/)
|
15 |
+
|
16 |
+
|
17 |
+
### News
|
18 |
+
#### 2022
|
19 |
+
- More pretrained VQGANs (e.g. a f8-model with only 256 codebook entries) are available in our new work on [Latent Diffusion Models](https://github.com/CompVis/latent-diffusion).
|
20 |
+
- Added scene synthesis models as proposed in the paper [High-Resolution Complex Scene Synthesis with Transformers](https://arxiv.org/abs/2105.06458), see [this section](#scene-image-synthesis).
|
21 |
+
#### 2021
|
22 |
+
- Thanks to [rom1504](https://github.com/rom1504) it is now easy to [train a VQGAN on your own datasets](#training-on-custom-data).
|
23 |
+
- Included a bugfix for the quantizer. For backward compatibility it is
|
24 |
+
disabled by default (which corresponds to always training with `beta=1.0`).
|
25 |
+
Use `legacy=False` in the quantizer config to enable it.
|
26 |
+
Thanks [richcmwang](https://github.com/richcmwang) and [wcshin-git](https://github.com/wcshin-git)!
|
27 |
+
- Our paper received an update: See https://arxiv.org/abs/2012.09841v3 and the corresponding changelog.
|
28 |
+
- Added a pretrained, [1.4B transformer model](https://k00.fr/s511rwcv) trained for class-conditional ImageNet synthesis, which obtains state-of-the-art FID scores among autoregressive approaches and outperforms BigGAN.
|
29 |
+
- Added pretrained, unconditional models on [FFHQ](https://k00.fr/yndvfu95) and [CelebA-HQ](https://k00.fr/2xkmielf).
|
30 |
+
- Added accelerated sampling via caching of keys/values in the self-attention operation, used in `scripts/sample_fast.py`.
|
31 |
+
- Added a checkpoint of a [VQGAN](https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/) trained with f8 compression and Gumbel-Quantization.
|
32 |
+
See also our updated [reconstruction notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb).
|
33 |
+
- We added a [colab notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb) which compares two VQGANs and OpenAI's [DALL-E](https://github.com/openai/DALL-E). See also [this section](#more-resources).
|
34 |
+
- We now include an overview of pretrained models in [Tab.1](#overview-of-pretrained-models). We added models for [COCO](#coco) and [ADE20k](#ade20k).
|
35 |
+
- The streamlit demo now supports image completions.
|
36 |
+
- We now include a couple of examples from the D-RIN dataset so you can run the
|
37 |
+
[D-RIN demo](#d-rin) without preparing the dataset first.
|
38 |
+
- You can now jump right into sampling with our [Colab quickstart notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/taming-transformers.ipynb).
|
39 |
+
|
40 |
+
## Requirements
|
41 |
+
A suitable [conda](https://conda.io/) environment named `taming` can be created
|
42 |
+
and activated with:
|
43 |
+
|
44 |
+
```
|
45 |
+
conda env create -f environment.yaml
|
46 |
+
conda activate taming
|
47 |
+
```
|
48 |
+
## Overview of pretrained models
|
49 |
+
The following table provides an overview of all models that are currently available.
|
50 |
+
FID scores were evaluated using [torch-fidelity](https://github.com/toshas/torch-fidelity).
|
51 |
+
For reference, we also include a link to the recently released autoencoder of the [DALL-E](https://github.com/openai/DALL-E) model.
|
52 |
+
See the corresponding [colab
|
53 |
+
notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb)
|
54 |
+
for a comparison and discussion of reconstruction capabilities.
|
55 |
+
|
56 |
+
| Dataset | FID vs train | FID vs val | Link | Samples (256x256) | Comments
|
57 |
+
| ------------- | ------------- | ------------- |------------- | ------------- |------------- |
|
58 |
+
| FFHQ (f=16) | 9.6 | -- | [ffhq_transformer](https://k00.fr/yndvfu95) | [ffhq_samples](https://k00.fr/j626x093) |
|
59 |
+
| CelebA-HQ (f=16) | 10.2 | -- | [celebahq_transformer](https://k00.fr/2xkmielf) | [celebahq_samples](https://k00.fr/j626x093) |
|
60 |
+
| ADE20K (f=16) | -- | 35.5 | [ade20k_transformer](https://k00.fr/ot46cksa) | [ade20k_samples.zip](https://heibox.uni-heidelberg.de/f/70bb78cbaf844501b8fb/) [2k] | evaluated on val split (2k images)
|
61 |
+
| COCO-Stuff (f=16) | -- | 20.4 | [coco_transformer](https://k00.fr/2zz6i2ce) | [coco_samples.zip](https://heibox.uni-heidelberg.de/f/a395a9be612f4a7a8054/) [5k] | evaluated on val split (5k images)
|
62 |
+
| ImageNet (cIN) (f=16) | 15.98/15.78/6.59/5.88/5.20 | -- | [cin_transformer](https://k00.fr/s511rwcv) | [cin_samples](https://k00.fr/j626x093) | different decoding hyperparameters |
|
63 |
+
| | | | || |
|
64 |
+
| FacesHQ (f=16) | -- | -- | [faceshq_transformer](https://k00.fr/qqfl2do8)
|
65 |
+
| S-FLCKR (f=16) | -- | -- | [sflckr](https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/)
|
66 |
+
| D-RIN (f=16) | -- | -- | [drin_transformer](https://k00.fr/39jcugc5)
|
67 |
+
| | | | | || |
|
68 |
+
| VQGAN ImageNet (f=16), 1024 | 10.54 | 7.94 | [vqgan_imagenet_f16_1024](https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/) | [reconstructions](https://k00.fr/j626x093) | Reconstruction-FIDs.
|
69 |
+
| VQGAN ImageNet (f=16), 16384 | 7.41 | 4.98 |[vqgan_imagenet_f16_16384](https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/) | [reconstructions](https://k00.fr/j626x093) | Reconstruction-FIDs.
|
70 |
+
| VQGAN OpenImages (f=8), 256 | -- | 1.49 |https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip | --- | Reconstruction-FIDs. Available via [latent diffusion](https://github.com/CompVis/latent-diffusion).
|
71 |
+
| VQGAN OpenImages (f=8), 16384 | -- | 1.14 |https://ommer-lab.com/files/latent-diffusion/vq-f8.zip | --- | Reconstruction-FIDs. Available via [latent diffusion](https://github.com/CompVis/latent-diffusion)
|
72 |
+
| VQGAN OpenImages (f=8), 8192, GumbelQuantization | 3.24 | 1.49 |[vqgan_gumbel_f8](https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/) | --- | Reconstruction-FIDs.
|
73 |
+
| | | | | || |
|
74 |
+
| DALL-E dVAE (f=8), 8192, GumbelQuantization | 33.88 | 32.01 | https://github.com/openai/DALL-E | [reconstructions](https://k00.fr/j626x093) | Reconstruction-FIDs.
|
75 |
+
|
76 |
+
|
77 |
+
## Running pretrained models
|
78 |
+
|
79 |
+
The commands below will start a streamlit demo which supports sampling at
|
80 |
+
different resolutions and image completions. To run a non-interactive version
|
81 |
+
of the sampling process, replace `streamlit run scripts/sample_conditional.py --`
|
82 |
+
by `python scripts/make_samples.py --outdir <path_to_write_samples_to>` and
|
83 |
+
keep the remaining command line arguments.
|
84 |
+
|
85 |
+
To sample from unconditional or class-conditional models,
|
86 |
+
run `python scripts/sample_fast.py -r <path/to/config_and_checkpoint>`.
|
87 |
+
We describe below how to use this script to sample from the ImageNet, FFHQ, and CelebA-HQ models,
|
88 |
+
respectively.
|
89 |
+
|
90 |
+
### S-FLCKR
|
91 |
+
![teaser](assets/sunset_and_ocean.jpg)
|
92 |
+
|
93 |
+
You can also [run this model in a Colab
|
94 |
+
notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/taming-transformers.ipynb),
|
95 |
+
which includes all necessary steps to start sampling.
|
96 |
+
|
97 |
+
Download the
|
98 |
+
[2020-11-09T13-31-51_sflckr](https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/)
|
99 |
+
folder and place it into `logs`. Then, run
|
100 |
+
```
|
101 |
+
streamlit run scripts/sample_conditional.py -- -r logs/2020-11-09T13-31-51_sflckr/
|
102 |
+
```
|
103 |
+
|
104 |
+
### ImageNet
|
105 |
+
![teaser](assets/imagenet.png)
|
106 |
+
|
107 |
+
Download the [2021-04-03T19-39-50_cin_transformer](https://k00.fr/s511rwcv)
|
108 |
+
folder and place it into logs. Sampling from the class-conditional ImageNet
|
109 |
+
model does not require any data preparation. To produce 50 samples for each of
|
110 |
+
the 1000 classes of ImageNet, with k=600 for top-k sampling, p=0.92 for nucleus
|
111 |
+
sampling and temperature t=1.0, run
|
112 |
+
|
113 |
+
```
|
114 |
+
python scripts/sample_fast.py -r logs/2021-04-03T19-39-50_cin_transformer/ -n 50 -k 600 -t 1.0 -p 0.92 --batch_size 25
|
115 |
+
```
|
116 |
+
|
117 |
+
To restrict the model to certain classes, provide them via the `--classes` argument, separated by
|
118 |
+
commas. For example, to sample 50 *ostriches*, *border collies* and *whiskey jugs*, run
|
119 |
+
|
120 |
+
```
|
121 |
+
python scripts/sample_fast.py -r logs/2021-04-03T19-39-50_cin_transformer/ -n 50 -k 600 -t 1.0 -p 0.92 --batch_size 25 --classes 9,232,901
|
122 |
+
```
|
123 |
+
We recommended to experiment with the autoregressive decoding parameters (top-k, top-p and temperature) for best results.
|
124 |
+
|
125 |
+
### FFHQ/CelebA-HQ
|
126 |
+
|
127 |
+
Download the [2021-04-23T18-19-01_ffhq_transformer](https://k00.fr/yndvfu95) and
|
128 |
+
[2021-04-23T18-11-19_celebahq_transformer](https://k00.fr/2xkmielf)
|
129 |
+
folders and place them into logs.
|
130 |
+
Again, sampling from these unconditional models does not require any data preparation.
|
131 |
+
To produce 50000 samples, with k=250 for top-k sampling,
|
132 |
+
p=1.0 for nucleus sampling and temperature t=1.0, run
|
133 |
+
|
134 |
+
```
|
135 |
+
python scripts/sample_fast.py -r logs/2021-04-23T18-19-01_ffhq_transformer/
|
136 |
+
```
|
137 |
+
for FFHQ and
|
138 |
+
|
139 |
+
```
|
140 |
+
python scripts/sample_fast.py -r logs/2021-04-23T18-11-19_celebahq_transformer/
|
141 |
+
```
|
142 |
+
to sample from the CelebA-HQ model.
|
143 |
+
For both models it can be advantageous to vary the top-k/top-p parameters for sampling.
|
144 |
+
|
145 |
+
### FacesHQ
|
146 |
+
![teaser](assets/faceshq.jpg)
|
147 |
+
|
148 |
+
Download [2020-11-13T21-41-45_faceshq_transformer](https://k00.fr/qqfl2do8) and
|
149 |
+
place it into `logs`. Follow the data preparation steps for
|
150 |
+
[CelebA-HQ](#celeba-hq) and [FFHQ](#ffhq). Run
|
151 |
+
```
|
152 |
+
streamlit run scripts/sample_conditional.py -- -r logs/2020-11-13T21-41-45_faceshq_transformer/
|
153 |
+
```
|
154 |
+
|
155 |
+
### D-RIN
|
156 |
+
![teaser](assets/drin.jpg)
|
157 |
+
|
158 |
+
Download [2020-11-20T12-54-32_drin_transformer](https://k00.fr/39jcugc5) and
|
159 |
+
place it into `logs`. To run the demo on a couple of example depth maps
|
160 |
+
included in the repository, run
|
161 |
+
|
162 |
+
```
|
163 |
+
streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T12-54-32_drin_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.imagenet.DRINExamples}}}"
|
164 |
+
```
|
165 |
+
|
166 |
+
To run the demo on the complete validation set, first follow the data preparation steps for
|
167 |
+
[ImageNet](#imagenet) and then run
|
168 |
+
```
|
169 |
+
streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T12-54-32_drin_transformer/
|
170 |
+
```
|
171 |
+
|
172 |
+
### COCO
|
173 |
+
Download [2021-01-20T16-04-20_coco_transformer](https://k00.fr/2zz6i2ce) and
|
174 |
+
place it into `logs`. To run the demo on a couple of example segmentation maps
|
175 |
+
included in the repository, run
|
176 |
+
|
177 |
+
```
|
178 |
+
streamlit run scripts/sample_conditional.py -- -r logs/2021-01-20T16-04-20_coco_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.coco.Examples}}}"
|
179 |
+
```
|
180 |
+
|
181 |
+
### ADE20k
|
182 |
+
Download [2020-11-20T21-45-44_ade20k_transformer](https://k00.fr/ot46cksa) and
|
183 |
+
place it into `logs`. To run the demo on a couple of example segmentation maps
|
184 |
+
included in the repository, run
|
185 |
+
|
186 |
+
```
|
187 |
+
streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T21-45-44_ade20k_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.ade20k.Examples}}}"
|
188 |
+
```
|
189 |
+
|
190 |
+
## Scene Image Synthesis
|
191 |
+
![teaser](assets/scene_images_samples.svg)
|
192 |
+
Scene image generation based on bounding box conditionals as done in our CVPR2021 AI4CC workshop paper [High-Resolution Complex Scene Synthesis with Transformers](https://arxiv.org/abs/2105.06458) (see talk on [workshop page](https://visual.cs.brown.edu/workshops/aicc2021/#awards)). Supporting the datasets COCO and Open Images.
|
193 |
+
|
194 |
+
### Training
|
195 |
+
Download first-stage models [COCO-8k-VQGAN](https://heibox.uni-heidelberg.de/f/78dea9589974474c97c1/) for COCO or [COCO/Open-Images-8k-VQGAN](https://heibox.uni-heidelberg.de/f/461d9a9f4fcf48ab84f4/) for Open Images.
|
196 |
+
Change `ckpt_path` in `data/coco_scene_images_transformer.yaml` and `data/open_images_scene_images_transformer.yaml` to point to the downloaded first-stage models.
|
197 |
+
Download the full COCO/OI datasets and adapt `data_path` in the same files, unless working with the 100 files provided for training and validation suits your needs already.
|
198 |
+
|
199 |
+
Code can be run with
|
200 |
+
`python main.py --base configs/coco_scene_images_transformer.yaml -t True --gpus 0,`
|
201 |
+
or
|
202 |
+
`python main.py --base configs/open_images_scene_images_transformer.yaml -t True --gpus 0,`
|
203 |
+
|
204 |
+
### Sampling
|
205 |
+
Train a model as described above or download a pre-trained model:
|
206 |
+
- [Open Images 1 billion parameter model](https://drive.google.com/file/d/1FEK-Z7hyWJBvFWQF50pzSK9y1W_CJEig/view?usp=sharing) available that trained 100 epochs. On 256x256 pixels, FID 41.48±0.21, SceneFID 14.60±0.15, Inception Score 18.47±0.27. The model was trained with 2d crops of images and is thus well-prepared for the task of generating high-resolution images, e.g. 512x512.
|
207 |
+
- [Open Images distilled version of the above model with 125 million parameters](https://drive.google.com/file/d/1xf89g0mc78J3d8Bx5YhbK4tNRNlOoYaO) allows for sampling on smaller GPUs (4 GB is enough for sampling 256x256 px images). Model was trained for 60 epochs with 10% soft loss, 90% hard loss. On 256x256 pixels, FID 43.07±0.40, SceneFID 15.93±0.19, Inception Score 17.23±0.11.
|
208 |
+
- [COCO 30 epochs](https://heibox.uni-heidelberg.de/f/0d0b2594e9074c7e9a33/)
|
209 |
+
- [COCO 60 epochs](https://drive.google.com/file/d/1bInd49g2YulTJBjU32Awyt5qnzxxG5U9/) (find model statistics for both COCO versions in `assets/coco_scene_images_training.svg`)
|
210 |
+
|
211 |
+
When downloading a pre-trained model, remember to change `ckpt_path` in `configs/*project.yaml` to point to your downloaded first-stage model (see ->Training).
|
212 |
+
|
213 |
+
Scene image generation can be run with
|
214 |
+
`python scripts/make_scene_samples.py --outdir=/some/outdir -r /path/to/pretrained/model --resolution=512,512`
|
215 |
+
|
216 |
+
|
217 |
+
## Training on custom data
|
218 |
+
|
219 |
+
Training on your own dataset can be beneficial to get better tokens and hence better images for your domain.
|
220 |
+
Those are the steps to follow to make this work:
|
221 |
+
1. install the repo with `conda env create -f environment.yaml`, `conda activate taming` and `pip install -e .`
|
222 |
+
1. put your .jpg files in a folder `your_folder`
|
223 |
+
2. create 2 text files a `xx_train.txt` and `xx_test.txt` that point to the files in your training and test set respectively (for example `find $(pwd)/your_folder -name "*.jpg" > train.txt`)
|
224 |
+
3. adapt `configs/custom_vqgan.yaml` to point to these 2 files
|
225 |
+
4. run `python main.py --base configs/custom_vqgan.yaml -t True --gpus 0,1` to
|
226 |
+
train on two GPUs. Use `--gpus 0,` (with a trailing comma) to train on a single GPU.
|
227 |
+
|
228 |
+
## Data Preparation
|
229 |
+
|
230 |
+
### ImageNet
|
231 |
+
The code will try to download (through [Academic
|
232 |
+
Torrents](http://academictorrents.com/)) and prepare ImageNet the first time it
|
233 |
+
is used. However, since ImageNet is quite large, this requires a lot of disk
|
234 |
+
space and time. If you already have ImageNet on your disk, you can speed things
|
235 |
+
up by putting the data into
|
236 |
+
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` (which defaults to
|
237 |
+
`~/.cache/autoencoders/data/ILSVRC2012_{split}/data/`), where `{split}` is one
|
238 |
+
of `train`/`validation`. It should have the following structure:
|
239 |
+
|
240 |
+
```
|
241 |
+
${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/
|
242 |
+
├── n01440764
|
243 |
+
│ ├── n01440764_10026.JPEG
|
244 |
+
│ ├── n01440764_10027.JPEG
|
245 |
+
│ ├── ...
|
246 |
+
├── n01443537
|
247 |
+
│ ├── n01443537_10007.JPEG
|
248 |
+
│ ├── n01443537_10014.JPEG
|
249 |
+
│ ├── ...
|
250 |
+
├── ...
|
251 |
+
```
|
252 |
+
|
253 |
+
If you haven't extracted the data, you can also place
|
254 |
+
`ILSVRC2012_img_train.tar`/`ILSVRC2012_img_val.tar` (or symlinks to them) into
|
255 |
+
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_train/` /
|
256 |
+
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_validation/`, which will then be
|
257 |
+
extracted into above structure without downloading it again. Note that this
|
258 |
+
will only happen if neither a folder
|
259 |
+
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` nor a file
|
260 |
+
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/.ready` exist. Remove them
|
261 |
+
if you want to force running the dataset preparation again.
|
262 |
+
|
263 |
+
You will then need to prepare the depth data using
|
264 |
+
[MiDaS](https://github.com/intel-isl/MiDaS). Create a symlink
|
265 |
+
`data/imagenet_depth` pointing to a folder with two subfolders `train` and
|
266 |
+
`val`, each mirroring the structure of the corresponding ImageNet folder
|
267 |
+
described above and containing a `png` file for each of ImageNet's `JPEG`
|
268 |
+
files. The `png` encodes `float32` depth values obtained from MiDaS as RGBA
|
269 |
+
images. We provide the script `scripts/extract_depth.py` to generate this data.
|
270 |
+
**Please note** that this script uses [MiDaS via PyTorch
|
271 |
+
Hub](https://pytorch.org/hub/intelisl_midas_v2/). When we prepared the data,
|
272 |
+
the hub provided the [MiDaS
|
273 |
+
v2.0](https://github.com/intel-isl/MiDaS/releases/tag/v2) version, but now it
|
274 |
+
provides a v2.1 version. We haven't tested our models with depth maps obtained
|
275 |
+
via v2.1 and if you want to make sure that things work as expected, you must
|
276 |
+
adjust the script to make sure it explicitly uses
|
277 |
+
[v2.0](https://github.com/intel-isl/MiDaS/releases/tag/v2)!
|
278 |
+
|
279 |
+
### CelebA-HQ
|
280 |
+
Create a symlink `data/celebahq` pointing to a folder containing the `.npy`
|
281 |
+
files of CelebA-HQ (instructions to obtain them can be found in the [PGGAN
|
282 |
+
repository](https://github.com/tkarras/progressive_growing_of_gans)).
|
283 |
+
|
284 |
+
### FFHQ
|
285 |
+
Create a symlink `data/ffhq` pointing to the `images1024x1024` folder obtained
|
286 |
+
from the [FFHQ repository](https://github.com/NVlabs/ffhq-dataset).
|
287 |
+
|
288 |
+
### S-FLCKR
|
289 |
+
Unfortunately, we are not allowed to distribute the images we collected for the
|
290 |
+
S-FLCKR dataset and can therefore only give a description how it was produced.
|
291 |
+
There are many resources on [collecting images from the
|
292 |
+
web](https://github.com/adrianmrit/flickrdatasets) to get started.
|
293 |
+
We collected sufficiently large images from [flickr](https://www.flickr.com)
|
294 |
+
(see `data/flickr_tags.txt` for a full list of tags used to find images)
|
295 |
+
and various [subreddits](https://www.reddit.com/r/sfwpornnetwork/wiki/network)
|
296 |
+
(see `data/subreddits.txt` for all subreddits that were used).
|
297 |
+
Overall, we collected 107625 images, and split them randomly into 96861
|
298 |
+
training images and 10764 validation images. We then obtained segmentation
|
299 |
+
masks for each image using [DeepLab v2](https://arxiv.org/abs/1606.00915)
|
300 |
+
trained on [COCO-Stuff](https://arxiv.org/abs/1612.03716). We used a [PyTorch
|
301 |
+
reimplementation](https://github.com/kazuto1011/deeplab-pytorch) and include an
|
302 |
+
example script for this process in `scripts/extract_segmentation.py`.
|
303 |
+
|
304 |
+
### COCO
|
305 |
+
Create a symlink `data/coco` containing the images from the 2017 split in
|
306 |
+
`train2017` and `val2017`, and their annotations in `annotations`. Files can be
|
307 |
+
obtained from the [COCO webpage](https://cocodataset.org/). In addition, we use
|
308 |
+
the [Stuff+thing PNG-style annotations on COCO 2017
|
309 |
+
trainval](http://calvin.inf.ed.ac.uk/wp-content/uploads/data/cocostuffdataset/stuffthingmaps_trainval2017.zip)
|
310 |
+
annotations from [COCO-Stuff](https://github.com/nightrome/cocostuff), which
|
311 |
+
should be placed under `data/cocostuffthings`.
|
312 |
+
|
313 |
+
### ADE20k
|
314 |
+
Create a symlink `data/ade20k_root` containing the contents of
|
315 |
+
[ADEChallengeData2016.zip](http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip)
|
316 |
+
from the [MIT Scene Parsing Benchmark](http://sceneparsing.csail.mit.edu/).
|
317 |
+
|
318 |
+
## Training models
|
319 |
+
|
320 |
+
### FacesHQ
|
321 |
+
|
322 |
+
Train a VQGAN with
|
323 |
+
```
|
324 |
+
python main.py --base configs/faceshq_vqgan.yaml -t True --gpus 0,
|
325 |
+
```
|
326 |
+
|
327 |
+
Then, adjust the checkpoint path of the config key
|
328 |
+
`model.params.first_stage_config.params.ckpt_path` in
|
329 |
+
`configs/faceshq_transformer.yaml` (or download
|
330 |
+
[2020-11-09T13-33-36_faceshq_vqgan](https://k00.fr/uxy5usa9) and place into `logs`, which
|
331 |
+
corresponds to the preconfigured checkpoint path), then run
|
332 |
+
```
|
333 |
+
python main.py --base configs/faceshq_transformer.yaml -t True --gpus 0,
|
334 |
+
```
|
335 |
+
|
336 |
+
### D-RIN
|
337 |
+
|
338 |
+
Train a VQGAN on ImageNet with
|
339 |
+
```
|
340 |
+
python main.py --base configs/imagenet_vqgan.yaml -t True --gpus 0,
|
341 |
+
```
|
342 |
+
|
343 |
+
or download a pretrained one from [2020-09-23T17-56-33_imagenet_vqgan](https://k00.fr/u0j2dtac)
|
344 |
+
and place under `logs`. If you trained your own, adjust the path in the config
|
345 |
+
key `model.params.first_stage_config.params.ckpt_path` of
|
346 |
+
`configs/drin_transformer.yaml`.
|
347 |
+
|
348 |
+
Train a VQGAN on Depth Maps of ImageNet with
|
349 |
+
```
|
350 |
+
python main.py --base configs/imagenetdepth_vqgan.yaml -t True --gpus 0,
|
351 |
+
```
|
352 |
+
|
353 |
+
or download a pretrained one from [2020-11-03T15-34-24_imagenetdepth_vqgan](https://k00.fr/55rlxs6i)
|
354 |
+
and place under `logs`. If you trained your own, adjust the path in the config
|
355 |
+
key `model.params.cond_stage_config.params.ckpt_path` of
|
356 |
+
`configs/drin_transformer.yaml`.
|
357 |
+
|
358 |
+
To train the transformer, run
|
359 |
+
```
|
360 |
+
python main.py --base configs/drin_transformer.yaml -t True --gpus 0,
|
361 |
+
```
|
362 |
+
|
363 |
+
## More Resources
|
364 |
+
### Comparing Different First Stage Models
|
365 |
+
The reconstruction and compression capabilities of different fist stage models can be analyzed in this [colab notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb).
|
366 |
+
In particular, the notebook compares two VQGANs with a downsampling factor of f=16 for each and codebook dimensionality of 1024 and 16384,
|
367 |
+
a VQGAN with f=8 and 8192 codebook entries and the discrete autoencoder of OpenAI's [DALL-E](https://github.com/openai/DALL-E) (which has f=8 and 8192
|
368 |
+
codebook entries).
|
369 |
+
![firststages1](assets/first_stage_squirrels.png)
|
370 |
+
![firststages2](assets/first_stage_mushrooms.png)
|
371 |
+
|
372 |
+
### Other
|
373 |
+
- A [video summary](https://www.youtube.com/watch?v=o7dqGcLDf0A&feature=emb_imp_woyt) by [Two Minute Papers](https://www.youtube.com/channel/UCbfYPyITQ-7l4upoX8nvctg).
|
374 |
+
- A [video summary](https://www.youtube.com/watch?v=-wDSDtIAyWQ) by [Gradient Dude](https://www.youtube.com/c/GradientDude/about).
|
375 |
+
- A [weights and biases report summarizing the paper](https://wandb.ai/ayush-thakur/taming-transformer/reports/-Overview-Taming-Transformers-for-High-Resolution-Image-Synthesis---Vmlldzo0NjEyMTY)
|
376 |
+
by [ayulockin](https://github.com/ayulockin).
|
377 |
+
- A [video summary](https://www.youtube.com/watch?v=JfUTd8fjtX8&feature=emb_imp_woyt) by [What's AI](https://www.youtube.com/channel/UCUzGQrN-lyyc0BWTYoJM_Sg).
|
378 |
+
- Take a look at [ak9250's notebook](https://github.com/ak9250/taming-transformers/blob/master/tamingtransformerscolab.ipynb) if you want to run the streamlit demos on Colab.
|
379 |
+
|
380 |
+
### Text-to-Image Optimization via CLIP
|
381 |
+
VQGAN has been successfully used as an image generator guided by the [CLIP](https://github.com/openai/CLIP) model, both for pure image generation
|
382 |
+
from scratch and image-to-image translation. We recommend the following notebooks/videos/resources:
|
383 |
+
|
384 |
+
- [Advadnouns](https://twitter.com/advadnoun/status/1389316507134357506) Patreon and corresponding LatentVision notebooks: https://www.patreon.com/patronizeme
|
385 |
+
- The [notebook]( https://colab.research.google.com/drive/1L8oL-vLJXVcRzCFbPwOoMkPKJ8-aYdPN) of [Rivers Have Wings](https://twitter.com/RiversHaveWings).
|
386 |
+
- A [video](https://www.youtube.com/watch?v=90QDe6DQXF4&t=12s) explanation by [Dot CSV](https://www.youtube.com/channel/UCy5znSnfMsDwaLlROnZ7Qbg) (in Spanish, but English subtitles are available)
|
387 |
+
|
388 |
+
![txt2img](assets/birddrawnbyachild.png)
|
389 |
+
|
390 |
+
Text prompt: *'A bird drawn by a child'*
|
391 |
+
|
392 |
+
## Shout-outs
|
393 |
+
Thanks to everyone who makes their code and models available. In particular,
|
394 |
+
|
395 |
+
- The architecture of our VQGAN is inspired by [Denoising Diffusion Probabilistic Models](https://github.com/hojonathanho/diffusion)
|
396 |
+
- The very hackable transformer implementation [minGPT](https://github.com/karpathy/minGPT)
|
397 |
+
- The good ol' [PatchGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) and [Learned Perceptual Similarity (LPIPS)](https://github.com/richzhang/PerceptualSimilarity)
|
398 |
+
|
399 |
+
## BibTeX
|
400 |
+
|
401 |
+
```
|
402 |
+
@misc{esser2020taming,
|
403 |
+
title={Taming Transformers for High-Resolution Image Synthesis},
|
404 |
+
author={Patrick Esser and Robin Rombach and Björn Ommer},
|
405 |
+
year={2020},
|
406 |
+
eprint={2012.09841},
|
407 |
+
archivePrefix={arXiv},
|
408 |
+
primaryClass={cs.CV}
|
409 |
+
}
|
410 |
+
```
|
taming-transformers/configs/coco_cond_stage.yaml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 4.5e-06
|
3 |
+
target: taming.models.vqgan.VQSegmentationModel
|
4 |
+
params:
|
5 |
+
embed_dim: 256
|
6 |
+
n_embed: 1024
|
7 |
+
image_key: "segmentation"
|
8 |
+
n_labels: 183
|
9 |
+
ddconfig:
|
10 |
+
double_z: false
|
11 |
+
z_channels: 256
|
12 |
+
resolution: 256
|
13 |
+
in_channels: 183
|
14 |
+
out_ch: 183
|
15 |
+
ch: 128
|
16 |
+
ch_mult:
|
17 |
+
- 1
|
18 |
+
- 1
|
19 |
+
- 2
|
20 |
+
- 2
|
21 |
+
- 4
|
22 |
+
num_res_blocks: 2
|
23 |
+
attn_resolutions:
|
24 |
+
- 16
|
25 |
+
dropout: 0.0
|
26 |
+
|
27 |
+
lossconfig:
|
28 |
+
target: taming.modules.losses.segmentation.BCELossWithQuant
|
29 |
+
params:
|
30 |
+
codebook_weight: 1.0
|
31 |
+
|
32 |
+
data:
|
33 |
+
target: main.DataModuleFromConfig
|
34 |
+
params:
|
35 |
+
batch_size: 12
|
36 |
+
train:
|
37 |
+
target: taming.data.coco.CocoImagesAndCaptionsTrain
|
38 |
+
params:
|
39 |
+
size: 296
|
40 |
+
crop_size: 256
|
41 |
+
onehot_segmentation: true
|
42 |
+
use_stuffthing: true
|
43 |
+
validation:
|
44 |
+
target: taming.data.coco.CocoImagesAndCaptionsValidation
|
45 |
+
params:
|
46 |
+
size: 256
|
47 |
+
crop_size: 256
|
48 |
+
onehot_segmentation: true
|
49 |
+
use_stuffthing: true
|
taming-transformers/configs/coco_scene_images_transformer.yaml
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 4.5e-06
|
3 |
+
target: taming.models.cond_transformer.Net2NetTransformer
|
4 |
+
params:
|
5 |
+
cond_stage_key: objects_bbox
|
6 |
+
transformer_config:
|
7 |
+
target: taming.modules.transformer.mingpt.GPT
|
8 |
+
params:
|
9 |
+
vocab_size: 8192
|
10 |
+
block_size: 348 # = 256 + 92 = dim(vqgan_latent_space,16x16) + dim(conditional_builder.embedding_dim)
|
11 |
+
n_layer: 40
|
12 |
+
n_head: 16
|
13 |
+
n_embd: 1408
|
14 |
+
embd_pdrop: 0.1
|
15 |
+
resid_pdrop: 0.1
|
16 |
+
attn_pdrop: 0.1
|
17 |
+
first_stage_config:
|
18 |
+
target: taming.models.vqgan.VQModel
|
19 |
+
params:
|
20 |
+
ckpt_path: /path/to/coco_epoch117.ckpt # https://heibox.uni-heidelberg.de/f/78dea9589974474c97c1/
|
21 |
+
embed_dim: 256
|
22 |
+
n_embed: 8192
|
23 |
+
ddconfig:
|
24 |
+
double_z: false
|
25 |
+
z_channels: 256
|
26 |
+
resolution: 256
|
27 |
+
in_channels: 3
|
28 |
+
out_ch: 3
|
29 |
+
ch: 128
|
30 |
+
ch_mult:
|
31 |
+
- 1
|
32 |
+
- 1
|
33 |
+
- 2
|
34 |
+
- 2
|
35 |
+
- 4
|
36 |
+
num_res_blocks: 2
|
37 |
+
attn_resolutions:
|
38 |
+
- 16
|
39 |
+
dropout: 0.0
|
40 |
+
lossconfig:
|
41 |
+
target: taming.modules.losses.DummyLoss
|
42 |
+
cond_stage_config:
|
43 |
+
target: taming.models.dummy_cond_stage.DummyCondStage
|
44 |
+
params:
|
45 |
+
conditional_key: objects_bbox
|
46 |
+
|
47 |
+
data:
|
48 |
+
target: main.DataModuleFromConfig
|
49 |
+
params:
|
50 |
+
batch_size: 6
|
51 |
+
train:
|
52 |
+
target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco
|
53 |
+
params:
|
54 |
+
data_path: data/coco_annotations_100 # substitute with path to full dataset
|
55 |
+
split: train
|
56 |
+
keys: [image, objects_bbox, file_name, annotations]
|
57 |
+
no_tokens: 8192
|
58 |
+
target_image_size: 256
|
59 |
+
min_object_area: 0.00001
|
60 |
+
min_objects_per_image: 2
|
61 |
+
max_objects_per_image: 30
|
62 |
+
crop_method: random-1d
|
63 |
+
random_flip: true
|
64 |
+
use_group_parameter: true
|
65 |
+
encode_crop: true
|
66 |
+
validation:
|
67 |
+
target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco
|
68 |
+
params:
|
69 |
+
data_path: data/coco_annotations_100 # substitute with path to full dataset
|
70 |
+
split: validation
|
71 |
+
keys: [image, objects_bbox, file_name, annotations]
|
72 |
+
no_tokens: 8192
|
73 |
+
target_image_size: 256
|
74 |
+
min_object_area: 0.00001
|
75 |
+
min_objects_per_image: 2
|
76 |
+
max_objects_per_image: 30
|
77 |
+
crop_method: center
|
78 |
+
random_flip: false
|
79 |
+
use_group_parameter: true
|
80 |
+
encode_crop: true
|
taming-transformers/configs/custom_vqgan.yaml
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 4.5e-6
|
3 |
+
target: taming.models.vqgan.VQModel
|
4 |
+
params:
|
5 |
+
embed_dim: 256
|
6 |
+
n_embed: 1024
|
7 |
+
ddconfig:
|
8 |
+
double_z: False
|
9 |
+
z_channels: 256
|
10 |
+
resolution: 256
|
11 |
+
in_channels: 3
|
12 |
+
out_ch: 3
|
13 |
+
ch: 128
|
14 |
+
ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
|
15 |
+
num_res_blocks: 2
|
16 |
+
attn_resolutions: [16]
|
17 |
+
dropout: 0.0
|
18 |
+
|
19 |
+
lossconfig:
|
20 |
+
target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
|
21 |
+
params:
|
22 |
+
disc_conditional: False
|
23 |
+
disc_in_channels: 3
|
24 |
+
disc_start: 10000
|
25 |
+
disc_weight: 0.8
|
26 |
+
codebook_weight: 1.0
|
27 |
+
|
28 |
+
data:
|
29 |
+
target: main.DataModuleFromConfig
|
30 |
+
params:
|
31 |
+
batch_size: 5
|
32 |
+
num_workers: 8
|
33 |
+
train:
|
34 |
+
target: taming.data.custom.CustomTrain
|
35 |
+
params:
|
36 |
+
training_images_list_file: some/training.txt
|
37 |
+
size: 256
|
38 |
+
validation:
|
39 |
+
target: taming.data.custom.CustomTest
|
40 |
+
params:
|
41 |
+
test_images_list_file: some/test.txt
|
42 |
+
size: 256
|
43 |
+
|
taming-transformers/configs/drin_transformer.yaml
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 4.5e-06
|
3 |
+
target: taming.models.cond_transformer.Net2NetTransformer
|
4 |
+
params:
|
5 |
+
cond_stage_key: depth
|
6 |
+
transformer_config:
|
7 |
+
target: taming.modules.transformer.mingpt.GPT
|
8 |
+
params:
|
9 |
+
vocab_size: 1024
|
10 |
+
block_size: 512
|
11 |
+
n_layer: 24
|
12 |
+
n_head: 16
|
13 |
+
n_embd: 1024
|
14 |
+
first_stage_config:
|
15 |
+
target: taming.models.vqgan.VQModel
|
16 |
+
params:
|
17 |
+
ckpt_path: logs/2020-09-23T17-56-33_imagenet_vqgan/checkpoints/last.ckpt
|
18 |
+
embed_dim: 256
|
19 |
+
n_embed: 1024
|
20 |
+
ddconfig:
|
21 |
+
double_z: false
|
22 |
+
z_channels: 256
|
23 |
+
resolution: 256
|
24 |
+
in_channels: 3
|
25 |
+
out_ch: 3
|
26 |
+
ch: 128
|
27 |
+
ch_mult:
|
28 |
+
- 1
|
29 |
+
- 1
|
30 |
+
- 2
|
31 |
+
- 2
|
32 |
+
- 4
|
33 |
+
num_res_blocks: 2
|
34 |
+
attn_resolutions:
|
35 |
+
- 16
|
36 |
+
dropout: 0.0
|
37 |
+
lossconfig:
|
38 |
+
target: taming.modules.losses.DummyLoss
|
39 |
+
cond_stage_config:
|
40 |
+
target: taming.models.vqgan.VQModel
|
41 |
+
params:
|
42 |
+
ckpt_path: logs/2020-11-03T15-34-24_imagenetdepth_vqgan/checkpoints/last.ckpt
|
43 |
+
embed_dim: 256
|
44 |
+
n_embed: 1024
|
45 |
+
ddconfig:
|
46 |
+
double_z: false
|
47 |
+
z_channels: 256
|
48 |
+
resolution: 256
|
49 |
+
in_channels: 1
|
50 |
+
out_ch: 1
|
51 |
+
ch: 128
|
52 |
+
ch_mult:
|
53 |
+
- 1
|
54 |
+
- 1
|
55 |
+
- 2
|
56 |
+
- 2
|
57 |
+
- 4
|
58 |
+
num_res_blocks: 2
|
59 |
+
attn_resolutions:
|
60 |
+
- 16
|
61 |
+
dropout: 0.0
|
62 |
+
lossconfig:
|
63 |
+
target: taming.modules.losses.DummyLoss
|
64 |
+
|
65 |
+
data:
|
66 |
+
target: main.DataModuleFromConfig
|
67 |
+
params:
|
68 |
+
batch_size: 2
|
69 |
+
num_workers: 8
|
70 |
+
train:
|
71 |
+
target: taming.data.imagenet.RINTrainWithDepth
|
72 |
+
params:
|
73 |
+
size: 256
|
74 |
+
validation:
|
75 |
+
target: taming.data.imagenet.RINValidationWithDepth
|
76 |
+
params:
|
77 |
+
size: 256
|
taming-transformers/configs/faceshq_transformer.yaml
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 4.5e-06
|
3 |
+
target: taming.models.cond_transformer.Net2NetTransformer
|
4 |
+
params:
|
5 |
+
cond_stage_key: coord
|
6 |
+
transformer_config:
|
7 |
+
target: taming.modules.transformer.mingpt.GPT
|
8 |
+
params:
|
9 |
+
vocab_size: 1024
|
10 |
+
block_size: 512
|
11 |
+
n_layer: 24
|
12 |
+
n_head: 16
|
13 |
+
n_embd: 1024
|
14 |
+
first_stage_config:
|
15 |
+
target: taming.models.vqgan.VQModel
|
16 |
+
params:
|
17 |
+
ckpt_path: logs/2020-11-09T13-33-36_faceshq_vqgan/checkpoints/last.ckpt
|
18 |
+
embed_dim: 256
|
19 |
+
n_embed: 1024
|
20 |
+
ddconfig:
|
21 |
+
double_z: false
|
22 |
+
z_channels: 256
|
23 |
+
resolution: 256
|
24 |
+
in_channels: 3
|
25 |
+
out_ch: 3
|
26 |
+
ch: 128
|
27 |
+
ch_mult:
|
28 |
+
- 1
|
29 |
+
- 1
|
30 |
+
- 2
|
31 |
+
- 2
|
32 |
+
- 4
|
33 |
+
num_res_blocks: 2
|
34 |
+
attn_resolutions:
|
35 |
+
- 16
|
36 |
+
dropout: 0.0
|
37 |
+
lossconfig:
|
38 |
+
target: taming.modules.losses.DummyLoss
|
39 |
+
cond_stage_config:
|
40 |
+
target: taming.modules.misc.coord.CoordStage
|
41 |
+
params:
|
42 |
+
n_embed: 1024
|
43 |
+
down_factor: 16
|
44 |
+
|
45 |
+
data:
|
46 |
+
target: main.DataModuleFromConfig
|
47 |
+
params:
|
48 |
+
batch_size: 2
|
49 |
+
num_workers: 8
|
50 |
+
train:
|
51 |
+
target: taming.data.faceshq.FacesHQTrain
|
52 |
+
params:
|
53 |
+
size: 256
|
54 |
+
crop_size: 256
|
55 |
+
coord: True
|
56 |
+
validation:
|
57 |
+
target: taming.data.faceshq.FacesHQValidation
|
58 |
+
params:
|
59 |
+
size: 256
|
60 |
+
crop_size: 256
|
61 |
+
coord: True
|
taming-transformers/configs/faceshq_vqgan.yaml
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 4.5e-6
|
3 |
+
target: taming.models.vqgan.VQModel
|
4 |
+
params:
|
5 |
+
embed_dim: 256
|
6 |
+
n_embed: 1024
|
7 |
+
ddconfig:
|
8 |
+
double_z: False
|
9 |
+
z_channels: 256
|
10 |
+
resolution: 256
|
11 |
+
in_channels: 3
|
12 |
+
out_ch: 3
|
13 |
+
ch: 128
|
14 |
+
ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
|
15 |
+
num_res_blocks: 2
|
16 |
+
attn_resolutions: [16]
|
17 |
+
dropout: 0.0
|
18 |
+
|
19 |
+
lossconfig:
|
20 |
+
target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
|
21 |
+
params:
|
22 |
+
disc_conditional: False
|
23 |
+
disc_in_channels: 3
|
24 |
+
disc_start: 30001
|
25 |
+
disc_weight: 0.8
|
26 |
+
codebook_weight: 1.0
|
27 |
+
|
28 |
+
data:
|
29 |
+
target: main.DataModuleFromConfig
|
30 |
+
params:
|
31 |
+
batch_size: 3
|
32 |
+
num_workers: 8
|
33 |
+
train:
|
34 |
+
target: taming.data.faceshq.FacesHQTrain
|
35 |
+
params:
|
36 |
+
size: 256
|
37 |
+
crop_size: 256
|
38 |
+
validation:
|
39 |
+
target: taming.data.faceshq.FacesHQValidation
|
40 |
+
params:
|
41 |
+
size: 256
|
42 |
+
crop_size: 256
|
taming-transformers/configs/imagenet_vqgan.yaml
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 4.5e-6
|
3 |
+
target: taming.models.vqgan.VQModel
|
4 |
+
params:
|
5 |
+
embed_dim: 256
|
6 |
+
n_embed: 1024
|
7 |
+
ddconfig:
|
8 |
+
double_z: False
|
9 |
+
z_channels: 256
|
10 |
+
resolution: 256
|
11 |
+
in_channels: 3
|
12 |
+
out_ch: 3
|
13 |
+
ch: 128
|
14 |
+
ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
|
15 |
+
num_res_blocks: 2
|
16 |
+
attn_resolutions: [16]
|
17 |
+
dropout: 0.0
|
18 |
+
|
19 |
+
lossconfig:
|
20 |
+
target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
|
21 |
+
params:
|
22 |
+
disc_conditional: False
|
23 |
+
disc_in_channels: 3
|
24 |
+
disc_start: 250001
|
25 |
+
disc_weight: 0.8
|
26 |
+
codebook_weight: 1.0
|
27 |
+
|
28 |
+
data:
|
29 |
+
target: main.DataModuleFromConfig
|
30 |
+
params:
|
31 |
+
batch_size: 12
|
32 |
+
num_workers: 24
|
33 |
+
train:
|
34 |
+
target: taming.data.imagenet.ImageNetTrain
|
35 |
+
params:
|
36 |
+
config:
|
37 |
+
size: 256
|
38 |
+
validation:
|
39 |
+
target: taming.data.imagenet.ImageNetValidation
|
40 |
+
params:
|
41 |
+
config:
|
42 |
+
size: 256
|
taming-transformers/configs/imagenetdepth_vqgan.yaml
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 4.5e-6
|
3 |
+
target: taming.models.vqgan.VQModel
|
4 |
+
params:
|
5 |
+
embed_dim: 256
|
6 |
+
n_embed: 1024
|
7 |
+
image_key: depth
|
8 |
+
ddconfig:
|
9 |
+
double_z: False
|
10 |
+
z_channels: 256
|
11 |
+
resolution: 256
|
12 |
+
in_channels: 1
|
13 |
+
out_ch: 1
|
14 |
+
ch: 128
|
15 |
+
ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
|
16 |
+
num_res_blocks: 2
|
17 |
+
attn_resolutions: [16]
|
18 |
+
dropout: 0.0
|
19 |
+
|
20 |
+
lossconfig:
|
21 |
+
target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
|
22 |
+
params:
|
23 |
+
disc_conditional: False
|
24 |
+
disc_in_channels: 1
|
25 |
+
disc_start: 50001
|
26 |
+
disc_weight: 0.75
|
27 |
+
codebook_weight: 1.0
|
28 |
+
|
29 |
+
data:
|
30 |
+
target: main.DataModuleFromConfig
|
31 |
+
params:
|
32 |
+
batch_size: 3
|
33 |
+
num_workers: 8
|
34 |
+
train:
|
35 |
+
target: taming.data.imagenet.ImageNetTrainWithDepth
|
36 |
+
params:
|
37 |
+
size: 256
|
38 |
+
validation:
|
39 |
+
target: taming.data.imagenet.ImageNetValidationWithDepth
|
40 |
+
params:
|
41 |
+
size: 256
|
taming-transformers/configs/open_images_scene_images_transformer.yaml
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 4.5e-06
|
3 |
+
target: taming.models.cond_transformer.Net2NetTransformer
|
4 |
+
params:
|
5 |
+
cond_stage_key: objects_bbox
|
6 |
+
transformer_config:
|
7 |
+
target: taming.modules.transformer.mingpt.GPT
|
8 |
+
params:
|
9 |
+
vocab_size: 8192
|
10 |
+
block_size: 348 # = 256 + 92 = dim(vqgan_latent_space,16x16) + dim(conditional_builder.embedding_dim)
|
11 |
+
n_layer: 36
|
12 |
+
n_head: 16
|
13 |
+
n_embd: 1536
|
14 |
+
embd_pdrop: 0.1
|
15 |
+
resid_pdrop: 0.1
|
16 |
+
attn_pdrop: 0.1
|
17 |
+
first_stage_config:
|
18 |
+
target: taming.models.vqgan.VQModel
|
19 |
+
params:
|
20 |
+
ckpt_path: /path/to/coco_oi_epoch12.ckpt # https://heibox.uni-heidelberg.de/f/461d9a9f4fcf48ab84f4/
|
21 |
+
embed_dim: 256
|
22 |
+
n_embed: 8192
|
23 |
+
ddconfig:
|
24 |
+
double_z: false
|
25 |
+
z_channels: 256
|
26 |
+
resolution: 256
|
27 |
+
in_channels: 3
|
28 |
+
out_ch: 3
|
29 |
+
ch: 128
|
30 |
+
ch_mult:
|
31 |
+
- 1
|
32 |
+
- 1
|
33 |
+
- 2
|
34 |
+
- 2
|
35 |
+
- 4
|
36 |
+
num_res_blocks: 2
|
37 |
+
attn_resolutions:
|
38 |
+
- 16
|
39 |
+
dropout: 0.0
|
40 |
+
lossconfig:
|
41 |
+
target: taming.modules.losses.DummyLoss
|
42 |
+
cond_stage_config:
|
43 |
+
target: taming.models.dummy_cond_stage.DummyCondStage
|
44 |
+
params:
|
45 |
+
conditional_key: objects_bbox
|
46 |
+
|
47 |
+
data:
|
48 |
+
target: main.DataModuleFromConfig
|
49 |
+
params:
|
50 |
+
batch_size: 6
|
51 |
+
train:
|
52 |
+
target: taming.data.annotated_objects_open_images.AnnotatedObjectsOpenImages
|
53 |
+
params:
|
54 |
+
data_path: data/open_images_annotations_100 # substitute with path to full dataset
|
55 |
+
split: train
|
56 |
+
keys: [image, objects_bbox, file_name, annotations]
|
57 |
+
no_tokens: 8192
|
58 |
+
target_image_size: 256
|
59 |
+
category_allow_list_target: taming.data.open_images_helper.top_300_classes_plus_coco_compatibility
|
60 |
+
category_mapping_target: taming.data.open_images_helper.open_images_unify_categories_for_coco
|
61 |
+
min_object_area: 0.0001
|
62 |
+
min_objects_per_image: 2
|
63 |
+
max_objects_per_image: 30
|
64 |
+
crop_method: random-2d
|
65 |
+
random_flip: true
|
66 |
+
use_group_parameter: true
|
67 |
+
use_additional_parameters: true
|
68 |
+
encode_crop: true
|
69 |
+
validation:
|
70 |
+
target: taming.data.annotated_objects_open_images.AnnotatedObjectsOpenImages
|
71 |
+
params:
|
72 |
+
data_path: data/open_images_annotations_100 # substitute with path to full dataset
|
73 |
+
split: validation
|
74 |
+
keys: [image, objects_bbox, file_name, annotations]
|
75 |
+
no_tokens: 8192
|
76 |
+
target_image_size: 256
|
77 |
+
category_allow_list_target: taming.data.open_images_helper.top_300_classes_plus_coco_compatibility
|
78 |
+
category_mapping_target: taming.data.open_images_helper.open_images_unify_categories_for_coco
|
79 |
+
min_object_area: 0.0001
|
80 |
+
min_objects_per_image: 2
|
81 |
+
max_objects_per_image: 30
|
82 |
+
crop_method: center
|
83 |
+
random_flip: false
|
84 |
+
use_group_parameter: true
|
85 |
+
use_additional_parameters: true
|
86 |
+
encode_crop: true
|
taming-transformers/configs/sflckr_cond_stage.yaml
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 4.5e-06
|
3 |
+
target: taming.models.vqgan.VQSegmentationModel
|
4 |
+
params:
|
5 |
+
embed_dim: 256
|
6 |
+
n_embed: 1024
|
7 |
+
image_key: "segmentation"
|
8 |
+
n_labels: 182
|
9 |
+
ddconfig:
|
10 |
+
double_z: false
|
11 |
+
z_channels: 256
|
12 |
+
resolution: 256
|
13 |
+
in_channels: 182
|
14 |
+
out_ch: 182
|
15 |
+
ch: 128
|
16 |
+
ch_mult:
|
17 |
+
- 1
|
18 |
+
- 1
|
19 |
+
- 2
|
20 |
+
- 2
|
21 |
+
- 4
|
22 |
+
num_res_blocks: 2
|
23 |
+
attn_resolutions:
|
24 |
+
- 16
|
25 |
+
dropout: 0.0
|
26 |
+
|
27 |
+
lossconfig:
|
28 |
+
target: taming.modules.losses.segmentation.BCELossWithQuant
|
29 |
+
params:
|
30 |
+
codebook_weight: 1.0
|
31 |
+
|
32 |
+
data:
|
33 |
+
target: cutlit.DataModuleFromConfig
|
34 |
+
params:
|
35 |
+
batch_size: 12
|
36 |
+
train:
|
37 |
+
target: taming.data.sflckr.Examples # adjust
|
38 |
+
params:
|
39 |
+
size: 256
|
40 |
+
validation:
|
41 |
+
target: taming.data.sflckr.Examples # adjust
|
42 |
+
params:
|
43 |
+
size: 256
|
taming-transformers/environment.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: taming
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- python=3.8.5
|
7 |
+
- pip=20.3
|
8 |
+
- cudatoolkit=10.2
|
9 |
+
- pytorch=1.7.0
|
10 |
+
- torchvision=0.8.1
|
11 |
+
- numpy=1.19.2
|
12 |
+
- pip:
|
13 |
+
- albumentations==0.4.3
|
14 |
+
- opencv-python==4.1.2.30
|
15 |
+
- pudb==2019.2
|
16 |
+
- imageio==2.9.0
|
17 |
+
- imageio-ffmpeg==0.4.2
|
18 |
+
- pytorch-lightning==1.0.8
|
19 |
+
- omegaconf==2.0.0
|
20 |
+
- test-tube>=0.7.5
|
21 |
+
- streamlit>=0.73.1
|
22 |
+
- einops==0.3.0
|
23 |
+
- more-itertools>=8.0.0
|
24 |
+
- transformers==4.3.1
|
25 |
+
- -e .
|
taming-transformers/main.py
ADDED
@@ -0,0 +1,585 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse, os, sys, datetime, glob, importlib
|
2 |
+
from omegaconf import OmegaConf
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import torch
|
6 |
+
import torchvision
|
7 |
+
from torch.utils.data import random_split, DataLoader, Dataset
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
from pytorch_lightning import seed_everything
|
10 |
+
from pytorch_lightning.trainer import Trainer
|
11 |
+
from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
|
12 |
+
from pytorch_lightning.utilities import rank_zero_only
|
13 |
+
|
14 |
+
from taming.data.utils import custom_collate
|
15 |
+
|
16 |
+
|
17 |
+
def get_obj_from_str(string, reload=False):
|
18 |
+
module, cls = string.rsplit(".", 1)
|
19 |
+
if reload:
|
20 |
+
module_imp = importlib.import_module(module)
|
21 |
+
importlib.reload(module_imp)
|
22 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
23 |
+
|
24 |
+
|
25 |
+
def get_parser(**parser_kwargs):
|
26 |
+
def str2bool(v):
|
27 |
+
if isinstance(v, bool):
|
28 |
+
return v
|
29 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
30 |
+
return True
|
31 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
32 |
+
return False
|
33 |
+
else:
|
34 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
35 |
+
|
36 |
+
parser = argparse.ArgumentParser(**parser_kwargs)
|
37 |
+
parser.add_argument(
|
38 |
+
"-n",
|
39 |
+
"--name",
|
40 |
+
type=str,
|
41 |
+
const=True,
|
42 |
+
default="",
|
43 |
+
nargs="?",
|
44 |
+
help="postfix for logdir",
|
45 |
+
)
|
46 |
+
parser.add_argument(
|
47 |
+
"-r",
|
48 |
+
"--resume",
|
49 |
+
type=str,
|
50 |
+
const=True,
|
51 |
+
default="",
|
52 |
+
nargs="?",
|
53 |
+
help="resume from logdir or checkpoint in logdir",
|
54 |
+
)
|
55 |
+
parser.add_argument(
|
56 |
+
"-b",
|
57 |
+
"--base",
|
58 |
+
nargs="*",
|
59 |
+
metavar="base_config.yaml",
|
60 |
+
help="paths to base configs. Loaded from left-to-right. "
|
61 |
+
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
|
62 |
+
default=list(),
|
63 |
+
)
|
64 |
+
parser.add_argument(
|
65 |
+
"-t",
|
66 |
+
"--train",
|
67 |
+
type=str2bool,
|
68 |
+
const=True,
|
69 |
+
default=False,
|
70 |
+
nargs="?",
|
71 |
+
help="train",
|
72 |
+
)
|
73 |
+
parser.add_argument(
|
74 |
+
"--no-test",
|
75 |
+
type=str2bool,
|
76 |
+
const=True,
|
77 |
+
default=False,
|
78 |
+
nargs="?",
|
79 |
+
help="disable test",
|
80 |
+
)
|
81 |
+
parser.add_argument("-p", "--project", help="name of new or path to existing project")
|
82 |
+
parser.add_argument(
|
83 |
+
"-d",
|
84 |
+
"--debug",
|
85 |
+
type=str2bool,
|
86 |
+
nargs="?",
|
87 |
+
const=True,
|
88 |
+
default=False,
|
89 |
+
help="enable post-mortem debugging",
|
90 |
+
)
|
91 |
+
parser.add_argument(
|
92 |
+
"-s",
|
93 |
+
"--seed",
|
94 |
+
type=int,
|
95 |
+
default=23,
|
96 |
+
help="seed for seed_everything",
|
97 |
+
)
|
98 |
+
parser.add_argument(
|
99 |
+
"-f",
|
100 |
+
"--postfix",
|
101 |
+
type=str,
|
102 |
+
default="",
|
103 |
+
help="post-postfix for default name",
|
104 |
+
)
|
105 |
+
|
106 |
+
return parser
|
107 |
+
|
108 |
+
|
109 |
+
def nondefault_trainer_args(opt):
|
110 |
+
parser = argparse.ArgumentParser()
|
111 |
+
parser = Trainer.add_argparse_args(parser)
|
112 |
+
args = parser.parse_args([])
|
113 |
+
return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
|
114 |
+
|
115 |
+
|
116 |
+
def instantiate_from_config(config):
|
117 |
+
if not "target" in config:
|
118 |
+
raise KeyError("Expected key `target` to instantiate.")
|
119 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
120 |
+
|
121 |
+
|
122 |
+
class WrappedDataset(Dataset):
|
123 |
+
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
|
124 |
+
def __init__(self, dataset):
|
125 |
+
self.data = dataset
|
126 |
+
|
127 |
+
def __len__(self):
|
128 |
+
return len(self.data)
|
129 |
+
|
130 |
+
def __getitem__(self, idx):
|
131 |
+
return self.data[idx]
|
132 |
+
|
133 |
+
|
134 |
+
class DataModuleFromConfig(pl.LightningDataModule):
|
135 |
+
def __init__(self, batch_size, train=None, validation=None, test=None,
|
136 |
+
wrap=False, num_workers=None):
|
137 |
+
super().__init__()
|
138 |
+
self.batch_size = batch_size
|
139 |
+
self.dataset_configs = dict()
|
140 |
+
self.num_workers = num_workers if num_workers is not None else batch_size*2
|
141 |
+
if train is not None:
|
142 |
+
self.dataset_configs["train"] = train
|
143 |
+
self.train_dataloader = self._train_dataloader
|
144 |
+
if validation is not None:
|
145 |
+
self.dataset_configs["validation"] = validation
|
146 |
+
self.val_dataloader = self._val_dataloader
|
147 |
+
if test is not None:
|
148 |
+
self.dataset_configs["test"] = test
|
149 |
+
self.test_dataloader = self._test_dataloader
|
150 |
+
self.wrap = wrap
|
151 |
+
|
152 |
+
def prepare_data(self):
|
153 |
+
for data_cfg in self.dataset_configs.values():
|
154 |
+
instantiate_from_config(data_cfg)
|
155 |
+
|
156 |
+
def setup(self, stage=None):
|
157 |
+
self.datasets = dict(
|
158 |
+
(k, instantiate_from_config(self.dataset_configs[k]))
|
159 |
+
for k in self.dataset_configs)
|
160 |
+
if self.wrap:
|
161 |
+
for k in self.datasets:
|
162 |
+
self.datasets[k] = WrappedDataset(self.datasets[k])
|
163 |
+
|
164 |
+
def _train_dataloader(self):
|
165 |
+
return DataLoader(self.datasets["train"], batch_size=self.batch_size,
|
166 |
+
num_workers=self.num_workers, shuffle=True, collate_fn=custom_collate)
|
167 |
+
|
168 |
+
def _val_dataloader(self):
|
169 |
+
return DataLoader(self.datasets["validation"],
|
170 |
+
batch_size=self.batch_size,
|
171 |
+
num_workers=self.num_workers, collate_fn=custom_collate)
|
172 |
+
|
173 |
+
def _test_dataloader(self):
|
174 |
+
return DataLoader(self.datasets["test"], batch_size=self.batch_size,
|
175 |
+
num_workers=self.num_workers, collate_fn=custom_collate)
|
176 |
+
|
177 |
+
|
178 |
+
class SetupCallback(Callback):
|
179 |
+
def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
|
180 |
+
super().__init__()
|
181 |
+
self.resume = resume
|
182 |
+
self.now = now
|
183 |
+
self.logdir = logdir
|
184 |
+
self.ckptdir = ckptdir
|
185 |
+
self.cfgdir = cfgdir
|
186 |
+
self.config = config
|
187 |
+
self.lightning_config = lightning_config
|
188 |
+
|
189 |
+
def on_pretrain_routine_start(self, trainer, pl_module):
|
190 |
+
if trainer.global_rank == 0:
|
191 |
+
# Create logdirs and save configs
|
192 |
+
os.makedirs(self.logdir, exist_ok=True)
|
193 |
+
os.makedirs(self.ckptdir, exist_ok=True)
|
194 |
+
os.makedirs(self.cfgdir, exist_ok=True)
|
195 |
+
|
196 |
+
print("Project config")
|
197 |
+
print(self.config.pretty())
|
198 |
+
OmegaConf.save(self.config,
|
199 |
+
os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
|
200 |
+
|
201 |
+
print("Lightning config")
|
202 |
+
print(self.lightning_config.pretty())
|
203 |
+
OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
|
204 |
+
os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
|
205 |
+
|
206 |
+
else:
|
207 |
+
# ModelCheckpoint callback created log directory --- remove it
|
208 |
+
if not self.resume and os.path.exists(self.logdir):
|
209 |
+
dst, name = os.path.split(self.logdir)
|
210 |
+
dst = os.path.join(dst, "child_runs", name)
|
211 |
+
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
212 |
+
try:
|
213 |
+
os.rename(self.logdir, dst)
|
214 |
+
except FileNotFoundError:
|
215 |
+
pass
|
216 |
+
|
217 |
+
|
218 |
+
class ImageLogger(Callback):
|
219 |
+
def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True):
|
220 |
+
super().__init__()
|
221 |
+
self.batch_freq = batch_frequency
|
222 |
+
self.max_images = max_images
|
223 |
+
self.logger_log_images = {
|
224 |
+
pl.loggers.WandbLogger: self._wandb,
|
225 |
+
pl.loggers.TestTubeLogger: self._testtube,
|
226 |
+
}
|
227 |
+
self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
|
228 |
+
if not increase_log_steps:
|
229 |
+
self.log_steps = [self.batch_freq]
|
230 |
+
self.clamp = clamp
|
231 |
+
|
232 |
+
@rank_zero_only
|
233 |
+
def _wandb(self, pl_module, images, batch_idx, split):
|
234 |
+
raise ValueError("No way wandb")
|
235 |
+
grids = dict()
|
236 |
+
for k in images:
|
237 |
+
grid = torchvision.utils.make_grid(images[k])
|
238 |
+
grids[f"{split}/{k}"] = wandb.Image(grid)
|
239 |
+
pl_module.logger.experiment.log(grids)
|
240 |
+
|
241 |
+
@rank_zero_only
|
242 |
+
def _testtube(self, pl_module, images, batch_idx, split):
|
243 |
+
for k in images:
|
244 |
+
grid = torchvision.utils.make_grid(images[k])
|
245 |
+
grid = (grid+1.0)/2.0 # -1,1 -> 0,1; c,h,w
|
246 |
+
|
247 |
+
tag = f"{split}/{k}"
|
248 |
+
pl_module.logger.experiment.add_image(
|
249 |
+
tag, grid,
|
250 |
+
global_step=pl_module.global_step)
|
251 |
+
|
252 |
+
@rank_zero_only
|
253 |
+
def log_local(self, save_dir, split, images,
|
254 |
+
global_step, current_epoch, batch_idx):
|
255 |
+
root = os.path.join(save_dir, "images", split)
|
256 |
+
for k in images:
|
257 |
+
grid = torchvision.utils.make_grid(images[k], nrow=4)
|
258 |
+
|
259 |
+
grid = (grid+1.0)/2.0 # -1,1 -> 0,1; c,h,w
|
260 |
+
grid = grid.transpose(0,1).transpose(1,2).squeeze(-1)
|
261 |
+
grid = grid.numpy()
|
262 |
+
grid = (grid*255).astype(np.uint8)
|
263 |
+
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
|
264 |
+
k,
|
265 |
+
global_step,
|
266 |
+
current_epoch,
|
267 |
+
batch_idx)
|
268 |
+
path = os.path.join(root, filename)
|
269 |
+
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
270 |
+
Image.fromarray(grid).save(path)
|
271 |
+
|
272 |
+
def log_img(self, pl_module, batch, batch_idx, split="train"):
|
273 |
+
if (self.check_frequency(batch_idx) and # batch_idx % self.batch_freq == 0
|
274 |
+
hasattr(pl_module, "log_images") and
|
275 |
+
callable(pl_module.log_images) and
|
276 |
+
self.max_images > 0):
|
277 |
+
logger = type(pl_module.logger)
|
278 |
+
|
279 |
+
is_train = pl_module.training
|
280 |
+
if is_train:
|
281 |
+
pl_module.eval()
|
282 |
+
|
283 |
+
with torch.no_grad():
|
284 |
+
images = pl_module.log_images(batch, split=split, pl_module=pl_module)
|
285 |
+
|
286 |
+
for k in images:
|
287 |
+
N = min(images[k].shape[0], self.max_images)
|
288 |
+
images[k] = images[k][:N]
|
289 |
+
if isinstance(images[k], torch.Tensor):
|
290 |
+
images[k] = images[k].detach().cpu()
|
291 |
+
if self.clamp:
|
292 |
+
images[k] = torch.clamp(images[k], -1., 1.)
|
293 |
+
|
294 |
+
self.log_local(pl_module.logger.save_dir, split, images,
|
295 |
+
pl_module.global_step, pl_module.current_epoch, batch_idx)
|
296 |
+
|
297 |
+
logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
|
298 |
+
logger_log_images(pl_module, images, pl_module.global_step, split)
|
299 |
+
|
300 |
+
if is_train:
|
301 |
+
pl_module.train()
|
302 |
+
|
303 |
+
def check_frequency(self, batch_idx):
|
304 |
+
if (batch_idx % self.batch_freq) == 0 or (batch_idx in self.log_steps):
|
305 |
+
try:
|
306 |
+
self.log_steps.pop(0)
|
307 |
+
except IndexError:
|
308 |
+
pass
|
309 |
+
return True
|
310 |
+
return False
|
311 |
+
|
312 |
+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
313 |
+
self.log_img(pl_module, batch, batch_idx, split="train")
|
314 |
+
|
315 |
+
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
316 |
+
self.log_img(pl_module, batch, batch_idx, split="val")
|
317 |
+
|
318 |
+
|
319 |
+
|
320 |
+
if __name__ == "__main__":
|
321 |
+
# custom parser to specify config files, train, test and debug mode,
|
322 |
+
# postfix, resume.
|
323 |
+
# `--key value` arguments are interpreted as arguments to the trainer.
|
324 |
+
# `nested.key=value` arguments are interpreted as config parameters.
|
325 |
+
# configs are merged from left-to-right followed by command line parameters.
|
326 |
+
|
327 |
+
# model:
|
328 |
+
# base_learning_rate: float
|
329 |
+
# target: path to lightning module
|
330 |
+
# params:
|
331 |
+
# key: value
|
332 |
+
# data:
|
333 |
+
# target: main.DataModuleFromConfig
|
334 |
+
# params:
|
335 |
+
# batch_size: int
|
336 |
+
# wrap: bool
|
337 |
+
# train:
|
338 |
+
# target: path to train dataset
|
339 |
+
# params:
|
340 |
+
# key: value
|
341 |
+
# validation:
|
342 |
+
# target: path to validation dataset
|
343 |
+
# params:
|
344 |
+
# key: value
|
345 |
+
# test:
|
346 |
+
# target: path to test dataset
|
347 |
+
# params:
|
348 |
+
# key: value
|
349 |
+
# lightning: (optional, has sane defaults and can be specified on cmdline)
|
350 |
+
# trainer:
|
351 |
+
# additional arguments to trainer
|
352 |
+
# logger:
|
353 |
+
# logger to instantiate
|
354 |
+
# modelcheckpoint:
|
355 |
+
# modelcheckpoint to instantiate
|
356 |
+
# callbacks:
|
357 |
+
# callback1:
|
358 |
+
# target: importpath
|
359 |
+
# params:
|
360 |
+
# key: value
|
361 |
+
|
362 |
+
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
363 |
+
|
364 |
+
# add cwd for convenience and to make classes in this file available when
|
365 |
+
# running as `python main.py`
|
366 |
+
# (in particular `main.DataModuleFromConfig`)
|
367 |
+
sys.path.append(os.getcwd())
|
368 |
+
|
369 |
+
parser = get_parser()
|
370 |
+
parser = Trainer.add_argparse_args(parser)
|
371 |
+
|
372 |
+
opt, unknown = parser.parse_known_args()
|
373 |
+
if opt.name and opt.resume:
|
374 |
+
raise ValueError(
|
375 |
+
"-n/--name and -r/--resume cannot be specified both."
|
376 |
+
"If you want to resume training in a new log folder, "
|
377 |
+
"use -n/--name in combination with --resume_from_checkpoint"
|
378 |
+
)
|
379 |
+
if opt.resume:
|
380 |
+
if not os.path.exists(opt.resume):
|
381 |
+
raise ValueError("Cannot find {}".format(opt.resume))
|
382 |
+
if os.path.isfile(opt.resume):
|
383 |
+
paths = opt.resume.split("/")
|
384 |
+
idx = len(paths)-paths[::-1].index("logs")+1
|
385 |
+
logdir = "/".join(paths[:idx])
|
386 |
+
ckpt = opt.resume
|
387 |
+
else:
|
388 |
+
assert os.path.isdir(opt.resume), opt.resume
|
389 |
+
logdir = opt.resume.rstrip("/")
|
390 |
+
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
391 |
+
|
392 |
+
opt.resume_from_checkpoint = ckpt
|
393 |
+
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
|
394 |
+
opt.base = base_configs+opt.base
|
395 |
+
_tmp = logdir.split("/")
|
396 |
+
nowname = _tmp[_tmp.index("logs")+1]
|
397 |
+
else:
|
398 |
+
if opt.name:
|
399 |
+
name = "_"+opt.name
|
400 |
+
elif opt.base:
|
401 |
+
cfg_fname = os.path.split(opt.base[0])[-1]
|
402 |
+
cfg_name = os.path.splitext(cfg_fname)[0]
|
403 |
+
name = "_"+cfg_name
|
404 |
+
else:
|
405 |
+
name = ""
|
406 |
+
nowname = now+name+opt.postfix
|
407 |
+
logdir = os.path.join("logs", nowname)
|
408 |
+
|
409 |
+
ckptdir = os.path.join(logdir, "checkpoints")
|
410 |
+
cfgdir = os.path.join(logdir, "configs")
|
411 |
+
seed_everything(opt.seed)
|
412 |
+
|
413 |
+
try:
|
414 |
+
# init and save configs
|
415 |
+
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
416 |
+
cli = OmegaConf.from_dotlist(unknown)
|
417 |
+
config = OmegaConf.merge(*configs, cli)
|
418 |
+
lightning_config = config.pop("lightning", OmegaConf.create())
|
419 |
+
# merge trainer cli with config
|
420 |
+
trainer_config = lightning_config.get("trainer", OmegaConf.create())
|
421 |
+
# default to ddp
|
422 |
+
trainer_config["distributed_backend"] = "ddp"
|
423 |
+
for k in nondefault_trainer_args(opt):
|
424 |
+
trainer_config[k] = getattr(opt, k)
|
425 |
+
if not "gpus" in trainer_config:
|
426 |
+
del trainer_config["distributed_backend"]
|
427 |
+
cpu = True
|
428 |
+
else:
|
429 |
+
gpuinfo = trainer_config["gpus"]
|
430 |
+
print(f"Running on GPUs {gpuinfo}")
|
431 |
+
cpu = False
|
432 |
+
trainer_opt = argparse.Namespace(**trainer_config)
|
433 |
+
lightning_config.trainer = trainer_config
|
434 |
+
|
435 |
+
# model
|
436 |
+
model = instantiate_from_config(config.model)
|
437 |
+
|
438 |
+
# trainer and callbacks
|
439 |
+
trainer_kwargs = dict()
|
440 |
+
|
441 |
+
# default logger configs
|
442 |
+
# NOTE wandb < 0.10.0 interferes with shutdown
|
443 |
+
# wandb >= 0.10.0 seems to fix it but still interferes with pudb
|
444 |
+
# debugging (wrongly sized pudb ui)
|
445 |
+
# thus prefer testtube for now
|
446 |
+
default_logger_cfgs = {
|
447 |
+
"wandb": {
|
448 |
+
"target": "pytorch_lightning.loggers.WandbLogger",
|
449 |
+
"params": {
|
450 |
+
"name": nowname,
|
451 |
+
"save_dir": logdir,
|
452 |
+
"offline": opt.debug,
|
453 |
+
"id": nowname,
|
454 |
+
}
|
455 |
+
},
|
456 |
+
"testtube": {
|
457 |
+
"target": "pytorch_lightning.loggers.TestTubeLogger",
|
458 |
+
"params": {
|
459 |
+
"name": "testtube",
|
460 |
+
"save_dir": logdir,
|
461 |
+
}
|
462 |
+
},
|
463 |
+
}
|
464 |
+
default_logger_cfg = default_logger_cfgs["testtube"]
|
465 |
+
logger_cfg = lightning_config.logger or OmegaConf.create()
|
466 |
+
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
|
467 |
+
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
|
468 |
+
|
469 |
+
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
|
470 |
+
# specify which metric is used to determine best models
|
471 |
+
default_modelckpt_cfg = {
|
472 |
+
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
|
473 |
+
"params": {
|
474 |
+
"dirpath": ckptdir,
|
475 |
+
"filename": "{epoch:06}",
|
476 |
+
"verbose": True,
|
477 |
+
"save_last": True,
|
478 |
+
}
|
479 |
+
}
|
480 |
+
if hasattr(model, "monitor"):
|
481 |
+
print(f"Monitoring {model.monitor} as checkpoint metric.")
|
482 |
+
default_modelckpt_cfg["params"]["monitor"] = model.monitor
|
483 |
+
default_modelckpt_cfg["params"]["save_top_k"] = 3
|
484 |
+
|
485 |
+
modelckpt_cfg = lightning_config.modelcheckpoint or OmegaConf.create()
|
486 |
+
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
|
487 |
+
trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
|
488 |
+
|
489 |
+
# add callback which sets up log directory
|
490 |
+
default_callbacks_cfg = {
|
491 |
+
"setup_callback": {
|
492 |
+
"target": "main.SetupCallback",
|
493 |
+
"params": {
|
494 |
+
"resume": opt.resume,
|
495 |
+
"now": now,
|
496 |
+
"logdir": logdir,
|
497 |
+
"ckptdir": ckptdir,
|
498 |
+
"cfgdir": cfgdir,
|
499 |
+
"config": config,
|
500 |
+
"lightning_config": lightning_config,
|
501 |
+
}
|
502 |
+
},
|
503 |
+
"image_logger": {
|
504 |
+
"target": "main.ImageLogger",
|
505 |
+
"params": {
|
506 |
+
"batch_frequency": 750,
|
507 |
+
"max_images": 4,
|
508 |
+
"clamp": True
|
509 |
+
}
|
510 |
+
},
|
511 |
+
"learning_rate_logger": {
|
512 |
+
"target": "main.LearningRateMonitor",
|
513 |
+
"params": {
|
514 |
+
"logging_interval": "step",
|
515 |
+
#"log_momentum": True
|
516 |
+
}
|
517 |
+
},
|
518 |
+
}
|
519 |
+
callbacks_cfg = lightning_config.callbacks or OmegaConf.create()
|
520 |
+
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
|
521 |
+
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
|
522 |
+
|
523 |
+
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
524 |
+
|
525 |
+
# data
|
526 |
+
data = instantiate_from_config(config.data)
|
527 |
+
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
528 |
+
# calling these ourselves should not be necessary but it is.
|
529 |
+
# lightning still takes care of proper multiprocessing though
|
530 |
+
data.prepare_data()
|
531 |
+
data.setup()
|
532 |
+
|
533 |
+
# configure learning rate
|
534 |
+
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
|
535 |
+
if not cpu:
|
536 |
+
ngpu = len(lightning_config.trainer.gpus.strip(",").split(','))
|
537 |
+
else:
|
538 |
+
ngpu = 1
|
539 |
+
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches or 1
|
540 |
+
print(f"accumulate_grad_batches = {accumulate_grad_batches}")
|
541 |
+
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
|
542 |
+
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
|
543 |
+
print("Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
|
544 |
+
model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
|
545 |
+
|
546 |
+
# allow checkpointing via USR1
|
547 |
+
def melk(*args, **kwargs):
|
548 |
+
# run all checkpoint hooks
|
549 |
+
if trainer.global_rank == 0:
|
550 |
+
print("Summoning checkpoint.")
|
551 |
+
ckpt_path = os.path.join(ckptdir, "last.ckpt")
|
552 |
+
trainer.save_checkpoint(ckpt_path)
|
553 |
+
|
554 |
+
def divein(*args, **kwargs):
|
555 |
+
if trainer.global_rank == 0:
|
556 |
+
import pudb; pudb.set_trace()
|
557 |
+
|
558 |
+
import signal
|
559 |
+
signal.signal(signal.SIGUSR1, melk)
|
560 |
+
signal.signal(signal.SIGUSR2, divein)
|
561 |
+
|
562 |
+
# run
|
563 |
+
if opt.train:
|
564 |
+
try:
|
565 |
+
trainer.fit(model, data)
|
566 |
+
except Exception:
|
567 |
+
melk()
|
568 |
+
raise
|
569 |
+
if not opt.no_test and not trainer.interrupted:
|
570 |
+
trainer.test(model, data)
|
571 |
+
except Exception:
|
572 |
+
if opt.debug and trainer.global_rank==0:
|
573 |
+
try:
|
574 |
+
import pudb as debugger
|
575 |
+
except ImportError:
|
576 |
+
import pdb as debugger
|
577 |
+
debugger.post_mortem()
|
578 |
+
raise
|
579 |
+
finally:
|
580 |
+
# move newly created debug project to debug_runs
|
581 |
+
if opt.debug and not opt.resume and trainer.global_rank==0:
|
582 |
+
dst, name = os.path.split(logdir)
|
583 |
+
dst = os.path.join(dst, "debug_runs", name)
|
584 |
+
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
585 |
+
os.rename(logdir, dst)
|
taming-transformers/scripts/extract_depth.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from tqdm import trange
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
def get_state(gpu):
|
9 |
+
import torch
|
10 |
+
midas = torch.hub.load("intel-isl/MiDaS", "MiDaS")
|
11 |
+
if gpu:
|
12 |
+
midas.cuda()
|
13 |
+
midas.eval()
|
14 |
+
|
15 |
+
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
|
16 |
+
transform = midas_transforms.default_transform
|
17 |
+
|
18 |
+
state = {"model": midas,
|
19 |
+
"transform": transform}
|
20 |
+
return state
|
21 |
+
|
22 |
+
|
23 |
+
def depth_to_rgba(x):
|
24 |
+
assert x.dtype == np.float32
|
25 |
+
assert len(x.shape) == 2
|
26 |
+
y = x.copy()
|
27 |
+
y.dtype = np.uint8
|
28 |
+
y = y.reshape(x.shape+(4,))
|
29 |
+
return np.ascontiguousarray(y)
|
30 |
+
|
31 |
+
|
32 |
+
def rgba_to_depth(x):
|
33 |
+
assert x.dtype == np.uint8
|
34 |
+
assert len(x.shape) == 3 and x.shape[2] == 4
|
35 |
+
y = x.copy()
|
36 |
+
y.dtype = np.float32
|
37 |
+
y = y.reshape(x.shape[:2])
|
38 |
+
return np.ascontiguousarray(y)
|
39 |
+
|
40 |
+
|
41 |
+
def run(x, state):
|
42 |
+
model = state["model"]
|
43 |
+
transform = state["transform"]
|
44 |
+
hw = x.shape[:2]
|
45 |
+
with torch.no_grad():
|
46 |
+
prediction = model(transform((x + 1.0) * 127.5).cuda())
|
47 |
+
prediction = torch.nn.functional.interpolate(
|
48 |
+
prediction.unsqueeze(1),
|
49 |
+
size=hw,
|
50 |
+
mode="bicubic",
|
51 |
+
align_corners=False,
|
52 |
+
).squeeze()
|
53 |
+
output = prediction.cpu().numpy()
|
54 |
+
return output
|
55 |
+
|
56 |
+
|
57 |
+
def get_filename(relpath, level=-2):
|
58 |
+
# save class folder structure and filename:
|
59 |
+
fn = relpath.split(os.sep)[level:]
|
60 |
+
folder = fn[-2]
|
61 |
+
file = fn[-1].split('.')[0]
|
62 |
+
return folder, file
|
63 |
+
|
64 |
+
|
65 |
+
def save_depth(dataset, path, debug=False):
|
66 |
+
os.makedirs(path)
|
67 |
+
N = len(dset)
|
68 |
+
if debug:
|
69 |
+
N = 10
|
70 |
+
state = get_state(gpu=True)
|
71 |
+
for idx in trange(N, desc="Data"):
|
72 |
+
ex = dataset[idx]
|
73 |
+
image, relpath = ex["image"], ex["relpath"]
|
74 |
+
folder, filename = get_filename(relpath)
|
75 |
+
# prepare
|
76 |
+
folderabspath = os.path.join(path, folder)
|
77 |
+
os.makedirs(folderabspath, exist_ok=True)
|
78 |
+
savepath = os.path.join(folderabspath, filename)
|
79 |
+
# run model
|
80 |
+
xout = run(image, state)
|
81 |
+
I = depth_to_rgba(xout)
|
82 |
+
Image.fromarray(I).save("{}.png".format(savepath))
|
83 |
+
|
84 |
+
|
85 |
+
if __name__ == "__main__":
|
86 |
+
from taming.data.imagenet import ImageNetTrain, ImageNetValidation
|
87 |
+
out = "data/imagenet_depth"
|
88 |
+
if not os.path.exists(out):
|
89 |
+
print("Please create a folder or symlink '{}' to extract depth data ".format(out) +
|
90 |
+
"(be prepared that the output size will be larger than ImageNet itself).")
|
91 |
+
exit(1)
|
92 |
+
|
93 |
+
# go
|
94 |
+
dset = ImageNetValidation()
|
95 |
+
abspath = os.path.join(out, "val")
|
96 |
+
if os.path.exists(abspath):
|
97 |
+
print("{} exists - not doing anything.".format(abspath))
|
98 |
+
else:
|
99 |
+
print("preparing {}".format(abspath))
|
100 |
+
save_depth(dset, abspath)
|
101 |
+
print("done with validation split")
|
102 |
+
|
103 |
+
dset = ImageNetTrain()
|
104 |
+
abspath = os.path.join(out, "train")
|
105 |
+
if os.path.exists(abspath):
|
106 |
+
print("{} exists - not doing anything.".format(abspath))
|
107 |
+
else:
|
108 |
+
print("preparing {}".format(abspath))
|
109 |
+
save_depth(dset, abspath)
|
110 |
+
print("done with train split")
|
111 |
+
|
112 |
+
print("done done.")
|
taming-transformers/scripts/extract_segmentation.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys, os
|
2 |
+
import numpy as np
|
3 |
+
import scipy
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from scipy import ndimage
|
7 |
+
from tqdm import tqdm, trange
|
8 |
+
from PIL import Image
|
9 |
+
import torch.hub
|
10 |
+
import torchvision
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
# download deeplabv2_resnet101_msc-cocostuff164k-100000.pth from
|
14 |
+
# https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/deeplabv2_resnet101_msc-cocostuff164k-100000.pth
|
15 |
+
# and put the path here
|
16 |
+
CKPT_PATH = "TODO"
|
17 |
+
|
18 |
+
rescale = lambda x: (x + 1.) / 2.
|
19 |
+
|
20 |
+
def rescale_bgr(x):
|
21 |
+
x = (x+1)*127.5
|
22 |
+
x = torch.flip(x, dims=[0])
|
23 |
+
return x
|
24 |
+
|
25 |
+
|
26 |
+
class COCOStuffSegmenter(nn.Module):
|
27 |
+
def __init__(self, config):
|
28 |
+
super().__init__()
|
29 |
+
self.config = config
|
30 |
+
self.n_labels = 182
|
31 |
+
model = torch.hub.load("kazuto1011/deeplab-pytorch", "deeplabv2_resnet101", n_classes=self.n_labels)
|
32 |
+
ckpt_path = CKPT_PATH
|
33 |
+
model.load_state_dict(torch.load(ckpt_path))
|
34 |
+
self.model = model
|
35 |
+
|
36 |
+
normalize = torchvision.transforms.Normalize(mean=self.mean, std=self.std)
|
37 |
+
self.image_transform = torchvision.transforms.Compose([
|
38 |
+
torchvision.transforms.Lambda(lambda image: torch.stack(
|
39 |
+
[normalize(rescale_bgr(x)) for x in image]))
|
40 |
+
])
|
41 |
+
|
42 |
+
def forward(self, x, upsample=None):
|
43 |
+
x = self._pre_process(x)
|
44 |
+
x = self.model(x)
|
45 |
+
if upsample is not None:
|
46 |
+
x = torch.nn.functional.upsample_bilinear(x, size=upsample)
|
47 |
+
return x
|
48 |
+
|
49 |
+
def _pre_process(self, x):
|
50 |
+
x = self.image_transform(x)
|
51 |
+
return x
|
52 |
+
|
53 |
+
@property
|
54 |
+
def mean(self):
|
55 |
+
# bgr
|
56 |
+
return [104.008, 116.669, 122.675]
|
57 |
+
|
58 |
+
@property
|
59 |
+
def std(self):
|
60 |
+
return [1.0, 1.0, 1.0]
|
61 |
+
|
62 |
+
@property
|
63 |
+
def input_size(self):
|
64 |
+
return [3, 224, 224]
|
65 |
+
|
66 |
+
|
67 |
+
def run_model(img, model):
|
68 |
+
model = model.eval()
|
69 |
+
with torch.no_grad():
|
70 |
+
segmentation = model(img, upsample=(img.shape[2], img.shape[3]))
|
71 |
+
segmentation = torch.argmax(segmentation, dim=1, keepdim=True)
|
72 |
+
return segmentation.detach().cpu()
|
73 |
+
|
74 |
+
|
75 |
+
def get_input(batch, k):
|
76 |
+
x = batch[k]
|
77 |
+
if len(x.shape) == 3:
|
78 |
+
x = x[..., None]
|
79 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
|
80 |
+
return x.float()
|
81 |
+
|
82 |
+
|
83 |
+
def save_segmentation(segmentation, path):
|
84 |
+
# --> class label to uint8, save as png
|
85 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
86 |
+
assert len(segmentation.shape)==4
|
87 |
+
assert segmentation.shape[0]==1
|
88 |
+
for seg in segmentation:
|
89 |
+
seg = seg.permute(1,2,0).numpy().squeeze().astype(np.uint8)
|
90 |
+
seg = Image.fromarray(seg)
|
91 |
+
seg.save(path)
|
92 |
+
|
93 |
+
|
94 |
+
def iterate_dataset(dataloader, destpath, model):
|
95 |
+
os.makedirs(destpath, exist_ok=True)
|
96 |
+
num_processed = 0
|
97 |
+
for i, batch in tqdm(enumerate(dataloader), desc="Data"):
|
98 |
+
try:
|
99 |
+
img = get_input(batch, "image")
|
100 |
+
img = img.cuda()
|
101 |
+
seg = run_model(img, model)
|
102 |
+
|
103 |
+
path = batch["relative_file_path_"][0]
|
104 |
+
path = os.path.splitext(path)[0]
|
105 |
+
|
106 |
+
path = os.path.join(destpath, path + ".png")
|
107 |
+
save_segmentation(seg, path)
|
108 |
+
num_processed += 1
|
109 |
+
except Exception as e:
|
110 |
+
print(e)
|
111 |
+
print("but anyhow..")
|
112 |
+
|
113 |
+
print("Processed {} files. Bye.".format(num_processed))
|
114 |
+
|
115 |
+
|
116 |
+
from taming.data.sflckr import Examples
|
117 |
+
from torch.utils.data import DataLoader
|
118 |
+
|
119 |
+
if __name__ == "__main__":
|
120 |
+
dest = sys.argv[1]
|
121 |
+
batchsize = 1
|
122 |
+
print("Running with batch-size {}, saving to {}...".format(batchsize, dest))
|
123 |
+
|
124 |
+
model = COCOStuffSegmenter({}).cuda()
|
125 |
+
print("Instantiated model.")
|
126 |
+
|
127 |
+
dataset = Examples()
|
128 |
+
dloader = DataLoader(dataset, batch_size=batchsize)
|
129 |
+
iterate_dataset(dataloader=dloader, destpath=dest, model=model)
|
130 |
+
print("done.")
|
taming-transformers/scripts/extract_submodel.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import sys
|
3 |
+
|
4 |
+
if __name__ == "__main__":
|
5 |
+
inpath = sys.argv[1]
|
6 |
+
outpath = sys.argv[2]
|
7 |
+
submodel = "cond_stage_model"
|
8 |
+
if len(sys.argv) > 3:
|
9 |
+
submodel = sys.argv[3]
|
10 |
+
|
11 |
+
print("Extracting {} from {} to {}.".format(submodel, inpath, outpath))
|
12 |
+
|
13 |
+
sd = torch.load(inpath, map_location="cpu")
|
14 |
+
new_sd = {"state_dict": dict((k.split(".", 1)[-1],v)
|
15 |
+
for k,v in sd["state_dict"].items()
|
16 |
+
if k.startswith("cond_stage_model"))}
|
17 |
+
torch.save(new_sd, outpath)
|
taming-transformers/scripts/make_samples.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse, os, sys, glob, math, time
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
from PIL import Image
|
6 |
+
from main import instantiate_from_config, DataModuleFromConfig
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from torch.utils.data.dataloader import default_collate
|
9 |
+
from tqdm import trange
|
10 |
+
|
11 |
+
|
12 |
+
def save_image(x, path):
|
13 |
+
c,h,w = x.shape
|
14 |
+
assert c==3
|
15 |
+
x = ((x.detach().cpu().numpy().transpose(1,2,0)+1.0)*127.5).clip(0,255).astype(np.uint8)
|
16 |
+
Image.fromarray(x).save(path)
|
17 |
+
|
18 |
+
|
19 |
+
@torch.no_grad()
|
20 |
+
def run_conditional(model, dsets, outdir, top_k, temperature, batch_size=1):
|
21 |
+
if len(dsets.datasets) > 1:
|
22 |
+
split = sorted(dsets.datasets.keys())[0]
|
23 |
+
dset = dsets.datasets[split]
|
24 |
+
else:
|
25 |
+
dset = next(iter(dsets.datasets.values()))
|
26 |
+
print("Dataset: ", dset.__class__.__name__)
|
27 |
+
for start_idx in trange(0,len(dset)-batch_size+1,batch_size):
|
28 |
+
indices = list(range(start_idx, start_idx+batch_size))
|
29 |
+
example = default_collate([dset[i] for i in indices])
|
30 |
+
|
31 |
+
x = model.get_input("image", example).to(model.device)
|
32 |
+
for i in range(x.shape[0]):
|
33 |
+
save_image(x[i], os.path.join(outdir, "originals",
|
34 |
+
"{:06}.png".format(indices[i])))
|
35 |
+
|
36 |
+
cond_key = model.cond_stage_key
|
37 |
+
c = model.get_input(cond_key, example).to(model.device)
|
38 |
+
|
39 |
+
scale_factor = 1.0
|
40 |
+
quant_z, z_indices = model.encode_to_z(x)
|
41 |
+
quant_c, c_indices = model.encode_to_c(c)
|
42 |
+
|
43 |
+
cshape = quant_z.shape
|
44 |
+
|
45 |
+
xrec = model.first_stage_model.decode(quant_z)
|
46 |
+
for i in range(xrec.shape[0]):
|
47 |
+
save_image(xrec[i], os.path.join(outdir, "reconstructions",
|
48 |
+
"{:06}.png".format(indices[i])))
|
49 |
+
|
50 |
+
if cond_key == "segmentation":
|
51 |
+
# get image from segmentation mask
|
52 |
+
num_classes = c.shape[1]
|
53 |
+
c = torch.argmax(c, dim=1, keepdim=True)
|
54 |
+
c = torch.nn.functional.one_hot(c, num_classes=num_classes)
|
55 |
+
c = c.squeeze(1).permute(0, 3, 1, 2).float()
|
56 |
+
c = model.cond_stage_model.to_rgb(c)
|
57 |
+
|
58 |
+
idx = z_indices
|
59 |
+
|
60 |
+
half_sample = False
|
61 |
+
if half_sample:
|
62 |
+
start = idx.shape[1]//2
|
63 |
+
else:
|
64 |
+
start = 0
|
65 |
+
|
66 |
+
idx[:,start:] = 0
|
67 |
+
idx = idx.reshape(cshape[0],cshape[2],cshape[3])
|
68 |
+
start_i = start//cshape[3]
|
69 |
+
start_j = start %cshape[3]
|
70 |
+
|
71 |
+
cidx = c_indices
|
72 |
+
cidx = cidx.reshape(quant_c.shape[0],quant_c.shape[2],quant_c.shape[3])
|
73 |
+
|
74 |
+
sample = True
|
75 |
+
|
76 |
+
for i in range(start_i,cshape[2]-0):
|
77 |
+
if i <= 8:
|
78 |
+
local_i = i
|
79 |
+
elif cshape[2]-i < 8:
|
80 |
+
local_i = 16-(cshape[2]-i)
|
81 |
+
else:
|
82 |
+
local_i = 8
|
83 |
+
for j in range(start_j,cshape[3]-0):
|
84 |
+
if j <= 8:
|
85 |
+
local_j = j
|
86 |
+
elif cshape[3]-j < 8:
|
87 |
+
local_j = 16-(cshape[3]-j)
|
88 |
+
else:
|
89 |
+
local_j = 8
|
90 |
+
|
91 |
+
i_start = i-local_i
|
92 |
+
i_end = i_start+16
|
93 |
+
j_start = j-local_j
|
94 |
+
j_end = j_start+16
|
95 |
+
patch = idx[:,i_start:i_end,j_start:j_end]
|
96 |
+
patch = patch.reshape(patch.shape[0],-1)
|
97 |
+
cpatch = cidx[:, i_start:i_end, j_start:j_end]
|
98 |
+
cpatch = cpatch.reshape(cpatch.shape[0], -1)
|
99 |
+
patch = torch.cat((cpatch, patch), dim=1)
|
100 |
+
logits,_ = model.transformer(patch[:,:-1])
|
101 |
+
logits = logits[:, -256:, :]
|
102 |
+
logits = logits.reshape(cshape[0],16,16,-1)
|
103 |
+
logits = logits[:,local_i,local_j,:]
|
104 |
+
|
105 |
+
logits = logits/temperature
|
106 |
+
|
107 |
+
if top_k is not None:
|
108 |
+
logits = model.top_k_logits(logits, top_k)
|
109 |
+
# apply softmax to convert to probabilities
|
110 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
111 |
+
# sample from the distribution or take the most likely
|
112 |
+
if sample:
|
113 |
+
ix = torch.multinomial(probs, num_samples=1)
|
114 |
+
else:
|
115 |
+
_, ix = torch.topk(probs, k=1, dim=-1)
|
116 |
+
idx[:,i,j] = ix
|
117 |
+
|
118 |
+
xsample = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
|
119 |
+
for i in range(xsample.shape[0]):
|
120 |
+
save_image(xsample[i], os.path.join(outdir, "samples",
|
121 |
+
"{:06}.png".format(indices[i])))
|
122 |
+
|
123 |
+
|
124 |
+
def get_parser():
|
125 |
+
parser = argparse.ArgumentParser()
|
126 |
+
parser.add_argument(
|
127 |
+
"-r",
|
128 |
+
"--resume",
|
129 |
+
type=str,
|
130 |
+
nargs="?",
|
131 |
+
help="load from logdir or checkpoint in logdir",
|
132 |
+
)
|
133 |
+
parser.add_argument(
|
134 |
+
"-b",
|
135 |
+
"--base",
|
136 |
+
nargs="*",
|
137 |
+
metavar="base_config.yaml",
|
138 |
+
help="paths to base configs. Loaded from left-to-right. "
|
139 |
+
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
|
140 |
+
default=list(),
|
141 |
+
)
|
142 |
+
parser.add_argument(
|
143 |
+
"-c",
|
144 |
+
"--config",
|
145 |
+
nargs="?",
|
146 |
+
metavar="single_config.yaml",
|
147 |
+
help="path to single config. If specified, base configs will be ignored "
|
148 |
+
"(except for the last one if left unspecified).",
|
149 |
+
const=True,
|
150 |
+
default="",
|
151 |
+
)
|
152 |
+
parser.add_argument(
|
153 |
+
"--ignore_base_data",
|
154 |
+
action="store_true",
|
155 |
+
help="Ignore data specification from base configs. Useful if you want "
|
156 |
+
"to specify a custom datasets on the command line.",
|
157 |
+
)
|
158 |
+
parser.add_argument(
|
159 |
+
"--outdir",
|
160 |
+
required=True,
|
161 |
+
type=str,
|
162 |
+
help="Where to write outputs to.",
|
163 |
+
)
|
164 |
+
parser.add_argument(
|
165 |
+
"--top_k",
|
166 |
+
type=int,
|
167 |
+
default=100,
|
168 |
+
help="Sample from among top-k predictions.",
|
169 |
+
)
|
170 |
+
parser.add_argument(
|
171 |
+
"--temperature",
|
172 |
+
type=float,
|
173 |
+
default=1.0,
|
174 |
+
help="Sampling temperature.",
|
175 |
+
)
|
176 |
+
return parser
|
177 |
+
|
178 |
+
|
179 |
+
def load_model_from_config(config, sd, gpu=True, eval_mode=True):
|
180 |
+
if "ckpt_path" in config.params:
|
181 |
+
print("Deleting the restore-ckpt path from the config...")
|
182 |
+
config.params.ckpt_path = None
|
183 |
+
if "downsample_cond_size" in config.params:
|
184 |
+
print("Deleting downsample-cond-size from the config and setting factor=0.5 instead...")
|
185 |
+
config.params.downsample_cond_size = -1
|
186 |
+
config.params["downsample_cond_factor"] = 0.5
|
187 |
+
try:
|
188 |
+
if "ckpt_path" in config.params.first_stage_config.params:
|
189 |
+
config.params.first_stage_config.params.ckpt_path = None
|
190 |
+
print("Deleting the first-stage restore-ckpt path from the config...")
|
191 |
+
if "ckpt_path" in config.params.cond_stage_config.params:
|
192 |
+
config.params.cond_stage_config.params.ckpt_path = None
|
193 |
+
print("Deleting the cond-stage restore-ckpt path from the config...")
|
194 |
+
except:
|
195 |
+
pass
|
196 |
+
|
197 |
+
model = instantiate_from_config(config)
|
198 |
+
if sd is not None:
|
199 |
+
missing, unexpected = model.load_state_dict(sd, strict=False)
|
200 |
+
print(f"Missing Keys in State Dict: {missing}")
|
201 |
+
print(f"Unexpected Keys in State Dict: {unexpected}")
|
202 |
+
if gpu:
|
203 |
+
model.cuda()
|
204 |
+
if eval_mode:
|
205 |
+
model.eval()
|
206 |
+
return {"model": model}
|
207 |
+
|
208 |
+
|
209 |
+
def get_data(config):
|
210 |
+
# get data
|
211 |
+
data = instantiate_from_config(config.data)
|
212 |
+
data.prepare_data()
|
213 |
+
data.setup()
|
214 |
+
return data
|
215 |
+
|
216 |
+
|
217 |
+
def load_model_and_dset(config, ckpt, gpu, eval_mode):
|
218 |
+
# get data
|
219 |
+
dsets = get_data(config) # calls data.config ...
|
220 |
+
|
221 |
+
# now load the specified checkpoint
|
222 |
+
if ckpt:
|
223 |
+
pl_sd = torch.load(ckpt, map_location="cpu")
|
224 |
+
global_step = pl_sd["global_step"]
|
225 |
+
else:
|
226 |
+
pl_sd = {"state_dict": None}
|
227 |
+
global_step = None
|
228 |
+
model = load_model_from_config(config.model,
|
229 |
+
pl_sd["state_dict"],
|
230 |
+
gpu=gpu,
|
231 |
+
eval_mode=eval_mode)["model"]
|
232 |
+
return dsets, model, global_step
|
233 |
+
|
234 |
+
|
235 |
+
if __name__ == "__main__":
|
236 |
+
sys.path.append(os.getcwd())
|
237 |
+
|
238 |
+
parser = get_parser()
|
239 |
+
|
240 |
+
opt, unknown = parser.parse_known_args()
|
241 |
+
|
242 |
+
ckpt = None
|
243 |
+
if opt.resume:
|
244 |
+
if not os.path.exists(opt.resume):
|
245 |
+
raise ValueError("Cannot find {}".format(opt.resume))
|
246 |
+
if os.path.isfile(opt.resume):
|
247 |
+
paths = opt.resume.split("/")
|
248 |
+
try:
|
249 |
+
idx = len(paths)-paths[::-1].index("logs")+1
|
250 |
+
except ValueError:
|
251 |
+
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
|
252 |
+
logdir = "/".join(paths[:idx])
|
253 |
+
ckpt = opt.resume
|
254 |
+
else:
|
255 |
+
assert os.path.isdir(opt.resume), opt.resume
|
256 |
+
logdir = opt.resume.rstrip("/")
|
257 |
+
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
258 |
+
print(f"logdir:{logdir}")
|
259 |
+
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
|
260 |
+
opt.base = base_configs+opt.base
|
261 |
+
|
262 |
+
if opt.config:
|
263 |
+
if type(opt.config) == str:
|
264 |
+
opt.base = [opt.config]
|
265 |
+
else:
|
266 |
+
opt.base = [opt.base[-1]]
|
267 |
+
|
268 |
+
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
269 |
+
cli = OmegaConf.from_dotlist(unknown)
|
270 |
+
if opt.ignore_base_data:
|
271 |
+
for config in configs:
|
272 |
+
if hasattr(config, "data"): del config["data"]
|
273 |
+
config = OmegaConf.merge(*configs, cli)
|
274 |
+
|
275 |
+
print(ckpt)
|
276 |
+
gpu = True
|
277 |
+
eval_mode = True
|
278 |
+
show_config = False
|
279 |
+
if show_config:
|
280 |
+
print(OmegaConf.to_container(config))
|
281 |
+
|
282 |
+
dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
|
283 |
+
print(f"Global step: {global_step}")
|
284 |
+
|
285 |
+
outdir = os.path.join(opt.outdir, "{:06}_{}_{}".format(global_step,
|
286 |
+
opt.top_k,
|
287 |
+
opt.temperature))
|
288 |
+
os.makedirs(outdir, exist_ok=True)
|
289 |
+
print("Writing samples to ", outdir)
|
290 |
+
for k in ["originals", "reconstructions", "samples"]:
|
291 |
+
os.makedirs(os.path.join(outdir, k), exist_ok=True)
|
292 |
+
run_conditional(model, dsets, outdir, opt.top_k, opt.temperature)
|
taming-transformers/scripts/make_scene_samples.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
from itertools import product
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Literal, List, Optional, Tuple
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from omegaconf import OmegaConf
|
11 |
+
from pytorch_lightning import seed_everything
|
12 |
+
from torch import Tensor
|
13 |
+
from torchvision.utils import save_image
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
from scripts.make_samples import get_parser, load_model_and_dset
|
17 |
+
from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
|
18 |
+
from taming.data.helper_types import BoundingBox, Annotation
|
19 |
+
from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
|
20 |
+
from taming.models.cond_transformer import Net2NetTransformer
|
21 |
+
|
22 |
+
seed_everything(42424242)
|
23 |
+
device: Literal['cuda', 'cpu'] = 'cuda'
|
24 |
+
first_stage_factor = 16
|
25 |
+
trained_on_res = 256
|
26 |
+
|
27 |
+
|
28 |
+
def _helper(coord: int, coord_max: int, coord_window: int) -> (int, int):
|
29 |
+
assert 0 <= coord < coord_max
|
30 |
+
coord_desired_center = (coord_window - 1) // 2
|
31 |
+
return np.clip(coord - coord_desired_center, 0, coord_max - coord_window)
|
32 |
+
|
33 |
+
|
34 |
+
def get_crop_coordinates(x: int, y: int) -> BoundingBox:
|
35 |
+
WIDTH, HEIGHT = desired_z_shape[1], desired_z_shape[0]
|
36 |
+
x0 = _helper(x, WIDTH, first_stage_factor) / WIDTH
|
37 |
+
y0 = _helper(y, HEIGHT, first_stage_factor) / HEIGHT
|
38 |
+
w = first_stage_factor / WIDTH
|
39 |
+
h = first_stage_factor / HEIGHT
|
40 |
+
return x0, y0, w, h
|
41 |
+
|
42 |
+
|
43 |
+
def get_z_indices_crop_out(z_indices: Tensor, predict_x: int, predict_y: int) -> Tensor:
|
44 |
+
WIDTH, HEIGHT = desired_z_shape[1], desired_z_shape[0]
|
45 |
+
x0 = _helper(predict_x, WIDTH, first_stage_factor)
|
46 |
+
y0 = _helper(predict_y, HEIGHT, first_stage_factor)
|
47 |
+
no_images = z_indices.shape[0]
|
48 |
+
cut_out_1 = z_indices[:, y0:predict_y, x0:x0+first_stage_factor].reshape((no_images, -1))
|
49 |
+
cut_out_2 = z_indices[:, predict_y, x0:predict_x]
|
50 |
+
return torch.cat((cut_out_1, cut_out_2), dim=1)
|
51 |
+
|
52 |
+
|
53 |
+
@torch.no_grad()
|
54 |
+
def sample(model: Net2NetTransformer, annotations: List[Annotation], dataset: AnnotatedObjectsDataset,
|
55 |
+
conditional_builder: ObjectsCenterPointsConditionalBuilder, no_samples: int,
|
56 |
+
temperature: float, top_k: int) -> Tensor:
|
57 |
+
x_max, y_max = desired_z_shape[1], desired_z_shape[0]
|
58 |
+
|
59 |
+
annotations = [a._replace(category_no=dataset.get_category_number(a.category_id)) for a in annotations]
|
60 |
+
|
61 |
+
recompute_conditional = any((desired_resolution[0] > trained_on_res, desired_resolution[1] > trained_on_res))
|
62 |
+
if not recompute_conditional:
|
63 |
+
crop_coordinates = get_crop_coordinates(0, 0)
|
64 |
+
conditional_indices = conditional_builder.build(annotations, crop_coordinates)
|
65 |
+
c_indices = conditional_indices.to(device).repeat(no_samples, 1)
|
66 |
+
z_indices = torch.zeros((no_samples, 0), device=device).long()
|
67 |
+
output_indices = model.sample(z_indices, c_indices, steps=x_max*y_max, temperature=temperature,
|
68 |
+
sample=True, top_k=top_k)
|
69 |
+
else:
|
70 |
+
output_indices = torch.zeros((no_samples, y_max, x_max), device=device).long()
|
71 |
+
for predict_y, predict_x in tqdm(product(range(y_max), range(x_max)), desc='sampling_image', total=x_max*y_max):
|
72 |
+
crop_coordinates = get_crop_coordinates(predict_x, predict_y)
|
73 |
+
z_indices = get_z_indices_crop_out(output_indices, predict_x, predict_y)
|
74 |
+
conditional_indices = conditional_builder.build(annotations, crop_coordinates)
|
75 |
+
c_indices = conditional_indices.to(device).repeat(no_samples, 1)
|
76 |
+
new_index = model.sample(z_indices, c_indices, steps=1, temperature=temperature, sample=True, top_k=top_k)
|
77 |
+
output_indices[:, predict_y, predict_x] = new_index[:, -1]
|
78 |
+
z_shape = (
|
79 |
+
no_samples,
|
80 |
+
model.first_stage_model.quantize.e_dim, # codebook embed_dim
|
81 |
+
desired_z_shape[0], # z_height
|
82 |
+
desired_z_shape[1] # z_width
|
83 |
+
)
|
84 |
+
x_sample = model.decode_to_img(output_indices, z_shape) * 0.5 + 0.5
|
85 |
+
x_sample = x_sample.to('cpu')
|
86 |
+
|
87 |
+
plotter = conditional_builder.plot
|
88 |
+
figure_size = (x_sample.shape[2], x_sample.shape[3])
|
89 |
+
scene_graph = conditional_builder.build(annotations, (0., 0., 1., 1.))
|
90 |
+
plot = plotter(scene_graph, dataset.get_textual_label_for_category_no, figure_size)
|
91 |
+
return torch.cat((x_sample, plot.unsqueeze(0)))
|
92 |
+
|
93 |
+
|
94 |
+
def get_resolution(resolution_str: str) -> (Tuple[int, int], Tuple[int, int]):
|
95 |
+
if not resolution_str.count(',') == 1:
|
96 |
+
raise ValueError("Give resolution as in 'height,width'")
|
97 |
+
res_h, res_w = resolution_str.split(',')
|
98 |
+
res_h = max(int(res_h), trained_on_res)
|
99 |
+
res_w = max(int(res_w), trained_on_res)
|
100 |
+
z_h = int(round(res_h/first_stage_factor))
|
101 |
+
z_w = int(round(res_w/first_stage_factor))
|
102 |
+
return (z_h, z_w), (z_h*first_stage_factor, z_w*first_stage_factor)
|
103 |
+
|
104 |
+
|
105 |
+
def add_arg_to_parser(parser):
|
106 |
+
parser.add_argument(
|
107 |
+
"-R",
|
108 |
+
"--resolution",
|
109 |
+
type=str,
|
110 |
+
default='256,256',
|
111 |
+
help=f"give resolution in multiples of {first_stage_factor}, default is '256,256'",
|
112 |
+
)
|
113 |
+
parser.add_argument(
|
114 |
+
"-C",
|
115 |
+
"--conditional",
|
116 |
+
type=str,
|
117 |
+
default='objects_bbox',
|
118 |
+
help=f"objects_bbox or objects_center_points",
|
119 |
+
)
|
120 |
+
parser.add_argument(
|
121 |
+
"-N",
|
122 |
+
"--n_samples_per_layout",
|
123 |
+
type=int,
|
124 |
+
default=4,
|
125 |
+
help=f"how many samples to generate per layout",
|
126 |
+
)
|
127 |
+
return parser
|
128 |
+
|
129 |
+
|
130 |
+
if __name__ == "__main__":
|
131 |
+
sys.path.append(os.getcwd())
|
132 |
+
|
133 |
+
parser = get_parser()
|
134 |
+
parser = add_arg_to_parser(parser)
|
135 |
+
|
136 |
+
opt, unknown = parser.parse_known_args()
|
137 |
+
|
138 |
+
ckpt = None
|
139 |
+
if opt.resume:
|
140 |
+
if not os.path.exists(opt.resume):
|
141 |
+
raise ValueError("Cannot find {}".format(opt.resume))
|
142 |
+
if os.path.isfile(opt.resume):
|
143 |
+
paths = opt.resume.split("/")
|
144 |
+
try:
|
145 |
+
idx = len(paths)-paths[::-1].index("logs")+1
|
146 |
+
except ValueError:
|
147 |
+
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
|
148 |
+
logdir = "/".join(paths[:idx])
|
149 |
+
ckpt = opt.resume
|
150 |
+
else:
|
151 |
+
assert os.path.isdir(opt.resume), opt.resume
|
152 |
+
logdir = opt.resume.rstrip("/")
|
153 |
+
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
154 |
+
print(f"logdir:{logdir}")
|
155 |
+
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
|
156 |
+
opt.base = base_configs+opt.base
|
157 |
+
|
158 |
+
if opt.config:
|
159 |
+
if type(opt.config) == str:
|
160 |
+
opt.base = [opt.config]
|
161 |
+
else:
|
162 |
+
opt.base = [opt.base[-1]]
|
163 |
+
|
164 |
+
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
165 |
+
cli = OmegaConf.from_dotlist(unknown)
|
166 |
+
if opt.ignore_base_data:
|
167 |
+
for config in configs:
|
168 |
+
if hasattr(config, "data"):
|
169 |
+
del config["data"]
|
170 |
+
config = OmegaConf.merge(*configs, cli)
|
171 |
+
desired_z_shape, desired_resolution = get_resolution(opt.resolution)
|
172 |
+
conditional = opt.conditional
|
173 |
+
|
174 |
+
print(ckpt)
|
175 |
+
gpu = True
|
176 |
+
eval_mode = True
|
177 |
+
show_config = False
|
178 |
+
if show_config:
|
179 |
+
print(OmegaConf.to_container(config))
|
180 |
+
|
181 |
+
dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
|
182 |
+
print(f"Global step: {global_step}")
|
183 |
+
|
184 |
+
data_loader = dsets.val_dataloader()
|
185 |
+
print(dsets.datasets["validation"].conditional_builders)
|
186 |
+
conditional_builder = dsets.datasets["validation"].conditional_builders[conditional]
|
187 |
+
|
188 |
+
outdir = Path(opt.outdir).joinpath(f"{global_step:06}_{opt.top_k}_{opt.temperature}")
|
189 |
+
outdir.mkdir(exist_ok=True, parents=True)
|
190 |
+
print("Writing samples to ", outdir)
|
191 |
+
|
192 |
+
p_bar_1 = tqdm(enumerate(iter(data_loader)), desc='batch', total=len(data_loader))
|
193 |
+
for batch_no, batch in p_bar_1:
|
194 |
+
save_img: Optional[Tensor] = None
|
195 |
+
for i, annotations in tqdm(enumerate(batch['annotations']), desc='within_batch', total=data_loader.batch_size):
|
196 |
+
imgs = sample(model, annotations, dsets.datasets["validation"], conditional_builder,
|
197 |
+
opt.n_samples_per_layout, opt.temperature, opt.top_k)
|
198 |
+
save_image(imgs, outdir.joinpath(f'{batch_no:04}_{i:02}.png'), n_row=opt.n_samples_per_layout+1)
|
taming-transformers/scripts/sample_conditional.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse, os, sys, glob, math, time
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
import streamlit as st
|
6 |
+
from streamlit import caching
|
7 |
+
from PIL import Image
|
8 |
+
from main import instantiate_from_config, DataModuleFromConfig
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
from torch.utils.data.dataloader import default_collate
|
11 |
+
|
12 |
+
|
13 |
+
rescale = lambda x: (x + 1.) / 2.
|
14 |
+
|
15 |
+
|
16 |
+
def bchw_to_st(x):
|
17 |
+
return rescale(x.detach().cpu().numpy().transpose(0,2,3,1))
|
18 |
+
|
19 |
+
def save_img(xstart, fname):
|
20 |
+
I = (xstart.clip(0,1)[0]*255).astype(np.uint8)
|
21 |
+
Image.fromarray(I).save(fname)
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
def get_interactive_image(resize=False):
|
26 |
+
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"])
|
27 |
+
if image is not None:
|
28 |
+
image = Image.open(image)
|
29 |
+
if not image.mode == "RGB":
|
30 |
+
image = image.convert("RGB")
|
31 |
+
image = np.array(image).astype(np.uint8)
|
32 |
+
print("upload image shape: {}".format(image.shape))
|
33 |
+
img = Image.fromarray(image)
|
34 |
+
if resize:
|
35 |
+
img = img.resize((256, 256))
|
36 |
+
image = np.array(img)
|
37 |
+
return image
|
38 |
+
|
39 |
+
|
40 |
+
def single_image_to_torch(x, permute=True):
|
41 |
+
assert x is not None, "Please provide an image through the upload function"
|
42 |
+
x = np.array(x)
|
43 |
+
x = torch.FloatTensor(x/255.*2. - 1.)[None,...]
|
44 |
+
if permute:
|
45 |
+
x = x.permute(0, 3, 1, 2)
|
46 |
+
return x
|
47 |
+
|
48 |
+
|
49 |
+
def pad_to_M(x, M):
|
50 |
+
hp = math.ceil(x.shape[2]/M)*M-x.shape[2]
|
51 |
+
wp = math.ceil(x.shape[3]/M)*M-x.shape[3]
|
52 |
+
x = torch.nn.functional.pad(x, (0,wp,0,hp,0,0,0,0))
|
53 |
+
return x
|
54 |
+
|
55 |
+
@torch.no_grad()
|
56 |
+
def run_conditional(model, dsets):
|
57 |
+
if len(dsets.datasets) > 1:
|
58 |
+
split = st.sidebar.radio("Split", sorted(dsets.datasets.keys()))
|
59 |
+
dset = dsets.datasets[split]
|
60 |
+
else:
|
61 |
+
dset = next(iter(dsets.datasets.values()))
|
62 |
+
batch_size = 1
|
63 |
+
start_index = st.sidebar.number_input("Example Index (Size: {})".format(len(dset)), value=0,
|
64 |
+
min_value=0,
|
65 |
+
max_value=len(dset)-batch_size)
|
66 |
+
indices = list(range(start_index, start_index+batch_size))
|
67 |
+
|
68 |
+
example = default_collate([dset[i] for i in indices])
|
69 |
+
|
70 |
+
x = model.get_input("image", example).to(model.device)
|
71 |
+
|
72 |
+
cond_key = model.cond_stage_key
|
73 |
+
c = model.get_input(cond_key, example).to(model.device)
|
74 |
+
|
75 |
+
scale_factor = st.sidebar.slider("Scale Factor", min_value=0.5, max_value=4.0, step=0.25, value=1.00)
|
76 |
+
if scale_factor != 1.0:
|
77 |
+
x = torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="bicubic")
|
78 |
+
c = torch.nn.functional.interpolate(c, scale_factor=scale_factor, mode="bicubic")
|
79 |
+
|
80 |
+
quant_z, z_indices = model.encode_to_z(x)
|
81 |
+
quant_c, c_indices = model.encode_to_c(c)
|
82 |
+
|
83 |
+
cshape = quant_z.shape
|
84 |
+
|
85 |
+
xrec = model.first_stage_model.decode(quant_z)
|
86 |
+
st.write("image: {}".format(x.shape))
|
87 |
+
st.image(bchw_to_st(x), clamp=True, output_format="PNG")
|
88 |
+
st.write("image reconstruction: {}".format(xrec.shape))
|
89 |
+
st.image(bchw_to_st(xrec), clamp=True, output_format="PNG")
|
90 |
+
|
91 |
+
if cond_key == "segmentation":
|
92 |
+
# get image from segmentation mask
|
93 |
+
num_classes = c.shape[1]
|
94 |
+
c = torch.argmax(c, dim=1, keepdim=True)
|
95 |
+
c = torch.nn.functional.one_hot(c, num_classes=num_classes)
|
96 |
+
c = c.squeeze(1).permute(0, 3, 1, 2).float()
|
97 |
+
c = model.cond_stage_model.to_rgb(c)
|
98 |
+
|
99 |
+
st.write(f"{cond_key}: {tuple(c.shape)}")
|
100 |
+
st.image(bchw_to_st(c), clamp=True, output_format="PNG")
|
101 |
+
|
102 |
+
idx = z_indices
|
103 |
+
|
104 |
+
half_sample = st.sidebar.checkbox("Image Completion", value=False)
|
105 |
+
if half_sample:
|
106 |
+
start = idx.shape[1]//2
|
107 |
+
else:
|
108 |
+
start = 0
|
109 |
+
|
110 |
+
idx[:,start:] = 0
|
111 |
+
idx = idx.reshape(cshape[0],cshape[2],cshape[3])
|
112 |
+
start_i = start//cshape[3]
|
113 |
+
start_j = start %cshape[3]
|
114 |
+
|
115 |
+
if not half_sample and quant_z.shape == quant_c.shape:
|
116 |
+
st.info("Setting idx to c_indices")
|
117 |
+
idx = c_indices.clone().reshape(cshape[0],cshape[2],cshape[3])
|
118 |
+
|
119 |
+
cidx = c_indices
|
120 |
+
cidx = cidx.reshape(quant_c.shape[0],quant_c.shape[2],quant_c.shape[3])
|
121 |
+
|
122 |
+
xstart = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
|
123 |
+
st.image(bchw_to_st(xstart), clamp=True, output_format="PNG")
|
124 |
+
|
125 |
+
temperature = st.number_input("Temperature", value=1.0)
|
126 |
+
top_k = st.number_input("Top k", value=100)
|
127 |
+
sample = st.checkbox("Sample", value=True)
|
128 |
+
update_every = st.number_input("Update every", value=75)
|
129 |
+
|
130 |
+
st.text(f"Sampling shape ({cshape[2]},{cshape[3]})")
|
131 |
+
|
132 |
+
animate = st.checkbox("animate")
|
133 |
+
if animate:
|
134 |
+
import imageio
|
135 |
+
outvid = "sampling.mp4"
|
136 |
+
writer = imageio.get_writer(outvid, fps=25)
|
137 |
+
elapsed_t = st.empty()
|
138 |
+
info = st.empty()
|
139 |
+
st.text("Sampled")
|
140 |
+
if st.button("Sample"):
|
141 |
+
output = st.empty()
|
142 |
+
start_t = time.time()
|
143 |
+
for i in range(start_i,cshape[2]-0):
|
144 |
+
if i <= 8:
|
145 |
+
local_i = i
|
146 |
+
elif cshape[2]-i < 8:
|
147 |
+
local_i = 16-(cshape[2]-i)
|
148 |
+
else:
|
149 |
+
local_i = 8
|
150 |
+
for j in range(start_j,cshape[3]-0):
|
151 |
+
if j <= 8:
|
152 |
+
local_j = j
|
153 |
+
elif cshape[3]-j < 8:
|
154 |
+
local_j = 16-(cshape[3]-j)
|
155 |
+
else:
|
156 |
+
local_j = 8
|
157 |
+
|
158 |
+
i_start = i-local_i
|
159 |
+
i_end = i_start+16
|
160 |
+
j_start = j-local_j
|
161 |
+
j_end = j_start+16
|
162 |
+
elapsed_t.text(f"Time: {time.time() - start_t} seconds")
|
163 |
+
info.text(f"Step: ({i},{j}) | Local: ({local_i},{local_j}) | Crop: ({i_start}:{i_end},{j_start}:{j_end})")
|
164 |
+
patch = idx[:,i_start:i_end,j_start:j_end]
|
165 |
+
patch = patch.reshape(patch.shape[0],-1)
|
166 |
+
cpatch = cidx[:, i_start:i_end, j_start:j_end]
|
167 |
+
cpatch = cpatch.reshape(cpatch.shape[0], -1)
|
168 |
+
patch = torch.cat((cpatch, patch), dim=1)
|
169 |
+
logits,_ = model.transformer(patch[:,:-1])
|
170 |
+
logits = logits[:, -256:, :]
|
171 |
+
logits = logits.reshape(cshape[0],16,16,-1)
|
172 |
+
logits = logits[:,local_i,local_j,:]
|
173 |
+
|
174 |
+
logits = logits/temperature
|
175 |
+
|
176 |
+
if top_k is not None:
|
177 |
+
logits = model.top_k_logits(logits, top_k)
|
178 |
+
# apply softmax to convert to probabilities
|
179 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
180 |
+
# sample from the distribution or take the most likely
|
181 |
+
if sample:
|
182 |
+
ix = torch.multinomial(probs, num_samples=1)
|
183 |
+
else:
|
184 |
+
_, ix = torch.topk(probs, k=1, dim=-1)
|
185 |
+
idx[:,i,j] = ix
|
186 |
+
|
187 |
+
if (i*cshape[3]+j)%update_every==0:
|
188 |
+
xstart = model.decode_to_img(idx[:, :cshape[2], :cshape[3]], cshape,)
|
189 |
+
|
190 |
+
xstart = bchw_to_st(xstart)
|
191 |
+
output.image(xstart, clamp=True, output_format="PNG")
|
192 |
+
|
193 |
+
if animate:
|
194 |
+
writer.append_data((xstart[0]*255).clip(0, 255).astype(np.uint8))
|
195 |
+
|
196 |
+
xstart = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
|
197 |
+
xstart = bchw_to_st(xstart)
|
198 |
+
output.image(xstart, clamp=True, output_format="PNG")
|
199 |
+
#save_img(xstart, "full_res_sample.png")
|
200 |
+
if animate:
|
201 |
+
writer.close()
|
202 |
+
st.video(outvid)
|
203 |
+
|
204 |
+
|
205 |
+
def get_parser():
|
206 |
+
parser = argparse.ArgumentParser()
|
207 |
+
parser.add_argument(
|
208 |
+
"-r",
|
209 |
+
"--resume",
|
210 |
+
type=str,
|
211 |
+
nargs="?",
|
212 |
+
help="load from logdir or checkpoint in logdir",
|
213 |
+
)
|
214 |
+
parser.add_argument(
|
215 |
+
"-b",
|
216 |
+
"--base",
|
217 |
+
nargs="*",
|
218 |
+
metavar="base_config.yaml",
|
219 |
+
help="paths to base configs. Loaded from left-to-right. "
|
220 |
+
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
|
221 |
+
default=list(),
|
222 |
+
)
|
223 |
+
parser.add_argument(
|
224 |
+
"-c",
|
225 |
+
"--config",
|
226 |
+
nargs="?",
|
227 |
+
metavar="single_config.yaml",
|
228 |
+
help="path to single config. If specified, base configs will be ignored "
|
229 |
+
"(except for the last one if left unspecified).",
|
230 |
+
const=True,
|
231 |
+
default="",
|
232 |
+
)
|
233 |
+
parser.add_argument(
|
234 |
+
"--ignore_base_data",
|
235 |
+
action="store_true",
|
236 |
+
help="Ignore data specification from base configs. Useful if you want "
|
237 |
+
"to specify a custom datasets on the command line.",
|
238 |
+
)
|
239 |
+
return parser
|
240 |
+
|
241 |
+
|
242 |
+
def load_model_from_config(config, sd, gpu=True, eval_mode=True):
|
243 |
+
if "ckpt_path" in config.params:
|
244 |
+
st.warning("Deleting the restore-ckpt path from the config...")
|
245 |
+
config.params.ckpt_path = None
|
246 |
+
if "downsample_cond_size" in config.params:
|
247 |
+
st.warning("Deleting downsample-cond-size from the config and setting factor=0.5 instead...")
|
248 |
+
config.params.downsample_cond_size = -1
|
249 |
+
config.params["downsample_cond_factor"] = 0.5
|
250 |
+
try:
|
251 |
+
if "ckpt_path" in config.params.first_stage_config.params:
|
252 |
+
config.params.first_stage_config.params.ckpt_path = None
|
253 |
+
st.warning("Deleting the first-stage restore-ckpt path from the config...")
|
254 |
+
if "ckpt_path" in config.params.cond_stage_config.params:
|
255 |
+
config.params.cond_stage_config.params.ckpt_path = None
|
256 |
+
st.warning("Deleting the cond-stage restore-ckpt path from the config...")
|
257 |
+
except:
|
258 |
+
pass
|
259 |
+
|
260 |
+
model = instantiate_from_config(config)
|
261 |
+
if sd is not None:
|
262 |
+
missing, unexpected = model.load_state_dict(sd, strict=False)
|
263 |
+
st.info(f"Missing Keys in State Dict: {missing}")
|
264 |
+
st.info(f"Unexpected Keys in State Dict: {unexpected}")
|
265 |
+
if gpu:
|
266 |
+
model.cuda()
|
267 |
+
if eval_mode:
|
268 |
+
model.eval()
|
269 |
+
return {"model": model}
|
270 |
+
|
271 |
+
|
272 |
+
def get_data(config):
|
273 |
+
# get data
|
274 |
+
data = instantiate_from_config(config.data)
|
275 |
+
data.prepare_data()
|
276 |
+
data.setup()
|
277 |
+
return data
|
278 |
+
|
279 |
+
|
280 |
+
@st.cache(allow_output_mutation=True, suppress_st_warning=True)
|
281 |
+
def load_model_and_dset(config, ckpt, gpu, eval_mode):
|
282 |
+
# get data
|
283 |
+
dsets = get_data(config) # calls data.config ...
|
284 |
+
|
285 |
+
# now load the specified checkpoint
|
286 |
+
if ckpt:
|
287 |
+
pl_sd = torch.load(ckpt, map_location="cpu")
|
288 |
+
global_step = pl_sd["global_step"]
|
289 |
+
else:
|
290 |
+
pl_sd = {"state_dict": None}
|
291 |
+
global_step = None
|
292 |
+
model = load_model_from_config(config.model,
|
293 |
+
pl_sd["state_dict"],
|
294 |
+
gpu=gpu,
|
295 |
+
eval_mode=eval_mode)["model"]
|
296 |
+
return dsets, model, global_step
|
297 |
+
|
298 |
+
|
299 |
+
if __name__ == "__main__":
|
300 |
+
sys.path.append(os.getcwd())
|
301 |
+
|
302 |
+
parser = get_parser()
|
303 |
+
|
304 |
+
opt, unknown = parser.parse_known_args()
|
305 |
+
|
306 |
+
ckpt = None
|
307 |
+
if opt.resume:
|
308 |
+
if not os.path.exists(opt.resume):
|
309 |
+
raise ValueError("Cannot find {}".format(opt.resume))
|
310 |
+
if os.path.isfile(opt.resume):
|
311 |
+
paths = opt.resume.split("/")
|
312 |
+
try:
|
313 |
+
idx = len(paths)-paths[::-1].index("logs")+1
|
314 |
+
except ValueError:
|
315 |
+
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
|
316 |
+
logdir = "/".join(paths[:idx])
|
317 |
+
ckpt = opt.resume
|
318 |
+
else:
|
319 |
+
assert os.path.isdir(opt.resume), opt.resume
|
320 |
+
logdir = opt.resume.rstrip("/")
|
321 |
+
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
322 |
+
print(f"logdir:{logdir}")
|
323 |
+
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
|
324 |
+
opt.base = base_configs+opt.base
|
325 |
+
|
326 |
+
if opt.config:
|
327 |
+
if type(opt.config) == str:
|
328 |
+
opt.base = [opt.config]
|
329 |
+
else:
|
330 |
+
opt.base = [opt.base[-1]]
|
331 |
+
|
332 |
+
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
333 |
+
cli = OmegaConf.from_dotlist(unknown)
|
334 |
+
if opt.ignore_base_data:
|
335 |
+
for config in configs:
|
336 |
+
if hasattr(config, "data"): del config["data"]
|
337 |
+
config = OmegaConf.merge(*configs, cli)
|
338 |
+
|
339 |
+
st.sidebar.text(ckpt)
|
340 |
+
gs = st.sidebar.empty()
|
341 |
+
gs.text(f"Global step: ?")
|
342 |
+
st.sidebar.text("Options")
|
343 |
+
#gpu = st.sidebar.checkbox("GPU", value=True)
|
344 |
+
gpu = True
|
345 |
+
#eval_mode = st.sidebar.checkbox("Eval Mode", value=True)
|
346 |
+
eval_mode = True
|
347 |
+
#show_config = st.sidebar.checkbox("Show Config", value=False)
|
348 |
+
show_config = False
|
349 |
+
if show_config:
|
350 |
+
st.info("Checkpoint: {}".format(ckpt))
|
351 |
+
st.json(OmegaConf.to_container(config))
|
352 |
+
|
353 |
+
dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
|
354 |
+
gs.text(f"Global step: {global_step}")
|
355 |
+
run_conditional(model, dsets)
|
taming-transformers/scripts/sample_fast.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse, os, sys, glob
|
2 |
+
import torch
|
3 |
+
import time
|
4 |
+
import numpy as np
|
5 |
+
from omegaconf import OmegaConf
|
6 |
+
from PIL import Image
|
7 |
+
from tqdm import tqdm, trange
|
8 |
+
from einops import repeat
|
9 |
+
|
10 |
+
from main import instantiate_from_config
|
11 |
+
from taming.modules.transformer.mingpt import sample_with_past
|
12 |
+
|
13 |
+
|
14 |
+
rescale = lambda x: (x + 1.) / 2.
|
15 |
+
|
16 |
+
|
17 |
+
def chw_to_pillow(x):
|
18 |
+
return Image.fromarray((255*rescale(x.detach().cpu().numpy().transpose(1,2,0))).clip(0,255).astype(np.uint8))
|
19 |
+
|
20 |
+
|
21 |
+
@torch.no_grad()
|
22 |
+
def sample_classconditional(model, batch_size, class_label, steps=256, temperature=None, top_k=None, callback=None,
|
23 |
+
dim_z=256, h=16, w=16, verbose_time=False, top_p=None):
|
24 |
+
log = dict()
|
25 |
+
assert type(class_label) == int, f'expecting type int but type is {type(class_label)}'
|
26 |
+
qzshape = [batch_size, dim_z, h, w]
|
27 |
+
assert not model.be_unconditional, 'Expecting a class-conditional Net2NetTransformer.'
|
28 |
+
c_indices = repeat(torch.tensor([class_label]), '1 -> b 1', b=batch_size).to(model.device) # class token
|
29 |
+
t1 = time.time()
|
30 |
+
index_sample = sample_with_past(c_indices, model.transformer, steps=steps,
|
31 |
+
sample_logits=True, top_k=top_k, callback=callback,
|
32 |
+
temperature=temperature, top_p=top_p)
|
33 |
+
if verbose_time:
|
34 |
+
sampling_time = time.time() - t1
|
35 |
+
print(f"Full sampling takes about {sampling_time:.2f} seconds.")
|
36 |
+
x_sample = model.decode_to_img(index_sample, qzshape)
|
37 |
+
log["samples"] = x_sample
|
38 |
+
log["class_label"] = c_indices
|
39 |
+
return log
|
40 |
+
|
41 |
+
|
42 |
+
@torch.no_grad()
|
43 |
+
def sample_unconditional(model, batch_size, steps=256, temperature=None, top_k=None, top_p=None, callback=None,
|
44 |
+
dim_z=256, h=16, w=16, verbose_time=False):
|
45 |
+
log = dict()
|
46 |
+
qzshape = [batch_size, dim_z, h, w]
|
47 |
+
assert model.be_unconditional, 'Expecting an unconditional model.'
|
48 |
+
c_indices = repeat(torch.tensor([model.sos_token]), '1 -> b 1', b=batch_size).to(model.device) # sos token
|
49 |
+
t1 = time.time()
|
50 |
+
index_sample = sample_with_past(c_indices, model.transformer, steps=steps,
|
51 |
+
sample_logits=True, top_k=top_k, callback=callback,
|
52 |
+
temperature=temperature, top_p=top_p)
|
53 |
+
if verbose_time:
|
54 |
+
sampling_time = time.time() - t1
|
55 |
+
print(f"Full sampling takes about {sampling_time:.2f} seconds.")
|
56 |
+
x_sample = model.decode_to_img(index_sample, qzshape)
|
57 |
+
log["samples"] = x_sample
|
58 |
+
return log
|
59 |
+
|
60 |
+
|
61 |
+
@torch.no_grad()
|
62 |
+
def run(logdir, model, batch_size, temperature, top_k, unconditional=True, num_samples=50000,
|
63 |
+
given_classes=None, top_p=None):
|
64 |
+
batches = [batch_size for _ in range(num_samples//batch_size)] + [num_samples % batch_size]
|
65 |
+
if not unconditional:
|
66 |
+
assert given_classes is not None
|
67 |
+
print("Running in pure class-conditional sampling mode. I will produce "
|
68 |
+
f"{num_samples} samples for each of the {len(given_classes)} classes, "
|
69 |
+
f"i.e. {num_samples*len(given_classes)} in total.")
|
70 |
+
for class_label in tqdm(given_classes, desc="Classes"):
|
71 |
+
for n, bs in tqdm(enumerate(batches), desc="Sampling Class"):
|
72 |
+
if bs == 0: break
|
73 |
+
logs = sample_classconditional(model, batch_size=bs, class_label=class_label,
|
74 |
+
temperature=temperature, top_k=top_k, top_p=top_p)
|
75 |
+
save_from_logs(logs, logdir, base_count=n * batch_size, cond_key=logs["class_label"])
|
76 |
+
else:
|
77 |
+
print(f"Running in unconditional sampling mode, producing {num_samples} samples.")
|
78 |
+
for n, bs in tqdm(enumerate(batches), desc="Sampling"):
|
79 |
+
if bs == 0: break
|
80 |
+
logs = sample_unconditional(model, batch_size=bs, temperature=temperature, top_k=top_k, top_p=top_p)
|
81 |
+
save_from_logs(logs, logdir, base_count=n * batch_size)
|
82 |
+
|
83 |
+
|
84 |
+
def save_from_logs(logs, logdir, base_count, key="samples", cond_key=None):
|
85 |
+
xx = logs[key]
|
86 |
+
for i, x in enumerate(xx):
|
87 |
+
x = chw_to_pillow(x)
|
88 |
+
count = base_count + i
|
89 |
+
if cond_key is None:
|
90 |
+
x.save(os.path.join(logdir, f"{count:06}.png"))
|
91 |
+
else:
|
92 |
+
condlabel = cond_key[i]
|
93 |
+
if type(condlabel) == torch.Tensor: condlabel = condlabel.item()
|
94 |
+
os.makedirs(os.path.join(logdir, str(condlabel)), exist_ok=True)
|
95 |
+
x.save(os.path.join(logdir, str(condlabel), f"{count:06}.png"))
|
96 |
+
|
97 |
+
|
98 |
+
def get_parser():
|
99 |
+
def str2bool(v):
|
100 |
+
if isinstance(v, bool):
|
101 |
+
return v
|
102 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
103 |
+
return True
|
104 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
105 |
+
return False
|
106 |
+
else:
|
107 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
108 |
+
|
109 |
+
parser = argparse.ArgumentParser()
|
110 |
+
parser.add_argument(
|
111 |
+
"-r",
|
112 |
+
"--resume",
|
113 |
+
type=str,
|
114 |
+
nargs="?",
|
115 |
+
help="load from logdir or checkpoint in logdir",
|
116 |
+
)
|
117 |
+
parser.add_argument(
|
118 |
+
"-o",
|
119 |
+
"--outdir",
|
120 |
+
type=str,
|
121 |
+
nargs="?",
|
122 |
+
help="path where the samples will be logged to.",
|
123 |
+
default=""
|
124 |
+
)
|
125 |
+
parser.add_argument(
|
126 |
+
"-b",
|
127 |
+
"--base",
|
128 |
+
nargs="*",
|
129 |
+
metavar="base_config.yaml",
|
130 |
+
help="paths to base configs. Loaded from left-to-right. "
|
131 |
+
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
|
132 |
+
default=list(),
|
133 |
+
)
|
134 |
+
parser.add_argument(
|
135 |
+
"-n",
|
136 |
+
"--num_samples",
|
137 |
+
type=int,
|
138 |
+
nargs="?",
|
139 |
+
help="num_samples to draw",
|
140 |
+
default=50000
|
141 |
+
)
|
142 |
+
parser.add_argument(
|
143 |
+
"--batch_size",
|
144 |
+
type=int,
|
145 |
+
nargs="?",
|
146 |
+
help="the batch size",
|
147 |
+
default=25
|
148 |
+
)
|
149 |
+
parser.add_argument(
|
150 |
+
"-k",
|
151 |
+
"--top_k",
|
152 |
+
type=int,
|
153 |
+
nargs="?",
|
154 |
+
help="top-k value to sample with",
|
155 |
+
default=250,
|
156 |
+
)
|
157 |
+
parser.add_argument(
|
158 |
+
"-t",
|
159 |
+
"--temperature",
|
160 |
+
type=float,
|
161 |
+
nargs="?",
|
162 |
+
help="temperature value to sample with",
|
163 |
+
default=1.0
|
164 |
+
)
|
165 |
+
parser.add_argument(
|
166 |
+
"-p",
|
167 |
+
"--top_p",
|
168 |
+
type=float,
|
169 |
+
nargs="?",
|
170 |
+
help="top-p value to sample with",
|
171 |
+
default=1.0
|
172 |
+
)
|
173 |
+
parser.add_argument(
|
174 |
+
"--classes",
|
175 |
+
type=str,
|
176 |
+
nargs="?",
|
177 |
+
help="specify comma-separated classes to sample from. Uses 1000 classes per default.",
|
178 |
+
default="imagenet"
|
179 |
+
)
|
180 |
+
return parser
|
181 |
+
|
182 |
+
|
183 |
+
def load_model_from_config(config, sd, gpu=True, eval_mode=True):
|
184 |
+
model = instantiate_from_config(config)
|
185 |
+
if sd is not None:
|
186 |
+
model.load_state_dict(sd)
|
187 |
+
if gpu:
|
188 |
+
model.cuda()
|
189 |
+
if eval_mode:
|
190 |
+
model.eval()
|
191 |
+
return {"model": model}
|
192 |
+
|
193 |
+
|
194 |
+
def load_model(config, ckpt, gpu, eval_mode):
|
195 |
+
# load the specified checkpoint
|
196 |
+
if ckpt:
|
197 |
+
pl_sd = torch.load(ckpt, map_location="cpu")
|
198 |
+
global_step = pl_sd["global_step"]
|
199 |
+
print(f"loaded model from global step {global_step}.")
|
200 |
+
else:
|
201 |
+
pl_sd = {"state_dict": None}
|
202 |
+
global_step = None
|
203 |
+
model = load_model_from_config(config.model, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode)["model"]
|
204 |
+
return model, global_step
|
205 |
+
|
206 |
+
|
207 |
+
if __name__ == "__main__":
|
208 |
+
sys.path.append(os.getcwd())
|
209 |
+
parser = get_parser()
|
210 |
+
|
211 |
+
opt, unknown = parser.parse_known_args()
|
212 |
+
assert opt.resume
|
213 |
+
|
214 |
+
ckpt = None
|
215 |
+
|
216 |
+
if not os.path.exists(opt.resume):
|
217 |
+
raise ValueError("Cannot find {}".format(opt.resume))
|
218 |
+
if os.path.isfile(opt.resume):
|
219 |
+
paths = opt.resume.split("/")
|
220 |
+
try:
|
221 |
+
idx = len(paths)-paths[::-1].index("logs")+1
|
222 |
+
except ValueError:
|
223 |
+
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
|
224 |
+
logdir = "/".join(paths[:idx])
|
225 |
+
ckpt = opt.resume
|
226 |
+
else:
|
227 |
+
assert os.path.isdir(opt.resume), opt.resume
|
228 |
+
logdir = opt.resume.rstrip("/")
|
229 |
+
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
230 |
+
|
231 |
+
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
|
232 |
+
opt.base = base_configs+opt.base
|
233 |
+
|
234 |
+
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
235 |
+
cli = OmegaConf.from_dotlist(unknown)
|
236 |
+
config = OmegaConf.merge(*configs, cli)
|
237 |
+
|
238 |
+
model, global_step = load_model(config, ckpt, gpu=True, eval_mode=True)
|
239 |
+
|
240 |
+
if opt.outdir:
|
241 |
+
print(f"Switching logdir from '{logdir}' to '{opt.outdir}'")
|
242 |
+
logdir = opt.outdir
|
243 |
+
|
244 |
+
if opt.classes == "imagenet":
|
245 |
+
given_classes = [i for i in range(1000)]
|
246 |
+
else:
|
247 |
+
cls_str = opt.classes
|
248 |
+
assert not cls_str.endswith(","), 'class string should not end with a ","'
|
249 |
+
given_classes = [int(c) for c in cls_str.split(",")]
|
250 |
+
|
251 |
+
logdir = os.path.join(logdir, "samples", f"top_k_{opt.top_k}_temp_{opt.temperature:.2f}_top_p_{opt.top_p}",
|
252 |
+
f"{global_step}")
|
253 |
+
|
254 |
+
print(f"Logging to {logdir}")
|
255 |
+
os.makedirs(logdir, exist_ok=True)
|
256 |
+
|
257 |
+
run(logdir, model, opt.batch_size, opt.temperature, opt.top_k, unconditional=model.be_unconditional,
|
258 |
+
given_classes=given_classes, num_samples=opt.num_samples, top_p=opt.top_p)
|
259 |
+
|
260 |
+
print("done.")
|
taming-transformers/setup.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup, find_packages
|
2 |
+
|
3 |
+
setup(
|
4 |
+
name='taming-transformers',
|
5 |
+
version='0.0.1',
|
6 |
+
description='Taming Transformers for High-Resolution Image Synthesis',
|
7 |
+
packages=find_packages(),
|
8 |
+
install_requires=[
|
9 |
+
'torch',
|
10 |
+
'numpy',
|
11 |
+
'tqdm',
|
12 |
+
],
|
13 |
+
)
|
taming-transformers/taming/lr_scheduler.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
class LambdaWarmUpCosineScheduler:
|
5 |
+
"""
|
6 |
+
note: use with a base_lr of 1.0
|
7 |
+
"""
|
8 |
+
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
|
9 |
+
self.lr_warm_up_steps = warm_up_steps
|
10 |
+
self.lr_start = lr_start
|
11 |
+
self.lr_min = lr_min
|
12 |
+
self.lr_max = lr_max
|
13 |
+
self.lr_max_decay_steps = max_decay_steps
|
14 |
+
self.last_lr = 0.
|
15 |
+
self.verbosity_interval = verbosity_interval
|
16 |
+
|
17 |
+
def schedule(self, n):
|
18 |
+
if self.verbosity_interval > 0:
|
19 |
+
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
20 |
+
if n < self.lr_warm_up_steps:
|
21 |
+
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
|
22 |
+
self.last_lr = lr
|
23 |
+
return lr
|
24 |
+
else:
|
25 |
+
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
|
26 |
+
t = min(t, 1.0)
|
27 |
+
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
|
28 |
+
1 + np.cos(t * np.pi))
|
29 |
+
self.last_lr = lr
|
30 |
+
return lr
|
31 |
+
|
32 |
+
def __call__(self, n):
|
33 |
+
return self.schedule(n)
|
34 |
+
|
taming-transformers/taming/models/cond_transformer.py
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, math
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
|
6 |
+
from main import instantiate_from_config
|
7 |
+
from taming.modules.util import SOSProvider
|
8 |
+
|
9 |
+
|
10 |
+
def disabled_train(self, mode=True):
|
11 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
12 |
+
does not change anymore."""
|
13 |
+
return self
|
14 |
+
|
15 |
+
|
16 |
+
class Net2NetTransformer(pl.LightningModule):
|
17 |
+
def __init__(self,
|
18 |
+
transformer_config,
|
19 |
+
first_stage_config,
|
20 |
+
cond_stage_config,
|
21 |
+
permuter_config=None,
|
22 |
+
ckpt_path=None,
|
23 |
+
ignore_keys=[],
|
24 |
+
first_stage_key="image",
|
25 |
+
cond_stage_key="depth",
|
26 |
+
downsample_cond_size=-1,
|
27 |
+
pkeep=1.0,
|
28 |
+
sos_token=0,
|
29 |
+
unconditional=False,
|
30 |
+
):
|
31 |
+
super().__init__()
|
32 |
+
self.be_unconditional = unconditional
|
33 |
+
self.sos_token = sos_token
|
34 |
+
self.first_stage_key = first_stage_key
|
35 |
+
self.cond_stage_key = cond_stage_key
|
36 |
+
self.init_first_stage_from_ckpt(first_stage_config)
|
37 |
+
self.init_cond_stage_from_ckpt(cond_stage_config)
|
38 |
+
if permuter_config is None:
|
39 |
+
permuter_config = {"target": "taming.modules.transformer.permuter.Identity"}
|
40 |
+
self.permuter = instantiate_from_config(config=permuter_config)
|
41 |
+
self.transformer = instantiate_from_config(config=transformer_config)
|
42 |
+
|
43 |
+
if ckpt_path is not None:
|
44 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
45 |
+
self.downsample_cond_size = downsample_cond_size
|
46 |
+
self.pkeep = pkeep
|
47 |
+
|
48 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
49 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
50 |
+
for k in sd.keys():
|
51 |
+
for ik in ignore_keys:
|
52 |
+
if k.startswith(ik):
|
53 |
+
self.print("Deleting key {} from state_dict.".format(k))
|
54 |
+
del sd[k]
|
55 |
+
self.load_state_dict(sd, strict=False)
|
56 |
+
print(f"Restored from {path}")
|
57 |
+
|
58 |
+
def init_first_stage_from_ckpt(self, config):
|
59 |
+
model = instantiate_from_config(config)
|
60 |
+
model = model.eval()
|
61 |
+
model.train = disabled_train
|
62 |
+
self.first_stage_model = model
|
63 |
+
|
64 |
+
def init_cond_stage_from_ckpt(self, config):
|
65 |
+
if config == "__is_first_stage__":
|
66 |
+
print("Using first stage also as cond stage.")
|
67 |
+
self.cond_stage_model = self.first_stage_model
|
68 |
+
elif config == "__is_unconditional__" or self.be_unconditional:
|
69 |
+
print(f"Using no cond stage. Assuming the training is intended to be unconditional. "
|
70 |
+
f"Prepending {self.sos_token} as a sos token.")
|
71 |
+
self.be_unconditional = True
|
72 |
+
self.cond_stage_key = self.first_stage_key
|
73 |
+
self.cond_stage_model = SOSProvider(self.sos_token)
|
74 |
+
else:
|
75 |
+
model = instantiate_from_config(config)
|
76 |
+
model = model.eval()
|
77 |
+
model.train = disabled_train
|
78 |
+
self.cond_stage_model = model
|
79 |
+
|
80 |
+
def forward(self, x, c):
|
81 |
+
# one step to produce the logits
|
82 |
+
_, z_indices = self.encode_to_z(x)
|
83 |
+
_, c_indices = self.encode_to_c(c)
|
84 |
+
|
85 |
+
if self.training and self.pkeep < 1.0:
|
86 |
+
mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape,
|
87 |
+
device=z_indices.device))
|
88 |
+
mask = mask.round().to(dtype=torch.int64)
|
89 |
+
r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
|
90 |
+
a_indices = mask*z_indices+(1-mask)*r_indices
|
91 |
+
else:
|
92 |
+
a_indices = z_indices
|
93 |
+
|
94 |
+
cz_indices = torch.cat((c_indices, a_indices), dim=1)
|
95 |
+
|
96 |
+
# target includes all sequence elements (no need to handle first one
|
97 |
+
# differently because we are conditioning)
|
98 |
+
target = z_indices
|
99 |
+
# make the prediction
|
100 |
+
logits, _ = self.transformer(cz_indices[:, :-1])
|
101 |
+
# cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
|
102 |
+
logits = logits[:, c_indices.shape[1]-1:]
|
103 |
+
|
104 |
+
return logits, target
|
105 |
+
|
106 |
+
def top_k_logits(self, logits, k):
|
107 |
+
v, ix = torch.topk(logits, k)
|
108 |
+
out = logits.clone()
|
109 |
+
out[out < v[..., [-1]]] = -float('Inf')
|
110 |
+
return out
|
111 |
+
|
112 |
+
@torch.no_grad()
|
113 |
+
def sample(self, x, c, steps, temperature=1.0, sample=False, top_k=None,
|
114 |
+
callback=lambda k: None):
|
115 |
+
x = torch.cat((c,x),dim=1)
|
116 |
+
block_size = self.transformer.get_block_size()
|
117 |
+
assert not self.transformer.training
|
118 |
+
if self.pkeep <= 0.0:
|
119 |
+
# one pass suffices since input is pure noise anyway
|
120 |
+
assert len(x.shape)==2
|
121 |
+
noise_shape = (x.shape[0], steps-1)
|
122 |
+
#noise = torch.randint(self.transformer.config.vocab_size, noise_shape).to(x)
|
123 |
+
noise = c.clone()[:,x.shape[1]-c.shape[1]:-1]
|
124 |
+
x = torch.cat((x,noise),dim=1)
|
125 |
+
logits, _ = self.transformer(x)
|
126 |
+
# take all logits for now and scale by temp
|
127 |
+
logits = logits / temperature
|
128 |
+
# optionally crop probabilities to only the top k options
|
129 |
+
if top_k is not None:
|
130 |
+
logits = self.top_k_logits(logits, top_k)
|
131 |
+
# apply softmax to convert to probabilities
|
132 |
+
probs = F.softmax(logits, dim=-1)
|
133 |
+
# sample from the distribution or take the most likely
|
134 |
+
if sample:
|
135 |
+
shape = probs.shape
|
136 |
+
probs = probs.reshape(shape[0]*shape[1],shape[2])
|
137 |
+
ix = torch.multinomial(probs, num_samples=1)
|
138 |
+
probs = probs.reshape(shape[0],shape[1],shape[2])
|
139 |
+
ix = ix.reshape(shape[0],shape[1])
|
140 |
+
else:
|
141 |
+
_, ix = torch.topk(probs, k=1, dim=-1)
|
142 |
+
# cut off conditioning
|
143 |
+
x = ix[:, c.shape[1]-1:]
|
144 |
+
else:
|
145 |
+
for k in range(steps):
|
146 |
+
callback(k)
|
147 |
+
assert x.size(1) <= block_size # make sure model can see conditioning
|
148 |
+
x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
|
149 |
+
logits, _ = self.transformer(x_cond)
|
150 |
+
# pluck the logits at the final step and scale by temperature
|
151 |
+
logits = logits[:, -1, :] / temperature
|
152 |
+
# optionally crop probabilities to only the top k options
|
153 |
+
if top_k is not None:
|
154 |
+
logits = self.top_k_logits(logits, top_k)
|
155 |
+
# apply softmax to convert to probabilities
|
156 |
+
probs = F.softmax(logits, dim=-1)
|
157 |
+
# sample from the distribution or take the most likely
|
158 |
+
if sample:
|
159 |
+
ix = torch.multinomial(probs, num_samples=1)
|
160 |
+
else:
|
161 |
+
_, ix = torch.topk(probs, k=1, dim=-1)
|
162 |
+
# append to the sequence and continue
|
163 |
+
x = torch.cat((x, ix), dim=1)
|
164 |
+
# cut off conditioning
|
165 |
+
x = x[:, c.shape[1]:]
|
166 |
+
return x
|
167 |
+
|
168 |
+
@torch.no_grad()
|
169 |
+
def encode_to_z(self, x):
|
170 |
+
quant_z, _, info = self.first_stage_model.encode(x)
|
171 |
+
indices = info[2].view(quant_z.shape[0], -1)
|
172 |
+
indices = self.permuter(indices)
|
173 |
+
return quant_z, indices
|
174 |
+
|
175 |
+
@torch.no_grad()
|
176 |
+
def encode_to_c(self, c):
|
177 |
+
if self.downsample_cond_size > -1:
|
178 |
+
c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size))
|
179 |
+
quant_c, _, [_,_,indices] = self.cond_stage_model.encode(c)
|
180 |
+
if len(indices.shape) > 2:
|
181 |
+
indices = indices.view(c.shape[0], -1)
|
182 |
+
return quant_c, indices
|
183 |
+
|
184 |
+
@torch.no_grad()
|
185 |
+
def decode_to_img(self, index, zshape):
|
186 |
+
index = self.permuter(index, reverse=True)
|
187 |
+
bhwc = (zshape[0],zshape[2],zshape[3],zshape[1])
|
188 |
+
quant_z = self.first_stage_model.quantize.get_codebook_entry(
|
189 |
+
index.reshape(-1), shape=bhwc)
|
190 |
+
x = self.first_stage_model.decode(quant_z)
|
191 |
+
return x
|
192 |
+
|
193 |
+
@torch.no_grad()
|
194 |
+
def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs):
|
195 |
+
log = dict()
|
196 |
+
|
197 |
+
N = 4
|
198 |
+
if lr_interface:
|
199 |
+
x, c = self.get_xc(batch, N, diffuse=False, upsample_factor=8)
|
200 |
+
else:
|
201 |
+
x, c = self.get_xc(batch, N)
|
202 |
+
x = x.to(device=self.device)
|
203 |
+
c = c.to(device=self.device)
|
204 |
+
|
205 |
+
quant_z, z_indices = self.encode_to_z(x)
|
206 |
+
quant_c, c_indices = self.encode_to_c(c)
|
207 |
+
|
208 |
+
# create a "half"" sample
|
209 |
+
z_start_indices = z_indices[:,:z_indices.shape[1]//2]
|
210 |
+
index_sample = self.sample(z_start_indices, c_indices,
|
211 |
+
steps=z_indices.shape[1]-z_start_indices.shape[1],
|
212 |
+
temperature=temperature if temperature is not None else 1.0,
|
213 |
+
sample=True,
|
214 |
+
top_k=top_k if top_k is not None else 100,
|
215 |
+
callback=callback if callback is not None else lambda k: None)
|
216 |
+
x_sample = self.decode_to_img(index_sample, quant_z.shape)
|
217 |
+
|
218 |
+
# sample
|
219 |
+
z_start_indices = z_indices[:, :0]
|
220 |
+
index_sample = self.sample(z_start_indices, c_indices,
|
221 |
+
steps=z_indices.shape[1],
|
222 |
+
temperature=temperature if temperature is not None else 1.0,
|
223 |
+
sample=True,
|
224 |
+
top_k=top_k if top_k is not None else 100,
|
225 |
+
callback=callback if callback is not None else lambda k: None)
|
226 |
+
x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape)
|
227 |
+
|
228 |
+
# det sample
|
229 |
+
z_start_indices = z_indices[:, :0]
|
230 |
+
index_sample = self.sample(z_start_indices, c_indices,
|
231 |
+
steps=z_indices.shape[1],
|
232 |
+
sample=False,
|
233 |
+
callback=callback if callback is not None else lambda k: None)
|
234 |
+
x_sample_det = self.decode_to_img(index_sample, quant_z.shape)
|
235 |
+
|
236 |
+
# reconstruction
|
237 |
+
x_rec = self.decode_to_img(z_indices, quant_z.shape)
|
238 |
+
|
239 |
+
log["inputs"] = x
|
240 |
+
log["reconstructions"] = x_rec
|
241 |
+
|
242 |
+
if self.cond_stage_key in ["objects_bbox", "objects_center_points"]:
|
243 |
+
figure_size = (x_rec.shape[2], x_rec.shape[3])
|
244 |
+
dataset = kwargs["pl_module"].trainer.datamodule.datasets["validation"]
|
245 |
+
label_for_category_no = dataset.get_textual_label_for_category_no
|
246 |
+
plotter = dataset.conditional_builders[self.cond_stage_key].plot
|
247 |
+
log["conditioning"] = torch.zeros_like(log["reconstructions"])
|
248 |
+
for i in range(quant_c.shape[0]):
|
249 |
+
log["conditioning"][i] = plotter(quant_c[i], label_for_category_no, figure_size)
|
250 |
+
log["conditioning_rec"] = log["conditioning"]
|
251 |
+
elif self.cond_stage_key != "image":
|
252 |
+
cond_rec = self.cond_stage_model.decode(quant_c)
|
253 |
+
if self.cond_stage_key == "segmentation":
|
254 |
+
# get image from segmentation mask
|
255 |
+
num_classes = cond_rec.shape[1]
|
256 |
+
|
257 |
+
c = torch.argmax(c, dim=1, keepdim=True)
|
258 |
+
c = F.one_hot(c, num_classes=num_classes)
|
259 |
+
c = c.squeeze(1).permute(0, 3, 1, 2).float()
|
260 |
+
c = self.cond_stage_model.to_rgb(c)
|
261 |
+
|
262 |
+
cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True)
|
263 |
+
cond_rec = F.one_hot(cond_rec, num_classes=num_classes)
|
264 |
+
cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float()
|
265 |
+
cond_rec = self.cond_stage_model.to_rgb(cond_rec)
|
266 |
+
log["conditioning_rec"] = cond_rec
|
267 |
+
log["conditioning"] = c
|
268 |
+
|
269 |
+
log["samples_half"] = x_sample
|
270 |
+
log["samples_nopix"] = x_sample_nopix
|
271 |
+
log["samples_det"] = x_sample_det
|
272 |
+
return log
|
273 |
+
|
274 |
+
def get_input(self, key, batch):
|
275 |
+
x = batch[key]
|
276 |
+
if len(x.shape) == 3:
|
277 |
+
x = x[..., None]
|
278 |
+
if len(x.shape) == 4:
|
279 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
|
280 |
+
if x.dtype == torch.double:
|
281 |
+
x = x.float()
|
282 |
+
return x
|
283 |
+
|
284 |
+
def get_xc(self, batch, N=None):
|
285 |
+
x = self.get_input(self.first_stage_key, batch)
|
286 |
+
c = self.get_input(self.cond_stage_key, batch)
|
287 |
+
if N is not None:
|
288 |
+
x = x[:N]
|
289 |
+
c = c[:N]
|
290 |
+
return x, c
|
291 |
+
|
292 |
+
def shared_step(self, batch, batch_idx):
|
293 |
+
x, c = self.get_xc(batch)
|
294 |
+
logits, target = self(x, c)
|
295 |
+
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
|
296 |
+
return loss
|
297 |
+
|
298 |
+
def training_step(self, batch, batch_idx):
|
299 |
+
loss = self.shared_step(batch, batch_idx)
|
300 |
+
self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
301 |
+
return loss
|
302 |
+
|
303 |
+
def validation_step(self, batch, batch_idx):
|
304 |
+
loss = self.shared_step(batch, batch_idx)
|
305 |
+
self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
306 |
+
return loss
|
307 |
+
|
308 |
+
def configure_optimizers(self):
|
309 |
+
"""
|
310 |
+
Following minGPT:
|
311 |
+
This long function is unfortunately doing something very simple and is being very defensive:
|
312 |
+
We are separating out all parameters of the model into two buckets: those that will experience
|
313 |
+
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
|
314 |
+
We are then returning the PyTorch optimizer object.
|
315 |
+
"""
|
316 |
+
# separate out all parameters to those that will and won't experience regularizing weight decay
|
317 |
+
decay = set()
|
318 |
+
no_decay = set()
|
319 |
+
whitelist_weight_modules = (torch.nn.Linear, )
|
320 |
+
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
321 |
+
for mn, m in self.transformer.named_modules():
|
322 |
+
for pn, p in m.named_parameters():
|
323 |
+
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
|
324 |
+
|
325 |
+
if pn.endswith('bias'):
|
326 |
+
# all biases will not be decayed
|
327 |
+
no_decay.add(fpn)
|
328 |
+
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
|
329 |
+
# weights of whitelist modules will be weight decayed
|
330 |
+
decay.add(fpn)
|
331 |
+
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
|
332 |
+
# weights of blacklist modules will NOT be weight decayed
|
333 |
+
no_decay.add(fpn)
|
334 |
+
|
335 |
+
# special case the position embedding parameter in the root GPT module as not decayed
|
336 |
+
no_decay.add('pos_emb')
|
337 |
+
|
338 |
+
# validate that we considered every parameter
|
339 |
+
param_dict = {pn: p for pn, p in self.transformer.named_parameters()}
|
340 |
+
inter_params = decay & no_decay
|
341 |
+
union_params = decay | no_decay
|
342 |
+
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
|
343 |
+
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
|
344 |
+
% (str(param_dict.keys() - union_params), )
|
345 |
+
|
346 |
+
# create the pytorch optimizer object
|
347 |
+
optim_groups = [
|
348 |
+
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01},
|
349 |
+
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
350 |
+
]
|
351 |
+
optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95))
|
352 |
+
return optimizer
|
taming-transformers/taming/models/dummy_cond_stage.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import Tensor
|
2 |
+
|
3 |
+
|
4 |
+
class DummyCondStage:
|
5 |
+
def __init__(self, conditional_key):
|
6 |
+
self.conditional_key = conditional_key
|
7 |
+
self.train = None
|
8 |
+
|
9 |
+
def eval(self):
|
10 |
+
return self
|
11 |
+
|
12 |
+
@staticmethod
|
13 |
+
def encode(c: Tensor):
|
14 |
+
return c, None, (None, None, c)
|
15 |
+
|
16 |
+
@staticmethod
|
17 |
+
def decode(c: Tensor):
|
18 |
+
return c
|
19 |
+
|
20 |
+
@staticmethod
|
21 |
+
def to_rgb(c: Tensor):
|
22 |
+
return c
|
taming-transformers/taming/models/vqgan.py
ADDED
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
|
5 |
+
from main import instantiate_from_config
|
6 |
+
|
7 |
+
from taming.modules.diffusionmodules.model import Encoder, Decoder
|
8 |
+
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
9 |
+
from taming.modules.vqvae.quantize import GumbelQuantize
|
10 |
+
from taming.modules.vqvae.quantize import EMAVectorQuantizer
|
11 |
+
|
12 |
+
class VQModel(pl.LightningModule):
|
13 |
+
def __init__(self,
|
14 |
+
ddconfig,
|
15 |
+
lossconfig,
|
16 |
+
n_embed,
|
17 |
+
embed_dim,
|
18 |
+
ckpt_path=None,
|
19 |
+
ignore_keys=[],
|
20 |
+
image_key="image",
|
21 |
+
colorize_nlabels=None,
|
22 |
+
monitor=None,
|
23 |
+
remap=None,
|
24 |
+
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
25 |
+
):
|
26 |
+
super().__init__()
|
27 |
+
self.image_key = image_key
|
28 |
+
self.encoder = Encoder(**ddconfig)
|
29 |
+
self.decoder = Decoder(**ddconfig)
|
30 |
+
self.loss = instantiate_from_config(lossconfig)
|
31 |
+
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
32 |
+
remap=remap, sane_index_shape=sane_index_shape)
|
33 |
+
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
34 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
35 |
+
if ckpt_path is not None:
|
36 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
37 |
+
self.image_key = image_key
|
38 |
+
if colorize_nlabels is not None:
|
39 |
+
assert type(colorize_nlabels)==int
|
40 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
41 |
+
if monitor is not None:
|
42 |
+
self.monitor = monitor
|
43 |
+
|
44 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
45 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
46 |
+
keys = list(sd.keys())
|
47 |
+
for k in keys:
|
48 |
+
for ik in ignore_keys:
|
49 |
+
if k.startswith(ik):
|
50 |
+
print("Deleting key {} from state_dict.".format(k))
|
51 |
+
del sd[k]
|
52 |
+
self.load_state_dict(sd, strict=False)
|
53 |
+
print(f"Restored from {path}")
|
54 |
+
|
55 |
+
def encode(self, x):
|
56 |
+
h = self.encoder(x)
|
57 |
+
h = self.quant_conv(h)
|
58 |
+
quant, emb_loss, info = self.quantize(h)
|
59 |
+
return quant, emb_loss, info
|
60 |
+
|
61 |
+
def decode(self, quant):
|
62 |
+
quant = self.post_quant_conv(quant)
|
63 |
+
dec = self.decoder(quant)
|
64 |
+
return dec
|
65 |
+
|
66 |
+
def decode_code(self, code_b):
|
67 |
+
quant_b = self.quantize.embed_code(code_b)
|
68 |
+
dec = self.decode(quant_b)
|
69 |
+
return dec
|
70 |
+
|
71 |
+
def forward(self, input):
|
72 |
+
quant, diff, _ = self.encode(input)
|
73 |
+
dec = self.decode(quant)
|
74 |
+
return dec, diff
|
75 |
+
|
76 |
+
def get_input(self, batch, k):
|
77 |
+
x = batch[k]
|
78 |
+
if len(x.shape) == 3:
|
79 |
+
x = x[..., None]
|
80 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
|
81 |
+
return x.float()
|
82 |
+
|
83 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
84 |
+
x = self.get_input(batch, self.image_key)
|
85 |
+
xrec, qloss = self(x)
|
86 |
+
|
87 |
+
if optimizer_idx == 0:
|
88 |
+
# autoencode
|
89 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
90 |
+
last_layer=self.get_last_layer(), split="train")
|
91 |
+
|
92 |
+
self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
93 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
94 |
+
return aeloss
|
95 |
+
|
96 |
+
if optimizer_idx == 1:
|
97 |
+
# discriminator
|
98 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
99 |
+
last_layer=self.get_last_layer(), split="train")
|
100 |
+
self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
101 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
102 |
+
return discloss
|
103 |
+
|
104 |
+
def validation_step(self, batch, batch_idx):
|
105 |
+
x = self.get_input(batch, self.image_key)
|
106 |
+
xrec, qloss = self(x)
|
107 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
|
108 |
+
last_layer=self.get_last_layer(), split="val")
|
109 |
+
|
110 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
|
111 |
+
last_layer=self.get_last_layer(), split="val")
|
112 |
+
rec_loss = log_dict_ae["val/rec_loss"]
|
113 |
+
self.log("val/rec_loss", rec_loss,
|
114 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
|
115 |
+
self.log("val/aeloss", aeloss,
|
116 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
|
117 |
+
self.log_dict(log_dict_ae)
|
118 |
+
self.log_dict(log_dict_disc)
|
119 |
+
return self.log_dict
|
120 |
+
|
121 |
+
def configure_optimizers(self):
|
122 |
+
lr = self.learning_rate
|
123 |
+
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
124 |
+
list(self.decoder.parameters())+
|
125 |
+
list(self.quantize.parameters())+
|
126 |
+
list(self.quant_conv.parameters())+
|
127 |
+
list(self.post_quant_conv.parameters()),
|
128 |
+
lr=lr, betas=(0.5, 0.9))
|
129 |
+
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
130 |
+
lr=lr, betas=(0.5, 0.9))
|
131 |
+
return [opt_ae, opt_disc], []
|
132 |
+
|
133 |
+
def get_last_layer(self):
|
134 |
+
return self.decoder.conv_out.weight
|
135 |
+
|
136 |
+
def log_images(self, batch, **kwargs):
|
137 |
+
log = dict()
|
138 |
+
x = self.get_input(batch, self.image_key)
|
139 |
+
x = x.to(self.device)
|
140 |
+
xrec, _ = self(x)
|
141 |
+
if x.shape[1] > 3:
|
142 |
+
# colorize with random projection
|
143 |
+
assert xrec.shape[1] > 3
|
144 |
+
x = self.to_rgb(x)
|
145 |
+
xrec = self.to_rgb(xrec)
|
146 |
+
log["inputs"] = x
|
147 |
+
log["reconstructions"] = xrec
|
148 |
+
return log
|
149 |
+
|
150 |
+
def to_rgb(self, x):
|
151 |
+
assert self.image_key == "segmentation"
|
152 |
+
if not hasattr(self, "colorize"):
|
153 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
154 |
+
x = F.conv2d(x, weight=self.colorize)
|
155 |
+
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
156 |
+
return x
|
157 |
+
|
158 |
+
|
159 |
+
class VQSegmentationModel(VQModel):
|
160 |
+
def __init__(self, n_labels, *args, **kwargs):
|
161 |
+
super().__init__(*args, **kwargs)
|
162 |
+
self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1))
|
163 |
+
|
164 |
+
def configure_optimizers(self):
|
165 |
+
lr = self.learning_rate
|
166 |
+
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
167 |
+
list(self.decoder.parameters())+
|
168 |
+
list(self.quantize.parameters())+
|
169 |
+
list(self.quant_conv.parameters())+
|
170 |
+
list(self.post_quant_conv.parameters()),
|
171 |
+
lr=lr, betas=(0.5, 0.9))
|
172 |
+
return opt_ae
|
173 |
+
|
174 |
+
def training_step(self, batch, batch_idx):
|
175 |
+
x = self.get_input(batch, self.image_key)
|
176 |
+
xrec, qloss = self(x)
|
177 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
|
178 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
179 |
+
return aeloss
|
180 |
+
|
181 |
+
def validation_step(self, batch, batch_idx):
|
182 |
+
x = self.get_input(batch, self.image_key)
|
183 |
+
xrec, qloss = self(x)
|
184 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val")
|
185 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
186 |
+
total_loss = log_dict_ae["val/total_loss"]
|
187 |
+
self.log("val/total_loss", total_loss,
|
188 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
|
189 |
+
return aeloss
|
190 |
+
|
191 |
+
@torch.no_grad()
|
192 |
+
def log_images(self, batch, **kwargs):
|
193 |
+
log = dict()
|
194 |
+
x = self.get_input(batch, self.image_key)
|
195 |
+
x = x.to(self.device)
|
196 |
+
xrec, _ = self(x)
|
197 |
+
if x.shape[1] > 3:
|
198 |
+
# colorize with random projection
|
199 |
+
assert xrec.shape[1] > 3
|
200 |
+
# convert logits to indices
|
201 |
+
xrec = torch.argmax(xrec, dim=1, keepdim=True)
|
202 |
+
xrec = F.one_hot(xrec, num_classes=x.shape[1])
|
203 |
+
xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
|
204 |
+
x = self.to_rgb(x)
|
205 |
+
xrec = self.to_rgb(xrec)
|
206 |
+
log["inputs"] = x
|
207 |
+
log["reconstructions"] = xrec
|
208 |
+
return log
|
209 |
+
|
210 |
+
|
211 |
+
class VQNoDiscModel(VQModel):
|
212 |
+
def __init__(self,
|
213 |
+
ddconfig,
|
214 |
+
lossconfig,
|
215 |
+
n_embed,
|
216 |
+
embed_dim,
|
217 |
+
ckpt_path=None,
|
218 |
+
ignore_keys=[],
|
219 |
+
image_key="image",
|
220 |
+
colorize_nlabels=None
|
221 |
+
):
|
222 |
+
super().__init__(ddconfig=ddconfig, lossconfig=lossconfig, n_embed=n_embed, embed_dim=embed_dim,
|
223 |
+
ckpt_path=ckpt_path, ignore_keys=ignore_keys, image_key=image_key,
|
224 |
+
colorize_nlabels=colorize_nlabels)
|
225 |
+
|
226 |
+
def training_step(self, batch, batch_idx):
|
227 |
+
x = self.get_input(batch, self.image_key)
|
228 |
+
xrec, qloss = self(x)
|
229 |
+
# autoencode
|
230 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train")
|
231 |
+
output = pl.TrainResult(minimize=aeloss)
|
232 |
+
output.log("train/aeloss", aeloss,
|
233 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
234 |
+
output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
235 |
+
return output
|
236 |
+
|
237 |
+
def validation_step(self, batch, batch_idx):
|
238 |
+
x = self.get_input(batch, self.image_key)
|
239 |
+
xrec, qloss = self(x)
|
240 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val")
|
241 |
+
rec_loss = log_dict_ae["val/rec_loss"]
|
242 |
+
output = pl.EvalResult(checkpoint_on=rec_loss)
|
243 |
+
output.log("val/rec_loss", rec_loss,
|
244 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
245 |
+
output.log("val/aeloss", aeloss,
|
246 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
247 |
+
output.log_dict(log_dict_ae)
|
248 |
+
|
249 |
+
return output
|
250 |
+
|
251 |
+
def configure_optimizers(self):
|
252 |
+
optimizer = torch.optim.Adam(list(self.encoder.parameters())+
|
253 |
+
list(self.decoder.parameters())+
|
254 |
+
list(self.quantize.parameters())+
|
255 |
+
list(self.quant_conv.parameters())+
|
256 |
+
list(self.post_quant_conv.parameters()),
|
257 |
+
lr=self.learning_rate, betas=(0.5, 0.9))
|
258 |
+
return optimizer
|
259 |
+
|
260 |
+
|
261 |
+
class GumbelVQ(VQModel):
|
262 |
+
def __init__(self,
|
263 |
+
ddconfig,
|
264 |
+
lossconfig,
|
265 |
+
n_embed,
|
266 |
+
embed_dim,
|
267 |
+
temperature_scheduler_config,
|
268 |
+
ckpt_path=None,
|
269 |
+
ignore_keys=[],
|
270 |
+
image_key="image",
|
271 |
+
colorize_nlabels=None,
|
272 |
+
monitor=None,
|
273 |
+
kl_weight=1e-8,
|
274 |
+
remap=None,
|
275 |
+
):
|
276 |
+
|
277 |
+
z_channels = ddconfig["z_channels"]
|
278 |
+
super().__init__(ddconfig,
|
279 |
+
lossconfig,
|
280 |
+
n_embed,
|
281 |
+
embed_dim,
|
282 |
+
ckpt_path=None,
|
283 |
+
ignore_keys=ignore_keys,
|
284 |
+
image_key=image_key,
|
285 |
+
colorize_nlabels=colorize_nlabels,
|
286 |
+
monitor=monitor,
|
287 |
+
)
|
288 |
+
|
289 |
+
self.loss.n_classes = n_embed
|
290 |
+
self.vocab_size = n_embed
|
291 |
+
|
292 |
+
self.quantize = GumbelQuantize(z_channels, embed_dim,
|
293 |
+
n_embed=n_embed,
|
294 |
+
kl_weight=kl_weight, temp_init=1.0,
|
295 |
+
remap=remap)
|
296 |
+
|
297 |
+
self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp
|
298 |
+
|
299 |
+
if ckpt_path is not None:
|
300 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
301 |
+
|
302 |
+
def temperature_scheduling(self):
|
303 |
+
self.quantize.temperature = self.temperature_scheduler(self.global_step)
|
304 |
+
|
305 |
+
def encode_to_prequant(self, x):
|
306 |
+
h = self.encoder(x)
|
307 |
+
h = self.quant_conv(h)
|
308 |
+
return h
|
309 |
+
|
310 |
+
def decode_code(self, code_b):
|
311 |
+
raise NotImplementedError
|
312 |
+
|
313 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
314 |
+
self.temperature_scheduling()
|
315 |
+
x = self.get_input(batch, self.image_key)
|
316 |
+
xrec, qloss = self(x)
|
317 |
+
|
318 |
+
if optimizer_idx == 0:
|
319 |
+
# autoencode
|
320 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
321 |
+
last_layer=self.get_last_layer(), split="train")
|
322 |
+
|
323 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
324 |
+
self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
325 |
+
return aeloss
|
326 |
+
|
327 |
+
if optimizer_idx == 1:
|
328 |
+
# discriminator
|
329 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
330 |
+
last_layer=self.get_last_layer(), split="train")
|
331 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
332 |
+
return discloss
|
333 |
+
|
334 |
+
def validation_step(self, batch, batch_idx):
|
335 |
+
x = self.get_input(batch, self.image_key)
|
336 |
+
xrec, qloss = self(x, return_pred_indices=True)
|
337 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
|
338 |
+
last_layer=self.get_last_layer(), split="val")
|
339 |
+
|
340 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
|
341 |
+
last_layer=self.get_last_layer(), split="val")
|
342 |
+
rec_loss = log_dict_ae["val/rec_loss"]
|
343 |
+
self.log("val/rec_loss", rec_loss,
|
344 |
+
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
345 |
+
self.log("val/aeloss", aeloss,
|
346 |
+
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
347 |
+
self.log_dict(log_dict_ae)
|
348 |
+
self.log_dict(log_dict_disc)
|
349 |
+
return self.log_dict
|
350 |
+
|
351 |
+
def log_images(self, batch, **kwargs):
|
352 |
+
log = dict()
|
353 |
+
x = self.get_input(batch, self.image_key)
|
354 |
+
x = x.to(self.device)
|
355 |
+
# encode
|
356 |
+
h = self.encoder(x)
|
357 |
+
h = self.quant_conv(h)
|
358 |
+
quant, _, _ = self.quantize(h)
|
359 |
+
# decode
|
360 |
+
x_rec = self.decode(quant)
|
361 |
+
log["inputs"] = x
|
362 |
+
log["reconstructions"] = x_rec
|
363 |
+
return log
|
364 |
+
|
365 |
+
|
366 |
+
class EMAVQ(VQModel):
|
367 |
+
def __init__(self,
|
368 |
+
ddconfig,
|
369 |
+
lossconfig,
|
370 |
+
n_embed,
|
371 |
+
embed_dim,
|
372 |
+
ckpt_path=None,
|
373 |
+
ignore_keys=[],
|
374 |
+
image_key="image",
|
375 |
+
colorize_nlabels=None,
|
376 |
+
monitor=None,
|
377 |
+
remap=None,
|
378 |
+
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
379 |
+
):
|
380 |
+
super().__init__(ddconfig,
|
381 |
+
lossconfig,
|
382 |
+
n_embed,
|
383 |
+
embed_dim,
|
384 |
+
ckpt_path=None,
|
385 |
+
ignore_keys=ignore_keys,
|
386 |
+
image_key=image_key,
|
387 |
+
colorize_nlabels=colorize_nlabels,
|
388 |
+
monitor=monitor,
|
389 |
+
)
|
390 |
+
self.quantize = EMAVectorQuantizer(n_embed=n_embed,
|
391 |
+
embedding_dim=embed_dim,
|
392 |
+
beta=0.25,
|
393 |
+
remap=remap)
|
394 |
+
def configure_optimizers(self):
|
395 |
+
lr = self.learning_rate
|
396 |
+
#Remove self.quantize from parameter list since it is updated via EMA
|
397 |
+
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
398 |
+
list(self.decoder.parameters())+
|
399 |
+
list(self.quant_conv.parameters())+
|
400 |
+
list(self.post_quant_conv.parameters()),
|
401 |
+
lr=lr, betas=(0.5, 0.9))
|
402 |
+
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
403 |
+
lr=lr, betas=(0.5, 0.9))
|
404 |
+
return [opt_ae, opt_disc], []
|
taming-transformers/taming/modules/diffusionmodules/model.py
ADDED
@@ -0,0 +1,776 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pytorch_diffusion + derived encoder decoder
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
def get_timestep_embedding(timesteps, embedding_dim):
|
9 |
+
"""
|
10 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
11 |
+
From Fairseq.
|
12 |
+
Build sinusoidal embeddings.
|
13 |
+
This matches the implementation in tensor2tensor, but differs slightly
|
14 |
+
from the description in Section 3.5 of "Attention Is All You Need".
|
15 |
+
"""
|
16 |
+
assert len(timesteps.shape) == 1
|
17 |
+
|
18 |
+
half_dim = embedding_dim // 2
|
19 |
+
emb = math.log(10000) / (half_dim - 1)
|
20 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
21 |
+
emb = emb.to(device=timesteps.device)
|
22 |
+
emb = timesteps.float()[:, None] * emb[None, :]
|
23 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
24 |
+
if embedding_dim % 2 == 1: # zero pad
|
25 |
+
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
26 |
+
return emb
|
27 |
+
|
28 |
+
|
29 |
+
def nonlinearity(x):
|
30 |
+
# swish
|
31 |
+
return x*torch.sigmoid(x)
|
32 |
+
|
33 |
+
|
34 |
+
def Normalize(in_channels):
|
35 |
+
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
36 |
+
|
37 |
+
|
38 |
+
class Upsample(nn.Module):
|
39 |
+
def __init__(self, in_channels, with_conv):
|
40 |
+
super().__init__()
|
41 |
+
self.with_conv = with_conv
|
42 |
+
if self.with_conv:
|
43 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
44 |
+
in_channels,
|
45 |
+
kernel_size=3,
|
46 |
+
stride=1,
|
47 |
+
padding=1)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
51 |
+
if self.with_conv:
|
52 |
+
x = self.conv(x)
|
53 |
+
return x
|
54 |
+
|
55 |
+
|
56 |
+
class Downsample(nn.Module):
|
57 |
+
def __init__(self, in_channels, with_conv):
|
58 |
+
super().__init__()
|
59 |
+
self.with_conv = with_conv
|
60 |
+
if self.with_conv:
|
61 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
62 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
63 |
+
in_channels,
|
64 |
+
kernel_size=3,
|
65 |
+
stride=2,
|
66 |
+
padding=0)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
if self.with_conv:
|
70 |
+
pad = (0,1,0,1)
|
71 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
72 |
+
x = self.conv(x)
|
73 |
+
else:
|
74 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
75 |
+
return x
|
76 |
+
|
77 |
+
|
78 |
+
class ResnetBlock(nn.Module):
|
79 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
80 |
+
dropout, temb_channels=512):
|
81 |
+
super().__init__()
|
82 |
+
self.in_channels = in_channels
|
83 |
+
out_channels = in_channels if out_channels is None else out_channels
|
84 |
+
self.out_channels = out_channels
|
85 |
+
self.use_conv_shortcut = conv_shortcut
|
86 |
+
|
87 |
+
self.norm1 = Normalize(in_channels)
|
88 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
89 |
+
out_channels,
|
90 |
+
kernel_size=3,
|
91 |
+
stride=1,
|
92 |
+
padding=1)
|
93 |
+
if temb_channels > 0:
|
94 |
+
self.temb_proj = torch.nn.Linear(temb_channels,
|
95 |
+
out_channels)
|
96 |
+
self.norm2 = Normalize(out_channels)
|
97 |
+
self.dropout = torch.nn.Dropout(dropout)
|
98 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
99 |
+
out_channels,
|
100 |
+
kernel_size=3,
|
101 |
+
stride=1,
|
102 |
+
padding=1)
|
103 |
+
if self.in_channels != self.out_channels:
|
104 |
+
if self.use_conv_shortcut:
|
105 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
106 |
+
out_channels,
|
107 |
+
kernel_size=3,
|
108 |
+
stride=1,
|
109 |
+
padding=1)
|
110 |
+
else:
|
111 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
112 |
+
out_channels,
|
113 |
+
kernel_size=1,
|
114 |
+
stride=1,
|
115 |
+
padding=0)
|
116 |
+
|
117 |
+
def forward(self, x, temb):
|
118 |
+
h = x
|
119 |
+
h = self.norm1(h)
|
120 |
+
h = nonlinearity(h)
|
121 |
+
h = self.conv1(h)
|
122 |
+
|
123 |
+
if temb is not None:
|
124 |
+
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
125 |
+
|
126 |
+
h = self.norm2(h)
|
127 |
+
h = nonlinearity(h)
|
128 |
+
h = self.dropout(h)
|
129 |
+
h = self.conv2(h)
|
130 |
+
|
131 |
+
if self.in_channels != self.out_channels:
|
132 |
+
if self.use_conv_shortcut:
|
133 |
+
x = self.conv_shortcut(x)
|
134 |
+
else:
|
135 |
+
x = self.nin_shortcut(x)
|
136 |
+
|
137 |
+
return x+h
|
138 |
+
|
139 |
+
|
140 |
+
class AttnBlock(nn.Module):
|
141 |
+
def __init__(self, in_channels):
|
142 |
+
super().__init__()
|
143 |
+
self.in_channels = in_channels
|
144 |
+
|
145 |
+
self.norm = Normalize(in_channels)
|
146 |
+
self.q = torch.nn.Conv2d(in_channels,
|
147 |
+
in_channels,
|
148 |
+
kernel_size=1,
|
149 |
+
stride=1,
|
150 |
+
padding=0)
|
151 |
+
self.k = torch.nn.Conv2d(in_channels,
|
152 |
+
in_channels,
|
153 |
+
kernel_size=1,
|
154 |
+
stride=1,
|
155 |
+
padding=0)
|
156 |
+
self.v = torch.nn.Conv2d(in_channels,
|
157 |
+
in_channels,
|
158 |
+
kernel_size=1,
|
159 |
+
stride=1,
|
160 |
+
padding=0)
|
161 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
162 |
+
in_channels,
|
163 |
+
kernel_size=1,
|
164 |
+
stride=1,
|
165 |
+
padding=0)
|
166 |
+
|
167 |
+
|
168 |
+
def forward(self, x):
|
169 |
+
h_ = x
|
170 |
+
h_ = self.norm(h_)
|
171 |
+
q = self.q(h_)
|
172 |
+
k = self.k(h_)
|
173 |
+
v = self.v(h_)
|
174 |
+
|
175 |
+
# compute attention
|
176 |
+
b,c,h,w = q.shape
|
177 |
+
q = q.reshape(b,c,h*w)
|
178 |
+
q = q.permute(0,2,1) # b,hw,c
|
179 |
+
k = k.reshape(b,c,h*w) # b,c,hw
|
180 |
+
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
181 |
+
w_ = w_ * (int(c)**(-0.5))
|
182 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
183 |
+
|
184 |
+
# attend to values
|
185 |
+
v = v.reshape(b,c,h*w)
|
186 |
+
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
187 |
+
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
188 |
+
h_ = h_.reshape(b,c,h,w)
|
189 |
+
|
190 |
+
h_ = self.proj_out(h_)
|
191 |
+
|
192 |
+
return x+h_
|
193 |
+
|
194 |
+
|
195 |
+
class Model(nn.Module):
|
196 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
197 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
198 |
+
resolution, use_timestep=True):
|
199 |
+
super().__init__()
|
200 |
+
self.ch = ch
|
201 |
+
self.temb_ch = self.ch*4
|
202 |
+
self.num_resolutions = len(ch_mult)
|
203 |
+
self.num_res_blocks = num_res_blocks
|
204 |
+
self.resolution = resolution
|
205 |
+
self.in_channels = in_channels
|
206 |
+
|
207 |
+
self.use_timestep = use_timestep
|
208 |
+
if self.use_timestep:
|
209 |
+
# timestep embedding
|
210 |
+
self.temb = nn.Module()
|
211 |
+
self.temb.dense = nn.ModuleList([
|
212 |
+
torch.nn.Linear(self.ch,
|
213 |
+
self.temb_ch),
|
214 |
+
torch.nn.Linear(self.temb_ch,
|
215 |
+
self.temb_ch),
|
216 |
+
])
|
217 |
+
|
218 |
+
# downsampling
|
219 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
220 |
+
self.ch,
|
221 |
+
kernel_size=3,
|
222 |
+
stride=1,
|
223 |
+
padding=1)
|
224 |
+
|
225 |
+
curr_res = resolution
|
226 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
227 |
+
self.down = nn.ModuleList()
|
228 |
+
for i_level in range(self.num_resolutions):
|
229 |
+
block = nn.ModuleList()
|
230 |
+
attn = nn.ModuleList()
|
231 |
+
block_in = ch*in_ch_mult[i_level]
|
232 |
+
block_out = ch*ch_mult[i_level]
|
233 |
+
for i_block in range(self.num_res_blocks):
|
234 |
+
block.append(ResnetBlock(in_channels=block_in,
|
235 |
+
out_channels=block_out,
|
236 |
+
temb_channels=self.temb_ch,
|
237 |
+
dropout=dropout))
|
238 |
+
block_in = block_out
|
239 |
+
if curr_res in attn_resolutions:
|
240 |
+
attn.append(AttnBlock(block_in))
|
241 |
+
down = nn.Module()
|
242 |
+
down.block = block
|
243 |
+
down.attn = attn
|
244 |
+
if i_level != self.num_resolutions-1:
|
245 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
246 |
+
curr_res = curr_res // 2
|
247 |
+
self.down.append(down)
|
248 |
+
|
249 |
+
# middle
|
250 |
+
self.mid = nn.Module()
|
251 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
252 |
+
out_channels=block_in,
|
253 |
+
temb_channels=self.temb_ch,
|
254 |
+
dropout=dropout)
|
255 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
256 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
257 |
+
out_channels=block_in,
|
258 |
+
temb_channels=self.temb_ch,
|
259 |
+
dropout=dropout)
|
260 |
+
|
261 |
+
# upsampling
|
262 |
+
self.up = nn.ModuleList()
|
263 |
+
for i_level in reversed(range(self.num_resolutions)):
|
264 |
+
block = nn.ModuleList()
|
265 |
+
attn = nn.ModuleList()
|
266 |
+
block_out = ch*ch_mult[i_level]
|
267 |
+
skip_in = ch*ch_mult[i_level]
|
268 |
+
for i_block in range(self.num_res_blocks+1):
|
269 |
+
if i_block == self.num_res_blocks:
|
270 |
+
skip_in = ch*in_ch_mult[i_level]
|
271 |
+
block.append(ResnetBlock(in_channels=block_in+skip_in,
|
272 |
+
out_channels=block_out,
|
273 |
+
temb_channels=self.temb_ch,
|
274 |
+
dropout=dropout))
|
275 |
+
block_in = block_out
|
276 |
+
if curr_res in attn_resolutions:
|
277 |
+
attn.append(AttnBlock(block_in))
|
278 |
+
up = nn.Module()
|
279 |
+
up.block = block
|
280 |
+
up.attn = attn
|
281 |
+
if i_level != 0:
|
282 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
283 |
+
curr_res = curr_res * 2
|
284 |
+
self.up.insert(0, up) # prepend to get consistent order
|
285 |
+
|
286 |
+
# end
|
287 |
+
self.norm_out = Normalize(block_in)
|
288 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
289 |
+
out_ch,
|
290 |
+
kernel_size=3,
|
291 |
+
stride=1,
|
292 |
+
padding=1)
|
293 |
+
|
294 |
+
|
295 |
+
def forward(self, x, t=None):
|
296 |
+
#assert x.shape[2] == x.shape[3] == self.resolution
|
297 |
+
|
298 |
+
if self.use_timestep:
|
299 |
+
# timestep embedding
|
300 |
+
assert t is not None
|
301 |
+
temb = get_timestep_embedding(t, self.ch)
|
302 |
+
temb = self.temb.dense[0](temb)
|
303 |
+
temb = nonlinearity(temb)
|
304 |
+
temb = self.temb.dense[1](temb)
|
305 |
+
else:
|
306 |
+
temb = None
|
307 |
+
|
308 |
+
# downsampling
|
309 |
+
hs = [self.conv_in(x)]
|
310 |
+
for i_level in range(self.num_resolutions):
|
311 |
+
for i_block in range(self.num_res_blocks):
|
312 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
313 |
+
if len(self.down[i_level].attn) > 0:
|
314 |
+
h = self.down[i_level].attn[i_block](h)
|
315 |
+
hs.append(h)
|
316 |
+
if i_level != self.num_resolutions-1:
|
317 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
318 |
+
|
319 |
+
# middle
|
320 |
+
h = hs[-1]
|
321 |
+
h = self.mid.block_1(h, temb)
|
322 |
+
h = self.mid.attn_1(h)
|
323 |
+
h = self.mid.block_2(h, temb)
|
324 |
+
|
325 |
+
# upsampling
|
326 |
+
for i_level in reversed(range(self.num_resolutions)):
|
327 |
+
for i_block in range(self.num_res_blocks+1):
|
328 |
+
h = self.up[i_level].block[i_block](
|
329 |
+
torch.cat([h, hs.pop()], dim=1), temb)
|
330 |
+
if len(self.up[i_level].attn) > 0:
|
331 |
+
h = self.up[i_level].attn[i_block](h)
|
332 |
+
if i_level != 0:
|
333 |
+
h = self.up[i_level].upsample(h)
|
334 |
+
|
335 |
+
# end
|
336 |
+
h = self.norm_out(h)
|
337 |
+
h = nonlinearity(h)
|
338 |
+
h = self.conv_out(h)
|
339 |
+
return h
|
340 |
+
|
341 |
+
|
342 |
+
class Encoder(nn.Module):
|
343 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
344 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
345 |
+
resolution, z_channels, double_z=True, **ignore_kwargs):
|
346 |
+
super().__init__()
|
347 |
+
self.ch = ch
|
348 |
+
self.temb_ch = 0
|
349 |
+
self.num_resolutions = len(ch_mult)
|
350 |
+
self.num_res_blocks = num_res_blocks
|
351 |
+
self.resolution = resolution
|
352 |
+
self.in_channels = in_channels
|
353 |
+
|
354 |
+
# downsampling
|
355 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
356 |
+
self.ch,
|
357 |
+
kernel_size=3,
|
358 |
+
stride=1,
|
359 |
+
padding=1)
|
360 |
+
|
361 |
+
curr_res = resolution
|
362 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
363 |
+
self.down = nn.ModuleList()
|
364 |
+
for i_level in range(self.num_resolutions):
|
365 |
+
block = nn.ModuleList()
|
366 |
+
attn = nn.ModuleList()
|
367 |
+
block_in = ch*in_ch_mult[i_level]
|
368 |
+
block_out = ch*ch_mult[i_level]
|
369 |
+
for i_block in range(self.num_res_blocks):
|
370 |
+
block.append(ResnetBlock(in_channels=block_in,
|
371 |
+
out_channels=block_out,
|
372 |
+
temb_channels=self.temb_ch,
|
373 |
+
dropout=dropout))
|
374 |
+
block_in = block_out
|
375 |
+
if curr_res in attn_resolutions:
|
376 |
+
attn.append(AttnBlock(block_in))
|
377 |
+
down = nn.Module()
|
378 |
+
down.block = block
|
379 |
+
down.attn = attn
|
380 |
+
if i_level != self.num_resolutions-1:
|
381 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
382 |
+
curr_res = curr_res // 2
|
383 |
+
self.down.append(down)
|
384 |
+
|
385 |
+
# middle
|
386 |
+
self.mid = nn.Module()
|
387 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
388 |
+
out_channels=block_in,
|
389 |
+
temb_channels=self.temb_ch,
|
390 |
+
dropout=dropout)
|
391 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
392 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
393 |
+
out_channels=block_in,
|
394 |
+
temb_channels=self.temb_ch,
|
395 |
+
dropout=dropout)
|
396 |
+
|
397 |
+
# end
|
398 |
+
self.norm_out = Normalize(block_in)
|
399 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
400 |
+
2*z_channels if double_z else z_channels,
|
401 |
+
kernel_size=3,
|
402 |
+
stride=1,
|
403 |
+
padding=1)
|
404 |
+
|
405 |
+
|
406 |
+
def forward(self, x):
|
407 |
+
#assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
|
408 |
+
|
409 |
+
# timestep embedding
|
410 |
+
temb = None
|
411 |
+
|
412 |
+
# downsampling
|
413 |
+
hs = [self.conv_in(x)]
|
414 |
+
for i_level in range(self.num_resolutions):
|
415 |
+
for i_block in range(self.num_res_blocks):
|
416 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
417 |
+
if len(self.down[i_level].attn) > 0:
|
418 |
+
h = self.down[i_level].attn[i_block](h)
|
419 |
+
hs.append(h)
|
420 |
+
if i_level != self.num_resolutions-1:
|
421 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
422 |
+
|
423 |
+
# middle
|
424 |
+
h = hs[-1]
|
425 |
+
h = self.mid.block_1(h, temb)
|
426 |
+
h = self.mid.attn_1(h)
|
427 |
+
h = self.mid.block_2(h, temb)
|
428 |
+
|
429 |
+
# end
|
430 |
+
h = self.norm_out(h)
|
431 |
+
h = nonlinearity(h)
|
432 |
+
h = self.conv_out(h)
|
433 |
+
return h
|
434 |
+
|
435 |
+
|
436 |
+
class Decoder(nn.Module):
|
437 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
438 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
439 |
+
resolution, z_channels, give_pre_end=False, **ignorekwargs):
|
440 |
+
super().__init__()
|
441 |
+
self.ch = ch
|
442 |
+
self.temb_ch = 0
|
443 |
+
self.num_resolutions = len(ch_mult)
|
444 |
+
self.num_res_blocks = num_res_blocks
|
445 |
+
self.resolution = resolution
|
446 |
+
self.in_channels = in_channels
|
447 |
+
self.give_pre_end = give_pre_end
|
448 |
+
|
449 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
450 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
451 |
+
block_in = ch*ch_mult[self.num_resolutions-1]
|
452 |
+
curr_res = resolution // 2**(self.num_resolutions-1)
|
453 |
+
self.z_shape = (1,z_channels,curr_res,curr_res)
|
454 |
+
print("Working with z of shape {} = {} dimensions.".format(
|
455 |
+
self.z_shape, np.prod(self.z_shape)))
|
456 |
+
|
457 |
+
# z to block_in
|
458 |
+
self.conv_in = torch.nn.Conv2d(z_channels,
|
459 |
+
block_in,
|
460 |
+
kernel_size=3,
|
461 |
+
stride=1,
|
462 |
+
padding=1)
|
463 |
+
|
464 |
+
# middle
|
465 |
+
self.mid = nn.Module()
|
466 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
467 |
+
out_channels=block_in,
|
468 |
+
temb_channels=self.temb_ch,
|
469 |
+
dropout=dropout)
|
470 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
471 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
472 |
+
out_channels=block_in,
|
473 |
+
temb_channels=self.temb_ch,
|
474 |
+
dropout=dropout)
|
475 |
+
|
476 |
+
# upsampling
|
477 |
+
self.up = nn.ModuleList()
|
478 |
+
for i_level in reversed(range(self.num_resolutions)):
|
479 |
+
block = nn.ModuleList()
|
480 |
+
attn = nn.ModuleList()
|
481 |
+
block_out = ch*ch_mult[i_level]
|
482 |
+
for i_block in range(self.num_res_blocks+1):
|
483 |
+
block.append(ResnetBlock(in_channels=block_in,
|
484 |
+
out_channels=block_out,
|
485 |
+
temb_channels=self.temb_ch,
|
486 |
+
dropout=dropout))
|
487 |
+
block_in = block_out
|
488 |
+
if curr_res in attn_resolutions:
|
489 |
+
attn.append(AttnBlock(block_in))
|
490 |
+
up = nn.Module()
|
491 |
+
up.block = block
|
492 |
+
up.attn = attn
|
493 |
+
if i_level != 0:
|
494 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
495 |
+
curr_res = curr_res * 2
|
496 |
+
self.up.insert(0, up) # prepend to get consistent order
|
497 |
+
|
498 |
+
# end
|
499 |
+
self.norm_out = Normalize(block_in)
|
500 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
501 |
+
out_ch,
|
502 |
+
kernel_size=3,
|
503 |
+
stride=1,
|
504 |
+
padding=1)
|
505 |
+
|
506 |
+
def forward(self, z):
|
507 |
+
#assert z.shape[1:] == self.z_shape[1:]
|
508 |
+
self.last_z_shape = z.shape
|
509 |
+
|
510 |
+
# timestep embedding
|
511 |
+
temb = None
|
512 |
+
|
513 |
+
# z to block_in
|
514 |
+
h = self.conv_in(z)
|
515 |
+
|
516 |
+
# middle
|
517 |
+
h = self.mid.block_1(h, temb)
|
518 |
+
h = self.mid.attn_1(h)
|
519 |
+
h = self.mid.block_2(h, temb)
|
520 |
+
|
521 |
+
# upsampling
|
522 |
+
for i_level in reversed(range(self.num_resolutions)):
|
523 |
+
for i_block in range(self.num_res_blocks+1):
|
524 |
+
h = self.up[i_level].block[i_block](h, temb)
|
525 |
+
if len(self.up[i_level].attn) > 0:
|
526 |
+
h = self.up[i_level].attn[i_block](h)
|
527 |
+
if i_level != 0:
|
528 |
+
h = self.up[i_level].upsample(h)
|
529 |
+
|
530 |
+
# end
|
531 |
+
if self.give_pre_end:
|
532 |
+
return h
|
533 |
+
|
534 |
+
h = self.norm_out(h)
|
535 |
+
h = nonlinearity(h)
|
536 |
+
h = self.conv_out(h)
|
537 |
+
return h
|
538 |
+
|
539 |
+
|
540 |
+
class VUNet(nn.Module):
|
541 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
542 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True,
|
543 |
+
in_channels, c_channels,
|
544 |
+
resolution, z_channels, use_timestep=False, **ignore_kwargs):
|
545 |
+
super().__init__()
|
546 |
+
self.ch = ch
|
547 |
+
self.temb_ch = self.ch*4
|
548 |
+
self.num_resolutions = len(ch_mult)
|
549 |
+
self.num_res_blocks = num_res_blocks
|
550 |
+
self.resolution = resolution
|
551 |
+
|
552 |
+
self.use_timestep = use_timestep
|
553 |
+
if self.use_timestep:
|
554 |
+
# timestep embedding
|
555 |
+
self.temb = nn.Module()
|
556 |
+
self.temb.dense = nn.ModuleList([
|
557 |
+
torch.nn.Linear(self.ch,
|
558 |
+
self.temb_ch),
|
559 |
+
torch.nn.Linear(self.temb_ch,
|
560 |
+
self.temb_ch),
|
561 |
+
])
|
562 |
+
|
563 |
+
# downsampling
|
564 |
+
self.conv_in = torch.nn.Conv2d(c_channels,
|
565 |
+
self.ch,
|
566 |
+
kernel_size=3,
|
567 |
+
stride=1,
|
568 |
+
padding=1)
|
569 |
+
|
570 |
+
curr_res = resolution
|
571 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
572 |
+
self.down = nn.ModuleList()
|
573 |
+
for i_level in range(self.num_resolutions):
|
574 |
+
block = nn.ModuleList()
|
575 |
+
attn = nn.ModuleList()
|
576 |
+
block_in = ch*in_ch_mult[i_level]
|
577 |
+
block_out = ch*ch_mult[i_level]
|
578 |
+
for i_block in range(self.num_res_blocks):
|
579 |
+
block.append(ResnetBlock(in_channels=block_in,
|
580 |
+
out_channels=block_out,
|
581 |
+
temb_channels=self.temb_ch,
|
582 |
+
dropout=dropout))
|
583 |
+
block_in = block_out
|
584 |
+
if curr_res in attn_resolutions:
|
585 |
+
attn.append(AttnBlock(block_in))
|
586 |
+
down = nn.Module()
|
587 |
+
down.block = block
|
588 |
+
down.attn = attn
|
589 |
+
if i_level != self.num_resolutions-1:
|
590 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
591 |
+
curr_res = curr_res // 2
|
592 |
+
self.down.append(down)
|
593 |
+
|
594 |
+
self.z_in = torch.nn.Conv2d(z_channels,
|
595 |
+
block_in,
|
596 |
+
kernel_size=1,
|
597 |
+
stride=1,
|
598 |
+
padding=0)
|
599 |
+
# middle
|
600 |
+
self.mid = nn.Module()
|
601 |
+
self.mid.block_1 = ResnetBlock(in_channels=2*block_in,
|
602 |
+
out_channels=block_in,
|
603 |
+
temb_channels=self.temb_ch,
|
604 |
+
dropout=dropout)
|
605 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
606 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
607 |
+
out_channels=block_in,
|
608 |
+
temb_channels=self.temb_ch,
|
609 |
+
dropout=dropout)
|
610 |
+
|
611 |
+
# upsampling
|
612 |
+
self.up = nn.ModuleList()
|
613 |
+
for i_level in reversed(range(self.num_resolutions)):
|
614 |
+
block = nn.ModuleList()
|
615 |
+
attn = nn.ModuleList()
|
616 |
+
block_out = ch*ch_mult[i_level]
|
617 |
+
skip_in = ch*ch_mult[i_level]
|
618 |
+
for i_block in range(self.num_res_blocks+1):
|
619 |
+
if i_block == self.num_res_blocks:
|
620 |
+
skip_in = ch*in_ch_mult[i_level]
|
621 |
+
block.append(ResnetBlock(in_channels=block_in+skip_in,
|
622 |
+
out_channels=block_out,
|
623 |
+
temb_channels=self.temb_ch,
|
624 |
+
dropout=dropout))
|
625 |
+
block_in = block_out
|
626 |
+
if curr_res in attn_resolutions:
|
627 |
+
attn.append(AttnBlock(block_in))
|
628 |
+
up = nn.Module()
|
629 |
+
up.block = block
|
630 |
+
up.attn = attn
|
631 |
+
if i_level != 0:
|
632 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
633 |
+
curr_res = curr_res * 2
|
634 |
+
self.up.insert(0, up) # prepend to get consistent order
|
635 |
+
|
636 |
+
# end
|
637 |
+
self.norm_out = Normalize(block_in)
|
638 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
639 |
+
out_ch,
|
640 |
+
kernel_size=3,
|
641 |
+
stride=1,
|
642 |
+
padding=1)
|
643 |
+
|
644 |
+
|
645 |
+
def forward(self, x, z):
|
646 |
+
#assert x.shape[2] == x.shape[3] == self.resolution
|
647 |
+
|
648 |
+
if self.use_timestep:
|
649 |
+
# timestep embedding
|
650 |
+
assert t is not None
|
651 |
+
temb = get_timestep_embedding(t, self.ch)
|
652 |
+
temb = self.temb.dense[0](temb)
|
653 |
+
temb = nonlinearity(temb)
|
654 |
+
temb = self.temb.dense[1](temb)
|
655 |
+
else:
|
656 |
+
temb = None
|
657 |
+
|
658 |
+
# downsampling
|
659 |
+
hs = [self.conv_in(x)]
|
660 |
+
for i_level in range(self.num_resolutions):
|
661 |
+
for i_block in range(self.num_res_blocks):
|
662 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
663 |
+
if len(self.down[i_level].attn) > 0:
|
664 |
+
h = self.down[i_level].attn[i_block](h)
|
665 |
+
hs.append(h)
|
666 |
+
if i_level != self.num_resolutions-1:
|
667 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
668 |
+
|
669 |
+
# middle
|
670 |
+
h = hs[-1]
|
671 |
+
z = self.z_in(z)
|
672 |
+
h = torch.cat((h,z),dim=1)
|
673 |
+
h = self.mid.block_1(h, temb)
|
674 |
+
h = self.mid.attn_1(h)
|
675 |
+
h = self.mid.block_2(h, temb)
|
676 |
+
|
677 |
+
# upsampling
|
678 |
+
for i_level in reversed(range(self.num_resolutions)):
|
679 |
+
for i_block in range(self.num_res_blocks+1):
|
680 |
+
h = self.up[i_level].block[i_block](
|
681 |
+
torch.cat([h, hs.pop()], dim=1), temb)
|
682 |
+
if len(self.up[i_level].attn) > 0:
|
683 |
+
h = self.up[i_level].attn[i_block](h)
|
684 |
+
if i_level != 0:
|
685 |
+
h = self.up[i_level].upsample(h)
|
686 |
+
|
687 |
+
# end
|
688 |
+
h = self.norm_out(h)
|
689 |
+
h = nonlinearity(h)
|
690 |
+
h = self.conv_out(h)
|
691 |
+
return h
|
692 |
+
|
693 |
+
|
694 |
+
class SimpleDecoder(nn.Module):
|
695 |
+
def __init__(self, in_channels, out_channels, *args, **kwargs):
|
696 |
+
super().__init__()
|
697 |
+
self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
|
698 |
+
ResnetBlock(in_channels=in_channels,
|
699 |
+
out_channels=2 * in_channels,
|
700 |
+
temb_channels=0, dropout=0.0),
|
701 |
+
ResnetBlock(in_channels=2 * in_channels,
|
702 |
+
out_channels=4 * in_channels,
|
703 |
+
temb_channels=0, dropout=0.0),
|
704 |
+
ResnetBlock(in_channels=4 * in_channels,
|
705 |
+
out_channels=2 * in_channels,
|
706 |
+
temb_channels=0, dropout=0.0),
|
707 |
+
nn.Conv2d(2*in_channels, in_channels, 1),
|
708 |
+
Upsample(in_channels, with_conv=True)])
|
709 |
+
# end
|
710 |
+
self.norm_out = Normalize(in_channels)
|
711 |
+
self.conv_out = torch.nn.Conv2d(in_channels,
|
712 |
+
out_channels,
|
713 |
+
kernel_size=3,
|
714 |
+
stride=1,
|
715 |
+
padding=1)
|
716 |
+
|
717 |
+
def forward(self, x):
|
718 |
+
for i, layer in enumerate(self.model):
|
719 |
+
if i in [1,2,3]:
|
720 |
+
x = layer(x, None)
|
721 |
+
else:
|
722 |
+
x = layer(x)
|
723 |
+
|
724 |
+
h = self.norm_out(x)
|
725 |
+
h = nonlinearity(h)
|
726 |
+
x = self.conv_out(h)
|
727 |
+
return x
|
728 |
+
|
729 |
+
|
730 |
+
class UpsampleDecoder(nn.Module):
|
731 |
+
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
|
732 |
+
ch_mult=(2,2), dropout=0.0):
|
733 |
+
super().__init__()
|
734 |
+
# upsampling
|
735 |
+
self.temb_ch = 0
|
736 |
+
self.num_resolutions = len(ch_mult)
|
737 |
+
self.num_res_blocks = num_res_blocks
|
738 |
+
block_in = in_channels
|
739 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
740 |
+
self.res_blocks = nn.ModuleList()
|
741 |
+
self.upsample_blocks = nn.ModuleList()
|
742 |
+
for i_level in range(self.num_resolutions):
|
743 |
+
res_block = []
|
744 |
+
block_out = ch * ch_mult[i_level]
|
745 |
+
for i_block in range(self.num_res_blocks + 1):
|
746 |
+
res_block.append(ResnetBlock(in_channels=block_in,
|
747 |
+
out_channels=block_out,
|
748 |
+
temb_channels=self.temb_ch,
|
749 |
+
dropout=dropout))
|
750 |
+
block_in = block_out
|
751 |
+
self.res_blocks.append(nn.ModuleList(res_block))
|
752 |
+
if i_level != self.num_resolutions - 1:
|
753 |
+
self.upsample_blocks.append(Upsample(block_in, True))
|
754 |
+
curr_res = curr_res * 2
|
755 |
+
|
756 |
+
# end
|
757 |
+
self.norm_out = Normalize(block_in)
|
758 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
759 |
+
out_channels,
|
760 |
+
kernel_size=3,
|
761 |
+
stride=1,
|
762 |
+
padding=1)
|
763 |
+
|
764 |
+
def forward(self, x):
|
765 |
+
# upsampling
|
766 |
+
h = x
|
767 |
+
for k, i_level in enumerate(range(self.num_resolutions)):
|
768 |
+
for i_block in range(self.num_res_blocks + 1):
|
769 |
+
h = self.res_blocks[i_level][i_block](h, None)
|
770 |
+
if i_level != self.num_resolutions - 1:
|
771 |
+
h = self.upsample_blocks[k](h)
|
772 |
+
h = self.norm_out(h)
|
773 |
+
h = nonlinearity(h)
|
774 |
+
h = self.conv_out(h)
|
775 |
+
return h
|
776 |
+
|
taming-transformers/taming/modules/discriminator/model.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
from taming.modules.util import ActNorm
|
6 |
+
|
7 |
+
|
8 |
+
def weights_init(m):
|
9 |
+
classname = m.__class__.__name__
|
10 |
+
if classname.find('Conv') != -1:
|
11 |
+
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
12 |
+
elif classname.find('BatchNorm') != -1:
|
13 |
+
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
14 |
+
nn.init.constant_(m.bias.data, 0)
|
15 |
+
|
16 |
+
|
17 |
+
class NLayerDiscriminator(nn.Module):
|
18 |
+
"""Defines a PatchGAN discriminator as in Pix2Pix
|
19 |
+
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
20 |
+
"""
|
21 |
+
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
|
22 |
+
"""Construct a PatchGAN discriminator
|
23 |
+
Parameters:
|
24 |
+
input_nc (int) -- the number of channels in input images
|
25 |
+
ndf (int) -- the number of filters in the last conv layer
|
26 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
27 |
+
norm_layer -- normalization layer
|
28 |
+
"""
|
29 |
+
super(NLayerDiscriminator, self).__init__()
|
30 |
+
if not use_actnorm:
|
31 |
+
norm_layer = nn.BatchNorm2d
|
32 |
+
else:
|
33 |
+
norm_layer = ActNorm
|
34 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
35 |
+
use_bias = norm_layer.func != nn.BatchNorm2d
|
36 |
+
else:
|
37 |
+
use_bias = norm_layer != nn.BatchNorm2d
|
38 |
+
|
39 |
+
kw = 4
|
40 |
+
padw = 1
|
41 |
+
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
42 |
+
nf_mult = 1
|
43 |
+
nf_mult_prev = 1
|
44 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
45 |
+
nf_mult_prev = nf_mult
|
46 |
+
nf_mult = min(2 ** n, 8)
|
47 |
+
sequence += [
|
48 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
49 |
+
norm_layer(ndf * nf_mult),
|
50 |
+
nn.LeakyReLU(0.2, True)
|
51 |
+
]
|
52 |
+
|
53 |
+
nf_mult_prev = nf_mult
|
54 |
+
nf_mult = min(2 ** n_layers, 8)
|
55 |
+
sequence += [
|
56 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
57 |
+
norm_layer(ndf * nf_mult),
|
58 |
+
nn.LeakyReLU(0.2, True)
|
59 |
+
]
|
60 |
+
|
61 |
+
sequence += [
|
62 |
+
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
63 |
+
self.main = nn.Sequential(*sequence)
|
64 |
+
|
65 |
+
def forward(self, input):
|
66 |
+
"""Standard forward."""
|
67 |
+
return self.main(input)
|
taming-transformers/taming/modules/losses/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from taming.modules.losses.vqperceptual import DummyLoss
|
2 |
+
|
taming-transformers/taming/modules/losses/lpips.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torchvision import models
|
6 |
+
from collections import namedtuple
|
7 |
+
|
8 |
+
from taming.util import get_ckpt_path
|
9 |
+
|
10 |
+
|
11 |
+
class LPIPS(nn.Module):
|
12 |
+
# Learned perceptual metric
|
13 |
+
def __init__(self, use_dropout=True):
|
14 |
+
super().__init__()
|
15 |
+
self.scaling_layer = ScalingLayer()
|
16 |
+
self.chns = [64, 128, 256, 512, 512] # vg16 features
|
17 |
+
self.net = vgg16(pretrained=True, requires_grad=False)
|
18 |
+
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
19 |
+
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
20 |
+
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
21 |
+
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
22 |
+
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
23 |
+
self.load_from_pretrained()
|
24 |
+
for param in self.parameters():
|
25 |
+
param.requires_grad = False
|
26 |
+
|
27 |
+
def load_from_pretrained(self, name="vgg_lpips"):
|
28 |
+
ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
|
29 |
+
self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
|
30 |
+
print("loaded pretrained LPIPS loss from {}".format(ckpt))
|
31 |
+
|
32 |
+
@classmethod
|
33 |
+
def from_pretrained(cls, name="vgg_lpips"):
|
34 |
+
if name != "vgg_lpips":
|
35 |
+
raise NotImplementedError
|
36 |
+
model = cls()
|
37 |
+
ckpt = get_ckpt_path(name)
|
38 |
+
model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
|
39 |
+
return model
|
40 |
+
|
41 |
+
def forward(self, input, target):
|
42 |
+
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
|
43 |
+
outs0, outs1 = self.net(in0_input), self.net(in1_input)
|
44 |
+
feats0, feats1, diffs = {}, {}, {}
|
45 |
+
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
|
46 |
+
for kk in range(len(self.chns)):
|
47 |
+
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
|
48 |
+
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
|
49 |
+
|
50 |
+
res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
|
51 |
+
val = res[0]
|
52 |
+
for l in range(1, len(self.chns)):
|
53 |
+
val += res[l]
|
54 |
+
return val
|
55 |
+
|
56 |
+
|
57 |
+
class ScalingLayer(nn.Module):
|
58 |
+
def __init__(self):
|
59 |
+
super(ScalingLayer, self).__init__()
|
60 |
+
self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
|
61 |
+
self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
|
62 |
+
|
63 |
+
def forward(self, inp):
|
64 |
+
return (inp - self.shift) / self.scale
|
65 |
+
|
66 |
+
|
67 |
+
class NetLinLayer(nn.Module):
|
68 |
+
""" A single linear layer which does a 1x1 conv """
|
69 |
+
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
70 |
+
super(NetLinLayer, self).__init__()
|
71 |
+
layers = [nn.Dropout(), ] if (use_dropout) else []
|
72 |
+
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
|
73 |
+
self.model = nn.Sequential(*layers)
|
74 |
+
|
75 |
+
|
76 |
+
class vgg16(torch.nn.Module):
|
77 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
78 |
+
super(vgg16, self).__init__()
|
79 |
+
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
|
80 |
+
self.slice1 = torch.nn.Sequential()
|
81 |
+
self.slice2 = torch.nn.Sequential()
|
82 |
+
self.slice3 = torch.nn.Sequential()
|
83 |
+
self.slice4 = torch.nn.Sequential()
|
84 |
+
self.slice5 = torch.nn.Sequential()
|
85 |
+
self.N_slices = 5
|
86 |
+
for x in range(4):
|
87 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
88 |
+
for x in range(4, 9):
|
89 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
90 |
+
for x in range(9, 16):
|
91 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
92 |
+
for x in range(16, 23):
|
93 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
94 |
+
for x in range(23, 30):
|
95 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
96 |
+
if not requires_grad:
|
97 |
+
for param in self.parameters():
|
98 |
+
param.requires_grad = False
|
99 |
+
|
100 |
+
def forward(self, X):
|
101 |
+
h = self.slice1(X)
|
102 |
+
h_relu1_2 = h
|
103 |
+
h = self.slice2(h)
|
104 |
+
h_relu2_2 = h
|
105 |
+
h = self.slice3(h)
|
106 |
+
h_relu3_3 = h
|
107 |
+
h = self.slice4(h)
|
108 |
+
h_relu4_3 = h
|
109 |
+
h = self.slice5(h)
|
110 |
+
h_relu5_3 = h
|
111 |
+
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
|
112 |
+
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
113 |
+
return out
|
114 |
+
|
115 |
+
|
116 |
+
def normalize_tensor(x,eps=1e-10):
|
117 |
+
norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
|
118 |
+
return x/(norm_factor+eps)
|
119 |
+
|
120 |
+
|
121 |
+
def spatial_average(x, keepdim=True):
|
122 |
+
return x.mean([2,3],keepdim=keepdim)
|
123 |
+
|
taming-transformers/taming/modules/losses/segmentation.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
class BCELoss(nn.Module):
|
6 |
+
def forward(self, prediction, target):
|
7 |
+
loss = F.binary_cross_entropy_with_logits(prediction,target)
|
8 |
+
return loss, {}
|
9 |
+
|
10 |
+
|
11 |
+
class BCELossWithQuant(nn.Module):
|
12 |
+
def __init__(self, codebook_weight=1.):
|
13 |
+
super().__init__()
|
14 |
+
self.codebook_weight = codebook_weight
|
15 |
+
|
16 |
+
def forward(self, qloss, target, prediction, split):
|
17 |
+
bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
|
18 |
+
loss = bce_loss + self.codebook_weight*qloss
|
19 |
+
return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
|
20 |
+
"{}/bce_loss".format(split): bce_loss.detach().mean(),
|
21 |
+
"{}/quant_loss".format(split): qloss.detach().mean()
|
22 |
+
}
|
taming-transformers/taming/modules/losses/vqperceptual.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from taming.modules.losses.lpips import LPIPS
|
6 |
+
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
|
7 |
+
|
8 |
+
|
9 |
+
class DummyLoss(nn.Module):
|
10 |
+
def __init__(self):
|
11 |
+
super().__init__()
|
12 |
+
|
13 |
+
|
14 |
+
def adopt_weight(weight, global_step, threshold=0, value=0.):
|
15 |
+
if global_step < threshold:
|
16 |
+
weight = value
|
17 |
+
return weight
|
18 |
+
|
19 |
+
|
20 |
+
def hinge_d_loss(logits_real, logits_fake):
|
21 |
+
loss_real = torch.mean(F.relu(1. - logits_real))
|
22 |
+
loss_fake = torch.mean(F.relu(1. + logits_fake))
|
23 |
+
d_loss = 0.5 * (loss_real + loss_fake)
|
24 |
+
return d_loss
|
25 |
+
|
26 |
+
|
27 |
+
def vanilla_d_loss(logits_real, logits_fake):
|
28 |
+
d_loss = 0.5 * (
|
29 |
+
torch.mean(torch.nn.functional.softplus(-logits_real)) +
|
30 |
+
torch.mean(torch.nn.functional.softplus(logits_fake)))
|
31 |
+
return d_loss
|
32 |
+
|
33 |
+
|
34 |
+
class VQLPIPSWithDiscriminator(nn.Module):
|
35 |
+
def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
|
36 |
+
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
|
37 |
+
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
|
38 |
+
disc_ndf=64, disc_loss="hinge"):
|
39 |
+
super().__init__()
|
40 |
+
assert disc_loss in ["hinge", "vanilla"]
|
41 |
+
self.codebook_weight = codebook_weight
|
42 |
+
self.pixel_weight = pixelloss_weight
|
43 |
+
self.perceptual_loss = LPIPS().eval()
|
44 |
+
self.perceptual_weight = perceptual_weight
|
45 |
+
|
46 |
+
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
|
47 |
+
n_layers=disc_num_layers,
|
48 |
+
use_actnorm=use_actnorm,
|
49 |
+
ndf=disc_ndf
|
50 |
+
).apply(weights_init)
|
51 |
+
self.discriminator_iter_start = disc_start
|
52 |
+
if disc_loss == "hinge":
|
53 |
+
self.disc_loss = hinge_d_loss
|
54 |
+
elif disc_loss == "vanilla":
|
55 |
+
self.disc_loss = vanilla_d_loss
|
56 |
+
else:
|
57 |
+
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
|
58 |
+
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
|
59 |
+
self.disc_factor = disc_factor
|
60 |
+
self.discriminator_weight = disc_weight
|
61 |
+
self.disc_conditional = disc_conditional
|
62 |
+
|
63 |
+
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
|
64 |
+
if last_layer is not None:
|
65 |
+
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
66 |
+
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
67 |
+
else:
|
68 |
+
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
|
69 |
+
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
|
70 |
+
|
71 |
+
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
72 |
+
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
73 |
+
d_weight = d_weight * self.discriminator_weight
|
74 |
+
return d_weight
|
75 |
+
|
76 |
+
def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
|
77 |
+
global_step, last_layer=None, cond=None, split="train"):
|
78 |
+
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
79 |
+
if self.perceptual_weight > 0:
|
80 |
+
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
|
81 |
+
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
82 |
+
else:
|
83 |
+
p_loss = torch.tensor([0.0])
|
84 |
+
|
85 |
+
nll_loss = rec_loss
|
86 |
+
#nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
87 |
+
nll_loss = torch.mean(nll_loss)
|
88 |
+
|
89 |
+
# now the GAN part
|
90 |
+
if optimizer_idx == 0:
|
91 |
+
# generator update
|
92 |
+
if cond is None:
|
93 |
+
assert not self.disc_conditional
|
94 |
+
logits_fake = self.discriminator(reconstructions.contiguous())
|
95 |
+
else:
|
96 |
+
assert self.disc_conditional
|
97 |
+
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
|
98 |
+
g_loss = -torch.mean(logits_fake)
|
99 |
+
|
100 |
+
try:
|
101 |
+
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
|
102 |
+
except RuntimeError:
|
103 |
+
assert not self.training
|
104 |
+
d_weight = torch.tensor(0.0)
|
105 |
+
|
106 |
+
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
107 |
+
loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
|
108 |
+
|
109 |
+
log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
|
110 |
+
"{}/quant_loss".format(split): codebook_loss.detach().mean(),
|
111 |
+
"{}/nll_loss".format(split): nll_loss.detach().mean(),
|
112 |
+
"{}/rec_loss".format(split): rec_loss.detach().mean(),
|
113 |
+
"{}/p_loss".format(split): p_loss.detach().mean(),
|
114 |
+
"{}/d_weight".format(split): d_weight.detach(),
|
115 |
+
"{}/disc_factor".format(split): torch.tensor(disc_factor),
|
116 |
+
"{}/g_loss".format(split): g_loss.detach().mean(),
|
117 |
+
}
|
118 |
+
return loss, log
|
119 |
+
|
120 |
+
if optimizer_idx == 1:
|
121 |
+
# second pass for discriminator update
|
122 |
+
if cond is None:
|
123 |
+
logits_real = self.discriminator(inputs.contiguous().detach())
|
124 |
+
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
125 |
+
else:
|
126 |
+
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
|
127 |
+
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
|
128 |
+
|
129 |
+
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
130 |
+
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
|
131 |
+
|
132 |
+
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
|
133 |
+
"{}/logits_real".format(split): logits_real.detach().mean(),
|
134 |
+
"{}/logits_fake".format(split): logits_fake.detach().mean()
|
135 |
+
}
|
136 |
+
return d_loss, log
|
taming-transformers/taming/modules/misc/coord.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
class CoordStage(object):
|
4 |
+
def __init__(self, n_embed, down_factor):
|
5 |
+
self.n_embed = n_embed
|
6 |
+
self.down_factor = down_factor
|
7 |
+
|
8 |
+
def eval(self):
|
9 |
+
return self
|
10 |
+
|
11 |
+
def encode(self, c):
|
12 |
+
"""fake vqmodel interface"""
|
13 |
+
assert 0.0 <= c.min() and c.max() <= 1.0
|
14 |
+
b,ch,h,w = c.shape
|
15 |
+
assert ch == 1
|
16 |
+
|
17 |
+
c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor,
|
18 |
+
mode="area")
|
19 |
+
c = c.clamp(0.0, 1.0)
|
20 |
+
c = self.n_embed*c
|
21 |
+
c_quant = c.round()
|
22 |
+
c_ind = c_quant.to(dtype=torch.long)
|
23 |
+
|
24 |
+
info = None, None, c_ind
|
25 |
+
return c_quant, None, info
|
26 |
+
|
27 |
+
def decode(self, c):
|
28 |
+
c = c/self.n_embed
|
29 |
+
c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor,
|
30 |
+
mode="nearest")
|
31 |
+
return c
|
taming-transformers/taming/modules/transformer/mingpt.py
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
taken from: https://github.com/karpathy/minGPT/
|
3 |
+
GPT model:
|
4 |
+
- the initial stem consists of a combination of token encoding and a positional encoding
|
5 |
+
- the meat of it is a uniform sequence of Transformer blocks
|
6 |
+
- each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
|
7 |
+
- all blocks feed into a central residual pathway similar to resnets
|
8 |
+
- the final decoder is a linear projection into a vanilla Softmax classifier
|
9 |
+
"""
|
10 |
+
|
11 |
+
import math
|
12 |
+
import logging
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
from torch.nn import functional as F
|
17 |
+
from transformers import top_k_top_p_filtering
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
class GPTConfig:
|
23 |
+
""" base GPT config, params common to all GPT versions """
|
24 |
+
embd_pdrop = 0.1
|
25 |
+
resid_pdrop = 0.1
|
26 |
+
attn_pdrop = 0.1
|
27 |
+
|
28 |
+
def __init__(self, vocab_size, block_size, **kwargs):
|
29 |
+
self.vocab_size = vocab_size
|
30 |
+
self.block_size = block_size
|
31 |
+
for k,v in kwargs.items():
|
32 |
+
setattr(self, k, v)
|
33 |
+
|
34 |
+
|
35 |
+
class GPT1Config(GPTConfig):
|
36 |
+
""" GPT-1 like network roughly 125M params """
|
37 |
+
n_layer = 12
|
38 |
+
n_head = 12
|
39 |
+
n_embd = 768
|
40 |
+
|
41 |
+
|
42 |
+
class CausalSelfAttention(nn.Module):
|
43 |
+
"""
|
44 |
+
A vanilla multi-head masked self-attention layer with a projection at the end.
|
45 |
+
It is possible to use torch.nn.MultiheadAttention here but I am including an
|
46 |
+
explicit implementation here to show that there is nothing too scary here.
|
47 |
+
"""
|
48 |
+
|
49 |
+
def __init__(self, config):
|
50 |
+
super().__init__()
|
51 |
+
assert config.n_embd % config.n_head == 0
|
52 |
+
# key, query, value projections for all heads
|
53 |
+
self.key = nn.Linear(config.n_embd, config.n_embd)
|
54 |
+
self.query = nn.Linear(config.n_embd, config.n_embd)
|
55 |
+
self.value = nn.Linear(config.n_embd, config.n_embd)
|
56 |
+
# regularization
|
57 |
+
self.attn_drop = nn.Dropout(config.attn_pdrop)
|
58 |
+
self.resid_drop = nn.Dropout(config.resid_pdrop)
|
59 |
+
# output projection
|
60 |
+
self.proj = nn.Linear(config.n_embd, config.n_embd)
|
61 |
+
# causal mask to ensure that attention is only applied to the left in the input sequence
|
62 |
+
mask = torch.tril(torch.ones(config.block_size,
|
63 |
+
config.block_size))
|
64 |
+
if hasattr(config, "n_unmasked"):
|
65 |
+
mask[:config.n_unmasked, :config.n_unmasked] = 1
|
66 |
+
self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size))
|
67 |
+
self.n_head = config.n_head
|
68 |
+
|
69 |
+
def forward(self, x, layer_past=None):
|
70 |
+
B, T, C = x.size()
|
71 |
+
|
72 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
73 |
+
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
74 |
+
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
75 |
+
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
76 |
+
|
77 |
+
present = torch.stack((k, v))
|
78 |
+
if layer_past is not None:
|
79 |
+
past_key, past_value = layer_past
|
80 |
+
k = torch.cat((past_key, k), dim=-2)
|
81 |
+
v = torch.cat((past_value, v), dim=-2)
|
82 |
+
|
83 |
+
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
84 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
85 |
+
if layer_past is None:
|
86 |
+
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
|
87 |
+
|
88 |
+
att = F.softmax(att, dim=-1)
|
89 |
+
att = self.attn_drop(att)
|
90 |
+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
91 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
92 |
+
|
93 |
+
# output projection
|
94 |
+
y = self.resid_drop(self.proj(y))
|
95 |
+
return y, present # TODO: check that this does not break anything
|
96 |
+
|
97 |
+
|
98 |
+
class Block(nn.Module):
|
99 |
+
""" an unassuming Transformer block """
|
100 |
+
def __init__(self, config):
|
101 |
+
super().__init__()
|
102 |
+
self.ln1 = nn.LayerNorm(config.n_embd)
|
103 |
+
self.ln2 = nn.LayerNorm(config.n_embd)
|
104 |
+
self.attn = CausalSelfAttention(config)
|
105 |
+
self.mlp = nn.Sequential(
|
106 |
+
nn.Linear(config.n_embd, 4 * config.n_embd),
|
107 |
+
nn.GELU(), # nice
|
108 |
+
nn.Linear(4 * config.n_embd, config.n_embd),
|
109 |
+
nn.Dropout(config.resid_pdrop),
|
110 |
+
)
|
111 |
+
|
112 |
+
def forward(self, x, layer_past=None, return_present=False):
|
113 |
+
# TODO: check that training still works
|
114 |
+
if return_present: assert not self.training
|
115 |
+
# layer past: tuple of length two with B, nh, T, hs
|
116 |
+
attn, present = self.attn(self.ln1(x), layer_past=layer_past)
|
117 |
+
|
118 |
+
x = x + attn
|
119 |
+
x = x + self.mlp(self.ln2(x))
|
120 |
+
if layer_past is not None or return_present:
|
121 |
+
return x, present
|
122 |
+
return x
|
123 |
+
|
124 |
+
|
125 |
+
class GPT(nn.Module):
|
126 |
+
""" the full GPT language model, with a context size of block_size """
|
127 |
+
def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256,
|
128 |
+
embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
|
129 |
+
super().__init__()
|
130 |
+
config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
|
131 |
+
embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
|
132 |
+
n_layer=n_layer, n_head=n_head, n_embd=n_embd,
|
133 |
+
n_unmasked=n_unmasked)
|
134 |
+
# input embedding stem
|
135 |
+
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
|
136 |
+
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
|
137 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
138 |
+
# transformer
|
139 |
+
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
|
140 |
+
# decoder head
|
141 |
+
self.ln_f = nn.LayerNorm(config.n_embd)
|
142 |
+
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
143 |
+
self.block_size = config.block_size
|
144 |
+
self.apply(self._init_weights)
|
145 |
+
self.config = config
|
146 |
+
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
|
147 |
+
|
148 |
+
def get_block_size(self):
|
149 |
+
return self.block_size
|
150 |
+
|
151 |
+
def _init_weights(self, module):
|
152 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
153 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
154 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
155 |
+
module.bias.data.zero_()
|
156 |
+
elif isinstance(module, nn.LayerNorm):
|
157 |
+
module.bias.data.zero_()
|
158 |
+
module.weight.data.fill_(1.0)
|
159 |
+
|
160 |
+
def forward(self, idx, embeddings=None, targets=None):
|
161 |
+
# forward the GPT model
|
162 |
+
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
|
163 |
+
|
164 |
+
if embeddings is not None: # prepend explicit embeddings
|
165 |
+
token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
|
166 |
+
|
167 |
+
t = token_embeddings.shape[1]
|
168 |
+
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
|
169 |
+
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
|
170 |
+
x = self.drop(token_embeddings + position_embeddings)
|
171 |
+
x = self.blocks(x)
|
172 |
+
x = self.ln_f(x)
|
173 |
+
logits = self.head(x)
|
174 |
+
|
175 |
+
# if we are given some desired targets also calculate the loss
|
176 |
+
loss = None
|
177 |
+
if targets is not None:
|
178 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
179 |
+
|
180 |
+
return logits, loss
|
181 |
+
|
182 |
+
def forward_with_past(self, idx, embeddings=None, targets=None, past=None, past_length=None):
|
183 |
+
# inference only
|
184 |
+
assert not self.training
|
185 |
+
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
|
186 |
+
if embeddings is not None: # prepend explicit embeddings
|
187 |
+
token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
|
188 |
+
|
189 |
+
if past is not None:
|
190 |
+
assert past_length is not None
|
191 |
+
past = torch.cat(past, dim=-2) # n_layer, 2, b, nh, len_past, dim_head
|
192 |
+
past_shape = list(past.shape)
|
193 |
+
expected_shape = [self.config.n_layer, 2, idx.shape[0], self.config.n_head, past_length, self.config.n_embd//self.config.n_head]
|
194 |
+
assert past_shape == expected_shape, f"{past_shape} =/= {expected_shape}"
|
195 |
+
position_embeddings = self.pos_emb[:, past_length, :] # each position maps to a (learnable) vector
|
196 |
+
else:
|
197 |
+
position_embeddings = self.pos_emb[:, :token_embeddings.shape[1], :]
|
198 |
+
|
199 |
+
x = self.drop(token_embeddings + position_embeddings)
|
200 |
+
presents = [] # accumulate over layers
|
201 |
+
for i, block in enumerate(self.blocks):
|
202 |
+
x, present = block(x, layer_past=past[i, ...] if past is not None else None, return_present=True)
|
203 |
+
presents.append(present)
|
204 |
+
|
205 |
+
x = self.ln_f(x)
|
206 |
+
logits = self.head(x)
|
207 |
+
# if we are given some desired targets also calculate the loss
|
208 |
+
loss = None
|
209 |
+
if targets is not None:
|
210 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
211 |
+
|
212 |
+
return logits, loss, torch.stack(presents) # _, _, n_layer, 2, b, nh, 1, dim_head
|
213 |
+
|
214 |
+
|
215 |
+
class DummyGPT(nn.Module):
|
216 |
+
# for debugging
|
217 |
+
def __init__(self, add_value=1):
|
218 |
+
super().__init__()
|
219 |
+
self.add_value = add_value
|
220 |
+
|
221 |
+
def forward(self, idx):
|
222 |
+
return idx + self.add_value, None
|
223 |
+
|
224 |
+
|
225 |
+
class CodeGPT(nn.Module):
|
226 |
+
"""Takes in semi-embeddings"""
|
227 |
+
def __init__(self, vocab_size, block_size, in_channels, n_layer=12, n_head=8, n_embd=256,
|
228 |
+
embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
|
229 |
+
super().__init__()
|
230 |
+
config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
|
231 |
+
embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
|
232 |
+
n_layer=n_layer, n_head=n_head, n_embd=n_embd,
|
233 |
+
n_unmasked=n_unmasked)
|
234 |
+
# input embedding stem
|
235 |
+
self.tok_emb = nn.Linear(in_channels, config.n_embd)
|
236 |
+
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
|
237 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
238 |
+
# transformer
|
239 |
+
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
|
240 |
+
# decoder head
|
241 |
+
self.ln_f = nn.LayerNorm(config.n_embd)
|
242 |
+
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
243 |
+
self.block_size = config.block_size
|
244 |
+
self.apply(self._init_weights)
|
245 |
+
self.config = config
|
246 |
+
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
|
247 |
+
|
248 |
+
def get_block_size(self):
|
249 |
+
return self.block_size
|
250 |
+
|
251 |
+
def _init_weights(self, module):
|
252 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
253 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
254 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
255 |
+
module.bias.data.zero_()
|
256 |
+
elif isinstance(module, nn.LayerNorm):
|
257 |
+
module.bias.data.zero_()
|
258 |
+
module.weight.data.fill_(1.0)
|
259 |
+
|
260 |
+
def forward(self, idx, embeddings=None, targets=None):
|
261 |
+
# forward the GPT model
|
262 |
+
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
|
263 |
+
|
264 |
+
if embeddings is not None: # prepend explicit embeddings
|
265 |
+
token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
|
266 |
+
|
267 |
+
t = token_embeddings.shape[1]
|
268 |
+
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
|
269 |
+
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
|
270 |
+
x = self.drop(token_embeddings + position_embeddings)
|
271 |
+
x = self.blocks(x)
|
272 |
+
x = self.taming_cinln_f(x)
|
273 |
+
logits = self.head(x)
|
274 |
+
|
275 |
+
# if we are given some desired targets also calculate the loss
|
276 |
+
loss = None
|
277 |
+
if targets is not None:
|
278 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
279 |
+
|
280 |
+
return logits, loss
|
281 |
+
|
282 |
+
|
283 |
+
|
284 |
+
#### sampling utils
|
285 |
+
|
286 |
+
def top_k_logits(logits, k):
|
287 |
+
v, ix = torch.topk(logits, k)
|
288 |
+
out = logits.clone()
|
289 |
+
out[out < v[:, [-1]]] = -float('Inf')
|
290 |
+
return out
|
291 |
+
|
292 |
+
@torch.no_grad()
|
293 |
+
def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
|
294 |
+
"""
|
295 |
+
take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
|
296 |
+
the sequence, feeding the predictions back into the model each time. Clearly the sampling
|
297 |
+
has quadratic complexity unlike an RNN that is only linear, and has a finite context window
|
298 |
+
of block_size, unlike an RNN that has an infinite context window.
|
299 |
+
"""
|
300 |
+
block_size = model.get_block_size()
|
301 |
+
model.eval()
|
302 |
+
for k in range(steps):
|
303 |
+
x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
|
304 |
+
logits, _ = model(x_cond)
|
305 |
+
# pluck the logits at the final step and scale by temperature
|
306 |
+
logits = logits[:, -1, :] / temperature
|
307 |
+
# optionally crop probabilities to only the top k options
|
308 |
+
if top_k is not None:
|
309 |
+
logits = top_k_logits(logits, top_k)
|
310 |
+
# apply softmax to convert to probabilities
|
311 |
+
probs = F.softmax(logits, dim=-1)
|
312 |
+
# sample from the distribution or take the most likely
|
313 |
+
if sample:
|
314 |
+
ix = torch.multinomial(probs, num_samples=1)
|
315 |
+
else:
|
316 |
+
_, ix = torch.topk(probs, k=1, dim=-1)
|
317 |
+
# append to the sequence and continue
|
318 |
+
x = torch.cat((x, ix), dim=1)
|
319 |
+
|
320 |
+
return x
|
321 |
+
|
322 |
+
|
323 |
+
@torch.no_grad()
|
324 |
+
def sample_with_past(x, model, steps, temperature=1., sample_logits=True,
|
325 |
+
top_k=None, top_p=None, callback=None):
|
326 |
+
# x is conditioning
|
327 |
+
sample = x
|
328 |
+
cond_len = x.shape[1]
|
329 |
+
past = None
|
330 |
+
for n in range(steps):
|
331 |
+
if callback is not None:
|
332 |
+
callback(n)
|
333 |
+
logits, _, present = model.forward_with_past(x, past=past, past_length=(n+cond_len-1))
|
334 |
+
if past is None:
|
335 |
+
past = [present]
|
336 |
+
else:
|
337 |
+
past.append(present)
|
338 |
+
logits = logits[:, -1, :] / temperature
|
339 |
+
if top_k is not None:
|
340 |
+
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
341 |
+
|
342 |
+
probs = F.softmax(logits, dim=-1)
|
343 |
+
if not sample_logits:
|
344 |
+
_, x = torch.topk(probs, k=1, dim=-1)
|
345 |
+
else:
|
346 |
+
x = torch.multinomial(probs, num_samples=1)
|
347 |
+
# append to the sequence and continue
|
348 |
+
sample = torch.cat((sample, x), dim=1)
|
349 |
+
del past
|
350 |
+
sample = sample[:, cond_len:] # cut conditioning off
|
351 |
+
return sample
|
352 |
+
|
353 |
+
|
354 |
+
#### clustering utils
|
355 |
+
|
356 |
+
class KMeans(nn.Module):
|
357 |
+
def __init__(self, ncluster=512, nc=3, niter=10):
|
358 |
+
super().__init__()
|
359 |
+
self.ncluster = ncluster
|
360 |
+
self.nc = nc
|
361 |
+
self.niter = niter
|
362 |
+
self.shape = (3,32,32)
|
363 |
+
self.register_buffer("C", torch.zeros(self.ncluster,nc))
|
364 |
+
self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
|
365 |
+
|
366 |
+
def is_initialized(self):
|
367 |
+
return self.initialized.item() == 1
|
368 |
+
|
369 |
+
@torch.no_grad()
|
370 |
+
def initialize(self, x):
|
371 |
+
N, D = x.shape
|
372 |
+
assert D == self.nc, D
|
373 |
+
c = x[torch.randperm(N)[:self.ncluster]] # init clusters at random
|
374 |
+
for i in range(self.niter):
|
375 |
+
# assign all pixels to the closest codebook element
|
376 |
+
a = ((x[:, None, :] - c[None, :, :])**2).sum(-1).argmin(1)
|
377 |
+
# move each codebook element to be the mean of the pixels that assigned to it
|
378 |
+
c = torch.stack([x[a==k].mean(0) for k in range(self.ncluster)])
|
379 |
+
# re-assign any poorly positioned codebook elements
|
380 |
+
nanix = torch.any(torch.isnan(c), dim=1)
|
381 |
+
ndead = nanix.sum().item()
|
382 |
+
print('done step %d/%d, re-initialized %d dead clusters' % (i+1, self.niter, ndead))
|
383 |
+
c[nanix] = x[torch.randperm(N)[:ndead]] # re-init dead clusters
|
384 |
+
|
385 |
+
self.C.copy_(c)
|
386 |
+
self.initialized.fill_(1)
|
387 |
+
|
388 |
+
|
389 |
+
def forward(self, x, reverse=False, shape=None):
|
390 |
+
if not reverse:
|
391 |
+
# flatten
|
392 |
+
bs,c,h,w = x.shape
|
393 |
+
assert c == self.nc
|
394 |
+
x = x.reshape(bs,c,h*w,1)
|
395 |
+
C = self.C.permute(1,0)
|
396 |
+
C = C.reshape(1,c,1,self.ncluster)
|
397 |
+
a = ((x-C)**2).sum(1).argmin(-1) # bs, h*w indices
|
398 |
+
return a
|
399 |
+
else:
|
400 |
+
# flatten
|
401 |
+
bs, HW = x.shape
|
402 |
+
"""
|
403 |
+
c = self.C.reshape( 1, self.nc, 1, self.ncluster)
|
404 |
+
c = c[bs*[0],:,:,:]
|
405 |
+
c = c[:,:,HW*[0],:]
|
406 |
+
x = x.reshape(bs, 1, HW, 1)
|
407 |
+
x = x[:,3*[0],:,:]
|
408 |
+
x = torch.gather(c, dim=3, index=x)
|
409 |
+
"""
|
410 |
+
x = self.C[x]
|
411 |
+
x = x.permute(0,2,1)
|
412 |
+
shape = shape if shape is not None else self.shape
|
413 |
+
x = x.reshape(bs, *shape)
|
414 |
+
|
415 |
+
return x
|
taming-transformers/taming/modules/transformer/permuter.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
class AbstractPermuter(nn.Module):
|
7 |
+
def __init__(self, *args, **kwargs):
|
8 |
+
super().__init__()
|
9 |
+
def forward(self, x, reverse=False):
|
10 |
+
raise NotImplementedError
|
11 |
+
|
12 |
+
|
13 |
+
class Identity(AbstractPermuter):
|
14 |
+
def __init__(self):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
def forward(self, x, reverse=False):
|
18 |
+
return x
|
19 |
+
|
20 |
+
|
21 |
+
class Subsample(AbstractPermuter):
|
22 |
+
def __init__(self, H, W):
|
23 |
+
super().__init__()
|
24 |
+
C = 1
|
25 |
+
indices = np.arange(H*W).reshape(C,H,W)
|
26 |
+
while min(H, W) > 1:
|
27 |
+
indices = indices.reshape(C,H//2,2,W//2,2)
|
28 |
+
indices = indices.transpose(0,2,4,1,3)
|
29 |
+
indices = indices.reshape(C*4,H//2, W//2)
|
30 |
+
H = H//2
|
31 |
+
W = W//2
|
32 |
+
C = C*4
|
33 |
+
assert H == W == 1
|
34 |
+
idx = torch.tensor(indices.ravel())
|
35 |
+
self.register_buffer('forward_shuffle_idx',
|
36 |
+
nn.Parameter(idx, requires_grad=False))
|
37 |
+
self.register_buffer('backward_shuffle_idx',
|
38 |
+
nn.Parameter(torch.argsort(idx), requires_grad=False))
|
39 |
+
|
40 |
+
def forward(self, x, reverse=False):
|
41 |
+
if not reverse:
|
42 |
+
return x[:, self.forward_shuffle_idx]
|
43 |
+
else:
|
44 |
+
return x[:, self.backward_shuffle_idx]
|
45 |
+
|
46 |
+
|
47 |
+
def mortonify(i, j):
|
48 |
+
"""(i,j) index to linear morton code"""
|
49 |
+
i = np.uint64(i)
|
50 |
+
j = np.uint64(j)
|
51 |
+
|
52 |
+
z = np.uint(0)
|
53 |
+
|
54 |
+
for pos in range(32):
|
55 |
+
z = (z |
|
56 |
+
((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) |
|
57 |
+
((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1))
|
58 |
+
)
|
59 |
+
return z
|
60 |
+
|
61 |
+
|
62 |
+
class ZCurve(AbstractPermuter):
|
63 |
+
def __init__(self, H, W):
|
64 |
+
super().__init__()
|
65 |
+
reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)]
|
66 |
+
idx = np.argsort(reverseidx)
|
67 |
+
idx = torch.tensor(idx)
|
68 |
+
reverseidx = torch.tensor(reverseidx)
|
69 |
+
self.register_buffer('forward_shuffle_idx',
|
70 |
+
idx)
|
71 |
+
self.register_buffer('backward_shuffle_idx',
|
72 |
+
reverseidx)
|
73 |
+
|
74 |
+
def forward(self, x, reverse=False):
|
75 |
+
if not reverse:
|
76 |
+
return x[:, self.forward_shuffle_idx]
|
77 |
+
else:
|
78 |
+
return x[:, self.backward_shuffle_idx]
|
79 |
+
|
80 |
+
|
81 |
+
class SpiralOut(AbstractPermuter):
|
82 |
+
def __init__(self, H, W):
|
83 |
+
super().__init__()
|
84 |
+
assert H == W
|
85 |
+
size = W
|
86 |
+
indices = np.arange(size*size).reshape(size,size)
|
87 |
+
|
88 |
+
i0 = size//2
|
89 |
+
j0 = size//2-1
|
90 |
+
|
91 |
+
i = i0
|
92 |
+
j = j0
|
93 |
+
|
94 |
+
idx = [indices[i0, j0]]
|
95 |
+
step_mult = 0
|
96 |
+
for c in range(1, size//2+1):
|
97 |
+
step_mult += 1
|
98 |
+
# steps left
|
99 |
+
for k in range(step_mult):
|
100 |
+
i = i - 1
|
101 |
+
j = j
|
102 |
+
idx.append(indices[i, j])
|
103 |
+
|
104 |
+
# step down
|
105 |
+
for k in range(step_mult):
|
106 |
+
i = i
|
107 |
+
j = j + 1
|
108 |
+
idx.append(indices[i, j])
|
109 |
+
|
110 |
+
step_mult += 1
|
111 |
+
if c < size//2:
|
112 |
+
# step right
|
113 |
+
for k in range(step_mult):
|
114 |
+
i = i + 1
|
115 |
+
j = j
|
116 |
+
idx.append(indices[i, j])
|
117 |
+
|
118 |
+
# step up
|
119 |
+
for k in range(step_mult):
|
120 |
+
i = i
|
121 |
+
j = j - 1
|
122 |
+
idx.append(indices[i, j])
|
123 |
+
else:
|
124 |
+
# end reached
|
125 |
+
for k in range(step_mult-1):
|
126 |
+
i = i + 1
|
127 |
+
idx.append(indices[i, j])
|
128 |
+
|
129 |
+
assert len(idx) == size*size
|
130 |
+
idx = torch.tensor(idx)
|
131 |
+
self.register_buffer('forward_shuffle_idx', idx)
|
132 |
+
self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
|
133 |
+
|
134 |
+
def forward(self, x, reverse=False):
|
135 |
+
if not reverse:
|
136 |
+
return x[:, self.forward_shuffle_idx]
|
137 |
+
else:
|
138 |
+
return x[:, self.backward_shuffle_idx]
|
139 |
+
|
140 |
+
|
141 |
+
class SpiralIn(AbstractPermuter):
|
142 |
+
def __init__(self, H, W):
|
143 |
+
super().__init__()
|
144 |
+
assert H == W
|
145 |
+
size = W
|
146 |
+
indices = np.arange(size*size).reshape(size,size)
|
147 |
+
|
148 |
+
i0 = size//2
|
149 |
+
j0 = size//2-1
|
150 |
+
|
151 |
+
i = i0
|
152 |
+
j = j0
|
153 |
+
|
154 |
+
idx = [indices[i0, j0]]
|
155 |
+
step_mult = 0
|
156 |
+
for c in range(1, size//2+1):
|
157 |
+
step_mult += 1
|
158 |
+
# steps left
|
159 |
+
for k in range(step_mult):
|
160 |
+
i = i - 1
|
161 |
+
j = j
|
162 |
+
idx.append(indices[i, j])
|
163 |
+
|
164 |
+
# step down
|
165 |
+
for k in range(step_mult):
|
166 |
+
i = i
|
167 |
+
j = j + 1
|
168 |
+
idx.append(indices[i, j])
|
169 |
+
|
170 |
+
step_mult += 1
|
171 |
+
if c < size//2:
|
172 |
+
# step right
|
173 |
+
for k in range(step_mult):
|
174 |
+
i = i + 1
|
175 |
+
j = j
|
176 |
+
idx.append(indices[i, j])
|
177 |
+
|
178 |
+
# step up
|
179 |
+
for k in range(step_mult):
|
180 |
+
i = i
|
181 |
+
j = j - 1
|
182 |
+
idx.append(indices[i, j])
|
183 |
+
else:
|
184 |
+
# end reached
|
185 |
+
for k in range(step_mult-1):
|
186 |
+
i = i + 1
|
187 |
+
idx.append(indices[i, j])
|
188 |
+
|
189 |
+
assert len(idx) == size*size
|
190 |
+
idx = idx[::-1]
|
191 |
+
idx = torch.tensor(idx)
|
192 |
+
self.register_buffer('forward_shuffle_idx', idx)
|
193 |
+
self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
|
194 |
+
|
195 |
+
def forward(self, x, reverse=False):
|
196 |
+
if not reverse:
|
197 |
+
return x[:, self.forward_shuffle_idx]
|
198 |
+
else:
|
199 |
+
return x[:, self.backward_shuffle_idx]
|
200 |
+
|
201 |
+
|
202 |
+
class Random(nn.Module):
|
203 |
+
def __init__(self, H, W):
|
204 |
+
super().__init__()
|
205 |
+
indices = np.random.RandomState(1).permutation(H*W)
|
206 |
+
idx = torch.tensor(indices.ravel())
|
207 |
+
self.register_buffer('forward_shuffle_idx', idx)
|
208 |
+
self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
|
209 |
+
|
210 |
+
def forward(self, x, reverse=False):
|
211 |
+
if not reverse:
|
212 |
+
return x[:, self.forward_shuffle_idx]
|
213 |
+
else:
|
214 |
+
return x[:, self.backward_shuffle_idx]
|
215 |
+
|
216 |
+
|
217 |
+
class AlternateParsing(AbstractPermuter):
|
218 |
+
def __init__(self, H, W):
|
219 |
+
super().__init__()
|
220 |
+
indices = np.arange(W*H).reshape(H,W)
|
221 |
+
for i in range(1, H, 2):
|
222 |
+
indices[i, :] = indices[i, ::-1]
|
223 |
+
idx = indices.flatten()
|
224 |
+
assert len(idx) == H*W
|
225 |
+
idx = torch.tensor(idx)
|
226 |
+
self.register_buffer('forward_shuffle_idx', idx)
|
227 |
+
self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
|
228 |
+
|
229 |
+
def forward(self, x, reverse=False):
|
230 |
+
if not reverse:
|
231 |
+
return x[:, self.forward_shuffle_idx]
|
232 |
+
else:
|
233 |
+
return x[:, self.backward_shuffle_idx]
|
234 |
+
|
235 |
+
|
236 |
+
if __name__ == "__main__":
|
237 |
+
p0 = AlternateParsing(16, 16)
|
238 |
+
print(p0.forward_shuffle_idx)
|
239 |
+
print(p0.backward_shuffle_idx)
|
240 |
+
|
241 |
+
x = torch.randint(0, 768, size=(11, 256))
|
242 |
+
y = p0(x)
|
243 |
+
xre = p0(y, reverse=True)
|
244 |
+
assert torch.equal(x, xre)
|
245 |
+
|
246 |
+
p1 = SpiralOut(2, 2)
|
247 |
+
print(p1.forward_shuffle_idx)
|
248 |
+
print(p1.backward_shuffle_idx)
|
taming-transformers/taming/modules/util.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
def count_params(model):
|
6 |
+
total_params = sum(p.numel() for p in model.parameters())
|
7 |
+
return total_params
|
8 |
+
|
9 |
+
|
10 |
+
class ActNorm(nn.Module):
|
11 |
+
def __init__(self, num_features, logdet=False, affine=True,
|
12 |
+
allow_reverse_init=False):
|
13 |
+
assert affine
|
14 |
+
super().__init__()
|
15 |
+
self.logdet = logdet
|
16 |
+
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
|
17 |
+
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
|
18 |
+
self.allow_reverse_init = allow_reverse_init
|
19 |
+
|
20 |
+
self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
|
21 |
+
|
22 |
+
def initialize(self, input):
|
23 |
+
with torch.no_grad():
|
24 |
+
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
|
25 |
+
mean = (
|
26 |
+
flatten.mean(1)
|
27 |
+
.unsqueeze(1)
|
28 |
+
.unsqueeze(2)
|
29 |
+
.unsqueeze(3)
|
30 |
+
.permute(1, 0, 2, 3)
|
31 |
+
)
|
32 |
+
std = (
|
33 |
+
flatten.std(1)
|
34 |
+
.unsqueeze(1)
|
35 |
+
.unsqueeze(2)
|
36 |
+
.unsqueeze(3)
|
37 |
+
.permute(1, 0, 2, 3)
|
38 |
+
)
|
39 |
+
|
40 |
+
self.loc.data.copy_(-mean)
|
41 |
+
self.scale.data.copy_(1 / (std + 1e-6))
|
42 |
+
|
43 |
+
def forward(self, input, reverse=False):
|
44 |
+
if reverse:
|
45 |
+
return self.reverse(input)
|
46 |
+
if len(input.shape) == 2:
|
47 |
+
input = input[:,:,None,None]
|
48 |
+
squeeze = True
|
49 |
+
else:
|
50 |
+
squeeze = False
|
51 |
+
|
52 |
+
_, _, height, width = input.shape
|
53 |
+
|
54 |
+
if self.training and self.initialized.item() == 0:
|
55 |
+
self.initialize(input)
|
56 |
+
self.initialized.fill_(1)
|
57 |
+
|
58 |
+
h = self.scale * (input + self.loc)
|
59 |
+
|
60 |
+
if squeeze:
|
61 |
+
h = h.squeeze(-1).squeeze(-1)
|
62 |
+
|
63 |
+
if self.logdet:
|
64 |
+
log_abs = torch.log(torch.abs(self.scale))
|
65 |
+
logdet = height*width*torch.sum(log_abs)
|
66 |
+
logdet = logdet * torch.ones(input.shape[0]).to(input)
|
67 |
+
return h, logdet
|
68 |
+
|
69 |
+
return h
|
70 |
+
|
71 |
+
def reverse(self, output):
|
72 |
+
if self.training and self.initialized.item() == 0:
|
73 |
+
if not self.allow_reverse_init:
|
74 |
+
raise RuntimeError(
|
75 |
+
"Initializing ActNorm in reverse direction is "
|
76 |
+
"disabled by default. Use allow_reverse_init=True to enable."
|
77 |
+
)
|
78 |
+
else:
|
79 |
+
self.initialize(output)
|
80 |
+
self.initialized.fill_(1)
|
81 |
+
|
82 |
+
if len(output.shape) == 2:
|
83 |
+
output = output[:,:,None,None]
|
84 |
+
squeeze = True
|
85 |
+
else:
|
86 |
+
squeeze = False
|
87 |
+
|
88 |
+
h = output / self.scale - self.loc
|
89 |
+
|
90 |
+
if squeeze:
|
91 |
+
h = h.squeeze(-1).squeeze(-1)
|
92 |
+
return h
|
93 |
+
|
94 |
+
|
95 |
+
class AbstractEncoder(nn.Module):
|
96 |
+
def __init__(self):
|
97 |
+
super().__init__()
|
98 |
+
|
99 |
+
def encode(self, *args, **kwargs):
|
100 |
+
raise NotImplementedError
|
101 |
+
|
102 |
+
|
103 |
+
class Labelator(AbstractEncoder):
|
104 |
+
"""Net2Net Interface for Class-Conditional Model"""
|
105 |
+
def __init__(self, n_classes, quantize_interface=True):
|
106 |
+
super().__init__()
|
107 |
+
self.n_classes = n_classes
|
108 |
+
self.quantize_interface = quantize_interface
|
109 |
+
|
110 |
+
def encode(self, c):
|
111 |
+
c = c[:,None]
|
112 |
+
if self.quantize_interface:
|
113 |
+
return c, None, [None, None, c.long()]
|
114 |
+
return c
|
115 |
+
|
116 |
+
|
117 |
+
class SOSProvider(AbstractEncoder):
|
118 |
+
# for unconditional training
|
119 |
+
def __init__(self, sos_token, quantize_interface=True):
|
120 |
+
super().__init__()
|
121 |
+
self.sos_token = sos_token
|
122 |
+
self.quantize_interface = quantize_interface
|
123 |
+
|
124 |
+
def encode(self, x):
|
125 |
+
# get batch size from data and replicate sos_token
|
126 |
+
c = torch.ones(x.shape[0], 1)*self.sos_token
|
127 |
+
c = c.long().to(x.device)
|
128 |
+
if self.quantize_interface:
|
129 |
+
return c, None, [None, None, c]
|
130 |
+
return c
|
taming-transformers/taming/modules/vqvae/quantize.py
ADDED
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from torch import einsum
|
6 |
+
from einops import rearrange
|
7 |
+
|
8 |
+
|
9 |
+
class VectorQuantizer(nn.Module):
|
10 |
+
"""
|
11 |
+
see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
|
12 |
+
____________________________________________
|
13 |
+
Discretization bottleneck part of the VQ-VAE.
|
14 |
+
Inputs:
|
15 |
+
- n_e : number of embeddings
|
16 |
+
- e_dim : dimension of embedding
|
17 |
+
- beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
18 |
+
_____________________________________________
|
19 |
+
"""
|
20 |
+
|
21 |
+
# NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for
|
22 |
+
# a fix and use legacy=False to apply that fix. VectorQuantizer2 can be
|
23 |
+
# used wherever VectorQuantizer has been used before and is additionally
|
24 |
+
# more efficient.
|
25 |
+
def __init__(self, n_e, e_dim, beta):
|
26 |
+
super(VectorQuantizer, self).__init__()
|
27 |
+
self.n_e = n_e
|
28 |
+
self.e_dim = e_dim
|
29 |
+
self.beta = beta
|
30 |
+
|
31 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
32 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
33 |
+
|
34 |
+
def forward(self, z):
|
35 |
+
"""
|
36 |
+
Inputs the output of the encoder network z and maps it to a discrete
|
37 |
+
one-hot vector that is the index of the closest embedding vector e_j
|
38 |
+
z (continuous) -> z_q (discrete)
|
39 |
+
z.shape = (batch, channel, height, width)
|
40 |
+
quantization pipeline:
|
41 |
+
1. get encoder input (B,C,H,W)
|
42 |
+
2. flatten input to (B*H*W,C)
|
43 |
+
"""
|
44 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
45 |
+
z = z.permute(0, 2, 3, 1).contiguous()
|
46 |
+
z_flattened = z.view(-1, self.e_dim)
|
47 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
48 |
+
|
49 |
+
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
50 |
+
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
|
51 |
+
torch.matmul(z_flattened, self.embedding.weight.t())
|
52 |
+
|
53 |
+
## could possible replace this here
|
54 |
+
# #\start...
|
55 |
+
# find closest encodings
|
56 |
+
min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
57 |
+
|
58 |
+
min_encodings = torch.zeros(
|
59 |
+
min_encoding_indices.shape[0], self.n_e).to(z)
|
60 |
+
min_encodings.scatter_(1, min_encoding_indices, 1)
|
61 |
+
|
62 |
+
# dtype min encodings: torch.float32
|
63 |
+
# min_encodings shape: torch.Size([2048, 512])
|
64 |
+
# min_encoding_indices.shape: torch.Size([2048, 1])
|
65 |
+
|
66 |
+
# get quantized latent vectors
|
67 |
+
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
68 |
+
#.........\end
|
69 |
+
|
70 |
+
# with:
|
71 |
+
# .........\start
|
72 |
+
#min_encoding_indices = torch.argmin(d, dim=1)
|
73 |
+
#z_q = self.embedding(min_encoding_indices)
|
74 |
+
# ......\end......... (TODO)
|
75 |
+
|
76 |
+
# compute loss for embedding
|
77 |
+
loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
|
78 |
+
torch.mean((z_q - z.detach()) ** 2)
|
79 |
+
|
80 |
+
# preserve gradients
|
81 |
+
z_q = z + (z_q - z).detach()
|
82 |
+
|
83 |
+
# perplexity
|
84 |
+
e_mean = torch.mean(min_encodings, dim=0)
|
85 |
+
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
86 |
+
|
87 |
+
# reshape back to match original input shape
|
88 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
89 |
+
|
90 |
+
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
91 |
+
|
92 |
+
def get_codebook_entry(self, indices, shape):
|
93 |
+
# shape specifying (batch, height, width, channel)
|
94 |
+
# TODO: check for more easy handling with nn.Embedding
|
95 |
+
min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
|
96 |
+
min_encodings.scatter_(1, indices[:,None], 1)
|
97 |
+
|
98 |
+
# get quantized latent vectors
|
99 |
+
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
100 |
+
|
101 |
+
if shape is not None:
|
102 |
+
z_q = z_q.view(shape)
|
103 |
+
|
104 |
+
# reshape back to match original input shape
|
105 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
106 |
+
|
107 |
+
return z_q
|
108 |
+
|
109 |
+
|
110 |
+
class GumbelQuantize(nn.Module):
|
111 |
+
"""
|
112 |
+
credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
|
113 |
+
Gumbel Softmax trick quantizer
|
114 |
+
Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
|
115 |
+
https://arxiv.org/abs/1611.01144
|
116 |
+
"""
|
117 |
+
def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True,
|
118 |
+
kl_weight=5e-4, temp_init=1.0, use_vqinterface=True,
|
119 |
+
remap=None, unknown_index="random"):
|
120 |
+
super().__init__()
|
121 |
+
|
122 |
+
self.embedding_dim = embedding_dim
|
123 |
+
self.n_embed = n_embed
|
124 |
+
|
125 |
+
self.straight_through = straight_through
|
126 |
+
self.temperature = temp_init
|
127 |
+
self.kl_weight = kl_weight
|
128 |
+
|
129 |
+
self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
|
130 |
+
self.embed = nn.Embedding(n_embed, embedding_dim)
|
131 |
+
|
132 |
+
self.use_vqinterface = use_vqinterface
|
133 |
+
|
134 |
+
self.remap = remap
|
135 |
+
if self.remap is not None:
|
136 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
137 |
+
self.re_embed = self.used.shape[0]
|
138 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
139 |
+
if self.unknown_index == "extra":
|
140 |
+
self.unknown_index = self.re_embed
|
141 |
+
self.re_embed = self.re_embed+1
|
142 |
+
print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
|
143 |
+
f"Using {self.unknown_index} for unknown indices.")
|
144 |
+
else:
|
145 |
+
self.re_embed = n_embed
|
146 |
+
|
147 |
+
def remap_to_used(self, inds):
|
148 |
+
ishape = inds.shape
|
149 |
+
assert len(ishape)>1
|
150 |
+
inds = inds.reshape(ishape[0],-1)
|
151 |
+
used = self.used.to(inds)
|
152 |
+
match = (inds[:,:,None]==used[None,None,...]).long()
|
153 |
+
new = match.argmax(-1)
|
154 |
+
unknown = match.sum(2)<1
|
155 |
+
if self.unknown_index == "random":
|
156 |
+
new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
|
157 |
+
else:
|
158 |
+
new[unknown] = self.unknown_index
|
159 |
+
return new.reshape(ishape)
|
160 |
+
|
161 |
+
def unmap_to_all(self, inds):
|
162 |
+
ishape = inds.shape
|
163 |
+
assert len(ishape)>1
|
164 |
+
inds = inds.reshape(ishape[0],-1)
|
165 |
+
used = self.used.to(inds)
|
166 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
167 |
+
inds[inds>=self.used.shape[0]] = 0 # simply set to zero
|
168 |
+
back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
|
169 |
+
return back.reshape(ishape)
|
170 |
+
|
171 |
+
def forward(self, z, temp=None, return_logits=False):
|
172 |
+
# force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
|
173 |
+
hard = self.straight_through if self.training else True
|
174 |
+
temp = self.temperature if temp is None else temp
|
175 |
+
|
176 |
+
logits = self.proj(z)
|
177 |
+
if self.remap is not None:
|
178 |
+
# continue only with used logits
|
179 |
+
full_zeros = torch.zeros_like(logits)
|
180 |
+
logits = logits[:,self.used,...]
|
181 |
+
|
182 |
+
soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
|
183 |
+
if self.remap is not None:
|
184 |
+
# go back to all entries but unused set to zero
|
185 |
+
full_zeros[:,self.used,...] = soft_one_hot
|
186 |
+
soft_one_hot = full_zeros
|
187 |
+
z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
|
188 |
+
|
189 |
+
# + kl divergence to the prior loss
|
190 |
+
qy = F.softmax(logits, dim=1)
|
191 |
+
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
|
192 |
+
|
193 |
+
ind = soft_one_hot.argmax(dim=1)
|
194 |
+
if self.remap is not None:
|
195 |
+
ind = self.remap_to_used(ind)
|
196 |
+
if self.use_vqinterface:
|
197 |
+
if return_logits:
|
198 |
+
return z_q, diff, (None, None, ind), logits
|
199 |
+
return z_q, diff, (None, None, ind)
|
200 |
+
return z_q, diff, ind
|
201 |
+
|
202 |
+
def get_codebook_entry(self, indices, shape):
|
203 |
+
b, h, w, c = shape
|
204 |
+
assert b*h*w == indices.shape[0]
|
205 |
+
indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w)
|
206 |
+
if self.remap is not None:
|
207 |
+
indices = self.unmap_to_all(indices)
|
208 |
+
one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
|
209 |
+
z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight)
|
210 |
+
return z_q
|
211 |
+
|
212 |
+
|
213 |
+
class VectorQuantizer2(nn.Module):
|
214 |
+
"""
|
215 |
+
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
216 |
+
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
217 |
+
"""
|
218 |
+
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
219 |
+
# backwards compatibility we use the buggy version by default, but you can
|
220 |
+
# specify legacy=False to fix it.
|
221 |
+
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
|
222 |
+
sane_index_shape=False, legacy=True):
|
223 |
+
super().__init__()
|
224 |
+
self.n_e = n_e
|
225 |
+
self.e_dim = e_dim
|
226 |
+
self.beta = beta
|
227 |
+
self.legacy = legacy
|
228 |
+
|
229 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
230 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
231 |
+
|
232 |
+
self.remap = remap
|
233 |
+
if self.remap is not None:
|
234 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
235 |
+
self.re_embed = self.used.shape[0]
|
236 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
237 |
+
if self.unknown_index == "extra":
|
238 |
+
self.unknown_index = self.re_embed
|
239 |
+
self.re_embed = self.re_embed+1
|
240 |
+
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
241 |
+
f"Using {self.unknown_index} for unknown indices.")
|
242 |
+
else:
|
243 |
+
self.re_embed = n_e
|
244 |
+
|
245 |
+
self.sane_index_shape = sane_index_shape
|
246 |
+
|
247 |
+
def remap_to_used(self, inds):
|
248 |
+
ishape = inds.shape
|
249 |
+
assert len(ishape)>1
|
250 |
+
inds = inds.reshape(ishape[0],-1)
|
251 |
+
used = self.used.to(inds)
|
252 |
+
match = (inds[:,:,None]==used[None,None,...]).long()
|
253 |
+
new = match.argmax(-1)
|
254 |
+
unknown = match.sum(2)<1
|
255 |
+
if self.unknown_index == "random":
|
256 |
+
new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
|
257 |
+
else:
|
258 |
+
new[unknown] = self.unknown_index
|
259 |
+
return new.reshape(ishape)
|
260 |
+
|
261 |
+
def unmap_to_all(self, inds):
|
262 |
+
ishape = inds.shape
|
263 |
+
assert len(ishape)>1
|
264 |
+
inds = inds.reshape(ishape[0],-1)
|
265 |
+
used = self.used.to(inds)
|
266 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
267 |
+
inds[inds>=self.used.shape[0]] = 0 # simply set to zero
|
268 |
+
back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
|
269 |
+
return back.reshape(ishape)
|
270 |
+
|
271 |
+
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
|
272 |
+
assert temp is None or temp==1.0, "Only for interface compatible with Gumbel"
|
273 |
+
assert rescale_logits==False, "Only for interface compatible with Gumbel"
|
274 |
+
assert return_logits==False, "Only for interface compatible with Gumbel"
|
275 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
276 |
+
z = rearrange(z, 'b c h w -> b h w c').contiguous()
|
277 |
+
z_flattened = z.view(-1, self.e_dim)
|
278 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
279 |
+
|
280 |
+
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
281 |
+
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
|
282 |
+
torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
|
283 |
+
|
284 |
+
min_encoding_indices = torch.argmin(d, dim=1)
|
285 |
+
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
286 |
+
perplexity = None
|
287 |
+
min_encodings = None
|
288 |
+
|
289 |
+
# compute loss for embedding
|
290 |
+
if not self.legacy:
|
291 |
+
loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
|
292 |
+
torch.mean((z_q - z.detach()) ** 2)
|
293 |
+
else:
|
294 |
+
loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
|
295 |
+
torch.mean((z_q - z.detach()) ** 2)
|
296 |
+
|
297 |
+
# preserve gradients
|
298 |
+
z_q = z + (z_q - z).detach()
|
299 |
+
|
300 |
+
# reshape back to match original input shape
|
301 |
+
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
|
302 |
+
|
303 |
+
if self.remap is not None:
|
304 |
+
min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
|
305 |
+
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
306 |
+
min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten
|
307 |
+
|
308 |
+
if self.sane_index_shape:
|
309 |
+
min_encoding_indices = min_encoding_indices.reshape(
|
310 |
+
z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
311 |
+
|
312 |
+
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
313 |
+
|
314 |
+
def get_codebook_entry(self, indices, shape):
|
315 |
+
# shape specifying (batch, height, width, channel)
|
316 |
+
if self.remap is not None:
|
317 |
+
indices = indices.reshape(shape[0],-1) # add batch axis
|
318 |
+
indices = self.unmap_to_all(indices)
|
319 |
+
indices = indices.reshape(-1) # flatten again
|
320 |
+
|
321 |
+
# get quantized latent vectors
|
322 |
+
z_q = self.embedding(indices)
|
323 |
+
|
324 |
+
if shape is not None:
|
325 |
+
z_q = z_q.view(shape)
|
326 |
+
# reshape back to match original input shape
|
327 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
328 |
+
|
329 |
+
return z_q
|
330 |
+
|
331 |
+
class EmbeddingEMA(nn.Module):
|
332 |
+
def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
|
333 |
+
super().__init__()
|
334 |
+
self.decay = decay
|
335 |
+
self.eps = eps
|
336 |
+
weight = torch.randn(num_tokens, codebook_dim)
|
337 |
+
self.weight = nn.Parameter(weight, requires_grad = False)
|
338 |
+
self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False)
|
339 |
+
self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False)
|
340 |
+
self.update = True
|
341 |
+
|
342 |
+
def forward(self, embed_id):
|
343 |
+
return F.embedding(embed_id, self.weight)
|
344 |
+
|
345 |
+
def cluster_size_ema_update(self, new_cluster_size):
|
346 |
+
self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
|
347 |
+
|
348 |
+
def embed_avg_ema_update(self, new_embed_avg):
|
349 |
+
self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
|
350 |
+
|
351 |
+
def weight_update(self, num_tokens):
|
352 |
+
n = self.cluster_size.sum()
|
353 |
+
smoothed_cluster_size = (
|
354 |
+
(self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
|
355 |
+
)
|
356 |
+
#normalize embedding average with smoothed cluster size
|
357 |
+
embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
|
358 |
+
self.weight.data.copy_(embed_normalized)
|
359 |
+
|
360 |
+
|
361 |
+
class EMAVectorQuantizer(nn.Module):
|
362 |
+
def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
|
363 |
+
remap=None, unknown_index="random"):
|
364 |
+
super().__init__()
|
365 |
+
self.codebook_dim = codebook_dim
|
366 |
+
self.num_tokens = num_tokens
|
367 |
+
self.beta = beta
|
368 |
+
self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
|
369 |
+
|
370 |
+
self.remap = remap
|
371 |
+
if self.remap is not None:
|
372 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
373 |
+
self.re_embed = self.used.shape[0]
|
374 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
375 |
+
if self.unknown_index == "extra":
|
376 |
+
self.unknown_index = self.re_embed
|
377 |
+
self.re_embed = self.re_embed+1
|
378 |
+
print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
|
379 |
+
f"Using {self.unknown_index} for unknown indices.")
|
380 |
+
else:
|
381 |
+
self.re_embed = n_embed
|
382 |
+
|
383 |
+
def remap_to_used(self, inds):
|
384 |
+
ishape = inds.shape
|
385 |
+
assert len(ishape)>1
|
386 |
+
inds = inds.reshape(ishape[0],-1)
|
387 |
+
used = self.used.to(inds)
|
388 |
+
match = (inds[:,:,None]==used[None,None,...]).long()
|
389 |
+
new = match.argmax(-1)
|
390 |
+
unknown = match.sum(2)<1
|
391 |
+
if self.unknown_index == "random":
|
392 |
+
new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
|
393 |
+
else:
|
394 |
+
new[unknown] = self.unknown_index
|
395 |
+
return new.reshape(ishape)
|
396 |
+
|
397 |
+
def unmap_to_all(self, inds):
|
398 |
+
ishape = inds.shape
|
399 |
+
assert len(ishape)>1
|
400 |
+
inds = inds.reshape(ishape[0],-1)
|
401 |
+
used = self.used.to(inds)
|
402 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
403 |
+
inds[inds>=self.used.shape[0]] = 0 # simply set to zero
|
404 |
+
back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
|
405 |
+
return back.reshape(ishape)
|
406 |
+
|
407 |
+
def forward(self, z):
|
408 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
409 |
+
#z, 'b c h w -> b h w c'
|
410 |
+
z = rearrange(z, 'b c h w -> b h w c')
|
411 |
+
z_flattened = z.reshape(-1, self.codebook_dim)
|
412 |
+
|
413 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
414 |
+
d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
|
415 |
+
self.embedding.weight.pow(2).sum(dim=1) - 2 * \
|
416 |
+
torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
|
417 |
+
|
418 |
+
|
419 |
+
encoding_indices = torch.argmin(d, dim=1)
|
420 |
+
|
421 |
+
z_q = self.embedding(encoding_indices).view(z.shape)
|
422 |
+
encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
|
423 |
+
avg_probs = torch.mean(encodings, dim=0)
|
424 |
+
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
|
425 |
+
|
426 |
+
if self.training and self.embedding.update:
|
427 |
+
#EMA cluster size
|
428 |
+
encodings_sum = encodings.sum(0)
|
429 |
+
self.embedding.cluster_size_ema_update(encodings_sum)
|
430 |
+
#EMA embedding average
|
431 |
+
embed_sum = encodings.transpose(0,1) @ z_flattened
|
432 |
+
self.embedding.embed_avg_ema_update(embed_sum)
|
433 |
+
#normalize embed_avg and update weight
|
434 |
+
self.embedding.weight_update(self.num_tokens)
|
435 |
+
|
436 |
+
# compute loss for embedding
|
437 |
+
loss = self.beta * F.mse_loss(z_q.detach(), z)
|
438 |
+
|
439 |
+
# preserve gradients
|
440 |
+
z_q = z + (z_q - z).detach()
|
441 |
+
|
442 |
+
# reshape back to match original input shape
|
443 |
+
#z_q, 'b h w c -> b c h w'
|
444 |
+
z_q = rearrange(z_q, 'b h w c -> b c h w')
|
445 |
+
return z_q, loss, (perplexity, encodings, encoding_indices)
|
taming-transformers/taming/util.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, hashlib
|
2 |
+
import requests
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
URL_MAP = {
|
6 |
+
"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
|
7 |
+
}
|
8 |
+
|
9 |
+
CKPT_MAP = {
|
10 |
+
"vgg_lpips": "vgg.pth"
|
11 |
+
}
|
12 |
+
|
13 |
+
MD5_MAP = {
|
14 |
+
"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
|
15 |
+
}
|
16 |
+
|
17 |
+
|
18 |
+
def download(url, local_path, chunk_size=1024):
|
19 |
+
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
|
20 |
+
with requests.get(url, stream=True) as r:
|
21 |
+
total_size = int(r.headers.get("content-length", 0))
|
22 |
+
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
|
23 |
+
with open(local_path, "wb") as f:
|
24 |
+
for data in r.iter_content(chunk_size=chunk_size):
|
25 |
+
if data:
|
26 |
+
f.write(data)
|
27 |
+
pbar.update(chunk_size)
|
28 |
+
|
29 |
+
|
30 |
+
def md5_hash(path):
|
31 |
+
with open(path, "rb") as f:
|
32 |
+
content = f.read()
|
33 |
+
return hashlib.md5(content).hexdigest()
|
34 |
+
|
35 |
+
|
36 |
+
def get_ckpt_path(name, root, check=False):
|
37 |
+
assert name in URL_MAP
|
38 |
+
path = os.path.join(root, CKPT_MAP[name])
|
39 |
+
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
|
40 |
+
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
|
41 |
+
download(URL_MAP[name], path)
|
42 |
+
md5 = md5_hash(path)
|
43 |
+
assert md5 == MD5_MAP[name], md5
|
44 |
+
return path
|
45 |
+
|
46 |
+
|
47 |
+
class KeyNotFoundError(Exception):
|
48 |
+
def __init__(self, cause, keys=None, visited=None):
|
49 |
+
self.cause = cause
|
50 |
+
self.keys = keys
|
51 |
+
self.visited = visited
|
52 |
+
messages = list()
|
53 |
+
if keys is not None:
|
54 |
+
messages.append("Key not found: {}".format(keys))
|
55 |
+
if visited is not None:
|
56 |
+
messages.append("Visited: {}".format(visited))
|
57 |
+
messages.append("Cause:\n{}".format(cause))
|
58 |
+
message = "\n".join(messages)
|
59 |
+
super().__init__(message)
|
60 |
+
|
61 |
+
|
62 |
+
def retrieve(
|
63 |
+
list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
|
64 |
+
):
|
65 |
+
"""Given a nested list or dict return the desired value at key expanding
|
66 |
+
callable nodes if necessary and :attr:`expand` is ``True``. The expansion
|
67 |
+
is done in-place.
|
68 |
+
|
69 |
+
Parameters
|
70 |
+
----------
|
71 |
+
list_or_dict : list or dict
|
72 |
+
Possibly nested list or dictionary.
|
73 |
+
key : str
|
74 |
+
key/to/value, path like string describing all keys necessary to
|
75 |
+
consider to get to the desired value. List indices can also be
|
76 |
+
passed here.
|
77 |
+
splitval : str
|
78 |
+
String that defines the delimiter between keys of the
|
79 |
+
different depth levels in `key`.
|
80 |
+
default : obj
|
81 |
+
Value returned if :attr:`key` is not found.
|
82 |
+
expand : bool
|
83 |
+
Whether to expand callable nodes on the path or not.
|
84 |
+
|
85 |
+
Returns
|
86 |
+
-------
|
87 |
+
The desired value or if :attr:`default` is not ``None`` and the
|
88 |
+
:attr:`key` is not found returns ``default``.
|
89 |
+
|
90 |
+
Raises
|
91 |
+
------
|
92 |
+
Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
|
93 |
+
``None``.
|
94 |
+
"""
|
95 |
+
|
96 |
+
keys = key.split(splitval)
|
97 |
+
|
98 |
+
success = True
|
99 |
+
try:
|
100 |
+
visited = []
|
101 |
+
parent = None
|
102 |
+
last_key = None
|
103 |
+
for key in keys:
|
104 |
+
if callable(list_or_dict):
|
105 |
+
if not expand:
|
106 |
+
raise KeyNotFoundError(
|
107 |
+
ValueError(
|
108 |
+
"Trying to get past callable node with expand=False."
|
109 |
+
),
|
110 |
+
keys=keys,
|
111 |
+
visited=visited,
|
112 |
+
)
|
113 |
+
list_or_dict = list_or_dict()
|
114 |
+
parent[last_key] = list_or_dict
|
115 |
+
|
116 |
+
last_key = key
|
117 |
+
parent = list_or_dict
|
118 |
+
|
119 |
+
try:
|
120 |
+
if isinstance(list_or_dict, dict):
|
121 |
+
list_or_dict = list_or_dict[key]
|
122 |
+
else:
|
123 |
+
list_or_dict = list_or_dict[int(key)]
|
124 |
+
except (KeyError, IndexError, ValueError) as e:
|
125 |
+
raise KeyNotFoundError(e, keys=keys, visited=visited)
|
126 |
+
|
127 |
+
visited += [key]
|
128 |
+
# final expansion of retrieved value
|
129 |
+
if expand and callable(list_or_dict):
|
130 |
+
list_or_dict = list_or_dict()
|
131 |
+
parent[last_key] = list_or_dict
|
132 |
+
except KeyNotFoundError as e:
|
133 |
+
if default is None:
|
134 |
+
raise e
|
135 |
+
else:
|
136 |
+
list_or_dict = default
|
137 |
+
success = False
|
138 |
+
|
139 |
+
if not pass_success:
|
140 |
+
return list_or_dict
|
141 |
+
else:
|
142 |
+
return list_or_dict, success
|
143 |
+
|
144 |
+
|
145 |
+
if __name__ == "__main__":
|
146 |
+
config = {"keya": "a",
|
147 |
+
"keyb": "b",
|
148 |
+
"keyc":
|
149 |
+
{"cc1": 1,
|
150 |
+
"cc2": 2,
|
151 |
+
}
|
152 |
+
}
|
153 |
+
from omegaconf import OmegaConf
|
154 |
+
config = OmegaConf.create(config)
|
155 |
+
print(config)
|
156 |
+
retrieve(config, "keya")
|
157 |
+
|
taming-transformers/taming_transformers.egg-info/PKG-INFO
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Metadata-Version: 2.1
|
2 |
+
Name: taming-transformers
|
3 |
+
Version: 0.0.1
|
4 |
+
Summary: Taming Transformers for High-Resolution Image Synthesis
|
5 |
+
Home-page: UNKNOWN
|
6 |
+
License: UNKNOWN
|
7 |
+
Platform: UNKNOWN
|
8 |
+
|
9 |
+
UNKNOWN
|
10 |
+
|
taming-transformers/taming_transformers.egg-info/SOURCES.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
README.md
|
2 |
+
setup.py
|
3 |
+
taming_transformers.egg-info/PKG-INFO
|
4 |
+
taming_transformers.egg-info/SOURCES.txt
|
5 |
+
taming_transformers.egg-info/dependency_links.txt
|
6 |
+
taming_transformers.egg-info/requires.txt
|
7 |
+
taming_transformers.egg-info/top_level.txt
|
taming-transformers/taming_transformers.egg-info/dependency_links.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
taming-transformers/taming_transformers.egg-info/requires.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
numpy
|
3 |
+
tqdm
|
taming-transformers/taming_transformers.egg-info/top_level.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|