import os import csv import argparse from collections import OrderedDict def init_config(config, default_config, name=None): """Initialise non-given config values with defaults""" if config is None: config = default_config else: for k in default_config.keys(): if k not in config.keys(): config[k] = default_config[k] if name and config['PRINT_CONFIG']: print('\n%s Config:' % name) for c in config.keys(): print('%-20s : %-30s' % (c, config[c])) return config def update_config(config): """ Parse the arguments of a script and updates the config values for a given value if specified in the arguments. :param config: the config to update :return: the updated config """ parser = argparse.ArgumentParser() for setting in config.keys(): if type(config[setting]) == list or type(config[setting]) == type(None): parser.add_argument("--" + setting, nargs='+') else: parser.add_argument("--" + setting) args = parser.parse_args().__dict__ for setting in args.keys(): if args[setting] is not None: if type(config[setting]) == type(True): if args[setting] == 'True': x = True elif args[setting] == 'False': x = False else: raise Exception('Command line parameter ' + setting + 'must be True or False') elif type(config[setting]) == type(1): x = int(args[setting]) elif type(args[setting]) == type(None): x = None else: x = args[setting] config[setting] = x return config def get_code_path(): """Get base path where code is""" return os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) def validate_metrics_list(metrics_list): """Get names of metric class and ensures they are unique, further checks that the fields within each metric class do not have overlapping names. """ metric_names = [metric.get_name() for metric in metrics_list] # check metric names are unique if len(metric_names) != len(set(metric_names)): raise TrackEvalException('Code being run with multiple metrics of the same name') fields = [] for m in metrics_list: fields += m.fields # check metric fields are unique if len(fields) != len(set(fields)): raise TrackEvalException('Code being run with multiple metrics with fields of the same name') return metric_names def write_summary_results(summaries, cls, output_folder): """Write summary results to file""" fields = sum([list(s.keys()) for s in summaries], []) values = sum([list(s.values()) for s in summaries], []) # In order to remain consistent upon new fields being adding, for each of the following fields if they are present # they will be output in the summary first in the order below. Any further fields will be output in the order each # metric family is called, and within each family either in the order they were added to the dict (python >= 3.6) or # randomly (python < 3.6). default_order = ['HOTA', 'DetA', 'AssA', 'DetRe', 'DetPr', 'AssRe', 'AssPr', 'LocA', 'OWTA', 'HOTA(0)', 'LocA(0)', 'HOTALocA(0)', 'MOTA', 'MOTP', 'MODA', 'CLR_Re', 'CLR_Pr', 'MTR', 'PTR', 'MLR', 'CLR_TP', 'CLR_FN', 'CLR_FP', 'IDSW', 'MT', 'PT', 'ML', 'Frag', 'sMOTA', 'IDF1', 'IDR', 'IDP', 'IDTP', 'IDFN', 'IDFP', 'Dets', 'GT_Dets', 'IDs', 'GT_IDs'] default_ordered_dict = OrderedDict(zip(default_order, [None for _ in default_order])) for f, v in zip(fields, values): default_ordered_dict[f] = v for df in default_order: if default_ordered_dict[df] is None: del default_ordered_dict[df] fields = list(default_ordered_dict.keys()) values = list(default_ordered_dict.values()) out_file = os.path.join(output_folder, cls + '_summary.txt') os.makedirs(os.path.dirname(out_file), exist_ok=True) with open(out_file, 'w', newline='') as f: writer = csv.writer(f, delimiter=' ') writer.writerow(fields) writer.writerow(values) def write_detailed_results(details, cls, output_folder): """Write detailed results to file""" sequences = details[0].keys() fields = ['seq'] + sum([list(s['COMBINED_SEQ'].keys()) for s in details], []) out_file = os.path.join(output_folder, cls + '_detailed.csv') os.makedirs(os.path.dirname(out_file), exist_ok=True) with open(out_file, 'w', newline='') as f: writer = csv.writer(f) writer.writerow(fields) for seq in sorted(sequences): if seq == 'COMBINED_SEQ': continue writer.writerow([seq] + sum([list(s[seq].values()) for s in details], [])) writer.writerow(['COMBINED'] + sum([list(s['COMBINED_SEQ'].values()) for s in details], [])) def load_detail(file): """Loads detailed data for a tracker.""" data = {} with open(file) as f: for i, row_text in enumerate(f): row = row_text.replace('\r', '').replace('\n', '').split(',') if i == 0: keys = row[1:] continue current_values = row[1:] seq = row[0] if seq == 'COMBINED': seq = 'COMBINED_SEQ' if (len(current_values) == len(keys)) and seq != '': data[seq] = {} for key, value in zip(keys, current_values): data[seq][key] = float(value) return data class TrackEvalException(Exception): """Custom exception for catching expected errors.""" ...