Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torch | |
from diffusers import StableDiffusionPipeline | |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput | |
def void(*args, **kwargs): | |
pass | |
st.title("AI 元火娘") | |
with st.sidebar: | |
model = st.selectbox("Model Name", [ | |
"wybxc/yanhuo-v1-dreambooth", | |
"wybxc/yanyuan-v1-dreambooth", | |
"wybxc/yuanhuo-v1-dreambooth", | |
"<Custom>" | |
]) | |
if model == "<Custom>": | |
model = st.text_input("Model Path", "").strip() | |
# Caching model | |
if 'model' not in st.session_state: | |
st.session_state.model = model | |
if 'pipeline' not in st.session_state: | |
st.session_state.pipeline = None | |
if model != st.session_state.model or st.session_state.pipeline is None: | |
if model: | |
with st.spinner("Loading Model..."): | |
pipeline = StableDiffusionPipeline.from_pretrained(model) | |
assert type(pipeline) is StableDiffusionPipeline | |
if torch.cuda.is_available(): | |
pipeline = pipeline.to("cuda") | |
st.session_state.model = model | |
st.session_state.pipeline = pipeline | |
else: | |
pipeline = None | |
else: | |
pipeline = st.session_state.pipeline | |
assert type(pipeline) is StableDiffusionPipeline | |
prompt = st.text_area("Prompt", "(yanhuo), 1girl, masterpiece, best quality, " | |
"white hair, ahoge, snowy street, [smile], dynamic angle, full body, " | |
"[blue eyes], flat chest, cinematic light") | |
negative_prompt = st.text_area("Negative Prompt", "lowres, bad anatomy, bad hands, " | |
"text, error, missing fingers, extra digit, fewer digits, cropped, " | |
"worst quality, low quality, normal quality, jpeg artifacts, signature, " | |
"watermark, username, blurry") | |
with st.sidebar: | |
height = st.slider("Height", 256, 1024, 512, 64) | |
width = st.slider("Width", 256, 1024, 512, 64) | |
steps = st.slider("Steps", 1, 100, 20, 1) | |
if pipeline and st.button("Generate"): | |
progress = st.progress(0) | |
result = pipeline( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
height=height, | |
width=width, | |
num_inference_steps=steps, | |
callback=lambda s, *_: void(progress.progress(s / steps)) | |
) | |
assert type(result) is StableDiffusionPipelineOutput | |
image = result.images[0] | |
progress.progress(1.0) | |
st.image(image) | |