Spaces:
Build error
Build error
Rongjiehuang
commited on
Commit
•
98a2c89
1
Parent(s):
6ceff9a
update
Browse files- utils/hparams.py +35 -38
utils/hparams.py
CHANGED
@@ -1,7 +1,5 @@
|
|
1 |
import argparse
|
2 |
import os
|
3 |
-
import subprocess
|
4 |
-
|
5 |
import yaml
|
6 |
|
7 |
global_print_hparams = True
|
@@ -23,31 +21,30 @@ def override_config(old_config: dict, new_config: dict):
|
|
23 |
|
24 |
|
25 |
def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True):
|
26 |
-
if config == ''
|
27 |
-
parser = argparse.ArgumentParser(description='')
|
28 |
-
parser.add_argument('--config', type=str, default='
|
29 |
help='location of the data corpus')
|
30 |
parser.add_argument('--exp_name', type=str, default='', help='exp_name')
|
31 |
parser.add_argument('--hparams', type=str, default='',
|
32 |
help='location of the data corpus')
|
33 |
-
parser.add_argument('--
|
34 |
parser.add_argument('--validate', action='store_true', help='validate')
|
35 |
parser.add_argument('--reset', action='store_true', help='reset hparams')
|
36 |
-
parser.add_argument('--remove', action='store_true', help='remove old ckpt')
|
37 |
parser.add_argument('--debug', action='store_true', help='debug')
|
38 |
args, unknown = parser.parse_known_args()
|
39 |
else:
|
40 |
args = Args(config=config, exp_name=exp_name, hparams=hparams_str,
|
41 |
infer=False, validate=False, reset=False, debug=False)
|
42 |
-
|
43 |
-
|
|
|
|
|
44 |
|
45 |
config_chains = []
|
46 |
loaded_config = set()
|
47 |
|
48 |
def load_config(config_fn): # deep first
|
49 |
-
if not os.path.exists(config_fn):
|
50 |
-
return {}
|
51 |
with open(config_fn) as f:
|
52 |
hparams_ = yaml.safe_load(f)
|
53 |
loaded_config.add(config_fn)
|
@@ -56,10 +53,10 @@ def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, glob
|
|
56 |
if not isinstance(hparams_['base_config'], list):
|
57 |
hparams_['base_config'] = [hparams_['base_config']]
|
58 |
for c in hparams_['base_config']:
|
59 |
-
if c.startswith('.'):
|
60 |
-
c = f'{os.path.dirname(config_fn)}/{c}'
|
61 |
-
c = os.path.normpath(c)
|
62 |
if c not in loaded_config:
|
|
|
|
|
|
|
63 |
override_config(ret_hparams, load_config(c))
|
64 |
override_config(ret_hparams, hparams_)
|
65 |
else:
|
@@ -67,53 +64,48 @@ def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, glob
|
|
67 |
config_chains.append(config_fn)
|
68 |
return ret_hparams
|
69 |
|
|
|
|
|
70 |
saved_hparams = {}
|
71 |
-
args_work_dir
|
72 |
-
if args.exp_name != '':
|
73 |
-
args_work_dir = f'checkpoints/{args.exp_name}'
|
74 |
ckpt_config_path = f'{args_work_dir}/config.yaml'
|
75 |
if os.path.exists(ckpt_config_path):
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
hparams_ = {}
|
79 |
-
|
80 |
-
|
81 |
if not args.reset:
|
82 |
hparams_.update(saved_hparams)
|
83 |
hparams_['work_dir'] = args_work_dir
|
84 |
|
85 |
-
# --hparams="a=1,b.c=2,d=[1 1 1]"
|
86 |
if args.hparams != "":
|
87 |
for new_hparam in args.hparams.split(","):
|
88 |
k, v = new_hparam.split("=")
|
89 |
-
v
|
90 |
-
|
91 |
-
for k_ in k.split(".")[:-1]:
|
92 |
-
config_node = config_node[k_]
|
93 |
-
k = k.split(".")[-1]
|
94 |
-
if v in ['True', 'False'] or type(config_node[k]) in [bool, list, dict]:
|
95 |
-
if type(config_node[k]) == list:
|
96 |
-
v = v.replace(" ", ",")
|
97 |
-
config_node[k] = eval(v)
|
98 |
else:
|
99 |
-
|
100 |
-
|
101 |
-
answer = input("REMOVE old checkpoint? Y/N [Default: N]: ")
|
102 |
-
if answer.lower() == "y":
|
103 |
-
subprocess.check_call(f'rm -rf {args_work_dir}', shell=True)
|
104 |
if args_work_dir != '' and (not os.path.exists(ckpt_config_path) or args.reset) and not args.infer:
|
105 |
os.makedirs(hparams_['work_dir'], exist_ok=True)
|
106 |
with open(ckpt_config_path, 'w') as f:
|
107 |
yaml.safe_dump(hparams_, f)
|
108 |
|
109 |
-
hparams_['
|
110 |
hparams_['debug'] = args.debug
|
111 |
hparams_['validate'] = args.validate
|
112 |
-
hparams_['exp_name'] = args.exp_name
|
113 |
global global_print_hparams
|
114 |
if global_hparams:
|
115 |
hparams.clear()
|
116 |
hparams.update(hparams_)
|
|
|
117 |
if print_hparams and global_print_hparams and global_hparams:
|
118 |
print('| Hparams chains: ', config_chains)
|
119 |
print('| Hparams: ')
|
@@ -121,4 +113,9 @@ def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, glob
|
|
121 |
print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "")
|
122 |
print("")
|
123 |
global_print_hparams = False
|
|
|
|
|
|
|
|
|
|
|
124 |
return hparams_
|
|
|
1 |
import argparse
|
2 |
import os
|
|
|
|
|
3 |
import yaml
|
4 |
|
5 |
global_print_hparams = True
|
|
|
21 |
|
22 |
|
23 |
def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True):
|
24 |
+
if config == '':
|
25 |
+
parser = argparse.ArgumentParser(description='neural music')
|
26 |
+
parser.add_argument('--config', type=str, default='',
|
27 |
help='location of the data corpus')
|
28 |
parser.add_argument('--exp_name', type=str, default='', help='exp_name')
|
29 |
parser.add_argument('--hparams', type=str, default='',
|
30 |
help='location of the data corpus')
|
31 |
+
parser.add_argument('--inference', action='store_true', help='inference')
|
32 |
parser.add_argument('--validate', action='store_true', help='validate')
|
33 |
parser.add_argument('--reset', action='store_true', help='reset hparams')
|
|
|
34 |
parser.add_argument('--debug', action='store_true', help='debug')
|
35 |
args, unknown = parser.parse_known_args()
|
36 |
else:
|
37 |
args = Args(config=config, exp_name=exp_name, hparams=hparams_str,
|
38 |
infer=False, validate=False, reset=False, debug=False)
|
39 |
+
args_work_dir = ''
|
40 |
+
if args.exp_name != '':
|
41 |
+
args.work_dir = args.exp_name
|
42 |
+
args_work_dir = f'checkpoints/{args.work_dir}'
|
43 |
|
44 |
config_chains = []
|
45 |
loaded_config = set()
|
46 |
|
47 |
def load_config(config_fn): # deep first
|
|
|
|
|
48 |
with open(config_fn) as f:
|
49 |
hparams_ = yaml.safe_load(f)
|
50 |
loaded_config.add(config_fn)
|
|
|
53 |
if not isinstance(hparams_['base_config'], list):
|
54 |
hparams_['base_config'] = [hparams_['base_config']]
|
55 |
for c in hparams_['base_config']:
|
|
|
|
|
|
|
56 |
if c not in loaded_config:
|
57 |
+
if c.startswith('.'):
|
58 |
+
c = f'{os.path.dirname(config_fn)}/{c}'
|
59 |
+
c = os.path.normpath(c)
|
60 |
override_config(ret_hparams, load_config(c))
|
61 |
override_config(ret_hparams, hparams_)
|
62 |
else:
|
|
|
64 |
config_chains.append(config_fn)
|
65 |
return ret_hparams
|
66 |
|
67 |
+
global hparams
|
68 |
+
assert args.config != '' or args_work_dir != ''
|
69 |
saved_hparams = {}
|
70 |
+
if args_work_dir != 'checkpoints/':
|
|
|
|
|
71 |
ckpt_config_path = f'{args_work_dir}/config.yaml'
|
72 |
if os.path.exists(ckpt_config_path):
|
73 |
+
try:
|
74 |
+
with open(ckpt_config_path) as f:
|
75 |
+
saved_hparams.update(yaml.safe_load(f))
|
76 |
+
except:
|
77 |
+
pass
|
78 |
+
if args.config == '':
|
79 |
+
args.config = ckpt_config_path
|
80 |
+
|
81 |
hparams_ = {}
|
82 |
+
hparams_.update(load_config(args.config))
|
83 |
+
|
84 |
if not args.reset:
|
85 |
hparams_.update(saved_hparams)
|
86 |
hparams_['work_dir'] = args_work_dir
|
87 |
|
|
|
88 |
if args.hparams != "":
|
89 |
for new_hparam in args.hparams.split(","):
|
90 |
k, v = new_hparam.split("=")
|
91 |
+
if v in ['True', 'False'] or type(hparams_[k]) == bool:
|
92 |
+
hparams_[k] = eval(v)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
else:
|
94 |
+
hparams_[k] = type(hparams_[k])(v)
|
95 |
+
|
|
|
|
|
|
|
96 |
if args_work_dir != '' and (not os.path.exists(ckpt_config_path) or args.reset) and not args.infer:
|
97 |
os.makedirs(hparams_['work_dir'], exist_ok=True)
|
98 |
with open(ckpt_config_path, 'w') as f:
|
99 |
yaml.safe_dump(hparams_, f)
|
100 |
|
101 |
+
hparams_['inference'] = args.infer
|
102 |
hparams_['debug'] = args.debug
|
103 |
hparams_['validate'] = args.validate
|
|
|
104 |
global global_print_hparams
|
105 |
if global_hparams:
|
106 |
hparams.clear()
|
107 |
hparams.update(hparams_)
|
108 |
+
|
109 |
if print_hparams and global_print_hparams and global_hparams:
|
110 |
print('| Hparams chains: ', config_chains)
|
111 |
print('| Hparams: ')
|
|
|
113 |
print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "")
|
114 |
print("")
|
115 |
global_print_hparams = False
|
116 |
+
# print(hparams_.keys())
|
117 |
+
if hparams.get('exp_name') is None:
|
118 |
+
hparams['exp_name'] = args.exp_name
|
119 |
+
if hparams_.get('exp_name') is None:
|
120 |
+
hparams_['exp_name'] = args.exp_name
|
121 |
return hparams_
|