JackAILab commited on
Commit
435d39c
·
verified ·
1 Parent(s): 5043913

Update pipline_StableDiffusion_ConsistentID.py

Browse files
pipline_StableDiffusion_ConsistentID.py CHANGED
@@ -419,6 +419,7 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
419
  class_tokens_mask: Optional[torch.LongTensor] = None,
420
  prompt_embeds_text_only: Optional[torch.FloatTensor] = None,
421
  retouching: bool=False,
 
422
  ):
423
  # 0. Default height and width to unet
424
  height = height or self.unet.config.sample_size * self.vae_scale_factor
@@ -604,9 +605,12 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
604
  image = self.decode_latents(latents)
605
 
606
  # 9.2 Run safety checker
607
- image, has_nsfw_concept = self.run_safety_checker(
608
- image, device, prompt_embeds.dtype
609
- )
 
 
 
610
 
611
  # 9.3 Convert to PIL
612
  image = self.numpy_to_pil(image)
 
419
  class_tokens_mask: Optional[torch.LongTensor] = None,
420
  prompt_embeds_text_only: Optional[torch.FloatTensor] = None,
421
  retouching: bool=False,
422
+ need_safetycheck: bool=True,
423
  ):
424
  # 0. Default height and width to unet
425
  height = height or self.unet.config.sample_size * self.vae_scale_factor
 
605
  image = self.decode_latents(latents)
606
 
607
  # 9.2 Run safety checker
608
+ if need_safetycheck:
609
+ image, has_nsfw_concept = self.run_safety_checker(
610
+ image, device, prompt_embeds.dtype
611
+ )
612
+ else:
613
+ has_nsfw_concept = None
614
 
615
  # 9.3 Convert to PIL
616
  image = self.numpy_to_pil(image)