Spaces:
Build error
Build error
""" | |
Thresholder | |
Author: Jonathon Luiten | |
Simply reads in a set of detection, thresholds them at a certain score threshold, and writes them out again. | |
""" | |
import os | |
import sys | |
from multiprocessing.pool import Pool | |
from multiprocessing import freeze_support | |
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) | |
from trackeval.baselines import baseline_utils as butils | |
from trackeval.utils import get_code_path | |
THRESHOLD = 0.2 | |
code_path = get_code_path() | |
config = { | |
'INPUT_FOL': os.path.join(code_path, 'data/detections/rob_mots/{split}/non_overlap_supplied/data/'), | |
'OUTPUT_FOL': os.path.join(code_path, 'data/detections/rob_mots/{split}/threshold_' + str(100*THRESHOLD) + '/data/'), | |
'SPLIT': 'train', # valid: 'train', 'val', 'test'. | |
'Benchmarks': None, # If None, all benchmarks in SPLIT. | |
'Num_Parallel_Cores': None, # If None, run without parallel. | |
'DETECTION_THRESHOLD': THRESHOLD, | |
} | |
def do_sequence(seq_file): | |
# Load input data from file (e.g. provided detections) | |
# data format: data['cls'][t] = {'ids', 'scores', 'im_hs', 'im_ws', 'mask_rles'} | |
data = butils.load_seq(seq_file) | |
# Where to accumulate output data for writing out | |
output_data = [] | |
# Run for each class. | |
for cls, cls_data in data.items(): | |
# Run for each timestep. | |
for timestep, t_data in enumerate(cls_data): | |
# Threshold detections. | |
t_data = butils.threshold(t_data, config['DETECTION_THRESHOLD']) | |
# Save result in output format to write to file later. | |
# Output Format = [timestep ID class score im_h im_w mask_RLE] | |
for i in range(len(t_data['ids'])): | |
row = [timestep, int(t_data['ids'][i]), cls, t_data['scores'][i], t_data['im_hs'][i], | |
t_data['im_ws'][i], t_data['mask_rles'][i]] | |
output_data.append(row) | |
# Write results to file | |
out_file = seq_file.replace(config['INPUT_FOL'].format(split=config['SPLIT']), | |
config['OUTPUT_FOL'].format(split=config['SPLIT'])) | |
butils.write_seq(output_data, out_file) | |
print('DONE:', seq_todo) | |
if __name__ == '__main__': | |
# Required to fix bug in multiprocessing on windows. | |
freeze_support() | |
# Obtain list of sequences to run tracker for. | |
if config['Benchmarks']: | |
benchmarks = config['Benchmarks'] | |
else: | |
benchmarks = ['davis_unsupervised', 'kitti_mots', 'youtube_vis', 'ovis', 'bdd_mots', 'tao'] | |
if config['SPLIT'] != 'train': | |
benchmarks += ['waymo', 'mots_challenge'] | |
seqs_todo = [] | |
for bench in benchmarks: | |
bench_fol = os.path.join(config['INPUT_FOL'].format(split=config['SPLIT']), bench) | |
seqs_todo += [os.path.join(bench_fol, seq) for seq in os.listdir(bench_fol)] | |
# Run in parallel | |
if config['Num_Parallel_Cores']: | |
with Pool(config['Num_Parallel_Cores']) as pool: | |
results = pool.map(do_sequence, seqs_todo) | |
# Run in series | |
else: | |
for seq_todo in seqs_todo: | |
do_sequence(seq_todo) | |