MrAlex commited on
Commit
6c02a8b
1 Parent(s): 63652fd

try fix controlnet batch processing

Browse files
Files changed (1) hide show
  1. pipeline.py +36 -9
pipeline.py CHANGED
@@ -980,6 +980,24 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
980
  # compute the percentage of total steps we are at
981
  current_sampling_percent = i / len(timesteps)
982
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
983
  if (
984
  current_sampling_percent < controlnet_guidance_start
985
  or current_sampling_percent > controlnet_guidance_end
@@ -988,15 +1006,24 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
988
  down_block_res_samples = None
989
  mid_block_res_sample = None
990
  else:
991
- # apply the controlnet
992
- down_block_res_samples, mid_block_res_sample = self.controlnet(
993
- latent_model_input,
994
- t,
995
- encoder_hidden_states=prompt_embeds,
996
- controlnet_cond=controlnet_conditioning_image,
997
- conditioning_scale=controlnet_conditioning_scale,
998
- return_dict=False,
999
- )
 
 
 
 
 
 
 
 
 
1000
 
1001
  # predict the noise residual
1002
  noise_pred = self.unet(
 
980
  # compute the percentage of total steps we are at
981
  current_sampling_percent = i / len(timesteps)
982
 
983
+ # if (
984
+ # current_sampling_percent < controlnet_guidance_start
985
+ # or current_sampling_percent > controlnet_guidance_end
986
+ # ):
987
+ # # do not apply the controlnet
988
+ # down_block_res_samples = None
989
+ # mid_block_res_sample = None
990
+ # else:
991
+ # # apply the controlnet
992
+ # down_block_res_samples, mid_block_res_sample = self.controlnet(
993
+ # latent_model_input,
994
+ # t,
995
+ # encoder_hidden_states=prompt_embeds,
996
+ # controlnet_cond=controlnet_conditioning_image,
997
+ # conditioning_scale=controlnet_conditioning_scale,
998
+ # return_dict=False,
999
+ # )
1000
+
1001
  if (
1002
  current_sampling_percent < controlnet_guidance_start
1003
  or current_sampling_percent > controlnet_guidance_end
 
1006
  down_block_res_samples = None
1007
  mid_block_res_sample = None
1008
  else:
1009
+ down_block_res_samples = []
1010
+ mid_block_res_samples = []
1011
+ for i in range(batch_size):
1012
+ # apply the controlnet
1013
+ down_block_res_sample, mid_block_res_sample = self.controlnet(
1014
+ latent_model_input[i * num_images_per_prompt:(i + 1) * num_images_per_prompt],
1015
+ t,
1016
+ encoder_hidden_states=prompt_embeds,
1017
+ controlnet_cond=controlnet_conditioning_image[i],
1018
+ conditioning_scale=controlnet_conditioning_scale,
1019
+ return_dict=False,
1020
+ )
1021
+ down_block_res_samples.append(down_block_res_sample)
1022
+ mid_block_res_samples.append(mid_block_res_sample)
1023
+
1024
+ down_block_res_samples = torch.cat(down_block_res_samples, dim=0)
1025
+ mid_block_res_sample = torch.cat(mid_block_res_samples, dim=0)
1026
+
1027
 
1028
  # predict the noise residual
1029
  noise_pred = self.unet(