diff --git a/src/jax_loop_utils/metric_writers/mlflow/metric_writer.py b/src/jax_loop_utils/metric_writers/mlflow/metric_writer.py index af639a0..809eb53 100644 --- a/src/jax_loop_utils/metric_writers/mlflow/metric_writer.py +++ b/src/jax_loop_utils/metric_writers/mlflow/metric_writer.py @@ -7,6 +7,8 @@ import mlflow import mlflow.config import mlflow.entities +import mlflow.exceptions +import mlflow.protos.databricks_pb2 import mlflow.tracking.fluent import numpy as np from absl import logging @@ -28,6 +30,7 @@ def __init__( experiment_name: str, run_name: str | None = None, tracking_uri: str | None = None, + _client_class: type[mlflow.MlflowClient] = mlflow.MlflowClient, ): """Initialize MLflow writer. @@ -37,17 +40,32 @@ def __init__( tracking_uri: Address of local or remote tracking server. Treated the same as arguments to mlflow.set_tracking_uri. See https://www.mlflow.org/docs/latest/python_api/mlflow.html#mlflow.set_tracking_uri + _client_class: MLflow client class (for testing only). """ - self._client = mlflow.MlflowClient(tracking_uri=tracking_uri) + self._client = _client_class(tracking_uri=tracking_uri) experiment = self._client.get_experiment_by_name(experiment_name) if experiment: experiment_id = experiment.experiment_id else: logging.info( - "Experiment with name '%s' does not exist. Creating a new experiment.", + "Experiment '%s' does not exist. Creating a new experiment.", experiment_name, ) - experiment_id = self._client.create_experiment(experiment_name) + try: + experiment_id = self._client.create_experiment(experiment_name) + except mlflow.exceptions.MlflowException as e: + # Handle race in creating experiment. + if e.error_code != mlflow.protos.databricks_pb2.ErrorCode.Name( + mlflow.protos.databricks_pb2.RESOURCE_ALREADY_EXISTS + ): + raise + experiment = self._client.get_experiment_by_name(experiment_name) + if not experiment: + raise RuntimeError( + "Failed to get, then failed to create, then failed to get " + f"again experiment '{experiment_name}'" + ) + experiment_id = experiment.experiment_id self._run_id = self._client.create_run( experiment_id=experiment_id, run_name=run_name ).info.run_id diff --git a/src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py b/src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py index 920902e..109e806 100644 --- a/src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py +++ b/src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py @@ -1,10 +1,12 @@ import tempfile import time +import jax.numpy as jnp import mlflow import mlflow.entities +import mlflow.exceptions +import mlflow.protos.databricks_pb2 import numpy as np -import jax.numpy as jnp from absl.testing import absltest from jax_loop_utils.metric_writers.mlflow import MlflowMetricWriter @@ -17,6 +19,21 @@ def _get_runs(tracking_uri: str, experiment_name: str) -> list[mlflow.entities.R return client.search_runs([experiment.experiment_id]) +def _exceptional_mlflow_client_class( + actually_create: bool, +) -> type[mlflow.MlflowClient]: + class ExceptionalMlflowClient(mlflow.MlflowClient): + def create_experiment(self, *args, **kwargs): + if actually_create: + return super().create_experiment(*args, **kwargs) + raise mlflow.exceptions.MlflowException( + "Experiment already exists", + error_code=mlflow.protos.databricks_pb2.RESOURCE_ALREADY_EXISTS, + ) + + return ExceptionalMlflowClient + + class MlflowMetricWriterTest(absltest.TestCase): def test_write_scalars(self): with tempfile.TemporaryDirectory() as temp_dir: @@ -135,6 +152,34 @@ def test_no_ops(self): self.assertEqual(run.data.metrics, {}) self.assertEqual(run.data.params, {}) + def test_experiment_creation_race_condition(self): + with tempfile.TemporaryDirectory() as temp_dir: + tracking_uri = f"file://{temp_dir}" + experiment_name = "race_condition_experiment" + + writer = MlflowMetricWriter( + experiment_name, + tracking_uri=tracking_uri, + _client_class=_exceptional_mlflow_client_class(True), + ) + + runs = _get_runs(tracking_uri, experiment_name) + self.assertEqual(len(runs), 1) + writer.close() + + def test_experiment_creation_race_condition_and_then_fail(self): + with tempfile.TemporaryDirectory() as temp_dir: + tracking_uri = f"file://{temp_dir}" + experiment_name = "race_condition_experiment" + + self.assertRaises( + RuntimeError, + MlflowMetricWriter, + experiment_name, + tracking_uri=tracking_uri, + _client_class=_exceptional_mlflow_client_class(False), + ) + if __name__ == "__main__": absltest.main()