Spaces:
Build error
Build error
import yaml | |
import json | |
import argparse | |
import logging | |
logger = logging.getLogger(__name__) | |
def load_config_dict_to_opt(opt, config_dict): | |
""" | |
Load the key, value pairs from config_dict to opt, overriding existing values in opt | |
if there is any. | |
""" | |
if not isinstance(config_dict, dict): | |
raise TypeError("Config must be a Python dictionary") | |
for k, v in config_dict.items(): | |
k_parts = k.split('.') | |
pointer = opt | |
for k_part in k_parts[:-1]: | |
if k_part not in pointer: | |
pointer[k_part] = {} | |
pointer = pointer[k_part] | |
assert isinstance(pointer, dict), "Overriding key needs to be inside a Python dict." | |
ori_value = pointer.get(k_parts[-1]) | |
pointer[k_parts[-1]] = v | |
if ori_value: | |
logger.warning(f"Overrided {k} from {ori_value} to {pointer[k_parts[-1]]}") | |
def load_opt_from_config_files(conf_file): | |
""" | |
Load opt from the config files, settings in later files can override those in previous files. | |
Args: | |
conf_files: config file path | |
Returns: | |
dict: a dictionary of opt settings | |
""" | |
opt = {} | |
with open(conf_file, encoding='utf-8') as f: | |
config_dict = yaml.safe_load(f) | |
load_config_dict_to_opt(opt, config_dict) | |
return opt | |
def load_opt_command(args): | |
parser = argparse.ArgumentParser(description='MainzTrain: Pretrain or fine-tune models for NLP tasks.') | |
parser.add_argument('command', help='Command: train/evaluate/train-and-evaluate') | |
parser.add_argument('--conf_files', required=True, help='Path(s) to the MainzTrain config file(s).') | |
parser.add_argument('--config_overrides', nargs='*', help='Override parameters on config with a json style string, e.g. {"<PARAM_NAME_1>": <PARAM_VALUE_1>, "<PARAM_GROUP_2>.<PARAM_SUBGROUP_2>.<PARAM_2>": <PARAM_VALUE_2>}. A key with "." updates the object in the corresponding nested dict. Remember to escape " in command line.') | |
parser.add_argument('--overrides', help='arguments that used to overide the config file in cmdline', nargs=argparse.REMAINDER) | |
cmdline_args = parser.parse_args() if not args else parser.parse_args(args) | |
opt = load_opt_from_config_files(cmdline_args.conf_files) | |
if cmdline_args.config_overrides: | |
config_overrides_string = ' '.join(cmdline_args.config_overrides) | |
logger.warning(f"Command line config overrides: {config_overrides_string}") | |
config_dict = json.loads(config_overrides_string) | |
load_config_dict_to_opt(opt, config_dict) | |
if cmdline_args.overrides: | |
assert len(cmdline_args.overrides) % 2 == 0, "overides arguments is not paired, required: key value" | |
keys = [cmdline_args.overrides[idx*2] for idx in range(len(cmdline_args.overrides)//2)] | |
vals = [cmdline_args.overrides[idx*2+1] for idx in range(len(cmdline_args.overrides)//2)] | |
vals = [val.replace('false', '').replace('False','') if len(val.replace(' ', '')) == 5 else val for val in vals] | |
types = [] | |
for key in keys: | |
key = key.split('.') | |
ele = opt.copy() | |
while len(key) > 0: | |
ele = ele[key.pop(0)] | |
types.append(type(ele)) | |
config_dict = {x:z(y) for x,y,z in zip(keys, vals, types)} | |
load_config_dict_to_opt(opt, config_dict) | |
# combine cmdline_args into opt dictionary | |
for key, val in cmdline_args.__dict__.items(): | |
if val is not None: | |
opt[key] = val | |
return opt, cmdline_args | |
def save_opt_to_json(opt, conf_file): | |
with open(conf_file, 'w', encoding='utf-8') as f: | |
json.dump(opt, f, indent=4) | |
def save_opt_to_yaml(opt, conf_file): | |
with open(conf_file, 'w', encoding='utf-8') as f: | |
yaml.dump(opt, f) | |