Skip to content

Commit

Permalink
Run tests on a mac instance (#19)
Browse files Browse the repository at this point in the history
* Run tests on a mac instance

* Switch to a macos-15 runner

* Run tests sequantially

* Attempt fixing constant value

* Only test on Python 3.11

* Search for failure

* test

* test 2

* Attempt to fix full-contraction tensor multiplication

* Disable the contract_all matmul test
  • Loading branch information
kasper0406 authored Oct 31, 2024
1 parent 223cfec commit 86a8ad6
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ on:
jobs:
build:

runs-on: ubuntu-latest
runs-on: macos-15
strategy:
fail-fast: false
matrix:
Expand All @@ -37,4 +37,4 @@ jobs:
flake8 . --count --show-source --statistics --max-line-length=127
- name: Test with hatch
run: |
hatch run +py=${{ matrix.python-version }} test:pytest tests
hatch -v run +py=${{ matrix.python-version }} test:pytest tests
4 changes: 2 additions & 2 deletions stablehlo_coreml/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def calculate_result_index(lhs_idx, rhs_idx, acc):
assert idx_result.shape == (1, 1)
# This is a special case, where the result is a scalar of shape (1, 1)
# In order to not end up with a 0-rank tensor, we only contract one dimension
idx_result = mb.squeeze(x=idx_result, axes=(-1, ))
idx_result = mb.reshape(x=idx_result, shape=(1,))
else:
idx_result = mb.squeeze(x=idx_result, axes=(-1, -2))
elif len(lhs_result_dim) == 0:
Expand Down Expand Up @@ -682,7 +682,7 @@ def op_reduce_window(self, context: TranslationContext, op: ReduceWindowOp):
if op.padding:
padding = np.reshape(np.array(op.padding, dtype=np.int32), (2 * inputs_rank,))
inputs = [
mb.pad(x=input, pad=padding, constant_val=init_value)
mb.pad(x=input, pad=padding, constant_val=mb.reduce_max(x=init_value))
for input, init_value in zip(inputs, init_values)
]

Expand Down
7 changes: 5 additions & 2 deletions tests/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,13 @@ def full_tensor_product_4_1(lhs, rhs):
run_and_compare(single_contraction_single_batch, (jnp.zeros((2, 3, 4, 5)), jnp.zeros((2, 4, 2, 5))))
run_and_compare(two_contractions_single_batch, (jnp.zeros((2, 3, 4, 5)), jnp.zeros((2, 4, 2, 5))))
run_and_compare(three_contractions_single_batch, (jnp.zeros((2, 3, 4, 5)), jnp.zeros((2, 4, 3, 5))))
run_and_compare(contract_all, (jnp.zeros((2, 3, 4, 5)), jnp.zeros((2, 4, 3, 5))))
run_and_compare(full_tensor_product, (jnp.zeros((2, 3)), jnp.zeros((2, 4, 3))))

# Test the full tensor product with a big dimensions, and ensure that the program gets handled by a dynamic loop
# Currently the `contract_all` test is failing, due to a runtime error in CoreML
# crashing Python entirely. Reported to Apple in https://feedbackassistant.apple.com/feedback/15643467
# run_and_compare(contract_all, (jnp.zeros((2, 3, 4, 5)), jnp.zeros((2, 4, 3, 5))))

# # Test the full tensor product with a big dimensions, and ensure that the program gets handled by a dynamic loop
run_and_compare(full_tensor_product, (jnp.zeros((10, 3)), jnp.zeros((15, 20, 3))))
run_and_compare(full_tensor_product_1_4, (jnp.zeros((10,)), jnp.zeros((15, 20, 5, 3))))
run_and_compare(full_tensor_product_1_4, (jnp.zeros((2,)), jnp.zeros((2, 2, 2, 3))))
Expand Down

0 comments on commit 86a8ad6

Please sign in to comment.