-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Change-Id: I2a6bf1d675b60d7862726f922c6f0c13836d0124
- Loading branch information
Showing
3 changed files
with
142 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
"""Writer that adds prefix and suffix to metric keys.""" | ||
|
||
from typing import Any, Mapping, Optional | ||
|
||
from jax_loop_utils.metric_writers import interface | ||
|
||
|
||
class PrefixSuffixWriter(interface.MetricWriter): | ||
"""Wraps a MetricWriter and adds prefix/suffix to all keys.""" | ||
|
||
def __init__( | ||
self, | ||
writer: interface.MetricWriter, | ||
prefix: str = "", | ||
suffix: str = "", | ||
): | ||
"""Initialize the writer. | ||
Args: | ||
writer: The underlying MetricWriter to wrap | ||
prefix: String to prepend to all keys | ||
suffix: String to append to all keys | ||
""" | ||
self._writer = writer | ||
self._prefix = prefix | ||
self._suffix = suffix | ||
|
||
def _transform_keys(self, data: Mapping[str, Any]) -> dict[str, Any]: | ||
"""Add prefix and suffix to all keys in the mapping.""" | ||
return { | ||
f"{self._prefix}{key}{self._suffix}": value for key, value in data.items() | ||
} | ||
|
||
def write_scalars(self, step: int, scalars: Mapping[str, interface.Scalar]): | ||
self._writer.write_scalars(step, self._transform_keys(scalars)) | ||
|
||
def write_images(self, step: int, images: Mapping[str, interface.Array]): | ||
self._writer.write_images(step, self._transform_keys(images)) | ||
|
||
def write_videos(self, step: int, videos: Mapping[str, interface.Array]): | ||
self._writer.write_videos(step, self._transform_keys(videos)) | ||
|
||
def write_audios( | ||
self, step: int, audios: Mapping[str, interface.Array], *, sample_rate: int | ||
): | ||
self._writer.write_audios( | ||
step, self._transform_keys(audios), sample_rate=sample_rate | ||
) | ||
|
||
def write_texts(self, step: int, texts: Mapping[str, str]): | ||
self._writer.write_texts(step, self._transform_keys(texts)) | ||
|
||
def write_histograms( | ||
self, | ||
step: int, | ||
arrays: Mapping[str, interface.Array], | ||
num_buckets: Optional[Mapping[str, int]] = None, | ||
): | ||
if num_buckets is not None: | ||
num_buckets = self._transform_keys(num_buckets) | ||
self._writer.write_histograms(step, self._transform_keys(arrays), num_buckets) | ||
|
||
def write_hparams(self, hparams: Mapping[str, Any]): | ||
self._writer.write_hparams(self._transform_keys(hparams)) | ||
|
||
def flush(self): | ||
self._writer.flush() | ||
|
||
def close(self): | ||
self._writer.close() | ||
68 changes: 68 additions & 0 deletions
68
src/jax_loop_utils/metric_writers/prefix_suffix_writer_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
"""Tests for PrefixSuffixWriter.""" | ||
|
||
import numpy as np | ||
from absl.testing import absltest | ||
|
||
from jax_loop_utils.metric_writers import memory_writer, prefix_suffix_writer | ||
|
||
|
||
class PrefixSuffixWriterTest(absltest.TestCase): | ||
def setUp(self): | ||
super().setUp() | ||
self.memory_writer = memory_writer.MemoryWriter() | ||
self.writer = prefix_suffix_writer.PrefixSuffixWriter( | ||
self.memory_writer, | ||
prefix="prefix/", | ||
suffix="/suffix", | ||
) | ||
|
||
def test_write_scalars(self): | ||
self.writer.write_scalars(0, {"metric": 1.0}) | ||
self.assertEqual(self.memory_writer.scalars, {0: {"prefix/metric/suffix": 1.0}}) | ||
|
||
def test_write_images(self): | ||
image = np.zeros((2, 2, 3)) | ||
self.writer.write_images(0, {"image": image}) | ||
self.assertEqual( | ||
list(self.memory_writer.images[0].keys()), ["prefix/image/suffix"] | ||
) | ||
|
||
def test_write_texts(self): | ||
self.writer.write_texts(0, {"text": "hello"}) | ||
self.assertEqual(self.memory_writer.texts, {0: {"prefix/text/suffix": "hello"}}) | ||
|
||
def test_write_histograms(self): | ||
data = np.array([1, 2, 3]) | ||
buckets = {"hist": 10} | ||
self.writer.write_histograms(0, {"hist": data}, buckets) | ||
self.assertEqual( | ||
list(self.memory_writer.histograms[0].arrays.keys()), | ||
["prefix/hist/suffix"], | ||
) | ||
|
||
def test_write_hparams(self): | ||
self.writer.write_hparams({"param": 1}) | ||
self.assertEqual(self.memory_writer.hparams, {"prefix/param/suffix": 1}) | ||
|
||
def test_empty_prefix_suffix(self): | ||
writer = prefix_suffix_writer.PrefixSuffixWriter(self.memory_writer) | ||
writer.write_scalars(0, {"metric": 1.0}) | ||
self.assertEqual(self.memory_writer.scalars, {0: {"metric": 1.0}}) | ||
|
||
def test_write_videos(self): | ||
video = np.zeros((10, 32, 32, 3)) # Simple video array with 10 frames | ||
self.writer.write_videos(0, {"video": video}) | ||
self.assertEqual( | ||
list(self.memory_writer.videos[0].keys()), ["prefix/video/suffix"] | ||
) | ||
|
||
def test_write_audios(self): | ||
audio = np.zeros((16000,)) # 1 second of audio at 16kHz | ||
self.writer.write_audios(0, {"audio": audio}, sample_rate=16000) | ||
self.assertEqual( | ||
list(self.memory_writer.audios[0].audios.keys()), ["prefix/audio/suffix"] | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
absltest.main() | ||