arthur-qiu commited on
Commit
3adee15
1 Parent(s): 0aff381
Files changed (8) hide show
  1. .gitattributes +0 -35
  2. .gitignore +8 -0
  3. README.md +0 -14
  4. app.py +208 -0
  5. free_lunch_utils.py +306 -0
  6. pipeline_freescale.py +1204 -0
  7. requirements.txt +12 -0
  8. scale_attention.py +372 -0
.gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ .DS_Store
2
+ *pyc
3
+ .vscode
4
+ __pycache__
5
+ *.egg-info
6
+
7
+ checkpoints
8
+ results
README.md DELETED
@@ -1,14 +0,0 @@
1
- ---
2
- title: FreeScale
3
- emoji: 👁
4
- colorFrom: pink
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 5.9.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: Unleashing the resolution of your SDXL using FreeScale
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from PIL import Image
4
+
5
+ from pipeline_freescale import StableDiffusionXLPipeline
6
+ from free_lunch_utils import register_free_upblock2d, register_free_crossattn_upblock2d
7
+
8
+ import gradio as gr
9
+ import spaces
10
+
11
+ @spaces.GPU(duration=120)
12
+ def infer_gpu_part(pipe, generator, prompt, negative_prompt, ddim_steps, guidance_scale, resolutions_list, fast_mode, cosine_scale):
13
+ pipe = pipe.to("cuda")
14
+ generator = generator.to("cuda")
15
+ resul = pipe(prompt, negative_prompt=negative_prompt, generator=generator,
16
+ num_inference_steps=ddim_steps, guidance_scale=guidance_scale,
17
+ resolutions_list=resolutions_list, fast_mode=fast_mode, cosine_scale=cosine_scale,
18
+ )
19
+ return result
20
+
21
+ def infer(prompt, output_size, ddim_steps, guidance_scale, cosine_scale, seed, options, negative_prompt):
22
+
23
+ disable_freeu = 'Disable FreeU' in options
24
+ fast_mode = 'Fast Mode' in options
25
+ if output_size == "2048 x 2048":
26
+ resolutions_list = [[1024, 1024],
27
+ [2048, 2048]]
28
+ elif output_size == "2048 x 4096":
29
+ resolutions_list = [[512, 1024],
30
+ [1024, 2048],
31
+ [2048, 4096]]
32
+ elif output_size == "4096 x 2048":
33
+ resolutions_list = [[1024, 512],
34
+ [2048, 1024],
35
+ [4096, 2048]]
36
+ elif output_size == "4096 x 4096":
37
+ resolutions_list = [[1024, 1024],
38
+ [2048, 2048],
39
+ [4096, 4096]]
40
+
41
+ model_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"
42
+ pipe = StableDiffusionXLPipeline.from_pretrained(model_ckpt, torch_dtype=torch.float16)
43
+ if not disable_freeu:
44
+ register_free_upblock2d(pipe, b1=1.1, b2=1.2, s1=0.6, s2=0.4)
45
+ register_free_crossattn_upblock2d(pipe, b1=1.1, b2=1.2, s1=0.6, s2=0.4)
46
+
47
+ generator = torch.Generator()
48
+ generator = generator.manual_seed(seed)
49
+
50
+ result = infer_gpu_part(pipe, generator, prompt, negative_prompt, ddim_steps, guidance_scale, resolutions_list, fast_mode, cosine_scale)
51
+
52
+ image = result.images[0]
53
+ save_path = 'output.png'
54
+ image.save(save_path)
55
+
56
+ return save_path
57
+
58
+
59
+ examples = [
60
+ ["A Enchanted illustration of a Palatial Ghost Explosion with a Mystical Sky, in the style of Eric, viewed from CamProX, Bokeh. High resolution, 8k, insanely detailed.",],
61
+ ["Brunette pilot girl in a snowstorm, full body, moody lighting, intricate details, depth of field, outdoors, Fujifilm XT3, RAW, 8K UHD, film grain, Unreal Engine 5, ray tracing.",],
62
+ ["A cute and adorable fluffy puppy wearing a witch hat in a Halloween autumn evening forest, falling autumn leaves, brown acorns on the ground, Halloween pumpkins spiderwebs, bats, and a witch’s broom.",],
63
+ ["A Fantasy Realism illustration of a Heroic Phoenix Rising Adventurous with a Fantasy Waterfall, in the style of Illusia, viewed from Capture360XPro, Historical light. High resolution, 8k, insanely detailed.",],
64
+ ]
65
+
66
+ css = """
67
+ #col-container {max-width: 640px; margin-left: auto; margin-right: auto;}
68
+ a {text-decoration-line: underline; font-weight: 600;}
69
+ .animate-spin {
70
+ animation: spin 1s linear infinite;
71
+ }
72
+ @keyframes spin {
73
+ from {
74
+ transform: rotate(0deg);
75
+ }
76
+ to {
77
+ transform: rotate(360deg);
78
+ }
79
+ }
80
+ #share-btn-container {
81
+ display: flex;
82
+ padding-left: 0.5rem !important;
83
+ padding-right: 0.5rem !important;
84
+ background-color: #000000;
85
+ justify-content: center;
86
+ align-items: center;
87
+ border-radius: 9999px !important;
88
+ max-width: 15rem;
89
+ height: 36px;
90
+ }
91
+ div#share-btn-container > div {
92
+ flex-direction: row;
93
+ background: black;
94
+ align-items: center;
95
+ }
96
+ #share-btn-container:hover {
97
+ background-color: #060606;
98
+ }
99
+ #share-btn {
100
+ all: initial;
101
+ color: #ffffff;
102
+ font-weight: 600;
103
+ cursor:pointer;
104
+ font-family: 'IBM Plex Sans', sans-serif;
105
+ margin-left: 0.5rem !important;
106
+ padding-top: 0.5rem !important;
107
+ padding-bottom: 0.5rem !important;
108
+ right:0;
109
+ }
110
+ #share-btn * {
111
+ all: unset;
112
+ }
113
+ #share-btn-container div:nth-child(-n+2){
114
+ width: auto !important;
115
+ min-height: 0px !important;
116
+ }
117
+ #share-btn-container .wrap {
118
+ display: none !important;
119
+ }
120
+ #share-btn-container.hidden {
121
+ display: none!important;
122
+ }
123
+ img[src*='#center'] {
124
+ display: inline-block;
125
+ margin: unset;
126
+ }
127
+ .footer {
128
+ margin-bottom: 45px;
129
+ margin-top: 10px;
130
+ text-align: center;
131
+ border-bottom: 1px solid #e5e5e5;
132
+ }
133
+ .footer>p {
134
+ font-size: .8rem;
135
+ display: inline-block;
136
+ padding: 0 10px;
137
+ transform: translateY(10px);
138
+ background: white;
139
+ }
140
+ .dark .footer {
141
+ border-color: #303030;
142
+ }
143
+ .dark .footer>p {
144
+ background: #0b0f19;
145
+ }
146
+ """
147
+
148
+ with gr.Blocks(css=css) as demo:
149
+ with gr.Column(elem_id="col-container"):
150
+ gr.Markdown(
151
+ """
152
+ <h1 style="text-align: center;">FreeScale (unleash the resolution of SDXL)</h1>
153
+ <p style="text-align: center;">
154
+ FreeScale: Unleashing the Resolution of Diffusion Models via Tuning-Free Scale Fusion
155
+ </p>
156
+ <p style="text-align: center;">
157
+ <a href="https://arxiv.org/abs/2412.09626" target="_blank"><b>[arXiv]</b></a> &nbsp;&nbsp;&nbsp;&nbsp;
158
+ <a href="http://haonanqiu.com/projects/FreeScale.html" target="_blank"><b>[Project Page]</b></a> &nbsp;&nbsp;&nbsp;&nbsp;
159
+ <a href="https://github.com/ali-vilab/FreeScale" target="_blank"><b>[Code]</b></a>
160
+ </p>
161
+ """
162
+ )
163
+
164
+ prompt_in = gr.Textbox(label="Prompt", placeholder="A chihuahua in astronaut suit floating in space, cinematic lighting, glow effect")
165
+
166
+ with gr.Row():
167
+ with gr.Accordion('FreeScale Parameters (feel free to adjust these parameters based on your prompt): ', open=False):
168
+ with gr.Row():
169
+ output_size = gr.Dropdown(["2048 x 2048", "2048 x 4096", "4096 x 2048", "4096 x 4096"], value="2048 x 2048", label="Output Size (H x W)")
170
+ with gr.Row():
171
+ ddim_steps = gr.Slider(label='DDIM Steps',
172
+ minimum=5,
173
+ maximum=200,
174
+ step=1,
175
+ value=50)
176
+ guidance_scale = gr.Slider(label='Guidance Scale',
177
+ minimum=1.0,
178
+ maximum=20.0,
179
+ step=0.1,
180
+ value=7.5)
181
+ with gr.Row():
182
+ cosine_scale = gr.Slider(label='Cosine Scale',
183
+ minimum=0,
184
+ maximum=10,
185
+ step=0.1,
186
+ value=2.0)
187
+ seed = gr.Slider(label='Random Seed',
188
+ minimum=0,
189
+ maximum=10000,
190
+ step=1,
191
+ value=123)
192
+ with gr.Row():
193
+ options = gr.CheckboxGroup(['Disable FreeU', 'Fast Mode'], label='Options (NOT recommended to change)')
194
+ with gr.Row():
195
+ negative_prompt = gr.Textbox(label='Negative Prompt', value='blurry, ugly, duplicate, poorly drawn, deformed, mosaic')
196
+
197
+ submit_btn = gr.Button("Generate", variant='primary')
198
+ image_result = gr.Image(label="Image Output")
199
+
200
+ gr.Examples(examples=examples, inputs=[prompt_in, output_size, ddim_steps, guidance_scale, cosine_scale, seed, options, negative_prompt])
201
+
202
+ submit_btn.click(fn=infer,
203
+ inputs=[prompt_in, output_size, ddim_steps, guidance_scale, cosine_scale, seed, options, negative_prompt],
204
+ outputs=[image_result],
205
+ api_name="freescalehf")
206
+
207
+ if __name__ == "__main__":
208
+ demo.queue(max_size=8).launch(show_api=True)
free_lunch_utils.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.fft as fft
3
+ from diffusers.models.unet_2d_condition import logger
4
+ from diffusers.utils import is_torch_version
5
+ from typing import Any, Dict, List, Optional, Tuple, Union
6
+
7
+ """ Borrowed from https://github.com/ChenyangSi/FreeU/blob/main/demo/free_lunch_utils.py
8
+ """
9
+
10
+ def isinstance_str(x: object, cls_name: str):
11
+ """
12
+ Checks whether x has any class *named* cls_name in its ancestry.
13
+ Doesn't require access to the class's implementation.
14
+
15
+ Useful for patching!
16
+ """
17
+
18
+ for _cls in x.__class__.__mro__:
19
+ if _cls.__name__ == cls_name:
20
+ return True
21
+
22
+ return False
23
+
24
+
25
+ def Fourier_filter(x, threshold, scale):
26
+ dtype = x.dtype
27
+ x = x.type(torch.float32)
28
+ # FFT
29
+ x_freq = fft.fftn(x, dim=(-2, -1))
30
+ x_freq = fft.fftshift(x_freq, dim=(-2, -1))
31
+
32
+ B, C, H, W = x_freq.shape
33
+ mask = torch.ones((B, C, H, W)).cuda()
34
+
35
+ crow, ccol = H // 2, W //2
36
+ mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
37
+ x_freq = x_freq * mask
38
+
39
+ # IFFT
40
+ x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
41
+ x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
42
+
43
+ x_filtered = x_filtered.type(dtype)
44
+ return x_filtered
45
+
46
+
47
+ def register_upblock2d(model):
48
+ def up_forward(self):
49
+ def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
50
+ for resnet in self.resnets:
51
+ # pop res hidden states
52
+ res_hidden_states = res_hidden_states_tuple[-1]
53
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
54
+ #print(f"in upblock2d, hidden states shape: {hidden_states.shape}")
55
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
56
+
57
+ if self.training and self.gradient_checkpointing:
58
+
59
+ def create_custom_forward(module):
60
+ def custom_forward(*inputs):
61
+ return module(*inputs)
62
+
63
+ return custom_forward
64
+
65
+ if is_torch_version(">=", "1.11.0"):
66
+ hidden_states = torch.utils.checkpoint.checkpoint(
67
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
68
+ )
69
+ else:
70
+ hidden_states = torch.utils.checkpoint.checkpoint(
71
+ create_custom_forward(resnet), hidden_states, temb
72
+ )
73
+ else:
74
+ hidden_states = resnet(hidden_states, temb)
75
+
76
+ if self.upsamplers is not None:
77
+ for upsampler in self.upsamplers:
78
+ hidden_states = upsampler(hidden_states, upsample_size)
79
+
80
+ return hidden_states
81
+
82
+ return forward
83
+
84
+ for i, upsample_block in enumerate(model.unet.up_blocks):
85
+ if isinstance_str(upsample_block, "UpBlock2D"):
86
+ upsample_block.forward = up_forward(upsample_block)
87
+
88
+
89
+ def register_free_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
90
+ def up_forward(self):
91
+ def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
92
+ for resnet in self.resnets:
93
+ # pop res hidden states
94
+ res_hidden_states = res_hidden_states_tuple[-1]
95
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
96
+ #print(f"in free upblock2d, hidden states shape: {hidden_states.shape}")
97
+
98
+ # --------------- FreeU code -----------------------
99
+ # Only operate on the first two stages
100
+ if hidden_states.shape[1] == 1280:
101
+ hidden_states[:,:640] = hidden_states[:,:640] * self.b1
102
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
103
+ if hidden_states.shape[1] == 640:
104
+ hidden_states[:,:320] = hidden_states[:,:320] * self.b2
105
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
106
+ # ---------------------------------------------------------
107
+
108
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
109
+
110
+ if self.training and self.gradient_checkpointing:
111
+
112
+ def create_custom_forward(module):
113
+ def custom_forward(*inputs):
114
+ return module(*inputs)
115
+
116
+ return custom_forward
117
+
118
+ if is_torch_version(">=", "1.11.0"):
119
+ hidden_states = torch.utils.checkpoint.checkpoint(
120
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
121
+ )
122
+ else:
123
+ hidden_states = torch.utils.checkpoint.checkpoint(
124
+ create_custom_forward(resnet), hidden_states, temb
125
+ )
126
+ else:
127
+ hidden_states = resnet(hidden_states, temb)
128
+
129
+ if self.upsamplers is not None:
130
+ for upsampler in self.upsamplers:
131
+ hidden_states = upsampler(hidden_states, upsample_size)
132
+
133
+ return hidden_states
134
+
135
+ return forward
136
+
137
+ for i, upsample_block in enumerate(model.unet.up_blocks):
138
+ if isinstance_str(upsample_block, "UpBlock2D"):
139
+ upsample_block.forward = up_forward(upsample_block)
140
+ setattr(upsample_block, 'b1', b1)
141
+ setattr(upsample_block, 'b2', b2)
142
+ setattr(upsample_block, 's1', s1)
143
+ setattr(upsample_block, 's2', s2)
144
+
145
+
146
+ def register_crossattn_upblock2d(model):
147
+ def up_forward(self):
148
+ def forward(
149
+ hidden_states: torch.FloatTensor,
150
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
151
+ temb: Optional[torch.FloatTensor] = None,
152
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
153
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
154
+ upsample_size: Optional[int] = None,
155
+ attention_mask: Optional[torch.FloatTensor] = None,
156
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
157
+ ):
158
+ for resnet, attn in zip(self.resnets, self.attentions):
159
+ # pop res hidden states
160
+ #print(f"in crossatten upblock2d, hidden states shape: {hidden_states.shape}")
161
+ res_hidden_states = res_hidden_states_tuple[-1]
162
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
163
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
164
+
165
+ if self.training and self.gradient_checkpointing:
166
+
167
+ def create_custom_forward(module, return_dict=None):
168
+ def custom_forward(*inputs):
169
+ if return_dict is not None:
170
+ return module(*inputs, return_dict=return_dict)
171
+ else:
172
+ return module(*inputs)
173
+
174
+ return custom_forward
175
+
176
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
177
+ hidden_states = torch.utils.checkpoint.checkpoint(
178
+ create_custom_forward(resnet),
179
+ hidden_states,
180
+ temb,
181
+ **ckpt_kwargs,
182
+ )
183
+ hidden_states = torch.utils.checkpoint.checkpoint(
184
+ create_custom_forward(attn, return_dict=False),
185
+ hidden_states,
186
+ encoder_hidden_states,
187
+ None, # timestep
188
+ None, # class_labels
189
+ cross_attention_kwargs,
190
+ attention_mask,
191
+ encoder_attention_mask,
192
+ **ckpt_kwargs,
193
+ )[0]
194
+ else:
195
+ hidden_states = resnet(hidden_states, temb)
196
+ hidden_states = attn(
197
+ hidden_states,
198
+ encoder_hidden_states=encoder_hidden_states,
199
+ cross_attention_kwargs=cross_attention_kwargs,
200
+ attention_mask=attention_mask,
201
+ encoder_attention_mask=encoder_attention_mask,
202
+ return_dict=False,
203
+ )[0]
204
+
205
+ if self.upsamplers is not None:
206
+ for upsampler in self.upsamplers:
207
+ hidden_states = upsampler(hidden_states, upsample_size)
208
+
209
+ return hidden_states
210
+
211
+ return forward
212
+
213
+ for i, upsample_block in enumerate(model.unet.up_blocks):
214
+ if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
215
+ upsample_block.forward = up_forward(upsample_block)
216
+
217
+
218
+ def register_free_crossattn_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
219
+ def up_forward(self):
220
+ def forward(
221
+ hidden_states: torch.FloatTensor,
222
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
223
+ temb: Optional[torch.FloatTensor] = None,
224
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
225
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
226
+ upsample_size: Optional[int] = None,
227
+ attention_mask: Optional[torch.FloatTensor] = None,
228
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
229
+ ):
230
+ for resnet, attn in zip(self.resnets, self.attentions):
231
+ # pop res hidden states
232
+ #print(f"in free crossatten upblock2d, hidden states shape: {hidden_states.shape}")
233
+ res_hidden_states = res_hidden_states_tuple[-1]
234
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
235
+
236
+ # --------------- FreeU code -----------------------
237
+ # Only operate on the first two stages
238
+ if hidden_states.shape[1] == 1280:
239
+ hidden_states[:,:640] = hidden_states[:,:640] * self.b1
240
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
241
+ if hidden_states.shape[1] == 640:
242
+ hidden_states[:,:320] = hidden_states[:,:320] * self.b2
243
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
244
+ # ---------------------------------------------------------
245
+
246
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
247
+
248
+ if self.training and self.gradient_checkpointing:
249
+
250
+ def create_custom_forward(module, return_dict=None):
251
+ def custom_forward(*inputs):
252
+ if return_dict is not None:
253
+ return module(*inputs, return_dict=return_dict)
254
+ else:
255
+ return module(*inputs)
256
+
257
+ return custom_forward
258
+
259
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
260
+ hidden_states = torch.utils.checkpoint.checkpoint(
261
+ create_custom_forward(resnet),
262
+ hidden_states,
263
+ temb,
264
+ **ckpt_kwargs,
265
+ )
266
+ hidden_states = torch.utils.checkpoint.checkpoint(
267
+ create_custom_forward(attn, return_dict=False),
268
+ hidden_states,
269
+ encoder_hidden_states,
270
+ None, # timestep
271
+ None, # class_labels
272
+ cross_attention_kwargs,
273
+ attention_mask,
274
+ encoder_attention_mask,
275
+ **ckpt_kwargs,
276
+ )[0]
277
+ else:
278
+ hidden_states = resnet(hidden_states, temb)
279
+ # hidden_states = attn(
280
+ # hidden_states,
281
+ # encoder_hidden_states=encoder_hidden_states,
282
+ # cross_attention_kwargs=cross_attention_kwargs,
283
+ # encoder_attention_mask=encoder_attention_mask,
284
+ # return_dict=False,
285
+ # )[0]
286
+ hidden_states = attn(
287
+ hidden_states,
288
+ encoder_hidden_states=encoder_hidden_states,
289
+ cross_attention_kwargs=cross_attention_kwargs,
290
+ )[0]
291
+
292
+ if self.upsamplers is not None:
293
+ for upsampler in self.upsamplers:
294
+ hidden_states = upsampler(hidden_states, upsample_size)
295
+
296
+ return hidden_states
297
+
298
+ return forward
299
+
300
+ for i, upsample_block in enumerate(model.unet.up_blocks):
301
+ if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
302
+ upsample_block.forward = up_forward(upsample_block)
303
+ setattr(upsample_block, 'b1', b1)
304
+ setattr(upsample_block, 'b2', b2)
305
+ setattr(upsample_block, 's1', s1)
306
+ setattr(upsample_block, 's2', s2)
pipeline_freescale.py ADDED
@@ -0,0 +1,1204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import os
3
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
7
+
8
+ from diffusers.image_processor import VaeImageProcessor
9
+ from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
10
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
11
+ from diffusers.models.attention_processor import (
12
+ AttnProcessor2_0,
13
+ LoRAAttnProcessor2_0,
14
+ LoRAXFormersAttnProcessor,
15
+ XFormersAttnProcessor,
16
+ )
17
+ from diffusers.schedulers import KarrasDiffusionSchedulers
18
+ from diffusers.utils import (
19
+ is_accelerate_available,
20
+ is_accelerate_version,
21
+ is_invisible_watermark_available,
22
+ logging,
23
+ randn_tensor,
24
+ replace_example_docstring,
25
+ )
26
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
27
+ from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
28
+
29
+ if is_invisible_watermark_available():
30
+ from .watermark import StableDiffusionXLWatermarker
31
+
32
+ from inspect import isfunction
33
+ from functools import partial
34
+ import numpy as np
35
+
36
+ from diffusers.models.attention import BasicTransformerBlock
37
+ from scale_attention import ori_forward, scale_forward
38
+
39
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
+
41
+ EXAMPLE_DOC_STRING = """
42
+ Examples:
43
+ ```py
44
+ >>> import torch
45
+ >>> from diffusers import StableDiffusionXLPipeline
46
+
47
+ >>> pipe = StableDiffusionXLPipeline.from_pretrained(
48
+ ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
49
+ ... )
50
+ >>> pipe = pipe.to("cuda")
51
+
52
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
53
+ >>> image = pipe(prompt).images[0]
54
+ ```
55
+ """
56
+
57
+ def default(val, d):
58
+ if exists(val):
59
+ return val
60
+ return d() if isfunction(d) else d
61
+
62
+ def exists(val):
63
+ return val is not None
64
+
65
+ def extract_into_tensor(a, t, x_shape):
66
+ b, *_ = t.shape
67
+ out = a.gather(-1, t)
68
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
69
+
70
+ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
71
+ if schedule == "linear":
72
+ betas = (
73
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
74
+ )
75
+ elif schedule == "cosine":
76
+ timesteps = (
77
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
78
+ )
79
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
80
+ alphas = torch.cos(alphas).pow(2)
81
+ alphas = alphas / alphas[0]
82
+ betas = 1 - alphas[1:] / alphas[:-1]
83
+ betas = np.clip(betas, a_min=0, a_max=0.999)
84
+ elif schedule == "sqrt_linear":
85
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
86
+ elif schedule == "sqrt":
87
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
88
+ else:
89
+ raise ValueError(f"schedule '{schedule}' unknown.")
90
+ return betas.numpy()
91
+
92
+ to_torch = partial(torch.tensor, dtype=torch.float16)
93
+ betas = make_beta_schedule("linear", 1000, linear_start=0.00085, linear_end=0.012)
94
+ alphas = 1. - betas
95
+ alphas_cumprod = np.cumprod(alphas, axis=0)
96
+ sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod))
97
+ sqrt_one_minus_alphas_cumprod = to_torch(np.sqrt(1. - alphas_cumprod))
98
+
99
+ def q_sample(x_start, t, init_noise_sigma = 1.0, noise=None, device=None):
100
+ noise = default(noise, lambda: torch.randn_like(x_start)).to(device) * init_noise_sigma
101
+ return (extract_into_tensor(sqrt_alphas_cumprod.to(device), t, x_start.shape) * x_start +
102
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod.to(device), t, x_start.shape) * noise)
103
+
104
+ def get_views(height, width, h_window_size=128, w_window_size=128, h_window_stride=64, w_window_stride=64, vae_scale_factor=8):
105
+ height //= vae_scale_factor
106
+ width //= vae_scale_factor
107
+ num_blocks_height = int((height - h_window_size) / h_window_stride - 1e-6) + 2 if height > h_window_size else 1
108
+ num_blocks_width = int((width - w_window_size) / w_window_stride - 1e-6) + 2 if width > w_window_size else 1
109
+ total_num_blocks = int(num_blocks_height * num_blocks_width)
110
+ views = []
111
+ for i in range(total_num_blocks):
112
+ h_start = int((i // num_blocks_width) * h_window_stride)
113
+ h_end = h_start + h_window_size
114
+ w_start = int((i % num_blocks_width) * w_window_stride)
115
+ w_end = w_start + w_window_size
116
+
117
+ if h_end > height:
118
+ h_start = int(h_start + height - h_end)
119
+ h_end = int(height)
120
+ if w_end > width:
121
+ w_start = int(w_start + width - w_end)
122
+ w_end = int(width)
123
+ if h_start < 0:
124
+ h_end = int(h_end - h_start)
125
+ h_start = 0
126
+ if w_start < 0:
127
+ w_end = int(w_end - w_start)
128
+ w_start = 0
129
+
130
+ random_jitter = True
131
+ if random_jitter:
132
+ h_jitter_range = (h_window_size - h_window_stride) // 4
133
+ w_jitter_range = (w_window_size - w_window_stride) // 4
134
+ h_jitter = 0
135
+ w_jitter = 0
136
+
137
+ if (w_start != 0) and (w_end != width):
138
+ w_jitter = random.randint(-w_jitter_range, w_jitter_range)
139
+ elif (w_start == 0) and (w_end != width):
140
+ w_jitter = random.randint(-w_jitter_range, 0)
141
+ elif (w_start != 0) and (w_end == width):
142
+ w_jitter = random.randint(0, w_jitter_range)
143
+ if (h_start != 0) and (h_end != height):
144
+ h_jitter = random.randint(-h_jitter_range, h_jitter_range)
145
+ elif (h_start == 0) and (h_end != height):
146
+ h_jitter = random.randint(-h_jitter_range, 0)
147
+ elif (h_start != 0) and (h_end == height):
148
+ h_jitter = random.randint(0, h_jitter_range)
149
+ h_start += (h_jitter + h_jitter_range)
150
+ h_end += (h_jitter + h_jitter_range)
151
+ w_start += (w_jitter + w_jitter_range)
152
+ w_end += (w_jitter + w_jitter_range)
153
+
154
+ views.append((h_start, h_end, w_start, w_end))
155
+ return views
156
+
157
+ def gaussian_kernel(kernel_size=3, sigma=1.0, channels=3):
158
+ x_coord = torch.arange(kernel_size)
159
+ gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2))
160
+ gaussian_1d = gaussian_1d / gaussian_1d.sum()
161
+ gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :]
162
+ kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1)
163
+
164
+ return kernel
165
+
166
+ def gaussian_filter(latents, kernel_size=3, sigma=1.0):
167
+ channels = latents.shape[1]
168
+ kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype)
169
+ if len(latents.shape) == 5:
170
+ b = latents.shape[0]
171
+ latents = rearrange(latents, 'b c t i j -> (b t) c i j')
172
+ blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels)
173
+ blurred_latents = rearrange(blurred_latents, '(b t) c i j -> b c t i j', b=b)
174
+ else:
175
+ blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels)
176
+
177
+ return blurred_latents
178
+
179
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
180
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
181
+ """
182
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
183
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
184
+ """
185
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
186
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
187
+ # rescale the results from guidance (fixes overexposure)
188
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
189
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
190
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
191
+ return noise_cfg
192
+
193
+
194
+ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
195
+ r"""
196
+ Pipeline for text-to-image generation using Stable Diffusion XL.
197
+
198
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
199
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
200
+
201
+ In addition the pipeline inherits the following loading methods:
202
+ - *LoRA*: [`StableDiffusionXLPipeline.load_lora_weights`]
203
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
204
+
205
+ as well as the following saving methods:
206
+ - *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`]
207
+
208
+ Args:
209
+ vae ([`AutoencoderKL`]):
210
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
211
+ text_encoder ([`CLIPTextModel`]):
212
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
213
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
214
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
215
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
216
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
217
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
218
+ specifically the
219
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
220
+ variant.
221
+ tokenizer (`CLIPTokenizer`):
222
+ Tokenizer of class
223
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
224
+ tokenizer_2 (`CLIPTokenizer`):
225
+ Second Tokenizer of class
226
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
227
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
228
+ scheduler ([`SchedulerMixin`]):
229
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
230
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
231
+ """
232
+
233
+ def __init__(
234
+ self,
235
+ vae: AutoencoderKL,
236
+ text_encoder: CLIPTextModel,
237
+ text_encoder_2: CLIPTextModelWithProjection,
238
+ tokenizer: CLIPTokenizer,
239
+ tokenizer_2: CLIPTokenizer,
240
+ unet: UNet2DConditionModel,
241
+ scheduler: KarrasDiffusionSchedulers,
242
+ force_zeros_for_empty_prompt: bool = True,
243
+ add_watermarker: Optional[bool] = None,
244
+ ):
245
+ super().__init__()
246
+
247
+ self.register_modules(
248
+ vae=vae,
249
+ text_encoder=text_encoder,
250
+ text_encoder_2=text_encoder_2,
251
+ tokenizer=tokenizer,
252
+ tokenizer_2=tokenizer_2,
253
+ unet=unet,
254
+ scheduler=scheduler,
255
+ )
256
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
257
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
258
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
259
+ self.default_sample_size = self.unet.config.sample_size
260
+
261
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
262
+
263
+ if add_watermarker:
264
+ self.watermark = StableDiffusionXLWatermarker()
265
+ else:
266
+ self.watermark = None
267
+
268
+ self.vae.enable_tiling()
269
+
270
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
271
+ def enable_vae_slicing(self):
272
+ r"""
273
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
274
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
275
+ """
276
+ self.vae.enable_slicing()
277
+
278
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
279
+ def disable_vae_slicing(self):
280
+ r"""
281
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
282
+ computing decoding in one step.
283
+ """
284
+ self.vae.disable_slicing()
285
+
286
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
287
+ def enable_vae_tiling(self):
288
+ r"""
289
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
290
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
291
+ processing larger images.
292
+ """
293
+ self.vae.enable_tiling()
294
+
295
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
296
+ def disable_vae_tiling(self):
297
+ r"""
298
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
299
+ computing decoding in one step.
300
+ """
301
+ self.vae.disable_tiling()
302
+
303
+ def enable_model_cpu_offload(self, gpu_id=0):
304
+ r"""
305
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
306
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
307
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
308
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
309
+ """
310
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
311
+ from accelerate import cpu_offload_with_hook
312
+ else:
313
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
314
+
315
+ device = torch.device(f"cuda:{gpu_id}")
316
+
317
+ if self.device.type != "cpu":
318
+ self.to("cpu", silence_dtype_warnings=True)
319
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
320
+
321
+ model_sequence = (
322
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
323
+ )
324
+ model_sequence.extend([self.unet, self.vae])
325
+
326
+ hook = None
327
+ for cpu_offloaded_model in model_sequence:
328
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
329
+
330
+ # We'll offload the last model manually.
331
+ self.final_offload_hook = hook
332
+
333
+ def encode_prompt(
334
+ self,
335
+ prompt: str,
336
+ prompt_2: Optional[str] = None,
337
+ device: Optional[torch.device] = None,
338
+ num_images_per_prompt: int = 1,
339
+ do_classifier_free_guidance: bool = True,
340
+ negative_prompt: Optional[str] = None,
341
+ negative_prompt_2: Optional[str] = None,
342
+ prompt_embeds: Optional[torch.FloatTensor] = None,
343
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
344
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
345
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
346
+ lora_scale: Optional[float] = None,
347
+ ):
348
+ r"""
349
+ Encodes the prompt into text encoder hidden states.
350
+
351
+ Args:
352
+ prompt (`str` or `List[str]`, *optional*):
353
+ prompt to be encoded
354
+ prompt_2 (`str` or `List[str]`, *optional*):
355
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
356
+ used in both text-encoders
357
+ device: (`torch.device`):
358
+ torch device
359
+ num_images_per_prompt (`int`):
360
+ number of images that should be generated per prompt
361
+ do_classifier_free_guidance (`bool`):
362
+ whether to use classifier free guidance or not
363
+ negative_prompt (`str` or `List[str]`, *optional*):
364
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
365
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
366
+ less than `1`).
367
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
368
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
369
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
370
+ prompt_embeds (`torch.FloatTensor`, *optional*):
371
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
372
+ provided, text embeddings will be generated from `prompt` input argument.
373
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
374
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
375
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
376
+ argument.
377
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
378
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
379
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
380
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
381
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
382
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
383
+ input argument.
384
+ lora_scale (`float`, *optional*):
385
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
386
+ """
387
+ device = device or self._execution_device
388
+
389
+ # set lora scale so that monkey patched LoRA
390
+ # function of text encoder can correctly access it
391
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
392
+ self._lora_scale = lora_scale
393
+
394
+ if prompt is not None and isinstance(prompt, str):
395
+ batch_size = 1
396
+ elif prompt is not None and isinstance(prompt, list):
397
+ batch_size = len(prompt)
398
+ else:
399
+ batch_size = prompt_embeds.shape[0]
400
+
401
+ # Define tokenizers and text encoders
402
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
403
+ text_encoders = (
404
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
405
+ )
406
+
407
+ if prompt_embeds is None:
408
+ prompt_2 = prompt_2 or prompt
409
+ # textual inversion: procecss multi-vector tokens if necessary
410
+ prompt_embeds_list = []
411
+ prompts = [prompt, prompt_2]
412
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
413
+ if isinstance(self, TextualInversionLoaderMixin):
414
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
415
+
416
+ text_inputs = tokenizer(
417
+ prompt,
418
+ padding="max_length",
419
+ max_length=tokenizer.model_max_length,
420
+ truncation=True,
421
+ return_tensors="pt",
422
+ )
423
+
424
+ text_input_ids = text_inputs.input_ids
425
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
426
+
427
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
428
+ text_input_ids, untruncated_ids
429
+ ):
430
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
431
+ logger.warning(
432
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
433
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
434
+ )
435
+
436
+ prompt_embeds = text_encoder(
437
+ text_input_ids.to(device),
438
+ output_hidden_states=True,
439
+ )
440
+
441
+ # We are only ALWAYS interested in the pooled output of the final text encoder
442
+ pooled_prompt_embeds = prompt_embeds[0]
443
+ prompt_embeds = prompt_embeds.hidden_states[-2]
444
+
445
+ prompt_embeds_list.append(prompt_embeds)
446
+
447
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
448
+
449
+ # get unconditional embeddings for classifier free guidance
450
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
451
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
452
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
453
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
454
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
455
+ negative_prompt = negative_prompt or ""
456
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
457
+
458
+ uncond_tokens: List[str]
459
+ if prompt is not None and type(prompt) is not type(negative_prompt):
460
+ raise TypeError(
461
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
462
+ f" {type(prompt)}."
463
+ )
464
+ elif isinstance(negative_prompt, str):
465
+ uncond_tokens = [negative_prompt, negative_prompt_2]
466
+ elif batch_size != len(negative_prompt):
467
+ raise ValueError(
468
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
469
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
470
+ " the batch size of `prompt`."
471
+ )
472
+ else:
473
+ uncond_tokens = [negative_prompt, negative_prompt_2]
474
+
475
+ negative_prompt_embeds_list = []
476
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
477
+ if isinstance(self, TextualInversionLoaderMixin):
478
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
479
+
480
+ max_length = prompt_embeds.shape[1]
481
+ uncond_input = tokenizer(
482
+ negative_prompt,
483
+ padding="max_length",
484
+ max_length=max_length,
485
+ truncation=True,
486
+ return_tensors="pt",
487
+ )
488
+
489
+ negative_prompt_embeds = text_encoder(
490
+ uncond_input.input_ids.to(device),
491
+ output_hidden_states=True,
492
+ )
493
+ # We are only ALWAYS interested in the pooled output of the final text encoder
494
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
495
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
496
+
497
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
498
+
499
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
500
+
501
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
502
+ bs_embed, seq_len, _ = prompt_embeds.shape
503
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
504
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
505
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
506
+
507
+ if do_classifier_free_guidance:
508
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
509
+ seq_len = negative_prompt_embeds.shape[1]
510
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
511
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
512
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
513
+
514
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
515
+ bs_embed * num_images_per_prompt, -1
516
+ )
517
+ if do_classifier_free_guidance:
518
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
519
+ bs_embed * num_images_per_prompt, -1
520
+ )
521
+
522
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
523
+
524
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
525
+ def prepare_extra_step_kwargs(self, generator, eta):
526
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
527
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
528
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
529
+ # and should be between [0, 1]
530
+
531
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
532
+ extra_step_kwargs = {}
533
+ if accepts_eta:
534
+ extra_step_kwargs["eta"] = eta
535
+
536
+ # check if the scheduler accepts generator
537
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
538
+ if accepts_generator:
539
+ extra_step_kwargs["generator"] = generator
540
+ return extra_step_kwargs
541
+
542
+ def check_inputs(
543
+ self,
544
+ prompt,
545
+ prompt_2,
546
+ height,
547
+ width,
548
+ callback_steps,
549
+ negative_prompt=None,
550
+ negative_prompt_2=None,
551
+ prompt_embeds=None,
552
+ negative_prompt_embeds=None,
553
+ pooled_prompt_embeds=None,
554
+ negative_pooled_prompt_embeds=None,
555
+ ):
556
+ if height % 8 != 0 or width % 8 != 0:
557
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
558
+
559
+ if (callback_steps is None) or (
560
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
561
+ ):
562
+ raise ValueError(
563
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
564
+ f" {type(callback_steps)}."
565
+ )
566
+
567
+ if prompt is not None and prompt_embeds is not None:
568
+ raise ValueError(
569
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
570
+ " only forward one of the two."
571
+ )
572
+ elif prompt_2 is not None and prompt_embeds is not None:
573
+ raise ValueError(
574
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
575
+ " only forward one of the two."
576
+ )
577
+ elif prompt is None and prompt_embeds is None:
578
+ raise ValueError(
579
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
580
+ )
581
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
582
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
583
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
584
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
585
+
586
+ if negative_prompt is not None and negative_prompt_embeds is not None:
587
+ raise ValueError(
588
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
589
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
590
+ )
591
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
592
+ raise ValueError(
593
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
594
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
595
+ )
596
+
597
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
598
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
599
+ raise ValueError(
600
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
601
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
602
+ f" {negative_prompt_embeds.shape}."
603
+ )
604
+
605
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
606
+ raise ValueError(
607
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
608
+ )
609
+
610
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
611
+ raise ValueError(
612
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
613
+ )
614
+
615
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
616
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
617
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
618
+ if isinstance(generator, list) and len(generator) != batch_size:
619
+ raise ValueError(
620
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
621
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
622
+ )
623
+
624
+ if latents is None:
625
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
626
+ else:
627
+ latents = latents.to(device)
628
+
629
+ # scale the initial noise by the standard deviation required by the scheduler
630
+ latents = latents * self.scheduler.init_noise_sigma
631
+ return latents
632
+
633
+ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
634
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
635
+
636
+ passed_add_embed_dim = (
637
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
638
+ )
639
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
640
+
641
+ if expected_add_embed_dim != passed_add_embed_dim:
642
+ raise ValueError(
643
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
644
+ )
645
+
646
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
647
+ return add_time_ids
648
+
649
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
650
+ def upcast_vae(self):
651
+ dtype = self.vae.dtype
652
+ self.vae.to(dtype=torch.float32)
653
+ use_torch_2_0_or_xformers = isinstance(
654
+ self.vae.decoder.mid_block.attentions[0].processor,
655
+ (
656
+ AttnProcessor2_0,
657
+ XFormersAttnProcessor,
658
+ LoRAXFormersAttnProcessor,
659
+ LoRAAttnProcessor2_0,
660
+ ),
661
+ )
662
+ # if xformers or torch_2_0 is used attention block does not need
663
+ # to be in float32 which can save lots of memory
664
+ if use_torch_2_0_or_xformers:
665
+ self.vae.post_quant_conv.to(dtype)
666
+ self.vae.decoder.conv_in.to(dtype)
667
+ self.vae.decoder.mid_block.to(dtype)
668
+
669
+ @torch.no_grad()
670
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
671
+ def __call__(
672
+ self,
673
+ prompt: Union[str, List[str]] = None,
674
+ prompt_2: Optional[Union[str, List[str]]] = None,
675
+ height: Optional[int] = None,
676
+ width: Optional[int] = None,
677
+ num_inference_steps: int = 50,
678
+ denoising_end: Optional[float] = None,
679
+ guidance_scale: float = 5.0,
680
+ negative_prompt: Optional[Union[str, List[str]]] = None,
681
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
682
+ num_images_per_prompt: Optional[int] = 1,
683
+ eta: float = 0.0,
684
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
685
+ latents: Optional[torch.FloatTensor] = None,
686
+ prompt_embeds: Optional[torch.FloatTensor] = None,
687
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
688
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
689
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
690
+ output_type: Optional[str] = "pil",
691
+ return_dict: bool = True,
692
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
693
+ callback_steps: int = 1,
694
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
695
+ guidance_rescale: float = 0.0,
696
+ original_size: Optional[Tuple[int, int]] = None,
697
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
698
+ target_size: Optional[Tuple[int, int]] = None,
699
+ resolutions_list: Optional[Union[int, List[int]]] = None,
700
+ restart_steps: Optional[Union[int, List[int]]] = None,
701
+ cosine_scale: float = 2.0,
702
+ dilate_tau: int = 35,
703
+ fast_mode: bool = False,
704
+ ):
705
+ r"""
706
+ Function invoked when calling the pipeline for generation.
707
+
708
+ Args:
709
+ prompt (`str` or `List[str]`, *optional*):
710
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
711
+ instead.
712
+ prompt_2 (`str` or `List[str]`, *optional*):
713
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
714
+ used in both text-encoders
715
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
716
+ The height in pixels of the generated image.
717
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
718
+ The width in pixels of the generated image.
719
+ num_inference_steps (`int`, *optional*, defaults to 50):
720
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
721
+ expense of slower inference.
722
+ denoising_end (`float`, *optional*):
723
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
724
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
725
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
726
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
727
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
728
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
729
+ guidance_scale (`float`, *optional*, defaults to 5.0):
730
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
731
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
732
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
733
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
734
+ usually at the expense of lower image quality.
735
+ negative_prompt (`str` or `List[str]`, *optional*):
736
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
737
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
738
+ less than `1`).
739
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
740
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
741
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
742
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
743
+ The number of images to generate per prompt.
744
+ eta (`float`, *optional*, defaults to 0.0):
745
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
746
+ [`schedulers.DDIMScheduler`], will be ignored for others.
747
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
748
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
749
+ to make generation deterministic.
750
+ latents (`torch.FloatTensor`, *optional*):
751
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
752
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
753
+ tensor will ge generated by sampling using the supplied random `generator`.
754
+ prompt_embeds (`torch.FloatTensor`, *optional*):
755
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
756
+ provided, text embeddings will be generated from `prompt` input argument.
757
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
758
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
759
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
760
+ argument.
761
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
762
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
763
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
764
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
765
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
766
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
767
+ input argument.
768
+ output_type (`str`, *optional*, defaults to `"pil"`):
769
+ The output format of the generate image. Choose between
770
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
771
+ return_dict (`bool`, *optional*, defaults to `True`):
772
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
773
+ of a plain tuple.
774
+ callback (`Callable`, *optional*):
775
+ A function that will be called every `callback_steps` steps during inference. The function will be
776
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
777
+ callback_steps (`int`, *optional*, defaults to 1):
778
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
779
+ called at every step.
780
+ cross_attention_kwargs (`dict`, *optional*):
781
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
782
+ `self.processor` in
783
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
784
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
785
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
786
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
787
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
788
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
789
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
790
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
791
+ `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
792
+ explained in section 2.2 of
793
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
794
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
795
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
796
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
797
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
798
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
799
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
800
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
801
+ not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
802
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
803
+
804
+ Examples:
805
+
806
+ Returns:
807
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
808
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
809
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
810
+ """
811
+
812
+
813
+ # 0. Default height and width to unet
814
+ if resolutions_list:
815
+ height, width = resolutions_list[0]
816
+ target_sizes = resolutions_list[1:]
817
+ if not restart_steps:
818
+ restart_steps = [15] * len(target_sizes)
819
+ else:
820
+ height = height or self.default_sample_size * self.vae_scale_factor
821
+ width = width or self.default_sample_size * self.vae_scale_factor
822
+
823
+ original_size = original_size or (height, width)
824
+ target_size = target_size or (height, width)
825
+
826
+ # 1. Check inputs. Raise error if not correct
827
+ self.check_inputs(
828
+ prompt,
829
+ prompt_2,
830
+ height,
831
+ width,
832
+ callback_steps,
833
+ negative_prompt,
834
+ negative_prompt_2,
835
+ prompt_embeds,
836
+ negative_prompt_embeds,
837
+ pooled_prompt_embeds,
838
+ negative_pooled_prompt_embeds,
839
+ )
840
+
841
+ # 2. Define call parameters
842
+ if prompt is not None and isinstance(prompt, str):
843
+ batch_size = 1
844
+ elif prompt is not None and isinstance(prompt, list):
845
+ batch_size = len(prompt)
846
+ else:
847
+ batch_size = prompt_embeds.shape[0]
848
+
849
+ device = self._execution_device
850
+
851
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
852
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
853
+ # corresponds to doing no classifier free guidance.
854
+ do_classifier_free_guidance = guidance_scale > 1.0
855
+
856
+ # 3. Encode input prompt
857
+ text_encoder_lora_scale = (
858
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
859
+ )
860
+ (
861
+ prompt_embeds,
862
+ negative_prompt_embeds,
863
+ pooled_prompt_embeds,
864
+ negative_pooled_prompt_embeds,
865
+ ) = self.encode_prompt(
866
+ prompt=prompt,
867
+ prompt_2=prompt_2,
868
+ device=device,
869
+ num_images_per_prompt=num_images_per_prompt,
870
+ do_classifier_free_guidance=do_classifier_free_guidance,
871
+ negative_prompt=negative_prompt,
872
+ negative_prompt_2=negative_prompt_2,
873
+ prompt_embeds=prompt_embeds,
874
+ negative_prompt_embeds=negative_prompt_embeds,
875
+ pooled_prompt_embeds=pooled_prompt_embeds,
876
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
877
+ lora_scale=text_encoder_lora_scale,
878
+ )
879
+
880
+ # 4. Prepare timesteps
881
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
882
+
883
+ timesteps = self.scheduler.timesteps
884
+
885
+ # 5. Prepare latent variables
886
+ num_channels_latents = self.unet.config.in_channels
887
+ latents = self.prepare_latents(
888
+ batch_size * num_images_per_prompt,
889
+ num_channels_latents,
890
+ height,
891
+ width,
892
+ prompt_embeds.dtype,
893
+ device,
894
+ generator,
895
+ latents,
896
+ )
897
+
898
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
899
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
900
+
901
+ # 7. Prepare added time ids & embeddings
902
+ add_text_embeds = pooled_prompt_embeds
903
+ add_time_ids = self._get_add_time_ids(
904
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
905
+ )
906
+
907
+ if do_classifier_free_guidance:
908
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
909
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
910
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
911
+
912
+ prompt_embeds = prompt_embeds.to(device)
913
+ add_text_embeds = add_text_embeds.to(device)
914
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
915
+
916
+ # 8. Denoising loop
917
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
918
+
919
+ # 9.1 Apply denoising_end
920
+ if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
921
+ discrete_timestep_cutoff = int(
922
+ round(
923
+ self.scheduler.config.num_train_timesteps
924
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
925
+ )
926
+ )
927
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
928
+ timesteps = timesteps[:num_inference_steps]
929
+
930
+ for block in self.unet.down_blocks + [self.unet.mid_block] + self.unet.up_blocks:
931
+ for module in block.modules():
932
+ if isinstance(module, BasicTransformerBlock):
933
+ module.forward = ori_forward.__get__(module, BasicTransformerBlock)
934
+
935
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
936
+ for i, t in enumerate(timesteps):
937
+ # expand the latents if we are doing classifier free guidance
938
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
939
+
940
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
941
+
942
+ # predict the noise residual
943
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
944
+ noise_pred = self.unet(
945
+ latent_model_input,
946
+ t,
947
+ encoder_hidden_states=prompt_embeds,
948
+ cross_attention_kwargs=cross_attention_kwargs,
949
+ added_cond_kwargs=added_cond_kwargs,
950
+ return_dict=False,
951
+ )[0]
952
+
953
+ # perform guidance
954
+ if do_classifier_free_guidance:
955
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
956
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
957
+
958
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
959
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
960
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
961
+
962
+ # compute the previous noisy sample x_t -> x_t-1
963
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
964
+
965
+ # call the callback, if provided
966
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
967
+ progress_bar.update()
968
+ if callback is not None and i % callback_steps == 0:
969
+ callback(i, t, latents)
970
+
971
+ for restart_index, target_size in enumerate(target_sizes):
972
+ restart_step = restart_steps[restart_index]
973
+ target_size_ = [target_size[0]//8, target_size[1]//8]
974
+
975
+ for block in self.unet.down_blocks + [self.unet.mid_block] + self.unet.up_blocks:
976
+ for module in block.modules():
977
+ if isinstance(module, BasicTransformerBlock):
978
+ module.forward = scale_forward.__get__(module, BasicTransformerBlock)
979
+ module.current_hw = target_size
980
+ module.fast_mode = fast_mode
981
+
982
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
983
+ if needs_upcasting:
984
+ self.upcast_vae()
985
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
986
+
987
+ latents = latents / self.vae.config.scaling_factor
988
+ image = self.vae.decode(latents, return_dict=False)[0]
989
+ image = torch.nn.functional.interpolate(
990
+ image,
991
+ size=target_size,
992
+ mode='bicubic',
993
+ )
994
+ latents = self.vae.encode(image).latent_dist.sample().half()
995
+ latents = latents * self.vae.config.scaling_factor
996
+
997
+ noise_latents = []
998
+ noise = torch.randn_like(latents)
999
+ for timestep in self.scheduler.timesteps:
1000
+ noise_latent = self.scheduler.add_noise(latents, noise, timestep.unsqueeze(0))
1001
+ noise_latents.append(noise_latent)
1002
+ latents = noise_latents[restart_step]
1003
+
1004
+ self.scheduler._step_index = 0
1005
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1006
+ for i, t in enumerate(timesteps):
1007
+
1008
+ if i < restart_step:
1009
+ self.scheduler._step_index += 1
1010
+ progress_bar.update()
1011
+ continue
1012
+
1013
+ cosine_factor = 0.5 * (1 + torch.cos(torch.pi * (self.scheduler.config.num_train_timesteps - t) / self.scheduler.config.num_train_timesteps)).cpu()
1014
+ c1 = cosine_factor ** cosine_scale
1015
+ latents = latents * (1 - c1) + noise_latents[i] * c1
1016
+
1017
+ dilate_coef=target_size[1]//1024
1018
+
1019
+ dilate_layers = [
1020
+ # "down_blocks.1.resnets.0.conv1",
1021
+ # "down_blocks.1.resnets.0.conv2",
1022
+ # "down_blocks.1.resnets.1.conv1",
1023
+ # "down_blocks.1.resnets.1.conv2",
1024
+ "down_blocks.1.downsamplers.0.conv",
1025
+ "down_blocks.2.resnets.0.conv1",
1026
+ "down_blocks.2.resnets.0.conv2",
1027
+ "down_blocks.2.resnets.1.conv1",
1028
+ "down_blocks.2.resnets.1.conv2",
1029
+ # "up_blocks.0.resnets.0.conv1",
1030
+ # "up_blocks.0.resnets.0.conv2",
1031
+ # "up_blocks.0.resnets.1.conv1",
1032
+ # "up_blocks.0.resnets.1.conv2",
1033
+ # "up_blocks.0.resnets.2.conv1",
1034
+ # "up_blocks.0.resnets.2.conv2",
1035
+ # "up_blocks.0.upsamplers.0.conv",
1036
+ # "up_blocks.1.resnets.0.conv1",
1037
+ # "up_blocks.1.resnets.0.conv2",
1038
+ # "up_blocks.1.resnets.1.conv1",
1039
+ # "up_blocks.1.resnets.1.conv2",
1040
+ # "up_blocks.1.resnets.2.conv1",
1041
+ # "up_blocks.1.resnets.2.conv2",
1042
+ # "up_blocks.1.upsamplers.0.conv",
1043
+ # "up_blocks.2.resnets.0.conv1",
1044
+ # "up_blocks.2.resnets.0.conv2",
1045
+ # "up_blocks.2.resnets.1.conv1",
1046
+ # "up_blocks.2.resnets.1.conv2",
1047
+ # "up_blocks.2.resnets.2.conv1",
1048
+ # "up_blocks.2.resnets.2.conv2",
1049
+ "mid_block.resnets.0.conv1",
1050
+ "mid_block.resnets.0.conv2",
1051
+ "mid_block.resnets.1.conv1",
1052
+ "mid_block.resnets.1.conv2"
1053
+ ]
1054
+
1055
+ for name, module in self.unet.named_modules():
1056
+ if name in dilate_layers:
1057
+ if i < dilate_tau:
1058
+ module.dilation = (dilate_coef, dilate_coef)
1059
+ module.padding = (dilate_coef, dilate_coef)
1060
+ else:
1061
+ module.dilation = (1, 1)
1062
+ module.padding = (1, 1)
1063
+
1064
+ # expand the latents if we are doing classifier free guidance
1065
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1066
+
1067
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1068
+
1069
+
1070
+ # predict the noise residual
1071
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1072
+ noise_pred = self.unet(
1073
+ latent_model_input,
1074
+ t,
1075
+ encoder_hidden_states=prompt_embeds,
1076
+ cross_attention_kwargs=cross_attention_kwargs,
1077
+ added_cond_kwargs=added_cond_kwargs,
1078
+ return_dict=False,
1079
+ )[0]
1080
+
1081
+ # perform guidance
1082
+ if do_classifier_free_guidance:
1083
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1084
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1085
+
1086
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
1087
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1088
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
1089
+
1090
+ # compute the previous noisy sample x_t -> x_t-1
1091
+ latents_dtype = latents.dtype
1092
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1093
+ if latents.dtype != latents_dtype:
1094
+ if torch.backends.mps.is_available():
1095
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1096
+ latents = latents.to(latents_dtype)
1097
+
1098
+ # call the callback, if provided
1099
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1100
+ progress_bar.update()
1101
+ if callback is not None and i % callback_steps == 0:
1102
+ callback(i, t, latents)
1103
+
1104
+ for name, module in self.unet.named_modules():
1105
+ # if ('.conv' in name) and ('.conv_' not in name):
1106
+ if name in dilate_layers:
1107
+ module.dilation = (1, 1)
1108
+ module.padding = (1, 1)
1109
+
1110
+ # make sure the VAE is in float32 mode, as it overflows in float16
1111
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
1112
+ self.upcast_vae()
1113
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1114
+
1115
+ if not output_type == "latent":
1116
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1117
+ else:
1118
+ image = latents
1119
+ return StableDiffusionXLPipelineOutput(images=image)
1120
+
1121
+ # apply watermark if available
1122
+ if self.watermark is not None:
1123
+ image = self.watermark.apply_watermark(image)
1124
+
1125
+ image = self.image_processor.postprocess(image, output_type=output_type)
1126
+
1127
+ # Offload last model to CPU
1128
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1129
+ self.final_offload_hook.offload()
1130
+
1131
+ if not return_dict:
1132
+ return (image,)
1133
+
1134
+ return StableDiffusionXLPipelineOutput(images=image)
1135
+
1136
+ # Overrride to properly handle the loading and unloading of the additional text encoder.
1137
+ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
1138
+ # We could have accessed the unet config from `lora_state_dict()` too. We pass
1139
+ # it here explicitly to be able to tell that it's coming from an SDXL
1140
+ # pipeline.
1141
+ state_dict, network_alphas = self.lora_state_dict(
1142
+ pretrained_model_name_or_path_or_dict,
1143
+ unet_config=self.unet.config,
1144
+ **kwargs,
1145
+ )
1146
+ self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
1147
+
1148
+ text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
1149
+ if len(text_encoder_state_dict) > 0:
1150
+ self.load_lora_into_text_encoder(
1151
+ text_encoder_state_dict,
1152
+ network_alphas=network_alphas,
1153
+ text_encoder=self.text_encoder,
1154
+ prefix="text_encoder",
1155
+ lora_scale=self.lora_scale,
1156
+ )
1157
+
1158
+ text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
1159
+ if len(text_encoder_2_state_dict) > 0:
1160
+ self.load_lora_into_text_encoder(
1161
+ text_encoder_2_state_dict,
1162
+ network_alphas=network_alphas,
1163
+ text_encoder=self.text_encoder_2,
1164
+ prefix="text_encoder_2",
1165
+ lora_scale=self.lora_scale,
1166
+ )
1167
+
1168
+ @classmethod
1169
+ def save_lora_weights(
1170
+ self,
1171
+ save_directory: Union[str, os.PathLike],
1172
+ unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1173
+ text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1174
+ text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1175
+ is_main_process: bool = True,
1176
+ weight_name: str = None,
1177
+ save_function: Callable = None,
1178
+ safe_serialization: bool = True,
1179
+ ):
1180
+ state_dict = {}
1181
+
1182
+ def pack_weights(layers, prefix):
1183
+ layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
1184
+ layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
1185
+ return layers_state_dict
1186
+
1187
+ state_dict.update(pack_weights(unet_lora_layers, "unet"))
1188
+
1189
+ if text_encoder_lora_layers and text_encoder_2_lora_layers:
1190
+ state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
1191
+ state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
1192
+
1193
+ self.write_lora_layers(
1194
+ state_dict=state_dict,
1195
+ save_directory=save_directory,
1196
+ is_main_process=is_main_process,
1197
+ weight_name=weight_name,
1198
+ save_function=save_function,
1199
+ safe_serialization=safe_serialization,
1200
+ )
1201
+
1202
+ def _remove_text_encoder_monkey_patch(self):
1203
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
1204
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.20.2
2
+ accelerate==0.16.0
3
+ clean-fid==0.1.35
4
+ torch~=2.0.0
5
+ scipy~=1.9.1
6
+ omegaconf~=2.1.1
7
+ accelerate~=0.16.0
8
+ transformers~=4.25.1
9
+ tqdm
10
+ xformers~=0.0.18
11
+ einops
12
+ gradio
scale_attention.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ from einops import rearrange, repeat
7
+ import random
8
+
9
+ def gaussian_kernel(kernel_size=3, sigma=1.0, channels=3):
10
+ x_coord = torch.arange(kernel_size)
11
+ gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2))
12
+ gaussian_1d = gaussian_1d / gaussian_1d.sum()
13
+ gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :]
14
+ kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1)
15
+
16
+ return kernel
17
+
18
+ def gaussian_filter(latents, kernel_size=3, sigma=1.0):
19
+ channels = latents.shape[1]
20
+ kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype)
21
+ blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels)
22
+
23
+ return blurred_latents
24
+
25
+ def get_views(height, width, h_window_size=128, w_window_size=128, scale_factor=8):
26
+ height = int(height)
27
+ width = int(width)
28
+ h_window_stride = h_window_size // 2
29
+ w_window_stride = w_window_size // 2
30
+ h_window_size = int(h_window_size / scale_factor)
31
+ w_window_size = int(w_window_size / scale_factor)
32
+ h_window_stride = int(h_window_stride / scale_factor)
33
+ w_window_stride = int(w_window_stride / scale_factor)
34
+ num_blocks_height = int((height - h_window_size) / h_window_stride - 1e-6) + 2 if height > h_window_size else 1
35
+ num_blocks_width = int((width - w_window_size) / w_window_stride - 1e-6) + 2 if width > w_window_size else 1
36
+ total_num_blocks = int(num_blocks_height * num_blocks_width)
37
+ views = []
38
+ for i in range(total_num_blocks):
39
+ h_start = int((i // num_blocks_width) * h_window_stride)
40
+ h_end = h_start + h_window_size
41
+ w_start = int((i % num_blocks_width) * w_window_stride)
42
+ w_end = w_start + w_window_size
43
+
44
+ if h_end > height:
45
+ h_start = int(h_start + height - h_end)
46
+ h_end = int(height)
47
+ if w_end > width:
48
+ w_start = int(w_start + width - w_end)
49
+ w_end = int(width)
50
+ if h_start < 0:
51
+ h_end = int(h_end - h_start)
52
+ h_start = 0
53
+ if w_start < 0:
54
+ w_end = int(w_end - w_start)
55
+ w_start = 0
56
+
57
+ random_jitter = True
58
+ if random_jitter:
59
+ h_jitter_range = h_window_size // 8
60
+ w_jitter_range = w_window_size // 8
61
+ h_jitter = 0
62
+ w_jitter = 0
63
+
64
+ if (w_start != 0) and (w_end != width):
65
+ w_jitter = random.randint(-w_jitter_range, w_jitter_range)
66
+ elif (w_start == 0) and (w_end != width):
67
+ w_jitter = random.randint(-w_jitter_range, 0)
68
+ elif (w_start != 0) and (w_end == width):
69
+ w_jitter = random.randint(0, w_jitter_range)
70
+ if (h_start != 0) and (h_end != height):
71
+ h_jitter = random.randint(-h_jitter_range, h_jitter_range)
72
+ elif (h_start == 0) and (h_end != height):
73
+ h_jitter = random.randint(-h_jitter_range, 0)
74
+ elif (h_start != 0) and (h_end == height):
75
+ h_jitter = random.randint(0, h_jitter_range)
76
+ h_start += (h_jitter + h_jitter_range)
77
+ h_end += (h_jitter + h_jitter_range)
78
+ w_start += (w_jitter + w_jitter_range)
79
+ w_end += (w_jitter + w_jitter_range)
80
+
81
+ views.append((h_start, h_end, w_start, w_end))
82
+ return views
83
+
84
+ def scale_forward(
85
+ self,
86
+ hidden_states: torch.FloatTensor,
87
+ attention_mask: Optional[torch.FloatTensor] = None,
88
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
89
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
90
+ timestep: Optional[torch.LongTensor] = None,
91
+ cross_attention_kwargs: Dict[str, Any] = None,
92
+ class_labels: Optional[torch.LongTensor] = None,
93
+ ):
94
+ # Notice that normalization is always applied before the real computation in the following blocks.
95
+ if self.current_hw:
96
+ current_scale_num_h, current_scale_num_w = self.current_hw[0] // 1024, self.current_hw[1] // 1024
97
+ else:
98
+ current_scale_num_h, current_scale_num_w = 1, 1
99
+
100
+ # 0. Self-Attention
101
+ if self.use_ada_layer_norm:
102
+ norm_hidden_states = self.norm1(hidden_states, timestep)
103
+ elif self.use_ada_layer_norm_zero:
104
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
105
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
106
+ )
107
+ else:
108
+ norm_hidden_states = self.norm1(hidden_states)
109
+
110
+ # 2. Prepare GLIGEN inputs
111
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
112
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
113
+
114
+ ratio_hw = current_scale_num_h / current_scale_num_w
115
+ latent_h = int((norm_hidden_states.shape[1] * ratio_hw) ** 0.5)
116
+ latent_w = int(latent_h / ratio_hw)
117
+ scale_factor = 128 * current_scale_num_h / latent_h
118
+ if ratio_hw > 1:
119
+ sub_h = 128
120
+ sub_w = int(128 / ratio_hw)
121
+ else:
122
+ sub_h = int(128 * ratio_hw)
123
+ sub_w = 128
124
+
125
+ h_jitter_range = int(sub_h / scale_factor // 8)
126
+ w_jitter_range = int(sub_w / scale_factor // 8)
127
+ views = get_views(latent_h, latent_w, sub_h, sub_w, scale_factor = scale_factor)
128
+
129
+ current_scale_num = max(current_scale_num_h, current_scale_num_w)
130
+ global_views = [[h, w] for h in range(current_scale_num_h) for w in range(current_scale_num_w)]
131
+
132
+ if self.fast_mode:
133
+ four_window = False
134
+ fourg_window = True
135
+ else:
136
+ four_window = True
137
+ fourg_window = False
138
+
139
+ if four_window:
140
+ norm_hidden_states_ = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
141
+ norm_hidden_states_ = F.pad(norm_hidden_states_, (0, 0, w_jitter_range, w_jitter_range, h_jitter_range, h_jitter_range), 'constant', 0)
142
+ value = torch.zeros_like(norm_hidden_states_)
143
+ count = torch.zeros_like(norm_hidden_states_)
144
+ for index, view in enumerate(views):
145
+ h_start, h_end, w_start, w_end = view
146
+ local_states = norm_hidden_states_[:, h_start:h_end, w_start:w_end, :]
147
+ local_states = rearrange(local_states, 'bh h w d -> bh (h w) d')
148
+ local_output = self.attn1(
149
+ local_states,
150
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
151
+ attention_mask=attention_mask,
152
+ **cross_attention_kwargs,
153
+ )
154
+ local_output = rearrange(local_output, 'bh (h w) d -> bh h w d', h = int(sub_h / scale_factor))
155
+
156
+ value[:, h_start:h_end, w_start:w_end, :] += local_output * 1
157
+ count[:, h_start:h_end, w_start:w_end, :] += 1
158
+
159
+ value = value[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
160
+ count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
161
+ attn_output = torch.where(count>0, value/count, value)
162
+
163
+ gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
164
+
165
+ attn_output_global = self.attn1(
166
+ norm_hidden_states,
167
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
168
+ attention_mask=attention_mask,
169
+ **cross_attention_kwargs,
170
+ )
171
+ attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh h w d', h = latent_h)
172
+
173
+ gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
174
+
175
+ attn_output = gaussian_local + (attn_output_global - gaussian_global)
176
+ attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d')
177
+
178
+ elif fourg_window:
179
+ norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
180
+ norm_hidden_states_ = F.pad(norm_hidden_states, (0, 0, w_jitter_range, w_jitter_range, h_jitter_range, h_jitter_range), 'constant', 0)
181
+ value = torch.zeros_like(norm_hidden_states_)
182
+ count = torch.zeros_like(norm_hidden_states_)
183
+ for index, view in enumerate(views):
184
+ h_start, h_end, w_start, w_end = view
185
+ local_states = norm_hidden_states_[:, h_start:h_end, w_start:w_end, :]
186
+ local_states = rearrange(local_states, 'bh h w d -> bh (h w) d')
187
+ local_output = self.attn1(
188
+ local_states,
189
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
190
+ attention_mask=attention_mask,
191
+ **cross_attention_kwargs,
192
+ )
193
+ local_output = rearrange(local_output, 'bh (h w) d -> bh h w d', h = int(sub_h / scale_factor))
194
+
195
+ value[:, h_start:h_end, w_start:w_end, :] += local_output * 1
196
+ count[:, h_start:h_end, w_start:w_end, :] += 1
197
+
198
+ value = value[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
199
+ count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
200
+ attn_output = torch.where(count>0, value/count, value)
201
+
202
+ gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
203
+
204
+ value = torch.zeros_like(norm_hidden_states)
205
+ count = torch.zeros_like(norm_hidden_states)
206
+ for index, global_view in enumerate(global_views):
207
+ h, w = global_view
208
+ global_states = norm_hidden_states[:, h::current_scale_num_h, w::current_scale_num_w, :]
209
+ global_states = rearrange(global_states, 'bh h w d -> bh (h w) d')
210
+ global_output = self.attn1(
211
+ global_states,
212
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
213
+ attention_mask=attention_mask,
214
+ **cross_attention_kwargs,
215
+ )
216
+ global_output = rearrange(global_output, 'bh (h w) d -> bh h w d', h = int(global_output.shape[1] ** 0.5))
217
+
218
+ value[:, h::current_scale_num_h, w::current_scale_num_w, :] += global_output * 1
219
+ count[:, h::current_scale_num_h, w::current_scale_num_w, :] += 1
220
+
221
+ attn_output_global = torch.where(count>0, value/count, value)
222
+
223
+ gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
224
+
225
+ attn_output = gaussian_local + (attn_output_global - gaussian_global)
226
+ attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d')
227
+
228
+ else:
229
+ attn_output = self.attn1(
230
+ norm_hidden_states,
231
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
232
+ attention_mask=attention_mask,
233
+ **cross_attention_kwargs,
234
+ )
235
+
236
+ if self.use_ada_layer_norm_zero:
237
+ attn_output = gate_msa.unsqueeze(1) * attn_output
238
+ hidden_states = attn_output + hidden_states
239
+
240
+ # 2.5 GLIGEN Control
241
+ if gligen_kwargs is not None:
242
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
243
+ # 2.5 ends
244
+
245
+ # 3. Cross-Attention
246
+ if self.attn2 is not None:
247
+ norm_hidden_states = (
248
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
249
+ )
250
+ attn_output = self.attn2(
251
+ norm_hidden_states,
252
+ encoder_hidden_states=encoder_hidden_states,
253
+ attention_mask=encoder_attention_mask,
254
+ **cross_attention_kwargs,
255
+ )
256
+ hidden_states = attn_output + hidden_states
257
+
258
+ # 4. Feed-forward
259
+ norm_hidden_states = self.norm3(hidden_states)
260
+
261
+ if self.use_ada_layer_norm_zero:
262
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
263
+
264
+ if self._chunk_size is not None:
265
+ # "feed_forward_chunk_size" can be used to save memory
266
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
267
+ raise ValueError(
268
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
269
+ )
270
+
271
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
272
+ ff_output = torch.cat(
273
+ [
274
+ self.ff(hid_slice)
275
+ for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
276
+ ],
277
+ dim=self._chunk_dim,
278
+ )
279
+ else:
280
+ ff_output = self.ff(norm_hidden_states)
281
+
282
+ if self.use_ada_layer_norm_zero:
283
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
284
+
285
+ hidden_states = ff_output + hidden_states
286
+
287
+ return hidden_states
288
+
289
+ def ori_forward(
290
+ self,
291
+ hidden_states: torch.FloatTensor,
292
+ attention_mask: Optional[torch.FloatTensor] = None,
293
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
294
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
295
+ timestep: Optional[torch.LongTensor] = None,
296
+ cross_attention_kwargs: Dict[str, Any] = None,
297
+ class_labels: Optional[torch.LongTensor] = None,
298
+ ):
299
+ # Notice that normalization is always applied before the real computation in the following blocks.
300
+ # 0. Self-Attention
301
+ if self.use_ada_layer_norm:
302
+ norm_hidden_states = self.norm1(hidden_states, timestep)
303
+ elif self.use_ada_layer_norm_zero:
304
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
305
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
306
+ )
307
+ else:
308
+ norm_hidden_states = self.norm1(hidden_states)
309
+
310
+ # 2. Prepare GLIGEN inputs
311
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
312
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
313
+
314
+ attn_output = self.attn1(
315
+ norm_hidden_states,
316
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
317
+ attention_mask=attention_mask,
318
+ **cross_attention_kwargs,
319
+ )
320
+
321
+ if self.use_ada_layer_norm_zero:
322
+ attn_output = gate_msa.unsqueeze(1) * attn_output
323
+ hidden_states = attn_output + hidden_states
324
+
325
+ # 2.5 GLIGEN Control
326
+ if gligen_kwargs is not None:
327
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
328
+ # 2.5 ends
329
+
330
+ # 3. Cross-Attention
331
+ if self.attn2 is not None:
332
+ norm_hidden_states = (
333
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
334
+ )
335
+ attn_output = self.attn2(
336
+ norm_hidden_states,
337
+ encoder_hidden_states=encoder_hidden_states,
338
+ attention_mask=encoder_attention_mask,
339
+ **cross_attention_kwargs,
340
+ )
341
+ hidden_states = attn_output + hidden_states
342
+
343
+ # 4. Feed-forward
344
+ norm_hidden_states = self.norm3(hidden_states)
345
+
346
+ if self.use_ada_layer_norm_zero:
347
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
348
+
349
+ if self._chunk_size is not None:
350
+ # "feed_forward_chunk_size" can be used to save memory
351
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
352
+ raise ValueError(
353
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
354
+ )
355
+
356
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
357
+ ff_output = torch.cat(
358
+ [
359
+ self.ff(hid_slice)
360
+ for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
361
+ ],
362
+ dim=self._chunk_dim,
363
+ )
364
+ else:
365
+ ff_output = self.ff(norm_hidden_states)
366
+
367
+ if self.use_ada_layer_norm_zero:
368
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
369
+
370
+ hidden_states = ff_output + hidden_states
371
+
372
+ return hidden_states