Spaces:
Runtime error
Runtime error
nsfw filter
Browse files- StableDiffuser.py +15 -1
StableDiffuser.py
CHANGED
@@ -5,10 +5,12 @@ from baukit import TraceDict
|
|
5 |
from diffusers import AutoencoderKL, UNet2DConditionModel
|
6 |
from PIL import Image
|
7 |
from tqdm.auto import tqdm
|
8 |
-
from transformers import CLIPTextModel, CLIPTokenizer
|
|
|
9 |
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
10 |
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
11 |
from diffusers.schedulers.scheduling_lms_discrete import LMSDiscreteScheduler
|
|
|
12 |
import util
|
13 |
|
14 |
|
@@ -53,6 +55,9 @@ class StableDiffuser(torch.nn.Module):
|
|
53 |
self.unet = UNet2DConditionModel.from_pretrained(
|
54 |
"CompVis/stable-diffusion-v1-4", subfolder="unet")
|
55 |
|
|
|
|
|
|
|
56 |
if scheduler == 'LMS':
|
57 |
self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
|
58 |
elif scheduler == 'DDIM':
|
@@ -237,6 +242,15 @@ class StableDiffuser(torch.nn.Module):
|
|
237 |
latents_steps = [self.decode(latents.to(self.unet.device)) for latents in latents_steps]
|
238 |
images_steps = [self.to_image(latents) for latents in latents_steps]
|
239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
images_steps = list(zip(*images_steps))
|
241 |
|
242 |
if trace_steps:
|
|
|
5 |
from diffusers import AutoencoderKL, UNet2DConditionModel
|
6 |
from PIL import Image
|
7 |
from tqdm.auto import tqdm
|
8 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor
|
9 |
+
from diffusers.schedulers import EulerAncestralDiscreteScheduler
|
10 |
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
11 |
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
12 |
from diffusers.schedulers.scheduling_lms_discrete import LMSDiscreteScheduler
|
13 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
14 |
import util
|
15 |
|
16 |
|
|
|
55 |
self.unet = UNet2DConditionModel.from_pretrained(
|
56 |
"CompVis/stable-diffusion-v1-4", subfolder="unet")
|
57 |
|
58 |
+
self.feature_extractor = CLIPFeatureExtractor.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="feature_extractor")
|
59 |
+
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="safety_checker")
|
60 |
+
|
61 |
if scheduler == 'LMS':
|
62 |
self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
|
63 |
elif scheduler == 'DDIM':
|
|
|
242 |
latents_steps = [self.decode(latents.to(self.unet.device)) for latents in latents_steps]
|
243 |
images_steps = [self.to_image(latents) for latents in latents_steps]
|
244 |
|
245 |
+
for i in range(len(images_steps)):
|
246 |
+
self.safety_checker = self.safety_checker.float()
|
247 |
+
safety_checker_input = self.feature_extractor(images_steps[i], return_tensors="pt").to(latents_steps[0].device)
|
248 |
+
image, has_nsfw_concept = self.safety_checker(
|
249 |
+
images=latents_steps[i].float().cpu().numpy(), clip_input=safety_checker_input.pixel_values.float()
|
250 |
+
)
|
251 |
+
|
252 |
+
images_steps[i][0] = self.to_image(torch.from_numpy(image))[0]
|
253 |
+
|
254 |
images_steps = list(zip(*images_steps))
|
255 |
|
256 |
if trace_steps:
|