File size: 1,019 Bytes
32b2aaa
 
1df74c6
32b2aaa
 
 
 
 
 
 
 
 
 
 
 
1df74c6
32b2aaa
 
 
 
 
1df74c6
 
 
 
 
 
32b2aaa
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import logging
from pathlib import Path
from typing import Union

import torch

RUN_NAME = "enhancer_stage2"

logger = logging.getLogger(__name__)


def get_source_url(relpath):
    return f"https://huggingface.co/ResembleAI/resemble-enhance/resolve/main/{RUN_NAME}/{relpath}?download=true"


def get_target_path(relpath: Union[str, Path], run_dir: Union[str, Path, None] = None):
    if run_dir is None:
        run_dir = Path(__file__).parent.parent / "model_repo" / RUN_NAME
    return Path(run_dir) / relpath


def download(run_dir: Union[str, Path, None] = None):
    relpaths = [
        "hparams.yaml",
        "ds/G/latest",
        "ds/G/default/mp_rank_00_model_states.pt",
    ]
    for relpath in relpaths:
        path = get_target_path(relpath, run_dir=run_dir)
        if path.exists():
            continue
        url = get_source_url(relpath)
        path.parent.mkdir(parents=True, exist_ok=True)
        torch.hub.download_url_to_file(url, str(path))
    return get_target_path("", run_dir=run_dir)