Skip to content

Commit

Permalink
support writing videos to mlflow
Browse files Browse the repository at this point in the history
  • Loading branch information
garymm committed Dec 23, 2024
1 parent 8aaad5b commit 884d673
Show file tree
Hide file tree
Showing 7 changed files with 229 additions and 18 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ dependencies = [
"jax>=0.4.36",
"numpy",
"packaging",
"Pillow",
"wrapt",
]
classifiers = [
Expand All @@ -29,13 +28,14 @@ keywords = ["JAX", "machine learning"]
Homepage = "http://github.com/Astera-org/jax_loop_utils"

[project.optional-dependencies]
mlflow = ["mlflow-skinny>=2.0"]
mlflow = ["mlflow-skinny>=2.0", "Pillow"]
pyright = ["pyright"]
# for synopsis.ipynb
synopsis = ["chex", "flax", "ipykernel", "matplotlib"]
tensorflow = ["tensorflow>=2.12"]
test = ["chex", "pytest", "pytest-cov"]
torch = ["torch>=2.0"]
audio-video = ["av>=14.0"]

[build-system]
requires = ["hatchling"]
Expand Down
8 changes: 8 additions & 0 deletions src/jax_loop_utils/metric_writers/_audio_video/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from jax_loop_utils.metric_writers._audio_video.audio_video import (
CODEC,
CONTAINER_FORMAT,
FPS,
encode_video,
)

__all__ = ["encode_video", "CONTAINER_FORMAT", "CODEC", "FPS"]
63 changes: 63 additions & 0 deletions src/jax_loop_utils/metric_writers/_audio_video/audio_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Utilities for audio and video.
Requires additional dependencies, part of the `audio-video` extra.
"""

import io

import av
import numpy as np

from jax_loop_utils.metric_writers.interface import (

Check warning on line 11 in src/jax_loop_utils/metric_writers/_audio_video/audio_video.py

View check run for this annotation

Codecov / codecov/patch

src/jax_loop_utils/metric_writers/_audio_video/audio_video.py#L11

Added line #L11 was not covered by tests
Array,
)

CONTAINER_FORMAT = "mp4"
CODEC = "h264"
FPS = 10

Check warning on line 17 in src/jax_loop_utils/metric_writers/_audio_video/audio_video.py

View check run for this annotation

Codecov / codecov/patch

src/jax_loop_utils/metric_writers/_audio_video/audio_video.py#L15-L17

Added lines #L15 - L17 were not covered by tests


def encode_video(video_array: Array, destination: io.IOBase):

Check warning on line 20 in src/jax_loop_utils/metric_writers/_audio_video/audio_video.py

View check run for this annotation

Codecov / codecov/patch

src/jax_loop_utils/metric_writers/_audio_video/audio_video.py#L20

Added line #L20 was not covered by tests
"""Encode a video array.
Encodes using CODEC and writes using CONTAINER_FORMAT at FPS frames per second.
Args:
video_array: array to encode. Must have shape (T, H, W, 1) or (T, H, W, 3),
where T is the number of frames, H is the height, W is the width, and the last
dimension is the number of channels. Must have dtype uint8.
destination: Destination to write the encoded video.
"""
video_array = np.array(video_array)
if (

Check warning on line 32 in src/jax_loop_utils/metric_writers/_audio_video/audio_video.py

View check run for this annotation

Codecov / codecov/patch

src/jax_loop_utils/metric_writers/_audio_video/audio_video.py#L31-L32

Added lines #L31 - L32 were not covered by tests
video_array.dtype != np.uint8
or video_array.ndim != 4
or video_array.shape[-1] not in (1, 3)
):
raise ValueError(

Check warning on line 37 in src/jax_loop_utils/metric_writers/_audio_video/audio_video.py

View check run for this annotation

Codecov / codecov/patch

src/jax_loop_utils/metric_writers/_audio_video/audio_video.py#L37

Added line #L37 was not covered by tests
"Expected a uint8 array with shape (T, H, W, 1) or (T, H, W, 3)."
f"Got shape {video_array.shape} with dtype {video_array.dtype}."
)

T, H, W, C = video_array.shape
is_grayscale = C == 1
if is_grayscale:
video_array = np.squeeze(video_array, axis=-1)

Check warning on line 45 in src/jax_loop_utils/metric_writers/_audio_video/audio_video.py

View check run for this annotation

Codecov / codecov/patch

src/jax_loop_utils/metric_writers/_audio_video/audio_video.py#L42-L45

Added lines #L42 - L45 were not covered by tests

with av.open(destination, mode="w", format=CONTAINER_FORMAT) as container:
stream = container.add_stream(CODEC, rate=FPS)
stream.width = W
stream.height = H
stream.pix_fmt = "yuv420p"

Check warning on line 51 in src/jax_loop_utils/metric_writers/_audio_video/audio_video.py

View check run for this annotation

Codecov / codecov/patch

src/jax_loop_utils/metric_writers/_audio_video/audio_video.py#L47-L51

Added lines #L47 - L51 were not covered by tests

for t in range(T):
frame_data = video_array[t]
if is_grayscale:

Check warning on line 55 in src/jax_loop_utils/metric_writers/_audio_video/audio_video.py

View check run for this annotation

Codecov / codecov/patch

src/jax_loop_utils/metric_writers/_audio_video/audio_video.py#L53-L55

Added lines #L53 - L55 were not covered by tests
# For grayscale, use gray format and let av handle conversion to yuv420p
frame = av.VideoFrame.from_ndarray(frame_data, format="gray")

Check warning on line 57 in src/jax_loop_utils/metric_writers/_audio_video/audio_video.py

View check run for this annotation

Codecov / codecov/patch

src/jax_loop_utils/metric_writers/_audio_video/audio_video.py#L57

Added line #L57 was not covered by tests
else:
frame = av.VideoFrame.from_ndarray(frame_data, format="rgb24")
frame.pts = t
container.mux(stream.encode(frame))

Check warning on line 61 in src/jax_loop_utils/metric_writers/_audio_video/audio_video.py

View check run for this annotation

Codecov / codecov/patch

src/jax_loop_utils/metric_writers/_audio_video/audio_video.py#L59-L61

Added lines #L59 - L61 were not covered by tests

container.mux(stream.encode(None))

Check warning on line 63 in src/jax_loop_utils/metric_writers/_audio_video/audio_video.py

View check run for this annotation

Codecov / codecov/patch

src/jax_loop_utils/metric_writers/_audio_video/audio_video.py#L63

Added line #L63 was not covered by tests
74 changes: 74 additions & 0 deletions src/jax_loop_utils/metric_writers/_audio_video/audio_video_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""Tests for video encoding utilities."""

import io

import av
import numpy as np
from absl.testing import absltest

from jax_loop_utils.metric_writers._audio_video import (
CONTAINER_FORMAT,
FPS,
encode_video,
)


class VideoTest(absltest.TestCase):
"""Tests for video encoding utilities."""

def test_encode_video_invalid_args(self):
"""Test that encode_video raises appropriate errors for invalid inputs."""
invalid_shape = np.zeros((10, 20, 30, 4), dtype=np.uint8)
with self.assertRaisesRegex(ValueError, "Expected a uint8 array with shape"):
encode_video(invalid_shape, io.BytesIO())

invalid_dtype = np.zeros((10, 20, 30, 3), dtype=np.float32)
with self.assertRaisesRegex(ValueError, "Expected a uint8 array with shape"):
encode_video(invalid_dtype, io.BytesIO())

def test_encode_video_success(self):
"""Test successful video encoding."""
# Create a simple test video - red square moving diagonally
T, H, W = 20, 64, 64
video = np.zeros((T, H, W, 3), dtype=np.uint8)
for t in range(T):
pos = t * 5 # Move 5 pixels each frame
video[t, pos : pos + 10, pos : pos + 10, 0] = 255 # Red square

output = io.BytesIO()
encode_video(video, output)

output.seek(0)
with av.open(output, mode="r", format=CONTAINER_FORMAT) as container:
stream = container.streams.video[0]
self.assertEqual(stream.codec_context.width, W)
self.assertEqual(stream.codec_context.height, H)
self.assertEqual(stream.codec_context.framerate, FPS)
# Check we can decode all frames
frame_count = sum(1 for _ in container.decode(stream))
self.assertEqual(frame_count, T)

def test_encode_video_grayscale(self):
"""Test encoding grayscale video (1 channel)."""
T, H, W = 5, 32, 32
video = np.zeros((T, H, W, 1), dtype=np.uint8)

# Create pulsing pattern
for t in range(T):
video[t, :, :, 0] = (t * 50) % 256 # Increasing brightness

output = io.BytesIO()
encode_video(video, output)

output.seek(0)
with av.open(output, mode="r", format=CONTAINER_FORMAT) as container:
stream = container.streams.video[0]
self.assertEqual(stream.codec_context.width, W)
self.assertEqual(stream.codec_context.height, H)
# Check we can decode all frames
frame_count = sum(1 for _ in container.decode(stream))
self.assertEqual(frame_count, T)


if __name__ == "__main__":
absltest.main()
49 changes: 36 additions & 13 deletions src/jax_loop_utils/metric_writers/mlflow/metric_writer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""MLflow implementation of MetricWriter interface."""

import pathlib
import tempfile
from collections.abc import Mapping
from time import time
from typing import Any
Expand All @@ -15,14 +17,17 @@

from jax_loop_utils.metric_writers.interface import (
Array,
MetricWriter,
Scalar,
)
from jax_loop_utils.metric_writers.interface import (
MetricWriter as MetricWriterInterface,
)

try:
from jax_loop_utils.metric_writers import _audio_video
except ImportError:
_audio_video = None

class MlflowMetricWriter(MetricWriterInterface):

class MlflowMetricWriter(MetricWriter):
"""Writes metrics to MLflow Tracking."""

def __init__(
Expand Down Expand Up @@ -91,18 +96,36 @@ def write_images(self, step: int, images: Mapping[str, Array]):
)

def write_videos(self, step: int, videos: Mapping[str, Array]):
"""MLflow doesn't support video logging directly."""
# this could be supported if we convert the video to a file
# and log the file as an artifact.
logging.log_first_n(
logging.WARNING,
"mlflow.MetricWriter does not support writing videos.",
1,
)
"""Convert videos to images and write them to MLflow.
Requires pillow to be installed.
"""
if _audio_video is None:
logging.log_first_n(
logging.WARNING,
"MlflowMetricWriter.write_videos requires the [video] extra to be installed.",
1,
)
return

temp_dir = tempfile.mkdtemp()

for key, video_array in videos.items():
local_path = (
pathlib.Path(temp_dir)
/ f"{key}_{step:09d}.{_audio_video.CONTAINER_FORMAT}"
)
with open(local_path, "wb") as f:
_audio_video.encode_video(video_array, f)
self._client.log_artifact(
self._run_id,
local_path,
artifact_path="videos",
)

def write_audios(self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
"""MLflow doesn't support audio logging directly."""
# this could be supported if we convert the video to a file
# this could be supported if we convert the audio to a file
# and log the file as an artifact.
logging.log_first_n(
logging.WARNING,
Expand Down
18 changes: 17 additions & 1 deletion src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,28 @@ def test_write_hparams(self):
self.assertEqual(run.data.params["batch_size"], "32")
self.assertEqual(run.data.params["epochs"], "100")

def test_write_videos(self):
with tempfile.TemporaryDirectory() as temp_dir:
tracking_uri = f"file://{temp_dir}"
experiment_name = "experiment_name"
writer = MlflowMetricWriter(experiment_name, tracking_uri=tracking_uri)

# Generate 100 frames of noise video
frames = []
for _ in range(100):
frame = np.random.randint(0, 256, (64, 64, 3), dtype=np.uint8)
frames.append(frame)

# Stack frames into video array [frames, height, width, channels]
video = np.stack(frames, axis=0)
writer.write_videos(0, {"noise_video": video})
writer.close()

def test_no_ops(self):
with tempfile.TemporaryDirectory() as temp_dir:
tracking_uri = f"file://{temp_dir}"
experiment_name = "experiment_name"
writer = MlflowMetricWriter(experiment_name, tracking_uri=tracking_uri)
writer.write_videos(0, {"video": np.zeros((4, 28, 28, 3))})
writer.write_audios(0, {"audio": np.zeros((2, 1000))}, sample_rate=16000)
writer.write_histograms(
0, {"histogram": np.zeros((10,))}, num_buckets={"histogram": 10}
Expand Down
31 changes: 29 additions & 2 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 884d673

Please sign in to comment.