Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve project setup and CI #21

Merged
merged 5 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 30 additions & 15 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,34 @@ jobs:
with:
enable-cache: true
- name: pytest
run: uv run --extra test pytest --capture=no --verbose --cov --cov-report=xml --ignore-glob='jax_loop_utils/metric_writers/tf/*' --ignore-glob='jax_loop_utils/metric_writers/torch/*' --ignore-glob='jax_loop_utils/metric_writers/mlflow/*' --ignore-glob='jax_loop_utils/metric_writers/_audio_video/*' jax_loop_utils/
working-directory: src
run: |
uv sync
uv run -- pytest --capture=no --verbose --cov --cov-report=xml \
--ignore=src/jax_loop_utils/metric_writers/tf/ \
--ignore=src/jax_loop_utils/metric_writers/torch/ \
--ignore=src/jax_loop_utils/metric_writers/mlflow/ \
--ignore=src/jax_loop_utils/metric_writers/_audio_video/ \
src/jax_loop_utils/
- name: pytest tensorflow
run: uv run --extra test --extra tensorflow pytest --capture=no --verbose --cov --cov-report=xml --cov-append jax_loop_utils/metric_writers/tf
working-directory: src
run: |
uv sync --extra tensorflow
uv run -- pytest --capture=no --verbose --cov --cov-report=xml --cov-append \
src/jax_loop_utils/metric_writers/tf
- name: pytest torch
run: uv run --extra test --extra torch pytest --capture=no --verbose --cov --cov-report=xml --cov-append jax_loop_utils/metric_writers/torch
working-directory: src
run: |
uv sync --group dev-torch --extra torch
uv run -- pytest --capture=no --verbose --cov --cov-report=xml --cov-append \
src/jax_loop_utils/metric_writers/torch
- name: pytest mlflow
run: uv run --extra test --extra mlflow --extra audio-video pytest --capture=no --verbose --cov --cov-report=xml --cov-append jax_loop_utils/metric_writers/mlflow
working-directory: src
run: |
uv sync --extra mlflow --extra audio-video
uv run -- pytest --capture=no --verbose --cov --cov-report=xml --cov-append \
src/jax_loop_utils/metric_writers/mlflow
- name: pytest audio-video
run: uv run --extra test --extra audio-video pytest --capture=no --verbose --cov --cov-report=xml --cov-append jax_loop_utils/metric_writers/_audio_video
working-directory: src
run: |
uv sync --extra audio-video
uv run -- pytest --capture=no --verbose --cov --cov-report=xml --cov-append \
src/jax_loop_utils/metric_writers/_audio_video
- name: Upload coverage reports to Codecov
if: always()
uses: codecov/codecov-action@v4
Expand All @@ -45,8 +59,10 @@ jobs:
- uses: astral-sh/setup-uv@v3
with:
enable-cache: true
- name: ruff
run: uv run --with ruff ruff check src
- name: ruff format
run: uv run -- ruff format --check
- name: ruff check
run: uv run -- ruff check

pyright:
runs-on: ubuntu-24.04
Expand All @@ -58,6 +74,5 @@ jobs:
- name: uv sync
run: uv sync --all-extras
- name: pyright
# TODO: add more dependencies as we fix the violations
run: uv run pyright jax_loop_utils/metric_writers/
working-directory: src
# TODO: check more directories as we fix the violations
run: uv run -- pyright src/jax_loop_utils/metric_writers/
38 changes: 36 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,23 @@ Homepage = "http://github.com/Astera-org/jax_loop_utils"

[project.optional-dependencies]
mlflow = ["mlflow-skinny>=2.0", "Pillow"]
pyright = ["pyright"]
# for synopsis.ipynb
synopsis = ["chex", "flax", "ipykernel", "matplotlib"]
tensorflow = ["tensorflow>=2.12"]
test = ["chex", "pytest", "pytest-cov"]
torch = ["torch>=2.0"]
audio-video = ["av>=14.0"]

[dependency-groups]
dev = [
"chex>=0.1.87",
"pyright==1.1.391",
"pytest>=8.3.4",
"pytest-cov>=6.0.0",
"ruff>=0.9.1",
]
# Development dependencies for --extra torch
dev-torch = ["tensorflow>=2.12"]

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
Expand All @@ -62,3 +71,28 @@ filterwarnings = [
# action:message:category:module:line
"error",
]

[tool.ruff]
line-length = 100

[tool.ruff.lint]
select = [
"A", # flake8-builtins
"B", # flake8-bugbear
"E", # pycodestyle
"F", # Pyflakes
"I", # isort
"PT", # flake8-pytest-style
"SIM", # flake8-simplify
"UP", # pyupgrade
]
ignore = [
"A003", # builtin-attribute-shadowing
"SIM108", # Use the ternary operator
"UP007", # Allow Optional[type] instead of X | Y
"PT", # TODO: Remove
"A005", # builtin-module-shadowing
]

[tool.pyright]
typeCheckingMode = "standard"
9 changes: 4 additions & 5 deletions src/jax_loop_utils/asynclib.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import functools
import sys
import threading
from typing import Callable, List, Optional
from collections.abc import Callable
from typing import Optional

from absl import logging

Expand Down Expand Up @@ -104,7 +105,7 @@ def has_errors(self) -> bool:
"""Returns True if there are any pending errors."""
return bool(self._errors)

def clear_errors(self) -> List[Exception]:
def clear_errors(self) -> list[Exception]:
"""Clears all pending errors and returns them as a (possibly empty) list."""
with self._errors_mutex:
errors, self._errors = self._errors, collections.deque()
Expand Down Expand Up @@ -135,9 +136,7 @@ def trap_errors(*args, **kwargs):
except Exception as e:
with self._errors_mutex:
self._errors.append(sys.exc_info())
logging.exception(
"Error in producer thread for %s", self._thread_name_prefix
)
logging.exception("Error in producer thread for %s", self._thread_name_prefix)
raise e
finally:
self._queue_length -= 1
Expand Down
1 change: 1 addition & 0 deletions src/jax_loop_utils/asynclib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from unittest import mock

from absl.testing import absltest

from jax_loop_utils import asynclib


Expand Down
10 changes: 4 additions & 6 deletions src/jax_loop_utils/internal/flax/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,18 @@

"""Utilities for defining custom classes that can be used with jax transformations."""

from collections.abc import Callable
import dataclasses
import functools
from typing import dataclass_transform, TypeVar, overload
from collections.abc import Callable
from typing import TypeVar, dataclass_transform, overload

import jax

_T = TypeVar("_T")


def field(pytree_node=True, *, metadata=None, **kwargs):
return dataclasses.field(
metadata=(metadata or {}) | {"pytree_node": pytree_node}, **kwargs
)
return dataclasses.field(metadata=(metadata or {}) | {"pytree_node": pytree_node}, **kwargs)

Check warning on line 28 in src/jax_loop_utils/internal/flax/struct.py

View check run for this annotation

Codecov / codecov/patch

src/jax_loop_utils/internal/flax/struct.py#L28

Added line #L28 was not covered by tests


@dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required]
Expand Down Expand Up @@ -123,7 +121,7 @@
if "_flax_dataclass" in clz.__dict__:
return clz

if "frozen" not in kwargs.keys():
if "frozen" not in kwargs:
kwargs["frozen"] = True
data_clz = dataclasses.dataclass(**kwargs)(clz) # type: ignore
meta_fields = []
Expand Down
22 changes: 9 additions & 13 deletions src/jax_loop_utils/internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
import contextlib
import sys
import time
from typing import Any, List, Mapping, Tuple, Union

from absl import logging
from collections.abc import Mapping
from typing import Any, Union

import jax.numpy as jnp
import numpy as np
import wrapt
from absl import logging


@contextlib.contextmanager
Expand All @@ -37,9 +37,7 @@ def log_activity(activity_name: str):
dt = time.time() - t0
exc, *_ = sys.exc_info()
if exc is not None:
logging.exception(
"%s FAILED after %.2fs with %s.", activity_name, dt, exc.__name__
)
logging.exception("%s FAILED after %.2fs with %s.", activity_name, dt, exc.__name__)
else:
logging.info("%s finished after %.2fs.", activity_name, dt)

Expand Down Expand Up @@ -68,7 +66,7 @@ def check_param(value, *, ndim=None, dtype=jnp.float32):
A `ValueError` if `value` does not match `ndim` or `dtype`, or if `value`
is not an instance of `jnp.ndarray`.
"""
if not isinstance(value, (np.ndarray, jnp.ndarray)):
if not isinstance(value, np.ndarray | jnp.ndarray):
raise ValueError(f"Expected np.array or jnp.array, got type={type(value)}")
if ndim is not None and value.ndim != ndim:
raise ValueError(f"Expected ndim={ndim}, got ndim={value.ndim}")
Expand All @@ -77,8 +75,8 @@ def check_param(value, *, ndim=None, dtype=jnp.float32):


def flatten_dict(
d: Mapping[str, Any], prefix: Tuple[str, ...] = ()
) -> List[Tuple[str, Union[int, float, str]]]:
d: Mapping[str, Any], prefix: tuple[str, ...] = ()
) -> list[tuple[str, Union[int, float, str]]]:
"""Returns a sequence of flattened (k, v) pairs for tfsummary.hparams().

Args:
Expand All @@ -94,10 +92,8 @@ def flatten_dict(
# Note `ml_collections.ConfigDict` is not (yet) a `Mapping`.
if isinstance(v, Mapping) or hasattr(v, "items"):
ret += flatten_dict(v, prefix + (k,))
elif isinstance(v, (list, tuple)):
ret += flatten_dict(
{str(idx): value for idx, value in enumerate(v)}, prefix + (k,)
)
elif isinstance(v, list | tuple):
ret += flatten_dict({str(idx): value for idx, value in enumerate(v)}, prefix + (k,))
else:
ret.append((".".join(prefix + (k,)), v if v is not None else ""))
return ret
39 changes: 16 additions & 23 deletions src/jax_loop_utils/internal/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# limitations under the License.


import jax.numpy as jnp
from absl.testing import absltest

from jax_loop_utils.internal import utils
import jax.numpy as jnp


class TestError(BaseException):
Expand All @@ -27,27 +28,24 @@ class HelpersTest(absltest.TestCase):
def test_log_activity(
self,
):
with self.assertLogs() as logs:
with utils.log_activity("test_activity"):
pass
with self.assertLogs() as logs, utils.log_activity("test_activity"):
pass
self.assertLen(logs.output, 2)
self.assertEqual(logs.output[0], "INFO:absl:test_activity ...")
self.assertRegex(
logs.output[1], r"^INFO:absl:test_activity finished after \d+.\d\ds.$"
)
self.assertRegex(logs.output[1], r"^INFO:absl:test_activity finished after \d+.\d\ds.$")

def test_log_activity_fails(
self,
):
with self.assertRaises(TestError): # pylint: disable=g-error-prone-assert-raises, line-too-long
with self.assertLogs() as logs:
with utils.log_activity("test_activity"):
raise TestError()
with (
self.assertRaises(TestError),
self.assertLogs() as logs,
utils.log_activity("test_activity"),
):
raise TestError()
self.assertLen(logs.output, 2)
self.assertEqual(logs.output[0], "INFO:absl:test_activity ...")
self.assertRegex(
logs.output[1], r"^ERROR:absl:test_activity FAILED after \d+.\d\ds"
)
self.assertRegex(logs.output[1], r"^ERROR:absl:test_activity FAILED after \d+.\d\ds")

def test_logged_with(self):
@utils.logged_with("test_activity")
Expand All @@ -58,23 +56,18 @@ def test():
test()
self.assertLen(logs.output, 2)
self.assertEqual(logs.output[0], "INFO:absl:test_activity ...")
self.assertRegex(
logs.output[1], r"^INFO:absl:test_activity finished after \d+.\d\ds.$"
)
self.assertRegex(logs.output[1], r"^INFO:absl:test_activity finished after \d+.\d\ds.$")

def test_logged_with_fails(self):
@utils.logged_with("test_activity")
def test():
raise TestError()

with self.assertRaises(TestError): # pylint: disable=g-error-prone-assert-raises, line-too-long
with self.assertLogs() as logs:
test()
with self.assertRaises(TestError), self.assertLogs() as logs:
test()
self.assertLen(logs.output, 2)
self.assertEqual(logs.output[0], "INFO:absl:test_activity ...")
self.assertRegex(
logs.output[1], r"^ERROR:absl:test_activity FAILED after \d+.\d\ds"
)
self.assertRegex(logs.output[1], r"^ERROR:absl:test_activity FAILED after \d+.\d\ds")

def test_check_param(self):
a = jnp.array(0.0)
Expand Down
7 changes: 4 additions & 3 deletions src/jax_loop_utils/metric_writers/_audio_video/audio_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,20 @@ def encode_video(video_array: Array, destination: io.IOBase):

if (
np.issubdtype(video_array.dtype, np.floating)
and np.all(0 <= video_array)
and np.all(video_array >= 0)
and np.all(video_array <= 1.0)
):
video_array = (video_array * 255).astype(np.uint8)
elif (
np.issubdtype(video_array.dtype, np.integer)
and np.all(0 <= video_array)
and np.all(video_array >= 0)
and np.all(video_array <= 255)
):
video_array = video_array.astype(np.uint8)
else:
raise ValueError(
f"Expected video_array to be floats in [0, 1] or ints in [0, 255], got {video_array.dtype}"
"Expected video_array to be floats in [0, 1] "
f"or ints in [0, 255], got {video_array.dtype}"
)

T, H, W, C = video_array.shape
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ def test_encode_video_invalid_args(self):
encode_video(invalid_shape, io.BytesIO())

invalid_dtype = 2 * np.ones((10, 20, 30, 3), dtype=np.float32)
with self.assertRaisesRegex(
ValueError, r"Expected video_array to be floats in \[0, 1\]"
):
with self.assertRaisesRegex(ValueError, r"Expected video_array to be floats in \[0, 1\]"):
encode_video(invalid_dtype, io.BytesIO())

def test_encode_video_success(self):
Expand Down
Loading