AlekseyCalvin commited on
Commit
bc4a507
·
verified ·
1 Parent(s): 0c99ba1

Delete pipeline7.py

Browse files
Files changed (1) hide show
  1. pipeline7.py +0 -902
pipeline7.py DELETED
@@ -1,902 +0,0 @@
1
- import torch
2
- import numpy as np
3
- import html
4
- import inspect
5
- import re
6
- import urllib.parse as ul
7
- from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast, CLIPTextModelWithProjection
8
- from diffusers import FlowMatchEulerDiscreteScheduler, AutoPipelineForImage2Image, FluxPipeline, FluxTransformer2DModel
9
- from diffusers import StableDiffusion3Pipeline, AutoencoderKL, DiffusionPipeline, ImagePipelineOutput
10
- from diffusers.image_processor import VaeImageProcessor
11
- from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, SD3LoraLoaderMixin
12
- from diffusers.utils import (
13
- USE_PEFT_BACKEND,
14
- is_torch_xla_available,
15
- logging,
16
- BACKENDS_MAPPING,
17
- is_bs4_available,
18
- is_ftfy_available,
19
- deprecate,
20
- replace_example_docstring,
21
- scale_lora_layers,
22
- unscale_lora_layers,
23
- )
24
- from diffusers.utils.torch_utils import randn_tensor
25
- from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
26
- from typing import Any, Callable, Dict, List, Optional, Union
27
- from PIL import Image
28
- from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps, FluxTransformer2DModel
29
-
30
- from diffusers.utils import is_torch_xla_available
31
-
32
- if is_bs4_available():
33
- from bs4 import BeautifulSoup
34
-
35
- if is_ftfy_available():
36
- import ftfy
37
-
38
- if is_torch_xla_available():
39
- import torch_xla.core.xla_model as xm
40
-
41
- XLA_AVAILABLE = True
42
- else:
43
- XLA_AVAILABLE = False
44
-
45
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
-
47
- # Constants for shift calculation
48
- BASE_SEQ_LEN = 256
49
- MAX_SEQ_LEN = 4096
50
- BASE_SHIFT = 0.5
51
- MAX_SHIFT = 1.2
52
-
53
- # Helper functions
54
- def calculate_timestep_shift(image_seq_len: int) -> float:
55
- """Calculates the timestep shift (mu) based on the image sequence length."""
56
- m = (MAX_SHIFT - BASE_SHIFT) / (MAX_SEQ_LEN - BASE_SEQ_LEN)
57
- b = BASE_SHIFT - m * BASE_SEQ_LEN
58
- mu = image_seq_len * m + b
59
- return mu
60
-
61
- def prepare_timesteps(
62
- scheduler: FlowMatchEulerDiscreteScheduler,
63
- num_inference_steps: Optional[int] = None,
64
- device: Optional[Union[str, torch.device]] = None,
65
- timesteps: Optional[List[int]] = None,
66
- sigmas: Optional[List[float]] = None,
67
- mu: Optional[float] = None,
68
- ) -> (torch.Tensor, int):
69
- """Prepares the timesteps for the diffusion process."""
70
- if timesteps is not None and sigmas is not None:
71
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed.")
72
-
73
- if timesteps is not None:
74
- scheduler.set_timesteps(timesteps=timesteps, device=device)
75
- elif sigmas is not None:
76
- scheduler.set_timesteps(sigmas=sigmas, device=device)
77
- else:
78
- scheduler.set_timesteps(num_inference_steps, device=device, mu=mu)
79
-
80
- timesteps = scheduler.timesteps
81
- num_inference_steps = len(timesteps)
82
- return timesteps, num_inference_steps
83
-
84
- # FLUX pipeline function
85
- class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
86
- def __init__(
87
- self,
88
- scheduler: FlowMatchEulerDiscreteScheduler,
89
- vae: AutoencoderKL,
90
- text_encoder: CLIPTextModel,
91
- tokenizer: CLIPTokenizer,
92
- text_encoder_2: T5EncoderModel,
93
- tokenizer_2: T5TokenizerFast,
94
- transformer: FluxTransformer2DModel,
95
- ):
96
- super().__init__()
97
-
98
- self.register_modules(
99
- vae=vae,
100
- text_encoder=text_encoder,
101
- text_encoder_2=text_encoder_2,
102
- tokenizer=tokenizer,
103
- tokenizer_2=tokenizer_2,
104
- transformer=transformer,
105
- scheduler=scheduler,
106
- )
107
- self.vae_scale_factor = (
108
- 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
109
- )
110
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
111
- self.tokenizer_max_length = (
112
- self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
113
- )
114
- self.default_sample_size = 64
115
-
116
- def _get_t5_prompt_embeds(
117
- self,
118
- prompt: Union[str, List[str]] = None,
119
- num_images_per_prompt: int = 1,
120
- max_sequence_length: int = 512,
121
- device: Optional[torch.device] = None,
122
- dtype: Optional[torch.dtype] = None,
123
- ):
124
- device = device or self._execution_device
125
- dtype = dtype or self.text_encoder.dtype
126
-
127
- prompt = [prompt] if isinstance(prompt, str) else prompt
128
- batch_size = len(prompt)
129
-
130
- text_inputs = self.tokenizer_2(
131
- prompt,
132
- padding="max_length",
133
- max_length=max_sequence_length,
134
- truncation=True,
135
- return_length=True,
136
- return_overflowing_tokens=True,
137
- return_tensors="pt",
138
- )
139
- text_input_ids = text_inputs.input_ids
140
- untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
141
-
142
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
143
- removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
144
- logger.warning(
145
- "The following part of your input was truncated because `max_sequence_length` is set to "
146
- f" {max_sequence_length} tokens: {removed_text}"
147
- )
148
-
149
- prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
150
-
151
- dtype = self.text_encoder_2.dtype
152
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
153
-
154
- _, seq_len, _ = prompt_embeds.shape
155
-
156
- # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
157
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
158
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
159
-
160
- return prompt_embeds
161
-
162
- def _get_clip_prompt_embeds(
163
- self,
164
- prompt: Union[str, List[str]],
165
- num_images_per_prompt: int = 1,
166
- device: Optional[torch.device] = None,
167
- ):
168
- device = device or self._execution_device
169
-
170
- prompt = [prompt] if isinstance(prompt, str) else prompt
171
- batch_size = len(prompt)
172
-
173
- text_inputs = self.tokenizer(
174
- prompt,
175
- padding="max_length",
176
- max_length=self.tokenizer_max_length,
177
- truncation=True,
178
- return_overflowing_tokens=False,
179
- return_length=False,
180
- return_tensors="pt",
181
- )
182
-
183
- text_input_ids = text_inputs.input_ids
184
- untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
185
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
186
- removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
187
- logger.warning(
188
- "The following part of your input was truncated because CLIP can only handle sequences up to"
189
- f" {self.tokenizer_max_length} tokens: {removed_text}"
190
- )
191
- prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
192
- pooled_prompt_embeds = prompt_embeds[0]
193
-
194
- # Use pooled output of CLIPTextModel
195
- prompt_embeds = prompt_embeds.pooler_output
196
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
197
-
198
- # duplicate text embeddings for each generation per prompt, using mps friendly method
199
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
200
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
201
-
202
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
203
- pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
204
-
205
- return prompt_embeds, pooled_prompt_embeds
206
-
207
- def encode_prompt(
208
- self,
209
- prompt: Union[str, List[str]],
210
- prompt_2: Union[str, List[str]],
211
- device: Optional[torch.device] = None,
212
- num_images_per_prompt: int = 1,
213
- do_classifier_free_guidance: bool = True,
214
- negative_prompt: Optional[Union[str, List[str]]] = None,
215
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
216
- prompt_embeds: Optional[torch.FloatTensor] = None,
217
- negative_prompt_embeds: Optional[torch.Tensor] = None,
218
- negative_prompt_2_embed: Optional[torch.Tensor] = None,
219
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
220
- negative_pooled_prompt_2_embed: Optional[torch.FloatTensor] = None,
221
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
222
- max_sequence_length: int = 512,
223
- lora_scale: Optional[float] = None,
224
- ):
225
- device = device or self._execution_device
226
-
227
- if device is None:
228
- device = self._execution_device
229
-
230
- # set lora scale so that monkey patched LoRA
231
- # function of text encoder can correctly access it
232
- if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
233
- self._lora_scale = lora_scale
234
-
235
- # dynamically adjust the LoRA scale
236
- if self.text_encoder is not None and USE_PEFT_BACKEND:
237
- scale_lora_layers(self.text_encoder, lora_scale)
238
- if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
239
- scale_lora_layers(self.text_encoder_2, lora_scale)
240
-
241
- prompt = [prompt] if isinstance(prompt, str) else prompt
242
-
243
- if prompt is not None and isinstance(prompt, str):
244
- batch_size = 1
245
- elif prompt is not None and isinstance(prompt, list):
246
- batch_size = len(prompt)
247
- else:
248
- batch_size = prompt_embeds.shape[0]
249
-
250
- if prompt_embeds is None:
251
- prompt_2 = prompt_2 or prompt
252
- prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
253
-
254
- # We only use the pooled prompt output from the CLIPTextModel
255
- pooled_prompt_embeds = self._get_clip_prompt_embeds(
256
- prompt=prompt,
257
- device=device,
258
- num_images_per_prompt=num_images_per_prompt,
259
- )
260
- prompt_embeds = self._get_t5_prompt_embeds(
261
- prompt=prompt_2,
262
- num_images_per_prompt=num_images_per_prompt,
263
- max_sequence_length=max_sequence_length,
264
- device=device,
265
- )
266
- prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
267
- prompt=prompt_2,
268
- device=device,
269
- num_images_per_prompt=num_images_per_prompt,
270
- )
271
- clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
272
-
273
- t5_prompt_embed = self._get_t5_prompt_embeds(
274
- prompt=prompt_3,
275
- num_images_per_prompt=num_images_per_prompt,
276
- max_sequence_length=max_sequence_length,
277
- device=device,
278
- )
279
-
280
- clip_prompt_embeds = torch.nn.functional.pad(
281
- clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
282
- )
283
-
284
- prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
285
- pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
286
-
287
- if do_classifier_free_guidance and negative_prompt_embeds is None:
288
- negative_prompt = negative_prompt or ""
289
- negative_prompt_2 = negative_prompt_2 or negative_prompt
290
-
291
- # normalize str to list
292
- negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
293
- negative_prompt_2 = (
294
- batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
295
- )
296
-
297
- if prompt is not None and type(prompt) is not type(negative_prompt):
298
- raise TypeError(
299
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
300
- f" {type(prompt)}."
301
- )
302
- elif batch_size != len(negative_prompt):
303
- raise ValueError(
304
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
305
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
306
- " the batch size of `prompt`."
307
- )
308
-
309
- negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
310
- negative_prompt,
311
- device=device,
312
- num_images_per_prompt=num_images_per_prompt,
313
- )
314
- negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
315
-
316
- t5_negative_prompt_embed = self._get_t5_prompt_embeds(
317
- prompt=negative_prompt_2,
318
- num_images_per_prompt=num_images_per_prompt,
319
- max_sequence_length=max_sequence_length,
320
- device=device,
321
- )
322
-
323
- negative_clip_prompt_embeds = torch.nn.functional.pad(
324
- negative_clip_prompt_embeds,
325
- (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
326
- )
327
-
328
- negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
329
- negative_pooled_prompt_embeds = torch.cat(
330
- [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
331
- )
332
-
333
- if self.text_encoder is not None:
334
- if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
335
- # Retrieve the original scale by scaling back the LoRA layers
336
- unscale_lora_layers(self.text_encoder, lora_scale)
337
-
338
- dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
339
- text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
340
-
341
- return prompt_embeds, pooled_prompt_embeds, text_ids, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
342
-
343
- def check_inputs(
344
- self,
345
- prompt,
346
- prompt_2,
347
- height,
348
- width,
349
- negative_prompt=None,
350
- negative_prompt_2=None,
351
- prompt_embeds=None,
352
- negative_prompt_embeds=None,
353
- pooled_prompt_embeds=None,
354
- negative_pooled_prompt_embeds=None,
355
- max_sequence_length=None,
356
- ):
357
- if height % 8 != 0 or width % 8 != 0:
358
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
359
-
360
- if prompt is not None and prompt_embeds is not None:
361
- raise ValueError(
362
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
363
- " only forward one of the two."
364
- )
365
- elif prompt_2 is not None and prompt_embeds is not None:
366
- raise ValueError(
367
- f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
368
- " only forward one of the two."
369
- )
370
- elif prompt is None and prompt_embeds is None:
371
- raise ValueError(
372
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
373
- )
374
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
375
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
376
- elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
377
- raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
378
-
379
- if prompt_embeds is not None and pooled_prompt_embeds is None:
380
- raise ValueError(
381
- "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
382
- )
383
- if negative_prompt is not None and negative_prompt_embeds is not None:
384
- raise ValueError(
385
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
386
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
387
- )
388
- elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
389
- raise ValueError(
390
- f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
391
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
392
- )
393
- if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
394
- raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
395
-
396
- if max_sequence_length is not None and max_sequence_length > 512:
397
- raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
398
-
399
- @staticmethod
400
- def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
401
- latent_image_ids = torch.zeros(height // 2, width // 2, 3)
402
- latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
403
- latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
404
-
405
- latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
406
-
407
- latent_image_ids = latent_image_ids.reshape(
408
- latent_image_id_height * latent_image_id_width, latent_image_id_channels
409
- )
410
-
411
- return latent_image_ids.to(device=device, dtype=dtype)
412
-
413
- @staticmethod
414
- def _pack_latents(latents, batch_size, num_channels_latents, height, width):
415
- latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
416
- latents = latents.permute(0, 2, 4, 1, 3, 5)
417
- latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
418
-
419
- return latents
420
-
421
- @staticmethod
422
- def _unpack_latents(latents, height, width, vae_scale_factor):
423
- batch_size, num_patches, channels = latents.shape
424
-
425
- height = height // vae_scale_factor
426
- width = width // vae_scale_factor
427
-
428
- latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
429
- latents = latents.permute(0, 3, 1, 4, 2, 5)
430
-
431
- latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
432
-
433
- return latents
434
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
435
- def prepare_extra_step_kwargs(self, generator, eta):
436
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
437
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
438
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
439
- # and should be between [0, 1]
440
-
441
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
442
- extra_step_kwargs = {}
443
- if accepts_eta:
444
- extra_step_kwargs["eta"] = eta
445
-
446
- # check if the scheduler accepts generator
447
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
448
- if accepts_generator:
449
- extra_step_kwargs["generator"] = generator
450
- return extra_step_kwargs
451
-
452
- def enable_vae_slicing(self):
453
- r"""
454
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
455
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
456
- """
457
- self.vae.enable_slicing()
458
-
459
- def disable_vae_slicing(self):
460
- r"""
461
- Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
462
- computing decoding in one step.
463
- """
464
- self.vae.disable_slicing()
465
-
466
- def enable_vae_tiling(self):
467
- r"""
468
- Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
469
- compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
470
- processing larger images.
471
- """
472
- self.vae.enable_tiling()
473
-
474
- def disable_vae_tiling(self):
475
- r"""
476
- Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
477
- computing decoding in one step.
478
- """
479
- self.vae.disable_tiling()
480
-
481
- def prepare_latents(
482
- self,
483
- batch_size,
484
- num_channels_latents,
485
- height,
486
- width,
487
- dtype,
488
- device,
489
- generator,
490
- latents=None,
491
- ):
492
- height = 2 * (int(height) // self.vae_scale_factor)
493
- width = 2 * (int(width) // self.vae_scale_factor)
494
-
495
- shape = (batch_size, num_channels_latents, height, width)
496
-
497
- if latents is not None:
498
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
499
- return latents.to(device=device, dtype=dtype), latent_image_ids
500
-
501
- if isinstance(generator, list) and len(generator) != batch_size:
502
- raise ValueError(
503
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
504
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
505
- )
506
-
507
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
508
- latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
509
-
510
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
511
-
512
- return latents, latent_image_ids
513
-
514
- @property
515
- def guidance_scale(self):
516
- return self._guidance_scale
517
-
518
- @property
519
- def do_classifier_free_guidance(self):
520
- return self._guidance_scale > 1
521
-
522
- @property
523
- def joint_attention_kwargs(self):
524
- return self._joint_attention_kwargs
525
-
526
- @property
527
- def num_timesteps(self):
528
- return self._num_timesteps
529
-
530
- @property
531
- def interrupt(self):
532
- return self._interrupt
533
-
534
- @torch.no_grad()
535
- @torch.inference_mode()
536
- def generate_images(
537
- self,
538
- prompt: Union[str, List[str]] = None,
539
- prompt_2: Optional[Union[str, List[str]]] = None,
540
- height: Optional[int] = None,
541
- width: Optional[int] = None,
542
- negative_prompt: Optional[Union[str, List[str]]] = None,
543
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
544
- num_inference_steps: int = 8,
545
- timesteps: List[int] = None,
546
- eta: Optional[float] = 0.0,
547
- guidance_scale: float = 3.5,
548
- num_images_per_prompt: Optional[int] = 1,
549
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
550
- latents: Optional[torch.FloatTensor] = None,
551
- prompt_embeds: Optional[torch.FloatTensor] = None,
552
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
553
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
554
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
555
- output_type: Optional[str] = "pil",
556
- return_dict: bool = True,
557
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
558
- max_sequence_length: int = 300,
559
- **kwargs,
560
- ):
561
- height = height or self.default_sample_size * self.vae_scale_factor
562
- width = width or self.default_sample_size * self.vae_scale_factor
563
-
564
- # 1. Check inputs
565
- self.check_inputs(
566
- prompt,
567
- prompt_2,
568
- height,
569
- width,
570
- negative_prompt=negative_prompt,
571
- negative_prompt_2=negative_prompt_2,
572
- prompt_embeds=prompt_embeds,
573
- negative_prompt_embeds=negative_prompt_embeds,
574
- pooled_prompt_embeds=pooled_prompt_embeds,
575
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
576
- max_sequence_length=max_sequence_length,
577
- )
578
-
579
- self._guidance_scale = guidance_scale
580
- self._joint_attention_kwargs = joint_attention_kwargs
581
- self._interrupt = False
582
-
583
- # 2. Define call parameters
584
- if prompt is not None and isinstance(prompt, str):
585
- batch_size = 1
586
- elif prompt is not None and isinstance(prompt, list):
587
- batch_size = len(prompt)
588
- else:
589
- batch_size = prompt_embeds.shape[0]
590
-
591
- device = self._execution_device
592
-
593
- do_classifier_free_guidance = guidance_scale > 1.0
594
-
595
- lora_scale = (
596
- self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
597
- )
598
- (
599
- prompt_embeds,
600
- negative_prompt_embeds,
601
- pooled_prompt_embeds,
602
- negative_pooled_prompt_embeds,
603
- ) = self.encode_prompt(
604
- prompt=prompt,
605
- prompt_2=prompt_2,
606
- negative_prompt=negative_prompt,
607
- negative_prompt_2=negative_prompt_2,
608
- do_classifier_free_guidance=self.do_classifier_free_guidance,
609
- prompt_embeds=prompt_embeds,
610
- negative_prompt_embeds=negative_prompt_embeds,
611
- pooled_prompt_embeds=pooled_prompt_embeds,
612
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
613
- device=device,
614
- num_images_per_prompt=num_images_per_prompt,
615
- max_sequence_length=max_sequence_length,
616
- lora_scale=lora_scale,
617
- )
618
-
619
- if self.do_classifier_free_guidance:
620
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
621
- pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
622
-
623
- # 4. Prepare latent variables
624
- num_channels_latents = self.transformer.config.in_channels // 4
625
- latents, latent_image_ids = self.prepare_latents(
626
- batch_size * num_images_per_prompt,
627
- num_channels_latents,
628
- height,
629
- width,
630
- prompt_embeds.dtype,
631
- negative_prompt_embeds.dtype,
632
- device,
633
- generator,
634
- latents,
635
- )
636
-
637
- # 5. Prepare timesteps
638
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
639
- image_seq_len = latents.shape[1]
640
- mu = calculate_timestep_shift(image_seq_len)
641
- timesteps, num_inference_steps = prepare_timesteps(
642
- self.scheduler,
643
- num_inference_steps,
644
- device,
645
- timesteps,
646
- sigmas,
647
- mu=mu,
648
- )
649
- self._num_timesteps = len(timesteps)
650
-
651
- # Handle guidance
652
- guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float16).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
653
-
654
- # 6. Denoising loop
655
- with self.progress_bar(total=num_inference_steps) as progress_bar:
656
- for i, t in enumerate(timesteps):
657
- if self.interrupt:
658
- continue
659
-
660
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
661
-
662
- timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
663
-
664
- if self.transformer.config.guidance_embeds:
665
- guidance = torch.tensor([guidance_scale], device=device)
666
- guidance = guidance.expand(latents.shape[0])
667
- else:
668
- guidance = None
669
-
670
- noise_pred = self.transformer(
671
- hidden_states=latent_model_input,
672
- timestep=timestep / 1000,
673
- guidance=guidance,
674
- pooled_projections=pooled_prompt_embeds,
675
- encoder_hidden_states=prompt_embeds,
676
- txt_ids=text_ids,
677
- img_ids=latent_image_ids,
678
- joint_attention_kwargs=self.joint_attention_kwargs,
679
- return_dict=False,
680
- )[0]
681
-
682
- noise_pred_uncond = self.transformer(
683
- hidden_states=latents,
684
- timestep=timestep / 1000,
685
- guidance=guidance,
686
- pooled_projections=negative_pooled_prompt_embeds,
687
- encoder_hidden_states=negative_prompt_embeds,
688
- img_ids=latent_image_ids,
689
- joint_attention_kwargs=self.joint_attention_kwargs,
690
- return_dict=False,
691
- )[0]
692
-
693
- if self.do_classifier_free_guidance:
694
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
695
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
696
-
697
- # compute the previous noisy sample x_t -> x_t-1
698
- latents_dtype = latents.dtype
699
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
700
-
701
- if latents.dtype != latents_dtype:
702
- if torch.backends.mps.is_available():
703
- # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
704
- latents = latents.to(latents_dtype)
705
-
706
- # call the callback, if provided
707
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
708
- progress_bar.update()
709
-
710
- # Final image
711
- return self._decode_latents_to_image(latents, height, width, output_type)
712
- self.maybe_free_model_hooks()
713
- torch.cuda.empty_cache()
714
-
715
- @torch.no_grad()
716
- @torch.inference_mode()
717
- def __call__(
718
- self,
719
- prompt: Union[str, List[str]] = None,
720
- prompt_2: Optional[Union[str, List[str]]] = None,
721
- height: Optional[int] = None,
722
- width: Optional[int] = None,
723
- negative_prompt: Optional[Union[str, List[str]]] = None,
724
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
725
- num_inference_steps: int = 8,
726
- timesteps: List[int] = None,
727
- eta: Optional[float] = 0.0,
728
- guidance_scale: float = 3.5,
729
- num_images_per_prompt: Optional[int] = 1,
730
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
731
- latents: Optional[torch.FloatTensor] = None,
732
- prompt_embeds: Optional[torch.FloatTensor] = None,
733
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
734
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
735
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
736
- output_type: Optional[str] = "pil",
737
- return_dict: bool = True,
738
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
739
- max_sequence_length: int = 300,
740
- **kwargs,
741
- ):
742
- height = height or self.default_sample_size * self.vae_scale_factor
743
- width = width or self.default_sample_size * self.vae_scale_factor
744
-
745
- # 1. Check inputs
746
- self.check_inputs(
747
- prompt,
748
- prompt_2,
749
- height,
750
- width,
751
- negative_prompt=negative_prompt,
752
- negative_prompt_2=negative_prompt_2,
753
- prompt_embeds=prompt_embeds,
754
- negative_prompt_embeds=negative_prompt_embeds,
755
- pooled_prompt_embeds=pooled_prompt_embeds,
756
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
757
- max_sequence_length=max_sequence_length,
758
- )
759
-
760
- self._guidance_scale = guidance_scale
761
- self._joint_attention_kwargs = joint_attention_kwargs
762
- self._interrupt = False
763
-
764
- # 2. Define call parameters
765
- if prompt is not None and isinstance(prompt, str):
766
- batch_size = 1
767
- elif prompt is not None and isinstance(prompt, list):
768
- batch_size = len(prompt)
769
- else:
770
- batch_size = prompt_embeds.shape[0]
771
-
772
- device = self._execution_device
773
-
774
- do_classifier_free_guidance = guidance_scale > 1.0
775
-
776
- lora_scale = (
777
- self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
778
- )
779
- (
780
- prompt_embeds,
781
- negative_prompt_embeds,
782
- pooled_prompt_embeds,
783
- negative_pooled_prompt_embeds,
784
- ) = self.encode_prompt(
785
- prompt=prompt,
786
- prompt_2=prompt_2,
787
- negative_prompt=negative_prompt,
788
- negative_prompt_2=negative_prompt_2,
789
- do_classifier_free_guidance=self.do_classifier_free_guidance,
790
- prompt_embeds=prompt_embeds,
791
- negative_prompt_embeds=negative_prompt_embeds,
792
- pooled_prompt_embeds=pooled_prompt_embeds,
793
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
794
- device=device,
795
- num_images_per_prompt=num_images_per_prompt,
796
- max_sequence_length=max_sequence_length,
797
- lora_scale=lora_scale,
798
- )
799
-
800
- if self.do_classifier_free_guidance:
801
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
802
- pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
803
-
804
- # 4. Prepare latent variables
805
- num_channels_latents = self.transformer.config.in_channels // 4
806
- latents, latent_image_ids = self.prepare_latents(
807
- batch_size * num_images_per_prompt,
808
- num_channels_latents,
809
- height,
810
- width,
811
- prompt_embeds.dtype,
812
- negative_prompt_embeds.dtype,
813
- device,
814
- generator,
815
- latents,
816
- )
817
-
818
- # 5. Prepare timesteps
819
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
820
- image_seq_len = latents.shape[1]
821
- mu = calculate_timestep_shift(image_seq_len)
822
- timesteps, num_inference_steps = prepare_timesteps(
823
- self.scheduler,
824
- num_inference_steps,
825
- device,
826
- timesteps,
827
- sigmas,
828
- mu=mu,
829
- )
830
- self._num_timesteps = len(timesteps)
831
-
832
- # Handle guidance
833
- guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float16).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
834
-
835
- # 6. Denoising loop
836
- with self.progress_bar(total=num_inference_steps) as progress_bar:
837
- for i, t in enumerate(timesteps):
838
- if self.interrupt:
839
- continue
840
-
841
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
842
-
843
- timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
844
-
845
- if self.transformer.config.guidance_embeds:
846
- guidance = torch.tensor([guidance_scale], device=device)
847
- guidance = guidance.expand(latents.shape[0])
848
- else:
849
- guidance = None
850
-
851
- noise_pred = self.transformer(
852
- hidden_states=latent_model_input,
853
- timestep=timestep / 1000,
854
- guidance=guidance,
855
- pooled_projections=pooled_prompt_embeds,
856
- encoder_hidden_states=prompt_embeds,
857
- txt_ids=text_ids,
858
- img_ids=latent_image_ids,
859
- joint_attention_kwargs=self.joint_attention_kwargs,
860
- return_dict=False,
861
- )[0]
862
-
863
- noise_pred_uncond = self.transformer(
864
- hidden_states=latents,
865
- timestep=timestep / 1000,
866
- guidance=guidance,
867
- pooled_projections=negative_pooled_prompt_embeds,
868
- encoder_hidden_states=negative_prompt_embeds,
869
- img_ids=latent_image_ids,
870
- joint_attention_kwargs=self.joint_attention_kwargs,
871
- return_dict=False,
872
- )[0]
873
-
874
- if self.do_classifier_free_guidance:
875
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
876
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
877
-
878
- # compute the previous noisy sample x_t -> x_t-1
879
- latents_dtype = latents.dtype
880
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
881
-
882
- if latents.dtype != latents_dtype:
883
- if torch.backends.mps.is_available():
884
- # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
885
- latents = latents.to(latents_dtype)
886
-
887
- # call the callback, if provided
888
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
889
- progress_bar.update()
890
-
891
- # Final image
892
- return self._decode_latents_to_image(latents, height, width, output_type)
893
- self.maybe_free_model_hooks()
894
- torch.cuda.empty_cache()
895
-
896
- def _decode_latents_to_image(self, latents, height, width, output_type, vae=None):
897
- """Decodes the given latents into an image."""
898
- vae = vae or self.vae
899
- latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
900
- latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
901
- image = vae.decode(latents, return_dict=False)[0]
902
- return self.image_processor.postprocess(image, output_type=output_type)[0]