Spaces:
Sleeping
Sleeping
import os | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from torchvision.utils import save_image | |
from dataloader import get_loader | |
from models.model_main import ModelMain | |
from models.transformers import denumericalize | |
from options import get_parser_main_model | |
from data_utils.svg_utils import render | |
from models.util_funcs import svg2img, cal_iou | |
# Testing (Only accuracy) | |
def test_main_model(opts): | |
test_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size, 'test') | |
model_main = ModelMain(opts) | |
path_ckpt = os.path.join(f"{opts.model_path}") | |
model_main.load_state_dict(torch.load(path_ckpt)['model']) | |
model_main.cuda() | |
model_main.eval() # Testing mode | |
with torch.no_grad(): | |
loss_val = {'img':{'l1':0.0, 'vggpt':0.0}, 'svg':{'total':0.0, 'cmd':0.0, 'args':0.0, 'aux':0.0}, | |
'svg_para':{'total':0.0, 'cmd':0.0, 'args':0.0, 'aux':0.0}} | |
for val_idx, val_data in enumerate(test_loader): | |
for key in val_data: val_data[key] = val_data[key].cuda() | |
ret_dict_val, loss_dict_val = model_main(val_data, mode='val') | |
for loss_cat in ['img', 'svg']: | |
for key, _ in loss_val[loss_cat].items(): | |
loss_val[loss_cat][key] += loss_dict_val[loss_cat][key] | |
for loss_cat in ['img', 'svg']: | |
for key, _ in loss_val[loss_cat].items(): | |
loss_val[loss_cat][key] /= len(test_loader) | |
val_msg = ( | |
f"Val loss img l1: {loss_val['img']['l1']: .6f}, " | |
f"Val loss img pt: {loss_val['img']['vggpt']: .6f}, " | |
f"Val loss total: {loss_val['svg']['total']: .6f}, " | |
f"Val loss cmd: {loss_val['svg']['cmd']: .6f}, " | |
f"Val loss args: {loss_val['svg']['args']: .6f}, " | |
) | |
print(val_msg) | |
print(f"l1: {loss_val['img']['l1']: .6f}, pt: {loss_val['img']['vggpt']: .6f}") | |
def main(): | |
opts = get_parser_main_model().parse_args() | |
opts.name_exp = opts.name_exp + '_' + opts.model_name | |
experiment_dir = os.path.join(f"{opts.exp_path}","experiments", opts.name_exp) | |
print(f"Testing on experiment {opts.name_exp}...") | |
# Dump options | |
test_main_model(opts) | |
if __name__ == "__main__": | |
main() |