diff --git a/src/jax_loop_utils/metric_writers/mlflow/metric_writer.py b/src/jax_loop_utils/metric_writers/mlflow/metric_writer.py index 81d889f..95c4a35 100644 --- a/src/jax_loop_utils/metric_writers/mlflow/metric_writer.py +++ b/src/jax_loop_utils/metric_writers/mlflow/metric_writer.py @@ -2,6 +2,7 @@ import os import pathlib +import shutil import tempfile import time from collections.abc import Mapping @@ -123,27 +124,36 @@ def write_videos(self, step: int, videos: Mapping[str, Array]): encode_and_log = maybe_async(self._encode_and_log_video) - temp_dir = pathlib.Path(tempfile.mkdtemp()) paths_arrays = [ ( - temp_dir / f"{key}_{step:09d}.{_audio_video.CONTAINER_FORMAT}", + f"{key}_{step:09d}.{_audio_video.CONTAINER_FORMAT}", video_array, ) for key, video_array in videos.items() ] + temp_dir = pathlib.Path(tempfile.mkdtemp()) for path, video_array in paths_arrays: - encode_and_log(path, video_array) + encode_and_log(temp_dir, path, video_array) pool.close() + shutil.rmtree(temp_dir) - def _encode_and_log_video(self, path: pathlib.Path, video_array: Array): - with open(path, "wb") as f: + def _encode_and_log_video( + self, temp_dir: pathlib.Path, rel_path: str, video_array: Array + ): + temp_path = temp_dir / rel_path + # handle keys with slashes + if not temp_path.parent.exists(): + temp_path.parent.mkdir(parents=True, exist_ok=True) + with open(temp_path, "wb") as f: _audio_video.encode_video(video_array, f) # pyright: ignore[reportOptionalMemberAccess] # If log_artifact(synchronous=False) existed, # we could synchronize with self.flush() rather than at the end of write_videos. # https://github.com/mlflow/mlflow/issues/14153 - self._client.log_artifact(self._run_id, path, os.path.join("videos", path.name)) + self._client.log_artifact( + self._run_id, temp_path, os.path.join("videos", rel_path) + ) def write_audios(self, step: int, audios: Mapping[str, Array], *, sample_rate: int): """MLflow doesn't support audio logging directly.""" diff --git a/src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py b/src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py index c7f265e..3771fcf 100644 --- a/src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py +++ b/src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py @@ -83,8 +83,6 @@ def test_write_images(self): run = runs[0] # the string "images" is hardcoded in MlflowClient.log_image. artifacts = writer._client.list_artifacts(run.info.run_id, "images") - if not artifacts: - artifacts = writer._client.list_artifacts(run.info.run_id, "images") artifact_paths = [artifact.path for artifact in artifacts] self.assertGreaterEqual(len(artifact_paths), 1) self.assertIn("test_image", artifact_paths[0]) @@ -141,7 +139,7 @@ def test_write_videos(self): frames.append(frame) videos = { - "noise_0": np.stack(frames, axis=0), + "zzz/noise_0": np.stack(frames, axis=0), "noise_1": np.stack(frames, axis=0), } writer.write_videos(0, videos) @@ -152,19 +150,17 @@ def test_write_videos(self): self.assertEqual(len(runs), 1) run = runs[0] - artifacts = writer._client.list_artifacts(run.info.run_id, "videos") - if not artifacts: - artifacts = writer._client.list_artifacts(run.info.run_id, "videos") - - artifact_paths = [artifact.path for artifact in artifacts] - self.assertEqual(len(artifact_paths), 2) - self.assertTrue( - any(path.startswith("videos/noise_0") for path in artifact_paths) - ) - self.assertTrue( - any(path.startswith("videos/noise_1") for path in artifact_paths) + artifacts_videos = writer._client.list_artifacts(run.info.run_id, "videos") + self.assertEqual(len(artifacts_videos), 2) + sorted_artifacts_videos = sorted(artifacts_videos, key=lambda x: x.path) + self.assertEqual( + sorted_artifacts_videos[0].path, "videos/noise_1_000000000.mp4" ) + artifacts_zzz = writer._client.list_artifacts(run.info.run_id, "videos/zzz") + self.assertEqual(len(artifacts_zzz), 1) + self.assertEqual(artifacts_zzz[0].path, "videos/zzz/noise_0_000000000.mp4") + def test_no_ops(self): with tempfile.TemporaryDirectory() as temp_dir: tracking_uri = f"file://{temp_dir}"