|
|
|
|
|
import os |
|
import torch |
|
|
|
from detectron2.config import get_cfg |
|
from detectron2.engine import default_setup |
|
from detectron2.modeling import build_model |
|
|
|
from densepose import add_dataset_category_config, add_densepose_config |
|
|
|
_BASE_CONFIG_DIR = "configs" |
|
_EVOLUTION_CONFIG_SUB_DIR = "evolution" |
|
_QUICK_SCHEDULES_CONFIG_SUB_DIR = "quick_schedules" |
|
_BASE_CONFIG_FILE_PREFIX = "Base-" |
|
_CONFIG_FILE_EXT = ".yaml" |
|
|
|
|
|
def _get_base_config_dir(): |
|
""" |
|
Return the base directory for configurations |
|
""" |
|
return os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", _BASE_CONFIG_DIR) |
|
|
|
|
|
def _get_evolution_config_dir(): |
|
""" |
|
Return the base directory for evolution configurations |
|
""" |
|
return os.path.join(_get_base_config_dir(), _EVOLUTION_CONFIG_SUB_DIR) |
|
|
|
|
|
def _get_quick_schedules_config_dir(): |
|
""" |
|
Return the base directory for quick schedules configurations |
|
""" |
|
return os.path.join(_get_base_config_dir(), _QUICK_SCHEDULES_CONFIG_SUB_DIR) |
|
|
|
|
|
def _collect_config_files(config_dir): |
|
""" |
|
Collect all configuration files (i.e. densepose_*.yaml) directly in the specified directory |
|
""" |
|
start = _get_base_config_dir() |
|
results = [] |
|
for entry in os.listdir(config_dir): |
|
path = os.path.join(config_dir, entry) |
|
if not os.path.isfile(path): |
|
continue |
|
_, ext = os.path.splitext(entry) |
|
if ext != _CONFIG_FILE_EXT: |
|
continue |
|
if entry.startswith(_BASE_CONFIG_FILE_PREFIX): |
|
continue |
|
config_file = os.path.relpath(path, start) |
|
results.append(config_file) |
|
return results |
|
|
|
|
|
def get_config_files(): |
|
""" |
|
Get all the configuration files (relative to the base configuration directory) |
|
""" |
|
return _collect_config_files(_get_base_config_dir()) |
|
|
|
|
|
def get_evolution_config_files(): |
|
""" |
|
Get all the evolution configuration files (relative to the base configuration directory) |
|
""" |
|
return _collect_config_files(_get_evolution_config_dir()) |
|
|
|
|
|
def get_quick_schedules_config_files(): |
|
""" |
|
Get all the quick schedules configuration files (relative to the base configuration directory) |
|
""" |
|
return _collect_config_files(_get_quick_schedules_config_dir()) |
|
|
|
|
|
def _get_model_config(config_file): |
|
""" |
|
Load and return the configuration from the specified file (relative to the base configuration |
|
directory) |
|
""" |
|
cfg = get_cfg() |
|
add_dataset_category_config(cfg) |
|
add_densepose_config(cfg) |
|
path = os.path.join(_get_base_config_dir(), config_file) |
|
cfg.merge_from_file(path) |
|
if not torch.cuda.is_available(): |
|
cfg.MODEL_DEVICE = "cpu" |
|
return cfg |
|
|
|
|
|
def get_model(config_file): |
|
""" |
|
Get the model from the specified file (relative to the base configuration directory) |
|
""" |
|
cfg = _get_model_config(config_file) |
|
return build_model(cfg) |
|
|
|
|
|
def setup(config_file): |
|
""" |
|
Setup the configuration from the specified file (relative to the base configuration directory) |
|
""" |
|
cfg = _get_model_config(config_file) |
|
cfg.freeze() |
|
default_setup(cfg, {}) |
|
|