diff --git a/pyproject.toml b/pyproject.toml index f39de7c..2dd7f67 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ authors = [{ name = "Astera Institute", email = "no-reply@astera.org" }] dependencies = [ "absl-py", "etils[epath,epy]", - "jax", + "jax>=0.4.36", "numpy", "packaging", "Pillow", diff --git a/src/jax_loop_utils/metric_writers/__init__.py b/src/jax_loop_utils/metric_writers/__init__.py index 716a8f5..892f040 100644 --- a/src/jax_loop_utils/metric_writers/__init__.py +++ b/src/jax_loop_utils/metric_writers/__init__.py @@ -46,8 +46,11 @@ ensure_flushes, ) from jax_loop_utils.metric_writers.interface import MetricWriter +from jax_loop_utils.metric_writers.keep_last_writer import KeepLastWriter from jax_loop_utils.metric_writers.logging_writer import LoggingWriter +from jax_loop_utils.metric_writers.memory_writer import MemoryWriter from jax_loop_utils.metric_writers.multi_writer import MultiWriter +from jax_loop_utils.metric_writers.prefix_suffix_writer import PrefixSuffixWriter from jax_loop_utils.metric_writers.utils import write_values # TODO(b/200953513): Migrate away from logging imports (on module level) @@ -57,8 +60,11 @@ "AsyncMultiWriter", "AsyncWriter", "ensure_flushes", + "KeepLastWriter", "LoggingWriter", + "MemoryWriter", "MetricWriter", "MultiWriter", + "PrefixSuffixWriter", "write_values", ] diff --git a/src/jax_loop_utils/metric_writers/keep_last_writer.py b/src/jax_loop_utils/metric_writers/keep_last_writer.py new file mode 100644 index 0000000..ceadeda --- /dev/null +++ b/src/jax_loop_utils/metric_writers/keep_last_writer.py @@ -0,0 +1,53 @@ +from collections.abc import Mapping +from typing import Any, Optional + +from .interface import Array, MetricWriter, Scalar + + +class KeepLastWriter(MetricWriter): + """MetricWriter that keeps the last value for each metric in memory.""" + + def __init__(self, inner: MetricWriter): + self._inner: MetricWriter = inner + self.scalars: Optional[Mapping[str, Scalar]] = None + self.images: Optional[Mapping[str, Array]] = None + self.videos: Optional[Mapping[str, Array]] = None + self.audios: Optional[Mapping[str, Array]] = None + self.texts: Optional[Mapping[str, str]] = None + self.hparams: Optional[Mapping[str, Any]] = None + self.histogram_arrays: Optional[Mapping[str, Array]] = None + self.histogram_num_buckets: Optional[Mapping[str, int]] = None + + def write_scalars(self, step: int, scalars: Mapping[str, Scalar]): + self._inner.write_scalars(step, scalars) + self.scalars = scalars + + def write_images(self, step: int, images: Mapping[str, Array]): + self._inner.write_images(step, images) + self.images = images + + def write_videos(self, step: int, videos: Mapping[str, Array]): + self._inner.write_videos(step, videos) + self.videos = videos + + def write_audios(self, step: int, audios: Mapping[str, Array], *, sample_rate: int): + self._inner.write_audios(step, audios, sample_rate=sample_rate) + self.audios = audios + + def write_texts(self, step: int, texts: Mapping[str, str]): + self._inner.write_texts(step, texts) + self.texts = texts + + def write_hparams(self, hparams: Mapping[str, Any]): + self._inner.write_hparams(hparams) + self.hparams = hparams + + def write_histograms( + self, + step: int, + arrays: Mapping[str, Array], + num_buckets: Optional[Mapping[str, int]] = None, + ): + self._inner.write_histograms(step, arrays, num_buckets) + self.histogram_arrays = arrays + self.histogram_num_buckets = num_buckets diff --git a/src/jax_loop_utils/metric_writers/keep_last_writer_test.py b/src/jax_loop_utils/metric_writers/keep_last_writer_test.py new file mode 100644 index 0000000..17b313c --- /dev/null +++ b/src/jax_loop_utils/metric_writers/keep_last_writer_test.py @@ -0,0 +1,77 @@ +"""Tests for KeepLastMetricWriter.""" + +import numpy as np +from absl.testing import absltest + +from jax_loop_utils.metric_writers import keep_last_writer, noop_writer + + +class KeepLastWriterTest(absltest.TestCase): + def setUp(self): + super().setUp() + self.writer = keep_last_writer.KeepLastWriter(noop_writer.NoOpWriter()) + + def test_write_scalars(self): + scalars1 = {"metric1": 1.0} + scalars2 = {"metric2": 2.0} + + self.writer.write_scalars(0, scalars1) + self.assertEqual(self.writer.scalars, scalars1) + + self.writer.write_scalars(1, scalars2) + self.assertEqual(self.writer.scalars, scalars2) + + def test_write_images(self): + image = np.zeros((2, 2, 3)) + images = {"image": image} + + self.writer.write_images(0, images) + self.assertEqual(self.writer.images, images) + + def test_write_videos(self): + video = np.zeros((10, 2, 2, 3)) + videos = {"video": video} + + self.writer.write_videos(0, videos) + self.assertEqual(self.writer.videos, videos) + + def test_write_audios(self): + audio = np.zeros((100, 2)) + audios = {"audio": audio} + + self.writer.write_audios(0, audios, sample_rate=44100) + self.assertEqual(self.writer.audios, audios) + + def test_write_texts(self): + texts = {"text": "hello world"} + + self.writer.write_texts(0, texts) + self.assertEqual(self.writer.texts, texts) + + def test_write_hparams(self): + hparams = {"learning_rate": 0.1} + + self.writer.write_hparams(hparams) + self.assertEqual(self.writer.hparams, hparams) + + def test_write_histograms(self): + arrays = {"hist": np.array([1, 2, 3])} + num_buckets = {"hist": 10} + + self.writer.write_histograms(0, arrays, num_buckets) + self.assertEqual(self.writer.histogram_arrays, arrays) + self.assertEqual(self.writer.histogram_num_buckets, num_buckets) + + def test_initial_values_are_none(self): + self.assertIsNone(self.writer.scalars) + self.assertIsNone(self.writer.images) + self.assertIsNone(self.writer.videos) + self.assertIsNone(self.writer.audios) + self.assertIsNone(self.writer.texts) + self.assertIsNone(self.writer.hparams) + self.assertIsNone(self.writer.histogram_arrays) + self.assertIsNone(self.writer.histogram_num_buckets) + + +if __name__ == "__main__": + absltest.main() diff --git a/src/jax_loop_utils/metric_writers/mlflow/metric_writer.py b/src/jax_loop_utils/metric_writers/mlflow/metric_writer.py index ec47b89..508ea5e 100644 --- a/src/jax_loop_utils/metric_writers/mlflow/metric_writer.py +++ b/src/jax_loop_utils/metric_writers/mlflow/metric_writer.py @@ -20,17 +20,22 @@ class MlflowMetricWriter(MetricWriterInterface): - """MLflow implementation of MetricWriter.""" + """Writes metrics to MLflow Tracking.""" - def __init__(self, experiment_name: str, tracking_uri: str | None = None): + def __init__( + self, + experiment_name: str, + run_name: str | None = None, + tracking_uri: str | None = None, + ): """Initialize MLflow writer. Args: experiment_name: Name of the MLflow experiment. - tracking_uri: Address of local or remote tracking server. If not provided, defaults - to the service set by ``mlflow.tracking.set_tracking_uri``. See - `Where Runs Get Recorded <../tracking.html#where-runs-get-recorded>`_ - for more info. + run_name: Name of the MLflow run. + tracking_uri: Address of local or remote tracking server. + Treated the same as arguments to mlflow.set_tracking_uri. + See https://www.mlflow.org/docs/latest/python_api/mlflow.html#mlflow.set_tracking_uri """ self._client = mlflow.MlflowClient(tracking_uri=tracking_uri) experiment = self._client.get_experiment_by_name(experiment_name) @@ -42,7 +47,9 @@ def __init__(self, experiment_name: str, tracking_uri: str | None = None): experiment_name, ) experiment_id = self._client.create_experiment(experiment_name) - self._run_id = self._client.create_run(experiment_id=experiment_id).info.run_id + self._run_id = self._client.create_run( + experiment_id=experiment_id, run_name=run_name + ).info.run_id def write_scalars(self, step: int, scalars: Mapping[str, Scalar]): """Write scalar metrics to MLflow.""" diff --git a/src/jax_loop_utils/metric_writers/prefix_suffix_writer.py b/src/jax_loop_utils/metric_writers/prefix_suffix_writer.py new file mode 100644 index 0000000..cef4e5e --- /dev/null +++ b/src/jax_loop_utils/metric_writers/prefix_suffix_writer.py @@ -0,0 +1,70 @@ +"""Writer that adds prefix and suffix to metric keys.""" + +from typing import Any, Mapping, Optional + +from jax_loop_utils.metric_writers import interface + + +class PrefixSuffixWriter(interface.MetricWriter): + """Wraps a MetricWriter and adds prefix/suffix to all keys.""" + + def __init__( + self, + writer: interface.MetricWriter, + prefix: str = "", + suffix: str = "", + ): + """Initialize the writer. + + Args: + writer: The underlying MetricWriter to wrap + prefix: String to prepend to all keys + suffix: String to append to all keys + """ + self._writer = writer + self._prefix = prefix + self._suffix = suffix + + def _transform_keys(self, data: Mapping[str, Any]) -> dict[str, Any]: + """Add prefix and suffix to all keys in the mapping.""" + return { + f"{self._prefix}{key}{self._suffix}": value for key, value in data.items() + } + + def write_scalars(self, step: int, scalars: Mapping[str, interface.Scalar]): + self._writer.write_scalars(step, self._transform_keys(scalars)) + + def write_images(self, step: int, images: Mapping[str, interface.Array]): + self._writer.write_images(step, self._transform_keys(images)) + + def write_videos(self, step: int, videos: Mapping[str, interface.Array]): + self._writer.write_videos(step, self._transform_keys(videos)) + + def write_audios( + self, step: int, audios: Mapping[str, interface.Array], *, sample_rate: int + ): + self._writer.write_audios( + step, self._transform_keys(audios), sample_rate=sample_rate + ) + + def write_texts(self, step: int, texts: Mapping[str, str]): + self._writer.write_texts(step, self._transform_keys(texts)) + + def write_histograms( + self, + step: int, + arrays: Mapping[str, interface.Array], + num_buckets: Optional[Mapping[str, int]] = None, + ): + if num_buckets is not None: + num_buckets = self._transform_keys(num_buckets) + self._writer.write_histograms(step, self._transform_keys(arrays), num_buckets) + + def write_hparams(self, hparams: Mapping[str, Any]): + self._writer.write_hparams(self._transform_keys(hparams)) + + def flush(self): + self._writer.flush() + + def close(self): + self._writer.close() diff --git a/src/jax_loop_utils/metric_writers/prefix_suffix_writer_test.py b/src/jax_loop_utils/metric_writers/prefix_suffix_writer_test.py new file mode 100644 index 0000000..b255b10 --- /dev/null +++ b/src/jax_loop_utils/metric_writers/prefix_suffix_writer_test.py @@ -0,0 +1,80 @@ +"""Tests for PrefixSuffixWriter.""" + +from unittest import mock + +import numpy as np +from absl.testing import absltest + +from jax_loop_utils.metric_writers import memory_writer, prefix_suffix_writer + + +class PrefixSuffixWriterTest(absltest.TestCase): + def setUp(self): + super().setUp() + self.memory_writer = memory_writer.MemoryWriter() + self.writer = prefix_suffix_writer.PrefixSuffixWriter( + self.memory_writer, + prefix="prefix/", + suffix="/suffix", + ) + + def test_write_scalars(self): + self.writer.write_scalars(0, {"metric": 1.0}) + self.assertEqual(self.memory_writer.scalars, {0: {"prefix/metric/suffix": 1.0}}) + + def test_write_images(self): + image = np.zeros((2, 2, 3)) + self.writer.write_images(0, {"image": image}) + self.assertEqual( + list(self.memory_writer.images[0].keys()), ["prefix/image/suffix"] + ) + + def test_write_texts(self): + self.writer.write_texts(0, {"text": "hello"}) + self.assertEqual(self.memory_writer.texts, {0: {"prefix/text/suffix": "hello"}}) + + def test_write_histograms(self): + data = np.array([1, 2, 3]) + buckets = {"hist": 10} + self.writer.write_histograms(0, {"hist": data}, buckets) + self.assertEqual( + list(self.memory_writer.histograms[0].arrays.keys()), + ["prefix/hist/suffix"], + ) + + def test_write_hparams(self): + self.writer.write_hparams({"param": 1}) + self.assertEqual(self.memory_writer.hparams, {"prefix/param/suffix": 1}) + + def test_empty_prefix_suffix(self): + writer = prefix_suffix_writer.PrefixSuffixWriter(self.memory_writer) + writer.write_scalars(0, {"metric": 1.0}) + self.assertEqual(self.memory_writer.scalars, {0: {"metric": 1.0}}) + + def test_write_videos(self): + video = np.zeros((10, 32, 32, 3)) # Simple video array with 10 frames + self.writer.write_videos(0, {"video": video}) + self.assertEqual( + list(self.memory_writer.videos[0].keys()), ["prefix/video/suffix"] + ) + + def test_write_audios(self): + audio = np.zeros((16000,)) # 1 second of audio at 16kHz + self.writer.write_audios(0, {"audio": audio}, sample_rate=16000) + self.assertEqual( + list(self.memory_writer.audios[0].audios.keys()), ["prefix/audio/suffix"] + ) + + def test_close(self): + with mock.patch.object(self.memory_writer, "close") as mock_close: + self.writer.close() + mock_close.assert_called_once() + + def test_flush(self): + with mock.patch.object(self.memory_writer, "flush") as mock_flush: + self.writer.flush() + mock_flush.assert_called_once() + + +if __name__ == "__main__": + absltest.main() diff --git a/uv.lock b/uv.lock index f2ec1e3..6d87ecb 100644 --- a/uv.lock +++ b/uv.lock @@ -734,7 +734,7 @@ requires-dist = [ { name = "etils", extras = ["epath", "epy"] }, { name = "flax", marker = "extra == 'synopsis'" }, { name = "ipykernel", marker = "extra == 'synopsis'" }, - { name = "jax" }, + { name = "jax", specifier = ">=0.4.36" }, { name = "matplotlib", marker = "extra == 'synopsis'" }, { name = "mlflow-skinny", marker = "extra == 'mlflow'", specifier = ">=2.0" }, { name = "numpy" },