|
import os |
|
import argparse |
|
|
|
import yaml |
|
import json |
|
|
|
from pathlib import Path |
|
|
|
from huggingface_hub import HfApi |
|
from huggingface_hub.repocard import metadata_save |
|
|
|
from typing import Tuple |
|
|
|
from mlagents_envs import logging_util |
|
from mlagents_envs.logging_util import get_logger |
|
|
|
logger = get_logger(__name__) |
|
logging_util.set_log_level(logging_util.INFO) |
|
|
|
|
|
def _generate_config(local_dir: Path, configfile_name: str) -> None: |
|
""" |
|
Generate a config.json file from configuration.yaml |
|
To do that we convert yaml to json |
|
:param local_dir: path of the local directory |
|
:param configfile_name: name of the yaml config file (by default configuration.yaml) |
|
""" |
|
|
|
with open(os.path.join(local_dir, configfile_name)) as yaml_in: |
|
yaml_object = yaml.safe_load(yaml_in) |
|
with open(os.path.join(local_dir, "config.json"), "w") as json_out: |
|
json.dump(yaml_object, json_out) |
|
|
|
|
|
def _generate_metadata(model_name: str, env_id: str) -> dict: |
|
""" |
|
Define the tags for the model card |
|
:param model_name: name of the model |
|
:param env_id: name of the environment |
|
""" |
|
env_tag = "ML-Agents-" + env_id |
|
|
|
metadata = {} |
|
metadata["library_name"] = "ml-agents" |
|
metadata["tags"] = [ |
|
env_id, |
|
"deep-reinforcement-learning", |
|
"reinforcement-learning", |
|
env_tag, |
|
] |
|
|
|
return metadata |
|
|
|
|
|
def _generate_model_card( |
|
local_dir: Path, configfile_name: str, repo_id: str |
|
) -> Tuple[str, dict]: |
|
""" |
|
Generate the model card |
|
:param local_dir: local path of the directory |
|
:param configfile_name: name of the yaml config file (by default configuration.yaml) |
|
:param repo_id: id of the model repository from the Hugging Face Hub |
|
""" |
|
|
|
with open(os.path.join(local_dir, "config.json")) as f: |
|
data = json.load(f) |
|
|
|
env_id = list(data["behaviors"].keys())[0] |
|
|
|
model_name = data["behaviors"][env_id]["trainer_type"] |
|
|
|
|
|
metadata = _generate_metadata(model_name, env_id) |
|
|
|
|
|
model_card = f""" |
|
# **{model_name}** Agent playing **{env_id}** |
|
This is a trained model of a **{model_name}** agent playing **{env_id}** |
|
using the [Unity ML-Agents Library](https://github.com/Unity-Technologies/ml-agents). |
|
|
|
## Usage (with ML-Agents) |
|
The Documentation: https://unity-technologies.github.io/ml-agents/ML-Agents-Toolkit-Documentation/ |
|
|
|
We wrote a complete tutorial to learn to train your first agent using ML-Agents and publish it to the Hub: |
|
- A *short tutorial* where you teach Huggy the Dog ๐ถ to fetch the stick and then play with him directly in your |
|
browser: https://huggingface.co/learn/deep-rl-course/unitbonus1/introduction |
|
- A *longer tutorial* to understand how works ML-Agents: |
|
https://huggingface.co/learn/deep-rl-course/unit5/introduction |
|
|
|
### Resume the training |
|
```bash |
|
mlagents-learn <your_configuration_file_path.yaml> --run-id=<run_id> --resume |
|
``` |
|
|
|
### Watch your Agent play |
|
You can watch your agent **playing directly in your browser** |
|
|
|
1. If the environment is part of ML-Agents official environments, go to https://huggingface.co/unity |
|
2. Step 1: Find your model_id: {repo_id} |
|
3. Step 2: Select your *.nn /*.onnx file |
|
4. Click on Watch the agent play ๐ |
|
""" |
|
|
|
return model_card, metadata |
|
|
|
|
|
def _save_model_card( |
|
local_dir: Path, generated_model_card: str, metadata: dict |
|
) -> None: |
|
"""Save a model card to the directory. |
|
:param local_dir: local directory path |
|
:param generated_model_card: model card generated by _generate_model_card() method |
|
:param metadata: metadata |
|
""" |
|
readme_path = local_dir / "README.md" |
|
|
|
with readme_path.open("w", encoding="utf-8") as f: |
|
f.write(generated_model_card) |
|
|
|
|
|
metadata_save(readme_path, metadata) |
|
|
|
|
|
def package_to_hub( |
|
run_id: str, |
|
path_of_run_id: Path, |
|
repo_id: str, |
|
commit_message: str, |
|
configfile_name: str, |
|
) -> None: |
|
""" |
|
This method generates the model card and upload the run_id folder |
|
with all his files into the Hub |
|
:param run_id : name of the run |
|
:param path_of_run_id: path of the run_id folder that contains the onnx model. |
|
:param repo_id: id of the model repository from the Hugging Face Hub |
|
:param commit_message: commit message |
|
:param configfile_name: name of the yaml config file (by default configuration.yaml) |
|
""" |
|
logger.info( |
|
f"This function will create a model card and upload your {run_id} " |
|
f"into HuggingFace Hub. This is a work in progress: If you encounter a bug, " |
|
f"please send open an issue" |
|
) |
|
|
|
_, repo_name = repo_id.split("/") |
|
|
|
|
|
api = HfApi() |
|
|
|
repo_url = api.create_repo( |
|
repo_id=repo_id, |
|
exist_ok=True, |
|
) |
|
|
|
local_path = Path(path_of_run_id) |
|
|
|
|
|
_generate_config(local_path, configfile_name) |
|
|
|
|
|
generated_model_card, metadata = _generate_model_card( |
|
local_path, configfile_name, repo_id |
|
) |
|
_save_model_card(local_path, generated_model_card, metadata) |
|
|
|
logger.info(f"Pushing repo {run_id} to the Hugging Face Hub") |
|
|
|
|
|
api.upload_folder( |
|
repo_id=repo_id, folder_path=local_path, commit_message=commit_message |
|
) |
|
|
|
logger.info( |
|
f"Your model is pushed to the hub. You can view your model here: {repo_url}" |
|
) |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--run-id", help="Name of the run-id folder", type=str) |
|
parser.add_argument( |
|
"--local-dir", |
|
help="Path of the run_id folder that contains the trained model", |
|
type=str, |
|
default="./", |
|
) |
|
parser.add_argument( |
|
"--repo-id", |
|
help="Repo id of the model repository from the Hugging Face Hub", |
|
type=str, |
|
) |
|
parser.add_argument( |
|
"--commit-message", help="Commit message", type=str, default="Push to Hub" |
|
) |
|
parser.add_argument( |
|
"--configfile-name", |
|
help="Name of the configuration yaml file", |
|
type=str, |
|
default="configuration.yaml", |
|
) |
|
args = parser.parse_args() |
|
|
|
|
|
package_to_hub( |
|
args.run_id, |
|
args.local_dir, |
|
args.repo_id, |
|
args.commit_message, |
|
args.configfile_name, |
|
) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|