AlekseyCalvin commited on
Commit
f717fb3
1 Parent(s): 6d55428

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +122 -112
pipeline.py CHANGED
@@ -655,63 +655,68 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
655
  if self.interrupt:
656
  continue
657
 
658
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
659
-
660
- timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
661
-
662
- noise_pred = self.transformer(
663
- hidden_states=latent_model_input,
664
- timestep=timestep / 1000,
665
- guidance=guidance,
666
- pooled_projections=pooled_prompt_embeds,
667
- encoder_hidden_states=prompt_embeds,
668
- txt_ids=text_ids,
669
- img_ids=latent_image_ids,
670
- joint_attention_kwargs=self.joint_attention_kwargs,
671
- return_dict=False,
672
- )[0]
673
-
674
- noise_pred_uncond = self.transformer(
675
- hidden_states=latents,
676
- timestep=timestep / 1000,
677
- guidance=guidance,
678
- pooled_projections=negative_pooled_prompt_embeds,
679
- encoder_hidden_states=negative_prompt_embeds,
680
- img_ids=latent_image_ids,
681
- joint_attention_kwargs=self.joint_attention_kwargs,
682
- return_dict=False,
683
- )[0]
684
-
685
- if self.do_classifier_free_guidance:
686
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
687
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
688
-
689
- latents_dtype = latents.dtype
690
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
691
- # Yield intermediate result
692
- torch.cuda.empty_cache()
693
-
694
- if latents.dtype != latents_dtype:
695
- if torch.backends.mps.is_available():
696
- # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
697
- latents = latents.to(latents_dtype)
698
-
699
- if callback_on_step_end is not None:
700
- callback_kwargs = {}
701
- for k in callback_on_step_end_tensor_inputs:
702
- callback_kwargs[k] = locals()[k]
703
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
704
-
705
- latents = callback_outputs.pop("latents", latents)
706
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
707
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
708
- negative_pooled_prompt_embeds = callback_outputs.pop(
709
- "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
710
- )
 
 
 
 
 
711
 
712
  # call the callback, if provided
713
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
714
- progress_bar.update()
715
 
716
  # Final image
717
  return self._decode_latents_to_image(latents, height, width, output_type)
@@ -850,65 +855,70 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
850
  if self.interrupt:
851
  continue
852
 
853
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
854
-
855
- timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
856
-
857
- noise_pred = self.transformer(
858
- hidden_states=latent_model_input,
859
- timestep=timestep / 1000,
860
- guidance=guidance,
861
- pooled_projections=pooled_prompt_embeds,
862
- encoder_hidden_states=prompt_embeds,
863
- txt_ids=text_ids,
864
- img_ids=latent_image_ids,
865
- joint_attention_kwargs=self.joint_attention_kwargs,
866
- return_dict=False,
867
- )[0]
868
-
869
- noise_pred_uncond = self.transformer(
870
- hidden_states=latents,
871
- timestep=timestep / 1000,
872
- guidance=guidance,
873
- pooled_projections=negative_pooled_prompt_embeds,
874
- encoder_hidden_states=negative_prompt_embeds,
875
- img_ids=latent_image_ids,
876
- joint_attention_kwargs=self.joint_attention_kwargs,
877
- return_dict=False,
878
- )[0]
879
-
880
- if self.do_classifier_free_guidance:
881
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
882
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
883
-
884
- latents_dtype = latents.dtype
885
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
886
- # Yield intermediate result
887
- torch.cuda.empty_cache()
888
-
889
- if latents.dtype != latents_dtype:
890
- if torch.backends.mps.is_available():
891
- # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
892
- latents = latents.to(latents_dtype)
893
-
894
- if callback_on_step_end is not None:
895
- callback_kwargs = {}
896
- for k in callback_on_step_end_tensor_inputs:
897
- callback_kwargs[k] = locals()[k]
898
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
899
-
900
- latents = callback_outputs.pop("latents", latents)
901
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
902
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
903
- negative_pooled_prompt_embeds = callback_outputs.pop(
904
- "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
905
- )
 
 
 
 
 
906
 
907
  # call the callback, if provided
908
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
909
- progress_bar.update()
910
- # Final image
911
-
912
  return self._decode_latents_to_image(latents, height, width, output_type)
913
  self.maybe_free_model_hooks()
914
  torch.cuda.empty_cache()
 
655
  if self.interrupt:
656
  continue
657
 
658
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
659
+
660
+ timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
661
+
662
+ if self.transformer.config.guidance_embeds:
663
+ guidance = torch.tensor([guidance_scale], device=device)
664
+ guidance = guidance.expand(latents.shape[0])
665
+ else:
666
+ guidance = None
667
+
668
+ noise_pred = self.transformer(
669
+ hidden_states=latent_model_input,
670
+ timestep=timestep / 1000,
671
+ guidance=guidance,
672
+ pooled_projections=pooled_prompt_embeds,
673
+ encoder_hidden_states=prompt_embeds,
674
+ txt_ids=text_ids,
675
+ img_ids=latent_image_ids,
676
+ joint_attention_kwargs=self.joint_attention_kwargs,
677
+ return_dict=False,
678
+ )[0]
679
+
680
+ noise_pred_uncond = self.transformer(
681
+ hidden_states=latents,
682
+ timestep=timestep / 1000,
683
+ guidance=guidance,
684
+ pooled_projections=negative_pooled_prompt_embeds,
685
+ encoder_hidden_states=negative_prompt_embeds,
686
+ img_ids=latent_image_ids,
687
+ joint_attention_kwargs=self.joint_attention_kwargs,
688
+ return_dict=False,
689
+ )[0]
690
+
691
+ if self.do_classifier_free_guidance:
692
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
693
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
694
+
695
+ # compute the previous noisy sample x_t -> x_t-1
696
+ latents_dtype = latents.dtype
697
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
698
+
699
+ if latents.dtype != latents_dtype:
700
+ if torch.backends.mps.is_available():
701
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
702
+ latents = latents.to(latents_dtype)
703
+
704
+ if callback_on_step_end is not None:
705
+ callback_kwargs = {}
706
+ for k in callback_on_step_end_tensor_inputs:
707
+ callback_kwargs[k] = locals()[k]
708
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
709
+
710
+ latents = callback_outputs.pop("latents", latents)
711
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
712
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
713
+ negative_pooled_prompt_embeds = callback_outputs.pop(
714
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
715
+ )
716
 
717
  # call the callback, if provided
718
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
719
+ progress_bar.update()
720
 
721
  # Final image
722
  return self._decode_latents_to_image(latents, height, width, output_type)
 
855
  if self.interrupt:
856
  continue
857
 
858
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
859
+
860
+ timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
861
+
862
+ if self.transformer.config.guidance_embeds:
863
+ guidance = torch.tensor([guidance_scale], device=device)
864
+ guidance = guidance.expand(latents.shape[0])
865
+ else:
866
+ guidance = None
867
+
868
+ noise_pred = self.transformer(
869
+ hidden_states=latent_model_input,
870
+ timestep=timestep / 1000,
871
+ guidance=guidance,
872
+ pooled_projections=pooled_prompt_embeds,
873
+ encoder_hidden_states=prompt_embeds,
874
+ txt_ids=text_ids,
875
+ img_ids=latent_image_ids,
876
+ joint_attention_kwargs=self.joint_attention_kwargs,
877
+ return_dict=False,
878
+ )[0]
879
+
880
+ noise_pred_uncond = self.transformer(
881
+ hidden_states=latents,
882
+ timestep=timestep / 1000,
883
+ guidance=guidance,
884
+ pooled_projections=negative_pooled_prompt_embeds,
885
+ encoder_hidden_states=negative_prompt_embeds,
886
+ img_ids=latent_image_ids,
887
+ joint_attention_kwargs=self.joint_attention_kwargs,
888
+ return_dict=False,
889
+ )[0]
890
+
891
+ if self.do_classifier_free_guidance:
892
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
893
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
894
+
895
+ # compute the previous noisy sample x_t -> x_t-1
896
+ latents_dtype = latents.dtype
897
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
898
+
899
+ if latents.dtype != latents_dtype:
900
+ if torch.backends.mps.is_available():
901
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
902
+ latents = latents.to(latents_dtype)
903
+
904
+ if callback_on_step_end is not None:
905
+ callback_kwargs = {}
906
+ for k in callback_on_step_end_tensor_inputs:
907
+ callback_kwargs[k] = locals()[k]
908
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
909
+
910
+ latents = callback_outputs.pop("latents", latents)
911
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
912
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
913
+ negative_pooled_prompt_embeds = callback_outputs.pop(
914
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
915
+ )
916
 
917
  # call the callback, if provided
918
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
919
+ progress_bar.update()
920
+
921
+ # Final image
922
  return self._decode_latents_to_image(latents, height, width, output_type)
923
  self.maybe_free_model_hooks()
924
  torch.cuda.empty_cache()