ysharma HF staff commited on
Commit
d950775
·
1 Parent(s): f7de7b2

upload git code base

Browse files
Files changed (49) hide show
  1. .gitattributes +10 -0
  2. .gitignore +6 -0
  3. LICENSE +21 -0
  4. README.md +154 -13
  5. assets/.DS_Store +0 -0
  6. assets/comparison.jpg +3 -0
  7. assets/embeddings_sd_1.4/cat.pt +3 -0
  8. assets/embeddings_sd_1.4/dog.pt +3 -0
  9. assets/embeddings_sd_1.4/horse.pt +3 -0
  10. assets/embeddings_sd_1.4/zebra.pt +3 -0
  11. assets/grid_cat2dog.jpg +3 -0
  12. assets/grid_dog2cat.jpg +3 -0
  13. assets/grid_horse2zebra.jpg +3 -0
  14. assets/grid_tree2fall.jpg +3 -0
  15. assets/grid_zebra2horse.jpg +3 -0
  16. assets/main.gif +3 -0
  17. assets/method.jpeg +3 -0
  18. assets/results_real.jpg +3 -0
  19. assets/results_syn.jpg +3 -0
  20. assets/results_teaser.jpg +0 -0
  21. assets/test_images/cats/cat_1.png +0 -0
  22. assets/test_images/cats/cat_2.png +0 -0
  23. assets/test_images/cats/cat_3.png +0 -0
  24. assets/test_images/cats/cat_4.png +0 -0
  25. assets/test_images/cats/cat_5.png +0 -0
  26. assets/test_images/cats/cat_6.png +0 -0
  27. assets/test_images/cats/cat_7.png +0 -0
  28. assets/test_images/cats/cat_8.png +0 -0
  29. assets/test_images/cats/cat_9.png +0 -0
  30. assets/test_images/dogs/dog_1.png +0 -0
  31. assets/test_images/dogs/dog_2.png +0 -0
  32. assets/test_images/dogs/dog_3.png +0 -0
  33. assets/test_images/dogs/dog_4.png +0 -0
  34. assets/test_images/dogs/dog_5.png +0 -0
  35. assets/test_images/dogs/dog_6.png +0 -0
  36. assets/test_images/dogs/dog_7.png +0 -0
  37. assets/test_images/dogs/dog_8.png +0 -0
  38. assets/test_images/dogs/dog_9.png +0 -0
  39. environment.yml +23 -0
  40. src/edit_real.py +65 -0
  41. src/edit_synthetic.py +52 -0
  42. src/inversion.py +64 -0
  43. src/make_edit_direction.py +61 -0
  44. src/utils/base_pipeline.py +322 -0
  45. src/utils/cross_attention.py +57 -0
  46. src/utils/ddim_inv.py +140 -0
  47. src/utils/edit_directions.py +29 -0
  48. src/utils/edit_pipeline.py +174 -0
  49. src/utils/scheduler.py +289 -0
.gitattributes CHANGED
@@ -32,3 +32,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ assets/comparison.jpg filter=lfs diff=lfs merge=lfs -text
36
+ assets/grid_cat2dog.jpg filter=lfs diff=lfs merge=lfs -text
37
+ assets/grid_dog2cat.jpg filter=lfs diff=lfs merge=lfs -text
38
+ assets/grid_horse2zebra.jpg filter=lfs diff=lfs merge=lfs -text
39
+ assets/grid_tree2fall.jpg filter=lfs diff=lfs merge=lfs -text
40
+ assets/grid_zebra2horse.jpg filter=lfs diff=lfs merge=lfs -text
41
+ assets/main.gif filter=lfs diff=lfs merge=lfs -text
42
+ assets/method.jpeg filter=lfs diff=lfs merge=lfs -text
43
+ assets/results_real.jpg filter=lfs diff=lfs merge=lfs -text
44
+ assets/results_syn.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ output
2
+ scripts
3
+ src/folder_*.py
4
+ src/ig_*.py
5
+ assets/edit_sentences
6
+ src/utils/edit_pipeline_spatial.py
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 pix2pixzero
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,13 +1,154 @@
1
- ---
2
- title: Pix2pix Zero
3
- emoji: 🕶
4
- colorFrom: gray
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 3.18.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pix2pix-zero
2
+
3
+ ## [**[website]**](https://pix2pixzero.github.io/)
4
+
5
+
6
+ This is author's reimplementation of "Zero-shot Image-to-Image Translation" using the diffusers library. <br>
7
+ The results in the paper are based on the [CompVis](https://github.com/CompVis/stable-diffusion) library, which will be released later.
8
+
9
+ **[New!]** Code for editing real and synthetic images released!
10
+
11
+
12
+
13
+ <br>
14
+ <div class="gif">
15
+ <p align="center">
16
+ <img src='assets/main.gif' align="center">
17
+ </p>
18
+ </div>
19
+
20
+
21
+ We propose pix2pix-zero, a diffusion-based image-to-image approach that allows users to specify the edit direction on-the-fly (e.g., cat to dog). Our method can directly use pre-trained [Stable Diffusion](https://github.com/CompVis/stable-diffusion), for editing real and synthetic images while preserving the input image's structure. Our method is training-free and prompt-free, as it requires neither manual text prompting for each input image nor costly fine-tuning for each task.
22
+
23
+ **TL;DR**: no finetuning required, no text input needed, input structure preserved.
24
+
25
+ ## Results
26
+ All our results are based on [stable-diffusion-v1-4](https://github.com/CompVis/stable-diffusion) model. Please the website for more results.
27
+
28
+ <div>
29
+ <p align="center">
30
+ <img src='assets/results_teaser.jpg' align="center" width=800px>
31
+ </p>
32
+ </div>
33
+ <hr>
34
+
35
+ The top row for each of the results below show editing of real images, and the bottom row shows synthetic image editing.
36
+ <div>
37
+ <p align="center">
38
+ <img src='assets/grid_dog2cat.jpg' align="center" width=800px>
39
+ </p>
40
+ <p align="center">
41
+ <img src='assets/grid_zebra2horse.jpg' align="center" width=800px>
42
+ </p>
43
+ <p align="center">
44
+ <img src='assets/grid_cat2dog.jpg' align="center" width=800px>
45
+ </p>
46
+ <p align="center">
47
+ <img src='assets/grid_horse2zebra.jpg' align="center" width=800px>
48
+ </p>
49
+ <p align="center">
50
+ <img src='assets/grid_tree2fall.jpg' align="center" width=800px>
51
+ </p>
52
+ </div>
53
+
54
+ ## Real Image Editing
55
+ <div>
56
+ <p align="center">
57
+ <img src='assets/results_real.jpg' align="center" width=800px>
58
+ </p>
59
+ </div>
60
+
61
+ ## Synthetic Image Editing
62
+ <div>
63
+ <p align="center">
64
+ <img src='assets/results_syn.jpg' align="center" width=800px>
65
+ </p>
66
+ </div>
67
+
68
+ ## Method Details
69
+
70
+ Given an input image, we first generate text captions using [BLIP](https://github.com/salesforce/LAVIS) and apply regularized DDIM inversion to obtain our inverted noise map.
71
+ Then, we obtain reference cross-attention maps that correspoind to the structure of the input image by denoising, guided with the CLIP embeddings
72
+ of our generated text (c). Next, we denoise with edited text embeddings, while enforcing a loss to match current cross-attention maps with the
73
+ reference cross-attention maps.
74
+
75
+ <div>
76
+ <p align="center">
77
+ <img src='assets/method.jpeg' align="center" width=900>
78
+ </p>
79
+ </div>
80
+
81
+
82
+ ## Getting Started
83
+
84
+ **Environment Setup**
85
+ - We provide a [conda env file](environment.yml) that contains all the required dependencies
86
+ ```
87
+ conda env create -f environment.yml
88
+ ```
89
+ - Following this, you can activate the conda environment with the command below.
90
+ ```
91
+ conda activate pix2pix-zero
92
+ ```
93
+
94
+ **Real Image Translation**
95
+ - First, run the inversion command below to obtain the input noise that reconstructs the image.
96
+ The command below will save the inversion in the results folder as `output/test_cat/inversion/cat_1.pt`
97
+ and the BLIP-generated prompt as `output/test_cat/prompt/cat_1.txt`
98
+ ```
99
+ python src/inversion.py \
100
+ --input_image "assets/test_images/cats/cat_1.png" \
101
+ --results_folder "output/test_cat"
102
+ ```
103
+ - Next, we can perform image editing with the editing direction as shown below.
104
+ The command below will save the edited image as `output/test_cat/edit/cat_1.png`
105
+ ```
106
+ python src/edit_real.py \
107
+ --inversion "output/test_cat/inversion/cat_1.pt" \
108
+ --prompt "output/test_cat/prompt/cat_1.txt" \
109
+ --task_name "cat2dog" \
110
+ --results_folder "output/test_cat/"
111
+ ```
112
+
113
+ **Editing Synthetic Images**
114
+ - Similarly, we can edit the synthetic images generated by Stable Diffusion with the following command.
115
+ ```
116
+ python src/edit_synthetic.py \
117
+ --results_folder "output/synth_editing" \
118
+ --prompt_str "a high resolution painting of a cat in the style of van gough" \
119
+ --task "cat2dog"
120
+ ```
121
+
122
+ ### **Tips and Debugging**
123
+ - **Controlling the Image Structure:**<br>
124
+ The `--xa_guidance` flag controls the amount of cross-attention guidance to be applied when performing the edit. If the output edited image does not retain the structure from the input, increasing the value will typically address the issue. We recommend changing the value in increments of 0.05.
125
+
126
+ - **Improving Image Quality:**<br>
127
+ If the output image quality is low or has some artifacts, using more steps for both the inversion and editing would be helpful.
128
+ This can be controlled with the `--num_ddim_steps` flag.
129
+
130
+ - **Reducing the VRAM Requirements:**<br>
131
+ We can reduce the VRAM requirements using lower precision and setting the flag `--use_float_16`.
132
+
133
+ <br>
134
+
135
+ **Finding Custom Edit Directions**<br>
136
+ - We provide some pre-computed directions in the assets [folder](assets/embeddings_sd_1.4).
137
+ To generate new edit directions, users can first generate two files containing a large number of sentences (~1000) and then run the command as shown below.
138
+ ```
139
+ python src/make_edit_direction.py \
140
+ --file_source_sentences sentences/apple.txt \
141
+ --file_target_sentences sentences/orange.txt \
142
+ --output_folder assets/embeddings_sd_1.4
143
+ ```
144
+ - After running the above command, you can set the flag `--task apple2orange` for the new edit.
145
+
146
+
147
+
148
+ ## Comparison
149
+ Comparisons with different baselines, including, SDEdit + word swap, DDIM + word swap, and prompt-to-propmt. Our method successfully applies the edit, while preserving the structure of the input image.
150
+ <div>
151
+ <p align="center">
152
+ <img src='assets/comparison.jpg' align="center" width=900>
153
+ </p>
154
+ </div>
assets/.DS_Store ADDED
Binary file (6.15 kB). View file
 
assets/comparison.jpg ADDED

Git LFS Details

  • SHA256: acab8ed1680a42dd2f540e8188a43eb0d101895fca8ed36c0e06c8b351d2c276
  • Pointer size: 132 Bytes
  • Size of remote file: 3.39 MB
assets/embeddings_sd_1.4/cat.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa9441dc014d5e86567c5ef165e10b50d2a7b3a68d90686d0cd1006792adf334
3
+ size 237300
assets/embeddings_sd_1.4/dog.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:becf079d61d7f35727bcc0d8506ddcdcddb61e62d611840ff3d18eca7fb6338c
3
+ size 237300
assets/embeddings_sd_1.4/horse.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5d499299544d11371f84674761292b0512055ef45776c700c0b0da164cbf6c7
3
+ size 118949
assets/embeddings_sd_1.4/zebra.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a29f6a11d91f3a276e27326b7623fae9d61a3d253ad430bb868bd40fb7e02fec
3
+ size 118949
assets/grid_cat2dog.jpg ADDED

Git LFS Details

  • SHA256: 0080134b70277af723e25c4627494fda8555d43a9f6376e682b67b3341d1f1f3
  • Pointer size: 132 Bytes
  • Size of remote file: 1.21 MB
assets/grid_dog2cat.jpg ADDED

Git LFS Details

  • SHA256: 0e5059ec1ad8e4b07fe8b715295e82fcead652b9c366733793674e84d51427d9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.25 MB
assets/grid_horse2zebra.jpg ADDED

Git LFS Details

  • SHA256: a31e0a456e9323697c966e675b02403511ebf0b7c334416a8da91df1c14723df
  • Pointer size: 132 Bytes
  • Size of remote file: 1.05 MB
assets/grid_tree2fall.jpg ADDED

Git LFS Details

  • SHA256: 559ab066e4ef0972748d0a7f004d2ca18fd15062c667ac6665309727f6dc0cc8
  • Pointer size: 132 Bytes
  • Size of remote file: 1.63 MB
assets/grid_zebra2horse.jpg ADDED

Git LFS Details

  • SHA256: b44b4aa4576be49289515f0aa9023dfd4424b3ba2476c66516b876dd83a06713
  • Pointer size: 132 Bytes
  • Size of remote file: 1.05 MB
assets/main.gif ADDED

Git LFS Details

  • SHA256: d1ebc380a461c4847beece13bdc9b5ea88312e8a8013f384eb8809109ff198fc
  • Pointer size: 132 Bytes
  • Size of remote file: 6.19 MB
assets/method.jpeg ADDED

Git LFS Details

  • SHA256: 8b1b4ea3608b9ad3797c4c7423bf2fd88e5e24f34fecbb00d3d2de22a99fd2ee
  • Pointer size: 132 Bytes
  • Size of remote file: 2.35 MB
assets/results_real.jpg ADDED

Git LFS Details

  • SHA256: 94095526e76b7a000ed56df15f7b5208c0f5a069b20b04fc9bcade14c54d92dc
  • Pointer size: 132 Bytes
  • Size of remote file: 1.48 MB
assets/results_syn.jpg ADDED

Git LFS Details

  • SHA256: 5731190e33098406995de563ca12bd6d2f84d9db725618a6d6580b4d1f2f0813
  • Pointer size: 132 Bytes
  • Size of remote file: 1.28 MB
assets/results_teaser.jpg ADDED
assets/test_images/cats/cat_1.png ADDED
assets/test_images/cats/cat_2.png ADDED
assets/test_images/cats/cat_3.png ADDED
assets/test_images/cats/cat_4.png ADDED
assets/test_images/cats/cat_5.png ADDED
assets/test_images/cats/cat_6.png ADDED
assets/test_images/cats/cat_7.png ADDED
assets/test_images/cats/cat_8.png ADDED
assets/test_images/cats/cat_9.png ADDED
assets/test_images/dogs/dog_1.png ADDED
assets/test_images/dogs/dog_2.png ADDED
assets/test_images/dogs/dog_3.png ADDED
assets/test_images/dogs/dog_4.png ADDED
assets/test_images/dogs/dog_5.png ADDED
assets/test_images/dogs/dog_6.png ADDED
assets/test_images/dogs/dog_7.png ADDED
assets/test_images/dogs/dog_8.png ADDED
assets/test_images/dogs/dog_9.png ADDED
environment.yml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: pix2pix-zero
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - defaults
6
+ dependencies:
7
+ - pip
8
+ - pytorch-cuda=11.6
9
+ - torchvision
10
+ - pytorch
11
+ - pip:
12
+ - accelerate
13
+ - diffusers
14
+ - einops
15
+ - gradio
16
+ - ipython
17
+ - numpy
18
+ - opencv-python-headless
19
+ - pillow
20
+ - psutil
21
+ - tqdm
22
+ - transformers
23
+ - salesforce-lavis
src/edit_real.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, pdb
2
+
3
+ import argparse
4
+ import numpy as np
5
+ import torch
6
+ import requests
7
+ from PIL import Image
8
+
9
+ from diffusers import DDIMScheduler
10
+ from utils.ddim_inv import DDIMInversion
11
+ from utils.edit_directions import construct_direction
12
+ from utils.edit_pipeline import EditingPipeline
13
+
14
+
15
+ if __name__=="__main__":
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument('--inversion', required=True)
18
+ parser.add_argument('--prompt', type=str, required=True)
19
+ parser.add_argument('--task_name', type=str, default='cat2dog')
20
+ parser.add_argument('--results_folder', type=str, default='output/test_cat')
21
+ parser.add_argument('--num_ddim_steps', type=int, default=50)
22
+ parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
23
+ parser.add_argument('--xa_guidance', default=0.1, type=float)
24
+ parser.add_argument('--negative_guidance_scale', default=5.0, type=float)
25
+ parser.add_argument('--use_float_16', action='store_true')
26
+
27
+ args = parser.parse_args()
28
+
29
+ os.makedirs(os.path.join(args.results_folder, "edit"), exist_ok=True)
30
+ os.makedirs(os.path.join(args.results_folder, "reconstruction"), exist_ok=True)
31
+
32
+ if args.use_float_16:
33
+ torch_dtype = torch.float16
34
+ else:
35
+ torch_dtype = torch.float32
36
+
37
+ # if the inversion is a folder, the prompt should also be a folder
38
+ assert (os.path.isdir(args.inversion)==os.path.isdir(args.prompt)), "If the inversion is a folder, the prompt should also be a folder"
39
+ if os.path.isdir(args.inversion):
40
+ l_inv_paths = sorted(glob(os.path.join(args.inversion, "*.pt")))
41
+ l_bnames = [os.path.basename(x) for x in l_inv_paths]
42
+ l_prompt_paths = [os.path.join(args.prompt, x.replace(".pt",".txt")) for x in l_bnames]
43
+ else:
44
+ l_inv_paths = [args.inversion]
45
+ l_prompt_paths = [args.prompt]
46
+
47
+ # Make the editing pipeline
48
+ pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch_dtype).to("cuda")
49
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
50
+
51
+
52
+ for inv_path, prompt_path in zip(l_inv_paths, l_prompt_paths):
53
+ prompt_str = open(prompt_path).read().strip()
54
+ rec_pil, edit_pil = pipe(prompt_str,
55
+ num_inference_steps=args.num_ddim_steps,
56
+ x_in=torch.load(inv_path).unsqueeze(0),
57
+ edit_dir=construct_direction(args.task_name),
58
+ guidance_amount=args.xa_guidance,
59
+ guidance_scale=args.negative_guidance_scale,
60
+ negative_prompt=prompt_str # use the unedited prompt for the negative prompt
61
+ )
62
+
63
+ bname = os.path.basename(args.inversion).split(".")[0]
64
+ edit_pil[0].save(os.path.join(args.results_folder, f"edit/{bname}.png"))
65
+ rec_pil[0].save(os.path.join(args.results_folder, f"reconstruction/{bname}.png"))
src/edit_synthetic.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, pdb
2
+
3
+ import argparse
4
+ import numpy as np
5
+ import torch
6
+ import requests
7
+ from PIL import Image
8
+
9
+ from diffusers import DDIMScheduler
10
+ from utils.edit_directions import construct_direction
11
+ from utils.edit_pipeline import EditingPipeline
12
+
13
+
14
+ if __name__=="__main__":
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument('--prompt_str', type=str, required=True)
17
+ parser.add_argument('--random_seed', default=0)
18
+ parser.add_argument('--task_name', type=str, default='cat2dog')
19
+ parser.add_argument('--results_folder', type=str, default='output/test_cat')
20
+ parser.add_argument('--num_ddim_steps', type=int, default=50)
21
+ parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
22
+ parser.add_argument('--xa_guidance', default=0.15, type=float)
23
+ parser.add_argument('--negative_guidance_scale', default=5.0, type=float)
24
+ parser.add_argument('--use_float_16', action='store_true')
25
+ args = parser.parse_args()
26
+
27
+ os.makedirs(args.results_folder, exist_ok=True)
28
+
29
+ if args.use_float_16:
30
+ torch_dtype = torch.float16
31
+ else:
32
+ torch_dtype = torch.float32
33
+
34
+ # make the input noise map
35
+ torch.cuda.manual_seed(args.random_seed)
36
+ x = torch.randn((1,4,64,64), device="cuda")
37
+
38
+ # Make the editing pipeline
39
+ pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch_dtype).to("cuda")
40
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
41
+
42
+ rec_pil, edit_pil = pipe(args.prompt_str,
43
+ num_inference_steps=args.num_ddim_steps,
44
+ x_in=x,
45
+ edit_dir=construct_direction(args.task_name),
46
+ guidance_amount=args.xa_guidance,
47
+ guidance_scale=args.negative_guidance_scale,
48
+ negative_prompt="" # use the empty string for the negative prompt
49
+ )
50
+
51
+ edit_pil[0].save(os.path.join(args.results_folder, f"edit.png"))
52
+ rec_pil[0].save(os.path.join(args.results_folder, f"reconstruction.png"))
src/inversion.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, pdb
2
+
3
+ import argparse
4
+ import numpy as np
5
+ import torch
6
+ import requests
7
+ from PIL import Image
8
+
9
+ from lavis.models import load_model_and_preprocess
10
+
11
+ from utils.ddim_inv import DDIMInversion
12
+ from utils.scheduler import DDIMInverseScheduler
13
+
14
+ if __name__=="__main__":
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument('--input_image', type=str, default='assets/test_images/cat_a.png')
17
+ parser.add_argument('--results_folder', type=str, default='output/test_cat')
18
+ parser.add_argument('--num_ddim_steps', type=int, default=50)
19
+ parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
20
+ parser.add_argument('--use_float_16', action='store_true')
21
+ args = parser.parse_args()
22
+
23
+ # make the output folders
24
+ os.makedirs(os.path.join(args.results_folder, "inversion"), exist_ok=True)
25
+ os.makedirs(os.path.join(args.results_folder, "prompt"), exist_ok=True)
26
+
27
+ if args.use_float_16:
28
+ torch_dtype = torch.float16
29
+ else:
30
+ torch_dtype = torch.float32
31
+
32
+
33
+ # load the BLIP model
34
+ model_blip, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=torch.device("cuda"))
35
+ # make the DDIM inversion pipeline
36
+ pipe = DDIMInversion.from_pretrained(args.model_path, torch_dtype=torch_dtype).to("cuda")
37
+ pipe.scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
38
+
39
+
40
+ # if the input is a folder, collect all the images as a list
41
+ if os.path.isdir(args.input_image):
42
+ l_img_paths = sorted(glob(os.path.join(args.input_image, "*.png")))
43
+ else:
44
+ l_img_paths = [args.input_image]
45
+
46
+
47
+ for img_path in l_img_paths:
48
+ bname = os.path.basename(args.input_image).split(".")[0]
49
+ img = Image.open(args.input_image).resize((512,512), Image.Resampling.LANCZOS)
50
+ # generate the caption
51
+ _image = vis_processors["eval"](img).unsqueeze(0).cuda()
52
+ prompt_str = model_blip.generate({"image": _image})[0]
53
+ x_inv, x_inv_image, x_dec_img = pipe(
54
+ prompt_str,
55
+ guidance_scale=1,
56
+ num_inversion_steps=args.num_ddim_steps,
57
+ img=img,
58
+ torch_dtype=torch_dtype
59
+ )
60
+ # save the inversion
61
+ torch.save(x_inv[0], os.path.join(args.results_folder, f"inversion/{bname}.pt"))
62
+ # save the prompt string
63
+ with open(os.path.join(args.results_folder, f"prompt/{bname}.txt"), "w") as f:
64
+ f.write(prompt_str)
src/make_edit_direction.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, pdb
2
+
3
+ import argparse
4
+ import numpy as np
5
+ import torch
6
+ import requests
7
+ from PIL import Image
8
+
9
+ from diffusers import DDIMScheduler
10
+ from utils.edit_pipeline import EditingPipeline
11
+
12
+
13
+ ## convert sentences to sentence embeddings
14
+ def load_sentence_embeddings(l_sentences, tokenizer, text_encoder, device="cuda"):
15
+ with torch.no_grad():
16
+ l_embeddings = []
17
+ for sent in l_sentences:
18
+ text_inputs = tokenizer(
19
+ sent,
20
+ padding="max_length",
21
+ max_length=tokenizer.model_max_length,
22
+ truncation=True,
23
+ return_tensors="pt",
24
+ )
25
+ text_input_ids = text_inputs.input_ids
26
+ prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0]
27
+ l_embeddings.append(prompt_embeds)
28
+ return torch.concatenate(l_embeddings, dim=0).mean(dim=0).unsqueeze(0)
29
+
30
+
31
+ if __name__=="__main__":
32
+ parser = argparse.ArgumentParser()
33
+ parser.add_argument('--file_source_sentences', required=True)
34
+ parser.add_argument('--file_target_sentences', required=True)
35
+ parser.add_argument('--output_folder', required=True)
36
+ parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
37
+ args = parser.parse_args()
38
+
39
+ # load the model
40
+ pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch.float16).to("cuda")
41
+ bname_src = os.path.basename(args.file_source_sentences).strip(".txt")
42
+ outf_src = os.path.join(args.output_folder, bname_src+".pt")
43
+ if os.path.exists(outf_src):
44
+ print(f"Skipping source file {outf_src} as it already exists")
45
+ else:
46
+ with open(args.file_source_sentences, "r") as f:
47
+ l_sents = [x.strip() for x in f.readlines()]
48
+ mean_emb = load_sentence_embeddings(l_sents, pipe.tokenizer, pipe.text_encoder, device="cuda")
49
+ print(mean_emb.shape)
50
+ torch.save(mean_emb, outf_src)
51
+
52
+ bname_tgt = os.path.basename(args.file_target_sentences).strip(".txt")
53
+ outf_tgt = os.path.join(args.output_folder, bname_tgt+".pt")
54
+ if os.path.exists(outf_tgt):
55
+ print(f"Skipping target file {outf_tgt} as it already exists")
56
+ else:
57
+ with open(args.file_target_sentences, "r") as f:
58
+ l_sents = [x.strip() for x in f.readlines()]
59
+ mean_emb = load_sentence_embeddings(l_sents, pipe.tokenizer, pipe.text_encoder, device="cuda")
60
+ print(mean_emb.shape)
61
+ torch.save(mean_emb, outf_tgt)
src/utils/base_pipeline.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import inspect
4
+ from packaging import version
5
+ from typing import Any, Callable, Dict, List, Optional, Union
6
+
7
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
8
+ from diffusers import DiffusionPipeline
9
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
10
+ from diffusers.schedulers import KarrasDiffusionSchedulers
11
+ from diffusers.utils import deprecate, is_accelerate_available, logging, randn_tensor, replace_example_docstring
12
+ from diffusers import StableDiffusionPipeline
13
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
14
+
15
+
16
+
17
+ class BasePipeline(DiffusionPipeline):
18
+ _optional_components = ["safety_checker", "feature_extractor"]
19
+ def __init__(
20
+ self,
21
+ vae: AutoencoderKL,
22
+ text_encoder: CLIPTextModel,
23
+ tokenizer: CLIPTokenizer,
24
+ unet: UNet2DConditionModel,
25
+ scheduler: KarrasDiffusionSchedulers,
26
+ safety_checker: StableDiffusionSafetyChecker,
27
+ feature_extractor: CLIPFeatureExtractor,
28
+ requires_safety_checker: bool = True,
29
+ ):
30
+ super().__init__()
31
+
32
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
33
+ deprecation_message = (
34
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
35
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
36
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
37
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
38
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
39
+ " file"
40
+ )
41
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
42
+ new_config = dict(scheduler.config)
43
+ new_config["steps_offset"] = 1
44
+ scheduler._internal_dict = FrozenDict(new_config)
45
+
46
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
47
+ deprecation_message = (
48
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
49
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
50
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
51
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
52
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
53
+ )
54
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
55
+ new_config = dict(scheduler.config)
56
+ new_config["clip_sample"] = False
57
+ scheduler._internal_dict = FrozenDict(new_config)
58
+
59
+ if safety_checker is None and requires_safety_checker:
60
+ logger.warning(
61
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
62
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
63
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
64
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
65
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
66
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
67
+ )
68
+
69
+ if safety_checker is not None and feature_extractor is None:
70
+ raise ValueError(
71
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
72
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
73
+ )
74
+
75
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
76
+ version.parse(unet.config._diffusers_version).base_version
77
+ ) < version.parse("0.9.0.dev0")
78
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
79
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
80
+ deprecation_message = (
81
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
82
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
83
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
84
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
85
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
86
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
87
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
88
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
89
+ " the `unet/config.json` file"
90
+ )
91
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
92
+ new_config = dict(unet.config)
93
+ new_config["sample_size"] = 64
94
+ unet._internal_dict = FrozenDict(new_config)
95
+
96
+ self.register_modules(
97
+ vae=vae,
98
+ text_encoder=text_encoder,
99
+ tokenizer=tokenizer,
100
+ unet=unet,
101
+ scheduler=scheduler,
102
+ safety_checker=safety_checker,
103
+ feature_extractor=feature_extractor,
104
+ )
105
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
106
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
107
+
108
+ @property
109
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
110
+ def _execution_device(self):
111
+ r"""
112
+ Returns the device on which the pipeline's models will be executed. After calling
113
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
114
+ hooks.
115
+ """
116
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
117
+ return self.device
118
+ for module in self.unet.modules():
119
+ if (
120
+ hasattr(module, "_hf_hook")
121
+ and hasattr(module._hf_hook, "execution_device")
122
+ and module._hf_hook.execution_device is not None
123
+ ):
124
+ return torch.device(module._hf_hook.execution_device)
125
+ return self.device
126
+
127
+
128
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
129
+ def _encode_prompt(
130
+ self,
131
+ prompt,
132
+ device,
133
+ num_images_per_prompt,
134
+ do_classifier_free_guidance,
135
+ negative_prompt=None,
136
+ prompt_embeds: Optional[torch.FloatTensor] = None,
137
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
138
+ ):
139
+ r"""
140
+ Encodes the prompt into text encoder hidden states.
141
+
142
+ Args:
143
+ prompt (`str` or `List[str]`, *optional*):
144
+ prompt to be encoded
145
+ device: (`torch.device`):
146
+ torch device
147
+ num_images_per_prompt (`int`):
148
+ number of images that should be generated per prompt
149
+ do_classifier_free_guidance (`bool`):
150
+ whether to use classifier free guidance or not
151
+ negative_ prompt (`str` or `List[str]`, *optional*):
152
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
153
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
154
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
155
+ prompt_embeds (`torch.FloatTensor`, *optional*):
156
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
157
+ provided, text embeddings will be generated from `prompt` input argument.
158
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
159
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
160
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
161
+ argument.
162
+ """
163
+ if prompt is not None and isinstance(prompt, str):
164
+ batch_size = 1
165
+ elif prompt is not None and isinstance(prompt, list):
166
+ batch_size = len(prompt)
167
+ else:
168
+ batch_size = prompt_embeds.shape[0]
169
+
170
+ if prompt_embeds is None:
171
+ text_inputs = self.tokenizer(
172
+ prompt,
173
+ padding="max_length",
174
+ max_length=self.tokenizer.model_max_length,
175
+ truncation=True,
176
+ return_tensors="pt",
177
+ )
178
+ text_input_ids = text_inputs.input_ids
179
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
180
+
181
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
182
+ text_input_ids, untruncated_ids
183
+ ):
184
+ removed_text = self.tokenizer.batch_decode(
185
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
186
+ )
187
+ logger.warning(
188
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
189
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
190
+ )
191
+
192
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
193
+ attention_mask = text_inputs.attention_mask.to(device)
194
+ else:
195
+ attention_mask = None
196
+
197
+ prompt_embeds = self.text_encoder(
198
+ text_input_ids.to(device),
199
+ attention_mask=attention_mask,
200
+ )
201
+ prompt_embeds = prompt_embeds[0]
202
+
203
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
204
+
205
+ bs_embed, seq_len, _ = prompt_embeds.shape
206
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
207
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
208
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
209
+
210
+ # get unconditional embeddings for classifier free guidance
211
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
212
+ uncond_tokens: List[str]
213
+ if negative_prompt is None:
214
+ uncond_tokens = [""] * batch_size
215
+ elif type(prompt) is not type(negative_prompt):
216
+ raise TypeError(
217
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
218
+ f" {type(prompt)}."
219
+ )
220
+ elif isinstance(negative_prompt, str):
221
+ uncond_tokens = [negative_prompt]
222
+ elif batch_size != len(negative_prompt):
223
+ raise ValueError(
224
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
225
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
226
+ " the batch size of `prompt`."
227
+ )
228
+ else:
229
+ uncond_tokens = negative_prompt
230
+
231
+ max_length = prompt_embeds.shape[1]
232
+ uncond_input = self.tokenizer(
233
+ uncond_tokens,
234
+ padding="max_length",
235
+ max_length=max_length,
236
+ truncation=True,
237
+ return_tensors="pt",
238
+ )
239
+
240
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
241
+ attention_mask = uncond_input.attention_mask.to(device)
242
+ else:
243
+ attention_mask = None
244
+
245
+ negative_prompt_embeds = self.text_encoder(
246
+ uncond_input.input_ids.to(device),
247
+ attention_mask=attention_mask,
248
+ )
249
+ negative_prompt_embeds = negative_prompt_embeds[0]
250
+
251
+ if do_classifier_free_guidance:
252
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
253
+ seq_len = negative_prompt_embeds.shape[1]
254
+
255
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
256
+
257
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
258
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
259
+
260
+ # For classifier free guidance, we need to do two forward passes.
261
+ # Here we concatenate the unconditional and text embeddings into a single batch
262
+ # to avoid doing two forward passes
263
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
264
+
265
+ return prompt_embeds
266
+
267
+
268
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
269
+ def decode_latents(self, latents):
270
+ latents = 1 / 0.18215 * latents
271
+ image = self.vae.decode(latents).sample
272
+ image = (image / 2 + 0.5).clamp(0, 1)
273
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
274
+ image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
275
+ return image
276
+
277
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
278
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
279
+ if isinstance(generator, list) and len(generator) != batch_size:
280
+ raise ValueError(
281
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
282
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
283
+ )
284
+
285
+ if latents is None:
286
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
287
+ else:
288
+ latents = latents.to(device)
289
+
290
+ # scale the initial noise by the standard deviation required by the scheduler
291
+ latents = latents * self.scheduler.init_noise_sigma
292
+ return latents
293
+
294
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
295
+ def prepare_extra_step_kwargs(self, generator, eta):
296
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
297
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
298
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
299
+ # and should be between [0, 1]
300
+
301
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
302
+ extra_step_kwargs = {}
303
+ if accepts_eta:
304
+ extra_step_kwargs["eta"] = eta
305
+
306
+ # check if the scheduler accepts generator
307
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
308
+ if accepts_generator:
309
+ extra_step_kwargs["generator"] = generator
310
+ return extra_step_kwargs
311
+
312
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
313
+ def run_safety_checker(self, image, device, dtype):
314
+ if self.safety_checker is not None:
315
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
316
+ image, has_nsfw_concept = self.safety_checker(
317
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
318
+ )
319
+ else:
320
+ has_nsfw_concept = None
321
+ return image, has_nsfw_concept
322
+
src/utils/cross_attention.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.models.attention import CrossAttention
3
+
4
+ class MyCrossAttnProcessor:
5
+ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
6
+ batch_size, sequence_length, _ = hidden_states.shape
7
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
8
+
9
+ query = attn.to_q(hidden_states)
10
+
11
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
12
+ key = attn.to_k(encoder_hidden_states)
13
+ value = attn.to_v(encoder_hidden_states)
14
+
15
+ query = attn.head_to_batch_dim(query)
16
+ key = attn.head_to_batch_dim(key)
17
+ value = attn.head_to_batch_dim(value)
18
+
19
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
20
+ # new bookkeeping to save the attn probs
21
+ attn.attn_probs = attention_probs
22
+
23
+ hidden_states = torch.bmm(attention_probs, value)
24
+ hidden_states = attn.batch_to_head_dim(hidden_states)
25
+
26
+ # linear proj
27
+ hidden_states = attn.to_out[0](hidden_states)
28
+ # dropout
29
+ hidden_states = attn.to_out[1](hidden_states)
30
+
31
+ return hidden_states
32
+
33
+
34
+ """
35
+ A function that prepares a U-Net model for training by enabling gradient computation
36
+ for a specified set of parameters and setting the forward pass to be performed by a
37
+ custom cross attention processor.
38
+
39
+ Parameters:
40
+ unet: A U-Net model.
41
+
42
+ Returns:
43
+ unet: The prepared U-Net model.
44
+ """
45
+ def prep_unet(unet):
46
+ # set the gradients for XA maps to be true
47
+ for name, params in unet.named_parameters():
48
+ if 'attn2' in name:
49
+ params.requires_grad = True
50
+ else:
51
+ params.requires_grad = False
52
+ # replace the fwd function
53
+ for name, module in unet.named_modules():
54
+ module_name = type(module).__name__
55
+ if module_name == "CrossAttention":
56
+ module.set_processor(MyCrossAttnProcessor())
57
+ return unet
src/utils/ddim_inv.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from random import randrange
6
+ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
7
+ from diffusers import DDIMScheduler
8
+ from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
9
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
10
+ sys.path.insert(0, "src/utils")
11
+ from base_pipeline import BasePipeline
12
+ from cross_attention import prep_unet
13
+
14
+
15
+ class DDIMInversion(BasePipeline):
16
+
17
+ def auto_corr_loss(self, x, random_shift=True):
18
+ B,C,H,W = x.shape
19
+ assert B==1
20
+ x = x.squeeze(0)
21
+ # x must be shape [C,H,W] now
22
+ reg_loss = 0.0
23
+ for ch_idx in range(x.shape[0]):
24
+ noise = x[ch_idx][None, None,:,:]
25
+ while True:
26
+ if random_shift: roll_amount = randrange(noise.shape[2]//2)
27
+ else: roll_amount = 1
28
+ reg_loss += (noise*torch.roll(noise, shifts=roll_amount, dims=2)).mean()**2
29
+ reg_loss += (noise*torch.roll(noise, shifts=roll_amount, dims=3)).mean()**2
30
+ if noise.shape[2] <= 8:
31
+ break
32
+ noise = F.avg_pool2d(noise, kernel_size=2)
33
+ return reg_loss
34
+
35
+ def kl_divergence(self, x):
36
+ _mu = x.mean()
37
+ _var = x.var()
38
+ return _var + _mu**2 - 1 - torch.log(_var+1e-7)
39
+
40
+
41
+ def __call__(
42
+ self,
43
+ prompt: Union[str, List[str]] = None,
44
+ num_inversion_steps: int = 50,
45
+ guidance_scale: float = 7.5,
46
+ negative_prompt: Optional[Union[str, List[str]]] = None,
47
+ num_images_per_prompt: Optional[int] = 1,
48
+ eta: float = 0.0,
49
+ output_type: Optional[str] = "pil",
50
+ return_dict: bool = True,
51
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
52
+ img=None, # the input image as a PIL image
53
+ torch_dtype=torch.float32,
54
+
55
+ # inversion regularization parameters
56
+ lambda_ac: float = 20.0,
57
+ lambda_kl: float = 20.0,
58
+ num_reg_steps: int = 5,
59
+ num_ac_rolls: int = 5,
60
+ ):
61
+
62
+ # 0. modify the unet to be useful :D
63
+ self.unet = prep_unet(self.unet)
64
+
65
+ # set the scheduler to be the Inverse DDIM scheduler
66
+ # self.scheduler = MyDDIMScheduler.from_config(self.scheduler.config)
67
+
68
+ device = self._execution_device
69
+ do_classifier_free_guidance = guidance_scale > 1.0
70
+ self.scheduler.set_timesteps(num_inversion_steps, device=device)
71
+ timesteps = self.scheduler.timesteps
72
+
73
+ # Encode the input image with the first stage model
74
+ x0 = np.array(img)/255
75
+ x0 = torch.from_numpy(x0).type(torch_dtype).permute(2, 0, 1).unsqueeze(dim=0).repeat(1, 1, 1, 1).cuda()
76
+ x0 = (x0 - 0.5) * 2.
77
+ with torch.no_grad():
78
+ x0_enc = self.vae.encode(x0).latent_dist.sample().to(device, torch_dtype)
79
+ latents = x0_enc = 0.18215 * x0_enc
80
+
81
+ # Decode and return the image
82
+ with torch.no_grad():
83
+ x0_dec = self.decode_latents(x0_enc.detach())
84
+ image_x0_dec = self.numpy_to_pil(x0_dec)
85
+
86
+ with torch.no_grad():
87
+ prompt_embeds = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt).to(device)
88
+ extra_step_kwargs = self.prepare_extra_step_kwargs(None, eta)
89
+
90
+ # Do the inversion
91
+ num_warmup_steps = len(timesteps) - num_inversion_steps * self.scheduler.order # should be 0?
92
+ with self.progress_bar(total=num_inversion_steps) as progress_bar:
93
+ for i, t in enumerate(timesteps.flip(0)[1:-1]):
94
+ # expand the latents if we are doing classifier free guidance
95
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
96
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
97
+
98
+ # predict the noise residual
99
+ with torch.no_grad():
100
+ noise_pred = self.unet(latent_model_input,t,encoder_hidden_states=prompt_embeds,cross_attention_kwargs=cross_attention_kwargs,).sample
101
+
102
+ # perform guidance
103
+ if do_classifier_free_guidance:
104
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
105
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
106
+
107
+ # regularization of the noise prediction
108
+ e_t = noise_pred
109
+ for _outer in range(num_reg_steps):
110
+ if lambda_ac>0:
111
+ for _inner in range(num_ac_rolls):
112
+ _var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
113
+ l_ac = self.auto_corr_loss(_var)
114
+ l_ac.backward()
115
+ _grad = _var.grad.detach()/num_ac_rolls
116
+ e_t = e_t - lambda_ac*_grad
117
+ if lambda_kl>0:
118
+ _var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
119
+ l_kld = self.kl_divergence(_var)
120
+ l_kld.backward()
121
+ _grad = _var.grad.detach()
122
+ e_t = e_t - lambda_kl*_grad
123
+ e_t = e_t.detach()
124
+ noise_pred = e_t
125
+
126
+ # compute the previous noisy sample x_t -> x_t-1
127
+ latents = self.scheduler.step(noise_pred, t, latents, reverse=True, **extra_step_kwargs).prev_sample
128
+
129
+ # call the callback, if provided
130
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
131
+ progress_bar.update()
132
+
133
+
134
+ x_inv = latents.detach().clone()
135
+ # reconstruct the image
136
+
137
+ # 8. Post-processing
138
+ image = self.decode_latents(latents.detach())
139
+ image = self.numpy_to_pil(image)
140
+ return x_inv, image, image_x0_dec
src/utils/edit_directions.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+
5
+ """
6
+ This function takes in a task name and returns the direction in the embedding space that transforms class A to class B for the given task.
7
+
8
+ Parameters:
9
+ task_name (str): name of the task for which direction is to be constructed.
10
+
11
+ Returns:
12
+ torch.Tensor: A tensor representing the direction in the embedding space that transforms class A to class B.
13
+
14
+ Examples:
15
+ >>> construct_direction("cat2dog")
16
+ """
17
+ def construct_direction(task_name):
18
+ if task_name=="cat2dog":
19
+ emb_dir = f"assets/embeddings_sd_1.4"
20
+ embs_a = torch.load(os.path.join(emb_dir, f"cat.pt"))
21
+ embs_b = torch.load(os.path.join(emb_dir, f"dog.pt"))
22
+ return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
23
+ elif task_name=="dog2cat":
24
+ emb_dir = f"assets/embeddings_sd_1.4"
25
+ embs_a = torch.load(os.path.join(emb_dir, f"dog.pt"))
26
+ embs_b = torch.load(os.path.join(emb_dir, f"cat.pt"))
27
+ return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
28
+ else:
29
+ raise NotImplementedError
src/utils/edit_pipeline.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb, sys
2
+
3
+ import numpy as np
4
+ import torch
5
+ from typing import Any, Callable, Dict, List, Optional, Union
6
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
7
+ sys.path.insert(0, "src/utils")
8
+ from base_pipeline import BasePipeline
9
+ from cross_attention import prep_unet
10
+
11
+
12
+ class EditingPipeline(BasePipeline):
13
+ def __call__(
14
+ self,
15
+ prompt: Union[str, List[str]] = None,
16
+ height: Optional[int] = None,
17
+ width: Optional[int] = None,
18
+ num_inference_steps: int = 50,
19
+ guidance_scale: float = 7.5,
20
+ negative_prompt: Optional[Union[str, List[str]]] = None,
21
+ num_images_per_prompt: Optional[int] = 1,
22
+ eta: float = 0.0,
23
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
24
+ latents: Optional[torch.FloatTensor] = None,
25
+ prompt_embeds: Optional[torch.FloatTensor] = None,
26
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
27
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
28
+
29
+ # pix2pix parameters
30
+ guidance_amount=0.1,
31
+ edit_dir=None,
32
+ x_in=None,
33
+
34
+ ):
35
+
36
+ x_in.to(dtype=self.unet.dtype, device=self._execution_device)
37
+
38
+ # 0. modify the unet to be useful :D
39
+ self.unet = prep_unet(self.unet)
40
+
41
+ # 1. setup all caching objects
42
+ d_ref_t2attn = {} # reference cross attention maps
43
+
44
+ # 2. Default height and width to unet
45
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
46
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
47
+
48
+ # TODO: add the input checker function
49
+ # self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds )
50
+
51
+ # 2. Define call parameters
52
+ if prompt is not None and isinstance(prompt, str):
53
+ batch_size = 1
54
+ elif prompt is not None and isinstance(prompt, list):
55
+ batch_size = len(prompt)
56
+ else:
57
+ batch_size = prompt_embeds.shape[0]
58
+
59
+ device = self._execution_device
60
+ do_classifier_free_guidance = guidance_scale > 1.0
61
+ x_in = x_in.to(dtype=self.unet.dtype, device=self._execution_device)
62
+ # 3. Encode input prompt = 2x77x1024
63
+ prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,)
64
+
65
+ # 4. Prepare timesteps
66
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
67
+ timesteps = self.scheduler.timesteps
68
+
69
+ # 5. Prepare latent variables
70
+ num_channels_latents = self.unet.in_channels
71
+
72
+ # randomly sample a latent code if not provided
73
+ latents = self.prepare_latents(batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, x_in,)
74
+
75
+ latents_init = latents.clone()
76
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
77
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
78
+
79
+ # 7. First Denoising loop for getting the reference cross attention maps
80
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
81
+ with torch.no_grad():
82
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
83
+ for i, t in enumerate(timesteps):
84
+ # expand the latents if we are doing classifier free guidance
85
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
86
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
87
+
88
+ # predict the noise residual
89
+ noise_pred = self.unet(latent_model_input,t,encoder_hidden_states=prompt_embeds,cross_attention_kwargs=cross_attention_kwargs,).sample
90
+
91
+ # add the cross attention map to the dictionary
92
+ d_ref_t2attn[t.item()] = {}
93
+ for name, module in self.unet.named_modules():
94
+ module_name = type(module).__name__
95
+ if module_name == "CrossAttention" and 'attn2' in name:
96
+ attn_mask = module.attn_probs # size is num_channel,s*s,77
97
+ d_ref_t2attn[t.item()][name] = attn_mask.detach().cpu()
98
+
99
+ # perform guidance
100
+ if do_classifier_free_guidance:
101
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
102
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
103
+
104
+ # compute the previous noisy sample x_t -> x_t-1
105
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
106
+
107
+ # call the callback, if provided
108
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
109
+ progress_bar.update()
110
+
111
+ # make the reference image (reconstruction)
112
+ image_rec = self.numpy_to_pil(self.decode_latents(latents.detach()))
113
+
114
+ prompt_embeds_edit = prompt_embeds.clone()
115
+ #add the edit only to the second prompt, idx 0 is the negative prompt
116
+ prompt_embeds_edit[1:2] += edit_dir
117
+
118
+ latents = latents_init
119
+ # Second denoising loop for editing the text prompt
120
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
121
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
122
+ for i, t in enumerate(timesteps):
123
+ # expand the latents if we are doing classifier free guidance
124
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
125
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
126
+
127
+ x_in = latent_model_input.detach().clone()
128
+ x_in.requires_grad = True
129
+
130
+ opt = torch.optim.SGD([x_in], lr=guidance_amount)
131
+
132
+ # predict the noise residual
133
+ noise_pred = self.unet(x_in,t,encoder_hidden_states=prompt_embeds_edit.detach(),cross_attention_kwargs=cross_attention_kwargs,).sample
134
+
135
+ loss = 0.0
136
+ for name, module in self.unet.named_modules():
137
+ module_name = type(module).__name__
138
+ if module_name == "CrossAttention" and 'attn2' in name:
139
+ curr = module.attn_probs # size is num_channel,s*s,77
140
+ ref = d_ref_t2attn[t.item()][name].detach().cuda()
141
+ loss += ((curr-ref)**2).sum((1,2)).mean(0)
142
+ loss.backward(retain_graph=False)
143
+ opt.step()
144
+
145
+ # recompute the noise
146
+ with torch.no_grad():
147
+ noise_pred = self.unet(x_in.detach(),t,encoder_hidden_states=prompt_embeds_edit,cross_attention_kwargs=cross_attention_kwargs,).sample
148
+
149
+ latents = x_in.detach().chunk(2)[0]
150
+
151
+ # perform guidance
152
+ if do_classifier_free_guidance:
153
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
154
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
155
+
156
+ # compute the previous noisy sample x_t -> x_t-1
157
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
158
+
159
+ # call the callback, if provided
160
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
161
+ progress_bar.update()
162
+
163
+
164
+ # 8. Post-processing
165
+ image = self.decode_latents(latents.detach())
166
+
167
+ # 9. Run safety checker
168
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
169
+
170
+ # 10. Convert to PIL
171
+ image_edit = self.numpy_to_pil(image)
172
+
173
+
174
+ return image_rec, image_edit
src/utils/scheduler.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
16
+ # and https://github.com/hojonathanho/diffusion
17
+ import os, sys, pdb
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+ from diffusers.utils import BaseOutput, randn_tensor
27
+ from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
28
+
29
+
30
+ @dataclass
31
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
32
+ class DDIMSchedulerOutput(BaseOutput):
33
+ """
34
+ Output class for the scheduler's step function output.
35
+
36
+ Args:
37
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
38
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
39
+ denoising loop.
40
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
41
+ The predicted denoised sample (x_{0}) based on the model output from the current timestep.
42
+ `pred_original_sample` can be used to preview progress or for guidance.
43
+ """
44
+
45
+ prev_sample: torch.FloatTensor
46
+ pred_original_sample: Optional[torch.FloatTensor] = None
47
+
48
+
49
+ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
50
+ """
51
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
52
+ (1-beta) over time from t = [0,1].
53
+
54
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
55
+ to that part of the diffusion process.
56
+
57
+
58
+ Args:
59
+ num_diffusion_timesteps (`int`): the number of betas to produce.
60
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
61
+ prevent singularities.
62
+
63
+ Returns:
64
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
65
+ """
66
+
67
+ def alpha_bar(time_step):
68
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
69
+
70
+ betas = []
71
+ for i in range(num_diffusion_timesteps):
72
+ t1 = i / num_diffusion_timesteps
73
+ t2 = (i + 1) / num_diffusion_timesteps
74
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
75
+ return torch.tensor(betas)
76
+
77
+
78
+ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
79
+ """
80
+ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
81
+ diffusion probabilistic models (DDPMs) with non-Markovian guidance.
82
+
83
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
84
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
85
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
86
+ [`~SchedulerMixin.from_pretrained`] functions.
87
+
88
+ For more details, see the original paper: https://arxiv.org/abs/2010.02502
89
+
90
+ Args:
91
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
92
+ beta_start (`float`): the starting `beta` value of inference.
93
+ beta_end (`float`): the final `beta` value.
94
+ beta_schedule (`str`):
95
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
96
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
97
+ trained_betas (`np.ndarray`, optional):
98
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
99
+ clip_sample (`bool`, default `True`):
100
+ option to clip predicted sample between -1 and 1 for numerical stability.
101
+ set_alpha_to_one (`bool`, default `True`):
102
+ each diffusion step uses the value of alphas product at that step and at the previous one. For the final
103
+ step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
104
+ otherwise it uses the value of alpha at step 0.
105
+ steps_offset (`int`, default `0`):
106
+ an offset added to the inference steps. You can use a combination of `offset=1` and
107
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
108
+ stable diffusion.
109
+ prediction_type (`str`, default `epsilon`, optional):
110
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
111
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
112
+ https://imagen.research.google/video/paper.pdf)
113
+ """
114
+
115
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
116
+ order = 1
117
+
118
+ @register_to_config
119
+ def __init__(
120
+ self,
121
+ num_train_timesteps: int = 1000,
122
+ beta_start: float = 0.0001,
123
+ beta_end: float = 0.02,
124
+ beta_schedule: str = "linear",
125
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
126
+ clip_sample: bool = True,
127
+ set_alpha_to_one: bool = True,
128
+ steps_offset: int = 0,
129
+ prediction_type: str = "epsilon",
130
+ ):
131
+ if trained_betas is not None:
132
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
133
+ elif beta_schedule == "linear":
134
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
135
+ elif beta_schedule == "scaled_linear":
136
+ # this schedule is very specific to the latent diffusion model.
137
+ self.betas = (
138
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
139
+ )
140
+ elif beta_schedule == "squaredcos_cap_v2":
141
+ # Glide cosine schedule
142
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
143
+ else:
144
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
145
+
146
+ self.alphas = 1.0 - self.betas
147
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
148
+
149
+ # At every step in ddim, we are looking into the previous alphas_cumprod
150
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
151
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
152
+ # whether we use the final alpha of the "non-previous" one.
153
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
154
+
155
+ # standard deviation of the initial noise distribution
156
+ self.init_noise_sigma = 1.0
157
+
158
+ # setable values
159
+ self.num_inference_steps = None
160
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
161
+
162
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
163
+ """
164
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
165
+ current timestep.
166
+
167
+ Args:
168
+ sample (`torch.FloatTensor`): input sample
169
+ timestep (`int`, optional): current timestep
170
+
171
+ Returns:
172
+ `torch.FloatTensor`: scaled input sample
173
+ """
174
+ return sample
175
+
176
+ def _get_variance(self, timestep, prev_timestep):
177
+ alpha_prod_t = self.alphas_cumprod[timestep]
178
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
179
+ beta_prod_t = 1 - alpha_prod_t
180
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
181
+
182
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
183
+
184
+ return variance
185
+
186
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
187
+ """
188
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
189
+
190
+ Args:
191
+ num_inference_steps (`int`):
192
+ the number of diffusion steps used when generating samples with a pre-trained model.
193
+ """
194
+
195
+ if num_inference_steps > self.config.num_train_timesteps:
196
+ raise ValueError(
197
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
198
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
199
+ f" maximal {self.config.num_train_timesteps} timesteps."
200
+ )
201
+
202
+ self.num_inference_steps = num_inference_steps
203
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
204
+ # creates integer timesteps by multiplying by ratio
205
+ # casting to int to avoid issues when num_inference_step is power of 3
206
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
207
+ self.timesteps = torch.from_numpy(timesteps).to(device)
208
+ self.timesteps += self.config.steps_offset
209
+
210
+ def step(
211
+ self,
212
+ model_output: torch.FloatTensor,
213
+ timestep: int,
214
+ sample: torch.FloatTensor,
215
+ eta: float = 0.0,
216
+ use_clipped_model_output: bool = False,
217
+ generator=None,
218
+ variance_noise: Optional[torch.FloatTensor] = None,
219
+ return_dict: bool = True,
220
+ reverse=False
221
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
222
+
223
+
224
+ e_t = model_output
225
+
226
+ x = sample
227
+ prev_timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps
228
+ # print(timestep, prev_timestep)
229
+ a_t = alpha_prod_t = self.alphas_cumprod[timestep-1]
230
+ a_prev = alpha_t_prev = self.alphas_cumprod[prev_timestep-1] if prev_timestep >= 0 else self.final_alpha_cumprod
231
+ beta_prod_t = 1 - alpha_prod_t
232
+
233
+ pred_x0 = (x - (1-a_t)**0.5 * e_t) / a_t.sqrt()
234
+ # direction pointing to x_t
235
+ dir_xt = (1. - a_prev).sqrt() * e_t
236
+ x = a_prev.sqrt()*pred_x0 + dir_xt
237
+ if not return_dict:
238
+ return (x,)
239
+ return DDIMSchedulerOutput(prev_sample=x, pred_original_sample=pred_x0)
240
+
241
+
242
+
243
+
244
+
245
+ def add_noise(
246
+ self,
247
+ original_samples: torch.FloatTensor,
248
+ noise: torch.FloatTensor,
249
+ timesteps: torch.IntTensor,
250
+ ) -> torch.FloatTensor:
251
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
252
+ self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
253
+ timesteps = timesteps.to(original_samples.device)
254
+
255
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
256
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
257
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
258
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
259
+
260
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
261
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
262
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
263
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
264
+
265
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
266
+ return noisy_samples
267
+
268
+ def get_velocity(
269
+ self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
270
+ ) -> torch.FloatTensor:
271
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
272
+ self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
273
+ timesteps = timesteps.to(sample.device)
274
+
275
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
276
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
277
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
278
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
279
+
280
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
281
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
282
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
283
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
284
+
285
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
286
+ return velocity
287
+
288
+ def __len__(self):
289
+ return self.config.num_train_timesteps