Spaces:
Sleeping
Sleeping
File size: 2,350 Bytes
b762e56 86e64e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import os
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.utils import save_image
from dataloader import get_loader
from models.model_main import ModelMain
from models.transformers import denumericalize
from options import get_parser_main_model
from data_utils.svg_utils import render
from models.util_funcs import svg2img, cal_iou
# Testing (Only accuracy)
def test_main_model(opts):
test_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size, 'test')
model_main = ModelMain(opts)
path_ckpt = os.path.join(f"{opts.model_path}")
model_main.load_state_dict(torch.load(path_ckpt)['model'])
model_main.cuda()
model_main.eval() # Testing mode
with torch.no_grad():
loss_val = {'img':{'l1':0.0, 'vggpt':0.0}, 'svg':{'total':0.0, 'cmd':0.0, 'args':0.0, 'aux':0.0},
'svg_para':{'total':0.0, 'cmd':0.0, 'args':0.0, 'aux':0.0}}
for val_idx, val_data in enumerate(test_loader):
for key in val_data: val_data[key] = val_data[key].cuda()
ret_dict_val, loss_dict_val = model_main(val_data, mode='val')
for loss_cat in ['img', 'svg']:
for key, _ in loss_val[loss_cat].items():
loss_val[loss_cat][key] += loss_dict_val[loss_cat][key]
for loss_cat in ['img', 'svg']:
for key, _ in loss_val[loss_cat].items():
loss_val[loss_cat][key] /= len(test_loader)
val_msg = (
f"Val loss img l1: {loss_val['img']['l1']: .6f}, "
f"Val loss img pt: {loss_val['img']['vggpt']: .6f}, "
f"Val loss total: {loss_val['svg']['total']: .6f}, "
f"Val loss cmd: {loss_val['svg']['cmd']: .6f}, "
f"Val loss args: {loss_val['svg']['args']: .6f}, "
)
print(val_msg)
print(f"l1: {loss_val['img']['l1']: .6f}, pt: {loss_val['img']['vggpt']: .6f}")
def main():
opts = get_parser_main_model().parse_args()
opts.name_exp = opts.name_exp + '_' + opts.model_name
experiment_dir = os.path.join(f"{opts.exp_path}","experiments", opts.name_exp)
print(f"Testing on experiment {opts.name_exp}...")
# Dump options
test_main_model(opts)
if __name__ == "__main__":
main() |