File size: 3,111 Bytes
47af768
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""
Thresholder

Author: Jonathon Luiten

Simply reads in a set of detection, thresholds them at a certain score threshold, and writes them out again.
"""

import os
import sys
from multiprocessing.pool import Pool
from multiprocessing import freeze_support

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
from trackeval.baselines import baseline_utils as butils
from trackeval.utils import get_code_path

THRESHOLD = 0.2

code_path = get_code_path()
config = {
    'INPUT_FOL': os.path.join(code_path, 'data/detections/rob_mots/{split}/non_overlap_supplied/data/'),
    'OUTPUT_FOL': os.path.join(code_path, 'data/detections/rob_mots/{split}/threshold_' + str(100*THRESHOLD) + '/data/'),
    'SPLIT': 'train',  # valid: 'train', 'val', 'test'.
    'Benchmarks': None,  # If None, all benchmarks in SPLIT.

    'Num_Parallel_Cores': None,  # If None, run without parallel.

    'DETECTION_THRESHOLD': THRESHOLD,
}


def do_sequence(seq_file):

    # Load input data from file (e.g. provided detections)
    # data format: data['cls'][t] = {'ids', 'scores', 'im_hs', 'im_ws', 'mask_rles'}
    data = butils.load_seq(seq_file)

    # Where to accumulate output data for writing out
    output_data = []

    # Run for each class.
    for cls, cls_data in data.items():

        # Run for each timestep.
        for timestep, t_data in enumerate(cls_data):

            # Threshold detections.
            t_data = butils.threshold(t_data, config['DETECTION_THRESHOLD'])

            # Save result in output format to write to file later.
            # Output Format = [timestep ID class score im_h im_w mask_RLE]
            for i in range(len(t_data['ids'])):
                row = [timestep, int(t_data['ids'][i]), cls, t_data['scores'][i], t_data['im_hs'][i],
                       t_data['im_ws'][i], t_data['mask_rles'][i]]
                output_data.append(row)

    # Write results to file
    out_file = seq_file.replace(config['INPUT_FOL'].format(split=config['SPLIT']),
                                config['OUTPUT_FOL'].format(split=config['SPLIT']))
    butils.write_seq(output_data, out_file)

    print('DONE:', seq_todo)


if __name__ == '__main__':

    # Required to fix bug in multiprocessing on windows.
    freeze_support()

    # Obtain list of sequences to run tracker for.
    if config['Benchmarks']:
        benchmarks = config['Benchmarks']
    else:
        benchmarks = ['davis_unsupervised', 'kitti_mots', 'youtube_vis', 'ovis', 'bdd_mots', 'tao']
        if config['SPLIT'] != 'train':
            benchmarks += ['waymo', 'mots_challenge']
    seqs_todo = []
    for bench in benchmarks:
        bench_fol = os.path.join(config['INPUT_FOL'].format(split=config['SPLIT']), bench)
        seqs_todo += [os.path.join(bench_fol, seq) for seq in os.listdir(bench_fol)]

    # Run in parallel
    if config['Num_Parallel_Cores']:
        with Pool(config['Num_Parallel_Cores']) as pool:
            results = pool.map(do_sequence, seqs_todo)

    # Run in series
    else:
        for seq_todo in seqs_todo:
            do_sequence(seq_todo)