File size: 6,670 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
import argparse
import json
import os
import shutil
import sys
import subprocess
import time
from .yamato_utils import (
find_executables,
get_base_path,
get_base_output_path,
run_standalone_build,
init_venv,
override_config_file,
checkout_csharp_version,
undo_git_checkout,
)
def run_training(python_version: str, csharp_version: str) -> bool:
latest = "latest"
run_id = int(time.time() * 1000.0)
print(
f"Running training with python={python_version or latest} and c#={csharp_version or latest}"
)
output_dir = "results"
onnx_file_expected = f"./{output_dir}/{run_id}/3DBall.onnx"
if os.path.exists(onnx_file_expected):
# Should never happen - make sure nothing leftover from an old test.
print("Artifacts from previous build found!")
return False
base_path = get_base_path()
print(f"Running in base path {base_path}")
# Only build the standalone player if we're overriding the C# version
# Otherwise we'll use the one built earlier in the pipeline.
if csharp_version is not None:
# We can't rely on the old C# code recognizing the commandline argument to set the output
# So rename testPlayer (containing the most recent build) to something else temporarily
artifact_path = get_base_output_path()
full_player_path = os.path.join(artifact_path, "testPlayer.app")
temp_player_path = os.path.join(artifact_path, "temp_testPlayer.app")
final_player_path = os.path.join(
artifact_path, f"testPlayer_{csharp_version}.app"
)
os.rename(full_player_path, temp_player_path)
checkout_csharp_version(csharp_version)
build_returncode = run_standalone_build(base_path)
if build_returncode != 0:
print(f"Standalone build FAILED! with return code {build_returncode}")
return False
# Now rename the newly-built executable, and restore the old one
os.rename(full_player_path, final_player_path)
os.rename(temp_player_path, full_player_path)
standalone_player_path = f"testPlayer_{csharp_version}"
else:
standalone_player_path = "testPlayer"
init_venv(python_version)
# Copy the default training config but override the max_steps parameter,
# and reduce the batch_size and buffer_size enough to ensure an update step happens.
yaml_out = "override.yaml"
overrides = {
"hyperparameters": {"batch_size": 10, "buffer_size": 10},
"max_steps": 100,
}
override_config_file("config/ppo/3DBall.yaml", yaml_out, overrides)
log_output_path = f"{get_base_output_path()}/training.log"
env_path = os.path.join(get_base_output_path(), standalone_player_path)
mla_learn_cmd = [
"mlagents-learn",
yaml_out,
"--force",
"--env",
env_path,
"--run-id",
str(run_id),
"--no-graphics",
"--env-args",
"-logFile",
log_output_path,
]
res = subprocess.run(mla_learn_cmd)
# Save models as artifacts (only if we're using latest python and C#)
if csharp_version is None and python_version is None:
model_artifacts_dir = os.path.join(get_base_output_path(), "models")
os.makedirs(model_artifacts_dir, exist_ok=True)
if os.path.exists(onnx_file_expected):
shutil.copy(onnx_file_expected, model_artifacts_dir)
if res.returncode != 0 or not os.path.exists(onnx_file_expected):
print("mlagents-learn run FAILED!")
print("Command line: " + " ".join(mla_learn_cmd))
subprocess.run(["cat", log_output_path])
return False
if csharp_version is None and python_version is None:
# Use abs path so that loading doesn't get confused
model_path = os.path.abspath(os.path.dirname(onnx_file_expected))
inference_ok = run_inference(env_path, model_path, "onnx")
if not inference_ok:
return False
print("mlagents-learn run SUCCEEDED!")
return True
def run_inference(env_path: str, output_path: str, model_extension: str) -> bool:
start_time = time.time()
exes = find_executables(env_path)
if len(exes) != 1:
print(f"Can't determine the player executable in {env_path}. Found {exes}.")
return False
log_output_path = f"{get_base_output_path()}/inference.{model_extension}.txt"
# 10 minutes for inference is more than enough
process_timeout = 10 * 60
# Try to gracefully exit a few seconds before that.
model_override_timeout = process_timeout - 15
exe_path = exes[0]
args = [
exe_path,
"-nographics",
"-batchmode",
"-logfile",
log_output_path,
"--mlagents-override-model-directory",
output_path,
"--mlagents-quit-on-load-failure",
"--mlagents-quit-after-episodes",
"1",
"--mlagents-override-model-extension",
model_extension,
"--mlagents-quit-after-seconds",
str(model_override_timeout),
]
print(f"Starting inference with args {' '.join(args)}")
res = subprocess.run(args, timeout=process_timeout)
end_time = time.time()
if res.returncode != 0:
print("Error running inference!")
print("Command line: " + " ".join(args))
subprocess.run(["cat", log_output_path])
return False
else:
print(f"Inference finished! Took {end_time - start_time} seconds")
# Check the artifacts directory for the timers, so we can get the gauges
timer_file = f"{exe_path}_Data/ML-Agents/Timers/3DBall_timers.json"
with open(timer_file) as f:
timer_data = json.load(f)
gauges = timer_data.get("gauges", {})
rewards = gauges.get("Override_3DBall.CumulativeReward", {})
max_reward = rewards.get("max")
if max_reward is None:
print(
"Unable to find rewards in timer file. This usually indicates a problem with Barracuda or inference."
)
return False
# We could check that the rewards are over a threshold, but since we train for so short a time,
# the values could be highly variable. So don't do it for now.
return True
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--python", default=None)
parser.add_argument("--csharp", default=None)
args = parser.parse_args()
try:
ok = run_training(args.python, args.csharp)
if not ok:
sys.exit(1)
finally:
# Cleanup - this gets executed even if we hit sys.exit()
undo_git_checkout()
if __name__ == "__main__":
main()
|