Spaces:
Runtime error
Runtime error
File size: 3,840 Bytes
0024d7b 78b2da7 0024d7b 3d2c0fd 0024d7b 3d2c0fd cca535e deea868 ab79cec cca535e 8b6d2e9 cca535e d149eaf 0024d7b 2fc58e2 cca535e d017ab0 3d2c0fd 9979ce7 3d2c0fd 9979ce7 3d2c0fd c2ced1c 2fc58e2 3d2c0fd 2fc58e2 3d2c0fd c2ced1c 2fc58e2 cca535e 2fc58e2 cca535e d149eaf 2fc58e2 d149eaf 2fc58e2 d149eaf 2fc58e2 3d2c0fd 2fc58e2 3d2c0fd 2fc58e2 eaaecdb 3588c73 eaaecdb 0024d7b cca535e 9f7adf4 cca535e c2ced1c 0024d7b eaaecdb 9f7adf4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import gradio as gr
import torch
from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import spaces
import os
from PIL import Image
SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1"
# Constants
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
checkpoints = {
"Warp 1": ["sdxl_lightning_1step_unet_x0.safetensors", 1],
"Warp 2": ["sdxl_lightning_2step_unet.safetensors", 2],
"Warp 4": ["sdxl_lightning_4step_unet.safetensors", 4],
"Warp 8": ["sdxl_lightning_8step_unet.safetensors", 8],
}
loaded = None
# Ensure model and scheduler are initialized in GPU-enabled function
if torch.cuda.is_available():
pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
if SAFETY_CHECKER:
from safety_checker import StableDiffusionSafetyChecker
from transformers import CLIPFeatureExtractor
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker"
).to("cuda")
feature_extractor = CLIPFeatureExtractor.from_pretrained(
"openai/clip-vit-base-patch32"
)
def check_nsfw_images(
images: list[Image.Image],
) -> tuple[list[Image.Image], list[bool]]:
safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
has_nsfw_concepts = safety_checker(
images=[images],
clip_input=safety_checker_input.pixel_values.to("cuda")
)
return images, has_nsfw_concepts
# Function
@spaces.GPU(enable_queue=True)
def generate_image(prompt, ckpt):
global loaded
print(prompt, ckpt)
checkpoint = checkpoints[ckpt][0]
num_inference_steps = checkpoints[ckpt][1]
if loaded != num_inference_steps:
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if num_inference_steps==1 else "epsilon")
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device="cuda"))
loaded = num_inference_steps
results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0)
if SAFETY_CHECKER:
images, has_nsfw_concepts = check_nsfw_images(results.images)
if any(has_nsfw_concepts):
gr.Warning("NSFW content detected.")
return Image.new("RGB", (512, 512))
return images[0]
return results.images[0]
# Gradio Interface
description = """
🌌 Welcome to the Starfleet Command's Advanced Image Generation Console. Utilizing the cutting-edge SDXL-Lightning model, this terminal allows Starfleet officers and Federation citizens alike to materialize visual representations from textual descriptions at various warp speeds.
🖖 Boldly go where no one has gone before - create images as vast as the universe with your imagination.
"""
with gr.Blocks(css="style.css") as demo:
gr.HTML("<h1><center>🌌 Starfleet Command: Text-to-Image Warp Drive - SDXL-Lightning ⚡</center></h1>")
gr.Markdown(description)
with gr.Group():
with gr.Row():
prompt = gr.Textbox(label='Input your cosmic prompt (English)', placeholder="Describe the celestial phenomena...", scale=8)
ckpt = gr.Dropdown(label='Choose your Warp Factor', choices=['Warp 1', 'Warp 2', 'Warp 4', 'Warp 8'], value='Warp 4', interactive=True)
submit = gr.Button("Initiate Image Generation", scale=1, variant='primary')
img = gr.Image(label='Visual Manifestation of the Cosmos')
prompt.submit(fn=generate_image, inputs=[prompt, ckpt], outputs=img)
submit.click(fn=generate_image, inputs=[prompt, ckpt], outputs=img)
demo.queue().launch()
|