File size: 3,495 Bytes
47af768
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
""" Test to ensure that the code is working correctly.
Should test ALL metrics across all datasets and splits currently supported.
Only tests one tracker per dataset/split to give a quick test result.
"""

import sys
import os
import numpy as np
from multiprocessing import freeze_support

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import trackeval  # noqa: E402

# Fixes multiprocessing on windows, does nothing otherwise
if __name__ == '__main__':
    freeze_support()

eval_config = {'USE_PARALLEL': False,
               'NUM_PARALLEL_CORES': 8,
               }
evaluator = trackeval.Evaluator(eval_config)
metrics_list = [trackeval.metrics.HOTA(), trackeval.metrics.CLEAR(), trackeval.metrics.Identity()]

tests = [
    {'DATASET': 'Kitti2DBox', 'SPLIT_TO_EVAL': 'training', 'TRACKERS_TO_EVAL': ['CIWT']},
    {'DATASET': 'MotChallenge2DBox', 'BENCHMARK': 'MOT15', 'SPLIT_TO_EVAL': 'train', 'TRACKERS_TO_EVAL': ['MPNTrack']},
    {'DATASET': 'MotChallenge2DBox', 'BENCHMARK': 'MOT16', 'SPLIT_TO_EVAL': 'train', 'TRACKERS_TO_EVAL': ['MPNTrack']},
    {'DATASET': 'MotChallenge2DBox', 'BENCHMARK': 'MOT17', 'SPLIT_TO_EVAL': 'train', 'TRACKERS_TO_EVAL': ['MPNTrack']},
    {'DATASET': 'MotChallenge2DBox', 'BENCHMARK': 'MOT20', 'SPLIT_TO_EVAL': 'train', 'TRACKERS_TO_EVAL': ['MPNTrack']},
]

for dataset_config in tests:

    dataset_name = dataset_config.pop('DATASET')
    if dataset_name == 'MotChallenge2DBox':
        dataset_list = [trackeval.datasets.MotChallenge2DBox(dataset_config)]
        file_loc = os.path.join('mot_challenge', dataset_config['BENCHMARK'] + '-' + dataset_config['SPLIT_TO_EVAL'])
    elif dataset_name == 'Kitti2DBox':
        dataset_list = [trackeval.datasets.Kitti2DBox(dataset_config)]
        file_loc = os.path.join('kitti', 'kitti_2d_box_train')
    else:
        raise Exception('Dataset %s does not exist.' % dataset_name)

    raw_results, messages = evaluator.evaluate(dataset_list, metrics_list)

    classes = dataset_list[0].config['CLASSES_TO_EVAL']
    tracker = dataset_config['TRACKERS_TO_EVAL'][0]
    test_data_loc = os.path.join(os.path.dirname(__file__), '..', 'data', 'tests', file_loc)

    for cls in classes:
        results = {seq: raw_results[dataset_name][tracker][seq][cls] for seq in raw_results[dataset_name][tracker].keys()}
        current_metrics_list = metrics_list + [trackeval.metrics.Count()]
        metric_names = trackeval.utils.validate_metrics_list(current_metrics_list)

        # Load expected results:
        test_data = trackeval.utils.load_detail(os.path.join(test_data_loc, tracker, cls + '_detailed.csv'))

        # Do checks
        for seq in test_data.keys():
            assert len(test_data[seq].keys()) > 250, len(test_data[seq].keys())

            details = []
            for metric, metric_name in zip(current_metrics_list, metric_names):
                table_res = {seq_key: seq_value[metric_name] for seq_key, seq_value in results.items()}
                details.append(metric.detailed_results(table_res))
            res_fields = sum([list(s['COMBINED_SEQ'].keys()) for s in details], [])
            res_values = sum([list(s[seq].values()) for s in details], [])
            res_dict = dict(zip(res_fields, res_values))

            for field in test_data[seq].keys():
                assert np.isclose(res_dict[field], test_data[seq][field]), seq + ': ' + cls + ': ' + field

    print('Tracker %s tests passed' % tracker)
print('All tests passed')