Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,964 Bytes
c2c42ca d827a95 81435cb 61bc6a3 9d41bd5 a1f66f7 c2c42ca aa5a24b 91dd651 c2c42ca 61bc6a3 c2c42ca aa5a24b 61bc6a3 c2c42ca 61bc6a3 143f063 aa5a24b 61bc6a3 aa5a24b 91dd651 825bfd6 61bc6a3 143f063 61bc6a3 c2c42ca 61bc6a3 34cb1b5 5e64d98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import gradio as gr
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler, LCMScheduler, AutoencoderKL
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
### SDXL Turbo ####
pipe_turbo = StableDiffusionXLPipeline.from_pretrained("stabilityai/sdxl-turbo",
vae=vae,
torch_dtype=torch.float16,
variant="fp16"
)
pipe_turbo.to("cuda")
### SDXL Lightning ###
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
ckpt = "sdxl_lightning_1step_unet_x0.safetensors"
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt)))
pipe_lightning = StableDiffusionXLPipeline.from_pretrained(base,
unet=unet,
vae=vae,
text_encoder=pipe_turbo.text_encoder,
text_encoder_2=pipe_turbo.text_encoder_2,
tokenizer=pipe_turbo.tokenizer,
tokenizer_2=pipe_turbo.tokenizer_2,
torch_dtype=torch.float16,
variant="fp16"
)#.to("cuda")
del unet
pipe_lightning.scheduler = EulerDiscreteScheduler.from_config(pipe_lightning.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
pipe_lightning.to("cuda")
### Hyper SDXL ###
repo_name = "ByteDance/Hyper-SD"
ckpt_name = "Hyper-SDXL-1step-Unet.safetensors"
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo_name, ckpt_name)))
pipe_hyper = StableDiffusionXLPipeline.from_pretrained(base,
unet=unet,
vae=vae,
text_encoder=pipe_turbo.text_encoder,
text_encoder_2=pipe_turbo.text_encoder_2,
tokenizer=pipe_turbo.tokenizer,
tokenizer_2=pipe_turbo.tokenizer_2,
torch_dtype=torch.float16,
variant="fp16"
)#.to("cuda")
pipe_hyper.scheduler = LCMScheduler.from_config(pipe_hyper.scheduler.config)
pipe_hyper.to("cuda")
del unet
def run_comparison(prompt):
image_turbo=pipe_turbo(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
image_lightning=pipe_lightning(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
image_hyper=pipe_hyper(prompt=prompt, num_inference_steps=1, guidance_scale=0, timesteps=[800]).images[0]
return image_turbo, image_lightning, image_hyper
with gr.Blocks() as demo:
prompt = gr.Textbox(label="Prompt")
run = gr.Button("Run")
with gr.Row():
image_turbo = gr.Image(label="SDXL Turbo")
image_lightning = gr.Image(label="SDXL Lightning")
image_hyper = gr.Image(label="Hyper SDXL")
run.click(fn=run_comparison, inputs=prompt, outputs=[image_turbo, image_lightning, image_hyper])
demo.launch() |