Spaces:
Sleeping
Sleeping
from typing import Optional | |
import streamlit as st | |
from generate import ttf_to_image | |
from threading import Thread | |
from PIL import Image | |
import os | |
LOADED_TTF_KEY = "loaded_ttf" | |
SET_IMG_KEY = "set_img" | |
OUTPUT_IMG_KEY = "output_img" | |
def get_ttf(key: str) -> Optional[any]: | |
if key in st.session_state: | |
return st.session_state[key] | |
return None | |
def get_img(key: str) -> Optional[Image.Image]: | |
if key in st.session_state: | |
return st.session_state[key] | |
return None | |
def set_img(key: str, img: Image.Image): | |
st.session_state[key] = img | |
def ttf_uploader(prefix): | |
file = st.file_uploader("TTF, OTF", ["ttf", "otf"], key=f"{prefix}-uploader") | |
if file: | |
return file | |
return get_ttf(LOADED_TTF_KEY) | |
def generate_button(prefix, file_input, version, **kwargs): | |
col1, col2 = st.columns(2) | |
with col1: | |
n_samples = st.slider( | |
"Number of inference sample", | |
min_value=1, | |
max_value=200, | |
value=20, | |
key=f"{prefix}-inference-sample", | |
) | |
with col2: | |
ref_char_ids = st.text_area( | |
"ref_char_ids", | |
value="1,2,3,4,5,6,7,8", | |
key=f"{prefix}-ref_char_ids", | |
) | |
# For multithreading toggle (prevent function from running too many time) | |
process_running = False | |
process_thread = None | |
cancel_label = st.empty() | |
def toggle_process(process_running, process_thread, run_process): | |
if process_running: | |
# Toggle off | |
process_running = False | |
cancel_label.text("Canceled") | |
if process_thread: # Kill Thread | |
process_thread.join() | |
else: | |
# Toggle on | |
cancel_label.empty() | |
process_running = True | |
process_thread = Thread(target=run_process) | |
process_thread.start() | |
if st.button("Generate image", key=f"{prefix}-btn") and process_running == False: | |
with st.spinner(f"⏳ Generating image (5 minutes per n_sample estimated time)"): | |
image = toggle_process(process_running, process_thread, ttf_to_image(file_input, OUTPUT_IMG_KEY, n_samples, ref_char_ids, version) ) | |
set_img(OUTPUT_IMG_KEY, image.copy()) | |
st.image(image) | |
test_font = st.text_area( | |
"test font", | |
value="กขคง", | |
key=f"{prefix}-prompt", | |
) | |
def generate_tab(): | |
prefix = "ttf2img" | |
col1, col2 = st.columns(2) | |
with col1: | |
sample_choose = st.selectbox( | |
"Choose Sample", ["Custom"] + [i for i in os.listdir("font_sample/")], key=f"{prefix}-sample_choose" | |
) | |
if sample_choose == "Custom": | |
uploaded_file = ttf_uploader(prefix) | |
if uploaded_file: | |
st.write("filename:", uploaded_file.name) | |
uploaded_file = uploaded_file.getbuffer() # Send file as Buffer | |
else: | |
st.write("filename:", sample_choose) | |
uploaded_file = os.path.join("font_sample", sample_choose) | |
with col2: | |
if uploaded_file: | |
version = st.selectbox( | |
"Model version", ["TH2TH", "ENG2TH"], key=f"{prefix}-version" | |
) | |
generate_button( | |
prefix, file_input=uploaded_file, version=version | |
) | |
def main(): | |
st.set_page_config(layout="wide") | |
st.title("ThaiVecFont Playground") | |
generate_tab() | |
with st.sidebar: | |
st.header("Latest Output") | |
output_image = get_img(OUTPUT_IMG_KEY) | |
if output_image: | |
st.image(output_image) | |
else: | |
st.markdown("No output generated yet") | |
if __name__ == "__main__": | |
main() |