Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,218 Bytes
c2c42ca 9d41bd5 c2c42ca 9d41bd5 c2c42ca 9d41bd5 4ac12a1 c2c42ca 4ac12a1 3bd17ee 9d41bd5 e10dc6d c2c42ca 4ac12a1 c2c42ca 4ac12a1 3bd17ee 9d41bd5 c2c42ca 4ac12a1 e10dc6d 143f063 34cb1b5 4ac12a1 c2c42ca 4ac12a1 c2c42ca 4ac12a1 c2c42ca 4ac12a1 c2c42ca 143f063 c2c42ca 34cb1b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import gradio as gr
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler, LCMScheduler, AutoencoderKL
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import spaces
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
### SDXL Turbo ####
pipe_turbo = StableDiffusionXLPipeline.from_pretrained("stabilityai/sdxl-turbo",
vae=vae,
torch_dtype=torch.float16,
variant="fp16"
)
#pipe_turbo.to("cuda")
### SDXL Lightning ###
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
ckpt = "sdxl_lightning_1step_unet_x0.safetensors"
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt)))
pipe_lightning = StableDiffusionXLPipeline.from_pretrained(base,
unet=unet,
vae=vae,
text_encoder_1=pipe_turbo.text_encoder_1,
text_encoder_2=pipe_turbo.text_encoder_2,
tokenizer=pipe_turbo.tokenizer,
tokenizer_2=pipe_turbo.tokenizer_2,
torch_dtype=torch.float16,
variant="fp16"
)#.to("cuda")
del unet
pipe_lightning.scheduler = EulerDiscreteScheduler.from_config(pipe_lightning.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
#pipe_lightning.to("cuda")
### Hyper SDXL ###
repo_name = "ByteDance/Hyper-SD"
ckpt_name = "Hyper-SDXL-1step-Unet.safetensors"
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo_name, ckpt_name)))
pipe_hyper = StableDiffusionXLPipeline.from_pretrained(base,
unet=unet,
vae=vae,
text_encoder_1=pipe_turbo.text_encoder_1,
text_encoder_2=pipe_turbo.text_encoder_2,
tokenizer=pipe_turbo.tokenizer,
tokenizer_2=pipe_turbo.tokenizer_2,
torch_dtype=torch.float16,
variant="fp16"
)#.to("cuda")
pipe_hyper.scheduler = LCMScheduler.from_config(pipe_hyper.scheduler.config)
#pipe_hyper.to("cuda")
del unet
@spaces.GPU
def run_comparison(prompt):
image_turbo.to("cuda")
image_turbo=pipe_turbo(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
image_turbo.to("cpu")
image_lightning.to("cuda")
image_lightning=pipe_lightning(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
image_lightning.to("cpu")
image_hyper.to("cuda")
image_hyper=pipe_hyper(prompt=prompt, num_inference_steps=1, guidance_scale=0, timesteps=[800]).images[0]
image_turbo.to("cpu")
return image_turbo, image_lightning, image_hyper
css = '''
.gradio-container{max-width: 768px !important}
'''
with gr.Blocks(css=css) as demo:
prompt = gr.Textbox(label="Prompt")
run = gr.Button("Run")
with gr.Row():
image_turbo = gr.Image(label="SDXL Turbo")
image_lightning = gr.Image(label="SDXL Lightning")
image_hyper = gr.Image("Hyper SDXL")
run.click(fn=run_comparison, inputs=prompt, outputs=[image_turbo, image_lightning, image_hyper])
|