File size: 3,840 Bytes
0024d7b
 
 
 
78b2da7
0024d7b
3d2c0fd
 
0024d7b
3d2c0fd
cca535e
deea868
ab79cec
 
cca535e
8b6d2e9
 
 
 
cca535e
d149eaf
0024d7b
2fc58e2
cca535e
 
d017ab0
 
3d2c0fd
 
9979ce7
3d2c0fd
 
 
 
9979ce7
3d2c0fd
 
 
c2ced1c
 
 
2fc58e2
3d2c0fd
2fc58e2
 
3d2c0fd
c2ced1c
2fc58e2
cca535e
2fc58e2
cca535e
 
d149eaf
2fc58e2
 
 
 
d149eaf
 
2fc58e2
d149eaf
 
2fc58e2
 
3d2c0fd
 
 
 
2fc58e2
 
 
3d2c0fd
 
2fc58e2
eaaecdb
3588c73
 
eaaecdb
0024d7b
cca535e
9f7adf4
cca535e
 
 
c2ced1c
 
 
 
0024d7b
eaaecdb
 
 
9f7adf4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gradio as gr
import torch
from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import spaces
import os
from PIL import Image

SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1"

# Constants
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
checkpoints = {
    "Warp 1": ["sdxl_lightning_1step_unet_x0.safetensors", 1],
    "Warp 2": ["sdxl_lightning_2step_unet.safetensors", 2],
    "Warp 4": ["sdxl_lightning_4step_unet.safetensors", 4],
    "Warp 8": ["sdxl_lightning_8step_unet.safetensors", 8],
}
loaded = None


# Ensure model and scheduler are initialized in GPU-enabled function
if torch.cuda.is_available():
    pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")

if SAFETY_CHECKER:
    from safety_checker import StableDiffusionSafetyChecker
    from transformers import CLIPFeatureExtractor

    safety_checker = StableDiffusionSafetyChecker.from_pretrained(
        "CompVis/stable-diffusion-safety-checker"
    ).to("cuda")
    feature_extractor = CLIPFeatureExtractor.from_pretrained(
        "openai/clip-vit-base-patch32"
    )

    def check_nsfw_images(
        images: list[Image.Image],
    ) -> tuple[list[Image.Image], list[bool]]:
        safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
        has_nsfw_concepts = safety_checker(
            images=[images],
            clip_input=safety_checker_input.pixel_values.to("cuda")
        )

        return images, has_nsfw_concepts

# Function 
@spaces.GPU(enable_queue=True)
def generate_image(prompt, ckpt):
    global loaded
    print(prompt, ckpt)

    checkpoint = checkpoints[ckpt][0]
    num_inference_steps = checkpoints[ckpt][1]

    if loaded != num_inference_steps:
        pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if num_inference_steps==1 else "epsilon")
        pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device="cuda"))
        loaded = num_inference_steps
        
    results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0)

    if SAFETY_CHECKER:
        images, has_nsfw_concepts = check_nsfw_images(results.images)
        if any(has_nsfw_concepts):
            gr.Warning("NSFW content detected.")
            return Image.new("RGB", (512, 512))
        return images[0]
    return results.images[0]

# Gradio Interface
description = """
🌌 Welcome to the Starfleet Command's Advanced Image Generation Console. Utilizing the cutting-edge SDXL-Lightning model, this terminal allows Starfleet officers and Federation citizens alike to materialize visual representations from textual descriptions at various warp speeds.
🖖 Boldly go where no one has gone before - create images as vast as the universe with your imagination.
"""

with gr.Blocks(css="style.css") as demo:
    gr.HTML("<h1><center>🌌 Starfleet Command: Text-to-Image Warp Drive - SDXL-Lightning ⚡</center></h1>")
    gr.Markdown(description)
    with gr.Group():
        with gr.Row():
            prompt = gr.Textbox(label='Input your cosmic prompt (English)', placeholder="Describe the celestial phenomena...", scale=8)
            ckpt = gr.Dropdown(label='Choose your Warp Factor', choices=['Warp 1', 'Warp 2', 'Warp 4', 'Warp 8'], value='Warp 4', interactive=True)
            submit = gr.Button("Initiate Image Generation", scale=1, variant='primary')
    img = gr.Image(label='Visual Manifestation of the Cosmos')

    prompt.submit(fn=generate_image, inputs=[prompt, ckpt], outputs=img)
    submit.click(fn=generate_image, inputs=[prompt, ckpt], outputs=img)

demo.queue().launch()