Spaces:
Sleeping
Sleeping
""" | |
Downloads the MovieLens dataset, ETLs it into Parquet, trains an | |
ALS model, and uses the ALS model to train a Keras neural network. | |
See README.rst for more details. | |
""" | |
import os | |
import click | |
import mlflow | |
from mlflow.entities import RunStatus | |
from mlflow.tracking import MlflowClient | |
from mlflow.tracking.fluent import _get_experiment_id | |
from mlflow.utils import mlflow_tags | |
from mlflow.utils.logging_utils import eprint | |
def _already_ran(entry_point_name, parameters, git_commit, experiment_id=None): | |
"""Best-effort detection of if a run with the given entrypoint name, | |
parameters, and experiment id already ran. The run must have completed | |
successfully and have at least the parameters provided. | |
""" | |
experiment_id = experiment_id if experiment_id is not None else _get_experiment_id() | |
client = MlflowClient() | |
all_runs = reversed(client.search_runs([experiment_id])) | |
for run in all_runs: | |
tags = run.data.tags | |
if tags.get(mlflow_tags.MLFLOW_PROJECT_ENTRY_POINT, None) != entry_point_name: | |
continue | |
match_failed = False | |
for param_key, param_value in parameters.items(): | |
run_value = run.data.params.get(param_key) | |
if run_value != param_value: | |
match_failed = True | |
break | |
if match_failed: | |
continue | |
if run.info.to_proto().status != RunStatus.FINISHED: | |
eprint( | |
("Run matched, but is not FINISHED, so skipping (run_id={}, status={})").format( | |
run.info.run_id, run.info.status | |
) | |
) | |
continue | |
previous_version = tags.get(mlflow_tags.MLFLOW_GIT_COMMIT, None) | |
if git_commit != previous_version: | |
eprint( | |
"Run matched, but has a different source version, so skipping " | |
f"(found={previous_version}, expected={git_commit})" | |
) | |
continue | |
return client.get_run(run.info.run_id) | |
eprint("No matching run has been found.") | |
return None | |
# TODO(aaron): This is not great because it doesn't account for: | |
# - changes in code | |
# - changes in dependent steps | |
def _get_or_run(entrypoint, parameters, git_commit, use_cache=True): | |
existing_run = _already_ran(entrypoint, parameters, git_commit) | |
if use_cache and existing_run: | |
print(f"Found existing run for entrypoint={entrypoint} and parameters={parameters}") | |
return existing_run | |
print(f"Launching new run for entrypoint={entrypoint} and parameters={parameters}") | |
submitted_run = mlflow.run(".", entrypoint, parameters=parameters, env_manager="local") | |
return MlflowClient().get_run(submitted_run.run_id) | |
def workflow(als_max_iter, keras_hidden_units, max_row_limit): | |
# Note: The entrypoint names are defined in MLproject. The artifact directories | |
# are documented by each step's .py file. | |
with mlflow.start_run() as active_run: | |
os.environ["SPARK_CONF_DIR"] = os.path.abspath(".") | |
git_commit = active_run.data.tags.get(mlflow_tags.MLFLOW_GIT_COMMIT) | |
load_raw_data_run = _get_or_run("load_raw_data", {}, git_commit) | |
ratings_csv_uri = os.path.join(load_raw_data_run.info.artifact_uri, "ratings-csv-dir") | |
etl_data_run = _get_or_run( | |
"etl_data", {"ratings_csv": ratings_csv_uri, "max_row_limit": max_row_limit}, git_commit | |
) | |
ratings_parquet_uri = os.path.join(etl_data_run.info.artifact_uri, "ratings-parquet-dir") | |
# We specify a spark-defaults.conf to override the default driver memory. ALS requires | |
# significant memory. The driver memory property cannot be set by the application itself. | |
als_run = _get_or_run( | |
"als", {"ratings_data": ratings_parquet_uri, "max_iter": str(als_max_iter)}, git_commit | |
) | |
als_model_uri = os.path.join(als_run.info.artifact_uri, "als-model") | |
keras_params = { | |
"ratings_data": ratings_parquet_uri, | |
"als_model_uri": als_model_uri, | |
"hidden_units": keras_hidden_units, | |
} | |
_get_or_run("train_keras", keras_params, git_commit, use_cache=False) | |
if __name__ == "__main__": | |
workflow() |