Skip to content

Commit

Permalink
add prefix_suffix_writer
Browse files Browse the repository at this point in the history
Change-Id: I2a6bf1d675b60d7862726f922c6f0c13836d0124
  • Loading branch information
garymm committed Dec 13, 2024
1 parent 9393020 commit d98c2c8
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/jax_loop_utils/metric_writers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@
from jax_loop_utils.metric_writers.interface import MetricWriter
from jax_loop_utils.metric_writers.keep_last_writer import KeepLastWriter
from jax_loop_utils.metric_writers.logging_writer import LoggingWriter
from jax_loop_utils.metric_writers.memory_writer import MemoryWriter
from jax_loop_utils.metric_writers.multi_writer import MultiWriter
from jax_loop_utils.metric_writers.prefix_suffix_writer import PrefixSuffixWriter
from jax_loop_utils.metric_writers.utils import write_values

# TODO(b/200953513): Migrate away from logging imports (on module level)
Expand All @@ -60,7 +62,9 @@
"ensure_flushes",
"KeepLastWriter",
"LoggingWriter",
"MemoryWriter",
"MetricWriter",
"MultiWriter",
"PrefixSuffixWriter",
"write_values",
]
70 changes: 70 additions & 0 deletions src/jax_loop_utils/metric_writers/prefix_suffix_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Writer that adds prefix and suffix to metric keys."""

from typing import Any, Mapping, Optional

from jax_loop_utils.metric_writers import interface


class PrefixSuffixWriter(interface.MetricWriter):
"""Wraps a MetricWriter and adds prefix/suffix to all keys."""

def __init__(
self,
writer: interface.MetricWriter,
prefix: str = "",
suffix: str = "",
):
"""Initialize the writer.
Args:
writer: The underlying MetricWriter to wrap
prefix: String to prepend to all keys
suffix: String to append to all keys
"""
self._writer = writer
self._prefix = prefix
self._suffix = suffix

def _transform_keys(self, data: Mapping[str, Any]) -> dict[str, Any]:
"""Add prefix and suffix to all keys in the mapping."""
return {
f"{self._prefix}{key}{self._suffix}": value for key, value in data.items()
}

def write_scalars(self, step: int, scalars: Mapping[str, interface.Scalar]):
self._writer.write_scalars(step, self._transform_keys(scalars))

def write_images(self, step: int, images: Mapping[str, interface.Array]):
self._writer.write_images(step, self._transform_keys(images))

def write_videos(self, step: int, videos: Mapping[str, interface.Array]):
self._writer.write_videos(step, self._transform_keys(videos))

def write_audios(
self, step: int, audios: Mapping[str, interface.Array], *, sample_rate: int
):
self._writer.write_audios(
step, self._transform_keys(audios), sample_rate=sample_rate
)

def write_texts(self, step: int, texts: Mapping[str, str]):
self._writer.write_texts(step, self._transform_keys(texts))

def write_histograms(
self,
step: int,
arrays: Mapping[str, interface.Array],
num_buckets: Optional[Mapping[str, int]] = None,
):
if num_buckets is not None:
num_buckets = self._transform_keys(num_buckets)
self._writer.write_histograms(step, self._transform_keys(arrays), num_buckets)

def write_hparams(self, hparams: Mapping[str, Any]):
self._writer.write_hparams(self._transform_keys(hparams))

def flush(self):
self._writer.flush()

Check warning on line 67 in src/jax_loop_utils/metric_writers/prefix_suffix_writer.py

View check run for this annotation

Codecov / codecov/patch

src/jax_loop_utils/metric_writers/prefix_suffix_writer.py#L67

Added line #L67 was not covered by tests

def close(self):
self._writer.close()

Check warning on line 70 in src/jax_loop_utils/metric_writers/prefix_suffix_writer.py

View check run for this annotation

Codecov / codecov/patch

src/jax_loop_utils/metric_writers/prefix_suffix_writer.py#L70

Added line #L70 was not covered by tests
68 changes: 68 additions & 0 deletions src/jax_loop_utils/metric_writers/prefix_suffix_writer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Tests for PrefixSuffixWriter."""

import numpy as np
from absl.testing import absltest

from jax_loop_utils.metric_writers import memory_writer, prefix_suffix_writer


class PrefixSuffixWriterTest(absltest.TestCase):
def setUp(self):
super().setUp()
self.memory_writer = memory_writer.MemoryWriter()
self.writer = prefix_suffix_writer.PrefixSuffixWriter(
self.memory_writer,
prefix="prefix/",
suffix="/suffix",
)

def test_write_scalars(self):
self.writer.write_scalars(0, {"metric": 1.0})
self.assertEqual(self.memory_writer.scalars, {0: {"prefix/metric/suffix": 1.0}})

def test_write_images(self):
image = np.zeros((2, 2, 3))
self.writer.write_images(0, {"image": image})
self.assertEqual(
list(self.memory_writer.images[0].keys()), ["prefix/image/suffix"]
)

def test_write_texts(self):
self.writer.write_texts(0, {"text": "hello"})
self.assertEqual(self.memory_writer.texts, {0: {"prefix/text/suffix": "hello"}})

def test_write_histograms(self):
data = np.array([1, 2, 3])
buckets = {"hist": 10}
self.writer.write_histograms(0, {"hist": data}, buckets)
self.assertEqual(
list(self.memory_writer.histograms[0].arrays.keys()),
["prefix/hist/suffix"],
)

def test_write_hparams(self):
self.writer.write_hparams({"param": 1})
self.assertEqual(self.memory_writer.hparams, {"prefix/param/suffix": 1})

def test_empty_prefix_suffix(self):
writer = prefix_suffix_writer.PrefixSuffixWriter(self.memory_writer)
writer.write_scalars(0, {"metric": 1.0})
self.assertEqual(self.memory_writer.scalars, {0: {"metric": 1.0}})

def test_write_videos(self):
video = np.zeros((10, 32, 32, 3)) # Simple video array with 10 frames
self.writer.write_videos(0, {"video": video})
self.assertEqual(
list(self.memory_writer.videos[0].keys()), ["prefix/video/suffix"]
)

def test_write_audios(self):
audio = np.zeros((16000,)) # 1 second of audio at 16kHz
self.writer.write_audios(0, {"audio": audio}, sample_rate=16000)
self.assertEqual(
list(self.memory_writer.audios[0].audios.keys()), ["prefix/audio/suffix"]
)


if __name__ == "__main__":
absltest.main()

Check warning on line 68 in src/jax_loop_utils/metric_writers/prefix_suffix_writer_test.py

View check run for this annotation

Codecov / codecov/patch

src/jax_loop_utils/metric_writers/prefix_suffix_writer_test.py#L68

Added line #L68 was not covered by tests

0 comments on commit d98c2c8

Please sign in to comment.