Spaces:
Running
on
Zero
Running
on
Zero
AlekseyCalvin
commited on
Update pipeline.py
Browse files- pipeline.py +70 -29
pipeline.py
CHANGED
@@ -323,22 +323,22 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
323 |
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
324 |
|
325 |
return prompt_embeds, pooled_prompt_embeds, text_ids, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
if height % 8 != 0 or width % 8 != 0:
|
343 |
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
344 |
|
@@ -368,6 +368,10 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
368 |
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
369 |
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
370 |
|
|
|
|
|
|
|
|
|
371 |
if negative_prompt is not None and negative_prompt_embeds is not None:
|
372 |
raise ValueError(
|
373 |
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
@@ -378,18 +382,11 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
378 |
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
|
379 |
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
380 |
)
|
381 |
-
|
382 |
-
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
383 |
-
raise ValueError(
|
384 |
-
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
385 |
-
)
|
386 |
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
387 |
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
388 |
|
389 |
if max_sequence_length is not None and max_sequence_length > 512:
|
390 |
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
391 |
-
|
392 |
-
return prompt_embeds, negative_prompt_embeds
|
393 |
|
394 |
@staticmethod
|
395 |
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
@@ -528,7 +525,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
528 |
|
529 |
@torch.no_grad()
|
530 |
@torch.inference_mode()
|
531 |
-
def
|
532 |
self,
|
533 |
prompt: Union[str, List[str]] = None,
|
534 |
prompt_2: Optional[Union[str, List[str]]] = None,
|
@@ -538,7 +535,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
538 |
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
539 |
num_inference_steps: int = 8,
|
540 |
timesteps: List[int] = None,
|
541 |
-
eta: float = 0.0,
|
542 |
guidance_scale: float = 3.5,
|
543 |
num_images_per_prompt: Optional[int] = 1,
|
544 |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
@@ -554,6 +551,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
554 |
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
555 |
clip_skip: Optional[int] = None,
|
556 |
max_sequence_length: int = 300,
|
|
|
557 |
):
|
558 |
height = height or self.default_sample_size * self.vae_scale_factor
|
559 |
width = width or self.default_sample_size * self.vae_scale_factor
|
@@ -572,7 +570,6 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
572 |
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
573 |
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
574 |
max_sequence_length=max_sequence_length,
|
575 |
-
lora_scale=lora_scale
|
576 |
)
|
577 |
|
578 |
self._guidance_scale = guidance_scale
|
@@ -595,6 +592,27 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
595 |
lora_scale = (
|
596 |
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
597 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
598 |
|
599 |
if self.do_classifier_free_guidance:
|
600 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
@@ -699,7 +717,9 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
699 |
return self._decode_latents_to_image(latents, height, width, output_type)
|
700 |
self.maybe_free_model_hooks()
|
701 |
torch.cuda.empty_cache()
|
702 |
-
|
|
|
|
|
703 |
def __call__(
|
704 |
self,
|
705 |
prompt: Union[str, List[str]] = None,
|
@@ -710,7 +730,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
710 |
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
711 |
num_inference_steps: int = 8,
|
712 |
timesteps: List[int] = None,
|
713 |
-
eta: float = 0.0,
|
714 |
guidance_scale: float = 3.5,
|
715 |
num_images_per_prompt: Optional[int] = 1,
|
716 |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
@@ -726,6 +746,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
726 |
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
727 |
clip_skip: Optional[int] = None,
|
728 |
max_sequence_length: int = 300,
|
|
|
729 |
):
|
730 |
height = height or self.default_sample_size * self.vae_scale_factor
|
731 |
width = width or self.default_sample_size * self.vae_scale_factor
|
@@ -744,7 +765,6 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
744 |
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
745 |
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
746 |
max_sequence_length=max_sequence_length,
|
747 |
-
lora_scale=lora_scale
|
748 |
)
|
749 |
|
750 |
self._guidance_scale = guidance_scale
|
@@ -767,6 +787,27 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
767 |
lora_scale = (
|
768 |
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
769 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
770 |
|
771 |
if self.do_classifier_free_guidance:
|
772 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
|
|
323 |
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
324 |
|
325 |
return prompt_embeds, pooled_prompt_embeds, text_ids, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
326 |
+
|
327 |
+
def check_inputs(
|
328 |
+
self,
|
329 |
+
prompt,
|
330 |
+
prompt_2,
|
331 |
+
height,
|
332 |
+
width,
|
333 |
+
negative_prompt=None,
|
334 |
+
negative_prompt_2=None,
|
335 |
+
prompt_embeds=None,
|
336 |
+
negative_prompt_embeds=None,
|
337 |
+
pooled_prompt_embeds=None,
|
338 |
+
negative_pooled_prompt_embeds=None,
|
339 |
+
callback_on_step_end_tensor_inputs=None,
|
340 |
+
max_sequence_length=None,
|
341 |
+
):
|
342 |
if height % 8 != 0 or width % 8 != 0:
|
343 |
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
344 |
|
|
|
368 |
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
369 |
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
370 |
|
371 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
372 |
+
raise ValueError(
|
373 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
374 |
+
)
|
375 |
if negative_prompt is not None and negative_prompt_embeds is not None:
|
376 |
raise ValueError(
|
377 |
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
|
|
382 |
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
|
383 |
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
384 |
)
|
|
|
|
|
|
|
|
|
|
|
385 |
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
386 |
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
387 |
|
388 |
if max_sequence_length is not None and max_sequence_length > 512:
|
389 |
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
|
|
|
|
390 |
|
391 |
@staticmethod
|
392 |
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
|
|
525 |
|
526 |
@torch.no_grad()
|
527 |
@torch.inference_mode()
|
528 |
+
def generate_images(
|
529 |
self,
|
530 |
prompt: Union[str, List[str]] = None,
|
531 |
prompt_2: Optional[Union[str, List[str]]] = None,
|
|
|
535 |
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
536 |
num_inference_steps: int = 8,
|
537 |
timesteps: List[int] = None,
|
538 |
+
eta: Optional[float] = 0.0,
|
539 |
guidance_scale: float = 3.5,
|
540 |
num_images_per_prompt: Optional[int] = 1,
|
541 |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
|
|
551 |
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
552 |
clip_skip: Optional[int] = None,
|
553 |
max_sequence_length: int = 300,
|
554 |
+
**kwargs,
|
555 |
):
|
556 |
height = height or self.default_sample_size * self.vae_scale_factor
|
557 |
width = width or self.default_sample_size * self.vae_scale_factor
|
|
|
570 |
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
571 |
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
572 |
max_sequence_length=max_sequence_length,
|
|
|
573 |
)
|
574 |
|
575 |
self._guidance_scale = guidance_scale
|
|
|
592 |
lora_scale = (
|
593 |
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
594 |
)
|
595 |
+
(
|
596 |
+
prompt_embeds,
|
597 |
+
negative_prompt_embeds,
|
598 |
+
pooled_prompt_embeds,
|
599 |
+
negative_pooled_prompt_embeds,
|
600 |
+
) = self.encode_prompt(
|
601 |
+
prompt=prompt,
|
602 |
+
prompt_2=prompt_2,
|
603 |
+
negative_prompt=negative_prompt,
|
604 |
+
negative_prompt_2=negative_prompt_2,
|
605 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
606 |
+
prompt_embeds=prompt_embeds,
|
607 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
608 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
609 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
610 |
+
device=device,
|
611 |
+
clip_skip=self.clip_skip,
|
612 |
+
num_images_per_prompt=num_images_per_prompt,
|
613 |
+
max_sequence_length=max_sequence_length,
|
614 |
+
lora_scale=lora_scale,
|
615 |
+
)
|
616 |
|
617 |
if self.do_classifier_free_guidance:
|
618 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
|
|
717 |
return self._decode_latents_to_image(latents, height, width, output_type)
|
718 |
self.maybe_free_model_hooks()
|
719 |
torch.cuda.empty_cache()
|
720 |
+
|
721 |
+
@torch.no_grad()
|
722 |
+
@torch.inference_mode()
|
723 |
def __call__(
|
724 |
self,
|
725 |
prompt: Union[str, List[str]] = None,
|
|
|
730 |
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
731 |
num_inference_steps: int = 8,
|
732 |
timesteps: List[int] = None,
|
733 |
+
eta: Optional[float] = 0.0,
|
734 |
guidance_scale: float = 3.5,
|
735 |
num_images_per_prompt: Optional[int] = 1,
|
736 |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
|
|
746 |
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
747 |
clip_skip: Optional[int] = None,
|
748 |
max_sequence_length: int = 300,
|
749 |
+
**kwargs,
|
750 |
):
|
751 |
height = height or self.default_sample_size * self.vae_scale_factor
|
752 |
width = width or self.default_sample_size * self.vae_scale_factor
|
|
|
765 |
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
766 |
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
767 |
max_sequence_length=max_sequence_length,
|
|
|
768 |
)
|
769 |
|
770 |
self._guidance_scale = guidance_scale
|
|
|
787 |
lora_scale = (
|
788 |
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
789 |
)
|
790 |
+
(
|
791 |
+
prompt_embeds,
|
792 |
+
negative_prompt_embeds,
|
793 |
+
pooled_prompt_embeds,
|
794 |
+
negative_pooled_prompt_embeds,
|
795 |
+
) = self.encode_prompt(
|
796 |
+
prompt=prompt,
|
797 |
+
prompt_2=prompt_2,
|
798 |
+
negative_prompt=negative_prompt,
|
799 |
+
negative_prompt_2=negative_prompt_2,
|
800 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
801 |
+
prompt_embeds=prompt_embeds,
|
802 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
803 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
804 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
805 |
+
device=device,
|
806 |
+
clip_skip=self.clip_skip,
|
807 |
+
num_images_per_prompt=num_images_per_prompt,
|
808 |
+
max_sequence_length=max_sequence_length,
|
809 |
+
lora_scale=lora_scale,
|
810 |
+
)
|
811 |
|
812 |
if self.do_classifier_free_guidance:
|
813 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|