File size: 4,247 Bytes
c2c42ca
9d41bd5
c2c42ca
 
 
de56cd9
c2c42ca
 
de56cd9
9d41bd5
 
c2c42ca
9d41bd5
 
 
 
 
4ac12a1
c2c42ca
 
 
 
 
 
4ac12a1
3bd17ee
9d41bd5
 
 
428d2aa
9d41bd5
 
 
 
 
 
e10dc6d
de56cd9
c2c42ca
4ac12a1
c2c42ca
 
 
 
 
4ac12a1
3bd17ee
9d41bd5
 
 
428d2aa
9d41bd5
 
 
 
 
 
c2c42ca
4ac12a1
e10dc6d
de56cd9
143f063
 
34cb1b5
4ac12a1
c2c42ca
4ac12a1
 
c2c42ca
4ac12a1
 
c2c42ca
4ac12a1
c2c42ca
143f063
 
c2c42ca
 
 
 
 
 
 
 
 
 
34cb1b5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import gradio as gr
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler, LCMScheduler, AutoencoderKL
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import gc
import spaces


vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)

### SDXL Turbo #### 
pipe_turbo = StableDiffusionXLPipeline.from_pretrained("stabilityai/sdxl-turbo",
                                                       vae=vae,
                                                       torch_dtype=torch.float16,
                                                       variant="fp16"
                                                      )
#pipe_turbo.to("cuda")

### SDXL Lightning ### 
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
ckpt = "sdxl_lightning_1step_unet_x0.safetensors" 

unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt)))
pipe_lightning = StableDiffusionXLPipeline.from_pretrained(base,
                                                           unet=unet,
                                                           vae=vae,
                                                           text_encoder=pipe_turbo.text_encoder,
                                                           text_encoder_2=pipe_turbo.text_encoder_2,
                                                           tokenizer=pipe_turbo.tokenizer,
                                                           tokenizer_2=pipe_turbo.tokenizer_2,
                                                           torch_dtype=torch.float16,
                                                           variant="fp16"
                                                          )#.to("cuda")
del unet
gc.collect()
pipe_lightning.scheduler = EulerDiscreteScheduler.from_config(pipe_lightning.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
#pipe_lightning.to("cuda")

### Hyper SDXL ### 
repo_name = "ByteDance/Hyper-SD"
ckpt_name = "Hyper-SDXL-1step-Unet.safetensors"

unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo_name, ckpt_name)))
pipe_hyper = StableDiffusionXLPipeline.from_pretrained(base,
                                                       unet=unet,
                                                       vae=vae,
                                                       text_encoder=pipe_turbo.text_encoder,
                                                       text_encoder_2=pipe_turbo.text_encoder_2,
                                                       tokenizer=pipe_turbo.tokenizer,
                                                       tokenizer_2=pipe_turbo.tokenizer_2,
                                                       torch_dtype=torch.float16,
                                                       variant="fp16"
                                                      )#.to("cuda")
pipe_hyper.scheduler = LCMScheduler.from_config(pipe_hyper.scheduler.config)
#pipe_hyper.to("cuda")
del unet
gc.collect()

@spaces.GPU
def run_comparison(prompt):
    image_turbo.to("cuda")
    image_turbo=pipe_turbo(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
    image_turbo.to("cpu")
    image_lightning.to("cuda")
    image_lightning=pipe_lightning(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
    image_lightning.to("cpu")
    image_hyper.to("cuda")
    image_hyper=pipe_hyper(prompt=prompt, num_inference_steps=1, guidance_scale=0, timesteps=[800]).images[0]
    image_turbo.to("cpu")
    return image_turbo, image_lightning, image_hyper


css = '''
.gradio-container{max-width: 768px !important}
'''
with gr.Blocks(css=css) as demo:
    prompt = gr.Textbox(label="Prompt")
    run = gr.Button("Run")
    with gr.Row():
        image_turbo = gr.Image(label="SDXL Turbo")
        image_lightning = gr.Image(label="SDXL Lightning")
        image_hyper = gr.Image("Hyper SDXL")
    
    run.click(fn=run_comparison, inputs=prompt, outputs=[image_turbo, image_lightning, image_hyper])