AlekseyCalvin commited on
Commit
cb4a9fb
1 Parent(s): caaeec1

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +45 -59
pipeline.py CHANGED
@@ -292,11 +292,9 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
292
  unscale_lora_layers(self.text_encoder_2, lora_scale)
293
 
294
  dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
295
- text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
296
- text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
297
- negative_text_ids = torch.zeros(batch_size, negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
298
 
299
- return prompt_embeds, pooled_prompt_embeds, text_ids, negative_prompt_embeds, negative_pooled_prompt_embeds, negative_text_ids
300
 
301
  def check_inputs(
302
  self,
@@ -485,13 +483,11 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
485
  self,
486
  prompt: Union[str, List[str]] = None,
487
  prompt_2: Optional[Union[str, List[str]]] = None,
 
488
  height: Optional[int] = None,
489
  width: Optional[int] = None,
490
- negative_prompt: Optional[Union[str, List[str]]] = None,
491
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
492
  num_inference_steps: int = 8,
493
  timesteps: List[int] = None,
494
- eta: Optional[float] = 0.0,
495
  guidance_scale: float = 3.5,
496
  device: Optional[int] = None,
497
  num_images_per_prompt: Optional[int] = 1,
@@ -499,14 +495,13 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
499
  latents: Optional[torch.FloatTensor] = None,
500
  prompt_embeds: Optional[torch.FloatTensor] = None,
501
  pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
502
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
503
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
504
  output_type: Optional[str] = "pil",
505
  cfg: Optional[bool] = True,
506
  return_dict: bool = True,
507
  joint_attention_kwargs: Optional[Dict[str, Any]] = None,
 
 
508
  max_sequence_length: int = 512,
509
- **kwargs,
510
  ):
511
  height = height or self.default_sample_size * self.vae_scale_factor
512
  width = width or self.default_sample_size * self.vae_scale_factor
@@ -518,9 +513,8 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
518
  height,
519
  width,
520
  prompt_embeds=prompt_embeds,
521
- negative_prompt_embeds=negative_prompt_embeds,
522
  pooled_prompt_embeds=pooled_prompt_embeds,
523
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
524
  max_sequence_length=max_sequence_length,
525
  )
526
 
@@ -546,21 +540,16 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
546
  pooled_prompt_embeds,
547
  text_ids,
548
  negative_prompt_embeds,
549
- negative_pooled_prompt_embeds,
550
- negative_text_ids,
551
  ) = self.encode_prompt(
552
  prompt=prompt,
553
  prompt_2=prompt_2,
554
  num_images_per_prompt=num_images_per_prompt,
555
  max_sequence_length=max_sequence_length,
556
- do_classifier_free_guidance=self.do_classifier_free_guidance,
557
  device=device,
558
  negative_prompt=negative_prompt,
559
- negative_prompt_2=negative_prompt_2,
560
  prompt_embeds=prompt_embeds,
561
- negative_prompt_embeds=negative_prompt_embeds,
562
  pooled_prompt_embeds=pooled_prompt_embeds,
563
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
564
  lora_scale=lora_scale,
565
  )
566
 
@@ -607,67 +596,64 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
607
  for i, t in enumerate(timesteps):
608
  if self.interrupt:
609
  continue
610
-
 
611
  latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
612
-
613
- timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
614
 
615
- if self.transformer.config.guidance_embeds:
616
- guidance = torch.tensor([guidance_scale], device=device)
617
- guidance = guidance.expand(latents.shape[0])
618
- else:
619
- guidance = None
620
-
621
- noise_pred_text = self.transformer(
622
  hidden_states=latent_model_input,
623
  timestep=timestep / 1000,
624
- guidance=guidance,
625
- pooled_projections=pooled_prompt_embeds.shape[1],
626
  encoder_hidden_states=prompt_embeds,
627
  txt_ids=text_ids,
628
  img_ids=latent_image_ids,
629
  joint_attention_kwargs=self.joint_attention_kwargs,
630
  return_dict=False,
631
  )[0]
632
- noise_pred_uncond = self.transformer(
633
- hidden_states=latents,
634
- timestep=timestep / 1000,
635
- guidance=guidance,
636
- pooled_projections=negative_pooled_prompt_embeds.shape[1],
637
- encoder_hidden_states=negative_prompt_embeds,
638
- txt_ids=negative_text_ids,
639
- img_ids=latent_image_ids,
640
- joint_attention_kwargs=self.joint_attention_kwargs,
641
- return_dict=False,
642
- )[0]
643
-
644
  if self.do_classifier_free_guidance:
645
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(0)
646
- noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
647
- else: noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
648
-
649
  # compute the previous noisy sample x_t -> x_t-1
650
  latents_dtype = latents.dtype
651
  latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
652
-
653
  if latents.dtype != latents_dtype:
654
  if torch.backends.mps.is_available():
655
  # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
656
  latents = latents.to(latents_dtype)
 
 
 
 
 
 
 
 
 
657
 
658
- # call the callback, if provided
659
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
660
  progress_bar.update()
 
 
 
661
 
662
- # Final image
663
- return self._decode_latents_to_image(latents, height, width, output_type)
 
 
 
 
 
 
 
 
664
  self.maybe_free_model_hooks()
665
- torch.cuda.empty_cache()
666
-
667
- def _decode_latents_to_image(self, latents, height, width, output_type, vae=None):
668
- """Decodes the given latents into an image."""
669
- vae = vae or self.vae
670
- latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
671
- latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
672
- image = vae.decode(latents, return_dict=False)[0]
673
- return self.image_processor.postprocess(image, output_type=output_type)[0]
 
292
  unscale_lora_layers(self.text_encoder_2, lora_scale)
293
 
294
  dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
295
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
 
 
296
 
297
+ return prompt_embeds, pooled_prompt_embeds, text_ids, negative_prompt_embeds, negative_pooled_prompt_embeds
298
 
299
  def check_inputs(
300
  self,
 
483
  self,
484
  prompt: Union[str, List[str]] = None,
485
  prompt_2: Optional[Union[str, List[str]]] = None,
486
+ negative_prompt: Union[str, List[str]] = None,
487
  height: Optional[int] = None,
488
  width: Optional[int] = None,
 
 
489
  num_inference_steps: int = 8,
490
  timesteps: List[int] = None,
 
491
  guidance_scale: float = 3.5,
492
  device: Optional[int] = None,
493
  num_images_per_prompt: Optional[int] = 1,
 
495
  latents: Optional[torch.FloatTensor] = None,
496
  prompt_embeds: Optional[torch.FloatTensor] = None,
497
  pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
 
 
498
  output_type: Optional[str] = "pil",
499
  cfg: Optional[bool] = True,
500
  return_dict: bool = True,
501
  joint_attention_kwargs: Optional[Dict[str, Any]] = None,
502
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
503
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
504
  max_sequence_length: int = 512,
 
505
  ):
506
  height = height or self.default_sample_size * self.vae_scale_factor
507
  width = width or self.default_sample_size * self.vae_scale_factor
 
513
  height,
514
  width,
515
  prompt_embeds=prompt_embeds,
 
516
  pooled_prompt_embeds=pooled_prompt_embeds,
517
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
518
  max_sequence_length=max_sequence_length,
519
  )
520
 
 
540
  pooled_prompt_embeds,
541
  text_ids,
542
  negative_prompt_embeds,
543
+ negative_pooled_prompt_embeds
 
544
  ) = self.encode_prompt(
545
  prompt=prompt,
546
  prompt_2=prompt_2,
547
  num_images_per_prompt=num_images_per_prompt,
548
  max_sequence_length=max_sequence_length,
 
549
  device=device,
550
  negative_prompt=negative_prompt,
 
551
  prompt_embeds=prompt_embeds,
 
552
  pooled_prompt_embeds=pooled_prompt_embeds,
 
553
  lora_scale=lora_scale,
554
  )
555
 
 
596
  for i, t in enumerate(timesteps):
597
  if self.interrupt:
598
  continue
599
+
600
+ # expand the latents if we are doing classifier free guidance
601
  latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
602
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
603
+ timestep = t.expand(latent_model_input.shape[0])
604
 
605
+ noise_pred = self.transformer(
 
 
 
 
 
 
606
  hidden_states=latent_model_input,
607
  timestep=timestep / 1000,
608
+ pooled_projections=pooled_prompt_embeds,
 
609
  encoder_hidden_states=prompt_embeds,
610
  txt_ids=text_ids,
611
  img_ids=latent_image_ids,
612
  joint_attention_kwargs=self.joint_attention_kwargs,
613
  return_dict=False,
614
  )[0]
615
+
 
 
 
 
 
 
 
 
 
 
 
616
  if self.do_classifier_free_guidance:
617
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
618
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
619
+
 
620
  # compute the previous noisy sample x_t -> x_t-1
621
  latents_dtype = latents.dtype
622
  latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
623
+
624
  if latents.dtype != latents_dtype:
625
  if torch.backends.mps.is_available():
626
  # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
627
  latents = latents.to(latents_dtype)
628
+
629
+ if callback_on_step_end is not None:
630
+ callback_kwargs = {}
631
+ for k in callback_on_step_end_tensor_inputs:
632
+ callback_kwargs[k] = locals()[k]
633
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
634
+
635
+ latents = callback_outputs.pop("latents", latents)
636
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
637
 
 
638
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
639
  progress_bar.update()
640
+
641
+ if XLA_AVAILABLE:
642
+ xm.mark_step()
643
 
644
+ if output_type == "latent":
645
+ image = latents
646
+
647
+ else:
648
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
649
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
650
+ image = self.vae.decode(latents, return_dict=False)[0]
651
+ image = self.image_processor.postprocess(image, output_type=output_type)
652
+
653
+ # Offload all models
654
  self.maybe_free_model_hooks()
655
+
656
+ if not return_dict:
657
+ return (image,)
658
+
659
+ return FluxPipelineOutput(images=image)