Spaces:
Build error
Build error
File size: 5,806 Bytes
47af768 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import os
import csv
import argparse
from collections import OrderedDict
def init_config(config, default_config, name=None):
"""Initialise non-given config values with defaults"""
if config is None:
config = default_config
else:
for k in default_config.keys():
if k not in config.keys():
config[k] = default_config[k]
if name and config['PRINT_CONFIG']:
print('\n%s Config:' % name)
for c in config.keys():
print('%-20s : %-30s' % (c, config[c]))
return config
def update_config(config):
"""
Parse the arguments of a script and updates the config values for a given value if specified in the arguments.
:param config: the config to update
:return: the updated config
"""
parser = argparse.ArgumentParser()
for setting in config.keys():
if type(config[setting]) == list or type(config[setting]) == type(None):
parser.add_argument("--" + setting, nargs='+')
else:
parser.add_argument("--" + setting)
args = parser.parse_args().__dict__
for setting in args.keys():
if args[setting] is not None:
if type(config[setting]) == type(True):
if args[setting] == 'True':
x = True
elif args[setting] == 'False':
x = False
else:
raise Exception('Command line parameter ' + setting + 'must be True or False')
elif type(config[setting]) == type(1):
x = int(args[setting])
elif type(args[setting]) == type(None):
x = None
else:
x = args[setting]
config[setting] = x
return config
def get_code_path():
"""Get base path where code is"""
return os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
def validate_metrics_list(metrics_list):
"""Get names of metric class and ensures they are unique, further checks that the fields within each metric class
do not have overlapping names.
"""
metric_names = [metric.get_name() for metric in metrics_list]
# check metric names are unique
if len(metric_names) != len(set(metric_names)):
raise TrackEvalException('Code being run with multiple metrics of the same name')
fields = []
for m in metrics_list:
fields += m.fields
# check metric fields are unique
if len(fields) != len(set(fields)):
raise TrackEvalException('Code being run with multiple metrics with fields of the same name')
return metric_names
def write_summary_results(summaries, cls, output_folder):
"""Write summary results to file"""
fields = sum([list(s.keys()) for s in summaries], [])
values = sum([list(s.values()) for s in summaries], [])
# In order to remain consistent upon new fields being adding, for each of the following fields if they are present
# they will be output in the summary first in the order below. Any further fields will be output in the order each
# metric family is called, and within each family either in the order they were added to the dict (python >= 3.6) or
# randomly (python < 3.6).
default_order = ['HOTA', 'DetA', 'AssA', 'DetRe', 'DetPr', 'AssRe', 'AssPr', 'LocA', 'OWTA', 'HOTA(0)', 'LocA(0)',
'HOTALocA(0)', 'MOTA', 'MOTP', 'MODA', 'CLR_Re', 'CLR_Pr', 'MTR', 'PTR', 'MLR', 'CLR_TP', 'CLR_FN',
'CLR_FP', 'IDSW', 'MT', 'PT', 'ML', 'Frag', 'sMOTA', 'IDF1', 'IDR', 'IDP', 'IDTP', 'IDFN', 'IDFP',
'Dets', 'GT_Dets', 'IDs', 'GT_IDs']
default_ordered_dict = OrderedDict(zip(default_order, [None for _ in default_order]))
for f, v in zip(fields, values):
default_ordered_dict[f] = v
for df in default_order:
if default_ordered_dict[df] is None:
del default_ordered_dict[df]
fields = list(default_ordered_dict.keys())
values = list(default_ordered_dict.values())
out_file = os.path.join(output_folder, cls + '_summary.txt')
os.makedirs(os.path.dirname(out_file), exist_ok=True)
with open(out_file, 'w', newline='') as f:
writer = csv.writer(f, delimiter=' ')
writer.writerow(fields)
writer.writerow(values)
def write_detailed_results(details, cls, output_folder):
"""Write detailed results to file"""
sequences = details[0].keys()
fields = ['seq'] + sum([list(s['COMBINED_SEQ'].keys()) for s in details], [])
out_file = os.path.join(output_folder, cls + '_detailed.csv')
os.makedirs(os.path.dirname(out_file), exist_ok=True)
with open(out_file, 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow(fields)
for seq in sorted(sequences):
if seq == 'COMBINED_SEQ':
continue
writer.writerow([seq] + sum([list(s[seq].values()) for s in details], []))
writer.writerow(['COMBINED'] + sum([list(s['COMBINED_SEQ'].values()) for s in details], []))
def load_detail(file):
"""Loads detailed data for a tracker."""
data = {}
with open(file) as f:
for i, row_text in enumerate(f):
row = row_text.replace('\r', '').replace('\n', '').split(',')
if i == 0:
keys = row[1:]
continue
current_values = row[1:]
seq = row[0]
if seq == 'COMBINED':
seq = 'COMBINED_SEQ'
if (len(current_values) == len(keys)) and seq != '':
data[seq] = {}
for key, value in zip(keys, current_values):
data[seq][key] = float(value)
return data
class TrackEvalException(Exception):
"""Custom exception for catching expected errors."""
...
|