AlekseyCalvin commited on
Commit
6d55428
·
verified ·
1 Parent(s): 2711b32

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +70 -29
pipeline.py CHANGED
@@ -323,22 +323,22 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
323
  text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
324
 
325
  return prompt_embeds, pooled_prompt_embeds, text_ids, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
326
-
327
- def check_inputs(
328
- self,
329
- prompt,
330
- prompt_2,
331
- height,
332
- width,
333
- negative_prompt=None,
334
- negative_prompt_2=None,
335
- prompt_embeds=None,
336
- negative_prompt_embeds=None,
337
- pooled_prompt_embeds=None,
338
- negative_pooled_prompt_embeds=None,
339
- callback_on_step_end_tensor_inputs=None,
340
- max_sequence_length=None,
341
- ):
342
  if height % 8 != 0 or width % 8 != 0:
343
  raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
344
 
@@ -368,6 +368,10 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
368
  elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
369
  raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
370
 
 
 
 
 
371
  if negative_prompt is not None and negative_prompt_embeds is not None:
372
  raise ValueError(
373
  f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
@@ -378,18 +382,11 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
378
  f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
379
  f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
380
  )
381
-
382
- if prompt_embeds is not None and pooled_prompt_embeds is None:
383
- raise ValueError(
384
- "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
385
- )
386
  if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
387
  raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
388
 
389
  if max_sequence_length is not None and max_sequence_length > 512:
390
  raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
391
-
392
- return prompt_embeds, negative_prompt_embeds
393
 
394
  @staticmethod
395
  def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
@@ -528,7 +525,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
528
 
529
  @torch.no_grad()
530
  @torch.inference_mode()
531
- def generate_image(
532
  self,
533
  prompt: Union[str, List[str]] = None,
534
  prompt_2: Optional[Union[str, List[str]]] = None,
@@ -538,7 +535,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
538
  negative_prompt_2: Optional[Union[str, List[str]]] = None,
539
  num_inference_steps: int = 8,
540
  timesteps: List[int] = None,
541
- eta: float = 0.0,
542
  guidance_scale: float = 3.5,
543
  num_images_per_prompt: Optional[int] = 1,
544
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@@ -554,6 +551,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
554
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
555
  clip_skip: Optional[int] = None,
556
  max_sequence_length: int = 300,
 
557
  ):
558
  height = height or self.default_sample_size * self.vae_scale_factor
559
  width = width or self.default_sample_size * self.vae_scale_factor
@@ -572,7 +570,6 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
572
  negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
573
  callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
574
  max_sequence_length=max_sequence_length,
575
- lora_scale=lora_scale
576
  )
577
 
578
  self._guidance_scale = guidance_scale
@@ -595,6 +592,27 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
595
  lora_scale = (
596
  self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
597
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
598
 
599
  if self.do_classifier_free_guidance:
600
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
@@ -699,7 +717,9 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
699
  return self._decode_latents_to_image(latents, height, width, output_type)
700
  self.maybe_free_model_hooks()
701
  torch.cuda.empty_cache()
702
-
 
 
703
  def __call__(
704
  self,
705
  prompt: Union[str, List[str]] = None,
@@ -710,7 +730,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
710
  negative_prompt_2: Optional[Union[str, List[str]]] = None,
711
  num_inference_steps: int = 8,
712
  timesteps: List[int] = None,
713
- eta: float = 0.0,
714
  guidance_scale: float = 3.5,
715
  num_images_per_prompt: Optional[int] = 1,
716
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@@ -726,6 +746,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
726
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
727
  clip_skip: Optional[int] = None,
728
  max_sequence_length: int = 300,
 
729
  ):
730
  height = height or self.default_sample_size * self.vae_scale_factor
731
  width = width or self.default_sample_size * self.vae_scale_factor
@@ -744,7 +765,6 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
744
  negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
745
  callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
746
  max_sequence_length=max_sequence_length,
747
- lora_scale=lora_scale
748
  )
749
 
750
  self._guidance_scale = guidance_scale
@@ -767,6 +787,27 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
767
  lora_scale = (
768
  self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
769
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
770
 
771
  if self.do_classifier_free_guidance:
772
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
 
323
  text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
324
 
325
  return prompt_embeds, pooled_prompt_embeds, text_ids, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
326
+
327
+ def check_inputs(
328
+ self,
329
+ prompt,
330
+ prompt_2,
331
+ height,
332
+ width,
333
+ negative_prompt=None,
334
+ negative_prompt_2=None,
335
+ prompt_embeds=None,
336
+ negative_prompt_embeds=None,
337
+ pooled_prompt_embeds=None,
338
+ negative_pooled_prompt_embeds=None,
339
+ callback_on_step_end_tensor_inputs=None,
340
+ max_sequence_length=None,
341
+ ):
342
  if height % 8 != 0 or width % 8 != 0:
343
  raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
344
 
 
368
  elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
369
  raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
370
 
371
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
372
+ raise ValueError(
373
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
374
+ )
375
  if negative_prompt is not None and negative_prompt_embeds is not None:
376
  raise ValueError(
377
  f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
 
382
  f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
383
  f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
384
  )
 
 
 
 
 
385
  if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
386
  raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
387
 
388
  if max_sequence_length is not None and max_sequence_length > 512:
389
  raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
 
 
390
 
391
  @staticmethod
392
  def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
 
525
 
526
  @torch.no_grad()
527
  @torch.inference_mode()
528
+ def generate_images(
529
  self,
530
  prompt: Union[str, List[str]] = None,
531
  prompt_2: Optional[Union[str, List[str]]] = None,
 
535
  negative_prompt_2: Optional[Union[str, List[str]]] = None,
536
  num_inference_steps: int = 8,
537
  timesteps: List[int] = None,
538
+ eta: Optional[float] = 0.0,
539
  guidance_scale: float = 3.5,
540
  num_images_per_prompt: Optional[int] = 1,
541
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
 
551
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
552
  clip_skip: Optional[int] = None,
553
  max_sequence_length: int = 300,
554
+ **kwargs,
555
  ):
556
  height = height or self.default_sample_size * self.vae_scale_factor
557
  width = width or self.default_sample_size * self.vae_scale_factor
 
570
  negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
571
  callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
572
  max_sequence_length=max_sequence_length,
 
573
  )
574
 
575
  self._guidance_scale = guidance_scale
 
592
  lora_scale = (
593
  self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
594
  )
595
+ (
596
+ prompt_embeds,
597
+ negative_prompt_embeds,
598
+ pooled_prompt_embeds,
599
+ negative_pooled_prompt_embeds,
600
+ ) = self.encode_prompt(
601
+ prompt=prompt,
602
+ prompt_2=prompt_2,
603
+ negative_prompt=negative_prompt,
604
+ negative_prompt_2=negative_prompt_2,
605
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
606
+ prompt_embeds=prompt_embeds,
607
+ negative_prompt_embeds=negative_prompt_embeds,
608
+ pooled_prompt_embeds=pooled_prompt_embeds,
609
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
610
+ device=device,
611
+ clip_skip=self.clip_skip,
612
+ num_images_per_prompt=num_images_per_prompt,
613
+ max_sequence_length=max_sequence_length,
614
+ lora_scale=lora_scale,
615
+ )
616
 
617
  if self.do_classifier_free_guidance:
618
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
 
717
  return self._decode_latents_to_image(latents, height, width, output_type)
718
  self.maybe_free_model_hooks()
719
  torch.cuda.empty_cache()
720
+
721
+ @torch.no_grad()
722
+ @torch.inference_mode()
723
  def __call__(
724
  self,
725
  prompt: Union[str, List[str]] = None,
 
730
  negative_prompt_2: Optional[Union[str, List[str]]] = None,
731
  num_inference_steps: int = 8,
732
  timesteps: List[int] = None,
733
+ eta: Optional[float] = 0.0,
734
  guidance_scale: float = 3.5,
735
  num_images_per_prompt: Optional[int] = 1,
736
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
 
746
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
747
  clip_skip: Optional[int] = None,
748
  max_sequence_length: int = 300,
749
+ **kwargs,
750
  ):
751
  height = height or self.default_sample_size * self.vae_scale_factor
752
  width = width or self.default_sample_size * self.vae_scale_factor
 
765
  negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
766
  callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
767
  max_sequence_length=max_sequence_length,
 
768
  )
769
 
770
  self._guidance_scale = guidance_scale
 
787
  lora_scale = (
788
  self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
789
  )
790
+ (
791
+ prompt_embeds,
792
+ negative_prompt_embeds,
793
+ pooled_prompt_embeds,
794
+ negative_pooled_prompt_embeds,
795
+ ) = self.encode_prompt(
796
+ prompt=prompt,
797
+ prompt_2=prompt_2,
798
+ negative_prompt=negative_prompt,
799
+ negative_prompt_2=negative_prompt_2,
800
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
801
+ prompt_embeds=prompt_embeds,
802
+ negative_prompt_embeds=negative_prompt_embeds,
803
+ pooled_prompt_embeds=pooled_prompt_embeds,
804
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
805
+ device=device,
806
+ clip_skip=self.clip_skip,
807
+ num_images_per_prompt=num_images_per_prompt,
808
+ max_sequence_length=max_sequence_length,
809
+ lora_scale=lora_scale,
810
+ )
811
 
812
  if self.do_classifier_free_guidance:
813
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)