Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,247 Bytes
c2c42ca 9d41bd5 c2c42ca de56cd9 c2c42ca de56cd9 9d41bd5 c2c42ca 9d41bd5 4ac12a1 c2c42ca 4ac12a1 3bd17ee 9d41bd5 428d2aa 9d41bd5 e10dc6d de56cd9 c2c42ca 4ac12a1 c2c42ca 4ac12a1 3bd17ee 9d41bd5 428d2aa 9d41bd5 c2c42ca 4ac12a1 e10dc6d de56cd9 143f063 34cb1b5 4ac12a1 c2c42ca 4ac12a1 c2c42ca 4ac12a1 c2c42ca 4ac12a1 c2c42ca 143f063 c2c42ca 34cb1b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import gradio as gr
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler, LCMScheduler, AutoencoderKL
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import gc
import spaces
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
### SDXL Turbo ####
pipe_turbo = StableDiffusionXLPipeline.from_pretrained("stabilityai/sdxl-turbo",
vae=vae,
torch_dtype=torch.float16,
variant="fp16"
)
#pipe_turbo.to("cuda")
### SDXL Lightning ###
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
ckpt = "sdxl_lightning_1step_unet_x0.safetensors"
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt)))
pipe_lightning = StableDiffusionXLPipeline.from_pretrained(base,
unet=unet,
vae=vae,
text_encoder=pipe_turbo.text_encoder,
text_encoder_2=pipe_turbo.text_encoder_2,
tokenizer=pipe_turbo.tokenizer,
tokenizer_2=pipe_turbo.tokenizer_2,
torch_dtype=torch.float16,
variant="fp16"
)#.to("cuda")
del unet
gc.collect()
pipe_lightning.scheduler = EulerDiscreteScheduler.from_config(pipe_lightning.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
#pipe_lightning.to("cuda")
### Hyper SDXL ###
repo_name = "ByteDance/Hyper-SD"
ckpt_name = "Hyper-SDXL-1step-Unet.safetensors"
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo_name, ckpt_name)))
pipe_hyper = StableDiffusionXLPipeline.from_pretrained(base,
unet=unet,
vae=vae,
text_encoder=pipe_turbo.text_encoder,
text_encoder_2=pipe_turbo.text_encoder_2,
tokenizer=pipe_turbo.tokenizer,
tokenizer_2=pipe_turbo.tokenizer_2,
torch_dtype=torch.float16,
variant="fp16"
)#.to("cuda")
pipe_hyper.scheduler = LCMScheduler.from_config(pipe_hyper.scheduler.config)
#pipe_hyper.to("cuda")
del unet
gc.collect()
@spaces.GPU
def run_comparison(prompt):
image_turbo.to("cuda")
image_turbo=pipe_turbo(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
image_turbo.to("cpu")
image_lightning.to("cuda")
image_lightning=pipe_lightning(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
image_lightning.to("cpu")
image_hyper.to("cuda")
image_hyper=pipe_hyper(prompt=prompt, num_inference_steps=1, guidance_scale=0, timesteps=[800]).images[0]
image_turbo.to("cpu")
return image_turbo, image_lightning, image_hyper
css = '''
.gradio-container{max-width: 768px !important}
'''
with gr.Blocks(css=css) as demo:
prompt = gr.Textbox(label="Prompt")
run = gr.Button("Run")
with gr.Row():
image_turbo = gr.Image(label="SDXL Turbo")
image_lightning = gr.Image(label="SDXL Lightning")
image_hyper = gr.Image("Hyper SDXL")
run.click(fn=run_comparison, inputs=prompt, outputs=[image_turbo, image_lightning, image_hyper])
|