|
from typing import Dict, Any |
|
|
|
import yaml |
|
|
|
from pathlib import Path |
|
|
|
from utils import get_full_file_path |
|
|
|
def load_config(file_name: str = 'config.yaml') -> Dict[str, Any]: |
|
""" |
|
Loads a YAML config from a relative file path |
|
""" |
|
file_path = get_full_file_path(file_name=file_name) |
|
with open(file_path, 'r') as file: |
|
config = yaml.safe_load(file) |
|
|
|
return config |
|
|
|
|
|
|
|
def get_weights_file_path(config, epoch: str) -> str: |
|
model_folder = config['model']['model_folder'] |
|
model_basename = config['model']['model_basename'] |
|
model_filename = f'{model_basename}{epoch}.pt' |
|
return str(Path('.') / model_folder / model_filename) |
|
|
|
|