Skip to content

Commit

Permalink
PR comments:
Browse files Browse the repository at this point in the history
* rename class
* use list comprehension

Change-Id: I68fe00540b7805cf7b7f520d80333915735d34f2
  • Loading branch information
garymm committed Dec 11, 2024
1 parent 49cb536 commit 56536f0
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 14 deletions.
4 changes: 2 additions & 2 deletions src/jax_loop_utils/metric_writers/mlflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .metric_writer import MetricWriter
from .metric_writer import MlflowMetricWriter

__all__ = ["MetricWriter"]
__all__ = ["MlflowMetricWriter"]
9 changes: 5 additions & 4 deletions src/jax_loop_utils/metric_writers/mlflow/metric_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)


class MetricWriter(MetricWriterInterface):
class MlflowMetricWriter(MetricWriterInterface):
"""MLflow implementation of MetricWriter."""

def __init__(self, experiment_name: str, tracking_uri: str | None = None):
Expand All @@ -46,10 +46,11 @@ def __init__(self, experiment_name: str, tracking_uri: str | None = None):

def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
"""Write scalar metrics to MLflow."""
metrics_list: list[mlflow.entities.Metric] = []
timestamp = int(time() * 1000)
for k, v in scalars.items():
metrics_list.append(mlflow.entities.Metric(k, float(v), timestamp, step))
metrics_list = [
mlflow.entities.Metric(k, float(v), timestamp, step)
for k, v in scalars.items()
]
self._client.log_batch(self._run_id, metrics=metrics_list, synchronous=False)

def write_images(self, step: int, images: Mapping[str, Array]):
Expand Down
16 changes: 8 additions & 8 deletions src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
from absl.testing import absltest

from jax_loop_utils.metric_writers.mlflow import MetricWriter
from jax_loop_utils.metric_writers.mlflow import MlflowMetricWriter


def _get_runs(tracking_uri: str, experiment_name: str) -> list[mlflow.entities.Run]:
Expand All @@ -16,12 +16,12 @@ def _get_runs(tracking_uri: str, experiment_name: str) -> list[mlflow.entities.R
return client.search_runs([experiment.experiment_id])


class MetricWriterTest(absltest.TestCase):
class MlflowMetricWriterTest(absltest.TestCase):
def test_write_scalars(self):
with tempfile.TemporaryDirectory() as temp_dir:
tracking_uri = f"file://{temp_dir}"
experiment_name = "experiment_name"
writer = MetricWriter(experiment_name, tracking_uri=tracking_uri)
writer = MlflowMetricWriter(experiment_name, tracking_uri=tracking_uri)
seq_of_scalars = (
{"a": 3, "b": 0.15},
{"a": 5, "b": 0.007},
Expand All @@ -45,7 +45,7 @@ def test_write_scalars(self):
run = runs[0]
self.assertEqual(run.info.status, "FINISHED")
# check we can create a new writer with an existing experiment
writer = MetricWriter(experiment_name, tracking_uri=tracking_uri)
writer = MlflowMetricWriter(experiment_name, tracking_uri=tracking_uri)
writer.write_scalars(0, {"a": 1, "b": 2})
writer.flush()
writer.close()
Expand All @@ -57,7 +57,7 @@ def test_write_images(self):
with tempfile.TemporaryDirectory() as temp_dir:
tracking_uri = f"file://{temp_dir}"
experiment_name = "experiment_name"
writer = MetricWriter(experiment_name, tracking_uri=tracking_uri)
writer = MlflowMetricWriter(experiment_name, tracking_uri=tracking_uri)
writer.write_images(0, {"test_image": np.zeros((3, 3, 3), dtype=np.uint8)})
writer.close()

Expand All @@ -79,7 +79,7 @@ def test_write_texts(self):
with tempfile.TemporaryDirectory() as temp_dir:
tracking_uri = f"file://{temp_dir}"
experiment_name = "experiment_name"
writer = MetricWriter(experiment_name, tracking_uri=tracking_uri)
writer = MlflowMetricWriter(experiment_name, tracking_uri=tracking_uri)
test_text = "Hello world"
writer.write_texts(0, {"test_text": test_text})
writer.close()
Expand All @@ -102,7 +102,7 @@ def test_write_hparams(self):
with tempfile.TemporaryDirectory() as temp_dir:
tracking_uri = f"file://{temp_dir}"
experiment_name = "experiment_name"
writer = MetricWriter(experiment_name, tracking_uri=tracking_uri)
writer = MlflowMetricWriter(experiment_name, tracking_uri=tracking_uri)
test_params = {"learning_rate": 0.001, "batch_size": 32, "epochs": 100}
writer.write_hparams(test_params)
writer.close()
Expand All @@ -118,7 +118,7 @@ def test_no_ops(self):
with tempfile.TemporaryDirectory() as temp_dir:
tracking_uri = f"file://{temp_dir}"
experiment_name = "experiment_name"
writer = MetricWriter(experiment_name, tracking_uri=tracking_uri)
writer = MlflowMetricWriter(experiment_name, tracking_uri=tracking_uri)
writer.write_videos(0, {"video": np.zeros((4, 28, 28, 3))})
writer.write_audios(0, {"audio": np.zeros((2, 1000))}, sample_rate=16000)
writer.write_histograms(
Expand Down

0 comments on commit 56536f0

Please sign in to comment.