sdxl-thumbs-up / app.py
uttamg07's picture
Update app.py
251b17c
raw
history blame
1.04 kB
import os
import torch
import gc
import shutil
import streamlit as st
from PIL import Image
from diffusers import DiffusionPipeline, AutoencoderKL
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
vae=vae, torch_dtype=torch.float16, variant="fp16",
use_safetensors=True
)
pipe.load_lora_weights("./lora-trained-thumbs-up-pw-0.5-steps-800")
pipe.fuse_lora(lora_scale=1.0)
pipe.save_pretrained("temp_model")
del pipe
gc.collect()
torch.cuda.empty_cache()
pipe = DiffusionPipeline.from_pretrained(
"temp_model",
vae=vae, torch_dtype=torch.float16, variant="fp16",
use_safetensors=True
)
pipe.load_lora_weights("./lora-trained-thumbs-up-pw-0.5-steps-800-uttam-pw-0.5-steps-1600")
_ = pipe.to("cuda")
prompt = st.text_area("Enter the prompt!")
image = pipe(prompt=prompt, num_inference_steps=25, num_images_per_prompt=1)[:]
shutil.rmtree("temp_model", ignore_errors=True)
st.image(image)