diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index ea201d7..c2e6f85 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -38,6 +38,8 @@ from jax._src.lib.mlir import ir import jax.dlpack import jax.extend as jex +from jax.interpreters import ad +from jax.interpreters import batching from jax.interpreters import mlir from jax.interpreters import xla import jax.numpy as jnp @@ -675,6 +677,31 @@ def prune_configs(configs, named_args, **kwargs): platform="rocm", ) + +def triton_kernel_call_raise_on_jvp(*args, **kwargs): + del args, kwargs # unused + raise NotImplementedError( + "jax_triton.triton_call does not support automatic differentiation. Use " + "jax.custom_jvp or jax.custom_vjp to implement a custom automatic " + "differentiation rule for your kernel." + ) + +ad.primitive_jvps[triton_kernel_call_p] = triton_kernel_call_raise_on_jvp + + +def triton_kernel_call_raise_on_vmap(*args, **kwargs): + del args, kwargs # unused + raise NotImplementedError( + "jax_triton.triton_call does not support batching with jax.vmap. Use " + "jax.custom_batching.custom_vmap to implement a custom batching rule for " + "your kernel." + ) + +batching.primitive_batchers[triton_kernel_call_p] = ( + triton_kernel_call_raise_on_vmap +) + + class ShapeDtype(Protocol): @property diff --git a/tests/triton_call_test.py b/tests/triton_call_test.py index a1dfd78..239b090 100644 --- a/tests/triton_call_test.py +++ b/tests/triton_call_test.py @@ -531,6 +531,24 @@ def test_autotune_with_input_output_aliasing(self): out = add(x, y, kernel=kernel, input_output_aliases={0: 0}) np.testing.assert_allclose(out, expected) + def test_autodiff_exception(self): + x, y = create_random_inputs([10, 100], dtype="float32") + with self.assertRaisesRegex( + NotImplementedError, + r"jax_triton.triton_call does not support automatic differentiation.*" + r"jax\.custom_jvp or jax\.custom_vjp.*", + ): + jax.grad(lambda x, y: jnp.sum(add(x, y, BLOCK_SIZE=32)))(x, y) + + def test_batching_exception(self): + x, y = create_random_inputs([10, 100], dtype="float32") + with self.assertRaisesRegex( + NotImplementedError, + r"jax_triton.triton_call does not support batching.*" + r"jax\.custom_batching\.custom_vmap.*", + ): + jax.vmap(lambda x, y: add(x, y, BLOCK_SIZE=32))(x, y) + if __name__ == "__main__": os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5"