Spaces:
Runtime error
Runtime error
Linoy Tsaban
commited on
Commit
·
3fcb5ce
1
Parent(s):
cb271cd
Update modified_pipeline_semantic_stable_diffusion.py
Browse files
modified_pipeline_semantic_stable_diffusion.py
CHANGED
|
@@ -717,37 +717,37 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline):
|
|
| 717 |
callback(i, t, latents)
|
| 718 |
|
| 719 |
|
| 720 |
-
|
| 721 |
-
image = self.decode_latents(latents)
|
| 722 |
|
| 723 |
-
# 9. Run safety checker
|
| 724 |
-
image, has_nsfw_concept = self.run_safety_checker(image, self.device, text_embeddings.dtype)
|
| 725 |
|
| 726 |
-
# 10. Convert to PIL
|
| 727 |
-
if output_type == "pil":
|
| 728 |
-
|
| 729 |
|
| 730 |
-
if not return_dict:
|
| 731 |
-
|
| 732 |
|
| 733 |
-
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
| 734 |
|
| 735 |
-
#
|
| 736 |
-
|
| 737 |
-
|
| 738 |
-
|
| 739 |
-
|
| 740 |
-
|
| 741 |
-
|
| 742 |
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
|
| 747 |
|
| 748 |
-
|
| 749 |
|
| 750 |
-
|
| 751 |
-
|
| 752 |
|
| 753 |
-
|
|
|
|
| 717 |
callback(i, t, latents)
|
| 718 |
|
| 719 |
|
| 720 |
+
# # 8. Post-processing
|
| 721 |
+
# image = self.decode_latents(latents)
|
| 722 |
|
| 723 |
+
# # 9. Run safety checker
|
| 724 |
+
# image, has_nsfw_concept = self.run_safety_checker(image, self.device, text_embeddings.dtype)
|
| 725 |
|
| 726 |
+
# # 10. Convert to PIL
|
| 727 |
+
# if output_type == "pil":
|
| 728 |
+
# image = self.numpy_to_pil(image)
|
| 729 |
|
| 730 |
+
# if not return_dict:
|
| 731 |
+
# return (image, has_nsfw_concept)
|
| 732 |
|
| 733 |
+
# return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
| 734 |
|
| 735 |
+
# 8. Post-processing
|
| 736 |
+
if not output_type == "latent":
|
| 737 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
| 738 |
+
image, has_nsfw_concept = self.run_safety_checker(image, self.device, text_embeddings.dtype)
|
| 739 |
+
else:
|
| 740 |
+
image = latents
|
| 741 |
+
has_nsfw_concept = None
|
| 742 |
|
| 743 |
+
if has_nsfw_concept is None:
|
| 744 |
+
do_denormalize = [True] * image.shape[0]
|
| 745 |
+
else:
|
| 746 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
| 747 |
|
| 748 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
| 749 |
|
| 750 |
+
if not return_dict:
|
| 751 |
+
return (image, has_nsfw_concept)
|
| 752 |
|
| 753 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|