from typing import Optional import numpy as np import cv2 import streamlit as st from PIL import Image from sdfile import PIPELINES, generate DEFAULT_PROMPT = "belted shirt black belted portrait-collar wrap blouse with black prints" DEAFULT_WIDTH, DEFAULT_HEIGHT = 512,512 OUTPUT_IMAGE_KEY = "output_img" LOADED_IMAGE_KEY = "loaded_img" def get_image(key: str) -> Optional[Image.Image]: if key in st.session_state: return st.session_state[key] return None def set_image(key:str, img: Image.Image): st.session_state[key] = img def prompt_and_generate_button(prefix, pipeline_name: PIPELINES, **kwargs): prompt = st.text_area( "Prompt", value = DEFAULT_PROMPT, key = f"{prefix}-prompt" ) negative_prompt = st.text_area( "Negative prompt", value = "", key =f"{prefix}-negative_prompt", ) col1,col2 =st.columns(2) with col1: steps = st.slider( "Number of inference steps", min_value=1, max_value=200, value=30, key=f"{prefix}-inference-steps", ) with col2: guidance_scale = st.slider( "Guidance scale", min_value=0.0, max_value=20.0, value= 7.5, step = 0.5, key=f"{prefix}-guidance-scale", ) enable_cpu_offload = st.checkbox( "Enable CPU offload if you run out of memory", key =f"{prefix}-cpu-offload", value= False, ) if st.button("Generate Image", key = f"{prefix}-btn"): with st.spinner("Generating image ..."): image = generate( prompt, pipeline_name, negative_prompt=negative_prompt, num_inference_steps=steps, guidance_scale=guidance_scale, enable_cpu_offload=enable_cpu_offload, **kwargs, ) set_image(OUTPUT_IMAGE_KEY,image.copy()) st.image(image) def width_and_height_sliders(prefix): col1, col2 = st.columns(2) with col1: width = st.slider( "Width", min_value=64, max_value=1600, step=16, value=512, key=f"{prefix}-width", ) with col2: height = st.slider( "Height", min_value=64, max_value=1600, step=16, value=512, key=f"{prefix}-height", ) return width, height def image_uploader(prefix): image = st.file_uploader("Image", ["jpg", "png"], key=f"{prefix}-uploader") if image: image = Image.open(image) print(f"loaded input image of size ({image.width}, {image.height})") return image return get_image(LOADED_IMAGE_KEY) def sketching(): image = image_uploader("Controlnet") if not image: return None,None image = cv2.imread(image) image = cv2.cvtColor(image,cv2.COLOR_BGR2GRAY) image_blur = cv2.GaussianBlur(image,(5,5),0) sketch = cv2.adaptiveThreshold(image_blur, 255, cv2.ADAPTIVE_THRESH_MEAN_C,cv2.THRES_BINARY,11,2) sketch_pil = Image.fromarray(sketch) return sketch_pil def txt2img_tab(): prefix = "txt2img" width, height = width_and_height_sliders(prefix) prompt_and_generate_button(prefix,"txt2img",width=width,height=height) def sketching_tab(): prefix = "sketch2img" col1,col2 = st.columns(2) with col1: sketch_pil = sketching() with col2: if sketch_pil: controlnet_conditioning_scale = st.slider( "Strength or dependence on the input sketch", min_value=0.0, max_value= 1.0, value = 0.5, step = 0.05, key=f"{prefix}-controlnet_conditioning_scale", ) prompt_and_generate_button( prefix, "sketch2img", sketch_pil=sketch_pil, controlnet_conditioning_scale=controlnet_conditioning_scale, ) def main(): st.set_page_config(layout="wide") st.title("Fashion-SDX: Playground") tab1,tab2 = st.tabs( ["Text to image", "Sketch to image"] ) with tab1: txt2img_tab() with tab2: sketching_tab() with st.sidebar: st.header("Most Recent Output Image") output_image = get_image((OUTPUT_IMAGE_KEY)) if output_image: st.image(output_image) else: st.markdown("no output generated yet") if __name__ =="__main__": main()