Spaces:
Running
on
Zero
Running
on
Zero
Update pipline_StableDiffusion_ConsistentID.py
Browse files
pipline_StableDiffusion_ConsistentID.py
CHANGED
@@ -419,6 +419,7 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
|
|
419 |
class_tokens_mask: Optional[torch.LongTensor] = None,
|
420 |
prompt_embeds_text_only: Optional[torch.FloatTensor] = None,
|
421 |
retouching: bool=False,
|
|
|
422 |
):
|
423 |
# 0. Default height and width to unet
|
424 |
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
@@ -604,9 +605,12 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
|
|
604 |
image = self.decode_latents(latents)
|
605 |
|
606 |
# 9.2 Run safety checker
|
607 |
-
|
608 |
-
image,
|
609 |
-
|
|
|
|
|
|
|
610 |
|
611 |
# 9.3 Convert to PIL
|
612 |
image = self.numpy_to_pil(image)
|
|
|
419 |
class_tokens_mask: Optional[torch.LongTensor] = None,
|
420 |
prompt_embeds_text_only: Optional[torch.FloatTensor] = None,
|
421 |
retouching: bool=False,
|
422 |
+
need_safetycheck: bool=True,
|
423 |
):
|
424 |
# 0. Default height and width to unet
|
425 |
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
|
|
605 |
image = self.decode_latents(latents)
|
606 |
|
607 |
# 9.2 Run safety checker
|
608 |
+
if need_safetycheck:
|
609 |
+
image, has_nsfw_concept = self.run_safety_checker(
|
610 |
+
image, device, prompt_embeds.dtype
|
611 |
+
)
|
612 |
+
else:
|
613 |
+
has_nsfw_concept = None
|
614 |
|
615 |
# 9.3 Convert to PIL
|
616 |
image = self.numpy_to_pil(image)
|