Spaces:
Running
on
Zero
Running
on
Zero
AlekseyCalvin
commited on
Commit
•
f717fb3
1
Parent(s):
6d55428
Update pipeline.py
Browse files- pipeline.py +122 -112
pipeline.py
CHANGED
@@ -655,63 +655,68 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
655 |
if self.interrupt:
|
656 |
continue
|
657 |
|
658 |
-
|
659 |
-
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
latents =
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
-
|
710 |
-
|
|
|
|
|
|
|
|
|
|
|
711 |
|
712 |
# call the callback, if provided
|
713 |
-
|
714 |
-
|
715 |
|
716 |
# Final image
|
717 |
return self._decode_latents_to_image(latents, height, width, output_type)
|
@@ -850,65 +855,70 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
850 |
if self.interrupt:
|
851 |
continue
|
852 |
|
853 |
-
|
854 |
-
|
855 |
-
|
856 |
-
|
857 |
-
|
858 |
-
|
859 |
-
|
860 |
-
|
861 |
-
|
862 |
-
|
863 |
-
|
864 |
-
|
865 |
-
|
866 |
-
|
867 |
-
|
868 |
-
|
869 |
-
|
870 |
-
|
871 |
-
|
872 |
-
|
873 |
-
|
874 |
-
|
875 |
-
|
876 |
-
|
877 |
-
|
878 |
-
|
879 |
-
|
880 |
-
|
881 |
-
|
882 |
-
|
883 |
-
|
884 |
-
|
885 |
-
|
886 |
-
|
887 |
-
|
888 |
-
|
889 |
-
|
890 |
-
|
891 |
-
|
892 |
-
latents =
|
893 |
-
|
894 |
-
|
895 |
-
|
896 |
-
|
897 |
-
|
898 |
-
|
899 |
-
|
900 |
-
|
901 |
-
|
902 |
-
|
903 |
-
|
904 |
-
|
905 |
-
|
|
|
|
|
|
|
|
|
|
|
906 |
|
907 |
# call the callback, if provided
|
908 |
-
|
909 |
-
|
910 |
-
|
911 |
-
|
912 |
return self._decode_latents_to_image(latents, height, width, output_type)
|
913 |
self.maybe_free_model_hooks()
|
914 |
torch.cuda.empty_cache()
|
|
|
655 |
if self.interrupt:
|
656 |
continue
|
657 |
|
658 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
659 |
+
|
660 |
+
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
|
661 |
+
|
662 |
+
if self.transformer.config.guidance_embeds:
|
663 |
+
guidance = torch.tensor([guidance_scale], device=device)
|
664 |
+
guidance = guidance.expand(latents.shape[0])
|
665 |
+
else:
|
666 |
+
guidance = None
|
667 |
+
|
668 |
+
noise_pred = self.transformer(
|
669 |
+
hidden_states=latent_model_input,
|
670 |
+
timestep=timestep / 1000,
|
671 |
+
guidance=guidance,
|
672 |
+
pooled_projections=pooled_prompt_embeds,
|
673 |
+
encoder_hidden_states=prompt_embeds,
|
674 |
+
txt_ids=text_ids,
|
675 |
+
img_ids=latent_image_ids,
|
676 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
677 |
+
return_dict=False,
|
678 |
+
)[0]
|
679 |
+
|
680 |
+
noise_pred_uncond = self.transformer(
|
681 |
+
hidden_states=latents,
|
682 |
+
timestep=timestep / 1000,
|
683 |
+
guidance=guidance,
|
684 |
+
pooled_projections=negative_pooled_prompt_embeds,
|
685 |
+
encoder_hidden_states=negative_prompt_embeds,
|
686 |
+
img_ids=latent_image_ids,
|
687 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
688 |
+
return_dict=False,
|
689 |
+
)[0]
|
690 |
+
|
691 |
+
if self.do_classifier_free_guidance:
|
692 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
693 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
694 |
+
|
695 |
+
# compute the previous noisy sample x_t -> x_t-1
|
696 |
+
latents_dtype = latents.dtype
|
697 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
698 |
+
|
699 |
+
if latents.dtype != latents_dtype:
|
700 |
+
if torch.backends.mps.is_available():
|
701 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
702 |
+
latents = latents.to(latents_dtype)
|
703 |
+
|
704 |
+
if callback_on_step_end is not None:
|
705 |
+
callback_kwargs = {}
|
706 |
+
for k in callback_on_step_end_tensor_inputs:
|
707 |
+
callback_kwargs[k] = locals()[k]
|
708 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
709 |
+
|
710 |
+
latents = callback_outputs.pop("latents", latents)
|
711 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
712 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
713 |
+
negative_pooled_prompt_embeds = callback_outputs.pop(
|
714 |
+
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
715 |
+
)
|
716 |
|
717 |
# call the callback, if provided
|
718 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
719 |
+
progress_bar.update()
|
720 |
|
721 |
# Final image
|
722 |
return self._decode_latents_to_image(latents, height, width, output_type)
|
|
|
855 |
if self.interrupt:
|
856 |
continue
|
857 |
|
858 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
859 |
+
|
860 |
+
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
|
861 |
+
|
862 |
+
if self.transformer.config.guidance_embeds:
|
863 |
+
guidance = torch.tensor([guidance_scale], device=device)
|
864 |
+
guidance = guidance.expand(latents.shape[0])
|
865 |
+
else:
|
866 |
+
guidance = None
|
867 |
+
|
868 |
+
noise_pred = self.transformer(
|
869 |
+
hidden_states=latent_model_input,
|
870 |
+
timestep=timestep / 1000,
|
871 |
+
guidance=guidance,
|
872 |
+
pooled_projections=pooled_prompt_embeds,
|
873 |
+
encoder_hidden_states=prompt_embeds,
|
874 |
+
txt_ids=text_ids,
|
875 |
+
img_ids=latent_image_ids,
|
876 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
877 |
+
return_dict=False,
|
878 |
+
)[0]
|
879 |
+
|
880 |
+
noise_pred_uncond = self.transformer(
|
881 |
+
hidden_states=latents,
|
882 |
+
timestep=timestep / 1000,
|
883 |
+
guidance=guidance,
|
884 |
+
pooled_projections=negative_pooled_prompt_embeds,
|
885 |
+
encoder_hidden_states=negative_prompt_embeds,
|
886 |
+
img_ids=latent_image_ids,
|
887 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
888 |
+
return_dict=False,
|
889 |
+
)[0]
|
890 |
+
|
891 |
+
if self.do_classifier_free_guidance:
|
892 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
893 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
894 |
+
|
895 |
+
# compute the previous noisy sample x_t -> x_t-1
|
896 |
+
latents_dtype = latents.dtype
|
897 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
898 |
+
|
899 |
+
if latents.dtype != latents_dtype:
|
900 |
+
if torch.backends.mps.is_available():
|
901 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
902 |
+
latents = latents.to(latents_dtype)
|
903 |
+
|
904 |
+
if callback_on_step_end is not None:
|
905 |
+
callback_kwargs = {}
|
906 |
+
for k in callback_on_step_end_tensor_inputs:
|
907 |
+
callback_kwargs[k] = locals()[k]
|
908 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
909 |
+
|
910 |
+
latents = callback_outputs.pop("latents", latents)
|
911 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
912 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
913 |
+
negative_pooled_prompt_embeds = callback_outputs.pop(
|
914 |
+
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
915 |
+
)
|
916 |
|
917 |
# call the callback, if provided
|
918 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
919 |
+
progress_bar.update()
|
920 |
+
|
921 |
+
# Final image
|
922 |
return self._decode_latents_to_image(latents, height, width, output_type)
|
923 |
self.maybe_free_model_hooks()
|
924 |
torch.cuda.empty_cache()
|