update for image batch processing
Browse files- pipeline.py +20 -5
pipeline.py
CHANGED
@@ -856,7 +856,8 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
|
|
856 |
)
|
857 |
|
858 |
# 4. Prepare image, and controlnet_conditioning_image
|
859 |
-
image = prepare_image(image)
|
|
|
860 |
|
861 |
# condition image(s)
|
862 |
if isinstance(self.controlnet, ControlNetModel):
|
@@ -897,15 +898,27 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
|
|
897 |
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
898 |
|
899 |
# 6. Prepare latent variables
|
900 |
-
latents = self.prepare_latents(
|
901 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
902 |
latent_timestep,
|
903 |
batch_size,
|
904 |
num_images_per_prompt,
|
905 |
prompt_embeds.dtype,
|
906 |
device,
|
907 |
generator,
|
908 |
-
)
|
|
|
|
|
909 |
|
910 |
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
911 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
@@ -915,7 +928,9 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
|
|
915 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
916 |
for i, t in enumerate(timesteps):
|
917 |
# expand the latents if we are doing classifier free guidance
|
918 |
-
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
|
|
|
|
919 |
|
920 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
921 |
|
|
|
856 |
)
|
857 |
|
858 |
# 4. Prepare image, and controlnet_conditioning_image
|
859 |
+
# image = prepare_image(image)
|
860 |
+
images = [prepare_image(img) for img in image]
|
861 |
|
862 |
# condition image(s)
|
863 |
if isinstance(self.controlnet, ControlNetModel):
|
|
|
898 |
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
899 |
|
900 |
# 6. Prepare latent variables
|
901 |
+
# latents = self.prepare_latents(
|
902 |
+
# image,
|
903 |
+
# latent_timestep,
|
904 |
+
# batch_size,
|
905 |
+
# num_images_per_prompt,
|
906 |
+
# prompt_embeds.dtype,
|
907 |
+
# device,
|
908 |
+
# generator,
|
909 |
+
# )
|
910 |
+
|
911 |
+
latents = [self.prepare_latents(
|
912 |
+
img,
|
913 |
latent_timestep,
|
914 |
batch_size,
|
915 |
num_images_per_prompt,
|
916 |
prompt_embeds.dtype,
|
917 |
device,
|
918 |
generator,
|
919 |
+
) for img in images]
|
920 |
+
latents = torch.cat(latents)
|
921 |
+
|
922 |
|
923 |
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
924 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
|
|
928 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
929 |
for i, t in enumerate(timesteps):
|
930 |
# expand the latents if we are doing classifier free guidance
|
931 |
+
# latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
932 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents.clone()
|
933 |
+
|
934 |
|
935 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
936 |
|