Spaces:
Configuration error
Configuration error
File size: 4,758 Bytes
4efbc62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
# International Conference on Computer Vision (ICCV), 2023
import os
import time
from copy import deepcopy
import torch.backends.cudnn
import torch.distributed
import torch.nn as nn
from efficientvit.apps.data_provider import DataProvider
from efficientvit.apps.trainer.run_config import RunConfig
from efficientvit.apps.utils import (dist_init, dump_config,
get_dist_local_rank, get_dist_rank,
get_dist_size, init_modules, is_master,
load_config, partial_update_config,
zero_last_gamma)
from efficientvit.models.utils import (build_kwargs_from_config,
load_state_dict_from_file)
__all__ = [
"save_exp_config",
"setup_dist_env",
"setup_seed",
"setup_exp_config",
"setup_data_provider",
"setup_run_config",
"init_model",
]
def save_exp_config(exp_config: dict, path: str, name="config.yaml") -> None:
if not is_master():
return
dump_config(exp_config, os.path.join(path, name))
def setup_dist_env(gpu: str or None = None) -> None:
if gpu is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = gpu
if not torch.distributed.is_initialized():
dist_init()
torch.backends.cudnn.benchmark = True
torch.cuda.set_device(get_dist_local_rank())
def setup_seed(manual_seed: int, resume: bool) -> None:
if resume:
manual_seed = int(time.time())
manual_seed = get_dist_rank() + manual_seed
torch.manual_seed(manual_seed)
torch.cuda.manual_seed_all(manual_seed)
def setup_exp_config(
config_path: str, recursive=True, opt_args: dict or None = None
) -> dict:
# load config
if not os.path.isfile(config_path):
raise ValueError(config_path)
fpaths = [config_path]
if recursive:
extension = os.path.splitext(config_path)[1]
while os.path.dirname(config_path) != config_path:
config_path = os.path.dirname(config_path)
fpath = os.path.join(config_path, "default" + extension)
if os.path.isfile(fpath):
fpaths.append(fpath)
fpaths = fpaths[::-1]
default_config = load_config(fpaths[0])
exp_config = deepcopy(default_config)
for fpath in fpaths[1:]:
partial_update_config(exp_config, load_config(fpath))
# update config via args
if opt_args is not None:
partial_update_config(exp_config, opt_args)
return exp_config
def setup_data_provider(
exp_config: dict,
data_provider_classes: list[type[DataProvider]],
is_distributed: bool = True,
) -> DataProvider:
dp_config = exp_config["data_provider"]
dp_config["num_replicas"] = get_dist_size() if is_distributed else None
dp_config["rank"] = get_dist_rank() if is_distributed else None
dp_config["test_batch_size"] = (
dp_config.get("test_batch_size", None) or dp_config["base_batch_size"] * 2
)
dp_config["batch_size"] = dp_config["train_batch_size"] = dp_config[
"base_batch_size"
]
data_provider_lookup = {
provider.name: provider for provider in data_provider_classes
}
data_provider_class = data_provider_lookup[dp_config["dataset"]]
data_provider_kwargs = build_kwargs_from_config(dp_config, data_provider_class)
data_provider = data_provider_class(**data_provider_kwargs)
return data_provider
def setup_run_config(exp_config: dict, run_config_cls: type[RunConfig]) -> RunConfig:
exp_config["run_config"]["init_lr"] = (
exp_config["run_config"]["base_lr"] * get_dist_size()
)
run_config = run_config_cls(**exp_config["run_config"])
return run_config
def init_model(
network: nn.Module,
init_from: str or None = None,
backbone_init_from: str or None = None,
rand_init="trunc_normal",
last_gamma=None,
) -> None:
# initialization
init_modules(network, init_type=rand_init)
# zero gamma of last bn in each block
if last_gamma is not None:
zero_last_gamma(network, last_gamma)
# load weight
if init_from is not None and os.path.isfile(init_from):
network.load_state_dict(load_state_dict_from_file(init_from))
print(f"Loaded init from {init_from}")
elif backbone_init_from is not None and os.path.isfile(backbone_init_from):
network.backbone.load_state_dict(load_state_dict_from_file(backbone_init_from))
print(f"Loaded backbone init from {backbone_init_from}")
else:
print(f"Random init ({rand_init}) with last gamma {last_gamma}")
|