Skip to content

Commit

Permalink
mlflow write_videos: handle keys with slashes
Browse files Browse the repository at this point in the history
Change-Id: I8f0cdce98b393387735564eb90be6e33c6f4582b
  • Loading branch information
garymm committed Dec 24, 2024
1 parent da2db1f commit 4e71685
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 20 deletions.
22 changes: 16 additions & 6 deletions src/jax_loop_utils/metric_writers/mlflow/metric_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import pathlib
import shutil
import tempfile
import time
from collections.abc import Mapping
Expand Down Expand Up @@ -123,27 +124,36 @@ def write_videos(self, step: int, videos: Mapping[str, Array]):

encode_and_log = maybe_async(self._encode_and_log_video)

temp_dir = pathlib.Path(tempfile.mkdtemp())
paths_arrays = [
(
temp_dir / f"{key}_{step:09d}.{_audio_video.CONTAINER_FORMAT}",
f"{key}_{step:09d}.{_audio_video.CONTAINER_FORMAT}",
video_array,
)
for key, video_array in videos.items()
]

temp_dir = pathlib.Path(tempfile.mkdtemp())
for path, video_array in paths_arrays:
encode_and_log(path, video_array)
encode_and_log(temp_dir, path, video_array)

pool.close()
shutil.rmtree(temp_dir)

def _encode_and_log_video(self, path: pathlib.Path, video_array: Array):
with open(path, "wb") as f:
def _encode_and_log_video(
self, temp_dir: pathlib.Path, rel_path: str, video_array: Array
):
temp_path = temp_dir / rel_path
# handle keys with slashes
if not temp_path.parent.exists():
temp_path.parent.mkdir(parents=True, exist_ok=True)
with open(temp_path, "wb") as f:
_audio_video.encode_video(video_array, f) # pyright: ignore[reportOptionalMemberAccess]
# If log_artifact(synchronous=False) existed,
# we could synchronize with self.flush() rather than at the end of write_videos.
# https://github.com/mlflow/mlflow/issues/14153
self._client.log_artifact(self._run_id, path, os.path.join("videos", path.name))
self._client.log_artifact(
self._run_id, temp_path, os.path.join("videos", rel_path)
)

def write_audios(self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
"""MLflow doesn't support audio logging directly."""
Expand Down
24 changes: 10 additions & 14 deletions src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,6 @@ def test_write_images(self):
run = runs[0]
# the string "images" is hardcoded in MlflowClient.log_image.
artifacts = writer._client.list_artifacts(run.info.run_id, "images")
if not artifacts:
artifacts = writer._client.list_artifacts(run.info.run_id, "images")
artifact_paths = [artifact.path for artifact in artifacts]
self.assertGreaterEqual(len(artifact_paths), 1)
self.assertIn("test_image", artifact_paths[0])
Expand Down Expand Up @@ -141,7 +139,7 @@ def test_write_videos(self):
frames.append(frame)

videos = {
"noise_0": np.stack(frames, axis=0),
"zzz/noise_0": np.stack(frames, axis=0),
"noise_1": np.stack(frames, axis=0),
}
writer.write_videos(0, videos)
Expand All @@ -152,19 +150,17 @@ def test_write_videos(self):
self.assertEqual(len(runs), 1)
run = runs[0]

artifacts = writer._client.list_artifacts(run.info.run_id, "videos")
if not artifacts:
artifacts = writer._client.list_artifacts(run.info.run_id, "videos")

artifact_paths = [artifact.path for artifact in artifacts]
self.assertEqual(len(artifact_paths), 2)
self.assertTrue(
any(path.startswith("videos/noise_0") for path in artifact_paths)
)
self.assertTrue(
any(path.startswith("videos/noise_1") for path in artifact_paths)
artifacts_videos = writer._client.list_artifacts(run.info.run_id, "videos")
self.assertEqual(len(artifacts_videos), 2)
sorted_artifacts_videos = sorted(artifacts_videos, key=lambda x: x.path)
self.assertEqual(
sorted_artifacts_videos[0].path, "videos/noise_1_000000000.mp4"
)

artifacts_zzz = writer._client.list_artifacts(run.info.run_id, "videos/zzz")
self.assertEqual(len(artifacts_zzz), 1)
self.assertEqual(artifacts_zzz[0].path, "videos/zzz/noise_0_000000000.mp4")

def test_no_ops(self):
with tempfile.TemporaryDirectory() as temp_dir:
tracking_uri = f"file://{temp_dir}"
Expand Down

0 comments on commit 4e71685

Please sign in to comment.