Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add mlflow metric writer #12

Merged
merged 3 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@ jobs:
with:
enable-cache: true
- name: pytest
run: uv run --extra test pytest --capture=no --verbose --cov --cov-report=xml --ignore-glob='jax_loop_utils/metric_writers/tf/*' --ignore-glob='jax_loop_utils/metric_writers/torch/*' jax_loop_utils/
run: uv run --extra test pytest --capture=no --verbose --cov --cov-report=xml --ignore-glob='jax_loop_utils/metric_writers/tf/*' --ignore-glob='jax_loop_utils/metric_writers/torch/*' --ignore-glob='jax_loop_utils/metric_writers/mlflow/*' jax_loop_utils/
working-directory: src
- name: pytest tensorflow
run: uv run --extra test --extra tensorflow pytest --capture=no --verbose --cov --cov-report=xml --cov-append jax_loop_utils/metric_writers/tf
working-directory: src
- name: pytest torch
run: uv run --extra test --extra torch pytest --capture=no --verbose --cov --cov-report=xml --cov-append jax_loop_utils/metric_writers/torch
working-directory: src
- name: pytest mlflow
run: uv run --extra test --extra mlflow pytest --capture=no --verbose --cov --cov-report=xml --cov-append jax_loop_utils/metric_writers/mlflow
working-directory: src
mickvangelderen marked this conversation as resolved.
Show resolved Hide resolved
- name: Upload coverage reports to Codecov
if: always()
uses: codecov/codecov-action@v4
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ keywords = ["JAX", "machine learning"]
Homepage = "http://github.com/Astera-org/jax_loop_utils"

[project.optional-dependencies]
mlflow = ["mlflow-skinny>=2.0"]
pyright = ["pyright"]
# for synopsis.ipynb
synopsis = ["chex", "flax", "ipykernel", "matplotlib"]
Expand Down
3 changes: 3 additions & 0 deletions src/jax_loop_utils/metric_writers/mlflow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .metric_writer import MlflowMetricWriter

__all__ = ["MlflowMetricWriter"]
120 changes: 120 additions & 0 deletions src/jax_loop_utils/metric_writers/mlflow/metric_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""MLflow implementation of MetricWriter interface."""

from collections.abc import Mapping
from time import time
from typing import Any

import mlflow
import mlflow.config
import mlflow.entities
import mlflow.tracking.fluent
from absl import logging

from jax_loop_utils.metric_writers.interface import (
Array,
Scalar,
)
from jax_loop_utils.metric_writers.interface import (
MetricWriter as MetricWriterInterface,
)


class MlflowMetricWriter(MetricWriterInterface):
"""MLflow implementation of MetricWriter."""

def __init__(self, experiment_name: str, tracking_uri: str | None = None):
"""Initialize MLflow writer.

Args:
experiment_name: Name of the MLflow experiment.
tracking_uri: Address of local or remote tracking server. If not provided, defaults
to the service set by ``mlflow.tracking.set_tracking_uri``. See
`Where Runs Get Recorded <../tracking.html#where-runs-get-recorded>`_
for more info.
"""
self._client = mlflow.MlflowClient(tracking_uri=tracking_uri)
experiment = self._client.get_experiment_by_name(experiment_name)
if experiment:
experiment_id = experiment.experiment_id
else:
logging.info(
"Experiment with name '%s' does not exist. Creating a new experiment.",
experiment_name,
)
experiment_id = self._client.create_experiment(experiment_name)
self._run_id = self._client.create_run(experiment_id=experiment_id).info.run_id

def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
"""Write scalar metrics to MLflow."""
timestamp = int(time() * 1000)
metrics_list = [
mlflow.entities.Metric(k, float(v), timestamp, step)
for k, v in scalars.items()
]
self._client.log_batch(self._run_id, metrics=metrics_list, synchronous=False)

def write_images(self, step: int, images: Mapping[str, Array]):
"""Write images to MLflow."""
for key, image_array in images.items():
self._client.log_image(
self._run_id, image_array, key=key, step=step, synchronous=False
)

def write_videos(self, step: int, videos: Mapping[str, Array]):
"""MLflow doesn't support video logging directly."""
# this could be supported if we convert the video to a file
# and log the file as an artifact.
logging.log_first_n(
logging.WARNING,
"mlflow.MetricWriter does not support writing videos.",
1,
)

def write_audios(self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
"""MLflow doesn't support audio logging directly."""
# this could be supported if we convert the video to a file
# and log the file as an artifact.
logging.log_first_n(
logging.WARNING,
"mlflow.MetricWriter does not support writing audios.",
1,
)

def write_texts(self, step: int, texts: Mapping[str, str]):
"""Write text artifacts to MLflow."""
for key, text in texts.items():
self._client.log_text(self._run_id, text, f"{key}_step_{step}.txt")

def write_histograms(
self,
step: int,
arrays: Mapping[str, Array],
num_buckets: Mapping[str, int] | None = None,
):
"""MLflow doesn't support histogram logging directly.

https://github.com/mlflow/mlflow/issues/8145
"""
logging.log_first_n(
logging.WARNING,
"mlflow.MetricWriter does not support writing histograms.",
1,
)

def write_hparams(self, hparams: Mapping[str, Any]):
"""Log hyperparameters to MLflow."""
params = [
mlflow.entities.Param(key, str(value)) for key, value in hparams.items()
]
self._client.log_batch(self._run_id, params=params, synchronous=False)

def flush(self):
"""Flushes all logged data."""
mlflow.flush_artifact_async_logging()
mlflow.flush_async_logging()
mlflow.flush_trace_async_logging()

def close(self):
"""End the MLflow run."""
self._client.set_terminated(self._run_id)
self.flush()
139 changes: 139 additions & 0 deletions src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import tempfile
import time

import mlflow
import mlflow.entities
import numpy as np
from absl.testing import absltest

from jax_loop_utils.metric_writers.mlflow import MlflowMetricWriter


def _get_runs(tracking_uri: str, experiment_name: str) -> list[mlflow.entities.Run]:
client = mlflow.MlflowClient(tracking_uri=tracking_uri)
experiment = client.get_experiment_by_name(experiment_name)
assert experiment is not None
return client.search_runs([experiment.experiment_id])


class MlflowMetricWriterTest(absltest.TestCase):
def test_write_scalars(self):
with tempfile.TemporaryDirectory() as temp_dir:
tracking_uri = f"file://{temp_dir}"
experiment_name = "experiment_name"
writer = MlflowMetricWriter(experiment_name, tracking_uri=tracking_uri)
seq_of_scalars = (
{"a": 3, "b": 0.15},
{"a": 5, "b": 0.007},
)
for step, scalars in enumerate(seq_of_scalars):
writer.write_scalars(step, scalars)
writer.flush()
runs = _get_runs(tracking_uri, experiment_name)
self.assertEqual(len(runs), 1)
run = runs[0]
for metric_key in ("a", "b"):
self.assertIn(metric_key, run.data.metrics)
self.assertEqual(
run.data.metrics[metric_key], seq_of_scalars[-1][metric_key]
)
# constant defined in mlflow.entities.RunStatus
self.assertEqual(run.info.status, "RUNNING")
writer.close()
runs = _get_runs(tracking_uri, experiment_name)
self.assertEqual(len(runs), 1)
run = runs[0]
self.assertEqual(run.info.status, "FINISHED")
# check we can create a new writer with an existing experiment
writer = MlflowMetricWriter(experiment_name, tracking_uri=tracking_uri)
writer.write_scalars(0, {"a": 1, "b": 2})
writer.flush()
writer.close()
# should result in a new run
runs = _get_runs(tracking_uri, experiment_name)
self.assertEqual(len(runs), 2)

def test_write_images(self):
with tempfile.TemporaryDirectory() as temp_dir:
tracking_uri = f"file://{temp_dir}"
experiment_name = "experiment_name"
writer = MlflowMetricWriter(experiment_name, tracking_uri=tracking_uri)
writer.write_images(0, {"test_image": np.zeros((3, 3, 3), dtype=np.uint8)})
writer.close()

runs = _get_runs(tracking_uri, experiment_name)
self.assertEqual(len(runs), 1)
run = runs[0]
# the string "images" is hardcoded in MlflowClient.log_image.
artifacts = writer._client.list_artifacts(run.info.run_id, "images")
if not artifacts:
# have seen some latency in artifacts becoming available
# Maybe file system sync? Not sure.
time.sleep(0.1)
artifacts = writer._client.list_artifacts(run.info.run_id, "images")
artifact_paths = [artifact.path for artifact in artifacts]
self.assertGreaterEqual(len(artifact_paths), 1)
self.assertIn("test_image", artifact_paths[0])

def test_write_texts(self):
with tempfile.TemporaryDirectory() as temp_dir:
tracking_uri = f"file://{temp_dir}"
experiment_name = "experiment_name"
writer = MlflowMetricWriter(experiment_name, tracking_uri=tracking_uri)
test_text = "Hello world"
writer.write_texts(0, {"test_text": test_text})
writer.close()

runs = _get_runs(tracking_uri, experiment_name)
self.assertEqual(len(runs), 1)
run = runs[0]
artifacts = writer._client.list_artifacts(run.info.run_id)
artifact_paths = [artifact.path for artifact in artifacts]
self.assertGreaterEqual(len(artifact_paths), 1)
self.assertIn("test_text_step_0.txt", artifact_paths)
local_path = writer._client.download_artifacts(
run.info.run_id, "test_text_step_0.txt"
)
with open(local_path, "r") as f:
content = f.read()
self.assertEqual(content, test_text)

def test_write_hparams(self):
with tempfile.TemporaryDirectory() as temp_dir:
tracking_uri = f"file://{temp_dir}"
experiment_name = "experiment_name"
writer = MlflowMetricWriter(experiment_name, tracking_uri=tracking_uri)
test_params = {"learning_rate": 0.001, "batch_size": 32, "epochs": 100}
writer.write_hparams(test_params)
writer.close()

runs = _get_runs(tracking_uri, experiment_name)
self.assertEqual(len(runs), 1)
run = runs[0]
self.assertEqual(run.data.params["learning_rate"], "0.001")
self.assertEqual(run.data.params["batch_size"], "32")
self.assertEqual(run.data.params["epochs"], "100")

def test_no_ops(self):
with tempfile.TemporaryDirectory() as temp_dir:
tracking_uri = f"file://{temp_dir}"
experiment_name = "experiment_name"
writer = MlflowMetricWriter(experiment_name, tracking_uri=tracking_uri)
writer.write_videos(0, {"video": np.zeros((4, 28, 28, 3))})
writer.write_audios(0, {"audio": np.zeros((2, 1000))}, sample_rate=16000)
writer.write_histograms(
0, {"histogram": np.zeros((10,))}, num_buckets={"histogram": 10}
)
writer.close()
runs = _get_runs(tracking_uri, experiment_name)
self.assertEqual(len(runs), 1)
run = runs[0]
artifacts = writer._client.list_artifacts(run.info.run_id)
# the above ops are all no-ops so no artifacts, metrics or params
self.assertEqual(len(artifacts), 0)
self.assertEqual(run.data.metrics, {})
self.assertEqual(run.data.params, {})


if __name__ == "__main__":
absltest.main()

Check warning on line 139 in src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py

View check run for this annotation

Codecov / codecov/patch

src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py#L139

Added line #L139 was not covered by tests
4 changes: 2 additions & 2 deletions src/jax_loop_utils/metric_writers/torch/tensorboard_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def write_images(self, step: int, images: Mapping[str, Array]):
def write_videos(self, step: int, videos: Mapping[str, Array]):
logging.log_first_n(
logging.WARNING,
"TorchTensorBoardWriter does not support writing videos.",
"torch.TensorboardWriter does not support writing videos.",
1,
)

Expand All @@ -60,7 +60,7 @@ def write_audios(self, step: int, audios: Mapping[str, Array], *, sample_rate: i

def write_texts(self, step: int, texts: Mapping[str, str]):
raise NotImplementedError(
"TorchTensorBoardWriter does not support writing texts."
"torch.TensorboardWriter does not support writing texts."
)

def write_histograms(
Expand Down
Loading