poipiii commited on
Commit
f25c918
1 Parent(s): 1b4eb1d

test in latnent upcale

Browse files
Files changed (1) hide show
  1. pipeline.py +34 -2
pipeline.py CHANGED
@@ -842,10 +842,42 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
842
  return None
843
  print(latents)
844
  print(latents.shape)
845
- resized_image = torch.nn.functional.interpolate(
846
  latents, size=(int(height*resize_scale)//8, int(width*resize_scale)//8))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
847
 
848
- print(resized_image.shape)
 
 
 
 
 
 
 
 
 
 
 
849
  #do latent upscale here
850
 
851
  # 9. Post-processing
 
842
  return None
843
  print(latents)
844
  print(latents.shape)
845
+ latents = torch.nn.functional.interpolate(
846
  latents, size=(int(height*resize_scale)//8, int(width*resize_scale)//8))
847
+
848
+ for i, t in enumerate(self.progress_bar(timesteps)):
849
+ # expand the latents if we are doing classifier free guidance
850
+ latent_model_input = torch.cat(
851
+ [latents] * 2) if do_classifier_free_guidance else latents
852
+ latent_model_input = self.scheduler.scale_model_input(
853
+ latent_model_input, t)
854
+
855
+ # predict the noise residual
856
+ noise_pred = self.unet(latent_model_input, t,
857
+ encoder_hidden_states=text_embeddings).sample
858
+
859
+ # perform guidance
860
+ if do_classifier_free_guidance:
861
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
862
+ noise_pred = noise_pred_uncond + guidance_scale * \
863
+ (noise_pred_text - noise_pred_uncond)
864
+
865
+ # compute the previous noisy sample x_t -> x_t-1
866
+ latents = self.scheduler.step(
867
+ noise_pred, t, latents, **extra_step_kwargs).prev_sample
868
 
869
+ if mask is not None:
870
+ # masking
871
+ init_latents_proper = self.scheduler.add_noise(
872
+ init_latents_orig, noise, torch.tensor([t]))
873
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
874
+
875
+ # call the callback, if provided
876
+ if i % callback_steps == 0:
877
+ if callback is not None:
878
+ callback(i, t, latents)
879
+ if is_cancelled_callback is not None and is_cancelled_callback():
880
+ return None
881
  #do latent upscale here
882
 
883
  # 9. Post-processing