diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index e32db68..a6c67a2 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -16,20 +16,34 @@ jobs: with: enable-cache: true - name: pytest - run: uv run --extra test pytest --capture=no --verbose --cov --cov-report=xml --ignore-glob='jax_loop_utils/metric_writers/tf/*' --ignore-glob='jax_loop_utils/metric_writers/torch/*' --ignore-glob='jax_loop_utils/metric_writers/mlflow/*' --ignore-glob='jax_loop_utils/metric_writers/_audio_video/*' jax_loop_utils/ - working-directory: src + run: | + uv sync + uv run -- pytest --capture=no --verbose --cov --cov-report=xml \ + --ignore=src/jax_loop_utils/metric_writers/tf/ \ + --ignore=src/jax_loop_utils/metric_writers/torch/ \ + --ignore=src/jax_loop_utils/metric_writers/mlflow/ \ + --ignore=src/jax_loop_utils/metric_writers/_audio_video/ \ + src/jax_loop_utils/ - name: pytest tensorflow - run: uv run --extra test --extra tensorflow pytest --capture=no --verbose --cov --cov-report=xml --cov-append jax_loop_utils/metric_writers/tf - working-directory: src + run: | + uv sync --extra tensorflow + uv run -- pytest --capture=no --verbose --cov --cov-report=xml --cov-append \ + src/jax_loop_utils/metric_writers/tf - name: pytest torch - run: uv run --extra test --extra torch pytest --capture=no --verbose --cov --cov-report=xml --cov-append jax_loop_utils/metric_writers/torch - working-directory: src + run: | + uv sync --group dev-torch --extra torch + uv run -- pytest --capture=no --verbose --cov --cov-report=xml --cov-append \ + src/jax_loop_utils/metric_writers/torch - name: pytest mlflow - run: uv run --extra test --extra mlflow --extra audio-video pytest --capture=no --verbose --cov --cov-report=xml --cov-append jax_loop_utils/metric_writers/mlflow - working-directory: src + run: | + uv sync --extra mlflow --extra audio-video + uv run -- pytest --capture=no --verbose --cov --cov-report=xml --cov-append \ + src/jax_loop_utils/metric_writers/mlflow - name: pytest audio-video - run: uv run --extra test --extra audio-video pytest --capture=no --verbose --cov --cov-report=xml --cov-append jax_loop_utils/metric_writers/_audio_video - working-directory: src + run: | + uv sync --extra audio-video + uv run -- pytest --capture=no --verbose --cov --cov-report=xml --cov-append \ + src/jax_loop_utils/metric_writers/_audio_video - name: Upload coverage reports to Codecov if: always() uses: codecov/codecov-action@v4 @@ -45,8 +59,10 @@ jobs: - uses: astral-sh/setup-uv@v3 with: enable-cache: true - - name: ruff - run: uv run --with ruff ruff check src + - name: ruff format + run: uv run -- ruff format --check + - name: ruff check + run: uv run -- ruff check pyright: runs-on: ubuntu-24.04 @@ -58,6 +74,5 @@ jobs: - name: uv sync run: uv sync --all-extras - name: pyright - # TODO: add more dependencies as we fix the violations - run: uv run pyright jax_loop_utils/metric_writers/ - working-directory: src + # TODO: check more directories as we fix the violations + run: uv run -- pyright src/jax_loop_utils/metric_writers/ diff --git a/pyproject.toml b/pyproject.toml index 36c937c..cb9e6c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,14 +29,23 @@ Homepage = "http://github.com/Astera-org/jax_loop_utils" [project.optional-dependencies] mlflow = ["mlflow-skinny>=2.0", "Pillow"] -pyright = ["pyright"] # for synopsis.ipynb synopsis = ["chex", "flax", "ipykernel", "matplotlib"] tensorflow = ["tensorflow>=2.12"] -test = ["chex", "pytest", "pytest-cov"] torch = ["torch>=2.0"] audio-video = ["av>=14.0"] +[dependency-groups] +dev = [ + "chex>=0.1.87", + "pyright==1.1.391", + "pytest>=8.3.4", + "pytest-cov>=6.0.0", + "ruff>=0.9.1", +] +# Development dependencies for --extra torch +dev-torch = ["tensorflow>=2.12"] + [build-system] requires = ["hatchling"] build-backend = "hatchling.build" @@ -62,3 +71,28 @@ filterwarnings = [ # action:message:category:module:line "error", ] + +[tool.ruff] +line-length = 100 + +[tool.ruff.lint] +select = [ + "A", # flake8-builtins + "B", # flake8-bugbear + "E", # pycodestyle + "F", # Pyflakes + "I", # isort + "PT", # flake8-pytest-style + "SIM", # flake8-simplify + "UP", # pyupgrade +] +ignore = [ + "A003", # builtin-attribute-shadowing + "SIM108", # Use the ternary operator + "UP007", # Allow Optional[type] instead of X | Y + "PT", # TODO: Remove + "A005", # builtin-module-shadowing +] + +[tool.pyright] +typeCheckingMode = "standard" diff --git a/src/jax_loop_utils/asynclib.py b/src/jax_loop_utils/asynclib.py index b091095..554877a 100644 --- a/src/jax_loop_utils/asynclib.py +++ b/src/jax_loop_utils/asynclib.py @@ -19,7 +19,8 @@ import functools import sys import threading -from typing import Callable, List, Optional +from collections.abc import Callable +from typing import Optional from absl import logging @@ -104,7 +105,7 @@ def has_errors(self) -> bool: """Returns True if there are any pending errors.""" return bool(self._errors) - def clear_errors(self) -> List[Exception]: + def clear_errors(self) -> list[Exception]: """Clears all pending errors and returns them as a (possibly empty) list.""" with self._errors_mutex: errors, self._errors = self._errors, collections.deque() @@ -135,9 +136,7 @@ def trap_errors(*args, **kwargs): except Exception as e: with self._errors_mutex: self._errors.append(sys.exc_info()) - logging.exception( - "Error in producer thread for %s", self._thread_name_prefix - ) + logging.exception("Error in producer thread for %s", self._thread_name_prefix) raise e finally: self._queue_length -= 1 diff --git a/src/jax_loop_utils/asynclib_test.py b/src/jax_loop_utils/asynclib_test.py index 274b9c5..70d008e 100644 --- a/src/jax_loop_utils/asynclib_test.py +++ b/src/jax_loop_utils/asynclib_test.py @@ -17,6 +17,7 @@ from unittest import mock from absl.testing import absltest + from jax_loop_utils import asynclib diff --git a/src/jax_loop_utils/internal/flax/struct.py b/src/jax_loop_utils/internal/flax/struct.py index 54a1576..6cc76e0 100644 --- a/src/jax_loop_utils/internal/flax/struct.py +++ b/src/jax_loop_utils/internal/flax/struct.py @@ -14,10 +14,10 @@ """Utilities for defining custom classes that can be used with jax transformations.""" -from collections.abc import Callable import dataclasses import functools -from typing import dataclass_transform, TypeVar, overload +from collections.abc import Callable +from typing import TypeVar, dataclass_transform, overload import jax @@ -25,9 +25,7 @@ def field(pytree_node=True, *, metadata=None, **kwargs): - return dataclasses.field( - metadata=(metadata or {}) | {"pytree_node": pytree_node}, **kwargs - ) + return dataclasses.field(metadata=(metadata or {}) | {"pytree_node": pytree_node}, **kwargs) @dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required] @@ -123,7 +121,7 @@ class method that provides the smart constructor. if "_flax_dataclass" in clz.__dict__: return clz - if "frozen" not in kwargs.keys(): + if "frozen" not in kwargs: kwargs["frozen"] = True data_clz = dataclasses.dataclass(**kwargs)(clz) # type: ignore meta_fields = [] diff --git a/src/jax_loop_utils/internal/utils.py b/src/jax_loop_utils/internal/utils.py index 9db282e..1803b80 100644 --- a/src/jax_loop_utils/internal/utils.py +++ b/src/jax_loop_utils/internal/utils.py @@ -17,13 +17,13 @@ import contextlib import sys import time -from typing import Any, List, Mapping, Tuple, Union - -from absl import logging +from collections.abc import Mapping +from typing import Any, Union import jax.numpy as jnp import numpy as np import wrapt +from absl import logging @contextlib.contextmanager @@ -37,9 +37,7 @@ def log_activity(activity_name: str): dt = time.time() - t0 exc, *_ = sys.exc_info() if exc is not None: - logging.exception( - "%s FAILED after %.2fs with %s.", activity_name, dt, exc.__name__ - ) + logging.exception("%s FAILED after %.2fs with %s.", activity_name, dt, exc.__name__) else: logging.info("%s finished after %.2fs.", activity_name, dt) @@ -68,7 +66,7 @@ def check_param(value, *, ndim=None, dtype=jnp.float32): A `ValueError` if `value` does not match `ndim` or `dtype`, or if `value` is not an instance of `jnp.ndarray`. """ - if not isinstance(value, (np.ndarray, jnp.ndarray)): + if not isinstance(value, np.ndarray | jnp.ndarray): raise ValueError(f"Expected np.array or jnp.array, got type={type(value)}") if ndim is not None and value.ndim != ndim: raise ValueError(f"Expected ndim={ndim}, got ndim={value.ndim}") @@ -77,8 +75,8 @@ def check_param(value, *, ndim=None, dtype=jnp.float32): def flatten_dict( - d: Mapping[str, Any], prefix: Tuple[str, ...] = () -) -> List[Tuple[str, Union[int, float, str]]]: + d: Mapping[str, Any], prefix: tuple[str, ...] = () +) -> list[tuple[str, Union[int, float, str]]]: """Returns a sequence of flattened (k, v) pairs for tfsummary.hparams(). Args: @@ -94,10 +92,8 @@ def flatten_dict( # Note `ml_collections.ConfigDict` is not (yet) a `Mapping`. if isinstance(v, Mapping) or hasattr(v, "items"): ret += flatten_dict(v, prefix + (k,)) - elif isinstance(v, (list, tuple)): - ret += flatten_dict( - {str(idx): value for idx, value in enumerate(v)}, prefix + (k,) - ) + elif isinstance(v, list | tuple): + ret += flatten_dict({str(idx): value for idx, value in enumerate(v)}, prefix + (k,)) else: ret.append((".".join(prefix + (k,)), v if v is not None else "")) return ret diff --git a/src/jax_loop_utils/internal/utils_test.py b/src/jax_loop_utils/internal/utils_test.py index 1344fb6..67bc11d 100644 --- a/src/jax_loop_utils/internal/utils_test.py +++ b/src/jax_loop_utils/internal/utils_test.py @@ -13,9 +13,10 @@ # limitations under the License. +import jax.numpy as jnp from absl.testing import absltest + from jax_loop_utils.internal import utils -import jax.numpy as jnp class TestError(BaseException): @@ -27,27 +28,24 @@ class HelpersTest(absltest.TestCase): def test_log_activity( self, ): - with self.assertLogs() as logs: - with utils.log_activity("test_activity"): - pass + with self.assertLogs() as logs, utils.log_activity("test_activity"): + pass self.assertLen(logs.output, 2) self.assertEqual(logs.output[0], "INFO:absl:test_activity ...") - self.assertRegex( - logs.output[1], r"^INFO:absl:test_activity finished after \d+.\d\ds.$" - ) + self.assertRegex(logs.output[1], r"^INFO:absl:test_activity finished after \d+.\d\ds.$") def test_log_activity_fails( self, ): - with self.assertRaises(TestError): # pylint: disable=g-error-prone-assert-raises, line-too-long - with self.assertLogs() as logs: - with utils.log_activity("test_activity"): - raise TestError() + with ( + self.assertRaises(TestError), + self.assertLogs() as logs, + utils.log_activity("test_activity"), + ): + raise TestError() self.assertLen(logs.output, 2) self.assertEqual(logs.output[0], "INFO:absl:test_activity ...") - self.assertRegex( - logs.output[1], r"^ERROR:absl:test_activity FAILED after \d+.\d\ds" - ) + self.assertRegex(logs.output[1], r"^ERROR:absl:test_activity FAILED after \d+.\d\ds") def test_logged_with(self): @utils.logged_with("test_activity") @@ -58,23 +56,18 @@ def test(): test() self.assertLen(logs.output, 2) self.assertEqual(logs.output[0], "INFO:absl:test_activity ...") - self.assertRegex( - logs.output[1], r"^INFO:absl:test_activity finished after \d+.\d\ds.$" - ) + self.assertRegex(logs.output[1], r"^INFO:absl:test_activity finished after \d+.\d\ds.$") def test_logged_with_fails(self): @utils.logged_with("test_activity") def test(): raise TestError() - with self.assertRaises(TestError): # pylint: disable=g-error-prone-assert-raises, line-too-long - with self.assertLogs() as logs: - test() + with self.assertRaises(TestError), self.assertLogs() as logs: + test() self.assertLen(logs.output, 2) self.assertEqual(logs.output[0], "INFO:absl:test_activity ...") - self.assertRegex( - logs.output[1], r"^ERROR:absl:test_activity FAILED after \d+.\d\ds" - ) + self.assertRegex(logs.output[1], r"^ERROR:absl:test_activity FAILED after \d+.\d\ds") def test_check_param(self): a = jnp.array(0.0) diff --git a/src/jax_loop_utils/metric_writers/_audio_video/audio_video.py b/src/jax_loop_utils/metric_writers/_audio_video/audio_video.py index b0fb041..118564f 100644 --- a/src/jax_loop_utils/metric_writers/_audio_video/audio_video.py +++ b/src/jax_loop_utils/metric_writers/_audio_video/audio_video.py @@ -38,19 +38,20 @@ def encode_video(video_array: Array, destination: io.IOBase): if ( np.issubdtype(video_array.dtype, np.floating) - and np.all(0 <= video_array) + and np.all(video_array >= 0) and np.all(video_array <= 1.0) ): video_array = (video_array * 255).astype(np.uint8) elif ( np.issubdtype(video_array.dtype, np.integer) - and np.all(0 <= video_array) + and np.all(video_array >= 0) and np.all(video_array <= 255) ): video_array = video_array.astype(np.uint8) else: raise ValueError( - f"Expected video_array to be floats in [0, 1] or ints in [0, 255], got {video_array.dtype}" + "Expected video_array to be floats in [0, 1] " + f"or ints in [0, 255], got {video_array.dtype}" ) T, H, W, C = video_array.shape diff --git a/src/jax_loop_utils/metric_writers/_audio_video/audio_video_test.py b/src/jax_loop_utils/metric_writers/_audio_video/audio_video_test.py index dfdedd6..a3192b8 100644 --- a/src/jax_loop_utils/metric_writers/_audio_video/audio_video_test.py +++ b/src/jax_loop_utils/metric_writers/_audio_video/audio_video_test.py @@ -23,9 +23,7 @@ def test_encode_video_invalid_args(self): encode_video(invalid_shape, io.BytesIO()) invalid_dtype = 2 * np.ones((10, 20, 30, 3), dtype=np.float32) - with self.assertRaisesRegex( - ValueError, r"Expected video_array to be floats in \[0, 1\]" - ): + with self.assertRaisesRegex(ValueError, r"Expected video_array to be floats in \[0, 1\]"): encode_video(invalid_dtype, io.BytesIO()) def test_encode_video_success(self): diff --git a/src/jax_loop_utils/metric_writers/async_writer.py b/src/jax_loop_utils/metric_writers/async_writer.py index 89dd8b0..9e00b38 100644 --- a/src/jax_loop_utils/metric_writers/async_writer.py +++ b/src/jax_loop_utils/metric_writers/async_writer.py @@ -63,17 +63,13 @@ class AsyncWriter(interface.MetricWriter): processes. """ - def __init__( - self, writer: interface.MetricWriter, *, num_workers: Optional[int] = 1 - ): + def __init__(self, writer: interface.MetricWriter, *, num_workers: Optional[int] = 1): super().__init__() self._writer = writer # By default, we have a thread pool with a single worker to ensure that # calls to the function are run in order (but in a background thread). self._num_workers = num_workers - self._pool = asynclib.Pool( - thread_name_prefix="AsyncWriter", max_workers=num_workers - ) + self._pool = asynclib.Pool(thread_name_prefix="AsyncWriter", max_workers=num_workers) @_wrap_exceptions # type: ignore[call-arg] def write_scalars(self, step: int, scalars: Mapping[str, Scalar]): @@ -89,9 +85,7 @@ def write_videos(self, step: int, videos: Mapping[str, Array]): @_wrap_exceptions # type: ignore[call-arg] def write_audios(self, step: int, audios: Mapping[str, Array], *, sample_rate: int): - self._pool(self._writer.write_audios)( - step=step, audios=audios, sample_rate=sample_rate - ) + self._pool(self._writer.write_audios)(step=step, audios=audios, sample_rate=sample_rate) @_wrap_exceptions # type: ignore[call-arg] def write_texts(self, step: int, texts: Mapping[str, str]): @@ -104,9 +98,7 @@ def write_histograms( arrays: Mapping[str, Array], num_buckets: Optional[Mapping[str, int]] = None, ): - self._pool(self._writer.write_histograms)( - step=step, arrays=arrays, num_buckets=num_buckets - ) + self._pool(self._writer.write_histograms)(step=step, arrays=arrays, num_buckets=num_buckets) @_wrap_exceptions # type: ignore[call-arg] def write_hparams(self, hparams: Mapping[str, Any]): diff --git a/src/jax_loop_utils/metric_writers/interface.py b/src/jax_loop_utils/metric_writers/interface.py index e218ddb..8fc0cca 100644 --- a/src/jax_loop_utils/metric_writers/interface.py +++ b/src/jax_loop_utils/metric_writers/interface.py @@ -148,11 +148,11 @@ def write_hparams(self, hparams: Mapping[str, Any]): hparams: Flat mapping from hyper parameter name to value. """ - def flush(self): + def flush(self): # noqa: B027 """Tells the MetricWriter to write out any cached values.""" pass - def close(self): + def close(self): # noqa: B027 """Flushes and closes the MetricWriter. Calling any method on MetricWriter after MetricWriter.close() diff --git a/src/jax_loop_utils/metric_writers/logging_writer.py b/src/jax_loop_utils/metric_writers/logging_writer.py index af23c61..7ccecd7 100644 --- a/src/jax_loop_utils/metric_writers/logging_writer.py +++ b/src/jax_loop_utils/metric_writers/logging_writer.py @@ -136,9 +136,7 @@ def _compute_histogram_as_tf( histo = np.asarray([array.size], dtype=np.int64) bins = np.asarray([range_max - 0.5, range_max + 0.5], dtype=np.float64) else: - histo, bins = np.histogram( - array, bins=num_buckets, range=(range_min, range_max) - ) + histo, bins = np.histogram(array, bins=num_buckets, range=(range_min, range_max)) bins = np.asarray(bins, dtype=np.float64) return histo, bins @@ -146,10 +144,7 @@ def _compute_histogram_as_tf( def _get_histogram_as_string(histo: np.ndarray, bins: np.ndarray): # First items are right-open (i.e. [a, b)). - items = [ - f"[{bins[i]:.3g}, {bins[i+1]:.3g}): {count}" - for i, count in enumerate(histo[:-1]) - ] + items = [f"[{bins[i]:.3g}, {bins[i + 1]:.3g}): {count}" for i, count in enumerate(histo[:-1])] # Last item is right-closed (i.e. [a, b]). items.append(f"[{bins[-2]:.3g}, {bins[-1]:.3g}]: {histo[-1]}") return ", ".join(items) diff --git a/src/jax_loop_utils/metric_writers/logging_writer_test.py b/src/jax_loop_utils/metric_writers/logging_writer_test.py index a22b39c..a83945c 100644 --- a/src/jax_loop_utils/metric_writers/logging_writer_test.py +++ b/src/jax_loop_utils/metric_writers/logging_writer_test.py @@ -68,9 +68,7 @@ def test_write_histogram(self): ) # Note: There are 31 distinct values [0, 1, ..., 30], and 30 buckets by # default. Last bucket gets 2 values. - expected_histo_b = ", ".join( - [f"[{i}, {i + 1}): 1" for i in range(29)] + ["[29, 30]: 2"] - ) + expected_histo_b = ", ".join([f"[{i}, {i + 1}): 1" for i in range(29)] + ["[29, 30]: 2"]) self.assertEqual( logs.output, [ @@ -112,7 +110,7 @@ def test_collection(self): "INFO:absl:[0] collection=train a=3, b=0.15", "INFO:absl:[4] collection=train Got images: {'input_images': (2, 28, 28, 3)}.", "INFO:absl:[4] collection=train Got texts: {'samples': 'bla'}.", - "INFO:absl:[4] collection=train Histogram for 'a' = {[-0.1, 0.1): 1, [0.1, 0.3]: 2}", + "INFO:absl:[4] collection=train Histogram for 'a' = {[-0.1, 0.1): 1, [0.1, 0.3]: 2}", # noqa: E501 "INFO:absl:[Hyperparameters] collection=train {'learning_rate': 0.1}", ], ) diff --git a/src/jax_loop_utils/metric_writers/memory_writer.py b/src/jax_loop_utils/metric_writers/memory_writer.py index 0fc9a40..39fbf32 100644 --- a/src/jax_loop_utils/metric_writers/memory_writer.py +++ b/src/jax_loop_utils/metric_writers/memory_writer.py @@ -1,6 +1,7 @@ +from collections import OrderedDict from collections.abc import Mapping from dataclasses import dataclass -from typing import Any, Optional, OrderedDict, TypeVar +from typing import Any, Optional, TypeVar import jax @@ -60,9 +61,7 @@ def write_videos(self, step: int, videos: Mapping[str, Array]): self.videos[step] = videos def write_audios(self, step: int, audios: Mapping[str, Array], *, sample_rate: int): - self.audios[step] = MemoryWriterAudioEntry( - audios=audios, sample_rate=sample_rate - ) + self.audios[step] = MemoryWriterAudioEntry(audios=audios, sample_rate=sample_rate) def write_texts(self, step: int, texts: Mapping[str, str]): self.texts[step] = texts @@ -73,9 +72,7 @@ def write_histograms( arrays: Mapping[str, Array], num_buckets: Optional[Mapping[str, int]] = None, ): - self.histograms[step] = MemoryWriterHistogramEntry( - arrays=arrays, num_buckets=num_buckets - ) + self.histograms[step] = MemoryWriterHistogramEntry(arrays=arrays, num_buckets=num_buckets) def write_hparams(self, hparams: Mapping[str, Any]): assert self.hparams is None, "Hyperparameters can only be set once." diff --git a/src/jax_loop_utils/metric_writers/memory_writer_test.py b/src/jax_loop_utils/metric_writers/memory_writer_test.py index 5a3bd8f..93e265c 100644 --- a/src/jax_loop_utils/metric_writers/memory_writer_test.py +++ b/src/jax_loop_utils/metric_writers/memory_writer_test.py @@ -22,9 +22,7 @@ def test_write_scalars(): def test_write_scalars_fails_when_using_same_step(): writer = MemoryWriter() writer.write_scalars(0, {}) - with pytest.raises( - ValueError, match=r"Step must be greater than the last inserted step\." - ): + with pytest.raises(ValueError, match=r"Step must be greater than the last inserted step\."): writer.write_scalars(0, {}) @@ -90,9 +88,7 @@ def test_write_histograms(): dict(writer.histograms), { 5: MemoryWriterHistogramEntry(arrays={"a": _histogram()}, num_buckets=None), - 6: MemoryWriterHistogramEntry( - arrays={"b": _histogram()}, num_buckets={"b": 10} - ), + 6: MemoryWriterHistogramEntry(arrays={"b": _histogram()}, num_buckets={"b": 10}), }, strict=True, ) diff --git a/src/jax_loop_utils/metric_writers/mlflow/metric_writer.py b/src/jax_loop_utils/metric_writers/mlflow/metric_writer.py index 25c2b44..f3aba7e 100644 --- a/src/jax_loop_utils/metric_writers/mlflow/metric_writer.py +++ b/src/jax_loop_utils/metric_writers/mlflow/metric_writer.py @@ -74,9 +74,9 @@ def __init__( experiment = self._client.get_experiment_by_name(experiment_name) if not experiment: raise RuntimeError( - "Failed to get, then failed to create, then failed to get " - f"again experiment '{experiment_name}'" - ) + "Failed to get, then failed to create, " + f"then failed to get again experiment '{experiment_name}'" + ) from None experiment_id = experiment.experiment_id self._run_id = self._client.create_run( experiment_id=experiment_id, run_name=run_name @@ -86,8 +86,7 @@ def write_scalars(self, step: int, scalars: Mapping[str, Scalar]): """Write scalar metrics to MLflow.""" timestamp = int(time.time() * 1000) metrics_list = [ - mlflow.entities.Metric(k, float(v), timestamp, step) - for k, v in scalars.items() + mlflow.entities.Metric(k, float(v), timestamp, step) for k, v in scalars.items() ] self._client.log_batch(self._run_id, metrics=metrics_list, synchronous=False) @@ -139,9 +138,7 @@ def write_videos(self, step: int, videos: Mapping[str, Array]): pool.close() shutil.rmtree(temp_dir) - def _encode_and_log_video( - self, temp_dir: pathlib.Path, rel_path: str, video_array: Array - ): + def _encode_and_log_video(self, temp_dir: pathlib.Path, rel_path: str, video_array: Array): temp_path = temp_dir / rel_path # handle keys with slashes if not temp_path.parent.exists(): @@ -187,9 +184,7 @@ def write_histograms( def write_hparams(self, hparams: Mapping[str, Any]): """Log hyperparameters to MLflow.""" - params = [ - mlflow.entities.Param(key, str(value)) for key, value in hparams.items() - ] + params = [mlflow.entities.Param(key, str(value)) for key, value in hparams.items()] self._client.log_batch(self._run_id, params=params, synchronous=False) def flush(self): diff --git a/src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py b/src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py index 16cf0a5..803d97e 100644 --- a/src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py +++ b/src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py @@ -51,9 +51,7 @@ def test_write_scalars(self): run = runs[0] for metric_key in ("a", "b"): self.assertIn(metric_key, run.data.metrics) - self.assertEqual( - run.data.metrics[metric_key], seq_of_scalars[-1][metric_key] - ) + self.assertEqual(run.data.metrics[metric_key], seq_of_scalars[-1][metric_key]) # constant defined in mlflow.entities.RunStatus self.assertEqual(run.info.status, "RUNNING") writer.close() @@ -103,10 +101,8 @@ def test_write_texts(self): artifact_paths = [artifact.path for artifact in artifacts] self.assertGreaterEqual(len(artifact_paths), 1) self.assertIn("test_text_step_0.txt", artifact_paths) - local_path = writer._client.download_artifacts( - run.info.run_id, "test_text_step_0.txt" - ) - with open(local_path, "r") as f: + local_path = writer._client.download_artifacts(run.info.run_id, "test_text_step_0.txt") + with open(local_path) as f: content = f.read() self.assertEqual(content, test_text) @@ -153,9 +149,7 @@ def test_write_videos(self): artifacts_videos = writer._client.list_artifacts(run.info.run_id, "videos") self.assertEqual(len(artifacts_videos), 2) sorted_artifacts_videos = sorted(artifacts_videos, key=lambda x: x.path) - self.assertEqual( - sorted_artifacts_videos[0].path, "videos/noise_1_000000000.mp4" - ) + self.assertEqual(sorted_artifacts_videos[0].path, "videos/noise_1_000000000.mp4") self.assertFalse(sorted_artifacts_videos[0].is_dir) artifacts_zzz = writer._client.list_artifacts(run.info.run_id, "videos/zzz") diff --git a/src/jax_loop_utils/metric_writers/prefix_suffix_writer.py b/src/jax_loop_utils/metric_writers/prefix_suffix_writer.py index cef4e5e..dbdb5c5 100644 --- a/src/jax_loop_utils/metric_writers/prefix_suffix_writer.py +++ b/src/jax_loop_utils/metric_writers/prefix_suffix_writer.py @@ -1,6 +1,7 @@ """Writer that adds prefix and suffix to metric keys.""" -from typing import Any, Mapping, Optional +from collections.abc import Mapping +from typing import Any, Optional from jax_loop_utils.metric_writers import interface @@ -27,9 +28,7 @@ def __init__( def _transform_keys(self, data: Mapping[str, Any]) -> dict[str, Any]: """Add prefix and suffix to all keys in the mapping.""" - return { - f"{self._prefix}{key}{self._suffix}": value for key, value in data.items() - } + return {f"{self._prefix}{key}{self._suffix}": value for key, value in data.items()} def write_scalars(self, step: int, scalars: Mapping[str, interface.Scalar]): self._writer.write_scalars(step, self._transform_keys(scalars)) @@ -40,12 +39,8 @@ def write_images(self, step: int, images: Mapping[str, interface.Array]): def write_videos(self, step: int, videos: Mapping[str, interface.Array]): self._writer.write_videos(step, self._transform_keys(videos)) - def write_audios( - self, step: int, audios: Mapping[str, interface.Array], *, sample_rate: int - ): - self._writer.write_audios( - step, self._transform_keys(audios), sample_rate=sample_rate - ) + def write_audios(self, step: int, audios: Mapping[str, interface.Array], *, sample_rate: int): + self._writer.write_audios(step, self._transform_keys(audios), sample_rate=sample_rate) def write_texts(self, step: int, texts: Mapping[str, str]): self._writer.write_texts(step, self._transform_keys(texts)) diff --git a/src/jax_loop_utils/metric_writers/prefix_suffix_writer_test.py b/src/jax_loop_utils/metric_writers/prefix_suffix_writer_test.py index b255b10..cd7d1fe 100644 --- a/src/jax_loop_utils/metric_writers/prefix_suffix_writer_test.py +++ b/src/jax_loop_utils/metric_writers/prefix_suffix_writer_test.py @@ -25,9 +25,7 @@ def test_write_scalars(self): def test_write_images(self): image = np.zeros((2, 2, 3)) self.writer.write_images(0, {"image": image}) - self.assertEqual( - list(self.memory_writer.images[0].keys()), ["prefix/image/suffix"] - ) + self.assertEqual(list(self.memory_writer.images[0].keys()), ["prefix/image/suffix"]) def test_write_texts(self): self.writer.write_texts(0, {"text": "hello"}) @@ -54,16 +52,12 @@ def test_empty_prefix_suffix(self): def test_write_videos(self): video = np.zeros((10, 32, 32, 3)) # Simple video array with 10 frames self.writer.write_videos(0, {"video": video}) - self.assertEqual( - list(self.memory_writer.videos[0].keys()), ["prefix/video/suffix"] - ) + self.assertEqual(list(self.memory_writer.videos[0].keys()), ["prefix/video/suffix"]) def test_write_audios(self): audio = np.zeros((16000,)) # 1 second of audio at 16kHz self.writer.write_audios(0, {"audio": audio}, sample_rate=16000) - self.assertEqual( - list(self.memory_writer.audios[0].audios.keys()), ["prefix/audio/suffix"] - ) + self.assertEqual(list(self.memory_writer.audios[0].audios.keys()), ["prefix/audio/suffix"]) def test_close(self): with mock.patch.object(self.memory_writer, "close") as mock_close: diff --git a/src/jax_loop_utils/metric_writers/tf/summary_writer.py b/src/jax_loop_utils/metric_writers/tf/summary_writer.py index 59ffc22..314d018 100644 --- a/src/jax_loop_utils/metric_writers/tf/summary_writer.py +++ b/src/jax_loop_utils/metric_writers/tf/summary_writer.py @@ -58,9 +58,7 @@ def write_images(self, step: int, images: Mapping[str, Array]): tf.summary.image(key, value, step=step, max_outputs=value.shape[0]) def write_videos(self, step: int, videos: Mapping[str, Array]): - logging.log_first_n( - logging.WARNING, "SummaryWriter does not support writing videos.", 1 - ) + logging.log_first_n(logging.WARNING, "SummaryWriter does not support writing videos.", 1) def write_audios(self, step: int, audios: Mapping[str, Array], *, sample_rate: int): with self._summary_writer.as_default(): diff --git a/src/jax_loop_utils/metric_writers/tf/summary_writer_test.py b/src/jax_loop_utils/metric_writers/tf/summary_writer_test.py index 660a6b7..c69078f 100644 --- a/src/jax_loop_utils/metric_writers/tf/summary_writer_test.py +++ b/src/jax_loop_utils/metric_writers/tf/summary_writer_test.py @@ -53,10 +53,7 @@ def _load_histograms_data(logdir): current_steps + [event.step], current_tensors + [tf.make_ndarray(value.tensor)], ) - return { - tag: (np.stack(steps), np.stack(tensors)) - for tag, (steps, tensors) in data.items() - } + return {tag: (np.stack(steps), np.stack(tensors)) for tag, (steps, tensors) in data.items()} def _load_scalars_data(logdir: str): diff --git a/src/jax_loop_utils/metric_writers/torch/tensorboard_writer.py b/src/jax_loop_utils/metric_writers/torch/tensorboard_writer.py index 2113347..0b58035 100644 --- a/src/jax_loop_utils/metric_writers/torch/tensorboard_writer.py +++ b/src/jax_loop_utils/metric_writers/torch/tensorboard_writer.py @@ -54,14 +54,10 @@ def write_videos(self, step: int, videos: Mapping[str, Array]): def write_audios(self, step: int, audios: Mapping[str, Array], *, sample_rate: int): for key, value in audios.items(): - self._writer.add_audio( - key, value, global_step=step, sample_rate=sample_rate - ) + self._writer.add_audio(key, value, global_step=step, sample_rate=sample_rate) def write_texts(self, step: int, texts: Mapping[str, str]): - raise NotImplementedError( - "torch.TensorboardWriter does not support writing texts." - ) + raise NotImplementedError("torch.TensorboardWriter does not support writing texts.") def write_histograms( self, @@ -71,9 +67,7 @@ def write_histograms( ): for tag, values in arrays.items(): bins = None if num_buckets is None else num_buckets.get(tag) - self._writer.add_histogram( - tag, values, global_step=step, bins="auto", max_bins=bins - ) + self._writer.add_histogram(tag, values, global_step=step, bins="auto", max_bins=bins) def write_hparams(self, hparams: Mapping[str, Any]): self._writer.add_hparams(hparams, {}) diff --git a/src/jax_loop_utils/metric_writers/torch/tensorboard_writer_test.py b/src/jax_loop_utils/metric_writers/torch/tensorboard_writer_test.py index 74aa0ca..76166c6 100644 --- a/src/jax_loop_utils/metric_writers/torch/tensorboard_writer_test.py +++ b/src/jax_loop_utils/metric_writers/torch/tensorboard_writer_test.py @@ -16,7 +16,7 @@ import collections import os -from typing import Any, Dict +from typing import Any import numpy as np import tensorflow as tf @@ -38,7 +38,7 @@ def _load_scalars_data(logdir: str): return data -def _load_histograms_data(logdir: str) -> Dict[int, Dict[str, Any]]: +def _load_histograms_data(logdir: str) -> dict[int, dict[str, Any]]: """Loads histograms summaries from events in a logdir. Args: diff --git a/src/jax_loop_utils/metric_writers/utils.py b/src/jax_loop_utils/metric_writers/utils.py index 9df3923..905d646 100644 --- a/src/jax_loop_utils/metric_writers/utils.py +++ b/src/jax_loop_utils/metric_writers/utils.py @@ -22,7 +22,8 @@ # pylint: disable=g-importing-member import collections -from typing import Any, Mapping, Union +from collections.abc import Mapping +from typing import Any, Union import jax.numpy as jnp import numpy as np @@ -35,9 +36,9 @@ def _is_scalar(value: Any) -> bool: - if isinstance(value, values.Scalar) or isinstance(value, (int, float, np.number)): + if isinstance(value, values.Scalar | int | float | np.number): return True - if isinstance(value, (np.ndarray, jnp.ndarray)): + if isinstance(value, np.ndarray | jnp.ndarray): return value.ndim == 0 or value.size <= 1 return False diff --git a/src/jax_loop_utils/metric_writers/utils_test.py b/src/jax_loop_utils/metric_writers/utils_test.py index 2cc8550..551fe5b 100644 --- a/src/jax_loop_utils/metric_writers/utils_test.py +++ b/src/jax_loop_utils/metric_writers/utils_test.py @@ -77,7 +77,7 @@ def _to_list_of_dicts(d): return [{k: v} for k, v in d.items()] -class ONEOF(object): +class ONEOF: """ONEOF(options_list) check value in options_list.""" def __init__(self, container): @@ -104,27 +104,19 @@ def test_write(self): num_buckets = 4 sample_rate = 10 scalar_metrics = { - "loss": jax_loop_utils.metrics.Average.from_model_output( - jnp.asarray([1, 2, 3]) - ), - "accuracy": jax_loop_utils.metrics.LastValue.from_model_output( - jnp.asarray([5]) - ), + "loss": jax_loop_utils.metrics.Average.from_model_output(jnp.asarray([1, 2, 3])), + "accuracy": jax_loop_utils.metrics.LastValue.from_model_output(jnp.asarray([5])), } image_metrics = { "image": ImageMetric(jnp.asarray([[4, 5], [1, 2]])), } histogram_metrics = { "hist": HistogramMetric(value=jnp.asarray([7, 8]), num_buckets=num_buckets), - "hist2": HistogramMetric( - value=jnp.asarray([9, 10]), num_buckets=num_buckets - ), + "hist2": HistogramMetric(value=jnp.asarray([9, 10]), num_buckets=num_buckets), } audio_metrics = { "audio": AudioMetric(value=jnp.asarray([1, 5]), sample_rate=sample_rate), - "audio2": AudioMetric( - value=jnp.asarray([1, 5]), sample_rate=sample_rate + 2 - ), + "audio2": AudioMetric(value=jnp.asarray([1, 5]), sample_rate=sample_rate + 2), } text_metrics = { "text": TextMetric(value="hello"), diff --git a/src/jax_loop_utils/metrics.py b/src/jax_loop_utils/metrics.py index 7c63c6b..ecf9dee 100644 --- a/src/jax_loop_utils/metrics.py +++ b/src/jax_loop_utils/metrics.py @@ -389,13 +389,9 @@ def empty(cls) -> CollectingMetric: return cls(values={}) def merge(self, other: CollectingMetric) -> CollectingMetric: - values = { - name: (*value, *other.values[name]) for name, value in self.values.items() - } + values = {name: (*value, *other.values[name]) for name, value in self.values.items()} if any(isinstance(vv, jax.core.Tracer) for v in values.values() for vv in v): # pylint: disable=g-complex-comprehension - raise RuntimeError( - "Tracer detected! CollectingMetric cannot be JIT compiled." - ) + raise RuntimeError("Tracer detected! CollectingMetric cannot be JIT compiled.") if other.values and not self.values: return other if self.values and not other.values: @@ -539,10 +535,7 @@ class MyMetrics(metrics.Collection): def empty(cls: type[C]) -> C: return cls( _reduction_counter=_ReductionCounter(jnp.array(1, dtype=jnp.int32)), - **{ - metric_name: metric.empty() - for metric_name, metric in cls.__annotations__.items() - }, + **{metric_name: metric.empty() for metric_name, metric in cls.__annotations__.items()}, ) @classmethod @@ -584,9 +577,7 @@ def gather_from_model_output(cls: type[C], axis_name="batch", **kwargs) -> C: A metric collection from provided `kwargs` model outputs that contains metrics for all devices across all hosts. """ - return jax.lax.all_gather( - cls._from_model_output(**kwargs), axis_name=axis_name - ).reduce() + return jax.lax.all_gather(cls._from_model_output(**kwargs), axis_name=axis_name).reduce() def merge(self: C, other: C) -> C: """Returns `Collection` that is the accumulation of `self` and `other`.""" @@ -625,10 +616,7 @@ def reduce(self: C) -> C: Reduced collection. """ return type(self)( - **{ - metric_name: metric.reduce() - for metric_name, metric in vars(self).items() - } + **{metric_name: metric.reduce() for metric_name, metric in vars(self).items()} ) def compute(self) -> dict[str, jnp.ndarray]: @@ -714,8 +702,7 @@ def __init__( count = count if count is not _default else jnp.array(1, dtype=jnp.int32) if (value is _default) == (total is _default): raise ValueError( - "Exactly one of 'total' and 'value' should be passed. " - f"Got {total}, {value}" + f"Exactly one of 'total' and 'value' should be passed. Got {total}, {value}" ) if total is _default: total = value * count @@ -849,9 +836,7 @@ def empty(cls) -> Std: ) @classmethod - def from_model_output( - cls, values: jnp.ndarray, mask: jnp.ndarray | None = None, **_ - ) -> Std: + def from_model_output(cls, values: jnp.ndarray, mask: jnp.ndarray | None = None, **_) -> Std: values, mask = _broadcast_masks(values, mask) return cls( total=jnp.where(mask, values, jnp.zeros_like(values)).sum(), @@ -897,13 +882,11 @@ class Accuracy(Average): """ @classmethod - def from_model_output( - cls, *, logits: jnp.ndarray, labels: jnp.ndarray, **kwargs - ) -> Accuracy: + def from_model_output(cls, *, logits: jnp.ndarray, labels: jnp.ndarray, **kwargs) -> Accuracy: if logits.ndim != labels.ndim + 1 or labels.dtype != jnp.int32: raise ValueError( - f"Expected labels.dtype==jnp.int32 and logits.ndim={logits.ndim}==" - f"labels.ndim+1={labels.ndim + 1}" + "Expected labels.dtype==jnp.int32 and " + f"logits.ndim={logits.ndim}==labels.ndim+1={labels.ndim + 1}" ) metric = super().from_model_output( values=(logits.argmax(axis=-1) == labels).astype(jnp.float32), **kwargs diff --git a/src/jax_loop_utils/metrics_test.py b/src/jax_loop_utils/metrics_test.py index b24fb6e..77ea9fd 100644 --- a/src/jax_loop_utils/metrics_test.py +++ b/src/jax_loop_utils/metrics_test.py @@ -17,21 +17,18 @@ import functools from unittest import mock -from absl.testing import absltest -from absl.testing import parameterized import chex -from jax_loop_utils import asynclib -from jax_loop_utils import metrics -from jax_loop_utils.internal import flax import jax import jax.numpy as jnp import numpy as np +from absl.testing import absltest, parameterized + +from jax_loop_utils import asynclib, metrics +from jax_loop_utils.internal import flax @flax.struct.dataclass -class CollectingMetricAccuracy( - metrics.CollectingMetric.from_outputs(("logits", "labels")) -): +class CollectingMetricAccuracy(metrics.CollectingMetric.from_outputs(("logits", "labels"))): def compute(self): values = super().compute() logits = values["logits"] @@ -83,7 +80,7 @@ def setUp(self): ) self.model_outputs_masked = tuple( dict(mask=mask, **model_output) - for mask, model_output in zip(masks, self.model_outputs) + for mask, model_output in zip(masks, self.model_outputs, strict=False) ) self.count = 4 @@ -130,12 +127,9 @@ def make_compute_metric(self, metric_class, reduce, jit=True): def compute_metric(model_outputs): if reduce: metric_list = [ - metric_class.from_model_output(**model_output) - for model_output in model_outputs + metric_class.from_model_output(**model_output) for model_output in model_outputs ] - metric_stacked = jax.tree_util.tree_map( - lambda *args: jnp.stack(args), *metric_list - ) + metric_stacked = jax.tree_util.tree_map(lambda *args: jnp.stack(args), *metric_list) metric = metric_stacked.reduce() else: metric = metric_class.empty() @@ -151,31 +145,17 @@ def compute_metric(model_outputs): def test_metric_last_value_reduce(self): metric1 = metrics.LastValue.from_model_output(jnp.array([1, 2])) metric2 = metrics.LastValue.from_model_output(jnp.array([3, 4])) - metric3 = metrics.LastValue.from_model_output( - jnp.array([3, 4]), jnp.array([0, 0]) - ) - metric12 = jax.tree_util.tree_map( - lambda *args: jnp.stack(args), metric1, metric2 - ) - metric21 = jax.tree_util.tree_map( - lambda *args: jnp.stack(args), metric2, metric1 - ) + metric3 = metrics.LastValue.from_model_output(jnp.array([3, 4]), jnp.array([0, 0])) + metric12 = jax.tree_util.tree_map(lambda *args: jnp.stack(args), metric1, metric2) + metric21 = jax.tree_util.tree_map(lambda *args: jnp.stack(args), metric2, metric1) self.assertEqual(metric12.reduce().value, 2.5) - chex.assert_trees_all_equal( - metric12.reduce().compute(), metric21.reduce().compute() - ) + chex.assert_trees_all_equal(metric12.reduce().compute(), metric21.reduce().compute()) - metric13 = jax.tree_util.tree_map( - lambda *args: jnp.stack(args), metric1, metric3 - ) - metric31 = jax.tree_util.tree_map( - lambda *args: jnp.stack(args), metric1, metric3 - ) + metric13 = jax.tree_util.tree_map(lambda *args: jnp.stack(args), metric1, metric3) + metric31 = jax.tree_util.tree_map(lambda *args: jnp.stack(args), metric1, metric3) self.assertEqual(metric13.reduce().value, 1.5) - chex.assert_trees_all_equal( - metric13.reduce().compute(), metric31.reduce().compute() - ) + chex.assert_trees_all_equal(metric13.reduce().compute(), metric31.reduce().compute()) def test_metric_last_value(self): metric0 = metrics.LastValue.from_model_output(jnp.array([])) @@ -307,9 +287,7 @@ def rename_mask(**kwargs): ) def test_merge_asserts_shape(self, metric_cls): metric1 = metric_cls.from_model_output(jnp.arange(3.0)) - metric2 = jax.tree_util.tree_map( - lambda *args: jnp.stack(args), metric1, metric1 - ) + metric2 = jax.tree_util.tree_map(lambda *args: jnp.stack(args), metric1, metric1) with self.assertRaisesRegex(ValueError, r"^Expected same shape"): metric1.merge(metric2) @@ -325,9 +303,7 @@ def test_accuracy(self, reduce): def test_last_value_asserts_shape(self): metric1 = metrics.LastValue.from_model_output(jnp.arange(3.0)) - metric2 = jax.tree_util.tree_map( - lambda *args: jnp.stack(args), metric1, metric1 - ) + metric2 = jax.tree_util.tree_map(lambda *args: jnp.stack(args), metric1, metric1) with self.assertRaisesRegex(ValueError, r"^Expected same shape"): metric1.merge(metric2) @@ -343,9 +319,9 @@ def test_loss_average(self, reduce): self.model_outputs_stacked["loss"].mean(), ) chex.assert_trees_all_close( - self.make_compute_metric( - metrics.Average.from_output("example_loss"), reduce - )(self.model_outputs_masked), + self.make_compute_metric(metrics.Average.from_output("example_loss"), reduce)( + self.model_outputs_masked + ), self.model_outputs_stacked["loss"].mean(), ) @@ -446,8 +422,7 @@ def compute_collection(model_outputs): def test_collection_gather(self, masked, all_gather_mock): model_outputs = self.model_outputs_masked if masked else self.model_outputs collections = [ - Collection.single_from_model_output(**model_output) - for model_output in (model_outputs) + Collection.single_from_model_output(**model_output) for model_output in (model_outputs) ] all_gather_mock.return_value = jax.tree_util.tree_map( lambda *args: jnp.stack(args), *collections @@ -474,9 +449,7 @@ def compute_collection(model_outputs): if jax.local_device_count() > 1: chex.assert_trees_all_close( compute_collection( - self.model_outputs_masked_stacked - if masked - else self.model_outputs_stacked + self.model_outputs_masked_stacked if masked else self.model_outputs_stacked ) .unreplicate() .compute(), @@ -494,13 +467,9 @@ def test_collection_asserts_replication(self): def test_collecting_metric(self): metric_class = metrics.CollectingMetric.from_outputs(("logits", "loss")) - logits = np.concatenate( - [model_output["logits"] for model_output in self.model_outputs] - ) + logits = np.concatenate([model_output["logits"] for model_output in self.model_outputs]) loss = np.array([model_output["loss"] for model_output in self.model_outputs]) - result = self.make_compute_metric(metric_class, reduce=False, jit=False)( - self.model_outputs - ) + result = self.make_compute_metric(metric_class, reduce=False, jit=False)(self.model_outputs) chex.assert_trees_all_close( result, { @@ -537,9 +506,7 @@ def copy_to_host(update): def test_collecting_metric_tracer(self): metric_class = metrics.CollectingMetric.from_outputs(("logits",)) with self.assertRaisesRegex(RuntimeError, r"^Tracer detected!"): - _ = self.make_compute_metric(metric_class, reduce=False, jit=True)( - self.model_outputs - ) + _ = self.make_compute_metric(metric_class, reduce=False, jit=True)(self.model_outputs) def test_collection_mixed_async(self): metric = CollectionMixed.empty() diff --git a/src/jax_loop_utils/parameter_overview.py b/src/jax_loop_utils/parameter_overview.py index 704b360..fe53552 100644 --- a/src/jax_loop_utils/parameter_overview.py +++ b/src/jax_loop_utils/parameter_overview.py @@ -14,15 +14,14 @@ """Helper function for creating and logging JAX variable overviews.""" -from collections.abc import Callable, Mapping, Sequence import dataclasses +from collections.abc import Callable, Mapping, Sequence from typing import Any -from absl import logging - import jax import jax.numpy as jnp import numpy as np +from absl import logging _ParamsContainer = dict[str, np.ndarray] | Mapping[str, Mapping[str, Any]] @@ -70,9 +69,7 @@ def flatten_dict( for key, value in input_dict.items(): nested_key = f"{prefix}{delimiter}{key}" if prefix else key if isinstance(value, Mapping): - output_dict.update( - flatten_dict(value, prefix=nested_key, delimiter=delimiter) - ) + output_dict.update(flatten_dict(value, prefix=nested_key, delimiter=delimiter)) else: output_dict[nested_key] = value return output_dict @@ -126,9 +123,7 @@ def _make_row_with_stats(name, value, mean, std) -> _ParamRowWithStats: ) -def _make_row_with_stats_and_sharding( - name, value, mean, std -) -> _ParamRowWithStatsAndSharding: +def _make_row_with_stats_and_sharding(name, value, mean, std) -> _ParamRowWithStatsAndSharding: row = _make_row_with_sharding(name, value) return _ParamRowWithStatsAndSharding( **dataclasses.asdict(row), @@ -164,7 +159,7 @@ def _get_parameter_rows( if params: params = flatten_dict(params) - names, values = map(list, tuple(zip(*sorted(params.items())))) + names, values = map(list, tuple(zip(*sorted(params.items()), strict=False))) else: names, values = [], [] @@ -174,9 +169,7 @@ def _get_parameter_rows( case True: mean_and_std = _mean_std(values) - return jax.tree_util.tree_map( - _make_row_with_stats, names, values, *mean_and_std - ) + return jax.tree_util.tree_map(_make_row_with_stats, names, values, *mean_and_std) case "global": mean_and_std = _mean_std_jit(values) @@ -196,9 +189,9 @@ def _default_table_value_formatter(value): if isinstance(value, bool): return str(value) elif isinstance(value, int): - return "{:,}".format(value) + return f"{value:,}" elif isinstance(value, float): - return "{:.3}".format(value) + return f"{value:.3}" else: return str(value) @@ -247,8 +240,7 @@ def __init__(self, name, values): column_names = [field.name for field in dataclasses.fields(rows[0])] columns = [ - Column(name, [value_formatter(getattr(row, name)) for row in rows]) - for name in column_names + Column(name, [value_formatter(getattr(row, name)) for row in rows]) for name in column_names ] var_line_format = "|" + "".join(f" {{: <{c.width}s}} |" for c in columns) @@ -324,9 +316,7 @@ def get_parameter_overview( Total: 65,172,512 """ - return _get_parameter_overview( - params, include_stats=include_stats, max_lines=max_lines - ) + return _get_parameter_overview(params, include_stats=include_stats, max_lines=max_lines) def _log_parameter_overview( @@ -339,9 +329,7 @@ def _log_parameter_overview( ): """See log_parameter_overview().""" - table = _get_parameter_overview( - params, include_stats=include_stats, max_lines=max_lines - ) + table = _get_parameter_overview(params, include_stats=include_stats, max_lines=max_lines) if jax_logging_process is None or jax_logging_process == jax.process_index(): lines = [msg] if msg else [] lines += table.split("\n") diff --git a/src/jax_loop_utils/parameter_overview_test.py b/src/jax_loop_utils/parameter_overview_test.py index ba68d30..a90db2d 100644 --- a/src/jax_loop_utils/parameter_overview_test.py +++ b/src/jax_loop_utils/parameter_overview_test.py @@ -14,20 +14,22 @@ """Tests for parameter overviews.""" -from absl.testing import absltest -from jax_loop_utils import parameter_overview import jax import jax.numpy as jnp import numpy as np +from absl.testing import absltest +from jax_loop_utils import parameter_overview -EMPTY_PARAMETER_OVERVIEW = """+------+-------+-------+------+------+-----+ +EMPTY_PARAMETER_OVERVIEW = """\ ++------+-------+-------+------+------+-----+ | Name | Shape | Dtype | Size | Mean | Std | +------+-------+-------+------+------+-----+ +------+-------+-------+------+------+-----+ Total: 0 -- 0 bytes""" -CONV2D_PARAMETER_OVERVIEW = """+-------------+--------------+---------+------+ +CONV2D_PARAMETER_OVERVIEW = """\ ++-------------+--------------+---------+------+ | Name | Shape | Dtype | Size | +-------------+--------------+---------+------+ | conv/bias | (2,) | float32 | 2 | @@ -35,7 +37,8 @@ +-------------+--------------+---------+------+ Total: 56 -- 224 bytes""" -CONV2D_PARAMETER_OVERVIEW_WITH_SHARDING = """+-------------+--------------+---------+------+----------+ +CONV2D_PARAMETER_OVERVIEW_WITH_SHARDING = """\ ++-------------+--------------+---------+------+----------+ | Name | Shape | Dtype | Size | Sharding | +-------------+--------------+---------+------+----------+ | conv/bias | (2,) | float32 | 2 | () | @@ -43,7 +46,8 @@ +-------------+--------------+---------+------+----------+ Total: 56 -- 224 bytes""" -CONV2D_PARAMETER_OVERVIEW_WITH_STATS = """+-------------+--------------+---------+------+------+-----+ +CONV2D_PARAMETER_OVERVIEW_WITH_STATS = """\ ++-------------+--------------+---------+------+------+-----+ | Name | Shape | Dtype | Size | Mean | Std | +-------------+--------------+---------+------+------+-----+ | conv/bias | (2,) | float32 | 2 | 1.0 | 0.0 | @@ -51,7 +55,8 @@ +-------------+--------------+---------+------+------+-----+ Total: 56 -- 224 bytes""" -CONV2D_PARAMETER_OVERVIEW_WITH_STATS_AND_SHARDING = """+-------------+--------------+---------+------+------+-----+----------+ +CONV2D_PARAMETER_OVERVIEW_WITH_STATS_AND_SHARDING = """\ ++-------------+--------------+---------+------+------+-----+----------+ | Name | Shape | Dtype | Size | Mean | Std | Sharding | +-------------+--------------+---------+------+------+-----+----------+ | conv/bias | (2,) | float32 | 2 | 1.0 | 0.0 | () | @@ -71,9 +76,7 @@ def test_count_parameters(self): self.assertEqual(56, parameter_overview.count_parameters(params)) def test_get_parameter_overview_empty(self): - self.assertEqual( - EMPTY_PARAMETER_OVERVIEW, parameter_overview.get_parameter_overview({}) - ) + self.assertEqual(EMPTY_PARAMETER_OVERVIEW, parameter_overview.get_parameter_overview({})) def test_get_parameter_overview(self): # Weights of a 2D convolution with 2 filters. @@ -110,12 +113,8 @@ def test_get_parameter_overview_shape_dtype_struct(self): ) def test_printing_bool(self): - self.assertEqual( - parameter_overview._default_table_value_formatter(True), "True" - ) - self.assertEqual( - parameter_overview._default_table_value_formatter(False), "False" - ) + self.assertEqual(parameter_overview._default_table_value_formatter(True), "True") + self.assertEqual(parameter_overview._default_table_value_formatter(False), "False") if __name__ == "__main__": diff --git a/src/jax_loop_utils/periodic_actions.py b/src/jax_loop_utils/periodic_actions.py index aadef7c..f55a222 100644 --- a/src/jax_loop_utils/periodic_actions.py +++ b/src/jax_loop_utils/periodic_actions.py @@ -21,17 +21,15 @@ import functools import os import time -from typing import Callable, Iterable, Optional, Sequence +from collections.abc import Callable, Iterable, Sequence +from typing import Optional -from absl import logging -from jax_loop_utils import asynclib -from jax_loop_utils import metric_writers -from jax_loop_utils import platform -from jax_loop_utils import profiler - -from etils import epath import jax import jax.numpy as jnp +from absl import logging +from etils import epath + +from jax_loop_utils import asynclib, metric_writers, platform, profiler # TODO(b/200953513): Migrate away from logging imports (on module level) # to logging the actual usage. See b/200953513. @@ -119,9 +117,7 @@ def _should_trigger(self, step: int, t: float) -> bool: return True if self._every_secs is not None and t - self._previous_time > self._every_secs: return True - if step in self._on_steps: - return True - return False + return step in self._on_steps def _after_apply(self, step: int, t: float): """Called after each time the action triggered.""" @@ -185,9 +181,7 @@ def __init__( on_steps = set(on_steps or []) if num_train_steps is not None: on_steps.add(num_train_steps) - super().__init__( - every_steps=every_steps, every_secs=every_secs, on_steps=on_steps - ) + super().__init__(every_steps=every_steps, every_secs=every_secs, on_steps=on_steps) # Check for negative values, e.g. tf.data.UNKNOWN/INFINITE_CARDINALTY. if num_train_steps is not None and num_train_steps < 0: num_train_steps = None @@ -281,17 +275,13 @@ def start_measurement(barrier: jax.Array) -> float: barrier.block_until_ready() return time.monotonic() - def stop_measurement( - start_future: concurrent.futures.Future[float], barrier: jax.Array - ): + def stop_measurement(start_future: concurrent.futures.Future[float], barrier: jax.Array): barrier.block_until_ready() self._time_per_part[name] += time.monotonic() - start_future.result() # Call _squareit on this thread so that it is guaranteed to be dispatched # to the TPU before any computations inside `yield`. - start_future = self._executor.submit( - start_measurement, barrier=_squareit(jnp.array(0.0)) - ) + start_future = self._executor.submit(start_measurement, barrier=_squareit(jnp.array(0.0))) yield # Same pattern: _squareit is dispatched after any programs dispatched from @@ -337,12 +327,8 @@ def __init__( artifact_name: Name of the artifact to record. """ if not num_profile_steps and not profile_duration_ms: - raise ValueError( - "Must specify num_profile_steps and/or profile_duration_ms." - ) - super().__init__( - every_steps=every_steps, every_secs=every_secs, on_steps=on_steps - ) + raise ValueError("Must specify num_profile_steps and/or profile_duration_ms.") + super().__init__(every_steps=every_steps, every_secs=every_secs, on_steps=on_steps) self._num_profile_steps = num_profile_steps self._first_profile = first_profile self._profile_duration_ms = profile_duration_ms @@ -355,12 +341,9 @@ def _should_trigger(self, step: int, t: float) -> bool: if self._session_running: # If a session is running we only check if we should stop it. dt = t - self._session_started - cond = ( - not self._profile_duration_ms or dt * 1e3 >= self._profile_duration_ms - ) + cond = not self._profile_duration_ms or dt * 1e3 >= self._profile_duration_ms cond &= ( - not self._num_profile_steps - or step >= self._previous_step + self._num_profile_steps + not self._num_profile_steps or step >= self._previous_step + self._num_profile_steps ) if cond: self._end_session(profiler.stop()) @@ -416,9 +399,7 @@ def __init__( every_secs: See `PeriodicAction.__init__()`. on_steps: See `PeriodicAction.__init__()`. """ - super().__init__( - every_steps=every_steps, every_secs=every_secs, on_steps=on_steps - ) + super().__init__(every_steps=every_steps, every_secs=every_secs, on_steps=on_steps) self._hosts = hosts self._first_profile = first_profile self._profile_duration_ms = profile_duration_ms @@ -472,9 +453,7 @@ def __init__( execute_async: if True wraps the callback into an async call. pass_step_and_time: if True the step and t are passed to the callback. """ - super().__init__( - every_steps=every_steps, every_secs=every_secs, on_steps=on_steps - ) + super().__init__(every_steps=every_steps, every_secs=every_secs, on_steps=on_steps) self._cb_results = collections.deque(maxlen=1) self.pass_step_and_time = pass_step_and_time if execute_async: diff --git a/src/jax_loop_utils/periodic_actions_test.py b/src/jax_loop_utils/periodic_actions_test.py index 374cc60..999765d 100644 --- a/src/jax_loop_utils/periodic_actions_test.py +++ b/src/jax_loop_utils/periodic_actions_test.py @@ -18,16 +18,15 @@ import time from unittest import mock -from absl.testing import absltest -from absl.testing import parameterized +from absl.testing import absltest, parameterized + from jax_loop_utils import periodic_actions +from jax_loop_utils.asynclib import AsyncError class ReportProgressTest(parameterized.TestCase): def test_every_steps(self): - hook = periodic_actions.ReportProgress( - every_steps=4, every_secs=None, num_train_steps=10 - ) + hook = periodic_actions.ReportProgress(every_steps=4, every_secs=None, num_train_steps=10) t = time.monotonic() with self.assertLogs(level="INFO") as logs: self.assertFalse(hook(1, t)) @@ -44,9 +43,7 @@ def test_every_steps(self): ) def test_every_secs(self): - hook = periodic_actions.ReportProgress( - every_steps=None, every_secs=0.3, num_train_steps=10 - ) + hook = periodic_actions.ReportProgress(every_steps=None, every_secs=0.3, num_train_steps=10) t = time.monotonic() with self.assertLogs(level="INFO") as logs: self.assertFalse(hook(1, t)) @@ -69,9 +66,7 @@ def test_without_num_train_steps(self): self.assertFalse(report(1, t)) self.assertTrue(report(2, t + 0.12)) # We did 1 step in 0.12s => 8.333 steps/s. - self.assertEqual( - logs.output, ["INFO:absl:Setting work unit notes: 8.3 steps/s"] - ) + self.assertEqual(logs.output, ["INFO:absl:Setting work unit notes: 8.3 steps/s"]) def test_unknown_cardinality(self): report = periodic_actions.ReportProgress(every_steps=2) @@ -80,16 +75,12 @@ def test_unknown_cardinality(self): self.assertFalse(report(1, t)) self.assertTrue(report(2, t + 0.12)) # We did 1 step in 0.12s => 8.333 steps/s. - self.assertEqual( - logs.output, ["INFO:absl:Setting work unit notes: 8.3 steps/s"] - ) + self.assertEqual(logs.output, ["INFO:absl:Setting work unit notes: 8.3 steps/s"]) def test_called_every_step(self): hook = periodic_actions.ReportProgress(every_steps=3, num_train_steps=10) t = time.monotonic() - with self.assertRaisesRegex( - ValueError, "PeriodicAction must be called after every step" - ): + with self.assertRaisesRegex(ValueError, "PeriodicAction must be called after every step"): hook(1, t) hook(11, t) # Raises exception. @@ -100,9 +91,7 @@ def test_called_every_step(self): @mock.patch("time.monotonic") def test_named(self, wait_jax_async_dispatch, mock_time): mock_time.return_value = 0 - hook = periodic_actions.ReportProgress( - every_steps=1, every_secs=None, num_train_steps=10 - ) + hook = periodic_actions.ReportProgress(every_steps=1, every_secs=None, num_train_steps=10) def _wait(): # Here we depend on hook._executor=ThreadPoolExecutor(max_workers=1) @@ -127,8 +116,8 @@ def _wait(): self.assertEqual( logs.output, [ - "INFO:absl:Setting work unit notes: 0.2 steps/s, 20.0% (2/10), ETA: 0m" - " (0m : 50.0% test1, 25.0% test2)" + "INFO:absl:Setting work unit notes: 0.2 steps/s, 20.0% (2/10), " + "ETA: 0m (0m : 50.0% test1, 25.0% test2)" ], ) @@ -136,9 +125,7 @@ def _wait(): def test_write_metrics(self, time_mock): time_mock.return_value = 0 writer_mock = mock.Mock() - hook = periodic_actions.ReportProgress( - every_steps=2, every_secs=None, writer=writer_mock - ) + hook = periodic_actions.ReportProgress(every_steps=2, every_secs=None, writer=writer_mock) time_mock.return_value = 1 hook(1) time_mock.return_value = 2 @@ -269,9 +256,7 @@ def cb(step, t): del t out.append(step) - hook = periodic_actions.PeriodicCallback( - every_steps=1, callback_fn=cb, execute_async=True - ) + hook = periodic_actions.PeriodicCallback(every_steps=1, callback_fn=cb, execute_async=True) hook(0) hook(1) hook(2) @@ -285,15 +270,13 @@ def test_error_async_is_forwarded(self): def cb(step, t): del step del t - raise Exception + raise ValueError("oh no") - hook = periodic_actions.PeriodicCallback( - every_steps=1, callback_fn=cb, execute_async=True - ) + hook = periodic_actions.PeriodicCallback(every_steps=1, callback_fn=cb, execute_async=True) hook(0) - with self.assertRaises(Exception): + with self.assertRaises(AsyncError): hook(1) def test_function_without_step_and_time(self): diff --git a/src/jax_loop_utils/platform/interface.py b/src/jax_loop_utils/platform/interface.py index 3ec7c4e..0d18115 100644 --- a/src/jax_loop_utils/platform/interface.py +++ b/src/jax_loop_utils/platform/interface.py @@ -68,7 +68,5 @@ def set_task_status(self, msg: str): """Sets the status string for this task.""" @abc.abstractmethod - def create_artifact( - self, artifact_type: ArtifactType, artifact: Any, description: str - ): + def create_artifact(self, artifact_type: ArtifactType, artifact: Any, description: str): """Creates an artifact entry for the work unit.""" diff --git a/src/jax_loop_utils/platform/local.py b/src/jax_loop_utils/platform/local.py index 6173261..f48db5e 100644 --- a/src/jax_loop_utils/platform/local.py +++ b/src/jax_loop_utils/platform/local.py @@ -17,6 +17,7 @@ from typing import Any from absl import logging + from jax_loop_utils.platform import interface WorkUnit = interface.WorkUnit @@ -44,9 +45,7 @@ def set_task_status(self, msg: str): """Sets the status string for this task.""" logging.info("Setting task status: %s", msg) - def create_artifact( - self, artifact_type: ArtifactType, artifact: Any, description: str - ): + def create_artifact(self, artifact_type: ArtifactType, artifact: Any, description: str): """Creates an artifact entry for the work unit.""" logging.info( "Created artifact %s of type %s and value %s.", diff --git a/synopsis.ipynb b/synopsis.ipynb index 6440110..fc4d2a0 100644 --- a/synopsis.ipynb +++ b/synopsis.ipynb @@ -208,6 +208,7 @@ "source": [ "import os\n", "import pathlib\n", + "\n", "from jax_loop_utils import metric_writers\n", "from jax_loop_utils.metric_writers.tf import SummaryWriter\n", "\n", @@ -248,9 +249,7 @@ " [metric_writers.LoggingWriter(collection=collection)]\n", " )\n", " else:\n", - " return metric_writers.MultiWriter(\n", - " [metric_writers.LoggingWriter(collection=collection)]\n", - " )\n", + " return metric_writers.MultiWriter([metric_writers.LoggingWriter(collection=collection)])\n", " writers = [metric_writers.LoggingWriter(collection=collection)]\n", " if logdir is not None:\n", " logdir = pathlib.Path(logdir)\n", @@ -330,9 +329,7 @@ "total_steps = 100\n", "hooks = [\n", " # Outputs progress via metric writer (in this case logs & TensorBoard).\n", - " periodic_actions.ReportProgress(\n", - " num_train_steps=total_steps, every_steps=10, writer=writer\n", - " ),\n", + " periodic_actions.ReportProgress(num_train_steps=total_steps, every_steps=10, writer=writer),\n", " periodic_actions.Profile(logdir=logdir),\n", "]\n", "\n", @@ -424,9 +421,10 @@ } ], "source": [ - "from jax_loop_utils import metrics\n", "import flax\n", "\n", + "from jax_loop_utils import metrics\n", + "\n", "# Metrics are computed in three steps:\n", "\n", "# 1. Compute intermediate values from model outputs\n", @@ -707,9 +705,7 @@ " pred_positives: jnp.array\n", "\n", " @classmethod\n", - " def from_model_output(\n", - " cls, *, logits: jnp.array, labels: jnp.array, **_\n", - " ) -> metrics.Metric:\n", + " def from_model_output(cls, *, logits: jnp.array, labels: jnp.array, **_) -> metrics.Metric:\n", " assert logits.shape[-1] == 2, \"Expected binary logits.\"\n", " preds = logits.argmax(axis=-1)\n", " return cls(\n", diff --git a/uv.lock b/uv.lock index 17fd811..eadb360 100644 --- a/uv.lock +++ b/uv.lock @@ -46,8 +46,8 @@ name = "astunparse" version = "1.6.3" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "six", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "wheel", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "six", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "wheel", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/f3/af/4182184d3c338792894f34a62672919db7ca008c89abee9b564dd34d8029/astunparse-1.6.3.tar.gz", hash = "sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872", size = 18290 } wheels = [ @@ -100,7 +100,7 @@ name = "cffi" version = "1.17.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pycparser", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "pycparser", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/fc/97/c783634659c2920c3fc70419e3af40972dbaf758daa229a7d6ea6135c90d/cffi-1.17.1.tar.gz", hash = "sha256:1c39c6016c32bc48dd54561950ebd6836e1670f2ae46128f67cf49e789c52824", size = 516621 } wheels = [ @@ -187,13 +187,13 @@ name = "chex" version = "0.1.87" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "absl-py", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "jax", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "jaxlib", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "setuptools", marker = "(python_full_version >= '3.12' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and platform_machine == 'x86_64' and sys_platform == 'linux') or (python_full_version >= '3.12' and platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "toolz", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "absl-py", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "jax", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "jaxlib", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "numpy", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "setuptools", marker = "(python_full_version >= '3.12' and platform_machine == 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.12' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "toolz", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "typing-extensions", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/ba/8a/857474810b64ab135a0c3e594b0453c7f39f140757c4cd26a32bccadcbc4/chex-0.1.87.tar.gz", hash = "sha256:0096d89cc8d898bb521ef4bfbf5c24549022b0e5b301f529ab57238896fe6c5d", size = 90063 } wheels = [ @@ -204,9 +204,6 @@ wheels = [ name = "click" version = "8.1.7" source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "colorama", marker = "(platform_machine == 'aarch64' and platform_system == 'Windows' and sys_platform == 'linux') or (platform_machine == 'x86_64' and platform_system == 'Windows' and sys_platform == 'linux') or (platform_machine == 'arm64' and platform_system == 'Windows' and sys_platform == 'darwin')" }, -] sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 } wheels = [ { url = "https://files.pythonhosted.org/packages/00/2e/d53fa4befbf2cfa713304affc7ca780ce4fc1fd8710527771b58311a3229/click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28", size = 97941 }, @@ -221,21 +218,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/41/e1d85ca3cab0b674e277c8c4f678cf66a91cd2cecf93df94353a606fe0db/cloudpickle-3.1.0-py3-none-any.whl", hash = "sha256:fe11acda67f61aaaec473e3afe030feb131d78a43461b718185363384f1ba12e", size = 22021 }, ] -[[package]] -name = "colorama" -version = "0.4.6" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335 }, -] - [[package]] name = "comm" version = "0.2.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "traitlets", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "traitlets", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/e9/a8/fb783cb0abe2b5fded9f55e5703015cdf1c9c85b3669087c538dd15a6a86/comm-0.2.2.tar.gz", hash = "sha256:3fd7a84065306e07bea1773df6eb8282de51ba82f77c72f9c85716ab11fe980e", size = 6210 } wheels = [ @@ -247,7 +235,7 @@ name = "contourpy" version = "1.3.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "numpy", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/25/c2/fc7193cc5383637ff390a712e88e4ded0452c9fbcf84abe3de5ea3df1866/contourpy-1.3.1.tar.gz", hash = "sha256:dfd97abd83335045a913e3bcc4a09c0ceadbe66580cf573fe961f4a825efa699", size = 13465753 } wheels = [ @@ -327,7 +315,7 @@ wheels = [ [package.optional-dependencies] toml = [ - { name = "tomli", marker = "(python_full_version <= '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version <= '3.11' and platform_machine == 'x86_64' and sys_platform == 'linux') or (python_full_version <= '3.11' and platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "tomli", marker = "(python_full_version <= '3.11' and platform_machine == 'arm64' and sys_platform == 'darwin') or (python_full_version <= '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version <= '3.11' and platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] [[package]] @@ -344,8 +332,8 @@ name = "databricks-sdk" version = "0.38.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "google-auth", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "requests", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "google-auth", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "requests", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/1d/c2/02fd3dad8d25b8b24a69925f79f04902e906212808e267ce0e39462e525b/databricks_sdk-0.38.0.tar.gz", hash = "sha256:65e505201b65d8a2b4110d3eabfebce5a25426d3ccdd5f8bc69eb03333ea1f39", size = 594528 } wheels = [ @@ -381,7 +369,7 @@ name = "deprecated" version = "1.2.15" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "wrapt", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "wrapt", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/2e/a3/53e7d78a6850ffdd394d7048a31a6f14e44900adedf190f9a165f6b69439/deprecated-1.2.15.tar.gz", hash = "sha256:683e561a90de76239796e6b6feac66b99030d2dd3fcf61ef996330f14bbb9b0d", size = 2977612 } wheels = [ @@ -399,13 +387,13 @@ wheels = [ [package.optional-dependencies] epath = [ - { name = "fsspec", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "importlib-resources", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "zipp", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "fsspec", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "importlib-resources", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "typing-extensions", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "zipp", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] epy = [ - { name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "typing-extensions", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] [[package]] @@ -440,15 +428,15 @@ name = "flax" version = "0.10.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "jax", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "msgpack", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "optax", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "orbax-checkpoint", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "pyyaml", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "rich", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "tensorstore", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "jax", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "msgpack", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "numpy", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "optax", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "orbax-checkpoint", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "pyyaml", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "rich", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "tensorstore", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "typing-extensions", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/ff/38/4a0203198ac9459832abd33246d4e4fe250528b928a1fcd14cd6559bfcb4/flax-0.10.2.tar.gz", hash = "sha256:6f831350026ad48182ba6588bb4dd72dc1084985d9aca923254cb3e4c78d75f3", size = 5082773 } wheels = [ @@ -505,7 +493,7 @@ name = "gitdb" version = "4.0.11" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "smmap", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "smmap", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/19/0d/bbb5b5ee188dec84647a4664f3e11b06ade2bde568dbd489d9d64adef8ed/gitdb-4.0.11.tar.gz", hash = "sha256:bf5421126136d6d0af55bc1e7c1af1c397a34f5b7bd79e776cd3e89785c2b04b", size = 394469 } wheels = [ @@ -517,7 +505,7 @@ name = "gitpython" version = "3.1.43" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "gitdb", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "gitdb", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/b6/a1/106fd9fa2dd989b6fb36e5893961f82992cf676381707253e0bf93eb1662/GitPython-3.1.43.tar.gz", hash = "sha256:35f314a9f878467f5453cc1fee295c3e18e52f1b99f10f6cf5b1682e968a9e7c", size = 214149 } wheels = [ @@ -529,9 +517,9 @@ name = "google-auth" version = "2.36.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "cachetools", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "pyasn1-modules", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "rsa", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "cachetools", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "pyasn1-modules", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "rsa", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/6a/71/4c5387d8a3e46e3526a8190ae396659484377a73b33030614dd3b28e7ded/google_auth-2.36.0.tar.gz", hash = "sha256:545e9618f2df0bcbb7dcbc45a546485b1212624716975a1ea5ae8149ce769ab1", size = 268336 } wheels = [ @@ -543,7 +531,7 @@ name = "google-pasta" version = "0.2.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "six", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "six", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/35/4a/0bd53b36ff0323d10d5f24ebd67af2de10a1117f5cf4d7add90df92756f1/google-pasta-0.2.0.tar.gz", hash = "sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e", size = 40430 } wheels = [ @@ -584,7 +572,7 @@ name = "h5py" version = "3.12.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "numpy", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/cc/0c/5c2b0a88158682aeafb10c1c2b735df5bc31f165bfe192f2ee9f2a23b5f1/h5py-3.12.1.tar.gz", hash = "sha256:326d70b53d31baa61f00b8aa5f95c2fcb9621a3ee8365d770c551a13dbbcbfdf", size = 411457 } wheels = [ @@ -625,7 +613,7 @@ name = "importlib-metadata" version = "8.5.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "zipp", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "zipp", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/cd/12/33e59336dca5be0c398a7482335911a33aa0e20776128f038019f1a95f1b/importlib_metadata-8.5.0.tar.gz", hash = "sha256:71522656f0abace1d072b9e5481a48f07c138e00f079c38c8f883823f9c26bd7", size = 55304 } wheels = [ @@ -655,19 +643,19 @@ name = "ipykernel" version = "6.29.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "appnope", marker = "(platform_machine == 'aarch64' and platform_system == 'Darwin' and sys_platform == 'linux') or (platform_machine == 'x86_64' and platform_system == 'Darwin' and sys_platform == 'linux') or (platform_machine == 'arm64' and platform_system == 'Darwin' and sys_platform == 'darwin')" }, - { name = "comm", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "debugpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "ipython", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "jupyter-client", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "jupyter-core", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "matplotlib-inline", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "nest-asyncio", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "packaging", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "psutil", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "pyzmq", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "tornado", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "traitlets", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "appnope", marker = "platform_machine == 'arm64' and sys_platform == 'darwin'" }, + { name = "comm", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "debugpy", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "ipython", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "jupyter-client", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "jupyter-core", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "matplotlib-inline", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "nest-asyncio", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "packaging", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "psutil", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "pyzmq", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "tornado", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "traitlets", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/e9/5c/67594cb0c7055dc50814b21731c22a601101ea3b1b50a9a1b090e11f5d0f/ipykernel-6.29.5.tar.gz", hash = "sha256:f093a22c4a40f8828f8e330a9c297cb93dcab13bd9678ded6de8e5cf81c56215", size = 163367 } wheels = [ @@ -679,15 +667,15 @@ name = "ipython" version = "8.30.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "decorator", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "jedi", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "matplotlib-inline", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "pexpect", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "prompt-toolkit", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "pygments", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "stack-data", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "traitlets", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "typing-extensions", marker = "(python_full_version < '3.12' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.12' and platform_machine == 'x86_64' and sys_platform == 'linux') or (python_full_version < '3.12' and platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "decorator", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "jedi", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "matplotlib-inline", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "pexpect", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "prompt-toolkit", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "pygments", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "stack-data", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "traitlets", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "typing-extensions", marker = "(python_full_version < '3.12' and platform_machine == 'arm64' and sys_platform == 'darwin') or (python_full_version < '3.12' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.12' and platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/d8/8b/710af065ab8ed05649afa5bd1e07401637c9ec9fb7cfda9eac7e91e9fbd4/ipython-8.30.0.tar.gz", hash = "sha256:cb0a405a306d2995a5cbb9901894d240784a9f341394c6ba3f4fe8c6eb89ff6e", size = 5592205 } wheels = [ @@ -699,11 +687,11 @@ name = "jax" version = "0.4.36" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "jaxlib", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "ml-dtypes", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "opt-einsum", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "scipy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "jaxlib", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "ml-dtypes", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "numpy", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "opt-einsum", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "scipy", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/df/79/b8033b443d15671725ef4aef7f756c8b0026a7add3e807981d4d7c6abba7/jax-0.4.36.tar.gz", hash = "sha256:088bff0575d01fc82682a9af4eb07433d60de7e5164686bd2cea3439492e608a", size = 1915594 } wheels = [ @@ -715,41 +703,45 @@ name = "jax-loop-utils" version = "0.0.12" source = { editable = "." } dependencies = [ - { name = "absl-py", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "etils", extra = ["epath", "epy"], marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "jax", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "packaging", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "wrapt", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "absl-py", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "etils", extra = ["epath", "epy"], marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "jax", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "numpy", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "packaging", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "wrapt", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] [package.optional-dependencies] audio-video = [ - { name = "av", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "av", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] mlflow = [ - { name = "mlflow-skinny", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "pillow", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, -] -pyright = [ - { name = "pyright", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "mlflow-skinny", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "pillow", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] synopsis = [ - { name = "chex", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "flax", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "ipykernel", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "matplotlib", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "chex", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "flax", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "ipykernel", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "matplotlib", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] tensorflow = [ - { name = "tensorflow", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, -] -test = [ - { name = "chex", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "pytest", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "pytest-cov", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "tensorflow", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] torch = [ - { name = "torch", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "torch", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, +] + +[package.dev-dependencies] +dev = [ + { name = "chex", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "pyright", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "pytest", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "pytest-cov", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "ruff", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, +] +dev-torch = [ + { name = "tensorflow", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] [package.metadata] @@ -757,7 +749,6 @@ requires-dist = [ { name = "absl-py" }, { name = "av", marker = "extra == 'audio-video'", specifier = ">=14.0" }, { name = "chex", marker = "extra == 'synopsis'" }, - { name = "chex", marker = "extra == 'test'" }, { name = "etils", extras = ["epath", "epy"] }, { name = "flax", marker = "extra == 'synopsis'" }, { name = "ipykernel", marker = "extra == 'synopsis'" }, @@ -767,22 +758,29 @@ requires-dist = [ { name = "numpy" }, { name = "packaging" }, { name = "pillow", marker = "extra == 'mlflow'" }, - { name = "pyright", marker = "extra == 'pyright'" }, - { name = "pytest", marker = "extra == 'test'" }, - { name = "pytest-cov", marker = "extra == 'test'" }, { name = "tensorflow", marker = "extra == 'tensorflow'", specifier = ">=2.12" }, { name = "torch", marker = "extra == 'torch'", specifier = ">=2.0" }, { name = "wrapt" }, ] +[package.metadata.requires-dev] +dev = [ + { name = "chex", specifier = ">=0.1.87" }, + { name = "pyright", specifier = "==1.1.391" }, + { name = "pytest", specifier = ">=8.3.4" }, + { name = "pytest-cov", specifier = ">=6.0.0" }, + { name = "ruff", specifier = ">=0.9.1" }, +] +dev-torch = [{ name = "tensorflow", specifier = ">=2.12" }] + [[package]] name = "jaxlib" version = "0.4.36" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "ml-dtypes", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "scipy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "ml-dtypes", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "numpy", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "scipy", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/e4/7d/9394ff39af5c23bb98a241c33742a328df5a43c21d569855ea7e096aaf5e/jaxlib-0.4.36-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:213792db3b876206b45f6a9fbea15e4dd22a9e80be25b03136f20c94784fecfa", size = 98669744 }, @@ -804,7 +802,7 @@ name = "jedi" version = "0.19.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "parso", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "parso", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/72/3a/79a912fbd4d8dd6fbb02bf69afd3bb72cf0c729bb3063c6f4498603db17a/jedi-0.19.2.tar.gz", hash = "sha256:4770dc3de41bde3966b02eb84fbcf557fb33cce26ad23da12c742fb50ecb11f0", size = 1231287 } wheels = [ @@ -816,7 +814,7 @@ name = "jinja2" version = "3.1.4" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "markupsafe", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "markupsafe", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/ed/55/39036716d19cab0747a5020fc7e907f362fbf48c984b14e62127f7e68e5d/jinja2-3.1.4.tar.gz", hash = "sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369", size = 240245 } wheels = [ @@ -828,11 +826,11 @@ name = "jupyter-client" version = "8.6.3" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "jupyter-core", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "python-dateutil", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "pyzmq", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "tornado", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "traitlets", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "jupyter-core", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "python-dateutil", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "pyzmq", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "tornado", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "traitlets", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/71/22/bf9f12fdaeae18019a468b68952a60fe6dbab5d67cd2a103cac7659b41ca/jupyter_client-8.6.3.tar.gz", hash = "sha256:35b3a0947c4a6e9d589eb97d7d4cd5e90f910ee73101611f01283732bd6d9419", size = 342019 } wheels = [ @@ -844,8 +842,8 @@ name = "jupyter-core" version = "5.7.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "platformdirs", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "traitlets", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "platformdirs", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "traitlets", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/00/11/b56381fa6c3f4cc5d2cf54a7dbf98ad9aa0b339ef7a601d6053538b079a7/jupyter_core-5.7.2.tar.gz", hash = "sha256:aa5f8d32bbf6b431ac830496da7392035d6f61b4f54872f15c4bd2a9c3f536d9", size = 87629 } wheels = [ @@ -857,14 +855,14 @@ name = "keras" version = "3.7.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "absl-py", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "h5py", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "ml-dtypes", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "namex", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "optree", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "packaging", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "rich", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "absl-py", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "h5py", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "ml-dtypes", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "namex", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "numpy", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "optree", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "packaging", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "rich", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/c9/c3/56fc6800c5eab94bd0f5e930751bd4c0fa1ee0aee272fad4a72723ffae87/keras-3.7.0.tar.gz", hash = "sha256:a4451a5591e75dfb414d0b84a3fd2fb9c0240cc87ebe7e397f547ce10b0e67b7", size = 924719 } wheels = [ @@ -947,7 +945,7 @@ name = "markdown-it-py" version = "3.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "mdurl", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "mdurl", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/38/71/3b932df36c1a044d397a1f92d1cf91ee0a503d91e470cbd670aa66b07ed0/markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb", size = 74596 } wheels = [ @@ -999,15 +997,15 @@ name = "matplotlib" version = "3.9.3" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "contourpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "cycler", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "fonttools", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "kiwisolver", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "packaging", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "pillow", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "pyparsing", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "python-dateutil", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "contourpy", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "cycler", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "fonttools", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "kiwisolver", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "numpy", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "packaging", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "pillow", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "pyparsing", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "python-dateutil", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/75/9f/562ed484b11ac9f4bb4f9d2d7546954ec106a8c0f06cc755d6f63e519274/matplotlib-3.9.3.tar.gz", hash = "sha256:cd5dbbc8e25cad5f706845c4d100e2c8b34691b412b93717ce38d8ae803bcfa5", size = 36113438 } wheels = [ @@ -1038,7 +1036,7 @@ name = "matplotlib-inline" version = "0.1.7" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "traitlets", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "traitlets", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/99/5b/a36a337438a14116b16480db471ad061c36c3694df7c2084a0da7ba538b7/matplotlib_inline-0.1.7.tar.gz", hash = "sha256:8423b23ec666be3d16e16b60bdd8ac4e86e840ebd1dd11a30b9f117f2fa0ab90", size = 8159 } wheels = [ @@ -1059,7 +1057,7 @@ name = "ml-dtypes" version = "0.4.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "numpy", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/fd/15/76f86faa0902836cc133939732f7611ace68cf54148487a99c539c272dc8/ml_dtypes-0.4.1.tar.gz", hash = "sha256:fad5f2de464fd09127e49b7fd1252b9006fb43d2edc1ff112d390c324af5ca7a", size = 692594 } wheels = [ @@ -1076,19 +1074,19 @@ name = "mlflow-skinny" version = "2.18.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "cachetools", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "click", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "cloudpickle", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "databricks-sdk", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "gitpython", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "importlib-metadata", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "opentelemetry-api", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "opentelemetry-sdk", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "packaging", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "protobuf", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "pyyaml", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "requests", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "sqlparse", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "cachetools", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "click", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "cloudpickle", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "databricks-sdk", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "gitpython", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "importlib-metadata", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "opentelemetry-api", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "opentelemetry-sdk", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "packaging", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "protobuf", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "pyyaml", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "requests", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "sqlparse", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/f9/89/3fbcf0e415678029b783d6951373443aa64cb352c4959374f08903710690/mlflow_skinny-2.18.0.tar.gz", hash = "sha256:87e83f56c362a520196b2f0292b24efdca7f8b2068a6a6941f2ec9feb9bfd914", size = 5445516 } wheels = [ @@ -1324,8 +1322,8 @@ name = "opentelemetry-api" version = "1.28.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "deprecated", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "importlib-metadata", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "deprecated", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "importlib-metadata", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/51/34/e4e9245c868c6490a46ffedf6bd5b0f512bbc0a848b19e3a51f6bbad648c/opentelemetry_api-1.28.2.tar.gz", hash = "sha256:ecdc70c7139f17f9b0cf3742d57d7020e3e8315d6cffcdf1a12a905d45b19cc0", size = 62796 } wheels = [ @@ -1337,9 +1335,9 @@ name = "opentelemetry-sdk" version = "1.28.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "opentelemetry-api", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "opentelemetry-semantic-conventions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "opentelemetry-api", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "opentelemetry-semantic-conventions", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "typing-extensions", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/4b/f4/840a5af4efe48d7fb4c456ad60fd624673e871a60d6494f7ff8a934755d4/opentelemetry_sdk-1.28.2.tar.gz", hash = "sha256:5fed24c5497e10df30282456fe2910f83377797511de07d14cec0d3e0a1a3110", size = 157272 } wheels = [ @@ -1351,8 +1349,8 @@ name = "opentelemetry-semantic-conventions" version = "0.49b2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "deprecated", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "opentelemetry-api", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "deprecated", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "opentelemetry-api", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/7d/0a/e3b93f94aa3223c6fd8e743502a1fefd4fb3a753d8f501ce2a418f7c0bd4/opentelemetry_semantic_conventions-0.49b2.tar.gz", hash = "sha256:44e32ce6a5bb8d7c0c617f84b9dc1c8deda1045a07dc16a688cc7cbeab679997", size = 95213 } wheels = [ @@ -1373,12 +1371,12 @@ name = "optax" version = "0.2.4" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "absl-py", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "chex", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "etils", extra = ["epy"], marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "jax", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "jaxlib", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "absl-py", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "chex", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "etils", extra = ["epy"], marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "jax", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "jaxlib", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "numpy", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/af/b5/f88a0d851547b2e6b2c7e7e6509ad66236b3e7019f1f095bb03dbaa61fa1/optax-0.2.4.tar.gz", hash = "sha256:4e05d3d5307e6dde4c319187ae36e6cd3a0c035d4ed25e9e992449a304f47336", size = 229717 } wheels = [ @@ -1390,7 +1388,7 @@ name = "optree" version = "0.13.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "typing-extensions", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/f7/f2/56afdaeaae36b076659be7db8e72be0924dd64ebd1c131675c77f7e704a6/optree-0.13.1.tar.gz", hash = "sha256:af67856aa8073d237fe67313d84f8aeafac32c1cef7239c628a2768d02679c43", size = 155738 } wheels = [ @@ -1429,18 +1427,18 @@ name = "orbax-checkpoint" version = "0.10.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "absl-py", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "etils", extra = ["epath", "epy"], marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "humanize", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "jax", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "msgpack", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "nest-asyncio", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "protobuf", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "pyyaml", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "simplejson", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "tensorstore", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "absl-py", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "etils", extra = ["epath", "epy"], marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "humanize", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "jax", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "msgpack", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "nest-asyncio", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "numpy", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "protobuf", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "pyyaml", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "simplejson", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "tensorstore", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "typing-extensions", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/d1/06/c42e2f1563dbaaf5ed1464d7b634324fb9a2da04021073c45777e61af78d/orbax_checkpoint-0.10.2.tar.gz", hash = "sha256:e575ebe1f94e5cb6353ab8c9df81de0ca7cddc118645c3bfc17b8344f19d42f1", size = 248170 } wheels = [ @@ -1470,7 +1468,7 @@ name = "pexpect" version = "4.9.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "ptyprocess", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "ptyprocess", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/42/92/cc564bf6381ff43ce1f4d06852fc19a2f11d180f23dc32d9588bee2f149d/pexpect-4.9.0.tar.gz", hash = "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f", size = 166450 } wheels = [ @@ -1537,7 +1535,7 @@ name = "prompt-toolkit" version = "3.0.48" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "wcwidth", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "wcwidth", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/2d/4f/feb5e137aff82f7c7f3248267b97451da3644f6cdc218edfe549fb354127/prompt_toolkit-3.0.48.tar.gz", hash = "sha256:d6623ab0477a80df74e646bdbc93621143f5caf104206aa29294d53de1a03d90", size = 424684 } wheels = [ @@ -1601,7 +1599,7 @@ name = "pyasn1-modules" version = "0.4.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pyasn1", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "pyasn1", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/1d/67/6afbf0d507f73c32d21084a79946bfcfca5fbc62a72057e9c23797a737c9/pyasn1_modules-0.4.1.tar.gz", hash = "sha256:c28e2dbf9c06ad61c71a075c7e0f9fd0f1b0bb2d2ad4377f240d33ac2ab60a7c", size = 310028 } wheels = [ @@ -1637,15 +1635,15 @@ wheels = [ [[package]] name = "pyright" -version = "1.1.390" +version = "1.1.391" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nodeenv", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "nodeenv", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "typing-extensions", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ba/42/1e0392f35dd275f9f775baf7c86407cef7f6a0d9b8e099a93e5422a7e571/pyright-1.1.390.tar.gz", hash = "sha256:aad7f160c49e0fbf8209507a15e17b781f63a86a1facb69ca877c71ef2e9538d", size = 21950 } +sdist = { url = "https://files.pythonhosted.org/packages/11/05/4ea52a8a45cc28897edb485b4102d37cbfd5fce8445d679cdeb62bfad221/pyright-1.1.391.tar.gz", hash = "sha256:66b2d42cdf5c3cbab05f2f4b76e8bec8aa78e679bfa0b6ad7b923d9e027cadb2", size = 21965 } wheels = [ - { url = "https://files.pythonhosted.org/packages/43/20/3f492ca789fb17962ad23619959c7fa642082621751514296c58de3bb801/pyright-1.1.390-py3-none-any.whl", hash = "sha256:ecebfba5b6b50af7c1a44c2ba144ba2ab542c227eb49bc1f16984ff714e0e110", size = 18579 }, + { url = "https://files.pythonhosted.org/packages/ad/89/66f49552fbeb21944c8077d11834b2201514a56fd1b7747ffff9630f1bd9/pyright-1.1.391-py3-none-any.whl", hash = "sha256:54fa186f8b3e8a55a44ebfa842636635688670c6896dcf6cf4a7fc75062f4d15", size = 18579 }, ] [[package]] @@ -1653,9 +1651,9 @@ name = "pytest" version = "8.3.4" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "iniconfig", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "packaging", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "pluggy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "iniconfig", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "packaging", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "pluggy", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/05/35/30e0d83068951d90a01852cb1cef56e5d8a09d20c7f511634cc2f7e0372a/pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761", size = 1445919 } wheels = [ @@ -1667,8 +1665,8 @@ name = "pytest-cov" version = "6.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "coverage", extra = ["toml"], marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "pytest", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "coverage", extra = ["toml"], marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "pytest", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/be/45/9b538de8cef30e17c7b45ef42f538a94889ed6a16f2387a6c89e73220651/pytest-cov-6.0.0.tar.gz", hash = "sha256:fde0b595ca248bb8e2d76f020b465f3b107c9632e6a1d1705f17834c89dcadc0", size = 66945 } wheels = [ @@ -1680,7 +1678,7 @@ name = "python-dateutil" version = "2.9.0.post0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "six", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "six", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/66/c0/0c8b6ad9f17a802ee498c46e004a0eb49bc148f2fd230864601a86dcf6db/python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 342432 } wheels = [ @@ -1721,7 +1719,7 @@ name = "pyzmq" version = "26.2.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "cffi", marker = "(implementation_name == 'pypy' and platform_machine == 'aarch64' and sys_platform == 'linux') or (implementation_name == 'pypy' and platform_machine == 'x86_64' and sys_platform == 'linux') or (implementation_name == 'pypy' and platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "cffi", marker = "(implementation_name == 'pypy' and platform_machine == 'arm64' and sys_platform == 'darwin') or (implementation_name == 'pypy' and platform_machine == 'aarch64' and sys_platform == 'linux') or (implementation_name == 'pypy' and platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/fd/05/bed626b9f7bb2322cdbbf7b4bd8f54b1b617b0d2ab2d3547d6e39428a48e/pyzmq-26.2.0.tar.gz", hash = "sha256:070672c258581c8e4f640b5159297580a9974b026043bd4ab0470be9ed324f1f", size = 271975 } wheels = [ @@ -1768,10 +1766,10 @@ name = "requests" version = "2.32.3" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "certifi", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "charset-normalizer", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "idna", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "urllib3", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "certifi", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "charset-normalizer", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "idna", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "urllib3", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/63/70/2bf7780ad2d390a8d301ad0b550f1581eadbd9a20f896afe06353c2a2913/requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760", size = 131218 } wheels = [ @@ -1783,8 +1781,8 @@ name = "rich" version = "13.9.4" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "markdown-it-py", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "pygments", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "markdown-it-py", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "pygments", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/ab/3a/0316b28d0761c6734d6bc14e770d85506c986c85ffb239e688eeaab2c2bc/rich-13.9.4.tar.gz", hash = "sha256:439594978a49a09530cff7ebc4b5c7103ef57baf48d5ea3184f21d9a2befa098", size = 223149 } wheels = [ @@ -1796,19 +1794,41 @@ name = "rsa" version = "4.9" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pyasn1", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "pyasn1", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/aa/65/7d973b89c4d2351d7fb232c2e452547ddfa243e93131e7cfa766da627b52/rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21", size = 29711 } wheels = [ { url = "https://files.pythonhosted.org/packages/49/97/fa78e3d2f65c02c8e1268b9aba606569fe97f6c8f7c2d74394553347c145/rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7", size = 34315 }, ] +[[package]] +name = "ruff" +version = "0.9.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/67/3e/e89f736f01aa9517a97e2e7e0ce8d34a4d8207087b3cfdec95133fee13b5/ruff-0.9.1.tar.gz", hash = "sha256:fd2b25ecaf907d6458fa842675382c8597b3c746a2dde6717fe3415425df0c17", size = 3498844 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dc/05/c3a2e0feb3d5d394cdfd552de01df9d3ec8a3a3771bbff247fab7e668653/ruff-0.9.1-py3-none-linux_armv6l.whl", hash = "sha256:84330dda7abcc270e6055551aca93fdde1b0685fc4fd358f26410f9349cf1743", size = 10645241 }, + { url = "https://files.pythonhosted.org/packages/dd/da/59f0a40e5f88ee5c054ad175caaa2319fc96571e1d29ab4730728f2aad4f/ruff-0.9.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:3cae39ba5d137054b0e5b472aee3b78a7c884e61591b100aeb544bcd1fc38d4f", size = 10391066 }, + { url = "https://files.pythonhosted.org/packages/b7/fe/85e1c1acf0ba04a3f2d54ae61073da030f7a5dc386194f96f3c6ca444a78/ruff-0.9.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:50c647ff96f4ba288db0ad87048257753733763b409b2faf2ea78b45c8bb7fcb", size = 10012308 }, + { url = "https://files.pythonhosted.org/packages/6f/9b/780aa5d4bdca8dcea4309264b8faa304bac30e1ce0bcc910422bfcadd203/ruff-0.9.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f0c8b149e9c7353cace7d698e1656ffcf1e36e50f8ea3b5d5f7f87ff9986a7ca", size = 10881960 }, + { url = "https://files.pythonhosted.org/packages/12/f4/dac4361afbfe520afa7186439e8094e4884ae3b15c8fc75fb2e759c1f267/ruff-0.9.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:beb3298604540c884d8b282fe7625651378e1986c25df51dec5b2f60cafc31ce", size = 10414803 }, + { url = "https://files.pythonhosted.org/packages/f0/a2/057a3cb7999513cb78d6cb33a7d1cc6401c82d7332583786e4dad9e38e44/ruff-0.9.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:39d0174ccc45c439093971cc06ed3ac4dc545f5e8bdacf9f067adf879544d969", size = 11464929 }, + { url = "https://files.pythonhosted.org/packages/eb/c6/1ccfcc209bee465ced4874dcfeaadc88aafcc1ea9c9f31ef66f063c187f0/ruff-0.9.1-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:69572926c0f0c9912288915214ca9b2809525ea263603370b9e00bed2ba56dbd", size = 12170717 }, + { url = "https://files.pythonhosted.org/packages/84/97/4a524027518525c7cf6931e9fd3b2382be5e4b75b2b61bec02681a7685a5/ruff-0.9.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:937267afce0c9170d6d29f01fcd1f4378172dec6760a9f4dface48cdabf9610a", size = 11708921 }, + { url = "https://files.pythonhosted.org/packages/a6/a4/4e77cf6065c700d5593b25fca6cf725b1ab6d70674904f876254d0112ed0/ruff-0.9.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:186c2313de946f2c22bdf5954b8dd083e124bcfb685732cfb0beae0c47233d9b", size = 13058074 }, + { url = "https://files.pythonhosted.org/packages/f9/d6/fcb78e0531e863d0a952c4c5600cc5cd317437f0e5f031cd2288b117bb37/ruff-0.9.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f94942a3bb767675d9a051867c036655fe9f6c8a491539156a6f7e6b5f31831", size = 11281093 }, + { url = "https://files.pythonhosted.org/packages/e4/3b/7235bbeff00c95dc2d073cfdbf2b871b5bbf476754c5d277815d286b4328/ruff-0.9.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:728d791b769cc28c05f12c280f99e8896932e9833fef1dd8756a6af2261fd1ab", size = 10882610 }, + { url = "https://files.pythonhosted.org/packages/2a/66/5599d23257c61cf038137f82999ca8f9d0080d9d5134440a461bef85b461/ruff-0.9.1-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2f312c86fb40c5c02b44a29a750ee3b21002bd813b5233facdaf63a51d9a85e1", size = 10489273 }, + { url = "https://files.pythonhosted.org/packages/78/85/de4aa057e2532db0f9761e2c2c13834991e087787b93e4aeb5f1cb10d2df/ruff-0.9.1-py3-none-musllinux_1_2_i686.whl", hash = "sha256:ae017c3a29bee341ba584f3823f805abbe5fe9cd97f87ed07ecbf533c4c88366", size = 11003314 }, + { url = "https://files.pythonhosted.org/packages/00/42/afedcaa089116d81447347f76041ff46025849fedb0ed2b187d24cf70fca/ruff-0.9.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5dc40a378a0e21b4cfe2b8a0f1812a6572fc7b230ef12cd9fac9161aa91d807f", size = 11342982 }, +] + [[package]] name = "scipy" version = "1.14.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "numpy", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/62/11/4d44a1f274e002784e4dbdb81e0ea96d2de2d1045b2132d5af62cc31fd28/scipy-1.14.1.tar.gz", hash = "sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417", size = 58620554 } wheels = [ @@ -1918,9 +1938,9 @@ name = "stack-data" version = "0.6.3" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "asttokens", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "executing", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "pure-eval", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "asttokens", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "executing", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "pure-eval", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/28/e3/55dcc2cfbc3ca9c29519eb6884dd1415ecb53b0e934862d3559ddcb7e20b/stack_data-0.6.3.tar.gz", hash = "sha256:836a778de4fec4dcd1dcd89ed8abff8a221f58308462e1c4aa2a3cf30148f0b9", size = 44707 } wheels = [ @@ -1932,7 +1952,7 @@ name = "sympy" version = "1.13.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "mpmath", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "mpmath", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/ca/99/5a5b6f19ff9f083671ddf7b9632028436167cd3d33e11015754e41b249a4/sympy-1.13.1.tar.gz", hash = "sha256:9cebf7e04ff162015ce31c9c6c9144daa34a93bd082f54fd8f12deca4f47515f", size = 7533040 } wheels = [ @@ -1944,16 +1964,16 @@ name = "tensorboard" version = "2.18.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "absl-py", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "grpcio", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "markdown", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "packaging", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "protobuf", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "setuptools", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "six", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "tensorboard-data-server", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "werkzeug", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "absl-py", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "grpcio", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "markdown", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "numpy", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "packaging", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "protobuf", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "setuptools", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "six", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "tensorboard-data-server", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "werkzeug", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/b1/de/021c1d407befb505791764ad2cbd56ceaaa53a746baed01d2e2143f05f18/tensorboard-2.18.0-py3-none-any.whl", hash = "sha256:107ca4821745f73e2aefa02c50ff70a9b694f39f790b11e6f682f7d326745eab", size = 5503036 }, @@ -1974,28 +1994,28 @@ name = "tensorflow" version = "2.18.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "absl-py", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "astunparse", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "flatbuffers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "gast", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "google-pasta", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "grpcio", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "h5py", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "keras", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "libclang", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "ml-dtypes", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "opt-einsum", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "packaging", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "protobuf", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "requests", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "setuptools", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "six", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "tensorboard", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "tensorflow-io-gcs-filesystem", marker = "(python_full_version < '3.12' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.12' and platform_machine == 'x86_64' and sys_platform == 'linux') or (python_full_version < '3.12' and platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "termcolor", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "wrapt", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "absl-py", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "astunparse", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "flatbuffers", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "gast", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "google-pasta", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "grpcio", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "h5py", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "keras", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "libclang", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "ml-dtypes", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "numpy", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "opt-einsum", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "packaging", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "protobuf", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "requests", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "setuptools", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "six", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "tensorboard", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "tensorflow-io-gcs-filesystem", marker = "(python_full_version < '3.12' and platform_machine == 'arm64' and sys_platform == 'darwin') or (python_full_version < '3.12' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.12' and platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "termcolor", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "typing-extensions", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "wrapt", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/26/08/556c4159675c1a30e077ec2a942eeeb81b457cc35c247a5b4a59a1274f05/tensorflow-2.18.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:453cb60638a02fd26316fb36c8cbcf1569d33671f17c658ca0cf2b4626f851e7", size = 239492146 }, @@ -2026,8 +2046,8 @@ name = "tensorstore" version = "0.1.69" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "ml-dtypes", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "ml-dtypes", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "numpy", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/12/0e/03b2797d25f70b30c807bfef1aa8ce09731e7a82a406ba84bb67ddfc34df/tensorstore-0.1.69.tar.gz", hash = "sha256:150cdc7e2044b7629ea466bfc8425ab50e52c8300950dbdd4a64445a2d4fbab1", size = 6585496 } wheels = [ @@ -2097,26 +2117,26 @@ name = "torch" version = "2.5.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "filelock", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "fsspec", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "jinja2", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "networkx", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform == 'linux'" }, - { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform == 'linux'" }, - { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform == 'linux'" }, - { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform == 'linux'" }, - { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform == 'linux'" }, - { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform == 'linux'" }, - { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform == 'linux'" }, - { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform == 'linux'" }, - { name = "setuptools", marker = "(python_full_version >= '3.12' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and platform_machine == 'x86_64' and sys_platform == 'linux') or (python_full_version >= '3.12' and platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "sympy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, - { name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform == 'linux'" }, - { name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "filelock", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "fsspec", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "jinja2", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "networkx", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "setuptools", marker = "(python_full_version >= '3.12' and platform_machine == 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.12' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "sympy", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "typing-extensions", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/d1/35/e8b2daf02ce933e4518e6f5682c72fd0ed66c15910ea1fb4168f442b71c4/torch-2.5.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:de5b7d6740c4b636ef4db92be922f0edc425b65ed78c5076c43c42d362a45457", size = 906474467 }, @@ -2197,7 +2217,7 @@ name = "werkzeug" version = "3.1.3" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "markupsafe", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'darwin')" }, + { name = "markupsafe", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/9f/69/83029f1f6300c5fb2471d621ab06f6ec6b3324685a2ce0f9777fd4a8b71e/werkzeug-3.1.3.tar.gz", hash = "sha256:60723ce945c19328679790e3282cc758aa4a6040e4bb330f53d30fa546d44746", size = 806925 } wheels = [