Spaces:
Runtime error
Runtime error
Kabatubare
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -35,12 +35,15 @@ if SAFETY_CHECKER:
|
|
35 |
"openai/clip-vit-base-patch32"
|
36 |
)
|
37 |
|
38 |
-
def check_nsfw_images(
|
|
|
|
|
39 |
safety_checker_input = feature_extractor(images=[image.convert("RGB") for image in images], return_tensors="pt").to("cuda")
|
40 |
has_nsfw_concepts = safety_checker(
|
41 |
images=images,
|
42 |
clip_input=safety_checker_input.pixel_values.to("cuda"),
|
43 |
)
|
|
|
44 |
return images, has_nsfw_concepts.bool().tolist()
|
45 |
|
46 |
@spaces.GPU(enable_queue=True)
|
@@ -52,18 +55,17 @@ def generate_image(prompt, ckpt):
|
|
52 |
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if num_inference_steps == 1 else "epsilon")
|
53 |
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device="cuda"))
|
54 |
loaded = num_inference_steps
|
55 |
-
|
56 |
results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=7.5)
|
57 |
|
58 |
if SAFETY_CHECKER:
|
59 |
images, has_nsfw_concepts = check_nsfw_images(results.images)
|
60 |
if any(has_nsfw_concepts):
|
61 |
-
return
|
62 |
return results.images[0]
|
63 |
|
64 |
description = """
|
65 |
-
🌌
|
66 |
-
🖖 Boldly go where no one has gone before - create images as vast as the universe with your imagination.
|
67 |
"""
|
68 |
|
69 |
with gr.Blocks(css="style.css") as demo:
|
@@ -71,10 +73,10 @@ with gr.Blocks(css="style.css") as demo:
|
|
71 |
gr.Markdown(description)
|
72 |
with gr.Group():
|
73 |
with gr.Row():
|
74 |
-
prompt = gr.Textbox(label='
|
75 |
-
ckpt = gr.Dropdown(label='Warp Factor
|
76 |
-
submit = gr.Button("
|
77 |
-
img = gr.Image(label='
|
78 |
|
79 |
prompt.submit(fn=generate_image, inputs=[prompt, ckpt], outputs=img)
|
80 |
submit.click(fn=generate_image, inputs=[prompt, ckpt], outputs=img)
|
|
|
35 |
"openai/clip-vit-base-patch32"
|
36 |
)
|
37 |
|
38 |
+
def check_nsfw_images(
|
39 |
+
images: list[Image.Image],
|
40 |
+
) -> tuple[list[Image.Image], list[bool]]:
|
41 |
safety_checker_input = feature_extractor(images=[image.convert("RGB") for image in images], return_tensors="pt").to("cuda")
|
42 |
has_nsfw_concepts = safety_checker(
|
43 |
images=images,
|
44 |
clip_input=safety_checker_input.pixel_values.to("cuda"),
|
45 |
)
|
46 |
+
|
47 |
return images, has_nsfw_concepts.bool().tolist()
|
48 |
|
49 |
@spaces.GPU(enable_queue=True)
|
|
|
55 |
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if num_inference_steps == 1 else "epsilon")
|
56 |
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device="cuda"))
|
57 |
loaded = num_inference_steps
|
58 |
+
|
59 |
results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=7.5)
|
60 |
|
61 |
if SAFETY_CHECKER:
|
62 |
images, has_nsfw_concepts = check_nsfw_images(results.images)
|
63 |
if any(has_nsfw_concepts):
|
64 |
+
return Image.new("RGB", (512, 512), "black")
|
65 |
return results.images[0]
|
66 |
|
67 |
description = """
|
68 |
+
🌌 Engage in the exploration of galaxies with the advanced SDXL-Lightning model, a creation of ByteDance capable of transforming your textual descriptions into vivid images at warp speed. This is a joint venture initiated by Starfleet, enabling creative minds to visualize the uncharted territories of space. 🚀 Link to model: [ByteDance/SDXL-Lightning](https://huggingface.co/ByteDance/SDXL-Lightning)
|
|
|
69 |
"""
|
70 |
|
71 |
with gr.Blocks(css="style.css") as demo:
|
|
|
73 |
gr.Markdown(description)
|
74 |
with gr.Group():
|
75 |
with gr.Row():
|
76 |
+
prompt = gr.Textbox(label='Input your cosmic prompt (English)', placeholder="Describe the celestial phenomena...", scale=8)
|
77 |
+
ckpt = gr.Dropdown(label='Choose your Warp Factor', choices=['Warp 1', 'Warp 2', 'Warp 4', 'Warp 8'], value='Warp 4', interactive=True)
|
78 |
+
submit = gr.Button("Initiate Image Generation", scale=1, variant='primary')
|
79 |
+
img = gr.Image(label='Visual Manifestation of the Cosmos')
|
80 |
|
81 |
prompt.submit(fn=generate_image, inputs=[prompt, ckpt], outputs=img)
|
82 |
submit.click(fn=generate_image, inputs=[prompt, ckpt], outputs=img)
|