Skip to content

Commit

Permalink
video: fix bugs
Browse files Browse the repository at this point in the history
* handle non-even input dimensions by padding
* handle float inputs by scaling
* upload artifacts to the correct directory

Change-Id: I84c7fe4cbcd311ed67f736b0e035ec32a9566cca
  • Loading branch information
garymm committed Dec 24, 2024
1 parent b0b1e12 commit eb1e8e0
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 15 deletions.
35 changes: 29 additions & 6 deletions src/jax_loop_utils/metric_writers/_audio_video/audio_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,44 @@ def encode_video(video_array: Array, destination: io.IOBase):
Args:
video_array: array to encode. Must have shape (T, H, W, 1) or (T, H, W, 3),
where T is the number of frames, H is the height, W is the width, and the last
dimension is the number of channels. Must have dtype uint8.
dimension is the number of channels.
Must be ints in [0, 255] or floats in [0, 1].
destination: Destination to write the encoded video.
"""
video_array = np.array(video_array)
if video_array.ndim != 4 or video_array.shape[-1] not in (1, 3):
raise ValueError(
"Expected an array with shape (T, H, W, 1) or (T, H, W, 3)."
f"Got shape {video_array.shape} with dtype {video_array.dtype}."
)

if (
video_array.dtype != np.uint8
or video_array.ndim != 4
or video_array.shape[-1] not in (1, 3)
np.issubdtype(video_array.dtype, np.floating)
and np.all(0 <= video_array)
and np.all(video_array <= 1.0)
):
video_array = (video_array * 255).astype(np.uint8)
elif (
np.issubdtype(video_array.dtype, np.integer)
and np.all(0 <= video_array)
and np.all(video_array <= 255)
):
video_array = video_array.astype(np.uint8)
else:
raise ValueError(
"Expected a uint8 array with shape (T, H, W, 1) or (T, H, W, 3)."
f"Got shape {video_array.shape} with dtype {video_array.dtype}."
f"Expected video_array to be floats in [0, 1] or ints in [0, 255], got {video_array.dtype}"
)

T, H, W, C = video_array.shape
# Pad height and width to even numbers if necessary
pad_h = H % 2
pad_w = W % 2
if pad_h or pad_w:
padding = [(0, 0), (0, pad_h), (0, pad_w), (0, 0)]
video_array = np.pad(video_array, padding, mode="constant")
H += pad_h
W += pad_w

is_grayscale = C == 1
if is_grayscale:
video_array = np.squeeze(video_array, axis=-1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,19 @@ class VideoTest(absltest.TestCase):
def test_encode_video_invalid_args(self):
"""Test that encode_video raises appropriate errors for invalid inputs."""
invalid_shape = np.zeros((10, 20, 30, 4), dtype=np.uint8)
with self.assertRaisesRegex(ValueError, "Expected a uint8 array with shape"):
with self.assertRaisesRegex(ValueError, r"Expected an array with shape"):
encode_video(invalid_shape, io.BytesIO())

invalid_dtype = np.zeros((10, 20, 30, 3), dtype=np.float32)
with self.assertRaisesRegex(ValueError, "Expected a uint8 array with shape"):
invalid_dtype = 2 * np.ones((10, 20, 30, 3), dtype=np.float32)
with self.assertRaisesRegex(
ValueError, r"Expected video_array to be floats in \[0, 1\]"
):
encode_video(invalid_dtype, io.BytesIO())

def test_encode_video_success(self):
"""Test successful video encoding."""
# Create a simple test video - red square moving diagonally
T, H, W = 20, 64, 64
T, H, W = 20, 63, 63 # test non-even dimensions
video = np.zeros((T, H, W, 3), dtype=np.uint8)
for t in range(T):
pos = t * 5 # Move 5 pixels each frame
Expand All @@ -41,8 +43,8 @@ def test_encode_video_success(self):
output.seek(0)
with av.open(output, mode="r", format=CONTAINER_FORMAT) as container:
stream = container.streams.video[0]
self.assertEqual(stream.codec_context.width, W)
self.assertEqual(stream.codec_context.height, H)
self.assertEqual(stream.codec_context.width, W + 1)
self.assertEqual(stream.codec_context.height, H + 1)
self.assertEqual(stream.codec_context.framerate, FPS)
# Check we can decode all frames
frame_count = sum(1 for _ in container.decode(stream))
Expand Down
5 changes: 2 additions & 3 deletions src/jax_loop_utils/metric_writers/mlflow/metric_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,11 @@ def _encode_and_log_video(
temp_path.parent.mkdir(parents=True, exist_ok=True)
with open(temp_path, "wb") as f:
_audio_video.encode_video(video_array, f) # pyright: ignore[reportOptionalMemberAccess]
dest_dir = os.path.join("videos", os.path.dirname(rel_path)).rstrip("/")
# If log_artifact(synchronous=False) existed,
# we could synchronize with self.flush() rather than at the end of write_videos.
# https://github.com/mlflow/mlflow/issues/14153
self._client.log_artifact(
self._run_id, temp_path, os.path.join("videos", rel_path)
)
self._client.log_artifact(self._run_id, temp_path, dest_dir)

def write_audios(self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
"""MLflow doesn't support audio logging directly."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,12 @@ def test_write_videos(self):
self.assertEqual(
sorted_artifacts_videos[0].path, "videos/noise_1_000000000.mp4"
)
self.assertFalse(sorted_artifacts_videos[0].is_dir)

artifacts_zzz = writer._client.list_artifacts(run.info.run_id, "videos/zzz")
self.assertEqual(len(artifacts_zzz), 1)
self.assertEqual(artifacts_zzz[0].path, "videos/zzz/noise_0_000000000.mp4")
self.assertFalse(artifacts_zzz[0].is_dir)

def test_no_ops(self):
with tempfile.TemporaryDirectory() as temp_dir:
Expand Down

0 comments on commit eb1e8e0

Please sign in to comment.