poipiii
commited on
Commit
•
f25c918
1
Parent(s):
1b4eb1d
test in latnent upcale
Browse files- pipeline.py +34 -2
pipeline.py
CHANGED
@@ -842,10 +842,42 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
842 |
return None
|
843 |
print(latents)
|
844 |
print(latents.shape)
|
845 |
-
|
846 |
latents, size=(int(height*resize_scale)//8, int(width*resize_scale)//8))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
847 |
|
848 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
849 |
#do latent upscale here
|
850 |
|
851 |
# 9. Post-processing
|
|
|
842 |
return None
|
843 |
print(latents)
|
844 |
print(latents.shape)
|
845 |
+
latents = torch.nn.functional.interpolate(
|
846 |
latents, size=(int(height*resize_scale)//8, int(width*resize_scale)//8))
|
847 |
+
|
848 |
+
for i, t in enumerate(self.progress_bar(timesteps)):
|
849 |
+
# expand the latents if we are doing classifier free guidance
|
850 |
+
latent_model_input = torch.cat(
|
851 |
+
[latents] * 2) if do_classifier_free_guidance else latents
|
852 |
+
latent_model_input = self.scheduler.scale_model_input(
|
853 |
+
latent_model_input, t)
|
854 |
+
|
855 |
+
# predict the noise residual
|
856 |
+
noise_pred = self.unet(latent_model_input, t,
|
857 |
+
encoder_hidden_states=text_embeddings).sample
|
858 |
+
|
859 |
+
# perform guidance
|
860 |
+
if do_classifier_free_guidance:
|
861 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
862 |
+
noise_pred = noise_pred_uncond + guidance_scale * \
|
863 |
+
(noise_pred_text - noise_pred_uncond)
|
864 |
+
|
865 |
+
# compute the previous noisy sample x_t -> x_t-1
|
866 |
+
latents = self.scheduler.step(
|
867 |
+
noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
868 |
|
869 |
+
if mask is not None:
|
870 |
+
# masking
|
871 |
+
init_latents_proper = self.scheduler.add_noise(
|
872 |
+
init_latents_orig, noise, torch.tensor([t]))
|
873 |
+
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
874 |
+
|
875 |
+
# call the callback, if provided
|
876 |
+
if i % callback_steps == 0:
|
877 |
+
if callback is not None:
|
878 |
+
callback(i, t, latents)
|
879 |
+
if is_cancelled_callback is not None and is_cancelled_callback():
|
880 |
+
return None
|
881 |
#do latent upscale here
|
882 |
|
883 |
# 9. Post-processing
|