mazpie's picture
Initial commit
2d9a728
raw
history blame
2.02 kB
import logging
import os
import sys
from os.path import dirname, join
from utils.config import Config
from utils.distributed import init_distributed_mode, is_main_process
from utils.logger import setup_logger
logger = logging.getLogger(__name__)
def setup_config():
"""Conbine yaml config and command line config with OmegaConf.
Also converts types, e.g., `'None'` (str) --> `None` (None)
"""
config = Config.get_config()
if config.debug:
config.wandb.enable = False
return config
def setup_evaluate_config(config):
"""setup evaluation default settings, e.g., disable wandb"""
assert config.evaluate
config.wandb.enable = False
if config.output_dir is None:
config.output_dir = join(dirname(config.pretrained_path), "eval")
return config
def setup_output_dir(output_dir, excludes=["code"]):
"""ensure not overwritting an exisiting/non-empty output dir"""
if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=False)
else:
existing_dirs_files = os.listdir(output_dir) # list
remaining = set(existing_dirs_files) - set(excludes)
remaining = [e for e in remaining if "slurm" not in e]
remaining = [e for e in remaining if ".out" not in e]
# assert len(remaining) == 0, f"remaining dirs or files: {remaining}"
logger.warn(f"remaining dirs or files: {remaining}")
def setup_main():
"""
Setup config, logger, output_dir, etc.
Shared for pretrain and all downstream tasks.
"""
config = setup_config()
if hasattr(config, "evaluate") and config.evaluate:
config = setup_evaluate_config(config)
init_distributed_mode(config)
if is_main_process():
setup_output_dir(config.output_dir, excludes=["code"])
setup_logger(output=config.output_dir, color=True, name="vindlu")
logger.info(f"config: {Config.pretty_text(config)}")
Config.dump(config, os.path.join(config.output_dir, "config.json"))
return config