xfys's picture
Upload 645 files
47af768
raw
history blame
5.81 kB
import os
import csv
import argparse
from collections import OrderedDict
def init_config(config, default_config, name=None):
"""Initialise non-given config values with defaults"""
if config is None:
config = default_config
else:
for k in default_config.keys():
if k not in config.keys():
config[k] = default_config[k]
if name and config['PRINT_CONFIG']:
print('\n%s Config:' % name)
for c in config.keys():
print('%-20s : %-30s' % (c, config[c]))
return config
def update_config(config):
"""
Parse the arguments of a script and updates the config values for a given value if specified in the arguments.
:param config: the config to update
:return: the updated config
"""
parser = argparse.ArgumentParser()
for setting in config.keys():
if type(config[setting]) == list or type(config[setting]) == type(None):
parser.add_argument("--" + setting, nargs='+')
else:
parser.add_argument("--" + setting)
args = parser.parse_args().__dict__
for setting in args.keys():
if args[setting] is not None:
if type(config[setting]) == type(True):
if args[setting] == 'True':
x = True
elif args[setting] == 'False':
x = False
else:
raise Exception('Command line parameter ' + setting + 'must be True or False')
elif type(config[setting]) == type(1):
x = int(args[setting])
elif type(args[setting]) == type(None):
x = None
else:
x = args[setting]
config[setting] = x
return config
def get_code_path():
"""Get base path where code is"""
return os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
def validate_metrics_list(metrics_list):
"""Get names of metric class and ensures they are unique, further checks that the fields within each metric class
do not have overlapping names.
"""
metric_names = [metric.get_name() for metric in metrics_list]
# check metric names are unique
if len(metric_names) != len(set(metric_names)):
raise TrackEvalException('Code being run with multiple metrics of the same name')
fields = []
for m in metrics_list:
fields += m.fields
# check metric fields are unique
if len(fields) != len(set(fields)):
raise TrackEvalException('Code being run with multiple metrics with fields of the same name')
return metric_names
def write_summary_results(summaries, cls, output_folder):
"""Write summary results to file"""
fields = sum([list(s.keys()) for s in summaries], [])
values = sum([list(s.values()) for s in summaries], [])
# In order to remain consistent upon new fields being adding, for each of the following fields if they are present
# they will be output in the summary first in the order below. Any further fields will be output in the order each
# metric family is called, and within each family either in the order they were added to the dict (python >= 3.6) or
# randomly (python < 3.6).
default_order = ['HOTA', 'DetA', 'AssA', 'DetRe', 'DetPr', 'AssRe', 'AssPr', 'LocA', 'OWTA', 'HOTA(0)', 'LocA(0)',
'HOTALocA(0)', 'MOTA', 'MOTP', 'MODA', 'CLR_Re', 'CLR_Pr', 'MTR', 'PTR', 'MLR', 'CLR_TP', 'CLR_FN',
'CLR_FP', 'IDSW', 'MT', 'PT', 'ML', 'Frag', 'sMOTA', 'IDF1', 'IDR', 'IDP', 'IDTP', 'IDFN', 'IDFP',
'Dets', 'GT_Dets', 'IDs', 'GT_IDs']
default_ordered_dict = OrderedDict(zip(default_order, [None for _ in default_order]))
for f, v in zip(fields, values):
default_ordered_dict[f] = v
for df in default_order:
if default_ordered_dict[df] is None:
del default_ordered_dict[df]
fields = list(default_ordered_dict.keys())
values = list(default_ordered_dict.values())
out_file = os.path.join(output_folder, cls + '_summary.txt')
os.makedirs(os.path.dirname(out_file), exist_ok=True)
with open(out_file, 'w', newline='') as f:
writer = csv.writer(f, delimiter=' ')
writer.writerow(fields)
writer.writerow(values)
def write_detailed_results(details, cls, output_folder):
"""Write detailed results to file"""
sequences = details[0].keys()
fields = ['seq'] + sum([list(s['COMBINED_SEQ'].keys()) for s in details], [])
out_file = os.path.join(output_folder, cls + '_detailed.csv')
os.makedirs(os.path.dirname(out_file), exist_ok=True)
with open(out_file, 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow(fields)
for seq in sorted(sequences):
if seq == 'COMBINED_SEQ':
continue
writer.writerow([seq] + sum([list(s[seq].values()) for s in details], []))
writer.writerow(['COMBINED'] + sum([list(s['COMBINED_SEQ'].values()) for s in details], []))
def load_detail(file):
"""Loads detailed data for a tracker."""
data = {}
with open(file) as f:
for i, row_text in enumerate(f):
row = row_text.replace('\r', '').replace('\n', '').split(',')
if i == 0:
keys = row[1:]
continue
current_values = row[1:]
seq = row[0]
if seq == 'COMBINED':
seq = 'COMBINED_SEQ'
if (len(current_values) == len(keys)) and seq != '':
data[seq] = {}
for key, value in zip(keys, current_values):
data[seq][key] = float(value)
return data
class TrackEvalException(Exception):
"""Custom exception for catching expected errors."""
...