MrAlex commited on
Commit
27f9e25
1 Parent(s): f413d0c

update for image batch processing

Browse files
Files changed (1) hide show
  1. pipeline.py +20 -5
pipeline.py CHANGED
@@ -856,7 +856,8 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
856
  )
857
 
858
  # 4. Prepare image, and controlnet_conditioning_image
859
- image = prepare_image(image)
 
860
 
861
  # condition image(s)
862
  if isinstance(self.controlnet, ControlNetModel):
@@ -897,15 +898,27 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
897
  latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
898
 
899
  # 6. Prepare latent variables
900
- latents = self.prepare_latents(
901
- image,
 
 
 
 
 
 
 
 
 
 
902
  latent_timestep,
903
  batch_size,
904
  num_images_per_prompt,
905
  prompt_embeds.dtype,
906
  device,
907
  generator,
908
- )
 
 
909
 
910
  # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
911
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
@@ -915,7 +928,9 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
915
  with self.progress_bar(total=num_inference_steps) as progress_bar:
916
  for i, t in enumerate(timesteps):
917
  # expand the latents if we are doing classifier free guidance
918
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
 
 
919
 
920
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
921
 
 
856
  )
857
 
858
  # 4. Prepare image, and controlnet_conditioning_image
859
+ # image = prepare_image(image)
860
+ images = [prepare_image(img) for img in image]
861
 
862
  # condition image(s)
863
  if isinstance(self.controlnet, ControlNetModel):
 
898
  latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
899
 
900
  # 6. Prepare latent variables
901
+ # latents = self.prepare_latents(
902
+ # image,
903
+ # latent_timestep,
904
+ # batch_size,
905
+ # num_images_per_prompt,
906
+ # prompt_embeds.dtype,
907
+ # device,
908
+ # generator,
909
+ # )
910
+
911
+ latents = [self.prepare_latents(
912
+ img,
913
  latent_timestep,
914
  batch_size,
915
  num_images_per_prompt,
916
  prompt_embeds.dtype,
917
  device,
918
  generator,
919
+ ) for img in images]
920
+ latents = torch.cat(latents)
921
+
922
 
923
  # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
924
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
 
928
  with self.progress_bar(total=num_inference_steps) as progress_bar:
929
  for i, t in enumerate(timesteps):
930
  # expand the latents if we are doing classifier free guidance
931
+ # latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
932
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents.clone()
933
+
934
 
935
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
936