Skip to content

Commit

Permalink
Implement MlflowMetricWriter.write_tags
Browse files Browse the repository at this point in the history
  • Loading branch information
mickvangelderen committed Jan 13, 2025
1 parent 2670570 commit 0c1a842
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/jax_loop_utils/metric_writers/mlflow/metric_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import mlflow.tracking.fluent
import numpy as np
from absl import logging
from mlflow.entities import RunTag

from jax_loop_utils import asynclib
from jax_loop_utils.metric_writers.interface import (
Expand Down Expand Up @@ -82,6 +83,12 @@ def __init__(
experiment_id=experiment_id, run_name=run_name
).info.run_id

def write_tags(self, tags: dict[str, Any]):
"""Set tags on the MLFlow run"""
self._client.log_batch(

Check warning on line 88 in src/jax_loop_utils/metric_writers/mlflow/metric_writer.py

View check run for this annotation

Codecov / codecov/patch

src/jax_loop_utils/metric_writers/mlflow/metric_writer.py#L88

Added line #L88 was not covered by tests
self._run_id, [], [], [RunTag(k, str(v)) for k, v in tags.items()], synchronous=False
)

def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
"""Write scalar metrics to MLflow."""
timestamp = int(time.time() * 1000)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ def create_experiment(self, *args, **kwargs):


class MlflowMetricWriterTest(absltest.TestCase):
def test_set_tags(self):
with tempfile.TemporaryDirectory() as temp_dir:
tracking_uri = f"file://{temp_dir}"
experiment_name = "experiment_name"
writer = MlflowMetricWriter(experiment_name, tracking_uri=tracking_uri)
writer.set_tags({"ooh": "aah"})

def test_write_scalars(self):
with tempfile.TemporaryDirectory() as temp_dir:
tracking_uri = f"file://{temp_dir}"
Expand Down

0 comments on commit 0c1a842

Please sign in to comment.