Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

changes necessary for usage in earl #13

Merged
merged 4 commits into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ authors = [{ name = "Astera Institute", email = "[email protected]" }]
dependencies = [
"absl-py",
"etils[epath,epy]",
"jax",
"jax>=0.4.36",
"numpy",
"packaging",
"Pillow",
Expand Down
6 changes: 6 additions & 0 deletions src/jax_loop_utils/metric_writers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,11 @@
ensure_flushes,
)
from jax_loop_utils.metric_writers.interface import MetricWriter
from jax_loop_utils.metric_writers.keep_last_writer import KeepLastWriter
from jax_loop_utils.metric_writers.logging_writer import LoggingWriter
from jax_loop_utils.metric_writers.memory_writer import MemoryWriter
from jax_loop_utils.metric_writers.multi_writer import MultiWriter
from jax_loop_utils.metric_writers.prefix_suffix_writer import PrefixSuffixWriter
from jax_loop_utils.metric_writers.utils import write_values

# TODO(b/200953513): Migrate away from logging imports (on module level)
Expand All @@ -57,8 +60,11 @@
"AsyncMultiWriter",
"AsyncWriter",
"ensure_flushes",
"KeepLastWriter",
"LoggingWriter",
"MemoryWriter",
"MetricWriter",
"MultiWriter",
"PrefixSuffixWriter",
"write_values",
]
53 changes: 53 additions & 0 deletions src/jax_loop_utils/metric_writers/keep_last_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from collections.abc import Mapping
from typing import Any, Optional

from .interface import Array, MetricWriter, Scalar


class KeepLastWriter(MetricWriter):
"""MetricWriter that keeps the last value for each metric in memory."""

def __init__(self, inner: MetricWriter):
self._inner: MetricWriter = inner
self.scalars: Optional[Mapping[str, Scalar]] = None
self.images: Optional[Mapping[str, Array]] = None
self.videos: Optional[Mapping[str, Array]] = None
self.audios: Optional[Mapping[str, Array]] = None
self.texts: Optional[Mapping[str, str]] = None
self.hparams: Optional[Mapping[str, Any]] = None
self.histogram_arrays: Optional[Mapping[str, Array]] = None
self.histogram_num_buckets: Optional[Mapping[str, int]] = None

def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
self._inner.write_scalars(step, scalars)
self.scalars = scalars

def write_images(self, step: int, images: Mapping[str, Array]):
self._inner.write_images(step, images)
self.images = images

def write_videos(self, step: int, videos: Mapping[str, Array]):
self._inner.write_videos(step, videos)
self.videos = videos

def write_audios(self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
self._inner.write_audios(step, audios, sample_rate=sample_rate)
self.audios = audios

def write_texts(self, step: int, texts: Mapping[str, str]):
self._inner.write_texts(step, texts)
self.texts = texts

def write_hparams(self, hparams: Mapping[str, Any]):
self._inner.write_hparams(hparams)
self.hparams = hparams

def write_histograms(
self,
step: int,
arrays: Mapping[str, Array],
num_buckets: Optional[Mapping[str, int]] = None,
):
self._inner.write_histograms(step, arrays, num_buckets)
self.histogram_arrays = arrays
self.histogram_num_buckets = num_buckets
77 changes: 77 additions & 0 deletions src/jax_loop_utils/metric_writers/keep_last_writer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Tests for KeepLastMetricWriter."""

import numpy as np
from absl.testing import absltest

from jax_loop_utils.metric_writers import keep_last_writer, noop_writer


class KeepLastWriterTest(absltest.TestCase):
def setUp(self):
super().setUp()
self.writer = keep_last_writer.KeepLastWriter(noop_writer.NoOpWriter())

def test_write_scalars(self):
scalars1 = {"metric1": 1.0}
scalars2 = {"metric2": 2.0}

self.writer.write_scalars(0, scalars1)
self.assertEqual(self.writer.scalars, scalars1)

self.writer.write_scalars(1, scalars2)
self.assertEqual(self.writer.scalars, scalars2)

def test_write_images(self):
image = np.zeros((2, 2, 3))
images = {"image": image}

self.writer.write_images(0, images)
self.assertEqual(self.writer.images, images)

def test_write_videos(self):
video = np.zeros((10, 2, 2, 3))
videos = {"video": video}

self.writer.write_videos(0, videos)
self.assertEqual(self.writer.videos, videos)

def test_write_audios(self):
audio = np.zeros((100, 2))
audios = {"audio": audio}

self.writer.write_audios(0, audios, sample_rate=44100)
self.assertEqual(self.writer.audios, audios)

def test_write_texts(self):
texts = {"text": "hello world"}

self.writer.write_texts(0, texts)
self.assertEqual(self.writer.texts, texts)

def test_write_hparams(self):
hparams = {"learning_rate": 0.1}

self.writer.write_hparams(hparams)
self.assertEqual(self.writer.hparams, hparams)

def test_write_histograms(self):
arrays = {"hist": np.array([1, 2, 3])}
num_buckets = {"hist": 10}

self.writer.write_histograms(0, arrays, num_buckets)
self.assertEqual(self.writer.histogram_arrays, arrays)
self.assertEqual(self.writer.histogram_num_buckets, num_buckets)

def test_initial_values_are_none(self):
self.assertIsNone(self.writer.scalars)
self.assertIsNone(self.writer.images)
self.assertIsNone(self.writer.videos)
self.assertIsNone(self.writer.audios)
self.assertIsNone(self.writer.texts)
self.assertIsNone(self.writer.hparams)
self.assertIsNone(self.writer.histogram_arrays)
self.assertIsNone(self.writer.histogram_num_buckets)


if __name__ == "__main__":
absltest.main()

Check warning on line 77 in src/jax_loop_utils/metric_writers/keep_last_writer_test.py

View check run for this annotation

Codecov / codecov/patch

src/jax_loop_utils/metric_writers/keep_last_writer_test.py#L77

Added line #L77 was not covered by tests
21 changes: 14 additions & 7 deletions src/jax_loop_utils/metric_writers/mlflow/metric_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,22 @@


class MlflowMetricWriter(MetricWriterInterface):
"""MLflow implementation of MetricWriter."""
"""Writes metrics to MLflow Tracking."""

def __init__(self, experiment_name: str, tracking_uri: str | None = None):
def __init__(
self,
experiment_name: str,
run_name: str | None = None,
tracking_uri: str | None = None,
):
"""Initialize MLflow writer.

Args:
experiment_name: Name of the MLflow experiment.
tracking_uri: Address of local or remote tracking server. If not provided, defaults
to the service set by ``mlflow.tracking.set_tracking_uri``. See
`Where Runs Get Recorded <../tracking.html#where-runs-get-recorded>`_
for more info.
run_name: Name of the MLflow run.
tracking_uri: Address of local or remote tracking server.
Treated the same as arguments to mlflow.set_tracking_uri.
See https://www.mlflow.org/docs/latest/python_api/mlflow.html#mlflow.set_tracking_uri
"""
self._client = mlflow.MlflowClient(tracking_uri=tracking_uri)
experiment = self._client.get_experiment_by_name(experiment_name)
Expand All @@ -42,7 +47,9 @@ def __init__(self, experiment_name: str, tracking_uri: str | None = None):
experiment_name,
)
experiment_id = self._client.create_experiment(experiment_name)
self._run_id = self._client.create_run(experiment_id=experiment_id).info.run_id
self._run_id = self._client.create_run(
experiment_id=experiment_id, run_name=run_name
).info.run_id

def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
"""Write scalar metrics to MLflow."""
Expand Down
70 changes: 70 additions & 0 deletions src/jax_loop_utils/metric_writers/prefix_suffix_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Writer that adds prefix and suffix to metric keys."""

from typing import Any, Mapping, Optional

from jax_loop_utils.metric_writers import interface


class PrefixSuffixWriter(interface.MetricWriter):
"""Wraps a MetricWriter and adds prefix/suffix to all keys."""

def __init__(
self,
writer: interface.MetricWriter,
prefix: str = "",
suffix: str = "",
):
"""Initialize the writer.

Args:
writer: The underlying MetricWriter to wrap
prefix: String to prepend to all keys
suffix: String to append to all keys
"""

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably should assert prefix | suffix

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The user will get exactly what they asked for which will come at a small performance cost. I'd rather "define errors out of existence".

self._writer = writer
self._prefix = prefix
self._suffix = suffix

def _transform_keys(self, data: Mapping[str, Any]) -> dict[str, Any]:
"""Add prefix and suffix to all keys in the mapping."""
return {
f"{self._prefix}{key}{self._suffix}": value for key, value in data.items()
}
mickvangelderen marked this conversation as resolved.
Show resolved Hide resolved

def write_scalars(self, step: int, scalars: Mapping[str, interface.Scalar]):
self._writer.write_scalars(step, self._transform_keys(scalars))

def write_images(self, step: int, images: Mapping[str, interface.Array]):
self._writer.write_images(step, self._transform_keys(images))

def write_videos(self, step: int, videos: Mapping[str, interface.Array]):
self._writer.write_videos(step, self._transform_keys(videos))

def write_audios(
self, step: int, audios: Mapping[str, interface.Array], *, sample_rate: int
):
self._writer.write_audios(
step, self._transform_keys(audios), sample_rate=sample_rate
)

def write_texts(self, step: int, texts: Mapping[str, str]):
self._writer.write_texts(step, self._transform_keys(texts))

def write_histograms(
self,
step: int,
arrays: Mapping[str, interface.Array],
num_buckets: Optional[Mapping[str, int]] = None,
):
if num_buckets is not None:
num_buckets = self._transform_keys(num_buckets)
self._writer.write_histograms(step, self._transform_keys(arrays), num_buckets)

def write_hparams(self, hparams: Mapping[str, Any]):
self._writer.write_hparams(self._transform_keys(hparams))

def flush(self):
self._writer.flush()

def close(self):
self._writer.close()
80 changes: 80 additions & 0 deletions src/jax_loop_utils/metric_writers/prefix_suffix_writer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Tests for PrefixSuffixWriter."""

from unittest import mock

import numpy as np
from absl.testing import absltest

from jax_loop_utils.metric_writers import memory_writer, prefix_suffix_writer


class PrefixSuffixWriterTest(absltest.TestCase):
def setUp(self):
super().setUp()
self.memory_writer = memory_writer.MemoryWriter()
self.writer = prefix_suffix_writer.PrefixSuffixWriter(
self.memory_writer,
prefix="prefix/",
suffix="/suffix",
)

def test_write_scalars(self):
self.writer.write_scalars(0, {"metric": 1.0})
self.assertEqual(self.memory_writer.scalars, {0: {"prefix/metric/suffix": 1.0}})

def test_write_images(self):
image = np.zeros((2, 2, 3))
self.writer.write_images(0, {"image": image})
self.assertEqual(
list(self.memory_writer.images[0].keys()), ["prefix/image/suffix"]
)

def test_write_texts(self):
self.writer.write_texts(0, {"text": "hello"})
self.assertEqual(self.memory_writer.texts, {0: {"prefix/text/suffix": "hello"}})

def test_write_histograms(self):
data = np.array([1, 2, 3])
buckets = {"hist": 10}
self.writer.write_histograms(0, {"hist": data}, buckets)
self.assertEqual(
list(self.memory_writer.histograms[0].arrays.keys()),
["prefix/hist/suffix"],
)

def test_write_hparams(self):
self.writer.write_hparams({"param": 1})
self.assertEqual(self.memory_writer.hparams, {"prefix/param/suffix": 1})

def test_empty_prefix_suffix(self):
writer = prefix_suffix_writer.PrefixSuffixWriter(self.memory_writer)
writer.write_scalars(0, {"metric": 1.0})
self.assertEqual(self.memory_writer.scalars, {0: {"metric": 1.0}})

def test_write_videos(self):
video = np.zeros((10, 32, 32, 3)) # Simple video array with 10 frames
self.writer.write_videos(0, {"video": video})
self.assertEqual(
list(self.memory_writer.videos[0].keys()), ["prefix/video/suffix"]
)

def test_write_audios(self):
audio = np.zeros((16000,)) # 1 second of audio at 16kHz
self.writer.write_audios(0, {"audio": audio}, sample_rate=16000)
self.assertEqual(
list(self.memory_writer.audios[0].audios.keys()), ["prefix/audio/suffix"]
)

def test_close(self):
with mock.patch.object(self.memory_writer, "close") as mock_close:
self.writer.close()
mock_close.assert_called_once()

def test_flush(self):
with mock.patch.object(self.memory_writer, "flush") as mock_flush:
self.writer.flush()
mock_flush.assert_called_once()


if __name__ == "__main__":
absltest.main()

Check warning on line 80 in src/jax_loop_utils/metric_writers/prefix_suffix_writer_test.py

View check run for this annotation

Codecov / codecov/patch

src/jax_loop_utils/metric_writers/prefix_suffix_writer_test.py#L80

Added line #L80 was not covered by tests
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.