Spaces:
Sleeping
Sleeping
cpu pls
Browse files- test_few_shot.py +6 -6
test_few_shot.py
CHANGED
@@ -23,7 +23,7 @@ def test_main_model(opts):
|
|
23 |
dir_res = os.path.join(f"{opts.exp_path}", "experiments/", opts.name_exp, "results")
|
24 |
|
25 |
test_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size, 'test')
|
26 |
-
|
27 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
28 |
print("Inference With Device:", device)
|
29 |
if opts.streamlit:
|
@@ -75,8 +75,8 @@ def test_main_model(opts):
|
|
75 |
st.progress((sample_idx+1)/opts.n_samples, f"Generating Font Sample {sample_idx+1} Please wait...")
|
76 |
im = Image.open(save_file_merge)
|
77 |
st.image(im, caption='img_sample_merge')
|
78 |
-
|
79 |
-
for char_idx in range(opts.char_num):
|
80 |
img_gt = (1.0 - img_trg[char_idx,...]).data
|
81 |
save_file_gt = os.path.join(dir_save,"imgs", f"{char_idx:02d}_gt.png")
|
82 |
save_image(img_gt, save_file_gt, normalize=True)
|
@@ -87,7 +87,7 @@ def test_main_model(opts):
|
|
87 |
|
88 |
# write results w/o parallel refinement
|
89 |
svg_dec_out = svg_sampled.clone().detach()
|
90 |
-
for i, one_seq in enumerate(svg_dec_out):
|
91 |
syn_svg_outfile = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{sample_idx}_wo_refine.svg")
|
92 |
|
93 |
syn_svg_f_ = open(syn_svg_outfile, 'w')
|
@@ -105,7 +105,7 @@ def test_main_model(opts):
|
|
105 |
|
106 |
# write results w/ parallel refinement
|
107 |
svg_dec_out = sampled_svg_2.clone().detach()
|
108 |
-
for i, one_seq in enumerate(svg_dec_out):
|
109 |
syn_svg_outfile = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{sample_idx}_refined.svg")
|
110 |
|
111 |
syn_svg_f = open(syn_svg_outfile, 'w')
|
@@ -127,7 +127,7 @@ def test_main_model(opts):
|
|
127 |
iou_max[i] = iou_tmp
|
128 |
idx_best_sample[i] = sample_idx
|
129 |
|
130 |
-
for i in range(opts.char_num):
|
131 |
# print(idx_best_sample[i])
|
132 |
syn_svg_outfile_best = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{int(idx_best_sample[i])}_refined.svg")
|
133 |
syn_svg_merge_f.write(open(syn_svg_outfile_best, 'r').read())
|
|
|
23 |
dir_res = os.path.join(f"{opts.exp_path}", "experiments/", opts.name_exp, "results")
|
24 |
|
25 |
test_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size, 'test')
|
26 |
+
|
27 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
28 |
print("Inference With Device:", device)
|
29 |
if opts.streamlit:
|
|
|
75 |
st.progress((sample_idx+1)/opts.n_samples, f"Generating Font Sample {sample_idx+1} Please wait...")
|
76 |
im = Image.open(save_file_merge)
|
77 |
st.image(im, caption='img_sample_merge')
|
78 |
+
|
79 |
+
for char_idx in tqdm(range(opts.char_num)):
|
80 |
img_gt = (1.0 - img_trg[char_idx,...]).data
|
81 |
save_file_gt = os.path.join(dir_save,"imgs", f"{char_idx:02d}_gt.png")
|
82 |
save_image(img_gt, save_file_gt, normalize=True)
|
|
|
87 |
|
88 |
# write results w/o parallel refinement
|
89 |
svg_dec_out = svg_sampled.clone().detach()
|
90 |
+
for i, one_seq in tqdm(enumerate(svg_dec_out)):
|
91 |
syn_svg_outfile = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{sample_idx}_wo_refine.svg")
|
92 |
|
93 |
syn_svg_f_ = open(syn_svg_outfile, 'w')
|
|
|
105 |
|
106 |
# write results w/ parallel refinement
|
107 |
svg_dec_out = sampled_svg_2.clone().detach()
|
108 |
+
for i, one_seq in tqdm(enumerate(svg_dec_out)):
|
109 |
syn_svg_outfile = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{sample_idx}_refined.svg")
|
110 |
|
111 |
syn_svg_f = open(syn_svg_outfile, 'w')
|
|
|
127 |
iou_max[i] = iou_tmp
|
128 |
idx_best_sample[i] = sample_idx
|
129 |
|
130 |
+
for i in tqdm(range(opts.char_num)):
|
131 |
# print(idx_best_sample[i])
|
132 |
syn_svg_outfile_best = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{int(idx_best_sample[i])}_refined.svg")
|
133 |
syn_svg_merge_f.write(open(syn_svg_outfile_best, 'r').read())
|