xfys's picture
Upload 645 files
47af768
raw
history blame
3.33 kB
"""
Non-Overlap: Code to take in a set of raw detections and produce a set of non-overlapping detections from it.
Author: Jonathon Luiten
"""
import os
import sys
from multiprocessing.pool import Pool
from multiprocessing import freeze_support
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
from trackeval.baselines import baseline_utils as butils
from trackeval.utils import get_code_path
code_path = get_code_path()
config = {
'INPUT_FOL': os.path.join(code_path, 'data/detections/rob_mots/{split}/raw_supplied/data/'),
'OUTPUT_FOL': os.path.join(code_path, 'data/detections/rob_mots/{split}/non_overlap_supplied/data/'),
'SPLIT': 'train', # valid: 'train', 'val', 'test'.
'Benchmarks': None, # If None, all benchmarks in SPLIT.
'Num_Parallel_Cores': None, # If None, run without parallel.
'THRESHOLD_NMS_MASK_IOU': 0.5,
}
def do_sequence(seq_file):
# Load input data from file (e.g. provided detections)
# data format: data['cls'][t] = {'ids', 'scores', 'im_hs', 'im_ws', 'mask_rles'}
data = butils.load_seq(seq_file)
# Converts data from a class-separated to a class-combined format.
# data[t] = {'ids', 'scores', 'im_hs', 'im_ws', 'mask_rles', 'cls'}
data = butils.combine_classes(data)
# Where to accumulate output data for writing out
output_data = []
# Run for each timestep.
for timestep, t_data in enumerate(data):
# Remove redundant masks by performing non-maximum suppression (NMS)
t_data = butils.mask_NMS(t_data, nms_threshold=config['THRESHOLD_NMS_MASK_IOU'])
# Perform non-overlap, to get non_overlapping masks.
t_data = butils.non_overlap(t_data, already_sorted=True)
# Save result in output format to write to file later.
# Output Format = [timestep ID class score im_h im_w mask_RLE]
for i in range(len(t_data['ids'])):
row = [timestep, int(t_data['ids'][i]), t_data['cls'][i], t_data['scores'][i], t_data['im_hs'][i],
t_data['im_ws'][i], t_data['mask_rles'][i]]
output_data.append(row)
# Write results to file
out_file = seq_file.replace(config['INPUT_FOL'].format(split=config['SPLIT']),
config['OUTPUT_FOL'].format(split=config['SPLIT']))
butils.write_seq(output_data, out_file)
print('DONE:', seq_file)
if __name__ == '__main__':
# Required to fix bug in multiprocessing on windows.
freeze_support()
# Obtain list of sequences to run tracker for.
if config['Benchmarks']:
benchmarks = config['Benchmarks']
else:
benchmarks = ['davis_unsupervised', 'kitti_mots', 'youtube_vis', 'ovis', 'bdd_mots', 'tao']
if config['SPLIT'] != 'train':
benchmarks += ['waymo', 'mots_challenge']
seqs_todo = []
for bench in benchmarks:
bench_fol = os.path.join(config['INPUT_FOL'].format(split=config['SPLIT']), bench)
seqs_todo += [os.path.join(bench_fol, seq) for seq in os.listdir(bench_fol)]
# Run in parallel
if config['Num_Parallel_Cores']:
with Pool(config['Num_Parallel_Cores']) as pool:
results = pool.map(do_sequence, seqs_todo)
# Run in series
else:
for seq_todo in seqs_todo:
do_sequence(seq_todo)