try fix controlnet batch processing
Browse files- pipeline.py +36 -9
pipeline.py
CHANGED
@@ -980,6 +980,24 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
|
|
980 |
# compute the percentage of total steps we are at
|
981 |
current_sampling_percent = i / len(timesteps)
|
982 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
983 |
if (
|
984 |
current_sampling_percent < controlnet_guidance_start
|
985 |
or current_sampling_percent > controlnet_guidance_end
|
@@ -988,15 +1006,24 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
|
|
988 |
down_block_res_samples = None
|
989 |
mid_block_res_sample = None
|
990 |
else:
|
991 |
-
|
992 |
-
|
993 |
-
|
994 |
-
|
995 |
-
|
996 |
-
|
997 |
-
|
998 |
-
|
999 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1000 |
|
1001 |
# predict the noise residual
|
1002 |
noise_pred = self.unet(
|
|
|
980 |
# compute the percentage of total steps we are at
|
981 |
current_sampling_percent = i / len(timesteps)
|
982 |
|
983 |
+
# if (
|
984 |
+
# current_sampling_percent < controlnet_guidance_start
|
985 |
+
# or current_sampling_percent > controlnet_guidance_end
|
986 |
+
# ):
|
987 |
+
# # do not apply the controlnet
|
988 |
+
# down_block_res_samples = None
|
989 |
+
# mid_block_res_sample = None
|
990 |
+
# else:
|
991 |
+
# # apply the controlnet
|
992 |
+
# down_block_res_samples, mid_block_res_sample = self.controlnet(
|
993 |
+
# latent_model_input,
|
994 |
+
# t,
|
995 |
+
# encoder_hidden_states=prompt_embeds,
|
996 |
+
# controlnet_cond=controlnet_conditioning_image,
|
997 |
+
# conditioning_scale=controlnet_conditioning_scale,
|
998 |
+
# return_dict=False,
|
999 |
+
# )
|
1000 |
+
|
1001 |
if (
|
1002 |
current_sampling_percent < controlnet_guidance_start
|
1003 |
or current_sampling_percent > controlnet_guidance_end
|
|
|
1006 |
down_block_res_samples = None
|
1007 |
mid_block_res_sample = None
|
1008 |
else:
|
1009 |
+
down_block_res_samples = []
|
1010 |
+
mid_block_res_samples = []
|
1011 |
+
for i in range(batch_size):
|
1012 |
+
# apply the controlnet
|
1013 |
+
down_block_res_sample, mid_block_res_sample = self.controlnet(
|
1014 |
+
latent_model_input[i * num_images_per_prompt:(i + 1) * num_images_per_prompt],
|
1015 |
+
t,
|
1016 |
+
encoder_hidden_states=prompt_embeds,
|
1017 |
+
controlnet_cond=controlnet_conditioning_image[i],
|
1018 |
+
conditioning_scale=controlnet_conditioning_scale,
|
1019 |
+
return_dict=False,
|
1020 |
+
)
|
1021 |
+
down_block_res_samples.append(down_block_res_sample)
|
1022 |
+
mid_block_res_samples.append(mid_block_res_sample)
|
1023 |
+
|
1024 |
+
down_block_res_samples = torch.cat(down_block_res_samples, dim=0)
|
1025 |
+
mid_block_res_sample = torch.cat(mid_block_res_samples, dim=0)
|
1026 |
+
|
1027 |
|
1028 |
# predict the noise residual
|
1029 |
noise_pred = self.unet(
|