import argparse |
import json |
import os |
import shutil |
import sys |
import subprocess |
import time |
from .yamato_utils import ( |
find_executables, |
get_base_path, |
get_base_output_path, |
run_standalone_build, |
init_venv, |
override_config_file, |
checkout_csharp_version, |
undo_git_checkout, |
) |
def run_training(python_version: str, csharp_version: str) -> bool: |
latest = "latest" |
run_id = int(time.time() * 1000.0) |
print( |
f"Running training with python={python_version or latest} and c#={csharp_version or latest}" |
) |
output_dir = "results" |
onnx_file_expected = f"./{output_dir}/{run_id}/3DBall.onnx" |
if os.path.exists(onnx_file_expected): |
print("Artifacts from previous build found!") |
return False |
base_path = get_base_path() |
print(f"Running in base path {base_path}") |
if csharp_version is not None: |
artifact_path = get_base_output_path() |
full_player_path = os.path.join(artifact_path, "testPlayer.app") |
temp_player_path = os.path.join(artifact_path, "temp_testPlayer.app") |
final_player_path = os.path.join( |
artifact_path, f"testPlayer_{csharp_version}.app" |
) |
os.rename(full_player_path, temp_player_path) |
checkout_csharp_version(csharp_version) |
build_returncode = run_standalone_build(base_path) |
if build_returncode != 0: |
print(f"Standalone build FAILED! with return code {build_returncode}") |
return False |
os.rename(full_player_path, final_player_path) |
os.rename(temp_player_path, full_player_path) |
standalone_player_path = f"testPlayer_{csharp_version}" |
else: |
standalone_player_path = "testPlayer" |
init_venv(python_version) |
yaml_out = "override.yaml" |
overrides = { |
"hyperparameters": {"batch_size": 10, "buffer_size": 10}, |
"max_steps": 100, |
} |
override_config_file("config/ppo/3DBall.yaml", yaml_out, overrides) |
log_output_path = f"{get_base_output_path()}/training.log" |
env_path = os.path.join(get_base_output_path(), standalone_player_path) |
mla_learn_cmd = [ |
"mlagents-learn", |
yaml_out, |
"--force", |
"--env", |
env_path, |
"--run-id", |
str(run_id), |
"--no-graphics", |
"--env-args", |
"-logFile", |
log_output_path, |
] |
res = subprocess.run(mla_learn_cmd) |
if csharp_version is None and python_version is None: |
model_artifacts_dir = os.path.join(get_base_output_path(), "models") |
os.makedirs(model_artifacts_dir, exist_ok=True) |
if os.path.exists(onnx_file_expected): |
shutil.copy(onnx_file_expected, model_artifacts_dir) |
if res.returncode != 0 or not os.path.exists(onnx_file_expected): |
print("mlagents-learn run FAILED!") |
print("Command line: " + " ".join(mla_learn_cmd)) |
subprocess.run(["cat", log_output_path]) |
return False |
if csharp_version is None and python_version is None: |
model_path = os.path.abspath(os.path.dirname(onnx_file_expected)) |
inference_ok = run_inference(env_path, model_path, "onnx") |
if not inference_ok: |
return False |
print("mlagents-learn run SUCCEEDED!") |
return True |
def run_inference(env_path: str, output_path: str, model_extension: str) -> bool: |
start_time = time.time() |
exes = find_executables(env_path) |
if len(exes) != 1: |
print(f"Can't determine the player executable in {env_path}. Found {exes}.") |
return False |
log_output_path = f"{get_base_output_path()}/inference.{model_extension}.txt" |
process_timeout = 10 * 60 |
model_override_timeout = process_timeout - 15 |
exe_path = exes[0] |
args = [ |
exe_path, |
"-nographics", |
"-batchmode", |
"-logfile", |
log_output_path, |
"--mlagents-override-model-directory", |
output_path, |
"--mlagents-quit-on-load-failure", |
"--mlagents-quit-after-episodes", |
"1", |
"--mlagents-override-model-extension", |
model_extension, |
"--mlagents-quit-after-seconds", |
str(model_override_timeout), |
] |
print(f"Starting inference with args {' '.join(args)}") |
res = subprocess.run(args, timeout=process_timeout) |
end_time = time.time() |
if res.returncode != 0: |
print("Error running inference!") |
print("Command line: " + " ".join(args)) |
subprocess.run(["cat", log_output_path]) |
return False |
else: |
print(f"Inference finished! Took {end_time - start_time} seconds") |
timer_file = f"{exe_path}_Data/ML-Agents/Timers/3DBall_timers.json" |
with open(timer_file) as f: |
timer_data = json.load(f) |
gauges = timer_data.get("gauges", {}) |
rewards = gauges.get("Override_3DBall.CumulativeReward", {}) |
max_reward = rewards.get("max") |
if max_reward is None: |
print( |
"Unable to find rewards in timer file. This usually indicates a problem with Barracuda or inference." |
) |
return False |
return True |
def main(): |
parser = argparse.ArgumentParser() |
parser.add_argument("--python", default=None) |
parser.add_argument("--csharp", default=None) |
args = parser.parse_args() |
try: |
ok = run_training(args.python, args.csharp) |
if not ok: |
sys.exit(1) |
finally: |
undo_git_checkout() |
if __name__ == "__main__": |
main() |