diff --git a/src/jax_loop_utils/metric_writers/__init__.py b/src/jax_loop_utils/metric_writers/__init__.py index 7ae955e..892f040 100644 --- a/src/jax_loop_utils/metric_writers/__init__.py +++ b/src/jax_loop_utils/metric_writers/__init__.py @@ -48,7 +48,9 @@ from jax_loop_utils.metric_writers.interface import MetricWriter from jax_loop_utils.metric_writers.keep_last_writer import KeepLastWriter from jax_loop_utils.metric_writers.logging_writer import LoggingWriter +from jax_loop_utils.metric_writers.memory_writer import MemoryWriter from jax_loop_utils.metric_writers.multi_writer import MultiWriter +from jax_loop_utils.metric_writers.prefix_suffix_writer import PrefixSuffixWriter from jax_loop_utils.metric_writers.utils import write_values # TODO(b/200953513): Migrate away from logging imports (on module level) @@ -60,7 +62,9 @@ "ensure_flushes", "KeepLastWriter", "LoggingWriter", + "MemoryWriter", "MetricWriter", "MultiWriter", + "PrefixSuffixWriter", "write_values", ] diff --git a/src/jax_loop_utils/metric_writers/prefix_suffix_writer.py b/src/jax_loop_utils/metric_writers/prefix_suffix_writer.py new file mode 100644 index 0000000..cef4e5e --- /dev/null +++ b/src/jax_loop_utils/metric_writers/prefix_suffix_writer.py @@ -0,0 +1,70 @@ +"""Writer that adds prefix and suffix to metric keys.""" + +from typing import Any, Mapping, Optional + +from jax_loop_utils.metric_writers import interface + + +class PrefixSuffixWriter(interface.MetricWriter): + """Wraps a MetricWriter and adds prefix/suffix to all keys.""" + + def __init__( + self, + writer: interface.MetricWriter, + prefix: str = "", + suffix: str = "", + ): + """Initialize the writer. + + Args: + writer: The underlying MetricWriter to wrap + prefix: String to prepend to all keys + suffix: String to append to all keys + """ + self._writer = writer + self._prefix = prefix + self._suffix = suffix + + def _transform_keys(self, data: Mapping[str, Any]) -> dict[str, Any]: + """Add prefix and suffix to all keys in the mapping.""" + return { + f"{self._prefix}{key}{self._suffix}": value for key, value in data.items() + } + + def write_scalars(self, step: int, scalars: Mapping[str, interface.Scalar]): + self._writer.write_scalars(step, self._transform_keys(scalars)) + + def write_images(self, step: int, images: Mapping[str, interface.Array]): + self._writer.write_images(step, self._transform_keys(images)) + + def write_videos(self, step: int, videos: Mapping[str, interface.Array]): + self._writer.write_videos(step, self._transform_keys(videos)) + + def write_audios( + self, step: int, audios: Mapping[str, interface.Array], *, sample_rate: int + ): + self._writer.write_audios( + step, self._transform_keys(audios), sample_rate=sample_rate + ) + + def write_texts(self, step: int, texts: Mapping[str, str]): + self._writer.write_texts(step, self._transform_keys(texts)) + + def write_histograms( + self, + step: int, + arrays: Mapping[str, interface.Array], + num_buckets: Optional[Mapping[str, int]] = None, + ): + if num_buckets is not None: + num_buckets = self._transform_keys(num_buckets) + self._writer.write_histograms(step, self._transform_keys(arrays), num_buckets) + + def write_hparams(self, hparams: Mapping[str, Any]): + self._writer.write_hparams(self._transform_keys(hparams)) + + def flush(self): + self._writer.flush() + + def close(self): + self._writer.close() diff --git a/src/jax_loop_utils/metric_writers/prefix_suffix_writer_test.py b/src/jax_loop_utils/metric_writers/prefix_suffix_writer_test.py new file mode 100644 index 0000000..b255b10 --- /dev/null +++ b/src/jax_loop_utils/metric_writers/prefix_suffix_writer_test.py @@ -0,0 +1,80 @@ +"""Tests for PrefixSuffixWriter.""" + +from unittest import mock + +import numpy as np +from absl.testing import absltest + +from jax_loop_utils.metric_writers import memory_writer, prefix_suffix_writer + + +class PrefixSuffixWriterTest(absltest.TestCase): + def setUp(self): + super().setUp() + self.memory_writer = memory_writer.MemoryWriter() + self.writer = prefix_suffix_writer.PrefixSuffixWriter( + self.memory_writer, + prefix="prefix/", + suffix="/suffix", + ) + + def test_write_scalars(self): + self.writer.write_scalars(0, {"metric": 1.0}) + self.assertEqual(self.memory_writer.scalars, {0: {"prefix/metric/suffix": 1.0}}) + + def test_write_images(self): + image = np.zeros((2, 2, 3)) + self.writer.write_images(0, {"image": image}) + self.assertEqual( + list(self.memory_writer.images[0].keys()), ["prefix/image/suffix"] + ) + + def test_write_texts(self): + self.writer.write_texts(0, {"text": "hello"}) + self.assertEqual(self.memory_writer.texts, {0: {"prefix/text/suffix": "hello"}}) + + def test_write_histograms(self): + data = np.array([1, 2, 3]) + buckets = {"hist": 10} + self.writer.write_histograms(0, {"hist": data}, buckets) + self.assertEqual( + list(self.memory_writer.histograms[0].arrays.keys()), + ["prefix/hist/suffix"], + ) + + def test_write_hparams(self): + self.writer.write_hparams({"param": 1}) + self.assertEqual(self.memory_writer.hparams, {"prefix/param/suffix": 1}) + + def test_empty_prefix_suffix(self): + writer = prefix_suffix_writer.PrefixSuffixWriter(self.memory_writer) + writer.write_scalars(0, {"metric": 1.0}) + self.assertEqual(self.memory_writer.scalars, {0: {"metric": 1.0}}) + + def test_write_videos(self): + video = np.zeros((10, 32, 32, 3)) # Simple video array with 10 frames + self.writer.write_videos(0, {"video": video}) + self.assertEqual( + list(self.memory_writer.videos[0].keys()), ["prefix/video/suffix"] + ) + + def test_write_audios(self): + audio = np.zeros((16000,)) # 1 second of audio at 16kHz + self.writer.write_audios(0, {"audio": audio}, sample_rate=16000) + self.assertEqual( + list(self.memory_writer.audios[0].audios.keys()), ["prefix/audio/suffix"] + ) + + def test_close(self): + with mock.patch.object(self.memory_writer, "close") as mock_close: + self.writer.close() + mock_close.assert_called_once() + + def test_flush(self): + with mock.patch.object(self.memory_writer, "flush") as mock_flush: + self.writer.flush() + mock_flush.assert_called_once() + + +if __name__ == "__main__": + absltest.main()