microhum commited on
Commit
6d8843e
·
1 Parent(s): 86aa827

latest output update

Browse files
Files changed (3) hide show
  1. app.py +3 -3
  2. generate.py +4 -3
  3. test_few_shot.py +5 -1
app.py CHANGED
@@ -65,8 +65,8 @@ def generate_button(prefix, file_input, version, **kwargs):
65
  )
66
 
67
  if st.button("Generate image", key=f"{prefix}-btn"):
68
- with st.spinner(f"⏳ Generating image... ({n_samples*5} minutes estimated time)"):
69
- image = toggle_process( ttf_to_image(file_input, n_samples, ref_char_ids, version) )
70
  set_img(OUTPUT_IMG_KEY, image.copy())
71
  st.image(image)
72
 
@@ -110,7 +110,7 @@ def main():
110
  generate_tab()
111
 
112
  with st.sidebar:
113
- st.header("Latest Output Image")
114
  output_image = get_img(OUTPUT_IMG_KEY)
115
  if output_image:
116
  st.image(output_image)
 
65
  )
66
 
67
  if st.button("Generate image", key=f"{prefix}-btn"):
68
+ with st.spinner(f"⏳ Generating image (5 minutes per n_sample estimated time)"):
69
+ image = toggle_process( ttf_to_image(file_input, OUTPUT_IMG_KEY, n_samples, ref_char_ids, version) )
70
  set_img(OUTPUT_IMG_KEY, image.copy())
71
  st.image(image)
72
 
 
110
  generate_tab()
111
 
112
  with st.sidebar:
113
+ st.header("Latest Output")
114
  output_image = get_img(OUTPUT_IMG_KEY)
115
  if output_image:
116
  st.image(output_image)
generate.py CHANGED
@@ -123,9 +123,10 @@ def preprocessing(ttf_file) -> str:
123
  print("Saved at", output_path)
124
  return output_path
125
 
126
- def inference_model(n_samples, ref_char_ids, version):
127
  opts.n_samples = n_samples
128
  opts.ref_char_ids = ref_char_ids
 
129
 
130
  # Select Model
131
  if version == "TH2TH":
@@ -137,9 +138,9 @@ def inference_model(n_samples, ref_char_ids, version):
137
 
138
  return test_main_model(opts)
139
 
140
- def ttf_to_image(ttf_file, n_samples=10, ref_char_ids="1,2,3,4,5,6,7,8", version="TH2TH"):
141
  preprocessing(ttf_file) # Make Data
142
- merge_svg_img = inference_model(n_samples, ref_char_ids, version) # Inference
143
  return merge_svg_img
144
 
145
  def main():
 
123
  print("Saved at", output_path)
124
  return output_path
125
 
126
+ def inference_model(OUTPUT_IMG_KEY, n_samples, ref_char_ids, version):
127
  opts.n_samples = n_samples
128
  opts.ref_char_ids = ref_char_ids
129
+ opts.OUTPUT_IMG_KEY = OUTPUT_IMG_KEY
130
 
131
  # Select Model
132
  if version == "TH2TH":
 
138
 
139
  return test_main_model(opts)
140
 
141
+ def ttf_to_image(ttf_file, OUTPUT_IMG_KEY, n_samples=10, ref_char_ids="1,2,3,4,5,6,7,8", version="TH2TH"):
142
  preprocessing(ttf_file) # Make Data
143
+ merge_svg_img = inference_model(OUTPUT_IMG_KEY, n_samples, ref_char_ids, version) # Inference
144
  return merge_svg_img
145
 
146
  def main():
test_few_shot.py CHANGED
@@ -31,6 +31,9 @@ def test_main_model(opts):
31
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32
  print("Inference With Device:", device)
33
  if opts.streamlit:
 
 
 
34
  st.write("Loading Model Weight...")
35
  st.write("Inference With Device:", device)
36
 
@@ -78,7 +81,8 @@ def test_main_model(opts):
78
  if opts.streamlit:
79
  st.progress((sample_idx+1)/opts.n_samples, f"Generating Font Sample {sample_idx+1} Please wait...")
80
  im = Image.open(save_file_merge)
81
- st.image(im, caption='img_sample_merge')
 
82
 
83
  for char_idx in tqdm(range(opts.char_num)):
84
  img_gt = (1.0 - img_trg[char_idx,...]).data
 
31
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32
  print("Inference With Device:", device)
33
  if opts.streamlit:
34
+ def set_img(key: str, img: Image.Image):
35
+ st.session_state[key] = img
36
+
37
  st.write("Loading Model Weight...")
38
  st.write("Inference With Device:", device)
39
 
 
81
  if opts.streamlit:
82
  st.progress((sample_idx+1)/opts.n_samples, f"Generating Font Sample {sample_idx+1} Please wait...")
83
  im = Image.open(save_file_merge)
84
+ set_img(opts.OUTPUT_IMG_KEY, im)
85
+ st.image(im, caption=f"sample {sample_idx+1}")
86
 
87
  for char_idx in tqdm(range(opts.char_num)):
88
  img_gt = (1.0 - img_trg[char_idx,...]).data