microhum commited on
Commit
448a707
·
1 Parent(s): 94dff7f
Files changed (1) hide show
  1. test_few_shot.py +6 -6
test_few_shot.py CHANGED
@@ -23,7 +23,7 @@ def test_main_model(opts):
23
  dir_res = os.path.join(f"{opts.exp_path}", "experiments/", opts.name_exp, "results")
24
 
25
  test_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size, 'test')
26
-
27
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28
  print("Inference With Device:", device)
29
  if opts.streamlit:
@@ -75,8 +75,8 @@ def test_main_model(opts):
75
  st.progress((sample_idx+1)/opts.n_samples, f"Generating Font Sample {sample_idx+1} Please wait...")
76
  im = Image.open(save_file_merge)
77
  st.image(im, caption='img_sample_merge')
78
-
79
- for char_idx in range(opts.char_num):
80
  img_gt = (1.0 - img_trg[char_idx,...]).data
81
  save_file_gt = os.path.join(dir_save,"imgs", f"{char_idx:02d}_gt.png")
82
  save_image(img_gt, save_file_gt, normalize=True)
@@ -87,7 +87,7 @@ def test_main_model(opts):
87
 
88
  # write results w/o parallel refinement
89
  svg_dec_out = svg_sampled.clone().detach()
90
- for i, one_seq in enumerate(svg_dec_out):
91
  syn_svg_outfile = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{sample_idx}_wo_refine.svg")
92
 
93
  syn_svg_f_ = open(syn_svg_outfile, 'w')
@@ -105,7 +105,7 @@ def test_main_model(opts):
105
 
106
  # write results w/ parallel refinement
107
  svg_dec_out = sampled_svg_2.clone().detach()
108
- for i, one_seq in enumerate(svg_dec_out):
109
  syn_svg_outfile = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{sample_idx}_refined.svg")
110
 
111
  syn_svg_f = open(syn_svg_outfile, 'w')
@@ -127,7 +127,7 @@ def test_main_model(opts):
127
  iou_max[i] = iou_tmp
128
  idx_best_sample[i] = sample_idx
129
 
130
- for i in range(opts.char_num):
131
  # print(idx_best_sample[i])
132
  syn_svg_outfile_best = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{int(idx_best_sample[i])}_refined.svg")
133
  syn_svg_merge_f.write(open(syn_svg_outfile_best, 'r').read())
 
23
  dir_res = os.path.join(f"{opts.exp_path}", "experiments/", opts.name_exp, "results")
24
 
25
  test_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size, 'test')
26
+
27
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28
  print("Inference With Device:", device)
29
  if opts.streamlit:
 
75
  st.progress((sample_idx+1)/opts.n_samples, f"Generating Font Sample {sample_idx+1} Please wait...")
76
  im = Image.open(save_file_merge)
77
  st.image(im, caption='img_sample_merge')
78
+
79
+ for char_idx in tqdm(range(opts.char_num)):
80
  img_gt = (1.0 - img_trg[char_idx,...]).data
81
  save_file_gt = os.path.join(dir_save,"imgs", f"{char_idx:02d}_gt.png")
82
  save_image(img_gt, save_file_gt, normalize=True)
 
87
 
88
  # write results w/o parallel refinement
89
  svg_dec_out = svg_sampled.clone().detach()
90
+ for i, one_seq in tqdm(enumerate(svg_dec_out)):
91
  syn_svg_outfile = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{sample_idx}_wo_refine.svg")
92
 
93
  syn_svg_f_ = open(syn_svg_outfile, 'w')
 
105
 
106
  # write results w/ parallel refinement
107
  svg_dec_out = sampled_svg_2.clone().detach()
108
+ for i, one_seq in tqdm(enumerate(svg_dec_out)):
109
  syn_svg_outfile = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{sample_idx}_refined.svg")
110
 
111
  syn_svg_f = open(syn_svg_outfile, 'w')
 
127
  iou_max[i] = iou_tmp
128
  idx_best_sample[i] = sample_idx
129
 
130
+ for i in tqdm(range(opts.char_num)):
131
  # print(idx_best_sample[i])
132
  syn_svg_outfile_best = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{int(idx_best_sample[i])}_refined.svg")
133
  syn_svg_merge_f.write(open(syn_svg_outfile_best, 'r').read())