Skip to content

Commit

Permalink
parallelize writing multiple videos
Browse files Browse the repository at this point in the history
Change-Id: I0de8095d3322c674eea4d2a7a1d9c973eb723ca5
  • Loading branch information
garymm committed Dec 23, 2024
1 parent 884d673 commit 5864abe
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 23 deletions.
4 changes: 2 additions & 2 deletions src/jax_loop_utils/asynclib.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class Pool:
Synopsis:
from clu.internal import asynclib
from jax_loop_utils import asynclib
pool = asynclib.Pool()
@pool
Expand All @@ -54,7 +54,7 @@ def __init__(self, thread_name_prefix: str = "", max_workers: Optional[int] = No
thread_name_prefix: See documentation of `ThreadPoolExecutor`.
max_workers: See documentation of `ThreadPoolExecutor`. The default `None`
optimizes for parallelizability using the number of CPU cores. If you
specify `max_workers=1` you the async calls are executed in the same
specify `max_workers=1` the async calls are executed in the same
order they have been scheduled.
"""
self._pool = concurrent.futures.ThreadPoolExecutor(
Expand Down
50 changes: 36 additions & 14 deletions src/jax_loop_utils/metric_writers/mlflow/metric_writer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""MLflow implementation of MetricWriter interface."""

import os
import pathlib
import tempfile
import time
from collections.abc import Mapping
from time import time
from typing import Any

import mlflow
Expand All @@ -15,6 +16,7 @@
import numpy as np
from absl import logging

from jax_loop_utils import asynclib
from jax_loop_utils.metric_writers.interface import (
Array,
MetricWriter,
Expand All @@ -27,6 +29,10 @@
_audio_video = None


def _noop_decorator(func):
return func


class MlflowMetricWriter(MetricWriter):
"""Writes metrics to MLflow Tracking."""

Expand Down Expand Up @@ -77,7 +83,7 @@ def __init__(

def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
"""Write scalar metrics to MLflow."""
timestamp = int(time() * 1000)
timestamp = int(time.time() * 1000)
metrics_list = [
mlflow.entities.Metric(k, float(v), timestamp, step)
for k, v in scalars.items()
Expand Down Expand Up @@ -108,20 +114,36 @@ def write_videos(self, step: int, videos: Mapping[str, Array]):
)
return

temp_dir = tempfile.mkdtemp()
pool = asynclib.Pool()

for key, video_array in videos.items():
local_path = (
pathlib.Path(temp_dir)
/ f"{key}_{step:09d}.{_audio_video.CONTAINER_FORMAT}"
)
with open(local_path, "wb") as f:
_audio_video.encode_video(video_array, f)
self._client.log_artifact(
self._run_id,
local_path,
artifact_path="videos",
if len(videos) > 1:
maybe_async = pool
else:
maybe_async = _noop_decorator

encode_and_log = maybe_async(self._encode_and_log_video)

temp_dir = pathlib.Path(tempfile.mkdtemp())
paths_arrays = [
(
temp_dir / f"{key}_{step:09d}.{_audio_video.CONTAINER_FORMAT}",
video_array,
)
for key, video_array in videos.items()
]

for path, video_array in paths_arrays:
encode_and_log(path, video_array)

pool.close()

def _encode_and_log_video(self, path: pathlib.Path, video_array: Array):
with open(path, "wb") as f:
_audio_video.encode_video(video_array, f) # pyright: ignore[reportOptionalMemberAccess]
# If log_artifact(synchronous=False) existed,
# we could synchronize with self.flush() rather than at the end of write_videos.
# https://github.com/mlflow/mlflow/issues/14153
self._client.log_artifact(self._run_id, path, os.path.join("videos", path.name))

def write_audios(self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
"""MLflow doesn't support audio logging directly."""
Expand Down
30 changes: 23 additions & 7 deletions src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import tempfile
import time

import jax.numpy as jnp
import mlflow
Expand Down Expand Up @@ -85,9 +84,6 @@ def test_write_images(self):
# the string "images" is hardcoded in MlflowClient.log_image.
artifacts = writer._client.list_artifacts(run.info.run_id, "images")
if not artifacts:
# have seen some latency in artifacts becoming available
# Maybe file system sync? Not sure.
time.sleep(0.1)
artifacts = writer._client.list_artifacts(run.info.run_id, "images")
artifact_paths = [artifact.path for artifact in artifacts]
self.assertGreaterEqual(len(artifact_paths), 1)
Expand Down Expand Up @@ -144,11 +140,31 @@ def test_write_videos(self):
frame = np.random.randint(0, 256, (64, 64, 3), dtype=np.uint8)
frames.append(frame)

# Stack frames into video array [frames, height, width, channels]
video = np.stack(frames, axis=0)
writer.write_videos(0, {"noise_video": video})
videos = {
"noise_0": np.stack(frames, axis=0),
"noise_1": np.stack(frames, axis=0),
}
writer.write_videos(0, videos)
writer.close()

# Verify artifacts were written
runs = _get_runs(tracking_uri, experiment_name)
self.assertEqual(len(runs), 1)
run = runs[0]

artifacts = writer._client.list_artifacts(run.info.run_id, "videos")
if not artifacts:
artifacts = writer._client.list_artifacts(run.info.run_id, "videos")

artifact_paths = [artifact.path for artifact in artifacts]
self.assertEqual(len(artifact_paths), 2)
self.assertTrue(
any(path.startswith("videos/noise_0") for path in artifact_paths)
)
self.assertTrue(
any(path.startswith("videos/noise_1") for path in artifact_paths)
)

def test_no_ops(self):
with tempfile.TemporaryDirectory() as temp_dir:
tracking_uri = f"file://{temp_dir}"
Expand Down

0 comments on commit 5864abe

Please sign in to comment.