File size: 4,218 Bytes
c2c42ca
9d41bd5
c2c42ca
 
 
 
 
9d41bd5
 
c2c42ca
9d41bd5
 
 
 
 
4ac12a1
c2c42ca
 
 
 
 
 
4ac12a1
3bd17ee
9d41bd5
 
 
 
 
 
 
 
 
 
e10dc6d
c2c42ca
4ac12a1
c2c42ca
 
 
 
 
4ac12a1
3bd17ee
9d41bd5
 
 
 
 
 
 
 
 
 
c2c42ca
4ac12a1
e10dc6d
143f063
 
34cb1b5
4ac12a1
c2c42ca
4ac12a1
 
c2c42ca
4ac12a1
 
c2c42ca
4ac12a1
c2c42ca
143f063
 
c2c42ca
 
 
 
 
 
 
 
 
 
34cb1b5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import gradio as gr
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler, LCMScheduler, AutoencoderKL
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import spaces

vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)

### SDXL Turbo #### 
pipe_turbo = StableDiffusionXLPipeline.from_pretrained("stabilityai/sdxl-turbo",
                                                       vae=vae,
                                                       torch_dtype=torch.float16,
                                                       variant="fp16"
                                                      )
#pipe_turbo.to("cuda")

### SDXL Lightning ### 
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
ckpt = "sdxl_lightning_1step_unet_x0.safetensors" 

unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt)))
pipe_lightning = StableDiffusionXLPipeline.from_pretrained(base,
                                                           unet=unet,
                                                           vae=vae,
                                                           text_encoder_1=pipe_turbo.text_encoder_1,
                                                           text_encoder_2=pipe_turbo.text_encoder_2,
                                                           tokenizer=pipe_turbo.tokenizer,
                                                           tokenizer_2=pipe_turbo.tokenizer_2,
                                                           torch_dtype=torch.float16,
                                                           variant="fp16"
                                                          )#.to("cuda")
del unet
pipe_lightning.scheduler = EulerDiscreteScheduler.from_config(pipe_lightning.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
#pipe_lightning.to("cuda")

### Hyper SDXL ### 
repo_name = "ByteDance/Hyper-SD"
ckpt_name = "Hyper-SDXL-1step-Unet.safetensors"

unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo_name, ckpt_name)))
pipe_hyper = StableDiffusionXLPipeline.from_pretrained(base,
                                                       unet=unet,
                                                       vae=vae,
                                                       text_encoder_1=pipe_turbo.text_encoder_1,
                                                       text_encoder_2=pipe_turbo.text_encoder_2,
                                                       tokenizer=pipe_turbo.tokenizer,
                                                       tokenizer_2=pipe_turbo.tokenizer_2,
                                                       torch_dtype=torch.float16,
                                                       variant="fp16"
                                                      )#.to("cuda")
pipe_hyper.scheduler = LCMScheduler.from_config(pipe_hyper.scheduler.config)
#pipe_hyper.to("cuda")
del unet

@spaces.GPU
def run_comparison(prompt):
    image_turbo.to("cuda")
    image_turbo=pipe_turbo(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
    image_turbo.to("cpu")
    image_lightning.to("cuda")
    image_lightning=pipe_lightning(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
    image_lightning.to("cpu")
    image_hyper.to("cuda")
    image_hyper=pipe_hyper(prompt=prompt, num_inference_steps=1, guidance_scale=0, timesteps=[800]).images[0]
    image_turbo.to("cpu")
    return image_turbo, image_lightning, image_hyper


css = '''
.gradio-container{max-width: 768px !important}
'''
with gr.Blocks(css=css) as demo:
    prompt = gr.Textbox(label="Prompt")
    run = gr.Button("Run")
    with gr.Row():
        image_turbo = gr.Image(label="SDXL Turbo")
        image_lightning = gr.Image(label="SDXL Lightning")
        image_hyper = gr.Image("Hyper SDXL")
    
    run.click(fn=run_comparison, inputs=prompt, outputs=[image_turbo, image_lightning, image_hyper])