Spaces:
Running
on
T4
Running
on
T4
patrickvonplaten
commited on
Commit
•
7302472
1
Parent(s):
459a0bd
[Safety Checker] Add Safety Checker Module
Browse filesFormer-commit-id: d0c714ae4afa1c011269a956d6f260f84f77025e
- scripts/txt2img.py +24 -1
scripts/txt2img.py
CHANGED
@@ -16,12 +16,29 @@ from ldm.util import instantiate_from_config
|
|
16 |
from ldm.models.diffusion.ddim import DDIMSampler
|
17 |
from ldm.models.diffusion.plms import PLMSSampler
|
18 |
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
def chunk(it, size):
|
21 |
it = iter(it)
|
22 |
return iter(lambda: tuple(islice(it, size)), ())
|
23 |
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
def load_model_from_config(config, ckpt, verbose=False):
|
26 |
print(f"Loading model from {ckpt}")
|
27 |
pl_sd = torch.load(ckpt, map_location="cpu")
|
@@ -220,7 +237,9 @@ def main():
|
|
220 |
if opt.fixed_code:
|
221 |
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
|
222 |
|
|
|
223 |
precision_scope = autocast if opt.precision=="autocast" else nullcontext
|
|
|
224 |
with torch.no_grad():
|
225 |
with precision_scope("cuda"):
|
226 |
with model.ema_scope():
|
@@ -269,7 +288,11 @@ def main():
|
|
269 |
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
270 |
grid_count += 1
|
271 |
|
272 |
-
|
|
|
|
|
|
|
|
|
273 |
|
274 |
print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
|
275 |
f" \nEnjoy.")
|
|
|
16 |
from ldm.models.diffusion.ddim import DDIMSampler
|
17 |
from ldm.models.diffusion.plms import PLMSSampler
|
18 |
|
19 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
20 |
+
from transformers import AutoFeatureExtractor
|
21 |
+
|
22 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-v-1-3", use_auth_token=True)
|
23 |
+
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-v-1-3", use_auth_token=True)
|
24 |
|
25 |
def chunk(it, size):
|
26 |
it = iter(it)
|
27 |
return iter(lambda: tuple(islice(it, size)), ())
|
28 |
|
29 |
|
30 |
+
def numpy_to_pil(images):
|
31 |
+
"""
|
32 |
+
Convert a numpy image or a batch of images to a PIL image.
|
33 |
+
"""
|
34 |
+
if images.ndim == 3:
|
35 |
+
images = images[None, ...]
|
36 |
+
images = (images * 255).round().astype("uint8")
|
37 |
+
pil_images = [Image.fromarray(image) for image in images]
|
38 |
+
|
39 |
+
return pil_images
|
40 |
+
|
41 |
+
|
42 |
def load_model_from_config(config, ckpt, verbose=False):
|
43 |
print(f"Loading model from {ckpt}")
|
44 |
pl_sd = torch.load(ckpt, map_location="cpu")
|
|
|
237 |
if opt.fixed_code:
|
238 |
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
|
239 |
|
240 |
+
print("start code", start_code.abs().sum())
|
241 |
precision_scope = autocast if opt.precision=="autocast" else nullcontext
|
242 |
+
precision_scope = nullcontext
|
243 |
with torch.no_grad():
|
244 |
with precision_scope("cuda"):
|
245 |
with model.ema_scope():
|
|
|
288 |
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
289 |
grid_count += 1
|
290 |
|
291 |
+
image = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
|
292 |
+
|
293 |
+
# run safety checker
|
294 |
+
safety_checker_input = pipe.feature_extractor(numpy_to_pil(image), return_tensors="pt")
|
295 |
+
image, has_nsfw_concept = pipe.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
|
296 |
|
297 |
print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
|
298 |
f" \nEnjoy.")
|