Spaces:
Sleeping
Sleeping
apolinario
commited on
Commit
•
6624621
1
Parent(s):
8dfb33f
Better NSFW filter
Browse files
app.py
CHANGED
@@ -42,15 +42,62 @@ def load_model_from_config(config, ckpt, verbose=False):
|
|
42 |
model.eval()
|
43 |
return model
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
config = OmegaConf.load("latent-diffusion/configs/latent-diffusion/txt2img-1p4B-eval.yaml")
|
46 |
model = load_model_from_config(config, f"txt2img-f8-large.ckpt")
|
47 |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
48 |
model = model.to(device)
|
|
|
49 |
#NSFW CLIP Filter
|
50 |
-
|
51 |
-
|
52 |
-
with torch.no_grad():
|
53 |
-
text_features = clip_model.encode_text(text)
|
54 |
|
55 |
def run(prompt, steps, width, height, images, scale):
|
56 |
opt = argparse.Namespace(
|
@@ -108,10 +155,13 @@ def run(prompt, steps, width, height, images, scale):
|
|
108 |
for x_sample in x_samples_ddim:
|
109 |
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
110 |
image_vector = Image.fromarray(x_sample.astype(np.uint8))
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
|
|
|
|
|
|
115 |
all_samples_images.append(image_vector)
|
116 |
else:
|
117 |
return(None,None,"Sorry, potential NSFW content was detected on your outputs by our NSFW detection model. Try again with different prompts. If you feel your prompt was not supposed to give NSFW outputs, this may be due to a bias in the model. Read more about biases in the Biases Acknowledgment section below.")
|
|
|
42 |
model.eval()
|
43 |
return model
|
44 |
|
45 |
+
def load_safety_model(clip_model):
|
46 |
+
"""load the safety model"""
|
47 |
+
import autokeras as ak # pylint: disable=import-outside-toplevel
|
48 |
+
from tensorflow.keras.models import load_model # pylint: disable=import-outside-toplevel
|
49 |
+
from os.path import expanduser # pylint: disable=import-outside-toplevel
|
50 |
+
|
51 |
+
home = expanduser("~")
|
52 |
+
|
53 |
+
cache_folder = home + "/.cache/clip_retrieval/" + clip_model.replace("/", "_")
|
54 |
+
if clip_model == "ViT-L/14":
|
55 |
+
model_dir = cache_folder + "/clip_autokeras_binary_nsfw"
|
56 |
+
dim = 768
|
57 |
+
elif clip_model == "ViT-B/32":
|
58 |
+
model_dir = cache_folder + "/clip_autokeras_nsfw_b32"
|
59 |
+
dim = 512
|
60 |
+
else:
|
61 |
+
raise ValueError("Unknown clip model")
|
62 |
+
if not os.path.exists(model_dir):
|
63 |
+
os.makedirs(cache_folder, exist_ok=True)
|
64 |
+
|
65 |
+
from urllib.request import urlretrieve # pylint: disable=import-outside-toplevel
|
66 |
+
|
67 |
+
path_to_zip_file = cache_folder + "/clip_autokeras_binary_nsfw.zip"
|
68 |
+
if clip_model == "ViT-L/14":
|
69 |
+
url_model = "https://raw.githubusercontent.com/LAION-AI/CLIP-based-NSFW-Detector/main/clip_autokeras_binary_nsfw.zip"
|
70 |
+
elif clip_model == "ViT-B/32":
|
71 |
+
url_model = (
|
72 |
+
"https://raw.githubusercontent.com/LAION-AI/CLIP-based-NSFW-Detector/main/clip_autokeras_nsfw_b32.zip"
|
73 |
+
)
|
74 |
+
else:
|
75 |
+
raise ValueError("Unknown model {}".format(clip_model))
|
76 |
+
urlretrieve(url_model, path_to_zip_file)
|
77 |
+
import zipfile # pylint: disable=import-outside-toplevel
|
78 |
+
|
79 |
+
with zipfile.ZipFile(path_to_zip_file, "r") as zip_ref:
|
80 |
+
zip_ref.extractall(cache_folder)
|
81 |
+
|
82 |
+
loaded_model = load_model(model_dir, custom_objects=ak.CUSTOM_OBJECTS)
|
83 |
+
loaded_model.predict(np.random.rand(10 ** 3, dim).astype("float32"), batch_size=10 ** 3)
|
84 |
+
|
85 |
+
return loaded_model
|
86 |
+
|
87 |
+
def is_unsafe(safety_model, embeddings, threshold=0.5):
|
88 |
+
"""find unsafe embeddings"""
|
89 |
+
nsfw_values = safety_model.predict(embeddings, batch_size=embeddings.shape[0])
|
90 |
+
x = np.array([e[0] for e in nsfw_values])
|
91 |
+
return True if x > threshold else False
|
92 |
+
|
93 |
config = OmegaConf.load("latent-diffusion/configs/latent-diffusion/txt2img-1p4B-eval.yaml")
|
94 |
model = load_model_from_config(config, f"txt2img-f8-large.ckpt")
|
95 |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
96 |
model = model.to(device)
|
97 |
+
|
98 |
#NSFW CLIP Filter
|
99 |
+
safety_model = load_safety_model("ViT-B/32")
|
100 |
+
clip_model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='openai')
|
|
|
|
|
101 |
|
102 |
def run(prompt, steps, width, height, images, scale):
|
103 |
opt = argparse.Namespace(
|
|
|
155 |
for x_sample in x_samples_ddim:
|
156 |
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
157 |
image_vector = Image.fromarray(x_sample.astype(np.uint8))
|
158 |
+
image_preprocess = preprocess(image_vector).unsqueeze(0)
|
159 |
+
with torch.no_grad():
|
160 |
+
image_features = clip_model.encode_image(image_preprocess)
|
161 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
162 |
+
query = image_features.cpu().detach().numpy().astype("float32")
|
163 |
+
unsafe = is_unsafe(safety_model,query,0.5)
|
164 |
+
if(not unsafe):
|
165 |
all_samples_images.append(image_vector)
|
166 |
else:
|
167 |
return(None,None,"Sorry, potential NSFW content was detected on your outputs by our NSFW detection model. Try again with different prompts. If you feel your prompt was not supposed to give NSFW outputs, this may be due to a bias in the model. Read more about biases in the Biases Acknowledgment section below.")
|