lixiang46 commited on
Commit
d5bcc1a
1 Parent(s): ae6a57b
app.py CHANGED
@@ -3,7 +3,7 @@ import random
3
  import torch
4
  from huggingface_hub import snapshot_download
5
  from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
6
- from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import StableDiffusionXLPipeline
7
  from kolors.models.modeling_chatglm import ChatGLMModel
8
  from kolors.models.tokenization_chatglm import ChatGLMTokenizer
9
  from kolors.models.unet_2d_condition import UNet2DConditionModel
@@ -11,7 +11,6 @@ from diffusers import AutoencoderKL, EulerDiscreteScheduler
11
  import gradio as gr
12
  import numpy as np
13
 
14
- device = "cuda"
15
  device = "cuda"
16
  ckpt_dir = '/home/lixiang46/Kolors/weights/Kolors'
17
  ckpt_IPA_dir = '/home/lixiang46/Kolors/weights/Kolors-IP-Adapter-Plus'
@@ -28,7 +27,15 @@ image_encoder = CLIPVisionModelWithProjection.from_pretrained(f'{ckpt_IPA_dir}/i
28
  ip_img_size = 336
29
  clip_image_processor = CLIPImageProcessor(size=ip_img_size, crop_size=ip_img_size)
30
 
31
- pipe = StableDiffusionXLPipeline(
 
 
 
 
 
 
 
 
32
  vae=vae,
33
  text_encoder=text_encoder,
34
  tokenizer=tokenizer,
@@ -39,36 +46,47 @@ pipe = StableDiffusionXLPipeline(
39
  force_zeros_for_empty_prompt=False
40
  ).to(device)
41
 
42
- if hasattr(pipe.unet, 'encoder_hid_proj'):
43
- pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj
44
 
45
- pipe.load_ip_adapter( f'{ckpt_IPA_dir}' , subfolder="", weight_name=["ip_adapter_plus_general.bin"])
46
 
47
  MAX_SEED = np.iinfo(np.int32).max
48
  MAX_IMAGE_SIZE = 2048
49
 
50
- def infer(prompt, ip_adapter_image, ip_adapter_scale, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
51
-
52
  if randomize_seed:
53
  seed = random.randint(0, MAX_SEED)
54
-
55
  generator = torch.Generator().manual_seed(seed)
56
- pipe.set_ip_adapter_scale([ip_adapter_scale])
57
- image = pipe(
58
- prompt= prompt ,
59
- ip_adapter_image=[ip_adapter_image],
60
- negative_prompt=negative_prompt,
61
- height=height,
62
- width=width,
63
- num_inference_steps=num_inference_steps,
64
- guidance_scale=guidance_scale,
65
- num_images_per_prompt=1,
66
- generator=generator
67
- ).images[0]
68
-
69
- return image
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  examples = [
 
72
  ["穿着黑色T恤衫,上面中文绿色大字写着“可图”", "image/test_ip.jpg", 0.5],
73
  ["一只可爱的小狗在奔跑", "image/test_ip2.png", 0.5]
74
  ]
@@ -171,7 +189,7 @@ with gr.Blocks(css=css) as demo:
171
 
172
  run_button.click(
173
  fn = infer,
174
- inputs = [prompt, ip_adapter_image, ip_adapter_scale, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
175
  outputs = [result]
176
  )
177
 
 
3
  import torch
4
  from huggingface_hub import snapshot_download
5
  from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
6
+ from kolors.pipelines import pipeline_stable_diffusion_xl_chatglm_256_ipadapter, pipeline_stable_diffusion_xl_chatglm_256
7
  from kolors.models.modeling_chatglm import ChatGLMModel
8
  from kolors.models.tokenization_chatglm import ChatGLMTokenizer
9
  from kolors.models.unet_2d_condition import UNet2DConditionModel
 
11
  import gradio as gr
12
  import numpy as np
13
 
 
14
  device = "cuda"
15
  ckpt_dir = '/home/lixiang46/Kolors/weights/Kolors'
16
  ckpt_IPA_dir = '/home/lixiang46/Kolors/weights/Kolors-IP-Adapter-Plus'
 
27
  ip_img_size = 336
28
  clip_image_processor = CLIPImageProcessor(size=ip_img_size, crop_size=ip_img_size)
29
 
30
+ pipe_t2i = pipeline_stable_diffusion_xl_chatglm_256.StableDiffusionXLPipeline(
31
+ vae=vae,text_encoder=text_encoder,
32
+ tokenizer=tokenizer,
33
+ unet=unet,
34
+ scheduler=scheduler,
35
+ force_zeros_for_empty_prompt=False
36
+ ).to(device)
37
+
38
+ pipe_i2i = pipeline_stable_diffusion_xl_chatglm_256_ipadapter.StableDiffusionXLPipeline(
39
  vae=vae,
40
  text_encoder=text_encoder,
41
  tokenizer=tokenizer,
 
46
  force_zeros_for_empty_prompt=False
47
  ).to(device)
48
 
49
+ if hasattr(pipe_i2i.unet, 'encoder_hid_proj'):
50
+ pipe_i2i.unet.text_encoder_hid_proj = pipe_i2i.unet.encoder_hid_proj
51
 
52
+ pipe_i2i.load_ip_adapter( f'{ckpt_IPA_dir}' , subfolder="", weight_name=["ip_adapter_plus_general.bin"])
53
 
54
  MAX_SEED = np.iinfo(np.int32).max
55
  MAX_IMAGE_SIZE = 2048
56
 
57
+ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, ip_adapter_image = None, ip_adapter_scale = None):
 
58
  if randomize_seed:
59
  seed = random.randint(0, MAX_SEED)
 
60
  generator = torch.Generator().manual_seed(seed)
61
+
62
+ if ip_adapter_image is None:
63
+ image = pipe_t2i(
64
+ prompt = prompt,
65
+ negative_prompt = negative_prompt,
66
+ guidance_scale = guidance_scale,
67
+ num_inference_steps = num_inference_steps,
68
+ width = width,
69
+ height = height,
70
+ generator = generator
71
+ ).images[0]
72
+ return image
73
+ else:
74
+ pipe_i2i.set_ip_adapter_scale([ip_adapter_scale])
75
+ image = pipe_i2i(
76
+ prompt= prompt ,
77
+ ip_adapter_image=[ip_adapter_image],
78
+ negative_prompt=negative_prompt,
79
+ height=height,
80
+ width=width,
81
+ num_inference_steps=num_inference_steps,
82
+ guidance_scale=guidance_scale,
83
+ num_images_per_prompt=1,
84
+ generator=generator
85
+ ).images[0]
86
+ return image
87
 
88
  examples = [
89
+ [None, "一张瓢虫的照片,微距,变焦,高质量,电影,拿着一个牌子,写着“可图”", None],
90
  ["穿着黑色T恤衫,上面中文绿色大字写着“可图”", "image/test_ip.jpg", 0.5],
91
  ["一只可爱的小狗在奔跑", "image/test_ip2.png", 0.5]
92
  ]
 
189
 
190
  run_button.click(
191
  fn = infer,
192
+ inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, ip_adapter_image, ip_adapter_scale],
193
  outputs = [result]
194
  )
195
 
kolors/pipelines/pipeline_stable_diffusion_xl_chatglm_256.py ADDED
@@ -0,0 +1,841 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import sys
15
+ import os
16
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
17
+ from kolors.models.modeling_chatglm import ChatGLMModel
18
+ from kolors.models.tokenization_chatglm import ChatGLMTokenizer
19
+ import inspect
20
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
21
+ import torch
22
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
23
+ from transformers import XLMRobertaModel, ChineseCLIPTextModel
24
+
25
+ from diffusers.image_processor import VaeImageProcessor
26
+ from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
27
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
28
+ from diffusers.models.attention_processor import (
29
+ AttnProcessor2_0,
30
+ LoRAAttnProcessor2_0,
31
+ LoRAXFormersAttnProcessor,
32
+ XFormersAttnProcessor,
33
+ )
34
+ from diffusers.schedulers import KarrasDiffusionSchedulers
35
+ from diffusers.utils import (
36
+ is_accelerate_available,
37
+ is_accelerate_version,
38
+ logging,
39
+ replace_example_docstring,
40
+ )
41
+ try:
42
+ from diffusers.utils import randn_tensor
43
+ except:
44
+ from diffusers.utils.torch_utils import randn_tensor
45
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
46
+ from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
47
+
48
+
49
+
50
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
51
+
52
+ EXAMPLE_DOC_STRING = """
53
+ Examples:
54
+ ```py
55
+ >>> import torch
56
+ >>> from diffusers import StableDiffusionXLPipeline
57
+
58
+ >>> pipe = StableDiffusionXLPipeline.from_pretrained(
59
+ ... "stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16
60
+ ... )
61
+ >>> pipe = pipe.to("cuda")
62
+
63
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
64
+ >>> image = pipe(prompt).images[0]
65
+ ```
66
+ """
67
+
68
+
69
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
70
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
71
+ """
72
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
73
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
74
+ """
75
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
76
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
77
+ # rescale the results from guidance (fixes overexposure)
78
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
79
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
80
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
81
+ return noise_cfg
82
+
83
+
84
+ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
85
+ r"""
86
+ Pipeline for text-to-image generation using Stable Diffusion XL.
87
+
88
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
89
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
90
+
91
+ In addition the pipeline inherits the following loading methods:
92
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
93
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
94
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
95
+
96
+ as well as the following saving methods:
97
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
98
+
99
+ Args:
100
+ vae ([`AutoencoderKL`]):
101
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
102
+ text_encoder ([`CLIPTextModel`]):
103
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
104
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
105
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
106
+
107
+ tokenizer (`CLIPTokenizer`):
108
+ Tokenizer of class
109
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
110
+
111
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
112
+ scheduler ([`SchedulerMixin`]):
113
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
114
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
115
+ """
116
+
117
+ def __init__(
118
+ self,
119
+ vae: AutoencoderKL,
120
+ text_encoder: ChatGLMModel,
121
+ tokenizer: ChatGLMTokenizer,
122
+ unet: UNet2DConditionModel,
123
+ scheduler: KarrasDiffusionSchedulers,
124
+ force_zeros_for_empty_prompt: bool = True,
125
+ ):
126
+ super().__init__()
127
+
128
+ self.register_modules(
129
+ vae=vae,
130
+ text_encoder=text_encoder,
131
+ tokenizer=tokenizer,
132
+ unet=unet,
133
+ scheduler=scheduler,
134
+ )
135
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
136
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
137
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
138
+ self.default_sample_size = self.unet.config.sample_size
139
+
140
+ # self.watermark = StableDiffusionXLWatermarker()
141
+
142
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
143
+ def enable_vae_slicing(self):
144
+ r"""
145
+ Enable sliced VAE decoding.
146
+
147
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
148
+ steps. This is useful to save some memory and allow larger batch sizes.
149
+ """
150
+ self.vae.enable_slicing()
151
+
152
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
153
+ def disable_vae_slicing(self):
154
+ r"""
155
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
156
+ computing decoding in one step.
157
+ """
158
+ self.vae.disable_slicing()
159
+
160
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
161
+ def enable_vae_tiling(self):
162
+ r"""
163
+ Enable tiled VAE decoding.
164
+
165
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
166
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
167
+ """
168
+ self.vae.enable_tiling()
169
+
170
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
171
+ def disable_vae_tiling(self):
172
+ r"""
173
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
174
+ computing decoding in one step.
175
+ """
176
+ self.vae.disable_tiling()
177
+
178
+ def enable_sequential_cpu_offload(self, gpu_id=0):
179
+ r"""
180
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
181
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
182
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
183
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
184
+ `enable_model_cpu_offload`, but performance is lower.
185
+ """
186
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
187
+ from accelerate import cpu_offload
188
+ else:
189
+ raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
190
+
191
+ device = torch.device(f"cuda:{gpu_id}")
192
+
193
+ if self.device.type != "cpu":
194
+ self.to("cpu", silence_dtype_warnings=True)
195
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
196
+
197
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
198
+ cpu_offload(cpu_offloaded_model, device)
199
+
200
+ def enable_model_cpu_offload(self, gpu_id=0):
201
+ r"""
202
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
203
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
204
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
205
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
206
+ """
207
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
208
+ from accelerate import cpu_offload_with_hook
209
+ else:
210
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
211
+
212
+ device = torch.device(f"cuda:{gpu_id}")
213
+
214
+ if self.device.type != "cpu":
215
+ self.to("cpu", silence_dtype_warnings=True)
216
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
217
+
218
+ model_sequence = (
219
+ [self.text_encoder]
220
+ )
221
+ model_sequence.extend([self.unet, self.vae])
222
+
223
+ hook = None
224
+ for cpu_offloaded_model in model_sequence:
225
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
226
+
227
+ # We'll offload the last model manually.
228
+ self.final_offload_hook = hook
229
+
230
+ @property
231
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
232
+ def _execution_device(self):
233
+ r"""
234
+ Returns the device on which the pipeline's models will be executed. After calling
235
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
236
+ hooks.
237
+ """
238
+ if not hasattr(self.unet, "_hf_hook"):
239
+ return self.device
240
+ for module in self.unet.modules():
241
+ if (
242
+ hasattr(module, "_hf_hook")
243
+ and hasattr(module._hf_hook, "execution_device")
244
+ and module._hf_hook.execution_device is not None
245
+ ):
246
+ return torch.device(module._hf_hook.execution_device)
247
+ return self.device
248
+
249
+ def encode_prompt(
250
+ self,
251
+ prompt,
252
+ device: Optional[torch.device] = None,
253
+ num_images_per_prompt: int = 1,
254
+ do_classifier_free_guidance: bool = True,
255
+ negative_prompt=None,
256
+ prompt_embeds: Optional[torch.FloatTensor] = None,
257
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
258
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
259
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
260
+ lora_scale: Optional[float] = None,
261
+ ):
262
+ r"""
263
+ Encodes the prompt into text encoder hidden states.
264
+
265
+ Args:
266
+ prompt (`str` or `List[str]`, *optional*):
267
+ prompt to be encoded
268
+ device: (`torch.device`):
269
+ torch device
270
+ num_images_per_prompt (`int`):
271
+ number of images that should be generated per prompt
272
+ do_classifier_free_guidance (`bool`):
273
+ whether to use classifier free guidance or not
274
+ negative_prompt (`str` or `List[str]`, *optional*):
275
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
276
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
277
+ less than `1`).
278
+ prompt_embeds (`torch.FloatTensor`, *optional*):
279
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
280
+ provided, text embeddings will be generated from `prompt` input argument.
281
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
282
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
283
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
284
+ argument.
285
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
286
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
287
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
288
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
289
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
290
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
291
+ input argument.
292
+ lora_scale (`float`, *optional*):
293
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
294
+ """
295
+ # from IPython import embed; embed(); exit()
296
+ device = device or self._execution_device
297
+
298
+ # set lora scale so that monkey patched LoRA
299
+ # function of text encoder can correctly access it
300
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
301
+ self._lora_scale = lora_scale
302
+
303
+ if prompt is not None and isinstance(prompt, str):
304
+ batch_size = 1
305
+ elif prompt is not None and isinstance(prompt, list):
306
+ batch_size = len(prompt)
307
+ else:
308
+ batch_size = prompt_embeds.shape[0]
309
+
310
+ # Define tokenizers and text encoders
311
+ tokenizers = [self.tokenizer]
312
+ text_encoders = [self.text_encoder]
313
+
314
+ if prompt_embeds is None:
315
+ # textual inversion: procecss multi-vector tokens if necessary
316
+ prompt_embeds_list = []
317
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
318
+ if isinstance(self, TextualInversionLoaderMixin):
319
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
320
+
321
+ text_inputs = tokenizer(
322
+ prompt,
323
+ padding="max_length",
324
+ max_length=256,
325
+ truncation=True,
326
+ return_tensors="pt",
327
+ ).to('cuda')
328
+ output = text_encoder(
329
+ input_ids=text_inputs['input_ids'] ,
330
+ attention_mask=text_inputs['attention_mask'],
331
+ position_ids=text_inputs['position_ids'],
332
+ output_hidden_states=True)
333
+ prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
334
+ pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() # [batch_size, 4096]
335
+ bs_embed, seq_len, _ = prompt_embeds.shape
336
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
337
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
338
+
339
+ prompt_embeds_list.append(prompt_embeds)
340
+
341
+ # prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
342
+ prompt_embeds = prompt_embeds_list[0]
343
+
344
+ # get unconditional embeddings for classifier free guidance
345
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
346
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
347
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
348
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
349
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
350
+ # negative_prompt = negative_prompt or ""
351
+ uncond_tokens: List[str]
352
+ if negative_prompt is None:
353
+ uncond_tokens = [""] * batch_size
354
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
355
+ raise TypeError(
356
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
357
+ f" {type(prompt)}."
358
+ )
359
+ elif isinstance(negative_prompt, str):
360
+ uncond_tokens = [negative_prompt]
361
+ elif batch_size != len(negative_prompt):
362
+ raise ValueError(
363
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
364
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
365
+ " the batch size of `prompt`."
366
+ )
367
+ else:
368
+ uncond_tokens = negative_prompt
369
+
370
+ negative_prompt_embeds_list = []
371
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
372
+ # textual inversion: procecss multi-vector tokens if necessary
373
+ if isinstance(self, TextualInversionLoaderMixin):
374
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
375
+
376
+ max_length = prompt_embeds.shape[1]
377
+ uncond_input = tokenizer(
378
+ uncond_tokens,
379
+ padding="max_length",
380
+ max_length=max_length,
381
+ truncation=True,
382
+ return_tensors="pt",
383
+ ).to('cuda')
384
+ output = text_encoder(
385
+ input_ids=uncond_input['input_ids'] ,
386
+ attention_mask=uncond_input['attention_mask'],
387
+ position_ids=uncond_input['position_ids'],
388
+ output_hidden_states=True)
389
+ negative_prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
390
+ negative_pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() # [batch_size, 4096]
391
+
392
+ if do_classifier_free_guidance:
393
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
394
+ seq_len = negative_prompt_embeds.shape[1]
395
+
396
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
397
+
398
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
399
+ negative_prompt_embeds = negative_prompt_embeds.view(
400
+ batch_size * num_images_per_prompt, seq_len, -1
401
+ )
402
+
403
+ # For classifier free guidance, we need to do two forward passes.
404
+ # Here we concatenate the unconditional and text embeddings into a single batch
405
+ # to avoid doing two forward passes
406
+
407
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
408
+
409
+ # negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
410
+ negative_prompt_embeds = negative_prompt_embeds_list[0]
411
+
412
+ bs_embed = pooled_prompt_embeds.shape[0]
413
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
414
+ bs_embed * num_images_per_prompt, -1
415
+ )
416
+ if do_classifier_free_guidance:
417
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
418
+ bs_embed * num_images_per_prompt, -1
419
+ )
420
+
421
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
422
+
423
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
424
+ def prepare_extra_step_kwargs(self, generator, eta):
425
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
426
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
427
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
428
+ # and should be between [0, 1]
429
+
430
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
431
+ extra_step_kwargs = {}
432
+ if accepts_eta:
433
+ extra_step_kwargs["eta"] = eta
434
+
435
+ # check if the scheduler accepts generator
436
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
437
+ if accepts_generator:
438
+ extra_step_kwargs["generator"] = generator
439
+ return extra_step_kwargs
440
+
441
+ def check_inputs(
442
+ self,
443
+ prompt,
444
+ height,
445
+ width,
446
+ callback_steps,
447
+ negative_prompt=None,
448
+ prompt_embeds=None,
449
+ negative_prompt_embeds=None,
450
+ pooled_prompt_embeds=None,
451
+ negative_pooled_prompt_embeds=None,
452
+ ):
453
+ if height % 8 != 0 or width % 8 != 0:
454
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
455
+
456
+ if (callback_steps is None) or (
457
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
458
+ ):
459
+ raise ValueError(
460
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
461
+ f" {type(callback_steps)}."
462
+ )
463
+
464
+ if prompt is not None and prompt_embeds is not None:
465
+ raise ValueError(
466
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
467
+ " only forward one of the two."
468
+ )
469
+ elif prompt is None and prompt_embeds is None:
470
+ raise ValueError(
471
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
472
+ )
473
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
474
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
475
+
476
+ if negative_prompt is not None and negative_prompt_embeds is not None:
477
+ raise ValueError(
478
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
479
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
480
+ )
481
+
482
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
483
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
484
+ raise ValueError(
485
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
486
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
487
+ f" {negative_prompt_embeds.shape}."
488
+ )
489
+
490
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
491
+ raise ValueError(
492
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
493
+ )
494
+
495
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
496
+ raise ValueError(
497
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
498
+ )
499
+
500
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
501
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
502
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
503
+ if isinstance(generator, list) and len(generator) != batch_size:
504
+ raise ValueError(
505
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
506
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
507
+ )
508
+
509
+ if latents is None:
510
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
511
+ else:
512
+ latents = latents.to(device)
513
+
514
+ # scale the initial noise by the standard deviation required by the scheduler
515
+ latents = latents * self.scheduler.init_noise_sigma
516
+ return latents
517
+
518
+ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
519
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
520
+
521
+ passed_add_embed_dim = (
522
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + 4096
523
+ )
524
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
525
+
526
+ if expected_add_embed_dim != passed_add_embed_dim:
527
+ raise ValueError(
528
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
529
+ )
530
+
531
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
532
+ return add_time_ids
533
+
534
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
535
+ def upcast_vae(self):
536
+ dtype = self.vae.dtype
537
+ self.vae.to(dtype=torch.float32)
538
+ use_torch_2_0_or_xformers = isinstance(
539
+ self.vae.decoder.mid_block.attentions[0].processor,
540
+ (
541
+ AttnProcessor2_0,
542
+ XFormersAttnProcessor,
543
+ LoRAXFormersAttnProcessor,
544
+ LoRAAttnProcessor2_0,
545
+ ),
546
+ )
547
+ # if xformers or torch_2_0 is used attention block does not need
548
+ # to be in float32 which can save lots of memory
549
+ if use_torch_2_0_or_xformers:
550
+ self.vae.post_quant_conv.to(dtype)
551
+ self.vae.decoder.conv_in.to(dtype)
552
+ self.vae.decoder.mid_block.to(dtype)
553
+
554
+ @torch.no_grad()
555
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
556
+ def __call__(
557
+ self,
558
+ prompt: Union[str, List[str]] = None,
559
+ height: Optional[int] = None,
560
+ width: Optional[int] = None,
561
+ num_inference_steps: int = 50,
562
+ denoising_end: Optional[float] = None,
563
+ guidance_scale: float = 5.0,
564
+ negative_prompt: Optional[Union[str, List[str]]] = None,
565
+ num_images_per_prompt: Optional[int] = 1,
566
+ eta: float = 0.0,
567
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
568
+ latents: Optional[torch.FloatTensor] = None,
569
+ prompt_embeds: Optional[torch.FloatTensor] = None,
570
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
571
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
572
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
573
+ output_type: Optional[str] = "pil",
574
+ return_dict: bool = True,
575
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
576
+ callback_steps: int = 1,
577
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
578
+ guidance_rescale: float = 0.0,
579
+ original_size: Optional[Tuple[int, int]] = None,
580
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
581
+ target_size: Optional[Tuple[int, int]] = None,
582
+ use_dynamic_threshold: Optional[bool] = False,
583
+ ):
584
+ r"""
585
+ Function invoked when calling the pipeline for generation.
586
+
587
+ Args:
588
+ prompt (`str` or `List[str]`, *optional*):
589
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
590
+ instead.
591
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
592
+ The height in pixels of the generated image.
593
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
594
+ The width in pixels of the generated image.
595
+ num_inference_steps (`int`, *optional*, defaults to 50):
596
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
597
+ expense of slower inference.
598
+ denoising_end (`float`, *optional*):
599
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
600
+ completed before it is intentionally prematurely terminated. For instance, if denoising_end is set to
601
+ 0.7 and `num_inference_steps` is fixed at 50, the process will execute only 35 (i.e., 0.7 * 50)
602
+ Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
603
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
604
+ guidance_scale (`float`, *optional*, defaults to 7.5):
605
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
606
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
607
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
608
+ negative_prompt (`str` or `List[str]`, *optional*):
609
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
610
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
611
+ less than `1`).
612
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
613
+ The number of images to generate per prompt.
614
+ eta (`float`, *optional*, defaults to 0.0):
615
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
616
+ [`schedulers.DDIMScheduler`], will be ignored for others.
617
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
618
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
619
+ to make generation deterministic.
620
+ latents (`torch.FloatTensor`, *optional*):
621
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
622
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
623
+ tensor will ge generated by sampling using the supplied random `generator`.
624
+ prompt_embeds (`torch.FloatTensor`, *optional*):
625
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
626
+ provided, text embeddings will be generated from `prompt` input argument.
627
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
628
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
629
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
630
+ argument.
631
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
632
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
633
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
634
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
635
+ output_type (`str`, *optional*, defaults to `"pil"`):
636
+ The output format of the generate image. Choose between
637
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
638
+ return_dict (`bool`, *optional*, defaults to `True`):
639
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a
640
+ callback (`Callable`, *optional*):
641
+ A function that will be called every `callback_steps` steps during inference. The function will be
642
+ callback_steps (`int`, *optional*, defaults to 1):
643
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
644
+ called at every step.
645
+ cross_attention_kwargs (`dict`, *optional*):
646
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
647
+ `self.processor` in
648
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
649
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
650
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
651
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
652
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
653
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
654
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
655
+ TODO
656
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
657
+ TODO
658
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
659
+ TODO
660
+
661
+ Examples:
662
+
663
+ Returns:
664
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
665
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
666
+ `tuple. When returning a tuple, the first element is a list with the generated images, and the second
667
+ element is a list of `bool`s denoting whether the corresponding generated image likely represents
668
+ "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
669
+ """
670
+ # 0. Default height and width to unet
671
+ height = height or self.default_sample_size * self.vae_scale_factor
672
+ width = width or self.default_sample_size * self.vae_scale_factor
673
+
674
+ original_size = original_size or (height, width)
675
+ target_size = target_size or (height, width)
676
+
677
+ # 1. Check inputs. Raise error if not correct
678
+ self.check_inputs(
679
+ prompt,
680
+ height,
681
+ width,
682
+ callback_steps,
683
+ negative_prompt,
684
+ prompt_embeds,
685
+ negative_prompt_embeds,
686
+ pooled_prompt_embeds,
687
+ negative_pooled_prompt_embeds,
688
+ )
689
+
690
+ # 2. Define call parameters
691
+ if prompt is not None and isinstance(prompt, str):
692
+ batch_size = 1
693
+ elif prompt is not None and isinstance(prompt, list):
694
+ batch_size = len(prompt)
695
+ else:
696
+ batch_size = prompt_embeds.shape[0]
697
+
698
+ device = self._execution_device
699
+
700
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
701
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
702
+ # corresponds to doing no classifier free guidance.
703
+ do_classifier_free_guidance = guidance_scale > 1.0
704
+
705
+ # 3. Encode input prompt
706
+ text_encoder_lora_scale = (
707
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
708
+ )
709
+ (
710
+ prompt_embeds,
711
+ negative_prompt_embeds,
712
+ pooled_prompt_embeds,
713
+ negative_pooled_prompt_embeds,
714
+ ) = self.encode_prompt(
715
+ prompt,
716
+ device,
717
+ num_images_per_prompt,
718
+ do_classifier_free_guidance,
719
+ negative_prompt,
720
+ prompt_embeds=prompt_embeds,
721
+ negative_prompt_embeds=negative_prompt_embeds,
722
+ pooled_prompt_embeds=pooled_prompt_embeds,
723
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
724
+ lora_scale=text_encoder_lora_scale,
725
+ )
726
+
727
+ # 4. Prepare timesteps
728
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
729
+
730
+ timesteps = self.scheduler.timesteps
731
+
732
+ # 5. Prepare latent variables
733
+ num_channels_latents = self.unet.config.in_channels
734
+ latents = self.prepare_latents(
735
+ batch_size * num_images_per_prompt,
736
+ num_channels_latents,
737
+ height,
738
+ width,
739
+ prompt_embeds.dtype,
740
+ device,
741
+ generator,
742
+ latents,
743
+ )
744
+
745
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
746
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
747
+
748
+ # 7. Prepare added time ids & embeddings
749
+ add_text_embeds = pooled_prompt_embeds
750
+ add_time_ids = self._get_add_time_ids(
751
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
752
+ )
753
+
754
+ if do_classifier_free_guidance:
755
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
756
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
757
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
758
+
759
+ prompt_embeds = prompt_embeds.to(device)
760
+ add_text_embeds = add_text_embeds.to(device)
761
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
762
+
763
+ # 8. Denoising loop
764
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
765
+
766
+ # 7.1 Apply denoising_end
767
+ if denoising_end is not None:
768
+ num_inference_steps = int(round(denoising_end * num_inference_steps))
769
+ timesteps = timesteps[: num_warmup_steps + self.scheduler.order * num_inference_steps]
770
+
771
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
772
+ for i, t in enumerate(timesteps):
773
+ # expand the latents if we are doing classifier free guidance
774
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
775
+
776
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
777
+
778
+ # predict the noise residual
779
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
780
+ noise_pred = self.unet(
781
+ latent_model_input,
782
+ t,
783
+ encoder_hidden_states=prompt_embeds,
784
+ cross_attention_kwargs=cross_attention_kwargs,
785
+ added_cond_kwargs=added_cond_kwargs,
786
+ return_dict=False,
787
+ )[0]
788
+
789
+ # perform guidance
790
+ if do_classifier_free_guidance:
791
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
792
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
793
+ if use_dynamic_threshold:
794
+ DynamicThresh = DynThresh(maxSteps=num_inference_steps, experiment_mode=0)
795
+ noise_pred = DynamicThresh.dynthresh(noise_pred_text,
796
+ noise_pred_uncond,
797
+ guidance_scale,
798
+ None)
799
+
800
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
801
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
802
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
803
+
804
+ # compute the previous noisy sample x_t -> x_t-1
805
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
806
+
807
+ # call the callback, if provided
808
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
809
+ progress_bar.update()
810
+ if callback is not None and i % callback_steps == 0:
811
+ callback(i, t, latents)
812
+
813
+ # make sureo the VAE is in float32 mode, as it overflows in float16
814
+ # torch.cuda.empty_cache()
815
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
816
+ self.upcast_vae()
817
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
818
+
819
+
820
+ if not output_type == "latent":
821
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
822
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
823
+ else:
824
+ image = latents
825
+ return StableDiffusionXLPipelineOutput(images=image)
826
+
827
+ # image = self.watermark.apply_watermark(image)
828
+ image = self.image_processor.postprocess(image, output_type=output_type)
829
+
830
+ # Offload last model to CPU
831
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
832
+ self.final_offload_hook.offload()
833
+
834
+ if not return_dict:
835
+ return (image,)
836
+
837
+ return StableDiffusionXLPipelineOutput(images=image)
838
+
839
+
840
+ if __name__ == "__main__":
841
+ pass