From b11e0695cac54d285ab84c5870a0c5454228c53b Mon Sep 17 00:00:00 2001 From: Gary Miguel Date: Fri, 6 Dec 2024 13:55:58 -0800 Subject: [PATCH 1/3] add mlflow metric writer Change-Id: Ief75dd77c24a63b05eec87b5b1f4822abb206abd --- pyproject.toml | 1 + .../metric_writers/mlflow/__init__.py | 3 + .../metric_writers/mlflow/metric_writer.py | 119 +++++++++ .../mlflow/metric_writer_test.py | 139 +++++++++++ .../torch/tensorboard_writer.py | 4 +- uv.lock | 233 ++++++++++++++++++ 6 files changed, 497 insertions(+), 2 deletions(-) create mode 100644 src/jax_loop_utils/metric_writers/mlflow/__init__.py create mode 100644 src/jax_loop_utils/metric_writers/mlflow/metric_writer.py create mode 100644 src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py diff --git a/pyproject.toml b/pyproject.toml index 18751f4..f39de7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ keywords = ["JAX", "machine learning"] Homepage = "http://github.com/Astera-org/jax_loop_utils" [project.optional-dependencies] +mlflow = ["mlflow-skinny>=2.0"] pyright = ["pyright"] # for synopsis.ipynb synopsis = ["chex", "flax", "ipykernel", "matplotlib"] diff --git a/src/jax_loop_utils/metric_writers/mlflow/__init__.py b/src/jax_loop_utils/metric_writers/mlflow/__init__.py new file mode 100644 index 0000000..db3025a --- /dev/null +++ b/src/jax_loop_utils/metric_writers/mlflow/__init__.py @@ -0,0 +1,3 @@ +from .metric_writer import MetricWriter + +__all__ = ["MetricWriter"] diff --git a/src/jax_loop_utils/metric_writers/mlflow/metric_writer.py b/src/jax_loop_utils/metric_writers/mlflow/metric_writer.py new file mode 100644 index 0000000..d7397f2 --- /dev/null +++ b/src/jax_loop_utils/metric_writers/mlflow/metric_writer.py @@ -0,0 +1,119 @@ +"""MLflow implementation of MetricWriter interface.""" + +from collections.abc import Mapping +from time import time +from typing import Any + +import mlflow +import mlflow.config +import mlflow.entities +import mlflow.tracking.fluent +from absl import logging + +from jax_loop_utils.metric_writers.interface import ( + Array, + Scalar, +) +from jax_loop_utils.metric_writers.interface import ( + MetricWriter as MetricWriterInterface, +) + + +class MetricWriter(MetricWriterInterface): + """MLflow implementation of MetricWriter.""" + + def __init__(self, experiment_name: str, tracking_uri: str | None = None): + """Initialize MLflow writer. + + Args: + experiment_name: Name of the MLflow experiment. + tracking_uri: Address of local or remote tracking server. If not provided, defaults + to the service set by ``mlflow.tracking.set_tracking_uri``. See + `Where Runs Get Recorded <../tracking.html#where-runs-get-recorded>`_ + for more info. + """ + self._client = mlflow.MlflowClient(tracking_uri=tracking_uri) + experiment = self._client.get_experiment_by_name(experiment_name) + if experiment: + experiment_id = experiment.experiment_id + else: + logging.info( + "Experiment with name '%s' does not exist. Creating a new experiment.", + experiment_name, + ) + experiment_id = self._client.create_experiment(experiment_name) + self._run_id = self._client.create_run(experiment_id=experiment_id).info.run_id + + def write_scalars(self, step: int, scalars: Mapping[str, Scalar]): + """Write scalar metrics to MLflow.""" + metrics_list: list[mlflow.entities.Metric] = [] + timestamp = int(time() * 1000) + for k, v in scalars.items(): + metrics_list.append(mlflow.entities.Metric(k, float(v), timestamp, step)) + self._client.log_batch(self._run_id, metrics=metrics_list, synchronous=False) + + def write_images(self, step: int, images: Mapping[str, Array]): + """Write images to MLflow.""" + for key, image_array in images.items(): + self._client.log_image( + self._run_id, image_array, key=key, step=step, synchronous=False + ) + + def write_videos(self, step: int, videos: Mapping[str, Array]): + """MLflow doesn't support video logging directly.""" + # this could be supported if we convert the video to a file + # and log the file as an artifact. + logging.log_first_n( + logging.WARNING, + "mlflow.MetricWriter does not support writing videos.", + 1, + ) + + def write_audios(self, step: int, audios: Mapping[str, Array], *, sample_rate: int): + """MLflow doesn't support audio logging directly.""" + # this could be supported if we convert the video to a file + # and log the file as an artifact. + logging.log_first_n( + logging.WARNING, + "mlflow.MetricWriter does not support writing audios.", + 1, + ) + + def write_texts(self, step: int, texts: Mapping[str, str]): + """Write text artifacts to MLflow.""" + for key, text in texts.items(): + self._client.log_text(self._run_id, text, f"{key}_step_{step}.txt") + + def write_histograms( + self, + step: int, + arrays: Mapping[str, Array], + num_buckets: Mapping[str, int] | None = None, + ): + """MLflow doesn't support histogram logging directly. + + https://github.com/mlflow/mlflow/issues/8145 + """ + logging.log_first_n( + logging.WARNING, + "mlflow.MetricWriter does not support writing histograms.", + 1, + ) + + def write_hparams(self, hparams: Mapping[str, Any]): + """Log hyperparameters to MLflow.""" + params = [ + mlflow.entities.Param(key, str(value)) for key, value in hparams.items() + ] + self._client.log_batch(self._run_id, params=params, synchronous=False) + + def flush(self): + """Flushes all logged data.""" + mlflow.flush_artifact_async_logging() + mlflow.flush_async_logging() + mlflow.flush_trace_async_logging() + + def close(self): + """End the MLflow run.""" + self._client.set_terminated(self._run_id) + self.flush() diff --git a/src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py b/src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py new file mode 100644 index 0000000..7c43649 --- /dev/null +++ b/src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py @@ -0,0 +1,139 @@ +import tempfile +import time + +import mlflow +import mlflow.entities +import numpy as np +from absl.testing import absltest + +from jax_loop_utils.metric_writers.mlflow import MetricWriter + + +def _get_runs(tracking_uri: str, experiment_name: str) -> list[mlflow.entities.Run]: + client = mlflow.MlflowClient(tracking_uri=tracking_uri) + experiment = client.get_experiment_by_name(experiment_name) + assert experiment is not None + return client.search_runs([experiment.experiment_id]) + + +class MetricWriterTest(absltest.TestCase): + def test_write_scalars(self): + with tempfile.TemporaryDirectory() as temp_dir: + tracking_uri = f"file://{temp_dir}" + experiment_name = "experiment_name" + writer = MetricWriter(experiment_name, tracking_uri=tracking_uri) + seq_of_scalars = ( + {"a": 3, "b": 0.15}, + {"a": 5, "b": 0.007}, + ) + for step, scalars in enumerate(seq_of_scalars): + writer.write_scalars(step, scalars) + writer.flush() + runs = _get_runs(tracking_uri, experiment_name) + self.assertEqual(len(runs), 1) + run = runs[0] + for metric_key in ("a", "b"): + self.assertIn(metric_key, run.data.metrics) + self.assertEqual( + run.data.metrics[metric_key], seq_of_scalars[-1][metric_key] + ) + # constant defined in mlflow.entities.RunStatus + self.assertEqual(run.info.status, "RUNNING") + writer.close() + runs = _get_runs(tracking_uri, experiment_name) + self.assertEqual(len(runs), 1) + run = runs[0] + self.assertEqual(run.info.status, "FINISHED") + # check we can create a new writer with an existing experiment + writer = MetricWriter(experiment_name, tracking_uri=tracking_uri) + writer.write_scalars(0, {"a": 1, "b": 2}) + writer.flush() + writer.close() + # should result in a new run + runs = _get_runs(tracking_uri, experiment_name) + self.assertEqual(len(runs), 2) + + def test_write_images(self): + with tempfile.TemporaryDirectory() as temp_dir: + tracking_uri = f"file://{temp_dir}" + experiment_name = "experiment_name" + writer = MetricWriter(experiment_name, tracking_uri=tracking_uri) + writer.write_images(0, {"test_image": np.zeros((3, 3, 3), dtype=np.uint8)}) + writer.close() + + runs = _get_runs(tracking_uri, experiment_name) + self.assertEqual(len(runs), 1) + run = runs[0] + # the string "images" is hardcoded in MlflowClient.log_image. + artifacts = writer._client.list_artifacts(run.info.run_id, "images") + if not artifacts: + # have seen some latency in artifacts becoming available + # Maybe file system sync? Not sure. + time.sleep(0.1) + artifacts = writer._client.list_artifacts(run.info.run_id, "images") + artifact_paths = [artifact.path for artifact in artifacts] + self.assertGreaterEqual(len(artifact_paths), 1) + self.assertIn("test_image", artifact_paths[0]) + + def test_write_texts(self): + with tempfile.TemporaryDirectory() as temp_dir: + tracking_uri = f"file://{temp_dir}" + experiment_name = "experiment_name" + writer = MetricWriter(experiment_name, tracking_uri=tracking_uri) + test_text = "Hello world" + writer.write_texts(0, {"test_text": test_text}) + writer.close() + + runs = _get_runs(tracking_uri, experiment_name) + self.assertEqual(len(runs), 1) + run = runs[0] + artifacts = writer._client.list_artifacts(run.info.run_id) + artifact_paths = [artifact.path for artifact in artifacts] + self.assertGreaterEqual(len(artifact_paths), 1) + self.assertIn("test_text_step_0.txt", artifact_paths) + local_path = writer._client.download_artifacts( + run.info.run_id, "test_text_step_0.txt" + ) + with open(local_path, "r") as f: + content = f.read() + self.assertEqual(content, test_text) + + def test_write_hparams(self): + with tempfile.TemporaryDirectory() as temp_dir: + tracking_uri = f"file://{temp_dir}" + experiment_name = "experiment_name" + writer = MetricWriter(experiment_name, tracking_uri=tracking_uri) + test_params = {"learning_rate": 0.001, "batch_size": 32, "epochs": 100} + writer.write_hparams(test_params) + writer.close() + + runs = _get_runs(tracking_uri, experiment_name) + self.assertEqual(len(runs), 1) + run = runs[0] + self.assertEqual(run.data.params["learning_rate"], "0.001") + self.assertEqual(run.data.params["batch_size"], "32") + self.assertEqual(run.data.params["epochs"], "100") + + def test_no_ops(self): + with tempfile.TemporaryDirectory() as temp_dir: + tracking_uri = f"file://{temp_dir}" + experiment_name = "experiment_name" + writer = MetricWriter(experiment_name, tracking_uri=tracking_uri) + writer.write_videos(0, {"video": np.zeros((4, 28, 28, 3))}) + writer.write_audios(0, {"audio": np.zeros((2, 1000))}, sample_rate=16000) + writer.write_histograms( + 0, {"histogram": np.zeros((10,))}, num_buckets={"histogram": 10} + ) + writer.close() + runs = _get_runs(tracking_uri, experiment_name) + self.assertEqual(len(runs), 1) + run = runs[0] + artifacts = writer._client.list_artifacts(run.info.run_id) + # the above ops are all no-ops so no artifacts, metrics or params + self.assertEqual(len(artifacts), 0) + self.assertEqual(run.data.metrics, {}) + self.assertEqual(run.data.params, {}) + + +if __name__ == "__main__": + absltest.main() diff --git a/src/jax_loop_utils/metric_writers/torch/tensorboard_writer.py b/src/jax_loop_utils/metric_writers/torch/tensorboard_writer.py index 840db5e..2993b33 100644 --- a/src/jax_loop_utils/metric_writers/torch/tensorboard_writer.py +++ b/src/jax_loop_utils/metric_writers/torch/tensorboard_writer.py @@ -48,7 +48,7 @@ def write_images(self, step: int, images: Mapping[str, Array]): def write_videos(self, step: int, videos: Mapping[str, Array]): logging.log_first_n( logging.WARNING, - "TorchTensorBoardWriter does not support writing videos.", + "torch.TensorboardWriter does not support writing videos.", 1, ) @@ -60,7 +60,7 @@ def write_audios(self, step: int, audios: Mapping[str, Array], *, sample_rate: i def write_texts(self, step: int, texts: Mapping[str, str]): raise NotImplementedError( - "TorchTensorBoardWriter does not support writing texts." + "torch.TensorboardWriter does not support writing texts." ) def write_histograms( diff --git a/uv.lock b/uv.lock index b6dd2a7..f2ec1e3 100644 --- a/uv.lock +++ b/uv.lock @@ -54,6 +54,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2b/03/13dde6512ad7b4557eb792fbcf0c653af6076b81e5941d36ec61f7ce6028/astunparse-1.6.3-py2.py3-none-any.whl", hash = "sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8", size = 12732 }, ] +[[package]] +name = "cachetools" +version = "5.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c3/38/a0f315319737ecf45b4319a8cd1f3a908e29d9277b46942263292115eee7/cachetools-5.5.0.tar.gz", hash = "sha256:2cc24fb4cbe39633fb7badd9db9ca6295d766d9c2995f245725a46715d050f2a", size = 27661 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a4/07/14f8ad37f2d12a5ce41206c21820d8cb6561b728e51fad4530dff0552a67/cachetools-5.5.0-py3-none-any.whl", hash = "sha256:02134e8439cdc2ffb62023ce1debca2944c3f289d66bb17ead3ab3dede74b292", size = 9524 }, +] + [[package]] name = "certifi" version = "2024.8.30" @@ -168,6 +177,36 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/46/dd/c1ff2eb8fbf95a8ca804abb1cc3ce70b283ee7b4bc653c3abac245670400/chex-0.1.87-py3-none-any.whl", hash = "sha256:ce536475661fd96d21be0c1728ecdbedd03f8ff950c662dfc338c92ea782cb16", size = 99369 }, ] +[[package]] +name = "click" +version = "8.1.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "(platform_machine == 'aarch64' and platform_system == 'Windows' and sys_platform == 'linux') or (platform_machine == 'x86_64' and platform_system == 'Windows' and sys_platform == 'linux') or (platform_machine == 'arm64' and platform_system == 'Windows' and sys_platform == 'darwin')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/2e/d53fa4befbf2cfa713304affc7ca780ce4fc1fd8710527771b58311a3229/click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28", size = 97941 }, +] + +[[package]] +name = "cloudpickle" +version = "3.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/97/c7/f746cadd08c4c08129215cf1b984b632f9e579fc781301e63da9e85c76c1/cloudpickle-3.1.0.tar.gz", hash = "sha256:81a929b6e3c7335c863c771d673d105f02efdb89dfaba0c90495d1c64796601b", size = 66155 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/48/41/e1d85ca3cab0b674e277c8c4f678cf66a91cd2cecf93df94353a606fe0db/cloudpickle-3.1.0-py3-none-any.whl", hash = "sha256:fe11acda67f61aaaec473e3afe030feb131d78a43461b718185363384f1ba12e", size = 22021 }, +] + +[[package]] +name = "colorama" +version = "0.4.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335 }, +] + [[package]] name = "comm" version = "0.2.2" @@ -277,6 +316,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl", hash = "sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30", size = 8321 }, ] +[[package]] +name = "databricks-sdk" +version = "0.38.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-auth", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "requests", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1d/c2/02fd3dad8d25b8b24a69925f79f04902e906212808e267ce0e39462e525b/databricks_sdk-0.38.0.tar.gz", hash = "sha256:65e505201b65d8a2b4110d3eabfebce5a25426d3ccdd5f8bc69eb03333ea1f39", size = 594528 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ed/0d/d203456eef915a40a90292be49ab87904509606d503eceff7ac8c28dbd8e/databricks_sdk-0.38.0-py3-none-any.whl", hash = "sha256:3cc3808e7a294ccf99a3f19f1e86c8e36a5dc0845ac62112dcae2e625ef97c28", size = 575096 }, +] + [[package]] name = "debugpy" version = "1.8.9" @@ -301,6 +353,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d5/50/83c593b07763e1161326b3b8c6686f0f4b0f24d5526546bee538c89837d6/decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186", size = 9073 }, ] +[[package]] +name = "deprecated" +version = "1.2.15" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "wrapt", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2e/a3/53e7d78a6850ffdd394d7048a31a6f14e44900adedf190f9a165f6b69439/deprecated-1.2.15.tar.gz", hash = "sha256:683e561a90de76239796e6b6feac66b99030d2dd3fcf61ef996330f14bbb9b0d", size = 2977612 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1d/8f/c7f227eb42cfeaddce3eb0c96c60cbca37797fa7b34f8e1aeadf6c5c0983/Deprecated-1.2.15-py2.py3-none-any.whl", hash = "sha256:353bc4a8ac4bfc96800ddab349d89c25dec1079f65fd53acdcc1e0b975b21320", size = 9941 }, +] + [[package]] name = "etils" version = "1.11.0" @@ -413,6 +477,44 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a3/61/8001b38461d751cd1a0c3a6ae84346796a5758123f3ed97a1b121dfbf4f3/gast-0.6.0-py3-none-any.whl", hash = "sha256:52b182313f7330389f72b069ba00f174cfe2a06411099547288839c6cbafbd54", size = 21173 }, ] +[[package]] +name = "gitdb" +version = "4.0.11" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "smmap", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/19/0d/bbb5b5ee188dec84647a4664f3e11b06ade2bde568dbd489d9d64adef8ed/gitdb-4.0.11.tar.gz", hash = "sha256:bf5421126136d6d0af55bc1e7c1af1c397a34f5b7bd79e776cd3e89785c2b04b", size = 394469 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fd/5b/8f0c4a5bb9fd491c277c21eff7ccae71b47d43c4446c9d0c6cff2fe8c2c4/gitdb-4.0.11-py3-none-any.whl", hash = "sha256:81a3407ddd2ee8df444cbacea00e2d038e40150acfa3001696fe0dcf1d3adfa4", size = 62721 }, +] + +[[package]] +name = "gitpython" +version = "3.1.43" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "gitdb", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b6/a1/106fd9fa2dd989b6fb36e5893961f82992cf676381707253e0bf93eb1662/GitPython-3.1.43.tar.gz", hash = "sha256:35f314a9f878467f5453cc1fee295c3e18e52f1b99f10f6cf5b1682e968a9e7c", size = 214149 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e9/bd/cc3a402a6439c15c3d4294333e13042b915bbeab54edc457c723931fed3f/GitPython-3.1.43-py3-none-any.whl", hash = "sha256:eec7ec56b92aad751f9912a73404bc02ba212a23adb2c7098ee668417051a1ff", size = 207337 }, +] + +[[package]] +name = "google-auth" +version = "2.36.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cachetools", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "pyasn1-modules", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "rsa", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6a/71/4c5387d8a3e46e3526a8190ae396659484377a73b33030614dd3b28e7ded/google_auth-2.36.0.tar.gz", hash = "sha256:545e9618f2df0bcbb7dcbc45a546485b1212624716975a1ea5ae8149ce769ab1", size = 268336 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2d/9a/3d5087d27865c2f0431b942b5c4500b7d1b744dd3262fdc973a4c39d099e/google_auth-2.36.0-py2.py3-none-any.whl", hash = "sha256:51a15d47028b66fd36e5c64a82d2d57480075bccc7da37cde257fc94177a61fb", size = 209519 }, +] + [[package]] name = "google-pasta" version = "0.2.0" @@ -495,6 +597,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442 }, ] +[[package]] +name = "importlib-metadata" +version = "8.5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "zipp", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/cd/12/33e59336dca5be0c398a7482335911a33aa0e20776128f038019f1a95f1b/importlib_metadata-8.5.0.tar.gz", hash = "sha256:71522656f0abace1d072b9e5481a48f07c138e00f079c38c8f883823f9c26bd7", size = 55304 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/d9/a1e041c5e7caa9a05c925f4bdbdfb7f006d1f74996af53467bc394c97be7/importlib_metadata-8.5.0-py3-none-any.whl", hash = "sha256:45e54197d28b7a7f1559e60b95e7c567032b602131fbd588f1497f47880aa68b", size = 26514 }, +] + [[package]] name = "importlib-resources" version = "6.4.5" @@ -588,6 +702,9 @@ dependencies = [ ] [package.optional-dependencies] +mlflow = [ + { name = "mlflow-skinny", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, +] pyright = [ { name = "pyright", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, ] @@ -619,6 +736,7 @@ requires-dist = [ { name = "ipykernel", marker = "extra == 'synopsis'" }, { name = "jax" }, { name = "matplotlib", marker = "extra == 'synopsis'" }, + { name = "mlflow-skinny", marker = "extra == 'mlflow'", specifier = ">=2.0" }, { name = "numpy" }, { name = "packaging" }, { name = "pillow" }, @@ -926,6 +1044,30 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/c6/f89620cecc0581dc1839e218c4315171312e46c62a62da6ace204bda91c0/ml_dtypes-0.4.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:126e7d679b8676d1a958f2651949fbfa182832c3cd08020d8facd94e4114f3e9", size = 2160488 }, ] +[[package]] +name = "mlflow-skinny" +version = "2.18.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cachetools", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "click", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "cloudpickle", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "databricks-sdk", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "gitpython", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "importlib-metadata", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "opentelemetry-api", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "opentelemetry-sdk", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "packaging", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "protobuf", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "pyyaml", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "requests", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "sqlparse", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f9/89/3fbcf0e415678029b783d6951373443aa64cb352c4959374f08903710690/mlflow_skinny-2.18.0.tar.gz", hash = "sha256:87e83f56c362a520196b2f0292b24efdca7f8b2068a6a6941f2ec9feb9bfd914", size = 5445516 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2e/1b/20128a015405fdfda2dce38b975acf19cd532f4c8dc4231fd088fb8553dd/mlflow_skinny-2.18.0-py3-none-any.whl", hash = "sha256:b924730b38cf9a7400737aa3e011c97edf978eed354bb0eb89ccb1f9e42764dc", size = 5793030 }, +] + [[package]] name = "mpmath" version = "1.3.0" @@ -1150,6 +1292,46 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/87/20/199b8713428322a2f22b722c62b8cc278cc53dffa9705d744484b5035ee9/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a", size = 99144 }, ] +[[package]] +name = "opentelemetry-api" +version = "1.28.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "deprecated", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "importlib-metadata", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/51/34/e4e9245c868c6490a46ffedf6bd5b0f512bbc0a848b19e3a51f6bbad648c/opentelemetry_api-1.28.2.tar.gz", hash = "sha256:ecdc70c7139f17f9b0cf3742d57d7020e3e8315d6cffcdf1a12a905d45b19cc0", size = 62796 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4d/58/b17393cdfc149e14ee84c662abf921993dcce8058628359ef1f49e2abb97/opentelemetry_api-1.28.2-py3-none-any.whl", hash = "sha256:6fcec89e265beb258fe6b1acaaa3c8c705a934bd977b9f534a2b7c0d2d4275a6", size = 64302 }, +] + +[[package]] +name = "opentelemetry-sdk" +version = "1.28.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "opentelemetry-semantic-conventions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4b/f4/840a5af4efe48d7fb4c456ad60fd624673e871a60d6494f7ff8a934755d4/opentelemetry_sdk-1.28.2.tar.gz", hash = "sha256:5fed24c5497e10df30282456fe2910f83377797511de07d14cec0d3e0a1a3110", size = 157272 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/da/8b/4f2b418496c08016d4384f9b1c4725a8af7faafa248d624be4bb95993ce1/opentelemetry_sdk-1.28.2-py3-none-any.whl", hash = "sha256:93336c129556f1e3ccd21442b94d3521759541521861b2214c499571b85cb71b", size = 118757 }, +] + +[[package]] +name = "opentelemetry-semantic-conventions" +version = "0.49b2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "deprecated", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "opentelemetry-api", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7d/0a/e3b93f94aa3223c6fd8e743502a1fefd4fb3a753d8f501ce2a418f7c0bd4/opentelemetry_semantic_conventions-0.49b2.tar.gz", hash = "sha256:44e32ce6a5bb8d7c0c617f84b9dc1c8deda1045a07dc16a688cc7cbeab679997", size = 95213 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b1/be/6661c8f76708bb3ba38c90be8fa8d7ffe17ccbc5cbbc229334f5535f6448/opentelemetry_semantic_conventions-0.49b2-py3-none-any.whl", hash = "sha256:51e7e1d0daa958782b6c2a8ed05e5f0e7dd0716fc327ac058777b8659649ee54", size = 159199 }, +] + [[package]] name = "opt-einsum" version = "3.4.0" @@ -1378,6 +1560,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl", hash = "sha256:1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0", size = 11842 }, ] +[[package]] +name = "pyasn1" +version = "0.6.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ba/e9/01f1a64245b89f039897cb0130016d79f77d52669aae6ee7b159a6c4c018/pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034", size = 145322 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/f1/d6a797abb14f6283c0ddff96bbdd46937f64122b8c925cab503dd37f8214/pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629", size = 83135 }, +] + +[[package]] +name = "pyasn1-modules" +version = "0.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyasn1", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1d/67/6afbf0d507f73c32d21084a79946bfcfca5fbc62a72057e9c23797a737c9/pyasn1_modules-0.4.1.tar.gz", hash = "sha256:c28e2dbf9c06ad61c71a075c7e0f9fd0f1b0bb2d2ad4377f240d33ac2ab60a7c", size = 310028 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/77/89/bc88a6711935ba795a679ea6ebee07e128050d6382eaa35a0a47c8032bdc/pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd", size = 181537 }, +] + [[package]] name = "pycparser" version = "2.22" @@ -1561,6 +1764,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/19/71/39c7c0d87f8d4e6c020a393182060eaefeeae6c01dab6a84ec346f2567df/rich-13.9.4-py3-none-any.whl", hash = "sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90", size = 242424 }, ] +[[package]] +name = "rsa" +version = "4.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyasn1", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/aa/65/7d973b89c4d2351d7fb232c2e452547ddfa243e93131e7cfa766da627b52/rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21", size = 29711 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/49/97/fa78e3d2f65c02c8e1268b9aba606569fe97f6c8f7c2d74394553347c145/rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7", size = 34315 }, +] + [[package]] name = "scipy" version = "1.14.1" @@ -1653,6 +1868,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050 }, ] +[[package]] +name = "smmap" +version = "5.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/88/04/b5bf6d21dc4041000ccba7eb17dd3055feb237e7ffc2c20d3fae3af62baa/smmap-5.0.1.tar.gz", hash = "sha256:dceeb6c0028fdb6734471eb07c0cd2aae706ccaecab45965ee83f11c8d3b1f62", size = 22291 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/a5/10f97f73544edcdef54409f1d839f6049a0d79df68adbc1ceb24d1aaca42/smmap-5.0.1-py3-none-any.whl", hash = "sha256:e6d8668fa5f93e706934a62d7b4db19c8d9eb8cf2adbb75ef1b675aa332b69da", size = 24282 }, +] + +[[package]] +name = "sqlparse" +version = "0.5.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/57/61/5bc3aff85dc5bf98291b37cf469dab74b3d0aef2dd88eade9070a200af05/sqlparse-0.5.2.tar.gz", hash = "sha256:9e37b35e16d1cc652a2545f0997c1deb23ea28fa1f3eefe609eee3063c3b105f", size = 84951 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7a/13/5f6654c9d915077fae255686ca6fa42095b62b7337e3e1aa9e82caa6f43a/sqlparse-0.5.2-py3-none-any.whl", hash = "sha256:e99bc85c78160918c3e1d9230834ab8d80fc06c59d03f8db2618f65f65dda55e", size = 44407 }, +] + [[package]] name = "stack-data" version = "0.6.3" From 49cb536b1bf97cccd1cde07ef4eed7cfd192e63a Mon Sep 17 00:00:00 2001 From: Gary Miguel Date: Tue, 10 Dec 2024 14:03:06 -0800 Subject: [PATCH 2/3] test mlflow in workflow Change-Id: I650c7cfdaecc3da47df3ed1653790b97e84ee34f --- .github/workflows/checks.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 0beaab0..93565c2 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -16,7 +16,7 @@ jobs: with: enable-cache: true - name: pytest - run: uv run --extra test pytest --capture=no --verbose --cov --cov-report=xml --ignore-glob='jax_loop_utils/metric_writers/tf/*' --ignore-glob='jax_loop_utils/metric_writers/torch/*' jax_loop_utils/ + run: uv run --extra test pytest --capture=no --verbose --cov --cov-report=xml --ignore-glob='jax_loop_utils/metric_writers/tf/*' --ignore-glob='jax_loop_utils/metric_writers/torch/*' --ignore-glob='jax_loop_utils/metric_writers/mlflow/*' jax_loop_utils/ working-directory: src - name: pytest tensorflow run: uv run --extra test --extra tensorflow pytest --capture=no --verbose --cov --cov-report=xml --cov-append jax_loop_utils/metric_writers/tf @@ -24,6 +24,9 @@ jobs: - name: pytest torch run: uv run --extra test --extra torch pytest --capture=no --verbose --cov --cov-report=xml --cov-append jax_loop_utils/metric_writers/torch working-directory: src + - name: pytest mlflow + run: uv run --extra test --extra mlflow pytest --capture=no --verbose --cov --cov-report=xml --cov-append jax_loop_utils/metric_writers/mlflow + working-directory: src - name: Upload coverage reports to Codecov if: always() uses: codecov/codecov-action@v4 From 56536f0ba52b17e52f5e67cb437c141e45c1755b Mon Sep 17 00:00:00 2001 From: Gary Miguel Date: Tue, 10 Dec 2024 16:14:54 -0800 Subject: [PATCH 3/3] PR comments: * rename class * use list comprehension Change-Id: I68fe00540b7805cf7b7f520d80333915735d34f2 --- .../metric_writers/mlflow/__init__.py | 4 ++-- .../metric_writers/mlflow/metric_writer.py | 9 +++++---- .../metric_writers/mlflow/metric_writer_test.py | 16 ++++++++-------- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/jax_loop_utils/metric_writers/mlflow/__init__.py b/src/jax_loop_utils/metric_writers/mlflow/__init__.py index db3025a..5c4f83f 100644 --- a/src/jax_loop_utils/metric_writers/mlflow/__init__.py +++ b/src/jax_loop_utils/metric_writers/mlflow/__init__.py @@ -1,3 +1,3 @@ -from .metric_writer import MetricWriter +from .metric_writer import MlflowMetricWriter -__all__ = ["MetricWriter"] +__all__ = ["MlflowMetricWriter"] diff --git a/src/jax_loop_utils/metric_writers/mlflow/metric_writer.py b/src/jax_loop_utils/metric_writers/mlflow/metric_writer.py index d7397f2..ec47b89 100644 --- a/src/jax_loop_utils/metric_writers/mlflow/metric_writer.py +++ b/src/jax_loop_utils/metric_writers/mlflow/metric_writer.py @@ -19,7 +19,7 @@ ) -class MetricWriter(MetricWriterInterface): +class MlflowMetricWriter(MetricWriterInterface): """MLflow implementation of MetricWriter.""" def __init__(self, experiment_name: str, tracking_uri: str | None = None): @@ -46,10 +46,11 @@ def __init__(self, experiment_name: str, tracking_uri: str | None = None): def write_scalars(self, step: int, scalars: Mapping[str, Scalar]): """Write scalar metrics to MLflow.""" - metrics_list: list[mlflow.entities.Metric] = [] timestamp = int(time() * 1000) - for k, v in scalars.items(): - metrics_list.append(mlflow.entities.Metric(k, float(v), timestamp, step)) + metrics_list = [ + mlflow.entities.Metric(k, float(v), timestamp, step) + for k, v in scalars.items() + ] self._client.log_batch(self._run_id, metrics=metrics_list, synchronous=False) def write_images(self, step: int, images: Mapping[str, Array]): diff --git a/src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py b/src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py index 7c43649..d951b37 100644 --- a/src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py +++ b/src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py @@ -6,7 +6,7 @@ import numpy as np from absl.testing import absltest -from jax_loop_utils.metric_writers.mlflow import MetricWriter +from jax_loop_utils.metric_writers.mlflow import MlflowMetricWriter def _get_runs(tracking_uri: str, experiment_name: str) -> list[mlflow.entities.Run]: @@ -16,12 +16,12 @@ def _get_runs(tracking_uri: str, experiment_name: str) -> list[mlflow.entities.R return client.search_runs([experiment.experiment_id]) -class MetricWriterTest(absltest.TestCase): +class MlflowMetricWriterTest(absltest.TestCase): def test_write_scalars(self): with tempfile.TemporaryDirectory() as temp_dir: tracking_uri = f"file://{temp_dir}" experiment_name = "experiment_name" - writer = MetricWriter(experiment_name, tracking_uri=tracking_uri) + writer = MlflowMetricWriter(experiment_name, tracking_uri=tracking_uri) seq_of_scalars = ( {"a": 3, "b": 0.15}, {"a": 5, "b": 0.007}, @@ -45,7 +45,7 @@ def test_write_scalars(self): run = runs[0] self.assertEqual(run.info.status, "FINISHED") # check we can create a new writer with an existing experiment - writer = MetricWriter(experiment_name, tracking_uri=tracking_uri) + writer = MlflowMetricWriter(experiment_name, tracking_uri=tracking_uri) writer.write_scalars(0, {"a": 1, "b": 2}) writer.flush() writer.close() @@ -57,7 +57,7 @@ def test_write_images(self): with tempfile.TemporaryDirectory() as temp_dir: tracking_uri = f"file://{temp_dir}" experiment_name = "experiment_name" - writer = MetricWriter(experiment_name, tracking_uri=tracking_uri) + writer = MlflowMetricWriter(experiment_name, tracking_uri=tracking_uri) writer.write_images(0, {"test_image": np.zeros((3, 3, 3), dtype=np.uint8)}) writer.close() @@ -79,7 +79,7 @@ def test_write_texts(self): with tempfile.TemporaryDirectory() as temp_dir: tracking_uri = f"file://{temp_dir}" experiment_name = "experiment_name" - writer = MetricWriter(experiment_name, tracking_uri=tracking_uri) + writer = MlflowMetricWriter(experiment_name, tracking_uri=tracking_uri) test_text = "Hello world" writer.write_texts(0, {"test_text": test_text}) writer.close() @@ -102,7 +102,7 @@ def test_write_hparams(self): with tempfile.TemporaryDirectory() as temp_dir: tracking_uri = f"file://{temp_dir}" experiment_name = "experiment_name" - writer = MetricWriter(experiment_name, tracking_uri=tracking_uri) + writer = MlflowMetricWriter(experiment_name, tracking_uri=tracking_uri) test_params = {"learning_rate": 0.001, "batch_size": 32, "epochs": 100} writer.write_hparams(test_params) writer.close() @@ -118,7 +118,7 @@ def test_no_ops(self): with tempfile.TemporaryDirectory() as temp_dir: tracking_uri = f"file://{temp_dir}" experiment_name = "experiment_name" - writer = MetricWriter(experiment_name, tracking_uri=tracking_uri) + writer = MlflowMetricWriter(experiment_name, tracking_uri=tracking_uri) writer.write_videos(0, {"video": np.zeros((4, 28, 28, 3))}) writer.write_audios(0, {"audio": np.zeros((2, 1000))}, sample_rate=16000) writer.write_histograms(