File size: 3,330 Bytes
47af768
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""
Non-Overlap: Code to take in a set of raw detections and produce a set of non-overlapping detections from it.

Author: Jonathon Luiten
"""

import os
import sys
from multiprocessing.pool import Pool
from multiprocessing import freeze_support

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
from trackeval.baselines import baseline_utils as butils
from trackeval.utils import get_code_path

code_path = get_code_path()
config = {
    'INPUT_FOL': os.path.join(code_path, 'data/detections/rob_mots/{split}/raw_supplied/data/'),
    'OUTPUT_FOL': os.path.join(code_path, 'data/detections/rob_mots/{split}/non_overlap_supplied/data/'),
    'SPLIT': 'train',  # valid: 'train', 'val', 'test'.
    'Benchmarks': None,  # If None, all benchmarks in SPLIT.

    'Num_Parallel_Cores': None,  # If None, run without parallel.

    'THRESHOLD_NMS_MASK_IOU': 0.5,
}


def do_sequence(seq_file):

    # Load input data from file (e.g. provided detections)
    # data format: data['cls'][t] = {'ids', 'scores', 'im_hs', 'im_ws', 'mask_rles'}
    data = butils.load_seq(seq_file)

    # Converts data from a class-separated to a class-combined format.
    # data[t] = {'ids', 'scores', 'im_hs', 'im_ws', 'mask_rles', 'cls'}
    data = butils.combine_classes(data)

    # Where to accumulate output data for writing out
    output_data = []

    # Run for each timestep.
    for timestep, t_data in enumerate(data):

        # Remove redundant masks by performing non-maximum suppression (NMS)
        t_data = butils.mask_NMS(t_data, nms_threshold=config['THRESHOLD_NMS_MASK_IOU'])

        # Perform non-overlap, to get non_overlapping masks.
        t_data = butils.non_overlap(t_data, already_sorted=True)

        # Save result in output format to write to file later.
        # Output Format = [timestep ID class score im_h im_w mask_RLE]
        for i in range(len(t_data['ids'])):
            row = [timestep, int(t_data['ids'][i]), t_data['cls'][i], t_data['scores'][i], t_data['im_hs'][i],
                   t_data['im_ws'][i], t_data['mask_rles'][i]]
            output_data.append(row)

    # Write results to file
    out_file = seq_file.replace(config['INPUT_FOL'].format(split=config['SPLIT']),
                                config['OUTPUT_FOL'].format(split=config['SPLIT']))
    butils.write_seq(output_data, out_file)

    print('DONE:', seq_file)


if __name__ == '__main__':

    # Required to fix bug in multiprocessing on windows.
    freeze_support()

    # Obtain list of sequences to run tracker for.
    if config['Benchmarks']:
        benchmarks = config['Benchmarks']
    else:
        benchmarks = ['davis_unsupervised', 'kitti_mots', 'youtube_vis', 'ovis', 'bdd_mots', 'tao']
        if config['SPLIT'] != 'train':
            benchmarks += ['waymo', 'mots_challenge']
    seqs_todo = []
    for bench in benchmarks:
        bench_fol = os.path.join(config['INPUT_FOL'].format(split=config['SPLIT']), bench)
        seqs_todo += [os.path.join(bench_fol, seq) for seq in os.listdir(bench_fol)]

    # Run in parallel
    if config['Num_Parallel_Cores']:
        with Pool(config['Num_Parallel_Cores']) as pool:
            results = pool.map(do_sequence, seqs_todo)

    # Run in series
    else:
        for seq_todo in seqs_todo:
            do_sequence(seq_todo)