Skip to content

Commit

Permalink
mlflow: Handle race in create experiment
Browse files Browse the repository at this point in the history
Fixes: Astera-org/obelisk#805
Change-Id: I196f285a77a931b967ea9cfa49294d8e45e82c2a
  • Loading branch information
garymm committed Dec 23, 2024
1 parent c514376 commit e1c9b4a
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 4 deletions.
24 changes: 21 additions & 3 deletions src/jax_loop_utils/metric_writers/mlflow/metric_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import mlflow
import mlflow.config
import mlflow.entities
import mlflow.exceptions
import mlflow.protos.databricks_pb2
import mlflow.tracking.fluent
import numpy as np
from absl import logging
Expand All @@ -28,6 +30,7 @@ def __init__(
experiment_name: str,
run_name: str | None = None,
tracking_uri: str | None = None,
_client_class: type[mlflow.MlflowClient] = mlflow.MlflowClient,
):
"""Initialize MLflow writer.
Expand All @@ -37,17 +40,32 @@ def __init__(
tracking_uri: Address of local or remote tracking server.
Treated the same as arguments to mlflow.set_tracking_uri.
See https://www.mlflow.org/docs/latest/python_api/mlflow.html#mlflow.set_tracking_uri
_client_class: MLflow client class (for testing only).
"""
self._client = mlflow.MlflowClient(tracking_uri=tracking_uri)
self._client = _client_class(tracking_uri=tracking_uri)
experiment = self._client.get_experiment_by_name(experiment_name)
if experiment:
experiment_id = experiment.experiment_id
else:
logging.info(
"Experiment with name '%s' does not exist. Creating a new experiment.",
"Experiment '%s' does not exist. Creating a new experiment.",
experiment_name,
)
experiment_id = self._client.create_experiment(experiment_name)
try:
experiment_id = self._client.create_experiment(experiment_name)
except mlflow.exceptions.MlflowException as e:
# Handle race in creating experiment.
if e.error_code != mlflow.protos.databricks_pb2.ErrorCode.Name(
mlflow.protos.databricks_pb2.RESOURCE_ALREADY_EXISTS
):
raise
experiment = self._client.get_experiment_by_name(experiment_name)
if not experiment:
raise RuntimeError(
"Failed to get, then failed to create, then failed to get "
f"again experiment '{experiment_name}'"
)
experiment_id = experiment.experiment_id
self._run_id = self._client.create_run(
experiment_id=experiment_id, run_name=run_name
).info.run_id
Expand Down
47 changes: 46 additions & 1 deletion src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import tempfile
import time

import jax.numpy as jnp
import mlflow
import mlflow.entities
import mlflow.exceptions
import mlflow.protos.databricks_pb2
import numpy as np
import jax.numpy as jnp
from absl.testing import absltest

from jax_loop_utils.metric_writers.mlflow import MlflowMetricWriter
Expand All @@ -17,6 +19,21 @@ def _get_runs(tracking_uri: str, experiment_name: str) -> list[mlflow.entities.R
return client.search_runs([experiment.experiment_id])


def _exceptional_mlflow_client_class(
actually_create: bool,
) -> type[mlflow.MlflowClient]:
class ExceptionalMlflowClient(mlflow.MlflowClient):
def create_experiment(self, *args, **kwargs):
if actually_create:
return super().create_experiment(*args, **kwargs)
raise mlflow.exceptions.MlflowException(
"Experiment already exists",
error_code=mlflow.protos.databricks_pb2.RESOURCE_ALREADY_EXISTS,
)

return ExceptionalMlflowClient


class MlflowMetricWriterTest(absltest.TestCase):
def test_write_scalars(self):
with tempfile.TemporaryDirectory() as temp_dir:
Expand Down Expand Up @@ -135,6 +152,34 @@ def test_no_ops(self):
self.assertEqual(run.data.metrics, {})
self.assertEqual(run.data.params, {})

def test_experiment_creation_race_condition(self):
with tempfile.TemporaryDirectory() as temp_dir:
tracking_uri = f"file://{temp_dir}"
experiment_name = "race_condition_experiment"

writer = MlflowMetricWriter(
experiment_name,
tracking_uri=tracking_uri,
_client_class=_exceptional_mlflow_client_class(True),
)

runs = _get_runs(tracking_uri, experiment_name)
self.assertEqual(len(runs), 1)
writer.close()

def test_experiment_creation_race_condition_and_then_fail(self):
with tempfile.TemporaryDirectory() as temp_dir:
tracking_uri = f"file://{temp_dir}"
experiment_name = "race_condition_experiment"

self.assertRaises(
RuntimeError,
MlflowMetricWriter,
experiment_name,
tracking_uri=tracking_uri,
_client_class=_exceptional_mlflow_client_class(False),
)


if __name__ == "__main__":
absltest.main()

0 comments on commit e1c9b4a

Please sign in to comment.