Spaces:
Sleeping
Sleeping
File size: 4,405 Bytes
79e1719 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
"""
Downloads the MovieLens dataset, ETLs it into Parquet, trains an
ALS model, and uses the ALS model to train a Keras neural network.
See README.rst for more details.
"""
import os
import click
import mlflow
from mlflow.entities import RunStatus
from mlflow.tracking import MlflowClient
from mlflow.tracking.fluent import _get_experiment_id
from mlflow.utils import mlflow_tags
from mlflow.utils.logging_utils import eprint
def _already_ran(entry_point_name, parameters, git_commit, experiment_id=None):
"""Best-effort detection of if a run with the given entrypoint name,
parameters, and experiment id already ran. The run must have completed
successfully and have at least the parameters provided.
"""
experiment_id = experiment_id if experiment_id is not None else _get_experiment_id()
client = MlflowClient()
all_runs = reversed(client.search_runs([experiment_id]))
for run in all_runs:
tags = run.data.tags
if tags.get(mlflow_tags.MLFLOW_PROJECT_ENTRY_POINT, None) != entry_point_name:
continue
match_failed = False
for param_key, param_value in parameters.items():
run_value = run.data.params.get(param_key)
if run_value != param_value:
match_failed = True
break
if match_failed:
continue
if run.info.to_proto().status != RunStatus.FINISHED:
eprint(
("Run matched, but is not FINISHED, so skipping (run_id={}, status={})").format(
run.info.run_id, run.info.status
)
)
continue
previous_version = tags.get(mlflow_tags.MLFLOW_GIT_COMMIT, None)
if git_commit != previous_version:
eprint(
"Run matched, but has a different source version, so skipping "
f"(found={previous_version}, expected={git_commit})"
)
continue
return client.get_run(run.info.run_id)
eprint("No matching run has been found.")
return None
# TODO(aaron): This is not great because it doesn't account for:
# - changes in code
# - changes in dependent steps
def _get_or_run(entrypoint, parameters, git_commit, use_cache=True):
existing_run = _already_ran(entrypoint, parameters, git_commit)
if use_cache and existing_run:
print(f"Found existing run for entrypoint={entrypoint} and parameters={parameters}")
return existing_run
print(f"Launching new run for entrypoint={entrypoint} and parameters={parameters}")
submitted_run = mlflow.run(".", entrypoint, parameters=parameters, env_manager="local")
return MlflowClient().get_run(submitted_run.run_id)
@click.command()
@click.option("--als-max-iter", default=10, type=int)
@click.option("--keras-hidden-units", default=20, type=int)
@click.option("--max-row-limit", default=100000, type=int)
def workflow(als_max_iter, keras_hidden_units, max_row_limit):
# Note: The entrypoint names are defined in MLproject. The artifact directories
# are documented by each step's .py file.
with mlflow.start_run() as active_run:
os.environ["SPARK_CONF_DIR"] = os.path.abspath(".")
git_commit = active_run.data.tags.get(mlflow_tags.MLFLOW_GIT_COMMIT)
load_raw_data_run = _get_or_run("load_raw_data", {}, git_commit)
ratings_csv_uri = os.path.join(load_raw_data_run.info.artifact_uri, "ratings-csv-dir")
etl_data_run = _get_or_run(
"etl_data", {"ratings_csv": ratings_csv_uri, "max_row_limit": max_row_limit}, git_commit
)
ratings_parquet_uri = os.path.join(etl_data_run.info.artifact_uri, "ratings-parquet-dir")
# We specify a spark-defaults.conf to override the default driver memory. ALS requires
# significant memory. The driver memory property cannot be set by the application itself.
als_run = _get_or_run(
"als", {"ratings_data": ratings_parquet_uri, "max_iter": str(als_max_iter)}, git_commit
)
als_model_uri = os.path.join(als_run.info.artifact_uri, "als-model")
keras_params = {
"ratings_data": ratings_parquet_uri,
"als_model_uri": als_model_uri,
"hidden_units": keras_hidden_units,
}
_get_or_run("train_keras", keras_params, git_commit, use_cache=False)
if __name__ == "__main__":
workflow() |