|
import argparse |
|
import json |
|
import os |
|
import shutil |
|
import sys |
|
import subprocess |
|
import time |
|
|
|
from .yamato_utils import ( |
|
find_executables, |
|
get_base_path, |
|
get_base_output_path, |
|
run_standalone_build, |
|
init_venv, |
|
override_config_file, |
|
checkout_csharp_version, |
|
undo_git_checkout, |
|
) |
|
|
|
|
|
def run_training(python_version: str, csharp_version: str) -> bool: |
|
latest = "latest" |
|
run_id = int(time.time() * 1000.0) |
|
print( |
|
f"Running training with python={python_version or latest} and c#={csharp_version or latest}" |
|
) |
|
output_dir = "results" |
|
onnx_file_expected = f"./{output_dir}/{run_id}/3DBall.onnx" |
|
|
|
if os.path.exists(onnx_file_expected): |
|
|
|
print("Artifacts from previous build found!") |
|
return False |
|
|
|
base_path = get_base_path() |
|
print(f"Running in base path {base_path}") |
|
|
|
|
|
|
|
if csharp_version is not None: |
|
|
|
|
|
artifact_path = get_base_output_path() |
|
full_player_path = os.path.join(artifact_path, "testPlayer.app") |
|
temp_player_path = os.path.join(artifact_path, "temp_testPlayer.app") |
|
final_player_path = os.path.join( |
|
artifact_path, f"testPlayer_{csharp_version}.app" |
|
) |
|
|
|
os.rename(full_player_path, temp_player_path) |
|
|
|
checkout_csharp_version(csharp_version) |
|
build_returncode = run_standalone_build(base_path) |
|
|
|
if build_returncode != 0: |
|
print(f"Standalone build FAILED! with return code {build_returncode}") |
|
return False |
|
|
|
|
|
os.rename(full_player_path, final_player_path) |
|
os.rename(temp_player_path, full_player_path) |
|
standalone_player_path = f"testPlayer_{csharp_version}" |
|
else: |
|
standalone_player_path = "testPlayer" |
|
|
|
init_venv(python_version) |
|
|
|
|
|
|
|
yaml_out = "override.yaml" |
|
overrides = { |
|
"hyperparameters": {"batch_size": 10, "buffer_size": 10}, |
|
"max_steps": 100, |
|
} |
|
override_config_file("config/ppo/3DBall.yaml", yaml_out, overrides) |
|
|
|
log_output_path = f"{get_base_output_path()}/training.log" |
|
env_path = os.path.join(get_base_output_path(), standalone_player_path) |
|
mla_learn_cmd = [ |
|
"mlagents-learn", |
|
yaml_out, |
|
"--force", |
|
"--env", |
|
env_path, |
|
"--run-id", |
|
str(run_id), |
|
"--no-graphics", |
|
"--env-args", |
|
"-logFile", |
|
log_output_path, |
|
] |
|
|
|
res = subprocess.run(mla_learn_cmd) |
|
|
|
|
|
if csharp_version is None and python_version is None: |
|
model_artifacts_dir = os.path.join(get_base_output_path(), "models") |
|
os.makedirs(model_artifacts_dir, exist_ok=True) |
|
if os.path.exists(onnx_file_expected): |
|
shutil.copy(onnx_file_expected, model_artifacts_dir) |
|
|
|
if res.returncode != 0 or not os.path.exists(onnx_file_expected): |
|
print("mlagents-learn run FAILED!") |
|
print("Command line: " + " ".join(mla_learn_cmd)) |
|
subprocess.run(["cat", log_output_path]) |
|
return False |
|
|
|
if csharp_version is None and python_version is None: |
|
|
|
model_path = os.path.abspath(os.path.dirname(onnx_file_expected)) |
|
inference_ok = run_inference(env_path, model_path, "onnx") |
|
if not inference_ok: |
|
return False |
|
|
|
print("mlagents-learn run SUCCEEDED!") |
|
return True |
|
|
|
|
|
def run_inference(env_path: str, output_path: str, model_extension: str) -> bool: |
|
start_time = time.time() |
|
exes = find_executables(env_path) |
|
if len(exes) != 1: |
|
print(f"Can't determine the player executable in {env_path}. Found {exes}.") |
|
return False |
|
|
|
log_output_path = f"{get_base_output_path()}/inference.{model_extension}.txt" |
|
|
|
|
|
process_timeout = 10 * 60 |
|
|
|
model_override_timeout = process_timeout - 15 |
|
|
|
exe_path = exes[0] |
|
args = [ |
|
exe_path, |
|
"-nographics", |
|
"-batchmode", |
|
"-logfile", |
|
log_output_path, |
|
"--mlagents-override-model-directory", |
|
output_path, |
|
"--mlagents-quit-on-load-failure", |
|
"--mlagents-quit-after-episodes", |
|
"1", |
|
"--mlagents-override-model-extension", |
|
model_extension, |
|
"--mlagents-quit-after-seconds", |
|
str(model_override_timeout), |
|
] |
|
print(f"Starting inference with args {' '.join(args)}") |
|
res = subprocess.run(args, timeout=process_timeout) |
|
end_time = time.time() |
|
if res.returncode != 0: |
|
print("Error running inference!") |
|
print("Command line: " + " ".join(args)) |
|
subprocess.run(["cat", log_output_path]) |
|
return False |
|
else: |
|
print(f"Inference finished! Took {end_time - start_time} seconds") |
|
|
|
|
|
timer_file = f"{exe_path}_Data/ML-Agents/Timers/3DBall_timers.json" |
|
with open(timer_file) as f: |
|
timer_data = json.load(f) |
|
|
|
gauges = timer_data.get("gauges", {}) |
|
rewards = gauges.get("Override_3DBall.CumulativeReward", {}) |
|
max_reward = rewards.get("max") |
|
if max_reward is None: |
|
print( |
|
"Unable to find rewards in timer file. This usually indicates a problem with Barracuda or inference." |
|
) |
|
return False |
|
|
|
|
|
|
|
return True |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--python", default=None) |
|
parser.add_argument("--csharp", default=None) |
|
args = parser.parse_args() |
|
|
|
try: |
|
ok = run_training(args.python, args.csharp) |
|
if not ok: |
|
sys.exit(1) |
|
|
|
finally: |
|
|
|
undo_git_checkout() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|