AlekseyCalvin commited on
Commit
2395f33
·
verified ·
1 Parent(s): bc4a507

Delete pipeline8.py

Browse files
Files changed (1) hide show
  1. pipeline8.py +0 -874
pipeline8.py DELETED
@@ -1,874 +0,0 @@
1
- import torch
2
- import numpy as np
3
- import html
4
- import inspect
5
- import re
6
- import urllib.parse as ul
7
- from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast, CLIPTextModelWithProjection
8
- from diffusers import FlowMatchEulerDiscreteScheduler, AutoPipelineForImage2Image, FluxPipeline, FluxTransformer2DModel
9
- from diffusers import StableDiffusion3Pipeline, AutoencoderKL, DiffusionPipeline, ImagePipelineOutput
10
- from diffusers.image_processor import VaeImageProcessor
11
- from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, SD3LoraLoaderMixin
12
- from diffusers.utils import (
13
- USE_PEFT_BACKEND,
14
- is_torch_xla_available,
15
- logging,
16
- BACKENDS_MAPPING,
17
- deprecate,
18
- replace_example_docstring,
19
- scale_lora_layers,
20
- unscale_lora_layers,
21
- )
22
- from diffusers.utils.torch_utils import randn_tensor
23
- from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
24
- from typing import Any, Callable, Dict, List, Optional, Union
25
- from PIL import Image
26
- from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps, FluxTransformer2DModel
27
- from diffusers.utils import is_torch_xla_available
28
-
29
- if is_torch_xla_available():
30
- import torch_xla.core.xla_model as xm
31
-
32
- XLA_AVAILABLE = True
33
- else:
34
- XLA_AVAILABLE = False
35
-
36
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
-
38
- # Constants for shift calculation
39
- BASE_SEQ_LEN = 256
40
- MAX_SEQ_LEN = 4096
41
- BASE_SHIFT = 0.5
42
- MAX_SHIFT = 1.2
43
-
44
- # Helper functions
45
- def calculate_timestep_shift(image_seq_len: int) -> float:
46
- """Calculates the timestep shift (mu) based on the image sequence length."""
47
- m = (MAX_SHIFT - BASE_SHIFT) / (MAX_SEQ_LEN - BASE_SEQ_LEN)
48
- b = BASE_SHIFT - m * BASE_SEQ_LEN
49
- mu = image_seq_len * m + b
50
- return mu
51
-
52
- def prepare_timesteps(
53
- scheduler: FlowMatchEulerDiscreteScheduler,
54
- num_inference_steps: Optional[int] = None,
55
- device: Optional[Union[str, torch.device]] = None,
56
- timesteps: Optional[List[int]] = None,
57
- sigmas: Optional[List[float]] = None,
58
- mu: Optional[float] = None,
59
- ) -> (torch.Tensor, int):
60
- """Prepares the timesteps for the diffusion process."""
61
- if timesteps is not None and sigmas is not None:
62
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed.")
63
-
64
- if timesteps is not None:
65
- scheduler.set_timesteps(timesteps=timesteps, device=device)
66
- elif sigmas is not None:
67
- scheduler.set_timesteps(sigmas=sigmas, device=device)
68
- else:
69
- scheduler.set_timesteps(num_inference_steps, device=device, mu=mu)
70
-
71
- timesteps = scheduler.timesteps
72
- num_inference_steps = len(timesteps)
73
- return timesteps, num_inference_steps
74
-
75
- # FLUX pipeline function
76
- class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
77
- def __init__(
78
- self,
79
- scheduler: FlowMatchEulerDiscreteScheduler,
80
- vae: AutoencoderKL,
81
- text_encoder: CLIPTextModel,
82
- tokenizer: CLIPTokenizer,
83
- text_encoder_2: T5EncoderModel,
84
- tokenizer_2: T5TokenizerFast,
85
- transformer: FluxTransformer2DModel,
86
- ):
87
- super().__init__()
88
-
89
- self.register_modules(
90
- vae=vae,
91
- text_encoder=text_encoder,
92
- text_encoder_2=text_encoder_2,
93
- tokenizer=tokenizer,
94
- tokenizer_2=tokenizer_2,
95
- transformer=transformer,
96
- scheduler=scheduler,
97
- )
98
- self.vae_scale_factor = (
99
- 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
100
- )
101
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
102
- self.tokenizer_max_length = (
103
- self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
104
- )
105
- self.default_sample_size = 64
106
-
107
- def _get_t5_prompt_embeds(
108
- self,
109
- prompt: Union[str, List[str]] = None,
110
- num_images_per_prompt: int = 1,
111
- max_sequence_length: int = 512,
112
- device: Optional[torch.device] = None,
113
- dtype: Optional[torch.dtype] = None,
114
- ):
115
- device = device or self._execution_device
116
- dtype = dtype or self.text_encoder.dtype
117
-
118
- prompt = [prompt] if isinstance(prompt, str) else prompt
119
- batch_size = len(prompt)
120
-
121
- text_inputs = self.tokenizer_2(
122
- prompt,
123
- padding="max_length",
124
- max_length=max_sequence_length,
125
- truncation=True,
126
- return_length=True,
127
- return_overflowing_tokens=True,
128
- return_tensors="pt",
129
- )
130
- text_input_ids = text_inputs.input_ids
131
- untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
132
-
133
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
134
- removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
135
- logger.warning(
136
- "The following part of your input was truncated because `max_sequence_length` is set to "
137
- f" {max_sequence_length} tokens: {removed_text}"
138
- )
139
-
140
- prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
141
-
142
- dtype = self.text_encoder_2.dtype
143
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
144
-
145
- _, seq_len, _ = prompt_embeds.shape
146
-
147
- # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
148
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
149
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
150
-
151
- return prompt_embeds
152
-
153
- def _get_clip_prompt_embeds(
154
- self,
155
- prompt: Union[str, List[str]],
156
- num_images_per_prompt: int = 1,
157
- device: Optional[torch.device] = None,
158
- ):
159
- device = device or self._execution_device
160
-
161
- prompt = [prompt] if isinstance(prompt, str) else prompt
162
- batch_size = len(prompt)
163
-
164
- text_inputs = self.tokenizer(
165
- prompt,
166
- padding="max_length",
167
- max_length=self.tokenizer_max_length,
168
- truncation=True,
169
- return_overflowing_tokens=False,
170
- return_length=False,
171
- return_tensors="pt",
172
- )
173
-
174
- text_input_ids = text_inputs.input_ids
175
- untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
176
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
177
- removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
178
- logger.warning(
179
- "The following part of your input was truncated because CLIP can only handle sequences up to"
180
- f" {self.tokenizer_max_length} tokens: {removed_text}"
181
- )
182
- prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
183
-
184
- # Use pooled output of CLIPTextModel
185
- prompt_embeds = prompt_embeds.pooler_output
186
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
187
-
188
- # duplicate text embeddings for each generation per prompt, using mps friendly method
189
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
190
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
191
-
192
- return prompt_embeds
193
-
194
- def encode_prompt(
195
- self,
196
- prompt: Union[str, List[str]],
197
- prompt_2: Union[str, List[str]],
198
- device: Optional[torch.device] = None,
199
- num_images_per_prompt: int = 1,
200
- do_classifier_free_guidance: bool = True,
201
- negative_prompt: Optional[Union[str, List[str]]] = None,
202
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
203
- prompt_embeds: Optional[torch.FloatTensor] = None,
204
- negative_prompt_embeds: Optional[torch.Tensor] = None,
205
- negative_prompt_2_embed: Optional[torch.Tensor] = None,
206
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
207
- negative_pooled_prompt_2_embed: Optional[torch.FloatTensor] = None,
208
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
209
- max_sequence_length: int = 512,
210
- lora_scale: Optional[float] = None,
211
- ):
212
- device = device or self._execution_device
213
-
214
- if device is None:
215
- device = self._execution_device
216
-
217
- # set lora scale so that monkey patched LoRA
218
- # function of text encoder can correctly access it
219
- if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
220
- self._lora_scale = lora_scale
221
-
222
- # dynamically adjust the LoRA scale
223
- if self.text_encoder is not None and USE_PEFT_BACKEND:
224
- scale_lora_layers(self.text_encoder, lora_scale)
225
- if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
226
- scale_lora_layers(self.text_encoder_2, lora_scale)
227
-
228
- prompt = [prompt] if isinstance(prompt, str) else prompt
229
- if prompt is not None:
230
- batch_size = len(prompt)
231
- else:
232
- batch_size = prompt_embeds.shape[0]
233
-
234
- if prompt_embeds is None:
235
- prompt_2 = prompt_2 or prompt
236
- prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
237
-
238
- pooled_prompt_embeds = self._get_clip_prompt_embeds(
239
- prompt=prompt,
240
- device=device,
241
- num_images_per_prompt=num_images_per_prompt,
242
- )
243
- prompt_embeds = self._get_t5_prompt_embeds(
244
- prompt=prompt_2,
245
- num_images_per_prompt=num_images_per_prompt,
246
- max_sequence_length=max_sequence_length,
247
- device=device,
248
- )
249
-
250
- if do_classifier_free_guidance and negative_prompt_embeds is None:
251
- negative_prompt = negative_prompt or ""
252
- negative_prompt_2 = negative_prompt_2 or negative_prompt
253
-
254
- # normalize str to list
255
- negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
256
- negative_prompt_2 = (
257
- batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
258
- )
259
-
260
- if prompt is not None and type(prompt) is not type(negative_prompt):
261
- raise TypeError(
262
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
263
- f" {type(prompt)}."
264
- )
265
- elif batch_size != len(negative_prompt):
266
- raise ValueError(
267
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
268
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
269
- " the batch size of `prompt`."
270
- )
271
-
272
- negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
273
- negative_prompt,
274
- device=device,
275
- num_images_per_prompt=num_images_per_prompt,
276
- )
277
- negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
278
-
279
- t5_negative_prompt_embed = self._get_t5_prompt_embeds(
280
- prompt=negative_prompt_2,
281
- num_images_per_prompt=num_images_per_prompt,
282
- max_sequence_length=max_sequence_length,
283
- device=device,
284
- )
285
-
286
- negative_clip_prompt_embeds = torch.nn.functional.pad(
287
- negative_clip_prompt_embeds,
288
- (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
289
- )
290
-
291
- negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
292
- negative_pooled_prompt_embeds = torch.cat(
293
- [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
294
- )
295
-
296
- if self.text_encoder is not None:
297
- if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
298
- # Retrieve the original scale by scaling back the LoRA layers
299
- unscale_lora_layers(self.text_encoder, lora_scale)
300
-
301
- if self.text_encoder_2 is not None:
302
- if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
303
- # Retrieve the original scale by scaling back the LoRA layers
304
- unscale_lora_layers(self.text_encoder_2, lora_scale)
305
-
306
- dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
307
- text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
308
-
309
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
310
- pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
311
-
312
- return prompt_embeds, pooled_prompt_embeds, text_ids, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
313
-
314
- def check_inputs(
315
- self,
316
- prompt,
317
- prompt_2,
318
- height,
319
- width,
320
- negative_prompt=None,
321
- negative_prompt_2=None,
322
- prompt_embeds=None,
323
- negative_prompt_embeds=None,
324
- pooled_prompt_embeds=None,
325
- negative_pooled_prompt_embeds=None,
326
- max_sequence_length=None,
327
- ):
328
- if height % 8 != 0 or width % 8 != 0:
329
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
330
-
331
- if prompt is not None and prompt_embeds is not None:
332
- raise ValueError(
333
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
334
- " only forward one of the two."
335
- )
336
- elif prompt_2 is not None and prompt_embeds is not None:
337
- raise ValueError(
338
- f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
339
- " only forward one of the two."
340
- )
341
- elif prompt is None and prompt_embeds is None:
342
- raise ValueError(
343
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
344
- )
345
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
346
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
347
- elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
348
- raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
349
-
350
- if prompt_embeds is not None and pooled_prompt_embeds is None:
351
- raise ValueError(
352
- "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
353
- )
354
- if negative_prompt is not None and negative_prompt_embeds is not None:
355
- raise ValueError(
356
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
357
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
358
- )
359
- elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
360
- raise ValueError(
361
- f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
362
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
363
- )
364
- if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
365
- raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
366
-
367
- if max_sequence_length is not None and max_sequence_length > 512:
368
- raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
369
-
370
- @staticmethod
371
- def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
372
- latent_image_ids = torch.zeros(height // 2, width // 2, 3)
373
- latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
374
- latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
375
-
376
- latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
377
-
378
- latent_image_ids = latent_image_ids.reshape(
379
- latent_image_id_height * latent_image_id_width, latent_image_id_channels
380
- )
381
-
382
- return latent_image_ids.to(device=device, dtype=dtype)
383
-
384
- @staticmethod
385
- def _pack_latents(latents, batch_size, num_channels_latents, height, width):
386
- latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
387
- latents = latents.permute(0, 2, 4, 1, 3, 5)
388
- latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
389
-
390
- return latents
391
-
392
- @staticmethod
393
- def _unpack_latents(latents, height, width, vae_scale_factor):
394
- batch_size, num_patches, channels = latents.shape
395
-
396
- height = height // vae_scale_factor
397
- width = width // vae_scale_factor
398
-
399
- latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
400
- latents = latents.permute(0, 3, 1, 4, 2, 5)
401
-
402
- latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
403
-
404
- return latents
405
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
406
- def prepare_extra_step_kwargs(self, generator, eta):
407
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
408
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
409
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
410
- # and should be between [0, 1]
411
-
412
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
413
- extra_step_kwargs = {}
414
- if accepts_eta:
415
- extra_step_kwargs["eta"] = eta
416
-
417
- # check if the scheduler accepts generator
418
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
419
- if accepts_generator:
420
- extra_step_kwargs["generator"] = generator
421
- return extra_step_kwargs
422
-
423
- def enable_vae_slicing(self):
424
- r"""
425
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
426
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
427
- """
428
- self.vae.enable_slicing()
429
-
430
- def disable_vae_slicing(self):
431
- r"""
432
- Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
433
- computing decoding in one step.
434
- """
435
- self.vae.disable_slicing()
436
-
437
- def enable_vae_tiling(self):
438
- r"""
439
- Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
440
- compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
441
- processing larger images.
442
- """
443
- self.vae.enable_tiling()
444
-
445
- def disable_vae_tiling(self):
446
- r"""
447
- Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
448
- computing decoding in one step.
449
- """
450
- self.vae.disable_tiling()
451
-
452
- def prepare_latents(
453
- self,
454
- batch_size,
455
- num_channels_latents,
456
- height,
457
- width,
458
- dtype,
459
- device,
460
- generator,
461
- latents=None,
462
- ):
463
- height = 2 * (int(height) // self.vae_scale_factor)
464
- width = 2 * (int(width) // self.vae_scale_factor)
465
-
466
- shape = (batch_size, num_channels_latents, height, width)
467
-
468
- if latents is not None:
469
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
470
- return latents.to(device=device, dtype=dtype), latent_image_ids
471
-
472
- if isinstance(generator, list) and len(generator) != batch_size:
473
- raise ValueError(
474
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
475
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
476
- )
477
-
478
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
479
- latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
480
-
481
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
482
-
483
- return latents, latent_image_ids
484
-
485
- @property
486
- def guidance_scale(self):
487
- return self._guidance_scale
488
-
489
- @property
490
- def do_classifier_free_guidance(self):
491
- return self._guidance_scale > 1
492
-
493
- @property
494
- def joint_attention_kwargs(self):
495
- return self._joint_attention_kwargs
496
-
497
- @property
498
- def num_timesteps(self):
499
- return self._num_timesteps
500
-
501
- @property
502
- def interrupt(self):
503
- return self._interrupt
504
-
505
- @torch.no_grad()
506
- @torch.inference_mode()
507
- def generate_images(
508
- self,
509
- prompt: Union[str, List[str]] = None,
510
- prompt_2: Optional[Union[str, List[str]]] = None,
511
- height: Optional[int] = None,
512
- width: Optional[int] = None,
513
- negative_prompt: Optional[Union[str, List[str]]] = None,
514
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
515
- num_inference_steps: int = 8,
516
- timesteps: List[int] = None,
517
- eta: Optional[float] = 0.0,
518
- guidance_scale: float = 3.5,
519
- num_images_per_prompt: Optional[int] = 1,
520
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
521
- latents: Optional[torch.FloatTensor] = None,
522
- prompt_embeds: Optional[torch.FloatTensor] = None,
523
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
524
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
525
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
526
- output_type: Optional[str] = "pil",
527
- return_dict: bool = True,
528
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
529
- max_sequence_length: int = 300,
530
- **kwargs,
531
- ):
532
- height = height or self.default_sample_size * self.vae_scale_factor
533
- width = width or self.default_sample_size * self.vae_scale_factor
534
-
535
- # 1. Check inputs
536
- self.check_inputs(
537
- prompt,
538
- prompt_2,
539
- height,
540
- width,
541
- negative_prompt=negative_prompt,
542
- negative_prompt_2=negative_prompt_2,
543
- prompt_embeds=prompt_embeds,
544
- negative_prompt_embeds=negative_prompt_embeds,
545
- pooled_prompt_embeds=pooled_prompt_embeds,
546
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
547
- max_sequence_length=max_sequence_length,
548
- )
549
-
550
- self._guidance_scale = guidance_scale
551
- self._joint_attention_kwargs = joint_attention_kwargs
552
- self._interrupt = False
553
-
554
- # 2. Define call parameters
555
- if prompt is not None and isinstance(prompt, str):
556
- batch_size = 1
557
- elif prompt is not None and isinstance(prompt, list):
558
- batch_size = len(prompt)
559
- else:
560
- batch_size = prompt_embeds.shape[0]
561
-
562
- device = self._execution_device
563
-
564
- do_classifier_free_guidance = guidance_scale > 1.0
565
-
566
- lora_scale = (
567
- self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
568
- )
569
- (
570
- prompt_embeds,
571
- negative_prompt_embeds,
572
- pooled_prompt_embeds,
573
- negative_pooled_prompt_embeds,
574
- ) = self.encode_prompt(
575
- prompt=prompt,
576
- prompt_2=prompt_2,
577
- negative_prompt=negative_prompt,
578
- negative_prompt_2=negative_prompt_2,
579
- do_classifier_free_guidance=self.do_classifier_free_guidance,
580
- prompt_embeds=prompt_embeds,
581
- negative_prompt_embeds=negative_prompt_embeds,
582
- pooled_prompt_embeds=pooled_prompt_embeds,
583
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
584
- device=device,
585
- num_images_per_prompt=num_images_per_prompt,
586
- max_sequence_length=max_sequence_length,
587
- lora_scale=lora_scale,
588
- )
589
-
590
- if self.do_classifier_free_guidance:
591
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
592
- pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
593
-
594
- # 4. Prepare latent variables
595
- num_channels_latents = self.transformer.config.in_channels // 4
596
- latents, latent_image_ids = self.prepare_latents(
597
- batch_size * num_images_per_prompt,
598
- num_channels_latents,
599
- height,
600
- width,
601
- prompt_embeds.dtype,
602
- negative_prompt_embeds.dtype,
603
- device,
604
- generator,
605
- latents,
606
- )
607
-
608
- # 5. Prepare timesteps
609
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
610
- image_seq_len = latents.shape[1]
611
- mu = calculate_timestep_shift(image_seq_len)
612
- timesteps, num_inference_steps = prepare_timesteps(
613
- self.scheduler,
614
- num_inference_steps,
615
- device,
616
- timesteps,
617
- sigmas,
618
- mu=mu,
619
- )
620
- self._num_timesteps = len(timesteps)
621
-
622
- # Handle guidance
623
- guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float16).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
624
-
625
- # 6. Denoising loop
626
- with self.progress_bar(total=num_inference_steps) as progress_bar:
627
- for i, t in enumerate(timesteps):
628
- if self.interrupt:
629
- continue
630
-
631
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
632
-
633
- timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
634
-
635
- if self.transformer.config.guidance_embeds:
636
- guidance = torch.tensor([guidance_scale], device=device)
637
- guidance = guidance.expand(latents.shape[0])
638
- else:
639
- guidance = None
640
-
641
- noise_pred = self.transformer(
642
- hidden_states=latent_model_input,
643
- timestep=timestep / 1000,
644
- guidance=guidance,
645
- pooled_projections=pooled_prompt_embeds,
646
- encoder_hidden_states=prompt_embeds,
647
- txt_ids=text_ids,
648
- img_ids=latent_image_ids,
649
- joint_attention_kwargs=self.joint_attention_kwargs,
650
- return_dict=False,
651
- )[0]
652
-
653
- noise_pred_uncond = self.transformer(
654
- hidden_states=latents,
655
- timestep=timestep / 1000,
656
- guidance=guidance,
657
- pooled_projections=negative_pooled_prompt_embeds,
658
- encoder_hidden_states=negative_prompt_embeds,
659
- img_ids=latent_image_ids,
660
- joint_attention_kwargs=self.joint_attention_kwargs,
661
- return_dict=False,
662
- )[0]
663
-
664
- if self.do_classifier_free_guidance:
665
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
666
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
667
-
668
- # compute the previous noisy sample x_t -> x_t-1
669
- latents_dtype = latents.dtype
670
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
671
-
672
- if latents.dtype != latents_dtype:
673
- if torch.backends.mps.is_available():
674
- # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
675
- latents = latents.to(latents_dtype)
676
-
677
- # call the callback, if provided
678
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
679
- progress_bar.update()
680
-
681
- # Final image
682
- return self._decode_latents_to_image(latents, height, width, output_type)
683
- self.maybe_free_model_hooks()
684
- torch.cuda.empty_cache()
685
-
686
- @torch.no_grad()
687
- @torch.inference_mode()
688
- def __call__(
689
- self,
690
- prompt: Union[str, List[str]] = None,
691
- prompt_2: Optional[Union[str, List[str]]] = None,
692
- height: Optional[int] = None,
693
- width: Optional[int] = None,
694
- negative_prompt: Optional[Union[str, List[str]]] = None,
695
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
696
- num_inference_steps: int = 8,
697
- timesteps: List[int] = None,
698
- eta: Optional[float] = 0.0,
699
- guidance_scale: float = 3.5,
700
- device: Optional[int] = None,
701
- num_images_per_prompt: Optional[int] = 1,
702
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
703
- latents: Optional[torch.FloatTensor] = None,
704
- prompt_embeds: Optional[torch.FloatTensor] = None,
705
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
706
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
707
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
708
- output_type: Optional[str] = "pil",
709
- return_dict: bool = True,
710
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
711
- max_sequence_length: int = 300,
712
- **kwargs,
713
- ):
714
- height = height or self.default_sample_size * self.vae_scale_factor
715
- width = width or self.default_sample_size * self.vae_scale_factor
716
-
717
- # 1. Check inputs
718
- self.check_inputs(
719
- prompt,
720
- prompt_2,
721
- height,
722
- width,
723
- negative_prompt=negative_prompt,
724
- negative_prompt_2=negative_prompt_2,
725
- prompt_embeds=prompt_embeds,
726
- negative_prompt_embeds=negative_prompt_embeds,
727
- pooled_prompt_embeds=pooled_prompt_embeds,
728
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
729
- max_sequence_length=max_sequence_length,
730
- )
731
-
732
- self._guidance_scale = guidance_scale
733
- self._joint_attention_kwargs = joint_attention_kwargs
734
- self._interrupt = False
735
-
736
- device = self._execution_device
737
-
738
- # 2. Define call parameters
739
- if prompt is not None and isinstance(prompt, str):
740
- batch_size = 1
741
- elif prompt is not None and isinstance(prompt, list):
742
- batch_size = len(prompt)
743
- else:
744
- batch_size = prompt_embeds.shape[0]
745
-
746
- do_classifier_free_guidance = guidance_scale > 1.0
747
-
748
- lora_scale = (
749
- self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
750
- )
751
- (
752
- prompt_embeds,
753
- negative_prompt_embeds,
754
- pooled_prompt_embeds,
755
- negative_pooled_prompt_embeds,
756
- ) = self.encode_prompt(
757
- prompt=prompt,
758
- prompt_2=prompt_2,
759
- negative_prompt=negative_prompt,
760
- negative_prompt_2=negative_prompt_2,
761
- do_classifier_free_guidance=self.do_classifier_free_guidance,
762
- prompt_embeds=prompt_embeds,
763
- negative_prompt_embeds=negative_prompt_embeds,
764
- pooled_prompt_embeds=pooled_prompt_embeds,
765
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
766
- device=device,
767
- num_images_per_prompt=num_images_per_prompt,
768
- max_sequence_length=max_sequence_length,
769
- lora_scale=lora_scale,
770
- )
771
-
772
- if self.do_classifier_free_guidance:
773
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
774
- pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
775
-
776
- # 4. Prepare latent variables
777
- num_channels_latents = self.transformer.config.in_channels // 4
778
- latents, latent_image_ids = self.prepare_latents(
779
- batch_size * num_images_per_prompt,
780
- num_channels_latents,
781
- height,
782
- width,
783
- prompt_embeds.dtype,
784
- negative_prompt_embeds.dtype,
785
- device,
786
- generator,
787
- latents,
788
- )
789
-
790
- # 5. Prepare timesteps
791
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
792
- image_seq_len = latents.shape[1]
793
- mu = calculate_timestep_shift(image_seq_len)
794
- timesteps, num_inference_steps = prepare_timesteps(
795
- self.scheduler,
796
- num_inference_steps,
797
- device,
798
- timesteps,
799
- sigmas,
800
- mu=mu,
801
- )
802
- self._num_timesteps = len(timesteps)
803
-
804
- # Handle guidance
805
- guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float16).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
806
-
807
- # 6. Denoising loop
808
- with self.progress_bar(total=num_inference_steps) as progress_bar:
809
- for i, t in enumerate(timesteps):
810
- if self.interrupt:
811
- continue
812
-
813
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
814
-
815
- timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
816
-
817
- if self.transformer.config.guidance_embeds:
818
- guidance = torch.tensor([guidance_scale], device=device)
819
- guidance = guidance.expand(latents.shape[0])
820
- else:
821
- guidance = None
822
-
823
- noise_pred = self.transformer(
824
- hidden_states=latent_model_input,
825
- timestep=timestep / 1000,
826
- guidance=guidance,
827
- pooled_projections=pooled_prompt_embeds,
828
- encoder_hidden_states=prompt_embeds,
829
- txt_ids=text_ids,
830
- img_ids=latent_image_ids,
831
- joint_attention_kwargs=self.joint_attention_kwargs,
832
- return_dict=False,
833
- )[0]
834
-
835
- noise_pred_uncond = self.transformer(
836
- hidden_states=latents,
837
- timestep=timestep / 1000,
838
- guidance=guidance,
839
- pooled_projections=negative_pooled_prompt_embeds,
840
- encoder_hidden_states=negative_prompt_embeds,
841
- img_ids=latent_image_ids,
842
- joint_attention_kwargs=self.joint_attention_kwargs,
843
- return_dict=False,
844
- )[0]
845
-
846
- if self.do_classifier_free_guidance:
847
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
848
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
849
-
850
- # compute the previous noisy sample x_t -> x_t-1
851
- latents_dtype = latents.dtype
852
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
853
-
854
- if latents.dtype != latents_dtype:
855
- if torch.backends.mps.is_available():
856
- # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
857
- latents = latents.to(latents_dtype)
858
-
859
- # call the callback, if provided
860
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
861
- progress_bar.update()
862
-
863
- # Final image
864
- return self._decode_latents_to_image(latents, height, width, output_type)
865
- self.maybe_free_model_hooks()
866
- torch.cuda.empty_cache()
867
-
868
- def _decode_latents_to_image(self, latents, height, width, output_type, vae=None):
869
- """Decodes the given latents into an image."""
870
- vae = vae or self.vae
871
- latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
872
- latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
873
- image = vae.decode(latents, return_dict=False)[0]
874
- return self.image_processor.postprocess(image, output_type=output_type)[0]