multimodalart HF staff commited on
Commit
c2c42ca
·
verified ·
1 Parent(s): e6973b7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler, LCMScheduler
3
+ import torch
4
+ from huggingface_hub import hf_hub_download
5
+ from safetensors.torch import load_file
6
+ import spaces
7
+
8
+ ### SDXL Turbo ####
9
+
10
+ pipe_turbo = StableDiffusionXLPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
11
+ pipe_turbo.to("cuda")
12
+
13
+ ### SDXL Lightning ###
14
+
15
+ base = "stabilityai/stable-diffusion-xl-base-1.0"
16
+ repo = "ByteDance/SDXL-Lightning"
17
+ ckpt = "sdxl_lightning_1step_unet_x0.safetensors"
18
+
19
+ unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
20
+ unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
21
+ pipe_lightning = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
22
+
23
+ pipe_lightning.scheduler = EulerDiscreteScheduler.from_config(pipe_lightning.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
24
+ pipe_lightning.to("cuda")
25
+
26
+ ### Hyper SDXL ###
27
+ repo_name = "ByteDance/Hyper-SD"
28
+ ckpt_name = "Hyper-SDXL-1step-Unet.safetensors"
29
+
30
+ unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
31
+ unet.load_state_dict(load_file(hf_hub_download(repo_name, ckpt_name), device="cuda"))
32
+ pipe_hyper = DiffusionPipeline.from_pretrained(base_model_id, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
33
+ pipe_hyper.scheduler = LCMScheduler.from_config(pipe_hyper.scheduler.config)
34
+ pipe_hyper.to("cuda")
35
+
36
+ def run(prompt):
37
+ image_turbo=pipe_turbo(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
38
+ image_lightning=pipe_lightning(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
39
+ image_hyper=pipe_hyper(prompt=prompt, num_inference_steps=1, guidance_scale=0, timesteps=[800]).images[0]
40
+ return image_turbo, image_lightning, image_hyper
41
+ css = '''
42
+ .gradio-container{max-width: 768px !important}
43
+ '''
44
+
45
+ @spaces.GPU
46
+ with gr.Blocks(css=css) as demo:
47
+ prompt = gr.Textbox(label="Prompt")
48
+ run = gr.Button("Run")
49
+ with gr.Row():
50
+ image_turbo = gr.Image(label="SDXL Turbo")
51
+ image_lightning = gr.Image(label="SDXL Lightning")
52
+ image_hyper = gr.Image("Hyper SDXL")
53
+ run.click(fn=run, inputs=prompt, outputs=[image_turbo, image_lightning, image_hyper])