File size: 3,852 Bytes
c2c42ca
61bc6a3
81435cb
61bc6a3
 
9d41bd5
c2c42ca
aa5a24b
 
 
 
 
91dd651
c2c42ca
61bc6a3
 
 
 
c2c42ca
aa5a24b
 
 
 
 
 
 
 
 
 
 
 
 
61bc6a3
 
c2c42ca
61bc6a3
 
 
143f063
aa5a24b
 
 
 
 
 
 
 
 
 
 
 
61bc6a3
 
aa5a24b
91dd651
825bfd6
61bc6a3
 
 
 
143f063
61bc6a3
c2c42ca
 
 
 
 
61bc6a3
34cb1b5
 
5e64d98
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import gradio as gr
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler, LCMScheduler
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

### SDXL Turbo #### 
pipe_turbo = StableDiffusionXLPipeline.from_pretrained("stabilityai/sdxl-turbo",
                                                       vae=vae,
                                                       torch_dtype=torch.float16,
                                                       variant="fp16"
                                                      )
pipe_turbo.to("cuda")

### SDXL Lightning ### 
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
ckpt = "sdxl_lightning_1step_unet_x0.safetensors" 

unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt)))
pipe_lightning = StableDiffusionXLPipeline.from_pretrained(base,
                                                           unet=unet,
                                                           vae=vae,
                                                           text_encoder=pipe_turbo.text_encoder,
                                                           text_encoder_2=pipe_turbo.text_encoder_2,
                                                           tokenizer=pipe_turbo.tokenizer,
                                                           tokenizer_2=pipe_turbo.tokenizer_2,
                                                           torch_dtype=torch.float16,
                                                           variant="fp16"
                                                          )#.to("cuda")
del unet
pipe_lightning.scheduler = EulerDiscreteScheduler.from_config(pipe_lightning.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
pipe_lightning.to("cuda")

### Hyper SDXL ### 
repo_name = "ByteDance/Hyper-SD"
ckpt_name = "Hyper-SDXL-1step-Unet.safetensors"

unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo_name, ckpt_name)))
pipe_hyper = StableDiffusionXLPipeline.from_pretrained(base,
                                                       unet=unet,
                                                       vae=vae,
                                                       text_encoder=pipe_turbo.text_encoder,
                                                       text_encoder_2=pipe_turbo.text_encoder_2,
                                                       tokenizer=pipe_turbo.tokenizer,
                                                       tokenizer_2=pipe_turbo.tokenizer_2,
                                                       torch_dtype=torch.float16,
                                                       variant="fp16"
                                                      )#.to("cuda")
pipe_hyper.scheduler = LCMScheduler.from_config(pipe_hyper.scheduler.config)
pipe_hyper.to("cuda")
del unet

def run_comparison(prompt):
    image_turbo=pipe_turbo(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
    image_lightning=pipe_lightning(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
    image_hyper=pipe_hyper(prompt=prompt, num_inference_steps=1, guidance_scale=0, timesteps=[800]).images[0]
    return image_turbo, image_lightning, image_hyper

with gr.Blocks() as demo:
    prompt = gr.Textbox(label="Prompt")
    run = gr.Button("Run")
    with gr.Row():
        image_turbo = gr.Image(label="SDXL Turbo")
        image_lightning = gr.Image(label="SDXL Lightning")
        image_hyper = gr.Image(label="Hyper SDXL")
    
    run.click(fn=run_comparison, inputs=prompt, outputs=[image_turbo, image_lightning, image_hyper])
demo.launch()