File size: 701 Bytes
bc1ada8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from typing import Dict, Any

import yaml

from pathlib import Path

from utils import get_full_file_path

def load_config(file_name: str = 'config.yaml') -> Dict[str, Any]:
    """
    Loads a YAML config from a relative file path
    """
    file_path = get_full_file_path(file_name=file_name)
    with open(file_path, 'r') as file:
        config = yaml.safe_load(file)

    return config

# print(list(load_config().items()))

def get_weights_file_path(config, epoch: str) -> str:
    model_folder = config['model']['model_folder']
    model_basename = config['model']['model_basename']
    model_filename = f'{model_basename}{epoch}.pt'
    return str(Path('.') / model_folder / model_filename)