File size: 3,631 Bytes
9390e2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import json
import pkg_resources
from collections import OrderedDict

# Paths
data_path = pkg_resources.resource_filename('spiga', 'data/annotations')

def main():

    import argparse
    pars = argparse.ArgumentParser(description='Benchmark alignments evaluator')
    pars.add_argument('pred_file', nargs='+', type=str, help='Absolute path to the prediction json file (Multi file)')
    pars.add_argument('--eval', nargs='+', type=str, default=['lnd'],
                      choices=['lnd', 'pose'], help='Evaluation modes')
    pars.add_argument('-s', '--save', action='store_true', help='Save results')
    args = pars.parse_args()

    for pred_file in args.pred_file:
        benchmark = get_evaluator(pred_file, args.eval, args.save)
        benchmark.metrics()


class Evaluator:

    def __init__(self, data_file, evals=(), save=True, process_err=True):

        # Inputs
        self.data_file = data_file
        self.evals = evals
        self.save = save

        # Paths
        data_name = data_file.split('/')[-1]
        self.data_dir = data_file.split(data_name)[0]

        # Information from name
        data_name = data_name.split('.')[0]
        data_name = data_name.split('_')
        self.data_type = data_name[-1]
        self.database = data_name[-2]

        # Load predictions and annotations
        anns_file = data_path + '/%s/%s.json' % (self.database, self.data_type)
        self.anns = self.load_files(anns_file)
        self.pred = self.load_files(data_file)

        # Compute errors
        self.error = OrderedDict()
        self.error_pimg = OrderedDict()
        self.metrics_log = OrderedDict()
        if process_err:
            self.compute_error(self.anns, self.pred)

    def compute_error(self, anns, pred, select_ids=None):
        database_ref = [self.database, self.data_type]
        for eval in self.evals:
            self.error[eval.name] = eval.compute_error(anns, pred, database_ref, select_ids)
            self.error_pimg = eval.get_pimg_err(self.error_pimg)
        return self.error

    def metrics(self):
        for eval in self.evals:
            self.metrics_log[eval.name] = eval.metrics()

        if self.save:
            file_name = self.data_dir + '/metrics_%s_%s.txt' % (self.database, self.data_type)
            with open(file_name, 'w') as file:
                file.write(str(self))

        return self.metrics_log

    def load_files(self, input_file):
        with open(input_file) as jsonfile:
            data = json.load(jsonfile)
        return data

    def _dict2text(self, name, dictionary, num_tab=1):
        prev_tabs = '\t'*num_tab
        text = '%s {\n' % name
        for k, v in dictionary.items():
            if isinstance(v, OrderedDict) or isinstance(v, dict):
                text += '{}{}'.format(prev_tabs, self._dict2text(k, v, num_tab=num_tab+1))
            else:
                text += '{}{}: {}\n'.format(prev_tabs, k, v)
        text += (prev_tabs + '}\n')
        return text

    def __str__(self):
        state_dict = self.metrics_log
        text = self._dict2text('Metrics', state_dict)
        return text


def get_evaluator(pred_file, evaluate=('lnd', 'pose'), save=False, process_err=True):
    eval_list = []
    if "lnd" in evaluate:
        import spiga.eval.benchmark.metrics.landmarks as mlnd
        eval_list.append(mlnd.MetricsLandmarks())
    if "pose" in evaluate:
        import spiga.eval.benchmark.metrics.pose as mpose
        eval_list.append(mpose.MetricsHeadpose())

    return Evaluator(pred_file, evals=eval_list, save=save, process_err=process_err)


if __name__ == '__main__':
    main()