fashionsd / app.py
Abhi5ingh's picture
deployed-v1
5b344d3
raw
history blame
4.62 kB
from typing import Optional
import numpy as np
import cv2
import streamlit as st
from PIL import Image
from sd.sdfile import PIPELINES, generate
DEFAULT_PROMPT = "belted shirt black belted portrait-collar wrap blouse with black prints"
DEAFULT_WIDTH, DEFAULT_HEIGHT = 512,512
OUTPUT_IMAGE_KEY = "output_img"
LOADED_IMAGE_KEY = "loaded_img"
def get_image(key: str) -> Optional[Image.Image]:
if key in st.session_state:
return st.session_state[key]
return None
def set_image(key:str, img: Image.Image):
st.session_state[key] = img
def prompt_and_generate_button(prefix, pipeline_name: PIPELINES, **kwargs):
prompt = st.text_area(
"Prompt",
value = DEFAULT_PROMPT,
key = f"{prefix}-prompt"
)
negative_prompt = st.text_area(
"Negative prompt",
value = "",
key =f"{prefix}-negative_prompt",
)
col1,col2 =st.columns(2)
with col1:
steps = st.slider(
"Number of inference steps",
min_value=1,
max_value=200,
value=30,
key=f"{prefix}-inference-steps",
)
with col2:
guidance_scale = st.slider(
"Guidance scale",
min_value=0.0,
max_value=20.0,
value= 7.5,
step = 0.5,
key=f"{prefix}-guidance-scale",
)
enable_cpu_offload = st.checkbox(
"Enable CPU offload if you run out of memory",
key =f"{prefix}-cpu-offload",
value= False,
)
if st.button("Generate Image", key = f"{prefix}-btn"):
with st.spinner("Generating image ..."):
image = generate(
prompt,
pipeline_name,
negative_prompt=negative_prompt,
num_inference_steps=steps,
guidance_scale=guidance_scale,
enable_cpu_offload=enable_cpu_offload,
**kwargs,
)
set_image(OUTPUT_IMAGE_KEY,image.copy())
st.image(image)
def width_and_height_sliders(prefix):
col1, col2 = st.columns(2)
with col1:
width = st.slider(
"Width",
min_value=64,
max_value=1600,
step=16,
value=512,
key=f"{prefix}-width",
)
with col2:
height = st.slider(
"Height",
min_value=64,
max_value=1600,
step=16,
value=512,
key=f"{prefix}-height",
)
return width, height
def image_uploader(prefix):
image = st.file_uploader("Image", ["jpg", "png"], key=f"{prefix}-uploader")
if image:
image = Image.open(image)
print(f"loaded input image of size ({image.width}, {image.height})")
return image
return get_image(LOADED_IMAGE_KEY)
def sketching():
image = image_uploader("Controlnet")
if not image:
return None,None
image = cv2.imread(image)
image = cv2.cvtColor(image,cv2.COLOR_BGR2GRAY)
image_blur = cv2.GaussianBlur(image,(5,5),0)
sketch = cv2.adaptiveThreshold(image_blur, 255, cv2.ADAPTIVE_THRESH_MEAN_C,cv2.THRES_BINARY,11,2)
sketch_pil = Image.fromarray(sketch)
return sketch_pil
def txt2img_tab():
prefix = "txt2img"
width, height = width_and_height_sliders(prefix)
prompt_and_generate_button(prefix,"txt2img",width=width,height=height)
def sketching_tab():
prefix = "sketch2img"
col1,col2 = st.columns(2)
with col1:
sketch_pil = sketching()
with col2:
if sketch_pil:
controlnet_conditioning_scale = st.slider(
"Strength or dependence on the input sketch",
min_value=0.0,
max_value= 1.0,
value = 0.5,
step = 0.05,
key=f"{prefix}-controlnet_conditioning_scale",
)
prompt_and_generate_button(
prefix,
"sketch2img",
sketch_pil=sketch_pil,
controlnet_conditioning_scale=controlnet_conditioning_scale,
)
def main():
st.set_page_config(layout="wide")
st.title("Fashion-SDX: Playground")
tab1,tab2 = st.tabs(
["Text to image", "Sketch to image"]
)
with tab1:
txt2img_tab()
with tab2:
sketching_tab()
with st.sidebar:
st.header("Most Recent Output Image")
output_image = get_image((OUTPUT_IMAGE_KEY))
if output_image:
st.image(output_image)
else:
st.markdown("no output generated yet")
if __name__ =="__main__":
main()