Skip to content

Commit

Permalink
Merge pull request #1 from davnn/add-tests
Browse files Browse the repository at this point in the history
✅ Add dispatch tests and remove exports
  • Loading branch information
davnn authored Oct 7, 2023
2 parents e0948d2 + c86d31f commit ee01199
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 19 deletions.
10 changes: 2 additions & 8 deletions Taskfile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ env:

tasks:
env-create:
cmd: $CONDA env create -n $PROJECT -f env.yml
cmd: $CONDA env create -n $PROJECT --file env.yml --yes

env-remove:
cmd: $CONDA env remove -n $PROJECT
cmd: $CONDA env remove -n $PROJECT --yes

poetry-install:
cmd: curl -sSL https://install.python-poetry.org | python -
Expand All @@ -21,12 +21,6 @@ tasks:
poetry-update-dev:
cmd: poetry add pytest@latest pytest-html@latest hypothesis@latest coverage@latest pytest-cov@latest coverage-badge@latest ruff@latest pre-commit@latest black@latest pyright@latest typing-extensions@latest bandit@latest safety@latest numpy@latest torch@latest jax@latest -G dev

install-env:
cmd: $CONDA env create -n $PROJECT -f env.lock.yml

install-env-dev:
cmd: $CONDA env create -n $PROJECT -f env.dev.lock.yml

pre-commit-install:
cmd: poetry run pre-commit install

Expand Down
6 changes: 3 additions & 3 deletions assets/coverage.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "safecheck"
version = "0.0.1"
version = "0.0.2"
description = "Utilities for typechecking, shapechecking and dispatch."
readme = "README.md"
authors = ["David Muhr <[email protected]>"]
Expand Down
10 changes: 4 additions & 6 deletions safecheck/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,29 +112,27 @@
"UInt64",
"Complex64",
"Complex128",
# package version
"version",
]

try:
from numpy import ndarray as NumpyArray # noqa: F401, N812

__all__.append("NumpyArray")
except ImportError:
except ImportError: # pragma: no cover
...

try:
from torch import Tensor as TorchArray # noqa: F401

__all__.append("TorchArray")
except ImportError:
except ImportError: # pragma: no cover
...

try:
from jax import Array as JaxArray # noqa: F401

__all__.append("JaxArray")
except ImportError:
except ImportError: # pragma: no cover
...


Expand All @@ -153,4 +151,4 @@ def get_version() -> str:
return "unknown"


version: str = get_version()
__version__: str = get_version()
Empty file added safecheck/py.typed
Empty file.
114 changes: 113 additions & 1 deletion tests/test_safecheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,17 @@
jax_array = jax.random.randint(key=jax.random.PRNGKey(0), minval=0, maxval=1, shape=(1,))

array_types = {TorchArray: torch_array, NumpyArray: np_array, JaxArray: jax_array}

array_types_str = {TorchArray: "torch", NumpyArray: "numpy", JaxArray: "jax"}
data_types = {
TorchArray: {Float: torch_array.float(), Integer: torch_array.int(), Bool: torch_array.bool()},
NumpyArray: {Float: np_array.astype(float), Integer: np_array.astype(int), Bool: np_array.astype(bool)},
JaxArray: {Float: jax_array.astype(float), Integer: jax_array.astype(int), Bool: jax_array.astype(bool)},
}
data_types_str = {
TorchArray: {Float: "torch_float", Integer: "torch_integer", Bool: "torch_bool"},
NumpyArray: {Float: "numpy_float", Integer: "numpy_integer", Bool: "numpy_bool"},
JaxArray: {Float: "jax_float", Integer: "jax_integer", Bool: "jax_bool"},
}


@pytest.mark.parametrize("array_type", array_types.keys())
Expand Down Expand Up @@ -56,3 +61,110 @@ def f(array: data_type[array_type, "..."]) -> data_type[array_type, "..."]:

with pytest.raises(BeartypeCallHintParamViolation):
f(current_array)


@pytest.mark.parametrize("array_type", data_types.keys())
def test_array_type_dispatch(array_type):
dispatch = Dispatcher()

@dispatch
def f(_: Shaped[NumpyArray, "..."]) -> str:
return "numpy"

@dispatch
def f(_: Shaped[TorchArray, "..."]) -> str:
return "torch"

@dispatch
def f(_: Shaped[JaxArray, "..."]) -> str:
return "jax"

assert array_types_str[array_type] == f(array_types[array_type])


@pytest.mark.parametrize("array_type", data_types.keys())
def test_array_type_dispatch_with_typecheck(array_type):
dispatch = Dispatcher()

@dispatch
@typecheck
def f(_: Shaped[NumpyArray, "..."]) -> str:
return "numpy"

@dispatch
@typecheck
def f(_: Shaped[TorchArray, "..."]) -> str:
return "torch"

@dispatch
@typecheck
def f(_: Shaped[JaxArray, "..."]) -> str:
return "jax"

assert array_types_str[array_type] == f(array_types[array_type])


@pytest.mark.parametrize("array_type", data_types.keys())
def test_array_type_dispatch_with_shapecheck(array_type):
dispatch = Dispatcher()

@dispatch
@shapecheck
def f(_: Shaped[NumpyArray, "..."]) -> str:
return "numpy"

@dispatch
@shapecheck
def f(_: Shaped[TorchArray, "..."]) -> str:
return "torch"

@dispatch
@shapecheck
def f(_: Shaped[JaxArray, "..."]) -> str:
return "jax"

assert array_types_str[array_type] == f(array_types[array_type])


@pytest.mark.parametrize("array_type", data_types.keys())
@pytest.mark.parametrize("data_type", next(iter(data_types.values())).keys())
def test_data_type_dispatch(array_type, data_type):
dispatch = Dispatcher()

@dispatch
def f(_: Float[NumpyArray, "..."]) -> str:
return "numpy_float"

@dispatch
def f(_: Integer[NumpyArray, "..."]) -> str:
return "numpy_integer"

@dispatch
def f(_: Bool[NumpyArray, "..."]) -> str:
return "numpy_bool"

@dispatch
def f(_: Float[TorchArray, "..."]) -> str:
return "torch_float"

@dispatch
def f(_: Integer[TorchArray, "..."]) -> str:
return "torch_integer"

@dispatch
def f(_: Bool[TorchArray, "..."]) -> str:
return "torch_bool"

@dispatch
def f(_: Float[JaxArray, "..."]) -> str:
return "jax_float"

@dispatch
def f(_: Integer[JaxArray, "..."]) -> str:
return "jax_integer"

@dispatch
def f(_: Bool[JaxArray, "..."]) -> str:
return "jax_bool"

assert data_types_str[array_type][data_type] == f(data_types[array_type][data_type])

0 comments on commit ee01199

Please sign in to comment.