multimodalart HF staff commited on
Commit
9d41bd5
·
verified ·
1 Parent(s): 3bd17ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -4
app.py CHANGED
@@ -1,12 +1,18 @@
1
  import gradio as gr
2
- from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler, LCMScheduler
3
  import torch
4
  from huggingface_hub import hf_hub_download
5
  from safetensors.torch import load_file
6
  import spaces
7
 
 
 
8
  ### SDXL Turbo ####
9
- pipe_turbo = StableDiffusionXLPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
 
 
 
 
10
  #pipe_turbo.to("cuda")
11
 
12
  ### SDXL Lightning ###
@@ -16,7 +22,16 @@ ckpt = "sdxl_lightning_1step_unet_x0.safetensors"
16
 
17
  unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
18
  unet.load_state_dict(load_file(hf_hub_download(repo, ckpt)))
19
- pipe_lightning = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16")#.to("cuda")
 
 
 
 
 
 
 
 
 
20
  del unet
21
  pipe_lightning.scheduler = EulerDiscreteScheduler.from_config(pipe_lightning.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
22
  #pipe_lightning.to("cuda")
@@ -27,7 +42,16 @@ ckpt_name = "Hyper-SDXL-1step-Unet.safetensors"
27
 
28
  unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
29
  unet.load_state_dict(load_file(hf_hub_download(repo_name, ckpt_name)))
30
- pipe_hyper = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16")#.to("cuda")
 
 
 
 
 
 
 
 
 
31
  pipe_hyper.scheduler = LCMScheduler.from_config(pipe_hyper.scheduler.config)
32
  #pipe_hyper.to("cuda")
33
  del unet
 
1
  import gradio as gr
2
+ from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler, LCMScheduler, AutoencoderKL
3
  import torch
4
  from huggingface_hub import hf_hub_download
5
  from safetensors.torch import load_file
6
  import spaces
7
 
8
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
9
+
10
  ### SDXL Turbo ####
11
+ pipe_turbo = StableDiffusionXLPipeline.from_pretrained("stabilityai/sdxl-turbo",
12
+ vae=vae,
13
+ torch_dtype=torch.float16,
14
+ variant="fp16"
15
+ )
16
  #pipe_turbo.to("cuda")
17
 
18
  ### SDXL Lightning ###
 
22
 
23
  unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
24
  unet.load_state_dict(load_file(hf_hub_download(repo, ckpt)))
25
+ pipe_lightning = StableDiffusionXLPipeline.from_pretrained(base,
26
+ unet=unet,
27
+ vae=vae,
28
+ text_encoder_1=pipe_turbo.text_encoder_1,
29
+ text_encoder_2=pipe_turbo.text_encoder_2,
30
+ tokenizer=pipe_turbo.tokenizer,
31
+ tokenizer_2=pipe_turbo.tokenizer_2,
32
+ torch_dtype=torch.float16,
33
+ variant="fp16"
34
+ )#.to("cuda")
35
  del unet
36
  pipe_lightning.scheduler = EulerDiscreteScheduler.from_config(pipe_lightning.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
37
  #pipe_lightning.to("cuda")
 
42
 
43
  unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
44
  unet.load_state_dict(load_file(hf_hub_download(repo_name, ckpt_name)))
45
+ pipe_hyper = StableDiffusionXLPipeline.from_pretrained(base,
46
+ unet=unet,
47
+ vae=vae,
48
+ text_encoder_1=pipe_turbo.text_encoder_1,
49
+ text_encoder_2=pipe_turbo.text_encoder_2,
50
+ tokenizer=pipe_turbo.tokenizer,
51
+ tokenizer_2=pipe_turbo.tokenizer_2,
52
+ torch_dtype=torch.float16,
53
+ variant="fp16"
54
+ )#.to("cuda")
55
  pipe_hyper.scheduler = LCMScheduler.from_config(pipe_hyper.scheduler.config)
56
  #pipe_hyper.to("cuda")
57
  del unet