File size: 1,263 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import os
import argparse

from huggingface_hub import snapshot_download

from mlagents_envs import logging_util
from mlagents_envs.logging_util import get_logger

logger = get_logger(__name__)
logging_util.set_log_level(logging_util.INFO)


def load_from_hf(repo_id: str, local_dir: str) -> None:
    """
    Download a model from Hugging Face Hub.
    :param repo_id: id of the model repository from the Hugging Face Hub
    :param local_dir: local destination of the repository
    """
    _, repo_name = repo_id.split("/")

    local_dir = os.path.join(local_dir, repo_name)

    snapshot_download(repo_id=repo_id, local_dir=local_dir)

    logger.info(f"The repository {repo_id} has been downloaded to {local_dir}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--repo-id",
        help="Repo id of the model repository from the Hugging Face Hub",
        type=str,
    )
    parser.add_argument(
        "--local-dir",
        help="Local destination of the repository",
        type=str,
        default="./",
    )
    args = parser.parse_args()

    # Load model from the Hub
    load_from_hf(args.repo_id, args.local_dir)


# For python debugger to directly run this script
if __name__ == "__main__":
    main()