Spaces:
Build error
Build error
import argparse | |
import os | |
import yaml | |
global_print_hparams = True | |
hparams = {} | |
class Args: | |
def __init__(self, **kwargs): | |
for k, v in kwargs.items(): | |
self.__setattr__(k, v) | |
def override_config(old_config: dict, new_config: dict): | |
for k, v in new_config.items(): | |
if isinstance(v, dict) and k in old_config: | |
override_config(old_config[k], new_config[k]) | |
else: | |
old_config[k] = v | |
def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True, reset=True, | |
infer=True): | |
''' | |
Load hparams from multiple sources: | |
1. config chain (i.e. first load base_config, then load config); | |
2. if reset == True, load from the (auto-saved) complete config file ('config.yaml') | |
which contains all settings and do not rely on base_config; | |
3. load from argument --hparams or hparams_str, as temporary modification. | |
''' | |
if config == '': | |
parser = argparse.ArgumentParser(description='neural music') | |
parser.add_argument('--config', type=str, default='', | |
help='location of the data corpus') | |
parser.add_argument('--exp_name', type=str, default='', help='exp_name') | |
parser.add_argument('--hparams', type=str, default='', | |
help='location of the data corpus') | |
parser.add_argument('--infer', action='store_true', help='infer') | |
parser.add_argument('--validate', action='store_true', help='validate') | |
parser.add_argument('--reset', action='store_true', help='reset hparams') | |
parser.add_argument('--debug', action='store_true', help='debug') | |
args, unknown = parser.parse_known_args() | |
else: | |
args = Args(config=config, exp_name=exp_name, hparams=hparams_str, | |
infer=infer, validate=False, reset=reset, debug=False) | |
args_work_dir = '' | |
if args.exp_name != '': | |
args.work_dir = args.exp_name | |
args_work_dir = f'checkpoints/{args.work_dir}' | |
config_chains = [] | |
loaded_config = set() | |
def load_config(config_fn): # deep first | |
with open(config_fn, encoding='utf-8') as f: | |
hparams_ = yaml.safe_load(f) | |
loaded_config.add(config_fn) | |
if 'base_config' in hparams_: | |
ret_hparams = {} | |
if not isinstance(hparams_['base_config'], list): | |
hparams_['base_config'] = [hparams_['base_config']] | |
for c in hparams_['base_config']: | |
if c not in loaded_config: | |
if c.startswith('.'): | |
c = f'{os.path.dirname(config_fn)}/{c}' | |
c = os.path.normpath(c) | |
override_config(ret_hparams, load_config(c)) | |
override_config(ret_hparams, hparams_) | |
else: | |
ret_hparams = hparams_ | |
config_chains.append(config_fn) | |
return ret_hparams | |
global hparams | |
assert args.config != '' or args_work_dir != '' | |
saved_hparams = {} | |
if args_work_dir != 'checkpoints/': | |
ckpt_config_path = f'{args_work_dir}/config.yaml' | |
if os.path.exists(ckpt_config_path): | |
try: | |
with open(ckpt_config_path, encoding='utf-8') as f: | |
saved_hparams.update(yaml.safe_load(f)) | |
except: | |
pass | |
if args.config == '': | |
args.config = ckpt_config_path | |
hparams_ = {} | |
hparams_.update(load_config(args.config)) | |
if not args.reset: | |
hparams_.update(saved_hparams) | |
hparams_['work_dir'] = args_work_dir | |
if args.hparams != "": | |
for new_hparam in args.hparams.split(","): | |
k, v = new_hparam.split("=") | |
if k not in hparams_: | |
hparams_[k] = eval(v) | |
if v in ['True', 'False'] or type(hparams_[k]) == bool: | |
hparams_[k] = eval(v) | |
else: | |
hparams_[k] = type(hparams_[k])(v) | |
if args_work_dir != '' and (not os.path.exists(ckpt_config_path) or args.reset) and not args.infer: | |
os.makedirs(hparams_['work_dir'], exist_ok=True) | |
with open(ckpt_config_path, 'w', encoding='utf-8') as f: | |
temp_haparams = hparams_ | |
if 'base_config' in temp_haparams.keys(): | |
del temp_haparams['base_config'] | |
yaml.safe_dump(temp_haparams, f) | |
hparams_['infer'] = args.infer | |
hparams_['debug'] = args.debug | |
hparams_['validate'] = args.validate | |
global global_print_hparams | |
if global_hparams: | |
hparams.clear() | |
hparams.update(hparams_) | |
if print_hparams and global_print_hparams and global_hparams: | |
print('| Hparams chains: ', config_chains) | |
print('| Hparams: ') | |
for i, (k, v) in enumerate(sorted(hparams_.items())): | |
print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "") | |
print("") | |
global_print_hparams = False | |
# print(hparams_.keys()) | |
if hparams.get('exp_name') is None: | |
hparams['exp_name'] = args.exp_name | |
if hparams_.get('exp_name') is None: | |
hparams_['exp_name'] = args.exp_name | |
return hparams_ | |