File size: 4,502 Bytes
29f689c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import csv
import os
import sys

__dir__ = os.path.dirname(os.path.abspath(__file__))

sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))

import numpy as np

from tools.data import build_dataloader
from tools.engine import Config, Trainer
from tools.utility import ArgsParser


def parse_args():
    parser = ArgsParser()
    args = parser.parse_args()
    return args


def main():
    FLAGS = parse_args()
    cfg = Config(FLAGS.config)
    FLAGS = vars(FLAGS)
    opt = FLAGS.pop('opt')
    cfg.merge_dict(FLAGS)
    cfg.merge_dict(opt)

    cfg.cfg['Global']['use_amp'] = False
    if cfg.cfg['Global']['output_dir'][-1] == '/':
        cfg.cfg['Global']['output_dir'] = cfg.cfg['Global']['output_dir'][:-1]
    cfg.cfg['Global']['max_text_length'] = 200
    cfg.cfg['Architecture']['Decoder']['max_len'] = 200
    cfg.cfg['Metric']['name'] = 'RecMetricLong'
    if cfg.cfg['Global']['pretrained_model'] is None:
        cfg.cfg['Global'][
            'pretrained_model'] = cfg.cfg['Global']['output_dir'] + '/best.pth'
    trainer = Trainer(cfg, mode='eval')

    best_model_dict = trainer.status.get('metrics', {})
    trainer.logger.info('metric in ckpt ***************')
    for k, v in best_model_dict.items():
        trainer.logger.info('{}:{}'.format(k, v))

    data_dirs_list = [
        ['../ltb/long_lmdb'],
    ]

    cfg = cfg.cfg
    file_csv = open(
        cfg['Global']['output_dir'] + '/' +
        cfg['Global']['output_dir'].split('/')[-1] +
        '_result1_1_test_all_long_final_ultra_bs1.csv', 'w')
    csv_w = csv.writer(file_csv)

    for data_dirs in data_dirs_list:
        acc_each = []
        acc_each_num = []
        acc_each_dis = []
        each_long = {}
        for datadir in data_dirs:
            config_each = cfg.copy()

            config_each['Eval']['dataset']['data_dir_list'] = [datadir]
            valid_dataloader = build_dataloader(config_each, 'Eval',
                                                trainer.logger)
            trainer.logger.info(
                f'{datadir} valid dataloader has {len(valid_dataloader)} iters'
            )
            trainer.valid_dataloader = valid_dataloader
            metric = trainer.eval()
            acc_each.append(metric['acc'] * 100)
            acc_each_dis.append(metric['norm_edit_dis'])
            acc_each_num.append(metric['all_num'])

            trainer.logger.info('metric eval ***************')
            for k, v in metric.items():
                trainer.logger.info('{}:{}'.format(k, v))
                if 'each' in k:
                    csv_w.writerow([k] + v[26:])
                    each_long[k] = each_long.get(k, []) + [np.array(v[26:])]
        avg1 = np.array(acc_each) * np.array(acc_each_num) / sum(acc_each_num)
        csv_w.writerow(acc_each + [avg1.sum().tolist()] +
                       [sum(acc_each) / len(acc_each)])
        print(acc_each + [avg1.sum().tolist()] +
              [sum(acc_each) / len(acc_each)])
        avg1 = np.array(acc_each_dis) * np.array(acc_each_num) / sum(
            acc_each_num)
        csv_w.writerow(acc_each_dis + [avg1.sum().tolist()] +
                       [sum(acc_each_dis) / len(acc_each)])

        sum_all = np.array(each_long['each_len_num']).sum(0)
        for k, v in each_long.items():
            if k != 'each_len_num':
                v_sum_weight = (np.array(v) *
                                np.array(each_long['each_len_num'])).sum(0)
                sum_all_pad = np.where(sum_all == 0, 1., sum_all)
                v_all = v_sum_weight / sum_all_pad
                v_all = np.where(sum_all == 0, 0., v_all)
                csv_w.writerow([k] + v_all.tolist())
                v_26_40 = (v_all[:10] * sum_all[:10]) / sum_all[:10].sum()
                csv_w.writerow([k + '26_35'] + [v_26_40.sum().tolist()] +
                               [sum_all[:10].sum().tolist()])
                v_41_55 = (v_all[10:30] *
                           sum_all[10:30]) / sum_all[10:30].sum()
                csv_w.writerow([k + '36_55'] + [v_41_55.sum().tolist()] +
                               [sum_all[10:30].sum().tolist()])
                v_56_70 = (v_all[30:] * sum_all[30:]) / sum_all[30:].sum()
                csv_w.writerow([k + '56'] + [v_56_70.sum().tolist()] +
                               [sum_all[30:].sum().tolist()])
            else:
                csv_w.writerow([k] + sum_all.tolist())
    file_csv.close()


if __name__ == '__main__':
    main()