xfys's picture
Upload 645 files
47af768
raw
history blame
2.73 kB
""" Test to ensure that the code is working correctly.
Runs all metrics on 14 trackers for the MOT Challenge MOT17 benchmark.
"""
import sys
import os
import numpy as np
from multiprocessing import freeze_support
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import trackeval # noqa: E402
# Fixes multiprocessing on windows, does nothing otherwise
if __name__ == '__main__':
freeze_support()
eval_config = {'USE_PARALLEL': False,
'NUM_PARALLEL_CORES': 8,
}
evaluator = trackeval.Evaluator(eval_config)
metrics_list = [trackeval.metrics.HOTA(), trackeval.metrics.CLEAR(), trackeval.metrics.Identity()]
test_data_loc = os.path.join(os.path.dirname(__file__), '..', 'data', 'tests', 'mot_challenge', 'MOT17-train')
trackers = [
'DPMOT',
'GNNMatch',
'IA',
'ISE_MOT17R',
'Lif_T',
'Lif_TsimInt',
'LPC_MOT',
'MAT',
'MIFTv2',
'MPNTrack',
'SSAT',
'TracktorCorr',
'Tracktorv2',
'UnsupTrack',
]
for tracker in trackers:
# Run code on tracker
dataset_config = {'TRACKERS_TO_EVAL': [tracker],
'BENCHMARK': 'MOT17'}
dataset_list = [trackeval.datasets.MotChallenge2DBox(dataset_config)]
raw_results, messages = evaluator.evaluate(dataset_list, metrics_list)
results = {seq: raw_results['MotChallenge2DBox'][tracker][seq]['pedestrian'] for seq in
raw_results['MotChallenge2DBox'][tracker].keys()}
current_metrics_list = metrics_list + [trackeval.metrics.Count()]
metric_names = trackeval.utils.validate_metrics_list(current_metrics_list)
# Load expected results:
test_data = trackeval.utils.load_detail(os.path.join(test_data_loc, tracker, 'pedestrian_detailed.csv'))
assert len(test_data.keys()) == 22, len(test_data.keys())
# Do checks
for seq in test_data.keys():
assert len(test_data[seq].keys()) > 250, len(test_data[seq].keys())
details = []
for metric, metric_name in zip(current_metrics_list, metric_names):
table_res = {seq_key: seq_value[metric_name] for seq_key, seq_value in results.items()}
details.append(metric.detailed_results(table_res))
res_fields = sum([list(s['COMBINED_SEQ'].keys()) for s in details], [])
res_values = sum([list(s[seq].values()) for s in details], [])
res_dict = dict(zip(res_fields, res_values))
for field in test_data[seq].keys():
if not np.isclose(res_dict[field], test_data[seq][field]):
print(tracker, seq, res_dict[field], test_data[seq][field], field)
raise AssertionError
print('Tracker %s tests passed' % tracker)
print('All tests passed')