|
import yaml |
|
import time |
|
from collections import OrderedDict |
|
from os import path as osp |
|
from basicsr.utils.misc import get_time_str |
|
|
|
def ordered_yaml(): |
|
"""Support OrderedDict for yaml. |
|
|
|
Returns: |
|
yaml Loader and Dumper. |
|
""" |
|
try: |
|
from yaml import CDumper as Dumper |
|
from yaml import CLoader as Loader |
|
except ImportError: |
|
from yaml import Dumper, Loader |
|
|
|
_mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG |
|
|
|
def dict_representer(dumper, data): |
|
return dumper.represent_dict(data.items()) |
|
|
|
def dict_constructor(loader, node): |
|
return OrderedDict(loader.construct_pairs(node)) |
|
|
|
Dumper.add_representer(OrderedDict, dict_representer) |
|
Loader.add_constructor(_mapping_tag, dict_constructor) |
|
return Loader, Dumper |
|
|
|
|
|
def parse(opt_path, root_path, is_train=True): |
|
"""Parse option file. |
|
|
|
Args: |
|
opt_path (str): Option file path. |
|
is_train (str): Indicate whether in training or not. Default: True. |
|
|
|
Returns: |
|
(dict): Options. |
|
""" |
|
with open(opt_path, mode='r') as f: |
|
Loader, _ = ordered_yaml() |
|
opt = yaml.load(f, Loader=Loader) |
|
|
|
opt['is_train'] = is_train |
|
|
|
|
|
if opt['path'].get('resume_state', None): |
|
resume_state_path = opt['path'].get('resume_state') |
|
opt['name'] = resume_state_path.split("/")[-3] |
|
else: |
|
opt['name'] = f"{get_time_str()}_{opt['name']}" |
|
|
|
|
|
|
|
for phase, dataset in opt['datasets'].items(): |
|
|
|
phase = phase.split('_')[0] |
|
dataset['phase'] = phase |
|
if 'scale' in opt: |
|
dataset['scale'] = opt['scale'] |
|
if dataset.get('dataroot_gt') is not None: |
|
dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) |
|
if dataset.get('dataroot_lq') is not None: |
|
dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) |
|
|
|
|
|
for key, val in opt['path'].items(): |
|
if (val is not None) and ('resume_state' in key or 'pretrain_network' in key): |
|
opt['path'][key] = osp.expanduser(val) |
|
|
|
if is_train: |
|
experiments_root = osp.join(root_path, 'experiments', opt['name']) |
|
opt['path']['experiments_root'] = experiments_root |
|
opt['path']['models'] = osp.join(experiments_root, 'models') |
|
opt['path']['training_states'] = osp.join(experiments_root, 'training_states') |
|
opt['path']['log'] = experiments_root |
|
opt['path']['visualization'] = osp.join(experiments_root, 'visualization') |
|
|
|
else: |
|
results_root = osp.join(root_path, 'results', opt['name']) |
|
opt['path']['results_root'] = results_root |
|
opt['path']['log'] = results_root |
|
opt['path']['visualization'] = osp.join(results_root, 'visualization') |
|
|
|
return opt |
|
|
|
|
|
def dict2str(opt, indent_level=1): |
|
"""dict to string for printing options. |
|
|
|
Args: |
|
opt (dict): Option dict. |
|
indent_level (int): Indent level. Default: 1. |
|
|
|
Return: |
|
(str): Option string for printing. |
|
""" |
|
msg = '\n' |
|
for k, v in opt.items(): |
|
if isinstance(v, dict): |
|
msg += ' ' * (indent_level * 2) + k + ':[' |
|
msg += dict2str(v, indent_level + 1) |
|
msg += ' ' * (indent_level * 2) + ']\n' |
|
else: |
|
msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' |
|
return msg |
|
|