Spaces:
Sleeping
Sleeping
latest output update
Browse files- app.py +3 -3
- generate.py +4 -3
- test_few_shot.py +5 -1
app.py
CHANGED
@@ -65,8 +65,8 @@ def generate_button(prefix, file_input, version, **kwargs):
|
|
65 |
)
|
66 |
|
67 |
if st.button("Generate image", key=f"{prefix}-btn"):
|
68 |
-
with st.spinner(f"⏳ Generating image
|
69 |
-
image = toggle_process( ttf_to_image(file_input, n_samples, ref_char_ids, version) )
|
70 |
set_img(OUTPUT_IMG_KEY, image.copy())
|
71 |
st.image(image)
|
72 |
|
@@ -110,7 +110,7 @@ def main():
|
|
110 |
generate_tab()
|
111 |
|
112 |
with st.sidebar:
|
113 |
-
st.header("Latest Output
|
114 |
output_image = get_img(OUTPUT_IMG_KEY)
|
115 |
if output_image:
|
116 |
st.image(output_image)
|
|
|
65 |
)
|
66 |
|
67 |
if st.button("Generate image", key=f"{prefix}-btn"):
|
68 |
+
with st.spinner(f"⏳ Generating image (5 minutes per n_sample estimated time)"):
|
69 |
+
image = toggle_process( ttf_to_image(file_input, OUTPUT_IMG_KEY, n_samples, ref_char_ids, version) )
|
70 |
set_img(OUTPUT_IMG_KEY, image.copy())
|
71 |
st.image(image)
|
72 |
|
|
|
110 |
generate_tab()
|
111 |
|
112 |
with st.sidebar:
|
113 |
+
st.header("Latest Output")
|
114 |
output_image = get_img(OUTPUT_IMG_KEY)
|
115 |
if output_image:
|
116 |
st.image(output_image)
|
generate.py
CHANGED
@@ -123,9 +123,10 @@ def preprocessing(ttf_file) -> str:
|
|
123 |
print("Saved at", output_path)
|
124 |
return output_path
|
125 |
|
126 |
-
def inference_model(n_samples, ref_char_ids, version):
|
127 |
opts.n_samples = n_samples
|
128 |
opts.ref_char_ids = ref_char_ids
|
|
|
129 |
|
130 |
# Select Model
|
131 |
if version == "TH2TH":
|
@@ -137,9 +138,9 @@ def inference_model(n_samples, ref_char_ids, version):
|
|
137 |
|
138 |
return test_main_model(opts)
|
139 |
|
140 |
-
def ttf_to_image(ttf_file, n_samples=10, ref_char_ids="1,2,3,4,5,6,7,8", version="TH2TH"):
|
141 |
preprocessing(ttf_file) # Make Data
|
142 |
-
merge_svg_img = inference_model(n_samples, ref_char_ids, version) # Inference
|
143 |
return merge_svg_img
|
144 |
|
145 |
def main():
|
|
|
123 |
print("Saved at", output_path)
|
124 |
return output_path
|
125 |
|
126 |
+
def inference_model(OUTPUT_IMG_KEY, n_samples, ref_char_ids, version):
|
127 |
opts.n_samples = n_samples
|
128 |
opts.ref_char_ids = ref_char_ids
|
129 |
+
opts.OUTPUT_IMG_KEY = OUTPUT_IMG_KEY
|
130 |
|
131 |
# Select Model
|
132 |
if version == "TH2TH":
|
|
|
138 |
|
139 |
return test_main_model(opts)
|
140 |
|
141 |
+
def ttf_to_image(ttf_file, OUTPUT_IMG_KEY, n_samples=10, ref_char_ids="1,2,3,4,5,6,7,8", version="TH2TH"):
|
142 |
preprocessing(ttf_file) # Make Data
|
143 |
+
merge_svg_img = inference_model(OUTPUT_IMG_KEY, n_samples, ref_char_ids, version) # Inference
|
144 |
return merge_svg_img
|
145 |
|
146 |
def main():
|
test_few_shot.py
CHANGED
@@ -31,6 +31,9 @@ def test_main_model(opts):
|
|
31 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
32 |
print("Inference With Device:", device)
|
33 |
if opts.streamlit:
|
|
|
|
|
|
|
34 |
st.write("Loading Model Weight...")
|
35 |
st.write("Inference With Device:", device)
|
36 |
|
@@ -78,7 +81,8 @@ def test_main_model(opts):
|
|
78 |
if opts.streamlit:
|
79 |
st.progress((sample_idx+1)/opts.n_samples, f"Generating Font Sample {sample_idx+1} Please wait...")
|
80 |
im = Image.open(save_file_merge)
|
81 |
-
|
|
|
82 |
|
83 |
for char_idx in tqdm(range(opts.char_num)):
|
84 |
img_gt = (1.0 - img_trg[char_idx,...]).data
|
|
|
31 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
32 |
print("Inference With Device:", device)
|
33 |
if opts.streamlit:
|
34 |
+
def set_img(key: str, img: Image.Image):
|
35 |
+
st.session_state[key] = img
|
36 |
+
|
37 |
st.write("Loading Model Weight...")
|
38 |
st.write("Inference With Device:", device)
|
39 |
|
|
|
81 |
if opts.streamlit:
|
82 |
st.progress((sample_idx+1)/opts.n_samples, f"Generating Font Sample {sample_idx+1} Please wait...")
|
83 |
im = Image.open(save_file_merge)
|
84 |
+
set_img(opts.OUTPUT_IMG_KEY, im)
|
85 |
+
st.image(im, caption=f"sample {sample_idx+1}")
|
86 |
|
87 |
for char_idx in tqdm(range(opts.char_num)):
|
88 |
img_gt = (1.0 - img_trg[char_idx,...]).data
|