File size: 5,278 Bytes
c2c42ca
d827a95
81435cb
61bc6a3
 
6da6b11
9d41bd5
a1f66f7
 
c2c42ca
aa5a24b
 
 
 
 
91dd651
c2c42ca
61bc6a3
 
 
 
c2c42ca
aa5a24b
 
 
 
 
 
 
 
 
 
 
 
 
61bc6a3
 
c2c42ca
61bc6a3
 
 
143f063
aa5a24b
 
 
 
 
 
 
 
 
 
 
 
61bc6a3
 
aa5a24b
91dd651
6da6b11
370f468
61bc6a3
 
 
 
143f063
b634b72
ddbaa70
 
 
 
 
 
 
 
61bc6a3
dc81866
d452942
c2c42ca
 
 
dc81866
 
 
 
 
 
 
 
 
ddbaa70
bc87ae3
 
 
 
ddbaa70
 
 
 
 
 
4835fe3
 
 
bc87ae3
5e64d98
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import gradio as gr
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler, LCMScheduler, AutoencoderKL
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import spaces

vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)

### SDXL Turbo #### 
pipe_turbo = StableDiffusionXLPipeline.from_pretrained("stabilityai/sdxl-turbo",
                                                       vae=vae,
                                                       torch_dtype=torch.float16,
                                                       variant="fp16"
                                                      )
pipe_turbo.to("cuda")

### SDXL Lightning ### 
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
ckpt = "sdxl_lightning_1step_unet_x0.safetensors" 

unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt)))
pipe_lightning = StableDiffusionXLPipeline.from_pretrained(base,
                                                           unet=unet,
                                                           vae=vae,
                                                           text_encoder=pipe_turbo.text_encoder,
                                                           text_encoder_2=pipe_turbo.text_encoder_2,
                                                           tokenizer=pipe_turbo.tokenizer,
                                                           tokenizer_2=pipe_turbo.tokenizer_2,
                                                           torch_dtype=torch.float16,
                                                           variant="fp16"
                                                          )#.to("cuda")
del unet
pipe_lightning.scheduler = EulerDiscreteScheduler.from_config(pipe_lightning.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
pipe_lightning.to("cuda")

### Hyper SDXL ### 
repo_name = "ByteDance/Hyper-SD"
ckpt_name = "Hyper-SDXL-1step-Unet.safetensors"

unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo_name, ckpt_name)))
pipe_hyper = StableDiffusionXLPipeline.from_pretrained(base,
                                                       unet=unet,
                                                       vae=vae,
                                                       text_encoder=pipe_turbo.text_encoder,
                                                       text_encoder_2=pipe_turbo.text_encoder_2,
                                                       tokenizer=pipe_turbo.tokenizer,
                                                       tokenizer_2=pipe_turbo.tokenizer_2,
                                                       torch_dtype=torch.float16,
                                                       variant="fp16"
                                                      )#.to("cuda")
pipe_hyper.scheduler = LCMScheduler.from_config(pipe_hyper.scheduler.config)
pipe_hyper.to("cuda")
del unet

@spaces.GPU
def run_comparison(prompt, progress=gr.Progress(track_tqdm=True)):
    image_turbo=pipe_turbo(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
    image_lightning=pipe_lightning(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
    image_hyper=pipe_hyper(prompt=prompt, num_inference_steps=1, guidance_scale=0, timesteps=[800]).images[0]
    return image_turbo, image_lightning, image_hyper

examples = ["A dignified beaver wearing glasses, a vest, and colorful neck tie.",
"The spirit of a tamagotchi wandering in the city of Barcelona",
"an ornate, high-backed mahogany chair with a red cushion",
"a sketch of a camel next to a stream",
"a delicate porcelain teacup sits on a saucer, its surface adorned with intricate blue patterns",
"a baby swan grafitti",
"A bald eagle made of chocolate powder, mango, and whipped cream"
]

with gr.Blocks() as demo:
    gr.Markdown("## One step SDXL comparison 🦶")
    gr.Markdown('Compare SDXL variants and distillations able to generate images in a single diffusion step')
    prompt = gr.Textbox(label="Prompt")
    run = gr.Button("Run")
    with gr.Row():
        with gr.Column():
            image_turbo = gr.Image(label="SDXL Turbo")
            gr.Markdown("## [SDXL Turbo](https://huggingface.co/stabilityai/sdxl-turbo)")
        with gr.Column():
            image_lightning = gr.Image(label="SDXL Lightning")
            gr.Markdown("## [SDXL Lightning](https://huggingface.co/ByteDance/SDXL-Lightning)")
        with gr.Column():
            image_hyper = gr.Image(label="Hyper SDXL")
            gr.Markdown("## [Hyper SDXL](https://huggingface.co/ByteDance/Hyper-SD)")
    image_outputs = [image_turbo, image_lightning, image_hyper]
    gr.on(
        triggers=[prompt.submit, run.click],
        fn=run_comparison,
        inputs=prompt,
        outputs=image_outputs
    )
    gr.Examples(
        examples=examples,
        fn=run_comparison,
        inputs=prompt,
        outputs=image_outputs,
        cache_examples=False,
        run_on_click=True
    )
demo.launch()