AlekseyCalvin commited on
Commit
07ec991
·
verified ·
1 Parent(s): 2395f33

Delete pipeline3.py

Browse files
Files changed (1) hide show
  1. pipeline3.py +0 -626
pipeline3.py DELETED
@@ -1,626 +0,0 @@
1
- import torch
2
- import numpy as np
3
- from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast, CLIPTextModelWithProjection
4
- from diffusers import FlowMatchEulerDiscreteScheduler, AutoPipelineForImage2Image, FluxPipeline, FluxTransformer2DModel
5
- from diffusers import StableDiffusion3Pipeline, AutoencoderKL, DiffusionPipeline
6
- from diffusers.image_processor import VaeImageProcessor
7
- from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, SD3LoraLoaderMixin
8
- from diffusers.utils import (
9
- USE_PEFT_BACKEND,
10
- is_torch_xla_available,
11
- logging,
12
- replace_example_docstring,
13
- scale_lora_layers,
14
- unscale_lora_layers,
15
- )
16
- from diffusers.utils.torch_utils import randn_tensor
17
- from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
18
- from typing import Any, Callable, Dict, List, Optional, Union
19
- from PIL import Image
20
- from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps, FluxTransformer2DModel
21
-
22
- from diffusers.utils import is_torch_xla_available
23
-
24
- if is_torch_xla_available():
25
- import torch_xla.core.xla_model as xm
26
-
27
- XLA_AVAILABLE = True
28
- else:
29
- XLA_AVAILABLE = False
30
-
31
-
32
- # Constants for shift calculation
33
- BASE_SEQ_LEN = 256
34
- MAX_SEQ_LEN = 4096
35
- BASE_SHIFT = 0.5
36
- MAX_SHIFT = 1.2
37
-
38
- # Helper functions
39
- def calculate_timestep_shift(image_seq_len: int) -> float:
40
- """Calculates the timestep shift (mu) based on the image sequence length."""
41
- m = (MAX_SHIFT - BASE_SHIFT) / (MAX_SEQ_LEN - BASE_SEQ_LEN)
42
- b = BASE_SHIFT - m * BASE_SEQ_LEN
43
- mu = image_seq_len * m + b
44
- return mu
45
-
46
- def prepare_timesteps(
47
- scheduler: FlowMatchEulerDiscreteScheduler,
48
- num_inference_steps: Optional[int] = None,
49
- device: Optional[Union[str, torch.device]] = None,
50
- timesteps: Optional[List[int]] = None,
51
- sigmas: Optional[List[float]] = None,
52
- mu: Optional[float] = None,
53
- ) -> (torch.Tensor, int):
54
- """Prepares the timesteps for the diffusion process."""
55
- if timesteps is not None and sigmas is not None:
56
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed.")
57
-
58
- if timesteps is not None:
59
- scheduler.set_timesteps(timesteps=timesteps, device=device)
60
- elif sigmas is not None:
61
- scheduler.set_timesteps(sigmas=sigmas, device=device)
62
- else:
63
- scheduler.set_timesteps(num_inference_steps, device=device, mu=mu)
64
-
65
- timesteps = scheduler.timesteps
66
- num_inference_steps = len(timesteps)
67
- return timesteps, num_inference_steps
68
-
69
- # FLUX pipeline function
70
- class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
71
- def __init__(
72
- self,
73
- scheduler: FlowMatchEulerDiscreteScheduler,
74
- vae: AutoencoderKL,
75
- text_encoder: CLIPTextModel,
76
- tokenizer: CLIPTokenizer,
77
- text_encoder_2: T5EncoderModel,
78
- tokenizer_2: T5TokenizerFast,
79
- transformer: FluxTransformer2DModel,
80
- ):
81
- super().__init__()
82
-
83
- self.register_modules(
84
- vae=vae,
85
- text_encoder=text_encoder,
86
- text_encoder_2=text_encoder_2,
87
- tokenizer=tokenizer,
88
- tokenizer_2=tokenizer_2,
89
- transformer=transformer,
90
- scheduler=scheduler,
91
- )
92
- self.vae_scale_factor = (
93
- 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
94
- )
95
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
96
- self.tokenizer_max_length = (
97
- self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
98
- )
99
- self.default_sample_size = 64
100
- def _get_t5_prompt_embeds(
101
- self,
102
- prompt: Union[str, List[str]] = None,
103
- negative_prompt: Optional[Union[str, List[str]]] = None,
104
- num_images_per_prompt: int = 1,
105
- max_sequence_length: int = 512,
106
- device: Optional[torch.device] = None,
107
- dtype: Optional[torch.dtype] = None,
108
- ):
109
- device = device or self._execution_device
110
- dtype = dtype or self.text_encoder.dtype
111
-
112
- prompt = [prompt] if isinstance(prompt, str) else prompt
113
- batch_size = len(prompt)
114
-
115
- text_inputs = self.tokenizer_2(
116
- prompt,
117
- padding="max_length",
118
- max_length=max_sequence_length,
119
- truncation=True,
120
- return_length=False,
121
- return_overflowing_tokens=False,
122
- return_tensors="pt",
123
- )
124
- text_input_ids = text_inputs.input_ids
125
- untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
126
-
127
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
128
- removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
129
- logger.warning(
130
- "The following part of your input was truncated because `max_sequence_length` is set to "
131
- f" {max_sequence_length} tokens: {removed_text}"
132
- )
133
-
134
- prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
135
-
136
- dtype = self.text_encoder_2.dtype
137
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
138
-
139
- _, seq_len, _ = prompt_embeds.shape
140
-
141
- # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
142
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
143
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
144
-
145
- return prompt_embeds
146
-
147
- def _get_clip_prompt_embeds(
148
- self,
149
- prompt: Union[str, List[str]],
150
- num_images_per_prompt: int = 1,
151
- device: Optional[torch.device] = None,
152
- ):
153
- device = device or self._execution_device
154
-
155
- prompt = [prompt] if isinstance(prompt, str) else prompt
156
- batch_size = len(prompt)
157
-
158
- text_inputs = self.tokenizer(
159
- prompt,
160
- negative_prompt,
161
- padding="max_length",
162
- max_length=self.tokenizer_max_length,
163
- truncation=True,
164
- return_overflowing_tokens=False,
165
- return_length=False,
166
- return_tensors="pt",
167
- )
168
-
169
- text_input_ids = text_inputs.input_ids
170
- untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
171
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
172
- removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
173
- logger.warning(
174
- "The following part of your input was truncated because CLIP can only handle sequences up to"
175
- f" {self.tokenizer_max_length} tokens: {removed_text}"
176
- )
177
- prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
178
-
179
- # Use pooled output of CLIPTextModel
180
- prompt_embeds = prompt_embeds.pooler_output
181
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
182
-
183
- # duplicate text embeddings for each generation per prompt, using mps friendly method
184
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
185
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
186
-
187
- return prompt_embeds
188
-
189
- def encode_prompt(
190
- self,
191
- prompt: Union[str, List[str]],
192
- prompt_2: Union[str, List[str]],
193
- do_classifier_free_guidance: bool = True,
194
- negative_prompt: Optional[Union[str, List[str]]] = None,
195
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
196
- device: Optional[torch.device] = None,
197
- num_images_per_prompt: int = 1,
198
- prompt_embeds: Optional[torch.FloatTensor] = None,
199
- negative_prompt_embeds: Optional[torch.Tensor] = None,
200
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
201
- negative_prompt_attention_mask: Optional[torch.Tensor] = None,
202
- max_sequence_length: int = 512,
203
- lora_scale: Optional[float] = None,
204
- adapter_weights: Optional[float] = None,
205
- ):
206
- r"""
207
-
208
- Args:
209
- prompt (`str` or `List[str]`, *optional*):
210
- prompt to be encoded
211
- prompt_2 (`str` or `List[str]`, *optional*):
212
- The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
213
- used in all text-encoders
214
- device: (`torch.device`):
215
- torch device
216
- num_images_per_prompt (`int`):
217
- number of images that should be generated per prompt
218
- prompt_embeds (`torch.FloatTensor`, *optional*):
219
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
220
- provided, text embeddings will be generated from `prompt` input argument.
221
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
222
- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
223
- If not provided, pooled text embeddings will be generated from `prompt` input argument.
224
- lora_scale (`float`, *optional*):
225
- A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
226
- """
227
- device = device or self._execution_device
228
-
229
- # set lora scale so that monkey patched LoRA
230
- # function of text encoder can correctly access it
231
- if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
232
- self._lora_scale = lora_scale
233
-
234
- # dynamically adjust the LoRA scale
235
- if self.text_encoder is not None and USE_PEFT_BACKEND:
236
- scale_lora_layers(self.text_encoder, lora_scale)
237
- if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
238
- scale_lora_layers(self.text_encoder_2, lora_scale)
239
-
240
- prompt = [prompt] if isinstance(prompt, str) else prompt
241
-
242
- if prompt_embeds is None:
243
- prompt_2 = prompt_2 or prompt
244
- prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
245
-
246
- # We only use the pooled prompt output from the CLIPTextModel
247
- pooled_prompt_embeds = self._get_clip_prompt_embeds(
248
- prompt=prompt,
249
- device=device,
250
- num_images_per_prompt=num_images_per_prompt,
251
- )
252
- prompt_embeds = self._get_t5_prompt_embeds(
253
- prompt=prompt_2,
254
- num_images_per_prompt=num_images_per_prompt,
255
- max_sequence_length=max_sequence_length,
256
- device=device,
257
- )
258
-
259
- dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
260
- text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
261
-
262
- return prompt_embeds, pooled_prompt_embeds, text_ids
263
-
264
- def check_inputs(
265
- self,
266
- prompt,
267
- prompt_2,
268
- height,
269
- width,
270
- negative_prompt=None,
271
- lora_scale=None,
272
- prompt_embeds=None,
273
- pooled_prompt_embeds=None,
274
- callback_on_step_end_tensor_inputs=None,
275
- max_sequence_length=None,
276
- ):
277
- if height % 8 != 0 or width % 8 != 0:
278
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
279
-
280
- if callback_on_step_end_tensor_inputs is not None and not all(
281
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
282
- ):
283
- raise ValueError(
284
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
285
- )
286
-
287
- if prompt is not None and prompt_embeds is not None:
288
- raise ValueError(
289
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
290
- " only forward one of the two."
291
- )
292
- elif prompt_2 is not None and prompt_embeds is not None:
293
- raise ValueError(
294
- f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
295
- " only forward one of the two."
296
- )
297
- elif prompt is None and prompt_embeds is None:
298
- raise ValueError(
299
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
300
- )
301
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
302
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
303
- elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
304
- raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
305
-
306
- if prompt_embeds is not None and pooled_prompt_embeds is None:
307
- raise ValueError(
308
- "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
309
- )
310
- if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
311
- raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
312
-
313
- if max_sequence_length is not None and max_sequence_length > 512:
314
- raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
315
-
316
- prompt_attention_mask = text_inputs.attention_mask
317
- prompt_attention_mask = prompt_attention_mask.to(device)
318
-
319
- prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
320
- prompt_embeds = prompt_embeds[0]
321
-
322
- if do_classifier_free_guidance and negative_prompt_embeds is None:
323
- uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
324
- uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
325
- max_length = prompt_embeds.shape[1]
326
- uncond_input = self.tokenizer(
327
- uncond_tokens,
328
- padding="max_length",
329
- max_length=max_length,
330
- truncation=True,
331
- return_attention_mask=True,
332
- add_special_tokens=True,
333
- return_tensors="pt",
334
- )
335
- negative_prompt_attention_mask = uncond_input.attention_mask
336
- negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
337
-
338
- negative_prompt_embeds = self.text_encoder(
339
- uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
340
- )
341
- negative_prompt_embeds = negative_prompt_embeds[0]
342
-
343
- if do_classifier_free_guidance:
344
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
345
- seq_len = negative_prompt_embeds.shape[1]
346
-
347
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
348
-
349
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
350
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
351
-
352
- negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
353
- negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
354
- else:
355
- negative_prompt_embeds = None
356
- negative_prompt_attention_mask = None
357
-
358
- return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
359
-
360
- @staticmethod
361
- def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
362
- latent_image_ids = torch.zeros(height // 2, width // 2, 3)
363
- latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
364
- latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
365
-
366
- latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
367
-
368
- latent_image_ids = latent_image_ids.reshape(
369
- latent_image_id_height * latent_image_id_width, latent_image_id_channels
370
- )
371
-
372
- return latent_image_ids.to(device=device, dtype=dtype)
373
-
374
- @staticmethod
375
- def _pack_latents(latents, batch_size, num_channels_latents, height, width):
376
- latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
377
- latents = latents.permute(0, 2, 4, 1, 3, 5)
378
- latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
379
-
380
- return latents
381
-
382
- @staticmethod
383
- def _unpack_latents(latents, height, width, vae_scale_factor):
384
- batch_size, num_patches, channels = latents.shape
385
-
386
- height = height // vae_scale_factor
387
- width = width // vae_scale_factor
388
-
389
- latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
390
- latents = latents.permute(0, 3, 1, 4, 2, 5)
391
-
392
- latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
393
-
394
- return latents
395
-
396
- def enable_vae_slicing(self):
397
- r"""
398
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
399
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
400
- """
401
- self.vae.enable_slicing()
402
-
403
- def disable_vae_slicing(self):
404
- r"""
405
- Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
406
- computing decoding in one step.
407
- """
408
- self.vae.disable_slicing()
409
-
410
- def enable_vae_tiling(self):
411
- r"""
412
- Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
413
- compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
414
- processing larger images.
415
- """
416
- self.vae.enable_tiling()
417
-
418
- def disable_vae_tiling(self):
419
- r"""
420
- Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
421
- computing decoding in one step.
422
- """
423
- self.vae.disable_tiling()
424
-
425
- def prepare_latents(
426
- self,
427
- batch_size,
428
- num_channels_latents,
429
- height,
430
- width,
431
- dtype,
432
- device,
433
- generator,
434
- latents=None,
435
- ):
436
- height = 2 * (int(height) // self.vae_scale_factor)
437
- width = 2 * (int(width) // self.vae_scale_factor)
438
-
439
- shape = (batch_size, num_channels_latents, height, width)
440
-
441
- if latents is not None:
442
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
443
- return latents.to(device=device, dtype=dtype), latent_image_ids
444
-
445
- if isinstance(generator, list) and len(generator) != batch_size:
446
- raise ValueError(
447
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
448
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
449
- )
450
-
451
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
452
- latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
453
-
454
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
455
-
456
- return latents, latent_image_ids
457
-
458
- @property
459
- def guidance_scale(self):
460
- return self._guidance_scale
461
-
462
- @property
463
- def joint_attention_kwargs(self):
464
- return self._joint_attention_kwargs
465
-
466
- @property
467
- def num_timesteps(self):
468
- return self._num_timesteps
469
-
470
- @property
471
- def interrupt(self):
472
- return self._interrupt
473
-
474
- @torch.no_grad()
475
-
476
- def __call__(
477
- self,
478
- prompt: Union[str, List[str]] = None,
479
- prompt_2: Optional[Union[str, List[str]]] = None,
480
- height: Optional[int] = None,
481
- width: Optional[int] = None,
482
- negative_prompt: Optional[Union[str, List[str]]] = None,
483
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
484
- num_inference_steps: int = 4,
485
- timesteps: List[int] = None,
486
- guidance_scale: float = 3.5,
487
- lora_scale: Optional[torch.FloatTensor] = None,
488
- num_images_per_prompt: Optional[int] = 1,
489
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
490
- latents: Optional[torch.FloatTensor] = None,
491
- prompt_embeds: Optional[torch.FloatTensor] = None,
492
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
493
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
494
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
495
- output_type: Optional[str] = "pil",
496
- return_dict: bool = True,
497
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
498
- max_sequence_length: int = 300,
499
- ):
500
- height = height or self.default_sample_size * self.vae_scale_factor
501
- width = width or self.default_sample_size * self.vae_scale_factor
502
-
503
- # 1. Check inputs
504
- self.check_inputs(
505
- prompt,
506
- prompt_2,
507
- negative_prompt,
508
- height,
509
- width,
510
- prompt_embeds=prompt_embeds,
511
- pooled_prompt_embeds=pooled_prompt_embeds,
512
- max_sequence_length=max_sequence_length,
513
- )
514
-
515
- self._guidance_scale = guidance_scale
516
- self._joint_attention_kwargs = joint_attention_kwargs
517
- self._interrupt = False
518
-
519
- # 2. Define call parameters
520
- batch_size = 1 if isinstance(prompt, str) else len(prompt)
521
- device = "cuda" if torch.cuda.is_available() else "cpu"
522
-
523
- # 3. Encode prompt
524
- lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
525
- prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
526
- prompt=prompt,
527
- prompt_2=prompt_2,
528
- prompt_embeds=prompt_embeds,
529
- pooled_prompt_embeds=pooled_prompt_embeds,
530
- device=device,
531
- num_images_per_prompt=num_images_per_prompt,
532
- max_sequence_length=max_sequence_length,
533
- lora_scale=lora_scale,
534
- )
535
- negative_prompt_embeds, negative_pooled_prompt_embeds, negative_text_ids = self.encode_prompt(
536
- prompt=negative_prompt,
537
- prompt_2=negative_prompt_2,
538
- prompt_embeds=negative_prompt_embeds,
539
- pooled_prompt_embeds=negative_pooled_prompt_embeds,
540
- device=device,
541
- num_images_per_prompt=num_images_per_prompt,
542
- max_sequence_length=max_sequence_length,
543
- lora_scale=lora_scale,
544
- )
545
-
546
- # 4. Prepare latent variables
547
- num_channels_latents = self.transformer.config.in_channels // 4
548
- latents, latent_image_ids = self.prepare_latents(
549
- batch_size * num_images_per_prompt,
550
- num_channels_latents,
551
- height,
552
- width,
553
- prompt_embeds.dtype,
554
- negative_prompt_embeds.dtype,
555
- device,
556
- generator,
557
- latents,
558
- )
559
-
560
- # 5. Prepare timesteps
561
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
562
- image_seq_len = latents.shape[1]
563
- mu = calculate_timestep_shift(image_seq_len)
564
- timesteps, num_inference_steps = prepare_timesteps(
565
- self.scheduler,
566
- num_inference_steps,
567
- device,
568
- timesteps,
569
- sigmas,
570
- mu=mu,
571
- )
572
- self._num_timesteps = len(timesteps)
573
-
574
- # Handle guidance
575
- guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float16).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
576
-
577
- # 6. Denoising loop
578
- for i, t in enumerate(timesteps):
579
- if self.interrupt:
580
- continue
581
-
582
- timestep = t.expand(latents.shape[0]).to(latents.dtype)
583
-
584
- noise_pred = self.transformer(
585
- hidden_states=latents,
586
- timestep=timestep / 1000,
587
- guidance=guidance,
588
- pooled_projections=pooled_prompt_embeds,
589
- encoder_hidden_states=prompt_embeds,
590
- txt_ids=text_ids,
591
- img_ids=latent_image_ids,
592
- joint_attention_kwargs=self.joint_attention_kwargs,
593
- return_dict=False,
594
- )[0]
595
-
596
- noise_pred_uncond = self.transformer(
597
- hidden_states=latents,
598
- timestep=timestep / 1000,
599
- guidance=guidance,
600
- pooled_projections=negative_pooled_prompt_embeds,
601
- encoder_hidden_states=negative_prompt_embeds,
602
- txt_ids=negative_text_ids,
603
- img_ids=latent_image_ids,
604
- joint_attention_kwargs=self.joint_attention_kwargs,
605
- return_dict=False,
606
- )[0]
607
-
608
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
609
-
610
- latents_dtype = latents.dtype
611
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
612
- # Yield intermediate result
613
- torch.cuda.empty_cache()
614
-
615
- # Final image
616
- return self._decode_latents_to_image(latents, height, width, output_type)
617
- self.maybe_free_model_hooks()
618
- torch.cuda.empty_cache()
619
-
620
- def _decode_latents_to_image(self, latents, height, width, output_type, vae=None):
621
- """Decodes the given latents into an image."""
622
- vae = vae or self.vae
623
- latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
624
- latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
625
- image = vae.decode(latents, return_dict=False)[0]
626
- return self.image_processor.postprocess(image, output_type=output_type)[0]