luxmorocco's picture
Upload 86 files
4efbc62 verified
raw
history blame
3.45 kB
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
# International Conference on Computer Vision (ICCV), 2023
import os
import yaml
__all__ = [
"parse_with_yaml",
"parse_unknown_args",
"partial_update_config",
"resolve_and_load_config",
"load_config",
"dump_config",
]
def parse_with_yaml(config_str: str) -> str or dict:
try:
# add space manually for dict
if "{" in config_str and "}" in config_str and ":" in config_str:
out_str = config_str.replace(":", ": ")
else:
out_str = config_str
return yaml.safe_load(out_str)
except ValueError:
# return raw string if parsing fails
return config_str
def parse_unknown_args(unknown: list) -> dict:
"""Parse unknown args."""
index = 0
parsed_dict = {}
while index < len(unknown):
key, val = unknown[index], unknown[index + 1]
index += 2
if not key.startswith("--"):
continue
key = key[2:]
# try parsing with either dot notation or full yaml notation
# Note that the vanilla case "--key value" will be parsed the same
if "." in key:
# key == a.b.c, val == val --> parsed_dict[a][b][c] = val
keys = key.split(".")
dict_to_update = parsed_dict
for key in keys[:-1]:
if not (
key in dict_to_update and isinstance(dict_to_update[key], dict)
):
dict_to_update[key] = {}
dict_to_update = dict_to_update[key]
dict_to_update[keys[-1]] = parse_with_yaml(
val
) # so we can parse lists, bools, etc...
else:
parsed_dict[key] = parse_with_yaml(val)
return parsed_dict
def partial_update_config(config: dict, partial_config: dict) -> dict:
for key in partial_config:
if (
key in config
and isinstance(partial_config[key], dict)
and isinstance(config[key], dict)
):
partial_update_config(config[key], partial_config[key])
else:
config[key] = partial_config[key]
return config
def resolve_and_load_config(path: str, config_name="config.yaml") -> dict:
path = os.path.realpath(os.path.expanduser(path))
if os.path.isdir(path):
config_path = os.path.join(path, config_name)
else:
config_path = path
if os.path.isfile(config_path):
pass
else:
raise Exception(f"Cannot find a valid config at {path}")
config = load_config(config_path)
return config
class SafeLoaderWithTuple(yaml.SafeLoader):
"""A yaml safe loader with python tuple loading capabilities."""
def construct_python_tuple(self, node):
return tuple(self.construct_sequence(node))
SafeLoaderWithTuple.add_constructor(
"tag:yaml.org,2002:python/tuple", SafeLoaderWithTuple.construct_python_tuple
)
def load_config(filename: str) -> dict:
"""Load a yaml file."""
filename = os.path.realpath(os.path.expanduser(filename))
return yaml.load(open(filename), Loader=SafeLoaderWithTuple)
def dump_config(config: dict, filename: str) -> None:
"""Dump a config file"""
filename = os.path.realpath(os.path.expanduser(filename))
yaml.dump(config, open(filename, "w"), sort_keys=False)