Spaces:
Runtime error
Runtime error
Kabatubare
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -13,41 +13,39 @@ SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1"
|
|
13 |
base = "stabilityai/stable-diffusion-xl-base-1.0"
|
14 |
repo = "ByteDance/SDXL-Lightning"
|
15 |
checkpoints = {
|
16 |
-
"Warp 1"
|
17 |
-
"Warp 2"
|
18 |
-
"Warp 4"
|
19 |
-
"Warp 8"
|
20 |
}
|
21 |
loaded = None
|
22 |
|
23 |
-
|
24 |
# Ensure model and scheduler are initialized in GPU-enabled function
|
25 |
if torch.cuda.is_available():
|
26 |
pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
|
27 |
|
28 |
if SAFETY_CHECKER:
|
29 |
from safety_checker import StableDiffusionSafetyChecker
|
30 |
-
from transformers import CLIPFeatureExtractor
|
31 |
|
32 |
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
33 |
"CompVis/stable-diffusion-safety-checker"
|
34 |
).to("cuda")
|
35 |
-
|
|
|
36 |
"openai/clip-vit-base-patch32"
|
37 |
)
|
38 |
|
39 |
-
def check_nsfw_images(
|
40 |
-
images: list[Image.Image],
|
41 |
-
) -> tuple[list[Image.Image], list[bool]]:
|
42 |
safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
|
43 |
has_nsfw_concepts = safety_checker(
|
44 |
-
images=
|
45 |
-
clip_input=safety_checker_input.pixel_values.to("cuda")
|
46 |
)
|
47 |
|
48 |
return images, has_nsfw_concepts
|
49 |
|
50 |
-
# Function
|
51 |
@spaces.GPU(enable_queue=True)
|
52 |
def generate_image(prompt, ckpt):
|
53 |
global loaded
|
@@ -57,10 +55,10 @@ def generate_image(prompt, ckpt):
|
|
57 |
num_inference_steps = checkpoints[ckpt][1]
|
58 |
|
59 |
if loaded != num_inference_steps:
|
60 |
-
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if num_inference_steps==1 else "epsilon")
|
61 |
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device="cuda"))
|
62 |
loaded = num_inference_steps
|
63 |
-
|
64 |
results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0)
|
65 |
|
66 |
if SAFETY_CHECKER:
|
@@ -71,8 +69,11 @@ def generate_image(prompt, ckpt):
|
|
71 |
return images[0]
|
72 |
return results.images[0]
|
73 |
|
|
|
|
|
|
|
|
|
74 |
|
75 |
-
# Update the dropdown choices to match the corrected keys in the 'checkpoints' dictionary
|
76 |
with gr.Blocks(css="style.css") as demo:
|
77 |
gr.HTML("<h1><center>🌌 Starfleet Command: Text-to-Image Warp Drive - SDXL-Lightning ⚡</center></h1>")
|
78 |
gr.Markdown(description)
|
@@ -83,13 +84,7 @@ with gr.Blocks(css="style.css") as demo:
|
|
83 |
submit = gr.Button(scale=1, variant='primary')
|
84 |
img = gr.Image(label='The Universe, As You Envision It')
|
85 |
|
86 |
-
prompt.submit(fn=generate_image,
|
87 |
-
|
88 |
-
|
89 |
-
)
|
90 |
-
submit.click(fn=generate_image,
|
91 |
-
inputs=[prompt, ckpt],
|
92 |
-
outputs=img,
|
93 |
-
)
|
94 |
-
|
95 |
demo.queue().launch()
|
|
|
13 |
base = "stabilityai/stable-diffusion-xl-base-1.0"
|
14 |
repo = "ByteDance/SDXL-Lightning"
|
15 |
checkpoints = {
|
16 |
+
"Warp 1": ["sdxl_lightning_1step_unet_x0.safetensors", 1],
|
17 |
+
"Warp 2": ["sdxl_lightning_2step_unet.safetensors", 2],
|
18 |
+
"Warp 4": ["sdxl_lightning_4step_unet.safetensors", 4],
|
19 |
+
"Warp 8": ["sdxl_lightning_8step_unet.safetensors", 8],
|
20 |
}
|
21 |
loaded = None
|
22 |
|
|
|
23 |
# Ensure model and scheduler are initialized in GPU-enabled function
|
24 |
if torch.cuda.is_available():
|
25 |
pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
|
26 |
|
27 |
if SAFETY_CHECKER:
|
28 |
from safety_checker import StableDiffusionSafetyChecker
|
29 |
+
from transformers import CLIPFeatureExtractor, CLIPProcessor
|
30 |
|
31 |
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
32 |
"CompVis/stable-diffusion-safety-checker"
|
33 |
).to("cuda")
|
34 |
+
# Updated to use CLIPProcessor as CLIPFeatureExtractor is deprecated
|
35 |
+
feature_extractor = CLIPProcessor.from_pretrained(
|
36 |
"openai/clip-vit-base-patch32"
|
37 |
)
|
38 |
|
39 |
+
def check_nsfw_images(images: list[Image.Image]) -> tuple[list[Image.Image], list[bool]]:
|
|
|
|
|
40 |
safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
|
41 |
has_nsfw_concepts = safety_checker(
|
42 |
+
images=images,
|
43 |
+
clip_input=safety_checker_input.pixel_values.to("cuda"),
|
44 |
)
|
45 |
|
46 |
return images, has_nsfw_concepts
|
47 |
|
48 |
+
# Function
|
49 |
@spaces.GPU(enable_queue=True)
|
50 |
def generate_image(prompt, ckpt):
|
51 |
global loaded
|
|
|
55 |
num_inference_steps = checkpoints[ckpt][1]
|
56 |
|
57 |
if loaded != num_inference_steps:
|
58 |
+
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if num_inference_steps == 1 else "epsilon")
|
59 |
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device="cuda"))
|
60 |
loaded = num_inference_steps
|
61 |
+
|
62 |
results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0)
|
63 |
|
64 |
if SAFETY_CHECKER:
|
|
|
69 |
return images[0]
|
70 |
return results.images[0]
|
71 |
|
72 |
+
# Gradio Interface
|
73 |
+
description = """
|
74 |
+
Welcome aboard the Starship SDXL Enterprise! Our mission: to explore strange new AI generations, to seek out new visual frontiers and computational boundaries, to boldly generate images like no one has seen before. Utilizing the cutting-edge SDXL-Lightning model, we're at the forefront of text-to-image technology, ready to transform your imaginative prompts into visual spectacles. Whether you're navigating the uncharted territories of outer space or delving into the realms of fantasy, your adventure begins now. Model powered by the pioneering intellects at ByteDance. Journey safely through the stars!
|
75 |
+
"""
|
76 |
|
|
|
77 |
with gr.Blocks(css="style.css") as demo:
|
78 |
gr.HTML("<h1><center>🌌 Starfleet Command: Text-to-Image Warp Drive - SDXL-Lightning ⚡</center></h1>")
|
79 |
gr.Markdown(description)
|
|
|
84 |
submit = gr.Button(scale=1, variant='primary')
|
85 |
img = gr.Image(label='The Universe, As You Envision It')
|
86 |
|
87 |
+
prompt.submit(fn=generate_image, inputs=[prompt, ckpt], outputs=img)
|
88 |
+
submit.click(fn=generate_image, inputs=[prompt, ckpt], outputs=img)
|
89 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
demo.queue().launch()
|