junnyu commited on
Commit
9bf4a2c
1 Parent(s): c210e7f

Upload webui_stable_diffusion_controlnet.py

Browse files
Files changed (1) hide show
  1. webui_stable_diffusion_controlnet.py +1837 -0
webui_stable_diffusion_controlnet.py ADDED
@@ -0,0 +1,1837 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ # modified from https://github.com/AUTOMATIC1111/stable-diffusion-webui
17
+ # Here is the AGPL-3.0 license https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/LICENSE.txt
18
+ from ppdiffusers.utils import check_min_version
19
+ check_min_version("0.14.1")
20
+
21
+ import inspect
22
+ from typing import Any, Callable, Dict, List, Optional, Union
23
+
24
+ import paddle
25
+ import paddle.nn as nn
26
+ import PIL
27
+ import PIL.Image
28
+
29
+ from paddlenlp.transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
30
+ from ppdiffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
31
+ from ppdiffusers.pipelines.pipeline_utils import DiffusionPipeline
32
+ from ppdiffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
33
+ from ppdiffusers.pipelines.stable_diffusion.safety_checker import (
34
+ StableDiffusionSafetyChecker,
35
+ )
36
+ from ppdiffusers.schedulers import KarrasDiffusionSchedulers
37
+ from ppdiffusers.utils import (
38
+ PIL_INTERPOLATION,
39
+ logging,
40
+ randn_tensor,
41
+ safetensors_load,
42
+ torch_load,
43
+ )
44
+
45
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
+
47
+
48
+ class WebUIStableDiffusionControlNetPipeline(DiffusionPipeline):
49
+ r"""
50
+ Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
51
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
52
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
53
+ Args:
54
+ vae ([`AutoencoderKL`]):
55
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
56
+ text_encoder ([`CLIPTextModel`]):
57
+ Frozen text-encoder. Stable Diffusion uses the text portion of
58
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
59
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
60
+ tokenizer (`CLIPTokenizer`):
61
+ Tokenizer of class
62
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
63
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
64
+ controlnet ([`ControlNetModel`]):
65
+ Provides additional conditioning to the unet during the denoising process.
66
+ scheduler ([`SchedulerMixin`]):
67
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
68
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
69
+ safety_checker ([`StableDiffusionSafetyChecker`]):
70
+ Classification module that estimates whether generated images could be considered offensive or harmful.
71
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
72
+ feature_extractor ([`CLIPFeatureExtractor`]):
73
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
74
+ """
75
+ _optional_components = ["safety_checker", "feature_extractor"]
76
+ enable_emphasis = True
77
+ comma_padding_backtrack = 20
78
+
79
+ def __init__(
80
+ self,
81
+ vae: AutoencoderKL,
82
+ text_encoder: CLIPTextModel,
83
+ tokenizer: CLIPTokenizer,
84
+ unet: UNet2DConditionModel,
85
+ controlnet: ControlNetModel,
86
+ scheduler: KarrasDiffusionSchedulers,
87
+ safety_checker: StableDiffusionSafetyChecker,
88
+ feature_extractor: CLIPFeatureExtractor,
89
+ requires_safety_checker: bool = True,
90
+ ):
91
+ super().__init__()
92
+
93
+ if safety_checker is None and requires_safety_checker:
94
+ logger.warning(
95
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
96
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
97
+ " results in services or applications open to the public. PaddleNLP team, diffusers team and Hugging Face"
98
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
99
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
100
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
101
+ )
102
+
103
+ if safety_checker is not None and feature_extractor is None:
104
+ raise ValueError(
105
+ f"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
106
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
107
+ )
108
+
109
+ self.register_modules(
110
+ vae=vae,
111
+ text_encoder=text_encoder,
112
+ tokenizer=tokenizer,
113
+ unet=unet,
114
+ controlnet=controlnet,
115
+ scheduler=scheduler,
116
+ safety_checker=safety_checker,
117
+ feature_extractor=feature_extractor,
118
+ )
119
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
120
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
121
+
122
+ # custom data
123
+ clip_model = FrozenCLIPEmbedder(text_encoder, tokenizer)
124
+ self.sj = StableDiffusionModelHijack(clip_model)
125
+ self.orginal_scheduler_config = self.scheduler.config
126
+ self.supported_scheduler = [
127
+ "pndm",
128
+ "lms",
129
+ "euler",
130
+ "euler-ancestral",
131
+ "dpm-multi",
132
+ "dpm-single",
133
+ "unipc-multi",
134
+ "ddim",
135
+ "ddpm",
136
+ "deis-multi",
137
+ "heun",
138
+ "kdpm2-ancestral",
139
+ "kdpm2",
140
+ ]
141
+
142
+ def add_ti_embedding_dir(self, embeddings_dir):
143
+ self.sj.embedding_db.add_embedding_dir(embeddings_dir)
144
+ self.sj.embedding_db.load_textual_inversion_embeddings()
145
+
146
+ def clear_ti_embedding(self):
147
+ self.sj.embedding_db.clear_embedding_dirs()
148
+ self.sj.embedding_db.load_textual_inversion_embeddings(True)
149
+
150
+ def switch_scheduler(self, scheduler_type="ddim"):
151
+ scheduler_type = scheduler_type.lower()
152
+ from ppdiffusers import (
153
+ DDIMScheduler,
154
+ DDPMScheduler,
155
+ DEISMultistepScheduler,
156
+ DPMSolverMultistepScheduler,
157
+ DPMSolverSinglestepScheduler,
158
+ EulerAncestralDiscreteScheduler,
159
+ EulerDiscreteScheduler,
160
+ HeunDiscreteScheduler,
161
+ KDPM2AncestralDiscreteScheduler,
162
+ KDPM2DiscreteScheduler,
163
+ LMSDiscreteScheduler,
164
+ PNDMScheduler,
165
+ UniPCMultistepScheduler,
166
+ )
167
+
168
+ if scheduler_type == "pndm":
169
+ scheduler = PNDMScheduler.from_config(self.orginal_scheduler_config, skip_prk_steps=True)
170
+ elif scheduler_type == "lms":
171
+ scheduler = LMSDiscreteScheduler.from_config(self.orginal_scheduler_config)
172
+ elif scheduler_type == "heun":
173
+ scheduler = HeunDiscreteScheduler.from_config(self.orginal_scheduler_config)
174
+ elif scheduler_type == "euler":
175
+ scheduler = EulerDiscreteScheduler.from_config(self.orginal_scheduler_config)
176
+ elif scheduler_type == "euler-ancestral":
177
+ scheduler = EulerAncestralDiscreteScheduler.from_config(self.orginal_scheduler_config)
178
+ elif scheduler_type == "dpm-multi":
179
+ scheduler = DPMSolverMultistepScheduler.from_config(self.orginal_scheduler_config)
180
+ elif scheduler_type == "dpm-single":
181
+ scheduler = DPMSolverSinglestepScheduler.from_config(self.orginal_scheduler_config)
182
+ elif scheduler_type == "kdpm2-ancestral":
183
+ scheduler = KDPM2AncestralDiscreteScheduler.from_config(self.orginal_scheduler_config)
184
+ elif scheduler_type == "kdpm2":
185
+ scheduler = KDPM2DiscreteScheduler.from_config(self.orginal_scheduler_config)
186
+ elif scheduler_type == "unipc-multi":
187
+ scheduler = UniPCMultistepScheduler.from_config(self.orginal_scheduler_config)
188
+ elif scheduler_type == "ddim":
189
+ scheduler = DDIMScheduler.from_config(
190
+ self.orginal_scheduler_config,
191
+ steps_offset=1,
192
+ clip_sample=False,
193
+ set_alpha_to_one=False,
194
+ )
195
+ elif scheduler_type == "ddpm":
196
+ scheduler = DDPMScheduler.from_config(
197
+ self.orginal_scheduler_config,
198
+ )
199
+ elif scheduler_type == "deis-multi":
200
+ scheduler = DEISMultistepScheduler.from_config(
201
+ self.orginal_scheduler_config,
202
+ )
203
+ else:
204
+ raise ValueError(
205
+ f"Scheduler of type {scheduler_type} doesn't exist! Please choose in {self.supported_scheduler}!"
206
+ )
207
+ self.scheduler = scheduler
208
+
209
+ @paddle.no_grad()
210
+ def _encode_prompt(
211
+ self,
212
+ prompt: str,
213
+ do_classifier_free_guidance: float = 7.5,
214
+ negative_prompt: str = None,
215
+ num_inference_steps: int = 50,
216
+ ):
217
+ if do_classifier_free_guidance:
218
+ assert isinstance(negative_prompt, str)
219
+ negative_prompt = [negative_prompt]
220
+ uc = get_learned_conditioning(self.sj.clip, negative_prompt, num_inference_steps)
221
+ else:
222
+ uc = None
223
+
224
+ c = get_multicond_learned_conditioning(self.sj.clip, prompt, num_inference_steps)
225
+ return c, uc
226
+
227
+ def run_safety_checker(self, image, dtype):
228
+ if self.safety_checker is not None:
229
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pd")
230
+ image, has_nsfw_concept = self.safety_checker(
231
+ images=image, clip_input=safety_checker_input.pixel_values.cast(dtype)
232
+ )
233
+ else:
234
+ has_nsfw_concept = None
235
+ return image, has_nsfw_concept
236
+
237
+ def decode_latents(self, latents):
238
+ latents = 1 / self.vae.config.scaling_factor * latents
239
+ image = self.vae.decode(latents).sample
240
+ image = (image / 2 + 0.5).clip(0, 1)
241
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
242
+ image = image.transpose([0, 2, 3, 1]).cast("float32").numpy()
243
+ return image
244
+
245
+ def prepare_extra_step_kwargs(self, generator, eta):
246
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
247
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
248
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
249
+ # and should be between [0, 1]
250
+
251
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
252
+ extra_step_kwargs = {}
253
+ if accepts_eta:
254
+ extra_step_kwargs["eta"] = eta
255
+
256
+ # check if the scheduler accepts generator
257
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
258
+ if accepts_generator:
259
+ extra_step_kwargs["generator"] = generator
260
+ return extra_step_kwargs
261
+
262
+ def check_inputs(
263
+ self,
264
+ prompt,
265
+ image,
266
+ height,
267
+ width,
268
+ callback_steps,
269
+ negative_prompt=None,
270
+ controlnet_conditioning_scale=1.0,
271
+ ):
272
+ if height % 8 != 0 or width % 8 != 0:
273
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
274
+
275
+ if (callback_steps is None) or (
276
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
277
+ ):
278
+ raise ValueError(
279
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
280
+ f" {type(callback_steps)}."
281
+ )
282
+
283
+ if prompt is not None and not isinstance(prompt, str):
284
+ raise ValueError(f"`prompt` has to be of type `str` but is {type(prompt)}")
285
+
286
+ if negative_prompt is not None and not isinstance(negative_prompt, str):
287
+ raise ValueError(f"`negative_prompt` has to be of type `str` but is {type(negative_prompt)}")
288
+
289
+ # Check `image`
290
+
291
+ if isinstance(self.controlnet, ControlNetModel):
292
+ self.check_image(image, prompt)
293
+ else:
294
+ assert False
295
+
296
+ # Check `controlnet_conditioning_scale`
297
+ if isinstance(self.controlnet, ControlNetModel):
298
+ if not isinstance(controlnet_conditioning_scale, (float, list, tuple)):
299
+ raise TypeError(
300
+ "For single controlnet: `controlnet_conditioning_scale` must be type `float, list(float) or tuple(float)`."
301
+ )
302
+
303
+ def check_image(self, image, prompt):
304
+ image_is_pil = isinstance(image, PIL.Image.Image)
305
+ image_is_tensor = isinstance(image, paddle.Tensor)
306
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
307
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], paddle.Tensor)
308
+
309
+ if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
310
+ raise TypeError(
311
+ "image must be one of PIL image, paddle tensor, list of PIL images, or list of paddle tensors"
312
+ )
313
+
314
+ if image_is_pil:
315
+ image_batch_size = 1
316
+ elif image_is_tensor:
317
+ image_batch_size = image.shape[0]
318
+ elif image_is_pil_list:
319
+ image_batch_size = len(image)
320
+ elif image_is_tensor_list:
321
+ image_batch_size = len(image)
322
+
323
+ if prompt is not None and isinstance(prompt, str):
324
+ prompt_batch_size = 1
325
+ elif prompt is not None and isinstance(prompt, list):
326
+ prompt_batch_size = len(prompt)
327
+
328
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
329
+ raise ValueError(
330
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
331
+ )
332
+
333
+ def prepare_image(self, image, width, height, dtype):
334
+ if not isinstance(image, paddle.Tensor):
335
+ if isinstance(image, PIL.Image.Image):
336
+ image = [image]
337
+
338
+ if isinstance(image[0], PIL.Image.Image):
339
+ images = []
340
+ for image_ in image:
341
+ image_ = image_.convert("RGB")
342
+ image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
343
+ image_ = np.array(image_)
344
+ image_ = image_[None, :]
345
+ images.append(image_)
346
+
347
+ image = np.concatenate(images, axis=0)
348
+ image = np.array(image).astype(np.float32) / 255.0
349
+ image = image.transpose(0, 3, 1, 2)
350
+ image = paddle.to_tensor(image)
351
+ elif isinstance(image[0], paddle.Tensor):
352
+ image = paddle.concat(image, axis=0)
353
+
354
+ image = image.cast(dtype)
355
+ return image
356
+
357
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None):
358
+ shape = [batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor]
359
+ if isinstance(generator, list) and len(generator) != batch_size:
360
+ raise ValueError(
361
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
362
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
363
+ )
364
+
365
+ if latents is None:
366
+ latents = randn_tensor(shape, generator=generator, dtype=dtype)
367
+
368
+ # scale the initial noise by the standard deviation required by the scheduler
369
+ latents = latents * self.scheduler.init_noise_sigma
370
+ return latents
371
+
372
+ def _default_height_width(self, height, width, image):
373
+ while isinstance(image, list):
374
+ image = image[0]
375
+
376
+ if height is None:
377
+ if isinstance(image, PIL.Image.Image):
378
+ height = image.height
379
+ elif isinstance(image, paddle.Tensor):
380
+ height = image.shape[3]
381
+
382
+ height = (height // 8) * 8 # round down to nearest multiple of 8
383
+
384
+ if width is None:
385
+ if isinstance(image, PIL.Image.Image):
386
+ width = image.width
387
+ elif isinstance(image, paddle.Tensor):
388
+ width = image.shape[2]
389
+
390
+ width = (width // 8) * 8 # round down to nearest multiple of 8
391
+
392
+ return height, width
393
+
394
+ @paddle.no_grad()
395
+ def __call__(
396
+ self,
397
+ prompt: str = None,
398
+ image: PIL.Image.Image = None,
399
+ height: Optional[int] = None,
400
+ width: Optional[int] = None,
401
+ num_inference_steps: int = 50,
402
+ guidance_scale: float = 7.5,
403
+ negative_prompt: str = None,
404
+ eta: float = 0.0,
405
+ generator: Optional[Union[paddle.Generator, List[paddle.Generator]]] = None,
406
+ latents: Optional[paddle.Tensor] = None,
407
+ output_type: Optional[str] = "pil",
408
+ return_dict: bool = True,
409
+ callback: Optional[Callable[[int, int, paddle.Tensor], None]] = None,
410
+ callback_steps: Optional[int] = 1,
411
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
412
+ clip_skip: int = 0,
413
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
414
+ ):
415
+ r"""
416
+ Function invoked when calling the pipeline for generation.
417
+
418
+ Args:
419
+ prompt (`str`, *optional*):
420
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
421
+ instead.
422
+ image (`paddle.Tensor`, `PIL.Image.Image`):
423
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
424
+ the type is specified as `paddle.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
425
+ also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
426
+ height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
427
+ specified in init, images must be passed as a list such that each element of the list can be correctly
428
+ batched for input to a single controlnet.
429
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
430
+ The height in pixels of the generated image.
431
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
432
+ The width in pixels of the generated image.
433
+ num_inference_steps (`int`, *optional*, defaults to 50):
434
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
435
+ expense of slower inference.
436
+ guidance_scale (`float`, *optional*, defaults to 7.5):
437
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
438
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
439
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
440
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
441
+ usually at the expense of lower image quality.
442
+ negative_prompt (`str`, *optional*):
443
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
444
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
445
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
446
+ eta (`float`, *optional*, defaults to 0.0):
447
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
448
+ [`schedulers.DDIMScheduler`], will be ignored for others.
449
+ generator (`paddle.Generator` or `List[paddle.Generator]`, *optional*):
450
+ One or a list of paddle generator(s) to make generation deterministic.
451
+ latents (`paddle.Tensor`, *optional*):
452
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
453
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
454
+ tensor will ge generated by sampling using the supplied random `generator`.
455
+ output_type (`str`, *optional*, defaults to `"pil"`):
456
+ The output format of the generate image. Choose between
457
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
458
+ return_dict (`bool`, *optional*, defaults to `True`):
459
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
460
+ plain tuple.
461
+ callback (`Callable`, *optional*):
462
+ A function that will be called every `callback_steps` steps during inference. The function will be
463
+ called with the following arguments: `callback(step: int, timestep: int, latents: paddle.Tensor)`.
464
+ callback_steps (`int`, *optional*, defaults to 1):
465
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
466
+ called at every step.
467
+ cross_attention_kwargs (`dict`, *optional*):
468
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
469
+ `self.processor` in
470
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
471
+ clip_skip (`int`, *optional*, defaults to 0):
472
+ CLIP_stop_at_last_layers, if clip_skip < 1, we will use the last_hidden_state from text_encoder.
473
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
474
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
475
+ to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
476
+ corresponding scale as a list.
477
+ Examples:
478
+
479
+ Returns:
480
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
481
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
482
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
483
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
484
+ (nsfw) content, according to the `safety_checker`.
485
+ """
486
+ # 0. Default height and width to unet
487
+ height, width = self._default_height_width(height, width, image)
488
+
489
+ # 1. Check inputs. Raise error if not correct
490
+ self.check_inputs(
491
+ prompt,
492
+ image,
493
+ height,
494
+ width,
495
+ callback_steps,
496
+ negative_prompt,
497
+ controlnet_conditioning_scale,
498
+ )
499
+
500
+ batch_size = 1
501
+
502
+ image = self.prepare_image(
503
+ image=image,
504
+ width=width,
505
+ height=height,
506
+ dtype=self.controlnet.dtype,
507
+ )
508
+
509
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
510
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
511
+ # corresponds to doing no classifier free guidance.
512
+ do_classifier_free_guidance = guidance_scale > 1.0
513
+
514
+ prompts, extra_network_data = parse_prompts([prompt])
515
+
516
+ self.sj.clip.CLIP_stop_at_last_layers = clip_skip
517
+ # 3. Encode input prompt
518
+ prompt_embeds, negative_prompt_embeds = self._encode_prompt(
519
+ prompts,
520
+ do_classifier_free_guidance,
521
+ negative_prompt,
522
+ num_inference_steps=num_inference_steps,
523
+ )
524
+
525
+ # 4. Prepare timesteps
526
+ self.scheduler.set_timesteps(num_inference_steps)
527
+ timesteps = self.scheduler.timesteps
528
+
529
+ # 5. Prepare latent variables
530
+ num_channels_latents = self.unet.in_channels
531
+ latents = self.prepare_latents(
532
+ batch_size,
533
+ num_channels_latents,
534
+ height,
535
+ width,
536
+ self.unet.dtype,
537
+ generator,
538
+ latents,
539
+ )
540
+
541
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
542
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
543
+
544
+ # 7. Denoising loop
545
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
546
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
547
+ for i, t in enumerate(timesteps):
548
+ step = i // self.scheduler.order
549
+ do_batch = False
550
+ conds_list, cond_tensor = reconstruct_multicond_batch(prompt_embeds, step)
551
+ try:
552
+ weight = conds_list[0][0][1]
553
+ except Exception:
554
+ weight = 1.0
555
+ if do_classifier_free_guidance:
556
+ uncond_tensor = reconstruct_cond_batch(negative_prompt_embeds, step)
557
+ do_batch = cond_tensor.shape[1] == uncond_tensor.shape[1]
558
+
559
+ # expand the latents if we are doing classifier free guidance
560
+ latent_model_input = paddle.concat([latents] * 2) if do_batch else latents
561
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
562
+
563
+ if do_batch:
564
+ encoder_hidden_states = paddle.concat([uncond_tensor, cond_tensor])
565
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
566
+ latent_model_input,
567
+ t,
568
+ encoder_hidden_states=encoder_hidden_states,
569
+ controlnet_cond=paddle.concat([image, image]),
570
+ conditioning_scale=controlnet_conditioning_scale,
571
+ return_dict=False,
572
+ )
573
+ noise_pred = self.unet(
574
+ latent_model_input,
575
+ t,
576
+ encoder_hidden_states=encoder_hidden_states,
577
+ cross_attention_kwargs=cross_attention_kwargs,
578
+ down_block_additional_residuals=down_block_res_samples,
579
+ mid_block_additional_residual=mid_block_res_sample,
580
+ ).sample
581
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
582
+ noise_pred = noise_pred_uncond + weight * guidance_scale * (noise_pred_text - noise_pred_uncond)
583
+ else:
584
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
585
+ latent_model_input,
586
+ t,
587
+ encoder_hidden_states=cond_tensor,
588
+ controlnet_cond=image,
589
+ conditioning_scale=controlnet_conditioning_scale,
590
+ return_dict=False,
591
+ )
592
+ noise_pred = self.unet(
593
+ latent_model_input,
594
+ t,
595
+ encoder_hidden_states=cond_tensor,
596
+ cross_attention_kwargs=cross_attention_kwargs,
597
+ down_block_additional_residuals=down_block_res_samples,
598
+ mid_block_additional_residual=mid_block_res_sample,
599
+ ).sample
600
+
601
+ if do_classifier_free_guidance:
602
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
603
+ latent_model_input,
604
+ t,
605
+ encoder_hidden_states=uncond_tensor,
606
+ controlnet_cond=image,
607
+ conditioning_scale=controlnet_conditioning_scale,
608
+ return_dict=False,
609
+ )
610
+ noise_pred_uncond = self.unet(
611
+ latent_model_input,
612
+ t,
613
+ encoder_hidden_states=uncond_tensor,
614
+ cross_attention_kwargs=cross_attention_kwargs,
615
+ down_block_additional_residuals=down_block_res_samples,
616
+ mid_block_additional_residual=mid_block_res_sample,
617
+ ).sample
618
+ noise_pred = noise_pred_uncond + weight * guidance_scale * (noise_pred - noise_pred_uncond)
619
+
620
+ # compute the previous noisy sample x_t -> x_t-1
621
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
622
+
623
+ # call the callback, if provided
624
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
625
+ progress_bar.update()
626
+ if callback is not None and i % callback_steps == 0:
627
+ callback(i, t, latents)
628
+
629
+ if output_type == "latent":
630
+ image = latents
631
+ has_nsfw_concept = None
632
+ elif output_type == "pil":
633
+ # 8. Post-processing
634
+ image = self.decode_latents(latents)
635
+
636
+ # 9. Run safety checker
637
+ image, has_nsfw_concept = self.run_safety_checker(image, self.unet.dtype)
638
+
639
+ # 10. Convert to PIL
640
+ image = self.numpy_to_pil(image)
641
+ else:
642
+ # 8. Post-processing
643
+ image = self.decode_latents(latents)
644
+
645
+ # 9. Run safety checker
646
+ image, has_nsfw_concept = self.run_safety_checker(image, self.unet.dtype)
647
+
648
+ if not return_dict:
649
+ return (image, has_nsfw_concept)
650
+
651
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
652
+
653
+
654
+ # clip.py
655
+ import math
656
+ from collections import namedtuple
657
+
658
+
659
+ class PromptChunk:
660
+ """
661
+ This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
662
+ If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
663
+ Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
664
+ so just 75 tokens from prompt.
665
+ """
666
+
667
+ def __init__(self):
668
+ self.tokens = []
669
+ self.multipliers = []
670
+ self.fixes = []
671
+
672
+
673
+ PromptChunkFix = namedtuple("PromptChunkFix", ["offset", "embedding"])
674
+ """An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt
675
+ chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally
676
+ are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
677
+
678
+
679
+ class FrozenCLIPEmbedder(nn.Layer):
680
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
681
+
682
+ LAYERS = ["last", "pooled", "hidden"]
683
+
684
+ def __init__(self, text_encoder, tokenizer, freeze=True, layer="last", layer_idx=None):
685
+ super().__init__()
686
+ assert layer in self.LAYERS
687
+ self.tokenizer = tokenizer
688
+ self.text_encoder = text_encoder
689
+ if freeze:
690
+ self.freeze()
691
+ self.layer = layer
692
+ self.layer_idx = layer_idx
693
+ if layer == "hidden":
694
+ assert layer_idx is not None
695
+ assert 0 <= abs(layer_idx) <= 12
696
+
697
+ def freeze(self):
698
+ self.text_encoder.eval()
699
+ for param in self.parameters():
700
+ param.stop_gradient = False
701
+
702
+ def forward(self, text):
703
+ batch_encoding = self.tokenizer(
704
+ text,
705
+ truncation=True,
706
+ max_length=self.tokenizer.model_max_length,
707
+ padding="max_length",
708
+ return_tensors="pd",
709
+ )
710
+ tokens = batch_encoding["input_ids"]
711
+ outputs = self.text_encoder(input_ids=tokens, output_hidden_states=self.layer == "hidden", return_dict=True)
712
+ if self.layer == "last":
713
+ z = outputs.last_hidden_state
714
+ elif self.layer == "pooled":
715
+ z = outputs.pooler_output[:, None, :]
716
+ else:
717
+ z = outputs.hidden_states[self.layer_idx]
718
+ return z
719
+
720
+ def encode(self, text):
721
+ return self(text)
722
+
723
+
724
+ class FrozenCLIPEmbedderWithCustomWordsBase(nn.Layer):
725
+ """A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
726
+ have unlimited prompt length and assign weights to tokens in prompt.
727
+ """
728
+
729
+ def __init__(self, wrapped, hijack):
730
+ super().__init__()
731
+
732
+ self.wrapped = wrapped
733
+ """Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
734
+ depending on model."""
735
+
736
+ self.hijack = hijack
737
+ self.chunk_length = 75
738
+
739
+ def empty_chunk(self):
740
+ """creates an empty PromptChunk and returns it"""
741
+
742
+ chunk = PromptChunk()
743
+ chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
744
+ chunk.multipliers = [1.0] * (self.chunk_length + 2)
745
+ return chunk
746
+
747
+ def get_target_prompt_token_count(self, token_count):
748
+ """returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""
749
+
750
+ return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
751
+
752
+ def tokenize(self, texts):
753
+ """Converts a batch of texts into a batch of token ids"""
754
+
755
+ raise NotImplementedError
756
+
757
+ def encode_with_text_encoder(self, tokens):
758
+ """
759
+ converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
760
+ All python lists with tokens are assumed to have same length, usually 77.
761
+ if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
762
+ model - can be 768 and 1024.
763
+ Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None).
764
+ """
765
+
766
+ raise NotImplementedError
767
+
768
+ def encode_embedding_init_text(self, init_text, nvpt):
769
+ """Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through
770
+ transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned."""
771
+
772
+ raise NotImplementedError
773
+
774
+ def tokenize_line(self, line):
775
+ """
776
+ this transforms a single prompt into a list of PromptChunk objects - as many as needed to
777
+ represent the prompt.
778
+ Returns the list and the total number of tokens in the prompt.
779
+ """
780
+
781
+ if WebUIStableDiffusionControlNetPipeline.enable_emphasis:
782
+ parsed = parse_prompt_attention(line)
783
+ else:
784
+ parsed = [[line, 1.0]]
785
+
786
+ tokenized = self.tokenize([text for text, _ in parsed])
787
+
788
+ chunks = []
789
+ chunk = PromptChunk()
790
+ token_count = 0
791
+ last_comma = -1
792
+
793
+ def next_chunk(is_last=False):
794
+ """puts current chunk into the list of results and produces the next one - empty;
795
+ if is_last is true, tokens <end-of-text> tokens at the end won't add to token_count"""
796
+ nonlocal token_count
797
+ nonlocal last_comma
798
+ nonlocal chunk
799
+
800
+ if is_last:
801
+ token_count += len(chunk.tokens)
802
+ else:
803
+ token_count += self.chunk_length
804
+
805
+ to_add = self.chunk_length - len(chunk.tokens)
806
+ if to_add > 0:
807
+ chunk.tokens += [self.id_end] * to_add
808
+ chunk.multipliers += [1.0] * to_add
809
+
810
+ chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
811
+ chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
812
+
813
+ last_comma = -1
814
+ chunks.append(chunk)
815
+ chunk = PromptChunk()
816
+
817
+ for tokens, (text, weight) in zip(tokenized, parsed):
818
+ if text == "BREAK" and weight == -1:
819
+ next_chunk()
820
+ continue
821
+
822
+ position = 0
823
+ while position < len(tokens):
824
+ token = tokens[position]
825
+
826
+ if token == self.comma_token:
827
+ last_comma = len(chunk.tokens)
828
+
829
+ # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
830
+ # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
831
+ elif (
832
+ WebUIStableDiffusionControlNetPipeline.comma_padding_backtrack != 0
833
+ and len(chunk.tokens) == self.chunk_length
834
+ and last_comma != -1
835
+ and len(chunk.tokens) - last_comma
836
+ <= WebUIStableDiffusionControlNetPipeline.comma_padding_backtrack
837
+ ):
838
+ break_location = last_comma + 1
839
+
840
+ reloc_tokens = chunk.tokens[break_location:]
841
+ reloc_mults = chunk.multipliers[break_location:]
842
+
843
+ chunk.tokens = chunk.tokens[:break_location]
844
+ chunk.multipliers = chunk.multipliers[:break_location]
845
+
846
+ next_chunk()
847
+ chunk.tokens = reloc_tokens
848
+ chunk.multipliers = reloc_mults
849
+
850
+ if len(chunk.tokens) == self.chunk_length:
851
+ next_chunk()
852
+
853
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(
854
+ tokens, position
855
+ )
856
+ if embedding is None:
857
+ chunk.tokens.append(token)
858
+ chunk.multipliers.append(weight)
859
+ position += 1
860
+ continue
861
+
862
+ emb_len = int(embedding.vec.shape[0])
863
+ if len(chunk.tokens) + emb_len > self.chunk_length:
864
+ next_chunk()
865
+
866
+ chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding))
867
+
868
+ chunk.tokens += [0] * emb_len
869
+ chunk.multipliers += [weight] * emb_len
870
+ position += embedding_length_in_tokens
871
+
872
+ if len(chunk.tokens) > 0 or len(chunks) == 0:
873
+ next_chunk(is_last=True)
874
+
875
+ return chunks, token_count
876
+
877
+ def process_texts(self, texts):
878
+ """
879
+ Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
880
+ length, in tokens, of all texts.
881
+ """
882
+
883
+ token_count = 0
884
+
885
+ cache = {}
886
+ batch_chunks = []
887
+ for line in texts:
888
+ if line in cache:
889
+ chunks = cache[line]
890
+ else:
891
+ chunks, current_token_count = self.tokenize_line(line)
892
+ token_count = max(current_token_count, token_count)
893
+
894
+ cache[line] = chunks
895
+
896
+ batch_chunks.append(chunks)
897
+
898
+ return batch_chunks, token_count
899
+
900
+ def forward(self, texts):
901
+ """
902
+ Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
903
+ Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
904
+ be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
905
+ An example shape returned by this function can be: (2, 77, 768).
906
+ Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
907
+ is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
908
+ """
909
+
910
+ batch_chunks, token_count = self.process_texts(texts)
911
+
912
+ used_embeddings = {}
913
+ chunk_count = max([len(x) for x in batch_chunks])
914
+
915
+ zs = []
916
+ for i in range(chunk_count):
917
+ batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks]
918
+
919
+ tokens = [x.tokens for x in batch_chunk]
920
+ multipliers = [x.multipliers for x in batch_chunk]
921
+ self.hijack.fixes = [x.fixes for x in batch_chunk]
922
+
923
+ for fixes in self.hijack.fixes:
924
+ for position, embedding in fixes:
925
+ used_embeddings[embedding.name] = embedding
926
+
927
+ z = self.process_tokens(tokens, multipliers)
928
+ zs.append(z)
929
+
930
+ if len(used_embeddings) > 0:
931
+ embeddings_list = ", ".join(
932
+ [f"{name} [{embedding.checksum()}]" for name, embedding in used_embeddings.items()]
933
+ )
934
+ self.hijack.comments.append(f"Used embeddings: {embeddings_list}")
935
+
936
+ return paddle.concat(zs, axis=1)
937
+
938
+ def process_tokens(self, remade_batch_tokens, batch_multipliers):
939
+ """
940
+ sends one single prompt chunk to be encoded by transformers neural network.
941
+ remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
942
+ there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
943
+ Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
944
+ corresponds to one token.
945
+ """
946
+ tokens = paddle.to_tensor(remade_batch_tokens)
947
+
948
+ # this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.
949
+ if self.id_end != self.id_pad:
950
+ for batch_pos in range(len(remade_batch_tokens)):
951
+ index = remade_batch_tokens[batch_pos].index(self.id_end)
952
+ tokens[batch_pos, index + 1 : tokens.shape[1]] = self.id_pad
953
+
954
+ z = self.encode_with_text_encoder(tokens)
955
+
956
+ # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
957
+ batch_multipliers = paddle.to_tensor(batch_multipliers)
958
+ original_mean = z.mean()
959
+ z = z * batch_multipliers.reshape(
960
+ batch_multipliers.shape
961
+ + [
962
+ 1,
963
+ ]
964
+ ).expand(z.shape)
965
+ new_mean = z.mean()
966
+ z = z * (original_mean / new_mean)
967
+
968
+ return z
969
+
970
+
971
+ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
972
+ def __init__(self, wrapped, hijack, CLIP_stop_at_last_layers=-1):
973
+ super().__init__(wrapped, hijack)
974
+ self.CLIP_stop_at_last_layers = CLIP_stop_at_last_layers
975
+ self.tokenizer = wrapped.tokenizer
976
+
977
+ vocab = self.tokenizer.get_vocab()
978
+
979
+ self.comma_token = vocab.get(",</w>", None)
980
+
981
+ self.token_mults = {}
982
+ tokens_with_parens = [(k, v) for k, v in vocab.items() if "(" in k or ")" in k or "[" in k or "]" in k]
983
+ for text, ident in tokens_with_parens:
984
+ mult = 1.0
985
+ for c in text:
986
+ if c == "[":
987
+ mult /= 1.1
988
+ if c == "]":
989
+ mult *= 1.1
990
+ if c == "(":
991
+ mult *= 1.1
992
+ if c == ")":
993
+ mult /= 1.1
994
+
995
+ if mult != 1.0:
996
+ self.token_mults[ident] = mult
997
+
998
+ self.id_start = self.wrapped.tokenizer.bos_token_id
999
+ self.id_end = self.wrapped.tokenizer.eos_token_id
1000
+ self.id_pad = self.id_end
1001
+
1002
+ def tokenize(self, texts):
1003
+ tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
1004
+
1005
+ return tokenized
1006
+
1007
+ def encode_with_text_encoder(self, tokens):
1008
+ output_hidden_states = self.CLIP_stop_at_last_layers > 1
1009
+ outputs = self.wrapped.text_encoder(
1010
+ input_ids=tokens, output_hidden_states=output_hidden_states, return_dict=True
1011
+ )
1012
+
1013
+ if output_hidden_states:
1014
+ z = outputs.hidden_states[-self.CLIP_stop_at_last_layers]
1015
+ z = self.wrapped.text_encoder.text_model.ln_final(z)
1016
+ else:
1017
+ z = outputs.last_hidden_state
1018
+
1019
+ return z
1020
+
1021
+ def encode_embedding_init_text(self, init_text, nvpt):
1022
+ embedding_layer = self.wrapped.text_encoder.text_model
1023
+ ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pd", add_special_tokens=False)[
1024
+ "input_ids"
1025
+ ]
1026
+ embedded = embedding_layer.token_embedding.wrapped(ids).squeeze(0)
1027
+
1028
+ return embedded
1029
+
1030
+
1031
+ # extra_networks.py
1032
+ import re
1033
+ from collections import defaultdict
1034
+
1035
+
1036
+ class ExtraNetworkParams:
1037
+ def __init__(self, items=None):
1038
+ self.items = items or []
1039
+
1040
+
1041
+ re_extra_net = re.compile(r"<(\w+):([^>]+)>")
1042
+
1043
+
1044
+ def parse_prompt(prompt):
1045
+ res = defaultdict(list)
1046
+
1047
+ def found(m):
1048
+ name = m.group(1)
1049
+ args = m.group(2)
1050
+
1051
+ res[name].append(ExtraNetworkParams(items=args.split(":")))
1052
+
1053
+ return ""
1054
+
1055
+ prompt = re.sub(re_extra_net, found, prompt)
1056
+
1057
+ return prompt, res
1058
+
1059
+
1060
+ def parse_prompts(prompts):
1061
+ res = []
1062
+ extra_data = None
1063
+
1064
+ for prompt in prompts:
1065
+ updated_prompt, parsed_extra_data = parse_prompt(prompt)
1066
+
1067
+ if extra_data is None:
1068
+ extra_data = parsed_extra_data
1069
+
1070
+ res.append(updated_prompt)
1071
+
1072
+ return res, extra_data
1073
+
1074
+
1075
+ # image_embeddings.py
1076
+
1077
+ import base64
1078
+ import json
1079
+ import zlib
1080
+
1081
+ import numpy as np
1082
+ from PIL import Image
1083
+
1084
+
1085
+ class EmbeddingDecoder(json.JSONDecoder):
1086
+ def __init__(self, *args, **kwargs):
1087
+ json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
1088
+
1089
+ def object_hook(self, d):
1090
+ if "TORCHTENSOR" in d:
1091
+ return paddle.to_tensor(np.array(d["TORCHTENSOR"]))
1092
+ return d
1093
+
1094
+
1095
+ def embedding_from_b64(data):
1096
+ d = base64.b64decode(data)
1097
+ return json.loads(d, cls=EmbeddingDecoder)
1098
+
1099
+
1100
+ def lcg(m=2**32, a=1664525, c=1013904223, seed=0):
1101
+ while True:
1102
+ seed = (a * seed + c) % m
1103
+ yield seed % 255
1104
+
1105
+
1106
+ def xor_block(block):
1107
+ g = lcg()
1108
+ randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape)
1109
+ return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F)
1110
+
1111
+
1112
+ def crop_black(img, tol=0):
1113
+ mask = (img > tol).all(2)
1114
+ mask0, mask1 = mask.any(0), mask.any(1)
1115
+ col_start, col_end = mask0.argmax(), mask.shape[1] - mask0[::-1].argmax()
1116
+ row_start, row_end = mask1.argmax(), mask.shape[0] - mask1[::-1].argmax()
1117
+ return img[row_start:row_end, col_start:col_end]
1118
+
1119
+
1120
+ def extract_image_data_embed(image):
1121
+ d = 3
1122
+ outarr = (
1123
+ crop_black(np.array(image.convert("RGB").getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8))
1124
+ & 0x0F
1125
+ )
1126
+ black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0)
1127
+ if black_cols[0].shape[0] < 2:
1128
+ print("No Image data blocks found.")
1129
+ return None
1130
+
1131
+ data_block_lower = outarr[:, : black_cols[0].min(), :].astype(np.uint8)
1132
+ data_block_upper = outarr[:, black_cols[0].max() + 1 :, :].astype(np.uint8)
1133
+
1134
+ data_block_lower = xor_block(data_block_lower)
1135
+ data_block_upper = xor_block(data_block_upper)
1136
+
1137
+ data_block = (data_block_upper << 4) | (data_block_lower)
1138
+ data_block = data_block.flatten().tobytes()
1139
+
1140
+ data = zlib.decompress(data_block)
1141
+ return json.loads(data, cls=EmbeddingDecoder)
1142
+
1143
+
1144
+ # prompt_parser.py
1145
+ import re
1146
+ from collections import namedtuple
1147
+ from typing import List
1148
+
1149
+ import lark
1150
+
1151
+ # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
1152
+ # will be represented with prompt_schedule like this (assuming steps=100):
1153
+ # [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
1154
+ # [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
1155
+ # [60, 'fantasy landscape with a lake and an oak in foreground in background masterful']
1156
+ # [75, 'fantasy landscape with a lake and an oak in background masterful']
1157
+ # [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
1158
+
1159
+ schedule_parser = lark.Lark(
1160
+ r"""
1161
+ !start: (prompt | /[][():]/+)*
1162
+ prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
1163
+ !emphasized: "(" prompt ")"
1164
+ | "(" prompt ":" prompt ")"
1165
+ | "[" prompt "]"
1166
+ scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
1167
+ alternate: "[" prompt ("|" prompt)+ "]"
1168
+ WHITESPACE: /\s+/
1169
+ plain: /([^\\\[\]():|]|\\.)+/
1170
+ %import common.SIGNED_NUMBER -> NUMBER
1171
+ """
1172
+ )
1173
+
1174
+
1175
+ def get_learned_conditioning_prompt_schedules(prompts, steps):
1176
+ """
1177
+ >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
1178
+ >>> g("test")
1179
+ [[10, 'test']]
1180
+ >>> g("a [b:3]")
1181
+ [[3, 'a '], [10, 'a b']]
1182
+ >>> g("a [b: 3]")
1183
+ [[3, 'a '], [10, 'a b']]
1184
+ >>> g("a [[[b]]:2]")
1185
+ [[2, 'a '], [10, 'a [[b]]']]
1186
+ >>> g("[(a:2):3]")
1187
+ [[3, ''], [10, '(a:2)']]
1188
+ >>> g("a [b : c : 1] d")
1189
+ [[1, 'a b d'], [10, 'a c d']]
1190
+ >>> g("a[b:[c:d:2]:1]e")
1191
+ [[1, 'abe'], [2, 'ace'], [10, 'ade']]
1192
+ >>> g("a [unbalanced")
1193
+ [[10, 'a [unbalanced']]
1194
+ >>> g("a [b:.5] c")
1195
+ [[5, 'a c'], [10, 'a b c']]
1196
+ >>> g("a [{b|d{:.5] c") # not handling this right now
1197
+ [[5, 'a c'], [10, 'a {b|d{ c']]
1198
+ >>> g("((a][:b:c [d:3]")
1199
+ [[3, '((a][:b:c '], [10, '((a][:b:c d']]
1200
+ >>> g("[a|(b:1.1)]")
1201
+ [[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
1202
+ """
1203
+
1204
+ def collect_steps(steps, tree):
1205
+ l = [steps]
1206
+
1207
+ class CollectSteps(lark.Visitor):
1208
+ def scheduled(self, tree):
1209
+ tree.children[-1] = float(tree.children[-1])
1210
+ if tree.children[-1] < 1:
1211
+ tree.children[-1] *= steps
1212
+ tree.children[-1] = min(steps, int(tree.children[-1]))
1213
+ l.append(tree.children[-1])
1214
+
1215
+ def alternate(self, tree):
1216
+ l.extend(range(1, steps + 1))
1217
+
1218
+ CollectSteps().visit(tree)
1219
+ return sorted(set(l))
1220
+
1221
+ def at_step(step, tree):
1222
+ class AtStep(lark.Transformer):
1223
+ def scheduled(self, args):
1224
+ before, after, _, when = args
1225
+ yield before or () if step <= when else after
1226
+
1227
+ def alternate(self, args):
1228
+ yield next(args[(step - 1) % len(args)])
1229
+
1230
+ def start(self, args):
1231
+ def flatten(x):
1232
+ if type(x) == str:
1233
+ yield x
1234
+ else:
1235
+ for gen in x:
1236
+ yield from flatten(gen)
1237
+
1238
+ return "".join(flatten(args))
1239
+
1240
+ def plain(self, args):
1241
+ yield args[0].value
1242
+
1243
+ def __default__(self, data, children, meta):
1244
+ for child in children:
1245
+ yield child
1246
+
1247
+ return AtStep().transform(tree)
1248
+
1249
+ def get_schedule(prompt):
1250
+ try:
1251
+ tree = schedule_parser.parse(prompt)
1252
+ except lark.exceptions.LarkError:
1253
+ if 0:
1254
+ import traceback
1255
+
1256
+ traceback.print_exc()
1257
+ return [[steps, prompt]]
1258
+ return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
1259
+
1260
+ promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)}
1261
+ return [promptdict[prompt] for prompt in prompts]
1262
+
1263
+
1264
+ ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
1265
+
1266
+
1267
+ def get_learned_conditioning(model, prompts, steps):
1268
+ """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
1269
+ and the sampling step at which this condition is to be replaced by the next one.
1270
+
1271
+ Input:
1272
+ (model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20)
1273
+
1274
+ Output:
1275
+ [
1276
+ [
1277
+ ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0'))
1278
+ ],
1279
+ [
1280
+ ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')),
1281
+ ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0'))
1282
+ ]
1283
+ ]
1284
+ """
1285
+ res = []
1286
+
1287
+ prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
1288
+ cache = {}
1289
+
1290
+ for prompt, prompt_schedule in zip(prompts, prompt_schedules):
1291
+
1292
+ cached = cache.get(prompt, None)
1293
+ if cached is not None:
1294
+ res.append(cached)
1295
+ continue
1296
+
1297
+ texts = [x[1] for x in prompt_schedule]
1298
+ conds = model(texts)
1299
+
1300
+ cond_schedule = []
1301
+ for i, (end_at_step, text) in enumerate(prompt_schedule):
1302
+ cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
1303
+
1304
+ cache[prompt] = cond_schedule
1305
+ res.append(cond_schedule)
1306
+
1307
+ return res
1308
+
1309
+
1310
+ re_AND = re.compile(r"\bAND\b")
1311
+ re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
1312
+
1313
+
1314
+ def get_multicond_prompt_list(prompts):
1315
+ res_indexes = []
1316
+
1317
+ prompt_flat_list = []
1318
+ prompt_indexes = {}
1319
+
1320
+ for prompt in prompts:
1321
+ subprompts = re_AND.split(prompt)
1322
+
1323
+ indexes = []
1324
+ for subprompt in subprompts:
1325
+ match = re_weight.search(subprompt)
1326
+
1327
+ text, weight = match.groups() if match is not None else (subprompt, 1.0)
1328
+
1329
+ weight = float(weight) if weight is not None else 1.0
1330
+
1331
+ index = prompt_indexes.get(text, None)
1332
+ if index is None:
1333
+ index = len(prompt_flat_list)
1334
+ prompt_flat_list.append(text)
1335
+ prompt_indexes[text] = index
1336
+
1337
+ indexes.append((index, weight))
1338
+
1339
+ res_indexes.append(indexes)
1340
+
1341
+ return res_indexes, prompt_flat_list, prompt_indexes
1342
+
1343
+
1344
+ class ComposableScheduledPromptConditioning:
1345
+ def __init__(self, schedules, weight=1.0):
1346
+ self.schedules: List[ScheduledPromptConditioning] = schedules
1347
+ self.weight: float = weight
1348
+
1349
+
1350
+ class MulticondLearnedConditioning:
1351
+ def __init__(self, shape, batch):
1352
+ self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
1353
+ self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
1354
+
1355
+
1356
+ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
1357
+ """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
1358
+ For each prompt, the list is obtained by splitting the prompt using the AND separator.
1359
+
1360
+ https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/
1361
+ """
1362
+
1363
+ res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
1364
+
1365
+ learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps)
1366
+
1367
+ res = []
1368
+ for indexes in res_indexes:
1369
+ res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes])
1370
+
1371
+ return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
1372
+
1373
+
1374
+ def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
1375
+ param = c[0][0].cond
1376
+ res = paddle.zeros(
1377
+ [
1378
+ len(c),
1379
+ ]
1380
+ + param.shape,
1381
+ dtype=param.dtype,
1382
+ )
1383
+ for i, cond_schedule in enumerate(c):
1384
+ target_index = 0
1385
+ for current, (end_at, cond) in enumerate(cond_schedule):
1386
+ if current_step <= end_at:
1387
+ target_index = current
1388
+ break
1389
+ res[i] = cond_schedule[target_index].cond
1390
+
1391
+ return res
1392
+
1393
+
1394
+ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
1395
+ param = c.batch[0][0].schedules[0].cond
1396
+
1397
+ tensors = []
1398
+ conds_list = []
1399
+
1400
+ for batch_no, composable_prompts in enumerate(c.batch):
1401
+ conds_for_batch = []
1402
+
1403
+ for cond_index, composable_prompt in enumerate(composable_prompts):
1404
+ target_index = 0
1405
+ for current, (end_at, cond) in enumerate(composable_prompt.schedules):
1406
+ if current_step <= end_at:
1407
+ target_index = current
1408
+ break
1409
+
1410
+ conds_for_batch.append((len(tensors), composable_prompt.weight))
1411
+ tensors.append(composable_prompt.schedules[target_index].cond)
1412
+
1413
+ conds_list.append(conds_for_batch)
1414
+
1415
+ # if prompts have wildly different lengths above the limit we'll get tensors fo different shapes
1416
+ # and won't be able to torch.stack them. So this fixes that.
1417
+ token_count = max([x.shape[0] for x in tensors])
1418
+ for i in range(len(tensors)):
1419
+ if tensors[i].shape[0] != token_count:
1420
+ last_vector = tensors[i][-1:]
1421
+ last_vector_repeated = last_vector.tile([token_count - tensors[i].shape[0], 1])
1422
+ tensors[i] = paddle.concat([tensors[i], last_vector_repeated], axis=0)
1423
+
1424
+ return conds_list, paddle.stack(tensors).cast(dtype=param.dtype)
1425
+
1426
+
1427
+ re_attention = re.compile(
1428
+ r"""
1429
+ \\\(|
1430
+ \\\)|
1431
+ \\\[|
1432
+ \\]|
1433
+ \\\\|
1434
+ \\|
1435
+ \(|
1436
+ \[|
1437
+ :([+-]?[.\d]+)\)|
1438
+ \)|
1439
+ ]|
1440
+ [^\\()\[\]:]+|
1441
+ :
1442
+ """,
1443
+ re.X,
1444
+ )
1445
+
1446
+ re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
1447
+
1448
+
1449
+ def parse_prompt_attention(text):
1450
+ """
1451
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
1452
+ Accepted tokens are:
1453
+ (abc) - increases attention to abc by a multiplier of 1.1
1454
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
1455
+ [abc] - decreases attention to abc by a multiplier of 1.1
1456
+ \( - literal character '('
1457
+ \[ - literal character '['
1458
+ \) - literal character ')'
1459
+ \] - literal character ']'
1460
+ \\ - literal character '\'
1461
+ anything else - just text
1462
+
1463
+ >>> parse_prompt_attention('normal text')
1464
+ [['normal text', 1.0]]
1465
+ >>> parse_prompt_attention('an (important) word')
1466
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
1467
+ >>> parse_prompt_attention('(unbalanced')
1468
+ [['unbalanced', 1.1]]
1469
+ >>> parse_prompt_attention('\(literal\]')
1470
+ [['(literal]', 1.0]]
1471
+ >>> parse_prompt_attention('(unnecessary)(parens)')
1472
+ [['unnecessaryparens', 1.1]]
1473
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
1474
+ [['a ', 1.0],
1475
+ ['house', 1.5730000000000004],
1476
+ [' ', 1.1],
1477
+ ['on', 1.0],
1478
+ [' a ', 1.1],
1479
+ ['hill', 0.55],
1480
+ [', sun, ', 1.1],
1481
+ ['sky', 1.4641000000000006],
1482
+ ['.', 1.1]]
1483
+ """
1484
+
1485
+ res = []
1486
+ round_brackets = []
1487
+ square_brackets = []
1488
+
1489
+ round_bracket_multiplier = 1.1
1490
+ square_bracket_multiplier = 1 / 1.1
1491
+
1492
+ def multiply_range(start_position, multiplier):
1493
+ for p in range(start_position, len(res)):
1494
+ res[p][1] *= multiplier
1495
+
1496
+ for m in re_attention.finditer(text):
1497
+ text = m.group(0)
1498
+ weight = m.group(1)
1499
+
1500
+ if text.startswith("\\"):
1501
+ res.append([text[1:], 1.0])
1502
+ elif text == "(":
1503
+ round_brackets.append(len(res))
1504
+ elif text == "[":
1505
+ square_brackets.append(len(res))
1506
+ elif weight is not None and len(round_brackets) > 0:
1507
+ multiply_range(round_brackets.pop(), float(weight))
1508
+ elif text == ")" and len(round_brackets) > 0:
1509
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
1510
+ elif text == "]" and len(square_brackets) > 0:
1511
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
1512
+ else:
1513
+ parts = re.split(re_break, text)
1514
+ for i, part in enumerate(parts):
1515
+ if i > 0:
1516
+ res.append(["BREAK", -1])
1517
+ res.append([part, 1.0])
1518
+
1519
+ for pos in round_brackets:
1520
+ multiply_range(pos, round_bracket_multiplier)
1521
+
1522
+ for pos in square_brackets:
1523
+ multiply_range(pos, square_bracket_multiplier)
1524
+
1525
+ if len(res) == 0:
1526
+ res = [["", 1.0]]
1527
+
1528
+ # merge runs of identical weights
1529
+ i = 0
1530
+ while i + 1 < len(res):
1531
+ if res[i][1] == res[i + 1][1]:
1532
+ res[i][0] += res[i + 1][0]
1533
+ res.pop(i + 1)
1534
+ else:
1535
+ i += 1
1536
+
1537
+ return res
1538
+
1539
+
1540
+ # sd_hijack.py
1541
+
1542
+
1543
+ class StableDiffusionModelHijack:
1544
+ fixes = None
1545
+ comments = []
1546
+ layers = None
1547
+ circular_enabled = False
1548
+
1549
+ def __init__(self, clip_model, embeddings_dir=None, CLIP_stop_at_last_layers=-1):
1550
+ model_embeddings = clip_model.text_encoder.text_model
1551
+ model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
1552
+ clip_model = FrozenCLIPEmbedderWithCustomWords(
1553
+ clip_model, self, CLIP_stop_at_last_layers=CLIP_stop_at_last_layers
1554
+ )
1555
+
1556
+ self.embedding_db = EmbeddingDatabase(clip_model)
1557
+ self.embedding_db.add_embedding_dir(embeddings_dir)
1558
+
1559
+ # hack this!
1560
+ self.clip = clip_model
1561
+
1562
+ def flatten(el):
1563
+ flattened = [flatten(children) for children in el.children()]
1564
+ res = [el]
1565
+ for c in flattened:
1566
+ res += c
1567
+ return res
1568
+
1569
+ self.layers = flatten(clip_model)
1570
+
1571
+ def clear_comments(self):
1572
+ self.comments = []
1573
+
1574
+ def get_prompt_lengths(self, text):
1575
+ _, token_count = self.clip.process_texts([text])
1576
+
1577
+ return token_count, self.clip.get_target_prompt_token_count(token_count)
1578
+
1579
+
1580
+ class EmbeddingsWithFixes(nn.Layer):
1581
+ def __init__(self, wrapped, embeddings):
1582
+ super().__init__()
1583
+ self.wrapped = wrapped
1584
+ self.embeddings = embeddings
1585
+
1586
+ def forward(self, input_ids):
1587
+ batch_fixes = self.embeddings.fixes
1588
+ self.embeddings.fixes = None
1589
+
1590
+ inputs_embeds = self.wrapped(input_ids)
1591
+
1592
+ if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
1593
+ return inputs_embeds
1594
+
1595
+ vecs = []
1596
+ for fixes, tensor in zip(batch_fixes, inputs_embeds):
1597
+ for offset, embedding in fixes:
1598
+ emb = embedding.vec.cast(self.wrapped.dtype)
1599
+ emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
1600
+ tensor = paddle.concat([tensor[0 : offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len :]])
1601
+
1602
+ vecs.append(tensor)
1603
+
1604
+ return paddle.stack(vecs)
1605
+
1606
+
1607
+ # textual_inversion.py
1608
+
1609
+ import os
1610
+ import sys
1611
+ import traceback
1612
+
1613
+
1614
+ class Embedding:
1615
+ def __init__(self, vec, name, step=None):
1616
+ self.vec = vec
1617
+ self.name = name
1618
+ self.step = step
1619
+ self.shape = None
1620
+ self.vectors = 0
1621
+ self.cached_checksum = None
1622
+ self.sd_checkpoint = None
1623
+ self.sd_checkpoint_name = None
1624
+ self.optimizer_state_dict = None
1625
+ self.filename = None
1626
+
1627
+ def save(self, filename):
1628
+ embedding_data = {
1629
+ "string_to_token": {"*": 265},
1630
+ "string_to_param": {"*": self.vec},
1631
+ "name": self.name,
1632
+ "step": self.step,
1633
+ "sd_checkpoint": self.sd_checkpoint,
1634
+ "sd_checkpoint_name": self.sd_checkpoint_name,
1635
+ }
1636
+
1637
+ paddle.save(embedding_data, filename)
1638
+
1639
+ def checksum(self):
1640
+ if self.cached_checksum is not None:
1641
+ return self.cached_checksum
1642
+
1643
+ def const_hash(a):
1644
+ r = 0
1645
+ for v in a:
1646
+ r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
1647
+ return r
1648
+
1649
+ self.cached_checksum = f"{const_hash(self.vec.flatten() * 100) & 0xffff:04x}"
1650
+ return self.cached_checksum
1651
+
1652
+
1653
+ class DirWithTextualInversionEmbeddings:
1654
+ def __init__(self, path):
1655
+ self.path = path
1656
+ self.mtime = None
1657
+
1658
+ def has_changed(self):
1659
+ if not os.path.isdir(self.path):
1660
+ return False
1661
+
1662
+ mt = os.path.getmtime(self.path)
1663
+ if self.mtime is None or mt > self.mtime:
1664
+ return True
1665
+
1666
+ def update(self):
1667
+ if not os.path.isdir(self.path):
1668
+ return
1669
+
1670
+ self.mtime = os.path.getmtime(self.path)
1671
+
1672
+
1673
+ class EmbeddingDatabase:
1674
+ def __init__(self, clip):
1675
+ self.clip = clip
1676
+ self.ids_lookup = {}
1677
+ self.word_embeddings = {}
1678
+ self.skipped_embeddings = {}
1679
+ self.expected_shape = -1
1680
+ self.embedding_dirs = {}
1681
+ self.previously_displayed_embeddings = ()
1682
+
1683
+ def add_embedding_dir(self, path):
1684
+ if path is not None:
1685
+ self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
1686
+
1687
+ def clear_embedding_dirs(self):
1688
+ self.embedding_dirs.clear()
1689
+
1690
+ def register_embedding(self, embedding, model):
1691
+ self.word_embeddings[embedding.name] = embedding
1692
+
1693
+ ids = model.tokenize([embedding.name])[0]
1694
+
1695
+ first_id = ids[0]
1696
+ if first_id not in self.ids_lookup:
1697
+ self.ids_lookup[first_id] = []
1698
+
1699
+ self.ids_lookup[first_id] = sorted(
1700
+ self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True
1701
+ )
1702
+
1703
+ return embedding
1704
+
1705
+ def get_expected_shape(self):
1706
+ vec = self.clip.encode_embedding_init_text(",", 1)
1707
+ return vec.shape[1]
1708
+
1709
+ def load_from_file(self, path, filename):
1710
+ name, ext = os.path.splitext(filename)
1711
+ ext = ext.upper()
1712
+
1713
+ if ext in [".PNG", ".WEBP", ".JXL", ".AVIF"]:
1714
+ _, second_ext = os.path.splitext(name)
1715
+ if second_ext.upper() == ".PREVIEW":
1716
+ return
1717
+
1718
+ embed_image = Image.open(path)
1719
+ if hasattr(embed_image, "text") and "sd-ti-embedding" in embed_image.text:
1720
+ data = embedding_from_b64(embed_image.text["sd-ti-embedding"])
1721
+ name = data.get("name", name)
1722
+ else:
1723
+ data = extract_image_data_embed(embed_image)
1724
+ if data:
1725
+ name = data.get("name", name)
1726
+ else:
1727
+ # if data is None, means this is not an embeding, just a preview image
1728
+ return
1729
+ elif ext in [".BIN", ".PT"]:
1730
+ data = torch_load(path)
1731
+ elif ext in [".SAFETENSORS"]:
1732
+ data = safetensors_load(path)
1733
+ else:
1734
+ return
1735
+
1736
+ # textual inversion embeddings
1737
+ if "string_to_param" in data:
1738
+ param_dict = data["string_to_param"]
1739
+ if hasattr(param_dict, "_parameters"):
1740
+ param_dict = getattr(param_dict, "_parameters")
1741
+ assert len(param_dict) == 1, "embedding file has multiple terms in it"
1742
+ emb = next(iter(param_dict.items()))[1]
1743
+ # diffuser concepts
1744
+ elif type(data) == dict and type(next(iter(data.values()))) == paddle.Tensor:
1745
+ assert len(data.keys()) == 1, "embedding file has multiple terms in it"
1746
+
1747
+ emb = next(iter(data.values()))
1748
+ if len(emb.shape) == 1:
1749
+ emb = emb.unsqueeze(0)
1750
+ else:
1751
+ raise Exception(
1752
+ f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept."
1753
+ )
1754
+
1755
+ with paddle.no_grad():
1756
+ if hasattr(emb, "detach"):
1757
+ emb = emb.detach()
1758
+ if hasattr(emb, "cpu"):
1759
+ emb = emb.cpu()
1760
+ if hasattr(emb, "numpy"):
1761
+ emb = emb.numpy()
1762
+ emb = paddle.to_tensor(emb)
1763
+ vec = emb.detach().cast(paddle.float32)
1764
+ embedding = Embedding(vec, name)
1765
+ embedding.step = data.get("step", None)
1766
+ embedding.sd_checkpoint = data.get("sd_checkpoint", None)
1767
+ embedding.sd_checkpoint_name = data.get("sd_checkpoint_name", None)
1768
+ embedding.vectors = vec.shape[0]
1769
+ embedding.shape = vec.shape[-1]
1770
+ embedding.filename = path
1771
+
1772
+ if self.expected_shape == -1 or self.expected_shape == embedding.shape:
1773
+ self.register_embedding(embedding, self.clip)
1774
+ else:
1775
+ self.skipped_embeddings[name] = embedding
1776
+
1777
+ def load_from_dir(self, embdir):
1778
+ if not os.path.isdir(embdir.path):
1779
+ return
1780
+
1781
+ for root, dirs, fns in os.walk(embdir.path, followlinks=True):
1782
+ for fn in fns:
1783
+ try:
1784
+ fullfn = os.path.join(root, fn)
1785
+
1786
+ if os.stat(fullfn).st_size == 0:
1787
+ continue
1788
+
1789
+ self.load_from_file(fullfn, fn)
1790
+ except Exception:
1791
+ print(f"Error loading embedding {fn}:", file=sys.stderr)
1792
+ print(traceback.format_exc(), file=sys.stderr)
1793
+ continue
1794
+
1795
+ def load_textual_inversion_embeddings(self, force_reload=False):
1796
+ if not force_reload:
1797
+ need_reload = False
1798
+ for path, embdir in self.embedding_dirs.items():
1799
+ if embdir.has_changed():
1800
+ need_reload = True
1801
+ break
1802
+
1803
+ if not need_reload:
1804
+ return
1805
+
1806
+ self.ids_lookup.clear()
1807
+ self.word_embeddings.clear()
1808
+ self.skipped_embeddings.clear()
1809
+ self.expected_shape = self.get_expected_shape()
1810
+
1811
+ for path, embdir in self.embedding_dirs.items():
1812
+ self.load_from_dir(embdir)
1813
+ embdir.update()
1814
+
1815
+ displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
1816
+ if self.previously_displayed_embeddings != displayed_embeddings:
1817
+ self.previously_displayed_embeddings = displayed_embeddings
1818
+ print(
1819
+ f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}"
1820
+ )
1821
+ if len(self.skipped_embeddings) > 0:
1822
+ print(
1823
+ f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}"
1824
+ )
1825
+
1826
+ def find_embedding_at_position(self, tokens, offset):
1827
+ token = tokens[offset]
1828
+ possible_matches = self.ids_lookup.get(token, None)
1829
+
1830
+ if possible_matches is None:
1831
+ return None, None
1832
+
1833
+ for ids, embedding in possible_matches:
1834
+ if tokens[offset : offset + len(ids)] == ids:
1835
+ return embedding, len(ids)
1836
+
1837
+ return None, None