Spaces:
Configuration error
Configuration error
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction | |
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han | |
# International Conference on Computer Vision (ICCV), 2023 | |
import os | |
import time | |
from copy import deepcopy | |
import torch.backends.cudnn | |
import torch.distributed | |
import torch.nn as nn | |
from efficientvit.apps.data_provider import DataProvider | |
from efficientvit.apps.trainer.run_config import RunConfig | |
from efficientvit.apps.utils import (dist_init, dump_config, | |
get_dist_local_rank, get_dist_rank, | |
get_dist_size, init_modules, is_master, | |
load_config, partial_update_config, | |
zero_last_gamma) | |
from efficientvit.models.utils import (build_kwargs_from_config, | |
load_state_dict_from_file) | |
__all__ = [ | |
"save_exp_config", | |
"setup_dist_env", | |
"setup_seed", | |
"setup_exp_config", | |
"setup_data_provider", | |
"setup_run_config", | |
"init_model", | |
] | |
def save_exp_config(exp_config: dict, path: str, name="config.yaml") -> None: | |
if not is_master(): | |
return | |
dump_config(exp_config, os.path.join(path, name)) | |
def setup_dist_env(gpu: str or None = None) -> None: | |
if gpu is not None: | |
os.environ["CUDA_VISIBLE_DEVICES"] = gpu | |
if not torch.distributed.is_initialized(): | |
dist_init() | |
torch.backends.cudnn.benchmark = True | |
torch.cuda.set_device(get_dist_local_rank()) | |
def setup_seed(manual_seed: int, resume: bool) -> None: | |
if resume: | |
manual_seed = int(time.time()) | |
manual_seed = get_dist_rank() + manual_seed | |
torch.manual_seed(manual_seed) | |
torch.cuda.manual_seed_all(manual_seed) | |
def setup_exp_config( | |
config_path: str, recursive=True, opt_args: dict or None = None | |
) -> dict: | |
# load config | |
if not os.path.isfile(config_path): | |
raise ValueError(config_path) | |
fpaths = [config_path] | |
if recursive: | |
extension = os.path.splitext(config_path)[1] | |
while os.path.dirname(config_path) != config_path: | |
config_path = os.path.dirname(config_path) | |
fpath = os.path.join(config_path, "default" + extension) | |
if os.path.isfile(fpath): | |
fpaths.append(fpath) | |
fpaths = fpaths[::-1] | |
default_config = load_config(fpaths[0]) | |
exp_config = deepcopy(default_config) | |
for fpath in fpaths[1:]: | |
partial_update_config(exp_config, load_config(fpath)) | |
# update config via args | |
if opt_args is not None: | |
partial_update_config(exp_config, opt_args) | |
return exp_config | |
def setup_data_provider( | |
exp_config: dict, | |
data_provider_classes: list[type[DataProvider]], | |
is_distributed: bool = True, | |
) -> DataProvider: | |
dp_config = exp_config["data_provider"] | |
dp_config["num_replicas"] = get_dist_size() if is_distributed else None | |
dp_config["rank"] = get_dist_rank() if is_distributed else None | |
dp_config["test_batch_size"] = ( | |
dp_config.get("test_batch_size", None) or dp_config["base_batch_size"] * 2 | |
) | |
dp_config["batch_size"] = dp_config["train_batch_size"] = dp_config[ | |
"base_batch_size" | |
] | |
data_provider_lookup = { | |
provider.name: provider for provider in data_provider_classes | |
} | |
data_provider_class = data_provider_lookup[dp_config["dataset"]] | |
data_provider_kwargs = build_kwargs_from_config(dp_config, data_provider_class) | |
data_provider = data_provider_class(**data_provider_kwargs) | |
return data_provider | |
def setup_run_config(exp_config: dict, run_config_cls: type[RunConfig]) -> RunConfig: | |
exp_config["run_config"]["init_lr"] = ( | |
exp_config["run_config"]["base_lr"] * get_dist_size() | |
) | |
run_config = run_config_cls(**exp_config["run_config"]) | |
return run_config | |
def init_model( | |
network: nn.Module, | |
init_from: str or None = None, | |
backbone_init_from: str or None = None, | |
rand_init="trunc_normal", | |
last_gamma=None, | |
) -> None: | |
# initialization | |
init_modules(network, init_type=rand_init) | |
# zero gamma of last bn in each block | |
if last_gamma is not None: | |
zero_last_gamma(network, last_gamma) | |
# load weight | |
if init_from is not None and os.path.isfile(init_from): | |
network.load_state_dict(load_state_dict_from_file(init_from)) | |
print(f"Loaded init from {init_from}") | |
elif backbone_init_from is not None and os.path.isfile(backbone_init_from): | |
network.backbone.load_state_dict(load_state_dict_from_file(backbone_init_from)) | |
print(f"Loaded backbone init from {backbone_init_from}") | |
else: | |
print(f"Random init ({rand_init}) with last gamma {last_gamma}") | |