File size: 2,433 Bytes
6711464
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import streamlit as st
import torch
from diffusers import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput

def void(*args, **kwargs):
    pass

st.title("AI 元火娘")

with st.sidebar:
    model = st.selectbox("Model Name", [
        "wybxc/yanhuo-v1-dreambooth",
        "wybxc/yanyuan-v1-dreambooth",
        "wybxc/yuanhuo-v1-dreambooth",
        "<Custom>"
    ])
    if model == "<Custom>":
        model = st.text_input("Model Path", "").strip()

# Caching model
if 'model' not in st.session_state:
    st.session_state.model = model
if 'pipeline' not in st.session_state:
    st.session_state.pipeline = None


if model != st.session_state.model or st.session_state.pipeline is None:
    if model:
        with st.spinner("Loading Model..."):
            pipeline = StableDiffusionPipeline.from_pretrained(
                model,                
                torch_dtype=torch.float16
            )
            assert type(pipeline) is StableDiffusionPipeline
            if torch.cuda.is_available():
                pipeline = pipeline.to("cuda")
            st.session_state.model = model
            st.session_state.pipeline = pipeline
    else:
        pipeline = None
else:
    pipeline = st.session_state.pipeline
    assert type(pipeline) is StableDiffusionPipeline


prompt = st.text_area("Prompt", "(yanhuo), 1girl, masterpiece, best quality, "
    "white hair, ahoge, snowy street, [smile], dynamic angle, full body, "
    "[blue eyes], flat chest, cinematic light")

negative_prompt = st.text_area("Negative Prompt", "lowres, bad anatomy, bad hands, "
    "text, error, missing fingers, extra digit, fewer digits, cropped, "
    "worst quality, low quality, normal quality, jpeg artifacts, signature, "
    "watermark, username, blurry")

with st.sidebar:
    height = st.slider("Height", 256, 1024, 512, 64)
    width = st.slider("Width", 256, 1024, 512, 64)
    steps = st.slider("Steps", 1, 100, 20, 1)

if pipeline and st.button("Generate"):
    progress = st.progress(0)    
    result = pipeline(
        prompt=prompt,
        negative_prompt=negative_prompt,
        height=height,
        width=width,
        num_inference_steps=steps,
        callback=lambda s, *_: void(progress.progress(s / steps))
    )
    assert type(result) is StableDiffusionPipelineOutput
    image = result.images[0]

    progress.progress(1.0)

    st.image(image)