JadenFK commited on
Commit
01064e8
1 Parent(s): 506badf

nsfw filter

Browse files
Files changed (1) hide show
  1. StableDiffuser.py +15 -1
StableDiffuser.py CHANGED
@@ -5,10 +5,12 @@ from baukit import TraceDict
5
  from diffusers import AutoencoderKL, UNet2DConditionModel
6
  from PIL import Image
7
  from tqdm.auto import tqdm
8
- from transformers import CLIPTextModel, CLIPTokenizer
 
9
  from diffusers.schedulers.scheduling_ddim import DDIMScheduler
10
  from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
11
  from diffusers.schedulers.scheduling_lms_discrete import LMSDiscreteScheduler
 
12
  import util
13
 
14
 
@@ -53,6 +55,9 @@ class StableDiffuser(torch.nn.Module):
53
  self.unet = UNet2DConditionModel.from_pretrained(
54
  "CompVis/stable-diffusion-v1-4", subfolder="unet")
55
 
 
 
 
56
  if scheduler == 'LMS':
57
  self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
58
  elif scheduler == 'DDIM':
@@ -237,6 +242,15 @@ class StableDiffuser(torch.nn.Module):
237
  latents_steps = [self.decode(latents.to(self.unet.device)) for latents in latents_steps]
238
  images_steps = [self.to_image(latents) for latents in latents_steps]
239
 
 
 
 
 
 
 
 
 
 
240
  images_steps = list(zip(*images_steps))
241
 
242
  if trace_steps:
 
5
  from diffusers import AutoencoderKL, UNet2DConditionModel
6
  from PIL import Image
7
  from tqdm.auto import tqdm
8
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor
9
+ from diffusers.schedulers import EulerAncestralDiscreteScheduler
10
  from diffusers.schedulers.scheduling_ddim import DDIMScheduler
11
  from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
12
  from diffusers.schedulers.scheduling_lms_discrete import LMSDiscreteScheduler
13
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
14
  import util
15
 
16
 
 
55
  self.unet = UNet2DConditionModel.from_pretrained(
56
  "CompVis/stable-diffusion-v1-4", subfolder="unet")
57
 
58
+ self.feature_extractor = CLIPFeatureExtractor.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="feature_extractor")
59
+ self.safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="safety_checker")
60
+
61
  if scheduler == 'LMS':
62
  self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
63
  elif scheduler == 'DDIM':
 
242
  latents_steps = [self.decode(latents.to(self.unet.device)) for latents in latents_steps]
243
  images_steps = [self.to_image(latents) for latents in latents_steps]
244
 
245
+ for i in range(len(images_steps)):
246
+ self.safety_checker = self.safety_checker.float()
247
+ safety_checker_input = self.feature_extractor(images_steps[i], return_tensors="pt").to(latents_steps[0].device)
248
+ image, has_nsfw_concept = self.safety_checker(
249
+ images=latents_steps[i].float().cpu().numpy(), clip_input=safety_checker_input.pixel_values.float()
250
+ )
251
+
252
+ images_steps[i][0] = self.to_image(torch.from_numpy(image))[0]
253
+
254
  images_steps = list(zip(*images_steps))
255
 
256
  if trace_steps: