From 884d6732742a4413177aefa0d9cd7b8d37546824 Mon Sep 17 00:00:00 2001 From: Gary Miguel Date: Fri, 20 Dec 2024 17:19:41 -0800 Subject: [PATCH] support writing videos to mlflow --- pyproject.toml | 4 +- .../metric_writers/_audio_video/__init__.py | 8 ++ .../_audio_video/audio_video.py | 63 ++++++++++++++++ .../_audio_video/audio_video_test.py | 74 +++++++++++++++++++ .../metric_writers/mlflow/metric_writer.py | 49 ++++++++---- .../mlflow/metric_writer_test.py | 18 ++++- uv.lock | 31 +++++++- 7 files changed, 229 insertions(+), 18 deletions(-) create mode 100644 src/jax_loop_utils/metric_writers/_audio_video/__init__.py create mode 100644 src/jax_loop_utils/metric_writers/_audio_video/audio_video.py create mode 100644 src/jax_loop_utils/metric_writers/_audio_video/audio_video_test.py diff --git a/pyproject.toml b/pyproject.toml index 2dd7f67..36c937c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,6 @@ dependencies = [ "jax>=0.4.36", "numpy", "packaging", - "Pillow", "wrapt", ] classifiers = [ @@ -29,13 +28,14 @@ keywords = ["JAX", "machine learning"] Homepage = "http://github.com/Astera-org/jax_loop_utils" [project.optional-dependencies] -mlflow = ["mlflow-skinny>=2.0"] +mlflow = ["mlflow-skinny>=2.0", "Pillow"] pyright = ["pyright"] # for synopsis.ipynb synopsis = ["chex", "flax", "ipykernel", "matplotlib"] tensorflow = ["tensorflow>=2.12"] test = ["chex", "pytest", "pytest-cov"] torch = ["torch>=2.0"] +audio-video = ["av>=14.0"] [build-system] requires = ["hatchling"] diff --git a/src/jax_loop_utils/metric_writers/_audio_video/__init__.py b/src/jax_loop_utils/metric_writers/_audio_video/__init__.py new file mode 100644 index 0000000..dc013c9 --- /dev/null +++ b/src/jax_loop_utils/metric_writers/_audio_video/__init__.py @@ -0,0 +1,8 @@ +from jax_loop_utils.metric_writers._audio_video.audio_video import ( + CODEC, + CONTAINER_FORMAT, + FPS, + encode_video, +) + +__all__ = ["encode_video", "CONTAINER_FORMAT", "CODEC", "FPS"] diff --git a/src/jax_loop_utils/metric_writers/_audio_video/audio_video.py b/src/jax_loop_utils/metric_writers/_audio_video/audio_video.py new file mode 100644 index 0000000..2598aa2 --- /dev/null +++ b/src/jax_loop_utils/metric_writers/_audio_video/audio_video.py @@ -0,0 +1,63 @@ +"""Utilities for audio and video. + +Requires additional dependencies, part of the `audio-video` extra. +""" + +import io + +import av +import numpy as np + +from jax_loop_utils.metric_writers.interface import ( + Array, +) + +CONTAINER_FORMAT = "mp4" +CODEC = "h264" +FPS = 10 + + +def encode_video(video_array: Array, destination: io.IOBase): + """Encode a video array. + + Encodes using CODEC and writes using CONTAINER_FORMAT at FPS frames per second. + + Args: + video_array: array to encode. Must have shape (T, H, W, 1) or (T, H, W, 3), + where T is the number of frames, H is the height, W is the width, and the last + dimension is the number of channels. Must have dtype uint8. + destination: Destination to write the encoded video. + """ + video_array = np.array(video_array) + if ( + video_array.dtype != np.uint8 + or video_array.ndim != 4 + or video_array.shape[-1] not in (1, 3) + ): + raise ValueError( + "Expected a uint8 array with shape (T, H, W, 1) or (T, H, W, 3)." + f"Got shape {video_array.shape} with dtype {video_array.dtype}." + ) + + T, H, W, C = video_array.shape + is_grayscale = C == 1 + if is_grayscale: + video_array = np.squeeze(video_array, axis=-1) + + with av.open(destination, mode="w", format=CONTAINER_FORMAT) as container: + stream = container.add_stream(CODEC, rate=FPS) + stream.width = W + stream.height = H + stream.pix_fmt = "yuv420p" + + for t in range(T): + frame_data = video_array[t] + if is_grayscale: + # For grayscale, use gray format and let av handle conversion to yuv420p + frame = av.VideoFrame.from_ndarray(frame_data, format="gray") + else: + frame = av.VideoFrame.from_ndarray(frame_data, format="rgb24") + frame.pts = t + container.mux(stream.encode(frame)) + + container.mux(stream.encode(None)) diff --git a/src/jax_loop_utils/metric_writers/_audio_video/audio_video_test.py b/src/jax_loop_utils/metric_writers/_audio_video/audio_video_test.py new file mode 100644 index 0000000..3609dd8 --- /dev/null +++ b/src/jax_loop_utils/metric_writers/_audio_video/audio_video_test.py @@ -0,0 +1,74 @@ +"""Tests for video encoding utilities.""" + +import io + +import av +import numpy as np +from absl.testing import absltest + +from jax_loop_utils.metric_writers._audio_video import ( + CONTAINER_FORMAT, + FPS, + encode_video, +) + + +class VideoTest(absltest.TestCase): + """Tests for video encoding utilities.""" + + def test_encode_video_invalid_args(self): + """Test that encode_video raises appropriate errors for invalid inputs.""" + invalid_shape = np.zeros((10, 20, 30, 4), dtype=np.uint8) + with self.assertRaisesRegex(ValueError, "Expected a uint8 array with shape"): + encode_video(invalid_shape, io.BytesIO()) + + invalid_dtype = np.zeros((10, 20, 30, 3), dtype=np.float32) + with self.assertRaisesRegex(ValueError, "Expected a uint8 array with shape"): + encode_video(invalid_dtype, io.BytesIO()) + + def test_encode_video_success(self): + """Test successful video encoding.""" + # Create a simple test video - red square moving diagonally + T, H, W = 20, 64, 64 + video = np.zeros((T, H, W, 3), dtype=np.uint8) + for t in range(T): + pos = t * 5 # Move 5 pixels each frame + video[t, pos : pos + 10, pos : pos + 10, 0] = 255 # Red square + + output = io.BytesIO() + encode_video(video, output) + + output.seek(0) + with av.open(output, mode="r", format=CONTAINER_FORMAT) as container: + stream = container.streams.video[0] + self.assertEqual(stream.codec_context.width, W) + self.assertEqual(stream.codec_context.height, H) + self.assertEqual(stream.codec_context.framerate, FPS) + # Check we can decode all frames + frame_count = sum(1 for _ in container.decode(stream)) + self.assertEqual(frame_count, T) + + def test_encode_video_grayscale(self): + """Test encoding grayscale video (1 channel).""" + T, H, W = 5, 32, 32 + video = np.zeros((T, H, W, 1), dtype=np.uint8) + + # Create pulsing pattern + for t in range(T): + video[t, :, :, 0] = (t * 50) % 256 # Increasing brightness + + output = io.BytesIO() + encode_video(video, output) + + output.seek(0) + with av.open(output, mode="r", format=CONTAINER_FORMAT) as container: + stream = container.streams.video[0] + self.assertEqual(stream.codec_context.width, W) + self.assertEqual(stream.codec_context.height, H) + # Check we can decode all frames + frame_count = sum(1 for _ in container.decode(stream)) + self.assertEqual(frame_count, T) + + +if __name__ == "__main__": + absltest.main() diff --git a/src/jax_loop_utils/metric_writers/mlflow/metric_writer.py b/src/jax_loop_utils/metric_writers/mlflow/metric_writer.py index 809eb53..2e931a6 100644 --- a/src/jax_loop_utils/metric_writers/mlflow/metric_writer.py +++ b/src/jax_loop_utils/metric_writers/mlflow/metric_writer.py @@ -1,5 +1,7 @@ """MLflow implementation of MetricWriter interface.""" +import pathlib +import tempfile from collections.abc import Mapping from time import time from typing import Any @@ -15,14 +17,17 @@ from jax_loop_utils.metric_writers.interface import ( Array, + MetricWriter, Scalar, ) -from jax_loop_utils.metric_writers.interface import ( - MetricWriter as MetricWriterInterface, -) +try: + from jax_loop_utils.metric_writers import _audio_video +except ImportError: + _audio_video = None -class MlflowMetricWriter(MetricWriterInterface): + +class MlflowMetricWriter(MetricWriter): """Writes metrics to MLflow Tracking.""" def __init__( @@ -91,18 +96,36 @@ def write_images(self, step: int, images: Mapping[str, Array]): ) def write_videos(self, step: int, videos: Mapping[str, Array]): - """MLflow doesn't support video logging directly.""" - # this could be supported if we convert the video to a file - # and log the file as an artifact. - logging.log_first_n( - logging.WARNING, - "mlflow.MetricWriter does not support writing videos.", - 1, - ) + """Convert videos to images and write them to MLflow. + + Requires pillow to be installed. + """ + if _audio_video is None: + logging.log_first_n( + logging.WARNING, + "MlflowMetricWriter.write_videos requires the [video] extra to be installed.", + 1, + ) + return + + temp_dir = tempfile.mkdtemp() + + for key, video_array in videos.items(): + local_path = ( + pathlib.Path(temp_dir) + / f"{key}_{step:09d}.{_audio_video.CONTAINER_FORMAT}" + ) + with open(local_path, "wb") as f: + _audio_video.encode_video(video_array, f) + self._client.log_artifact( + self._run_id, + local_path, + artifact_path="videos", + ) def write_audios(self, step: int, audios: Mapping[str, Array], *, sample_rate: int): """MLflow doesn't support audio logging directly.""" - # this could be supported if we convert the video to a file + # this could be supported if we convert the audio to a file # and log the file as an artifact. logging.log_first_n( logging.WARNING, diff --git a/src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py b/src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py index f41a587..b0de958 100644 --- a/src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py +++ b/src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py @@ -132,12 +132,28 @@ def test_write_hparams(self): self.assertEqual(run.data.params["batch_size"], "32") self.assertEqual(run.data.params["epochs"], "100") + def test_write_videos(self): + with tempfile.TemporaryDirectory() as temp_dir: + tracking_uri = f"file://{temp_dir}" + experiment_name = "experiment_name" + writer = MlflowMetricWriter(experiment_name, tracking_uri=tracking_uri) + + # Generate 100 frames of noise video + frames = [] + for _ in range(100): + frame = np.random.randint(0, 256, (64, 64, 3), dtype=np.uint8) + frames.append(frame) + + # Stack frames into video array [frames, height, width, channels] + video = np.stack(frames, axis=0) + writer.write_videos(0, {"noise_video": video}) + writer.close() + def test_no_ops(self): with tempfile.TemporaryDirectory() as temp_dir: tracking_uri = f"file://{temp_dir}" experiment_name = "experiment_name" writer = MlflowMetricWriter(experiment_name, tracking_uri=tracking_uri) - writer.write_videos(0, {"video": np.zeros((4, 28, 28, 3))}) writer.write_audios(0, {"audio": np.zeros((2, 1000))}, sample_rate=16000) writer.write_histograms( 0, {"histogram": np.zeros((10,))}, num_buckets={"histogram": 10} diff --git a/uv.lock b/uv.lock index 6d87ecb..17fd811 100644 --- a/uv.lock +++ b/uv.lock @@ -54,6 +54,29 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2b/03/13dde6512ad7b4557eb792fbcf0c653af6076b81e5941d36ec61f7ce6028/astunparse-1.6.3-py2.py3-none-any.whl", hash = "sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8", size = 12732 }, ] +[[package]] +name = "av" +version = "14.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/04/6a/8461055082eee773b549b87519eb8519c2194af1157f6971e3a0722b308e/av-14.0.1.tar.gz", hash = "sha256:2b0a17301af469ddaea46b5c1c982df1b7b5de8bc6c94cdc98cad4a67178c82a", size = 3918806 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/07/46/ad32d62c13369116ff9761c5e4eac248cfe9553abb6e4828f77a7adafaf6/av-14.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b9cfdc671bb7e09824897164626d76a8ecdc009c13afef3decb7071de57f0c71", size = 19527044 }, + { url = "https://files.pythonhosted.org/packages/73/60/13024290b4c726a08ec83d4da4d8f46078479de9946df042456ea4688b47/av-14.0.1-cp311-cp311-macosx_11_0_x86_64.whl", hash = "sha256:9198ffaab74b8ac659d14b355c1821208e8b16f35138f4922721113bc6c7b7ab", size = 24331238 }, + { url = "https://files.pythonhosted.org/packages/3e/9d/c230d1a035ce69d43637da592be5f18a0c9d37b1eca9a07e4ab92f464351/av-14.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e85d933bbcde2db01d114419283edda35330ca651b11f9f5d6a694ce0c1b26ee", size = 31990679 }, + { url = "https://files.pythonhosted.org/packages/7f/61/ead56d9c2ced8c7770a84732d12ff33abab02c83629edf13232375a87f9f/av-14.0.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f1f13b662e8f8a7736fb7bc34e17d6c5a4e57b7142bae6b0502d962173883b26", size = 31345454 }, + { url = "https://files.pythonhosted.org/packages/46/b0/6380e05f36ec78c4695df9cc6c6fc6c3dec976caf105bacee5328f83d198/av-14.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e54203fccff31c8bc6563df6dceff8c236fd1e18fdb1771b1a56ed1525cd72b", size = 33811771 }, + { url = "https://files.pythonhosted.org/packages/7d/01/b198c3af186003b6429ad4dd9f654f99761a8efa59ab838884c6ac842117/av-14.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e022c6ec1ef4bf324f83ecbe4aa6b0afa89e86537d8c5e15f065e4da1a5de6a8", size = 19530078 }, + { url = "https://files.pythonhosted.org/packages/00/48/2f0ec27a521f4963683d76f3b7870c7667dfb8be4ab9512dad92566a8cbc/av-14.0.1-cp312-cp312-macosx_11_0_x86_64.whl", hash = "sha256:97daa268795449709f0034b19bb7ca4e99018825f9c7640fde30f2cb51f63f00", size = 24343809 }, + { url = "https://files.pythonhosted.org/packages/ca/6f/1bf89ac1f82c79f797d8331a7b191a16c6a8336fc59dbbc3d0b66d543e0a/av-14.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:14faba003c6bf6d6b9b9521f77a2059cfd206ae95b48f610b14de8d5ba2ccd4e", size = 32252684 }, + { url = "https://files.pythonhosted.org/packages/ba/71/8eedb586a32f7d1f8d23d9e0a7bc879d4dae6f82c2b7977c69171ea2e5b6/av-14.0.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9c1e0319a09525916a51f3c0ab4c07dfe9e82b3c1d8cf7aa3bb495d5dd28e767", size = 31617399 }, + { url = "https://files.pythonhosted.org/packages/16/19/7a833ba5b3c190f204dabc138bbc3f06730084400d79723cbe8d5792364e/av-14.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d37bf19987c8cad8c5701c174bf4e89df7e2de97be2176bd81a2b0f86516f1c3", size = 34166747 }, + { url = "https://files.pythonhosted.org/packages/25/32/3ce1b51d6638a1141614f24ed4f4e483613ae9b4a6e53f23f4a5b25907e2/av-14.0.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:71d7cae036bc362b3c255f99669e2e3412dc9b9e3e390ff426b9ea167f1f1c37", size = 19489839 }, + { url = "https://files.pythonhosted.org/packages/ef/e4/152c072da36a91b3158816a304eed73d2ab1bc6accd690ccce5bcb2dbc3a/av-14.0.1-cp313-cp313-macosx_11_0_x86_64.whl", hash = "sha256:e54079e560cc8d91b9be224a8ced4c8c6b97efdb8932f27c56efcbc2181c8129", size = 24303231 }, + { url = "https://files.pythonhosted.org/packages/6d/f1/968a4182df7dae4badb581106c8655591d42d88bc8efcc54bd5f27c7f3a6/av-14.0.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d98ab1b5ef8b13fa5875757d6e16fbe0a0c48e98e4c2c1da8478a0dda0ed500b", size = 31899455 }, + { url = "https://files.pythonhosted.org/packages/62/fe/61a85fc2c8286c11a0f607cc044b0c1cd858ddfc71004d5a9857d071f357/av-14.0.1-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:15924960530541ae6a2b3ce5de2e9bcb5c20379ac57850cfac3ee179b4411f64", size = 31286047 }, + { url = "https://files.pythonhosted.org/packages/e3/86/5eb261662b8766007239f35c33e37589f51b6346fb29cd5281df49f01c84/av-14.0.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:502ecd3cc0973cd2f868ec245bb3f838bb9d4da91bcc608f72a7cdd2cd44f0d1", size = 33822372 }, +] + [[package]] name = "cachetools" version = "5.5.0" @@ -697,13 +720,16 @@ dependencies = [ { name = "jax", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, { name = "packaging", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "pillow", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, { name = "wrapt", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, ] [package.optional-dependencies] +audio-video = [ + { name = "av", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, +] mlflow = [ { name = "mlflow-skinny", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "pillow", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, ] pyright = [ { name = "pyright", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, @@ -729,6 +755,7 @@ torch = [ [package.metadata] requires-dist = [ { name = "absl-py" }, + { name = "av", marker = "extra == 'audio-video'", specifier = ">=14.0" }, { name = "chex", marker = "extra == 'synopsis'" }, { name = "chex", marker = "extra == 'test'" }, { name = "etils", extras = ["epath", "epy"] }, @@ -739,7 +766,7 @@ requires-dist = [ { name = "mlflow-skinny", marker = "extra == 'mlflow'", specifier = ">=2.0" }, { name = "numpy" }, { name = "packaging" }, - { name = "pillow" }, + { name = "pillow", marker = "extra == 'mlflow'" }, { name = "pyright", marker = "extra == 'pyright'" }, { name = "pytest", marker = "extra == 'test'" }, { name = "pytest-cov", marker = "extra == 'test'" },