Spaces:
Sleeping
Sleeping
import os | |
import shutil | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from torchvision.utils import save_image | |
from dataloader import get_loader | |
from models.model_main import ModelMain | |
from models.transformers import denumericalize | |
from options import get_parser_main_model | |
from data_utils.svg_utils import render | |
from models.util_funcs import svg2img, cal_iou | |
from tqdm import tqdm | |
from PIL import Image | |
def test_main_model(opts): | |
if opts.streamlit: | |
import streamlit as st | |
if opts.dir_res: | |
dir_res = os.path.join(opts.dir_res, "results") | |
if os.path.exists(dir_res): | |
shutil.rmtree(dir_res) | |
os.mkdir(os.path.join(opts.dir_res, "results")) | |
else: | |
dir_res = os.path.join(f"{opts.exp_path}", "experiments/", opts.name_exp, "results") | |
test_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size, 'test') | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print("Inference With Device:", device) | |
if opts.streamlit: | |
def set_img(key: str, img: Image.Image): | |
st.session_state[key] = img | |
st.write("Loading Model Weight...") | |
st.write("Inference With Device:", device) | |
model_main = ModelMain(opts) | |
path_ckpt = os.path.join(f"{opts.model_path}") | |
model_main.load_state_dict(torch.load(path_ckpt, map_location=device)['model']) | |
model_main.to(device) | |
model_main.eval() | |
with torch.no_grad(): | |
for test_idx, test_data in enumerate(test_loader): | |
for key in test_data: test_data[key] = test_data[key].to(device) | |
print("testing font %04d ..."%test_idx) | |
dir_save = os.path.join(dir_res, "%04d"%test_idx) | |
if not os.path.exists(dir_save): | |
os.mkdir(dir_save) | |
os.mkdir(os.path.join(dir_save, "imgs")) | |
os.mkdir(os.path.join(dir_save, "svgs_single")) | |
os.mkdir(os.path.join(dir_save, "svgs_merge")) | |
svg_merge_dir = os.path.join(dir_save, "svgs_merge") | |
iou_max = np.zeros(opts.char_num) | |
idx_best_sample = np.zeros(opts.char_num) | |
# syn_svg_merge_f = open(os.path.join(svg_merge_dir, f"{opts.name_ckpt}_syn_merge_{test_idx}_rand_{sample_idx}.html"), 'w') | |
syn_svg_merge_f = open(os.path.join(svg_merge_dir, f"{opts.name_ckpt}_syn_merge_{test_idx}.html"), 'w') | |
for sample_idx in tqdm(range(opts.n_samples)): | |
ret_dict_test, loss_dict_test = model_main(test_data, mode='test') | |
svg_sampled = ret_dict_test['svg']['sampled_1'] | |
sampled_svg_2 = ret_dict_test['svg']['sampled_2'] | |
img_trg = ret_dict_test['img']['trg'] | |
img_output = ret_dict_test['img']['out'] | |
trg_seq_gt = ret_dict_test['svg']['trg'] | |
img_sample_merge = torch.cat((img_trg.data, img_output.data), -2) | |
save_file_merge = os.path.join(dir_save, "imgs", f"merge_{opts.img_size}.png") | |
save_image(img_sample_merge, save_file_merge, nrow=8, normalize=True) | |
if opts.streamlit: | |
st.progress((sample_idx+1)/opts.n_samples, f"Generating Font Sample {sample_idx+1} Please wait...") | |
im = Image.open(save_file_merge) | |
set_img(opts.OUTPUT_IMG_KEY, im.copy()) | |
st.image(im, caption=f"sample {sample_idx+1}") | |
for char_idx in tqdm(range(opts.char_num)): | |
img_gt = (1.0 - img_trg[char_idx,...]).data | |
save_file_gt = os.path.join(dir_save,"imgs", f"{char_idx:02d}_gt.png") | |
save_image(img_gt, save_file_gt, normalize=True) | |
img_sample = (1.0 - img_output[char_idx,...]).data | |
save_file = os.path.join(dir_save,"imgs", f"{char_idx:02d}_{opts.img_size}.png") | |
save_image(img_sample, save_file, normalize=True) | |
# write results w/o parallel refinement | |
svg_dec_out = svg_sampled.clone().detach() | |
for i, one_seq in tqdm(enumerate(svg_dec_out)): | |
syn_svg_outfile = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{sample_idx}_wo_refine.svg") | |
syn_svg_f_ = open(syn_svg_outfile, 'w') | |
try: | |
svg = render(one_seq.cpu().numpy()) | |
syn_svg_f_.write(svg) | |
# syn_svg_merge_f.write(svg) | |
if i > 0 and i % 13 == 12: | |
syn_svg_f_.write('<br>') | |
# syn_svg_merge_f.write('<br>') | |
except: | |
continue | |
syn_svg_f_.close() | |
# write results w/ parallel refinement | |
svg_dec_out = sampled_svg_2.clone().detach() | |
for i, one_seq in tqdm(enumerate(svg_dec_out)): | |
syn_svg_outfile = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{sample_idx}_refined.svg") | |
syn_svg_f = open(syn_svg_outfile, 'w') | |
try: | |
svg = render(one_seq.cpu().numpy()) | |
syn_svg_f.write(svg) | |
#syn_svg_merge_f.write(svg) | |
#if i > 0 and i % 13 == 12: | |
# syn_svg_merge_f.write('<br>') | |
except: | |
continue | |
syn_svg_f.close() | |
syn_img_outfile = syn_svg_outfile.replace('.svg', '.png') | |
svg2img(syn_svg_outfile, syn_img_outfile, img_size=opts.img_size) | |
iou_tmp, l1_tmp = cal_iou(syn_img_outfile, os.path.join(dir_save, "imgs", f"{i:02d}_{opts.img_size}.png")) | |
iou_tmp = iou_tmp | |
if iou_tmp > iou_max[i]: | |
iou_max[i] = iou_tmp | |
idx_best_sample[i] = sample_idx | |
for i in tqdm(range(opts.char_num)): | |
# print(idx_best_sample[i]) | |
syn_svg_outfile_best = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{int(idx_best_sample[i])}_refined.svg") | |
syn_svg_merge_f.write(open(syn_svg_outfile_best, 'r').read()) | |
if i > 0 and i % 13 == 12: | |
syn_svg_merge_f.write('<br>') | |
svg_target = trg_seq_gt.clone().detach() | |
tgt_commands_onehot = F.one_hot(svg_target[:, :, :1].long(), 4).squeeze() | |
tgt_args_denum = denumericalize(svg_target[:, :, 1:]) | |
svg_target = torch.cat([tgt_commands_onehot, tgt_args_denum], dim=-1) | |
for i, one_gt_seq in enumerate(svg_target): | |
# gt_svg_outfile = os.path.join(os.path.join(dir_save, "svgs_single"), f"gt_{i:02d}.svg") | |
# gt_svg_f = open(gt_svg_outfile, 'w') | |
gt_svg = render(one_gt_seq.cpu().numpy()) | |
# gt_svg_f.write(gt_svg) | |
syn_svg_merge_f.write(gt_svg) | |
# gt_svg_f.close() | |
if i > 0 and i % 13 == 12: | |
syn_svg_merge_f.write('<br>') | |
syn_svg_merge_f.close() | |
return im | |
def main(): | |
opts = get_parser_main_model().parse_args() | |
opts.name_exp = opts.name_exp + '_' + opts.model_name | |
experiment_dir = os.path.join(f"{opts.exp_path}","experiments", opts.name_exp) | |
print(f"Testing on experiment {opts.name_exp}...") | |
# Dump options | |
test_main_model(opts) | |
if __name__ == "__main__": | |
main() |