poipiii commited on
Commit
e611cbb
1 Parent(s): 6dfb556

test interpolate latent

Browse files
Files changed (1) hide show
  1. pipeline.py +7 -1
pipeline.py CHANGED
@@ -665,6 +665,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
665
  mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
666
  height: int = 512,
667
  width: int = 512,
 
668
  num_inference_steps: int = 50,
669
  guidance_scale: float = 7.5,
670
  strength: float = 0.8,
@@ -841,10 +842,15 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
841
  return None
842
  print(latents)
843
  print(latents.shape)
 
 
 
 
 
 
844
  # 9. Post-processing
845
  image = self.decode_latents(latents)
846
 
847
- #do latent upscale here
848
 
849
  # 10. Run safety checker
850
  image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
 
665
  mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
666
  height: int = 512,
667
  width: int = 512,
668
+ resize_scale: float = 1.2,
669
  num_inference_steps: int = 50,
670
  guidance_scale: float = 7.5,
671
  strength: float = 0.8,
 
842
  return None
843
  print(latents)
844
  print(latents.shape)
845
+ resized_image = torch.nn.functional.interpolate(
846
+ latents, size=(int(latents.shape[2]*resize_scale)//8, int(latents.shape[3]*resize_scale)//8))
847
+
848
+ print(resized_image.shape)
849
+ #do latent upscale here
850
+
851
  # 9. Post-processing
852
  image = self.decode_latents(latents)
853
 
 
854
 
855
  # 10. Run safety checker
856
  image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)