Skip to content

Commit

Permalink
Improve project setup and CI (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
mickvangelderen authored Jan 13, 2025
1 parent eb1e8e0 commit 2670570
Show file tree
Hide file tree
Showing 35 changed files with 534 additions and 649 deletions.
45 changes: 30 additions & 15 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,34 @@ jobs:
with:
enable-cache: true
- name: pytest
run: uv run --extra test pytest --capture=no --verbose --cov --cov-report=xml --ignore-glob='jax_loop_utils/metric_writers/tf/*' --ignore-glob='jax_loop_utils/metric_writers/torch/*' --ignore-glob='jax_loop_utils/metric_writers/mlflow/*' --ignore-glob='jax_loop_utils/metric_writers/_audio_video/*' jax_loop_utils/
working-directory: src
run: |
uv sync
uv run -- pytest --capture=no --verbose --cov --cov-report=xml \
--ignore=src/jax_loop_utils/metric_writers/tf/ \
--ignore=src/jax_loop_utils/metric_writers/torch/ \
--ignore=src/jax_loop_utils/metric_writers/mlflow/ \
--ignore=src/jax_loop_utils/metric_writers/_audio_video/ \
src/jax_loop_utils/
- name: pytest tensorflow
run: uv run --extra test --extra tensorflow pytest --capture=no --verbose --cov --cov-report=xml --cov-append jax_loop_utils/metric_writers/tf
working-directory: src
run: |
uv sync --extra tensorflow
uv run -- pytest --capture=no --verbose --cov --cov-report=xml --cov-append \
src/jax_loop_utils/metric_writers/tf
- name: pytest torch
run: uv run --extra test --extra torch pytest --capture=no --verbose --cov --cov-report=xml --cov-append jax_loop_utils/metric_writers/torch
working-directory: src
run: |
uv sync --group dev-torch --extra torch
uv run -- pytest --capture=no --verbose --cov --cov-report=xml --cov-append \
src/jax_loop_utils/metric_writers/torch
- name: pytest mlflow
run: uv run --extra test --extra mlflow --extra audio-video pytest --capture=no --verbose --cov --cov-report=xml --cov-append jax_loop_utils/metric_writers/mlflow
working-directory: src
run: |
uv sync --extra mlflow --extra audio-video
uv run -- pytest --capture=no --verbose --cov --cov-report=xml --cov-append \
src/jax_loop_utils/metric_writers/mlflow
- name: pytest audio-video
run: uv run --extra test --extra audio-video pytest --capture=no --verbose --cov --cov-report=xml --cov-append jax_loop_utils/metric_writers/_audio_video
working-directory: src
run: |
uv sync --extra audio-video
uv run -- pytest --capture=no --verbose --cov --cov-report=xml --cov-append \
src/jax_loop_utils/metric_writers/_audio_video
- name: Upload coverage reports to Codecov
if: always()
uses: codecov/codecov-action@v4
Expand All @@ -45,8 +59,10 @@ jobs:
- uses: astral-sh/setup-uv@v3
with:
enable-cache: true
- name: ruff
run: uv run --with ruff ruff check src
- name: ruff format
run: uv run -- ruff format --check
- name: ruff check
run: uv run -- ruff check

pyright:
runs-on: ubuntu-24.04
Expand All @@ -58,6 +74,5 @@ jobs:
- name: uv sync
run: uv sync --all-extras
- name: pyright
# TODO: add more dependencies as we fix the violations
run: uv run pyright jax_loop_utils/metric_writers/
working-directory: src
# TODO: check more directories as we fix the violations
run: uv run -- pyright src/jax_loop_utils/metric_writers/
38 changes: 36 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,23 @@ Homepage = "http://github.com/Astera-org/jax_loop_utils"

[project.optional-dependencies]
mlflow = ["mlflow-skinny>=2.0", "Pillow"]
pyright = ["pyright"]
# for synopsis.ipynb
synopsis = ["chex", "flax", "ipykernel", "matplotlib"]
tensorflow = ["tensorflow>=2.12"]
test = ["chex", "pytest", "pytest-cov"]
torch = ["torch>=2.0"]
audio-video = ["av>=14.0"]

[dependency-groups]
dev = [
"chex>=0.1.87",
"pyright==1.1.391",
"pytest>=8.3.4",
"pytest-cov>=6.0.0",
"ruff>=0.9.1",
]
# Development dependencies for --extra torch
dev-torch = ["tensorflow>=2.12"]

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
Expand All @@ -62,3 +71,28 @@ filterwarnings = [
# action:message:category:module:line
"error",
]

[tool.ruff]
line-length = 100

[tool.ruff.lint]
select = [
"A", # flake8-builtins
"B", # flake8-bugbear
"E", # pycodestyle
"F", # Pyflakes
"I", # isort
"PT", # flake8-pytest-style
"SIM", # flake8-simplify
"UP", # pyupgrade
]
ignore = [
"A003", # builtin-attribute-shadowing
"SIM108", # Use the ternary operator
"UP007", # Allow Optional[type] instead of X | Y
"PT", # TODO: Remove
"A005", # builtin-module-shadowing
]

[tool.pyright]
typeCheckingMode = "standard"
9 changes: 4 additions & 5 deletions src/jax_loop_utils/asynclib.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import functools
import sys
import threading
from typing import Callable, List, Optional
from collections.abc import Callable
from typing import Optional

from absl import logging

Expand Down Expand Up @@ -104,7 +105,7 @@ def has_errors(self) -> bool:
"""Returns True if there are any pending errors."""
return bool(self._errors)

def clear_errors(self) -> List[Exception]:
def clear_errors(self) -> list[Exception]:
"""Clears all pending errors and returns them as a (possibly empty) list."""
with self._errors_mutex:
errors, self._errors = self._errors, collections.deque()
Expand Down Expand Up @@ -135,9 +136,7 @@ def trap_errors(*args, **kwargs):
except Exception as e:
with self._errors_mutex:
self._errors.append(sys.exc_info())
logging.exception(
"Error in producer thread for %s", self._thread_name_prefix
)
logging.exception("Error in producer thread for %s", self._thread_name_prefix)
raise e
finally:
self._queue_length -= 1
Expand Down
1 change: 1 addition & 0 deletions src/jax_loop_utils/asynclib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from unittest import mock

from absl.testing import absltest

from jax_loop_utils import asynclib


Expand Down
10 changes: 4 additions & 6 deletions src/jax_loop_utils/internal/flax/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,18 @@

"""Utilities for defining custom classes that can be used with jax transformations."""

from collections.abc import Callable
import dataclasses
import functools
from typing import dataclass_transform, TypeVar, overload
from collections.abc import Callable
from typing import TypeVar, dataclass_transform, overload

import jax

_T = TypeVar("_T")


def field(pytree_node=True, *, metadata=None, **kwargs):
return dataclasses.field(
metadata=(metadata or {}) | {"pytree_node": pytree_node}, **kwargs
)
return dataclasses.field(metadata=(metadata or {}) | {"pytree_node": pytree_node}, **kwargs)


@dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required]
Expand Down Expand Up @@ -123,7 +121,7 @@ class method that provides the smart constructor.
if "_flax_dataclass" in clz.__dict__:
return clz

if "frozen" not in kwargs.keys():
if "frozen" not in kwargs:
kwargs["frozen"] = True
data_clz = dataclasses.dataclass(**kwargs)(clz) # type: ignore
meta_fields = []
Expand Down
22 changes: 9 additions & 13 deletions src/jax_loop_utils/internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
import contextlib
import sys
import time
from typing import Any, List, Mapping, Tuple, Union

from absl import logging
from collections.abc import Mapping
from typing import Any, Union

import jax.numpy as jnp
import numpy as np
import wrapt
from absl import logging


@contextlib.contextmanager
Expand All @@ -37,9 +37,7 @@ def log_activity(activity_name: str):
dt = time.time() - t0
exc, *_ = sys.exc_info()
if exc is not None:
logging.exception(
"%s FAILED after %.2fs with %s.", activity_name, dt, exc.__name__
)
logging.exception("%s FAILED after %.2fs with %s.", activity_name, dt, exc.__name__)
else:
logging.info("%s finished after %.2fs.", activity_name, dt)

Expand Down Expand Up @@ -68,7 +66,7 @@ def check_param(value, *, ndim=None, dtype=jnp.float32):
A `ValueError` if `value` does not match `ndim` or `dtype`, or if `value`
is not an instance of `jnp.ndarray`.
"""
if not isinstance(value, (np.ndarray, jnp.ndarray)):
if not isinstance(value, np.ndarray | jnp.ndarray):
raise ValueError(f"Expected np.array or jnp.array, got type={type(value)}")
if ndim is not None and value.ndim != ndim:
raise ValueError(f"Expected ndim={ndim}, got ndim={value.ndim}")
Expand All @@ -77,8 +75,8 @@ def check_param(value, *, ndim=None, dtype=jnp.float32):


def flatten_dict(
d: Mapping[str, Any], prefix: Tuple[str, ...] = ()
) -> List[Tuple[str, Union[int, float, str]]]:
d: Mapping[str, Any], prefix: tuple[str, ...] = ()
) -> list[tuple[str, Union[int, float, str]]]:
"""Returns a sequence of flattened (k, v) pairs for tfsummary.hparams().
Args:
Expand All @@ -94,10 +92,8 @@ def flatten_dict(
# Note `ml_collections.ConfigDict` is not (yet) a `Mapping`.
if isinstance(v, Mapping) or hasattr(v, "items"):
ret += flatten_dict(v, prefix + (k,))
elif isinstance(v, (list, tuple)):
ret += flatten_dict(
{str(idx): value for idx, value in enumerate(v)}, prefix + (k,)
)
elif isinstance(v, list | tuple):
ret += flatten_dict({str(idx): value for idx, value in enumerate(v)}, prefix + (k,))
else:
ret.append((".".join(prefix + (k,)), v if v is not None else ""))
return ret
39 changes: 16 additions & 23 deletions src/jax_loop_utils/internal/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# limitations under the License.


import jax.numpy as jnp
from absl.testing import absltest

from jax_loop_utils.internal import utils
import jax.numpy as jnp


class TestError(BaseException):
Expand All @@ -27,27 +28,24 @@ class HelpersTest(absltest.TestCase):
def test_log_activity(
self,
):
with self.assertLogs() as logs:
with utils.log_activity("test_activity"):
pass
with self.assertLogs() as logs, utils.log_activity("test_activity"):
pass
self.assertLen(logs.output, 2)
self.assertEqual(logs.output[0], "INFO:absl:test_activity ...")
self.assertRegex(
logs.output[1], r"^INFO:absl:test_activity finished after \d+.\d\ds.$"
)
self.assertRegex(logs.output[1], r"^INFO:absl:test_activity finished after \d+.\d\ds.$")

def test_log_activity_fails(
self,
):
with self.assertRaises(TestError): # pylint: disable=g-error-prone-assert-raises, line-too-long
with self.assertLogs() as logs:
with utils.log_activity("test_activity"):
raise TestError()
with (
self.assertRaises(TestError),
self.assertLogs() as logs,
utils.log_activity("test_activity"),
):
raise TestError()
self.assertLen(logs.output, 2)
self.assertEqual(logs.output[0], "INFO:absl:test_activity ...")
self.assertRegex(
logs.output[1], r"^ERROR:absl:test_activity FAILED after \d+.\d\ds"
)
self.assertRegex(logs.output[1], r"^ERROR:absl:test_activity FAILED after \d+.\d\ds")

def test_logged_with(self):
@utils.logged_with("test_activity")
Expand All @@ -58,23 +56,18 @@ def test():
test()
self.assertLen(logs.output, 2)
self.assertEqual(logs.output[0], "INFO:absl:test_activity ...")
self.assertRegex(
logs.output[1], r"^INFO:absl:test_activity finished after \d+.\d\ds.$"
)
self.assertRegex(logs.output[1], r"^INFO:absl:test_activity finished after \d+.\d\ds.$")

def test_logged_with_fails(self):
@utils.logged_with("test_activity")
def test():
raise TestError()

with self.assertRaises(TestError): # pylint: disable=g-error-prone-assert-raises, line-too-long
with self.assertLogs() as logs:
test()
with self.assertRaises(TestError), self.assertLogs() as logs:
test()
self.assertLen(logs.output, 2)
self.assertEqual(logs.output[0], "INFO:absl:test_activity ...")
self.assertRegex(
logs.output[1], r"^ERROR:absl:test_activity FAILED after \d+.\d\ds"
)
self.assertRegex(logs.output[1], r"^ERROR:absl:test_activity FAILED after \d+.\d\ds")

def test_check_param(self):
a = jnp.array(0.0)
Expand Down
7 changes: 4 additions & 3 deletions src/jax_loop_utils/metric_writers/_audio_video/audio_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,20 @@ def encode_video(video_array: Array, destination: io.IOBase):

if (
np.issubdtype(video_array.dtype, np.floating)
and np.all(0 <= video_array)
and np.all(video_array >= 0)
and np.all(video_array <= 1.0)
):
video_array = (video_array * 255).astype(np.uint8)
elif (
np.issubdtype(video_array.dtype, np.integer)
and np.all(0 <= video_array)
and np.all(video_array >= 0)
and np.all(video_array <= 255)
):
video_array = video_array.astype(np.uint8)
else:
raise ValueError(
f"Expected video_array to be floats in [0, 1] or ints in [0, 255], got {video_array.dtype}"
"Expected video_array to be floats in [0, 1] "
f"or ints in [0, 255], got {video_array.dtype}"
)

T, H, W, C = video_array.shape
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ def test_encode_video_invalid_args(self):
encode_video(invalid_shape, io.BytesIO())

invalid_dtype = 2 * np.ones((10, 20, 30, 3), dtype=np.float32)
with self.assertRaisesRegex(
ValueError, r"Expected video_array to be floats in \[0, 1\]"
):
with self.assertRaisesRegex(ValueError, r"Expected video_array to be floats in \[0, 1\]"):
encode_video(invalid_dtype, io.BytesIO())

def test_encode_video_success(self):
Expand Down
Loading

0 comments on commit 2670570

Please sign in to comment.