|
import os |
|
import argparse |
|
|
|
from huggingface_hub import snapshot_download |
|
|
|
from mlagents_envs import logging_util |
|
from mlagents_envs.logging_util import get_logger |
|
|
|
logger = get_logger(__name__) |
|
logging_util.set_log_level(logging_util.INFO) |
|
|
|
|
|
def load_from_hf(repo_id: str, local_dir: str) -> None: |
|
""" |
|
Download a model from Hugging Face Hub. |
|
:param repo_id: id of the model repository from the Hugging Face Hub |
|
:param local_dir: local destination of the repository |
|
""" |
|
_, repo_name = repo_id.split("/") |
|
|
|
local_dir = os.path.join(local_dir, repo_name) |
|
|
|
snapshot_download(repo_id=repo_id, local_dir=local_dir) |
|
|
|
logger.info(f"The repository {repo_id} has been downloaded to {local_dir}") |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--repo-id", |
|
help="Repo id of the model repository from the Hugging Face Hub", |
|
type=str, |
|
) |
|
parser.add_argument( |
|
"--local-dir", |
|
help="Local destination of the repository", |
|
type=str, |
|
default="./", |
|
) |
|
args = parser.parse_args() |
|
|
|
|
|
load_from_hf(args.repo_id, args.local_dir) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|