ryanrahmadifa
Added files
79e1719
raw
history blame
4.17 kB
"""
Time series forecasting
"""
import os
import click
import mlflow
from mlflow.entities import RunStatus
from mlflow.tracking import MlflowClient
from mlflow.tracking.fluent import _get_experiment_id
from mlflow.utils import mlflow_tags
from mlflow.utils.logging_utils import eprint
def _already_ran(entry_point_name, parameters, git_commit, experiment_id=None):
"""Best-effort detection of if a run with the given entrypoint name,
parameters, and experiment id already ran. The run must have completed
successfully and have at least the parameters provided.
"""
experiment_id = experiment_id if experiment_id is not None else _get_experiment_id()
client = MlflowClient()
all_runs = reversed(client.search_runs([experiment_id]))
for run in all_runs:
tags = run.data.tags
if tags.get(mlflow_tags.MLFLOW_PROJECT_ENTRY_POINT, None) != entry_point_name:
continue
match_failed = False
for param_key, param_value in parameters.items():
run_value = run.data.params.get(param_key)
if run_value != param_value:
match_failed = True
break
if match_failed:
continue
if run.info.to_proto().status != RunStatus.FINISHED:
eprint(
("Run matched, but is not FINISHED, so skipping (run_id={}, status={})").format(
run.info.run_id, run.info.status
)
)
continue
previous_version = tags.get(mlflow_tags.MLFLOW_GIT_COMMIT, None)
if git_commit != previous_version:
eprint(
"Run matched, but has a different source version, so skipping "
f"(found={previous_version}, expected={git_commit})"
)
continue
return client.get_run(run.info.run_id)
eprint("No matching run has been found.")
return None
# TODO(aaron): This is not great because it doesn't account for:
# - changes in code
# - changes in dependent steps
def _get_or_run(entrypoint, parameters, git_commit, use_cache=True):
existing_run = _already_ran(entrypoint, parameters, git_commit)
if use_cache and existing_run:
print(f"Found existing run for entrypoint={entrypoint} and parameters={parameters}")
return existing_run
print(f"Launching new run for entrypoint={entrypoint} and parameters={parameters}")
submitted_run = mlflow.run(".", entrypoint, parameters=parameters, env_manager="local")
return MlflowClient().get_run(submitted_run.run_id)
@click.command()
@click.option("--max-row-limit", default=100000, type=int)
def workflow(max_row_limit):
# Note: The entrypoint names are defined in MLproject. The artifact directories
# are documented by each step's .py file.
with mlflow.start_run() as active_run:
os.environ["SPARK_CONF_DIR"] = os.path.abspath(".")
git_commit = active_run.data.tags.get(mlflow_tags.MLFLOW_GIT_COMMIT)
ingest_request_run = _get_or_run("ingest_request", {}, git_commit)
data_csv_uri = os.path.join(ingest_request_run.info.artifact_uri, "data-csv-dir")
print(data_csv_uri)
ingest_convert_run = _get_or_run(
"ingest_convert", {"data-csv": data_csv_uri, "max-row-limit": max_row_limit}, git_commit
)
data_parquet_uri = os.path.join(ingest_convert_run.info.artifact_uri, "data-parquet-dir")
# We specify a spark-defaults.conf to override the default driver memory. ALS requires
# significant memory. The driver memory property cannot be set by the application itself.
# als_run = _get_or_run(
# "als", {"ratings_data": ratings_parquet_uri, "max_iter": str(als_max_iter)}, git_commit
# )
# als_model_uri = os.path.join(als_run.info.artifact_uri, "als-model")
# keras_params = {
# "ratings_data": ratings_parquet_uri,
# "als_model_uri": als_model_uri,
# "hidden_units": keras_hidden_units,
# }
# _get_or_run("train_keras", keras_params, git_commit, use_cache=False)
if __name__ == "__main__":
workflow()