Spaces:
Sleeping
Sleeping
File size: 1,019 Bytes
32b2aaa 1df74c6 32b2aaa 1df74c6 32b2aaa 1df74c6 32b2aaa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import logging
from pathlib import Path
from typing import Union
import torch
RUN_NAME = "enhancer_stage2"
logger = logging.getLogger(__name__)
def get_source_url(relpath):
return f"https://huggingface.co/ResembleAI/resemble-enhance/resolve/main/{RUN_NAME}/{relpath}?download=true"
def get_target_path(relpath: Union[str, Path], run_dir: Union[str, Path, None] = None):
if run_dir is None:
run_dir = Path(__file__).parent.parent / "model_repo" / RUN_NAME
return Path(run_dir) / relpath
def download(run_dir: Union[str, Path, None] = None):
relpaths = [
"hparams.yaml",
"ds/G/latest",
"ds/G/default/mp_rank_00_model_states.pt",
]
for relpath in relpaths:
path = get_target_path(relpath, run_dir=run_dir)
if path.exists():
continue
url = get_source_url(relpath)
path.parent.mkdir(parents=True, exist_ok=True)
torch.hub.download_url_to_file(url, str(path))
return get_target_path("", run_dir=run_dir)
|