You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hey everyone, thanks for developing this library. I'd like to use block sparse matmul with jax, and this project seems to deliver just what we need 👍 Yet, I am having trouble getting examples/pallas/blocksparse_matmul.py to run. When installing from HEAD, I run into compatibility problems with ptxas. Help with this would be much appreciated.
As far as I understand ptxas version 7.8 is shipped with CUDA 11.8, and 8.0 with CUDA 12.0. As noted below, I installed jaxlib with local CUDA 11.8. Considering the traceback below, I am wondering if jax-triton requires CUDA 12 in its current form? In this case, I would be happy to get a recommendation for jax, and jax-triton commits to install from source.
Traceback
2023-07-27 13:17:05.422836: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: ptxas exited with non-zero error code 65280, output: ptxas /tmp/tempfile-taurusi8032-f5aaecd1-56106-601761aa90836, line 5; fatal : Unsupported .version 8.0; current version is '7.8'
ptxas fatal : Ptx assembly aborted due to errors
2023-07-27 13:17:05.422953: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2537] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: ptxas exited with non-zero error code 65280, output: ptxas /tmp/tempfile-taurusi8032-f5aaecd1-56106-601761aa90836, line 5; fatal : Unsupported .version 8.0; current version is '7.8'
ptxas fatal : Ptx assembly aborted due to errors
; current tracing scope: custom-call.0; current profiling annotation: XlaModule:#hlo_module=jit_sdd_matmul,program_id=292#.
Traceback (most recent call last):
File "/filesystem/my/workspace/Code/BlockSparse/jax_triton_blocksparse.py", line 227, in <module>
app.run(main)
File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/filesystem/my/workspace/Code/BlockSparse/jax_triton_blocksparse.py", line 201, in main
sdd_matmul(x, y, bn=bn, debug=True).block_until_ready()
File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 250, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **params)
File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/core.py", line 2578, in bind
return self.bind_with_trace(top_trace, args, params)
File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/core.py", line 382, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/core.py", line 814, in process_primitive
return primitive.impl(*tracers, **params)
File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 1223, in _pjit_call_impl
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 1207, in call_impl_cache_miss
out_flat, compiled = _pjit_call_impl_python(
File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 1163, in _pjit_call_impl_python
return compiled.unsafe_call(*args), compiled
File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 1344, in __call__
results = self.xla_executable.execute_sharded(input_bufs)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: ptxas exited with non-zero error code 65280, output: ptxas /tmp/tempfile-taurusi8032-f5aaecd1-56106-601761aa90836, line 5; fatal : Unsupported .version 8.0; current version is '7.8'
ptxas fatal : Ptx assembly aborted due to errors
; current tracing scope: custom-call.0; current profiling annotation: XlaModule:#hlo_module=jit_sdd_matmul,program_id=292#.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/filesystem/my/workspace/Code/BlockSparse/jax_triton_blocksparse.py", line 227, in <module>
app.run(main)
File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/filesystem/my/workspace/Code/BlockSparse/jax_triton_blocksparse.py", line 201, in main
sdd_matmul(x, y, bn=bn, debug=True).block_until_ready()
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: ptxas exited with non-zero error code 65280, output: ptxas /tmp/tempfile-taurusi8032-f5aaecd1-56106-601761aa90836, line 5; fatal : Unsupported .version 8.0; current version is '7.8'
ptxas fatal : Ptx assembly aborted due to errors
; current tracing scope: custom-call.0; current profiling annotation: XlaModule:#hlo_module=jit_sdd_matmul,program_id=292#.
Environment
Working on a university cluster with installed modules for Python 3.10, CUDA 11.8 and cuDNN 8.6. Upon loading the modules, they appear in the $PATH, and $CUDA_HOME is properly set to the directory (e.g. nvcc and ptxas are located here).
The issue might be due to Triton depending on CUDA 12. See Line 124 of triton's setup.py. On a cluster with CUDA 12 ready NVIDIA drivers, the issue can be fixed by installing recent nvcc from conda: conda install cuda-nvcc -c nvidia
and then installing jaxlib, jax and jax-triton.
Note that I encountered this issue as well, see jax-ml/jax#25344.
In the end, a solution given was to pass the correct ptxas to triton_lib.
Originally, this comes from the triton wheel (3.1.0) that ship a ptxas binary (built with an llvm+nvptx+cuda 12.4) that may be different from the system one or already installed, as Mark pointed out.
Setting the env. variable TRITON_PTXAS_PATH=$CUDA_HOME/bin/ptxas fixes it, as it allow triton to favor the installed ptxas and not the bundled one.!
Hey everyone, thanks for developing this library. I'd like to use block sparse matmul with jax, and this project seems to deliver just what we need 👍 Yet, I am having trouble getting
examples/pallas/blocksparse_matmul.py
to run. When installing from HEAD, I run into compatibility problems with ptxas. Help with this would be much appreciated.As far as I understand ptxas version 7.8 is shipped with CUDA 11.8, and 8.0 with CUDA 12.0. As noted below, I installed jaxlib with local CUDA 11.8. Considering the traceback below, I am wondering if jax-triton requires CUDA 12 in its current form? In this case, I would be happy to get a recommendation for jax, and jax-triton commits to install from source.
Traceback
Environment
Working on a university cluster with installed modules for Python 3.10, CUDA 11.8 and cuDNN 8.6. Upon loading the modules, they appear in the $PATH, and $CUDA_HOME is properly set to the directory (e.g. nvcc and ptxas are located here).
I installed jaxlib according to my cuda versions:
pip install "jaxlib @ https://storage.googleapis.com/jax-releases/nightly/cuda118/jaxlib-0.4.14.dev20230714+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl"
Then installed jax-triton as recommended from HEAD.
pip install 'jax-triton @ git+https://github.com/jax-ml/jax-triton.git'
Pip list shows (selection):
executing
ptxas --version
yields(Side note: Using a stable jaxlib (0.4.13) and jax-triton stable (0.1.3) yields the already reported import error #157 )
The text was updated successfully, but these errors were encountered: