Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[3.2.x] ptx_get_version cannot handle CUDA>12.6 #5737

Open
h-vetinari opened this issue Jan 29, 2025 · 8 comments
Open

[3.2.x] ptx_get_version cannot handle CUDA>12.6 #5737

h-vetinari opened this issue Jan 29, 2025 · 8 comments
Assignees

Comments

@h-vetinari
Copy link

nvidia recently released CUDA 12.8, and I'm seeing failures while running triton if it is present:

 │ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
 │ cuda_version = '12.8'
 │     @functools.lru_cache()
 │     def ptx_get_version(cuda_version) -> int:
 │         '''
 │         Get the highest PTX version supported by the current CUDA driver.
 │         '''
 │         assert isinstance(cuda_version, str)
 │         major, minor = map(int, cuda_version.split('.'))
 │         if major == 12:
 │             if minor < 6:
 │                 return 80 + minor
 │             elif minor == 6:
 │                 return 85
 │         if major == 11:
 │             return 70 + minor
 │         if major == 10:
 │             return 63 + minor
 │ >       raise RuntimeError("Triton only support CUDA 10.0 or higher, but got CUDA version: " + cuda_version)
 │ E       RuntimeError: Triton only support CUDA 10.0 or higher, but got CUDA version: 12.8
 │ ../../../../../lib/python3.11/site-packages/triton/backends/nvidia/compiler.py:57: RuntimeError

IMO it would be more appropriate to use >=6 in

if major == 12:
if minor < 6:
return 80 + minor
elif minor == 6:
return 85

as it's less of an issue whether an older PTX version is used, than if the whole thing errors out.

@h-vetinari h-vetinari changed the title ptx_get_version cannot handle CUDA>=12.6 ptx_get_version cannot handle CUDA>12.6 Jan 29, 2025
@h-vetinari h-vetinari changed the title ptx_get_version cannot handle CUDA>12.6 [3.2.x] ptx_get_version cannot handle CUDA>12.6 Jan 29, 2025
@ThomasRaoux
Copy link
Collaborator

ToT tree should be fine. Here is the code:

@functools.lru_cache()
def ptx_get_version(cuda_version) -> int:
    '''
    Get the highest PTX version supported by the current CUDA driver.
    '''
    assert isinstance(cuda_version, str)
    major, minor = map(int, cuda_version.split('.'))
    if major == 12:
        if minor < 6:
            return 80 + minor
        else:
            return 80 + minor - 1
    if major == 11:
        return 70 + minor
    if major == 10:
        return 63 + minor
    raise RuntimeError("Triton only support CUDA 10.0 or higher, but got CUDA version: " + cuda_version)

@h-vetinari
Copy link
Author

h-vetinari commented Jan 29, 2025

I see that b39c1e1 landed on main, which is almost certainly too big for backporting, but I'm wondering if

    if major == 12:
        if minor < 6:
            return 80 + minor
-       elif minor == 6:
+       elif minor >= 6:
            return 85
    if major == 11:
        return 70 + minor

would be acceptable for 3.2.x

@h-vetinari
Copy link
Author

h-vetinari commented Jan 29, 2025

we're going to need triton 3.2 for pytorch 2.6, and it would be a pity if that cannot be used with CUDA 12.8 - I'm not talking about sm100 support, but just being able to use a CUDA 12.8 toolchain.

@ThomasRaoux
Copy link
Collaborator

@bertmaher is handling the release branch, I'll defer to him

@h-vetinari
Copy link
Author

Thanks. Can you please reopen the issue in the meantime, otherwise the reduced visibility makes it all too easy for it to fall through the cracks.

@bertmaher
Copy link
Collaborator

@atalman Can we still patch this to release/3.2.x in time for PyTorch 2.6? People will almost certainly be using CUDA 12.8 soon and it'll be really frustrating if torch.compile doesn't work there because of this

@h-vetinari
Copy link
Author

h-vetinari commented Jan 30, 2025

From our limited testing, I can confirm that

    if major == 12:
        if minor < 6:
            return 80 + minor
-       elif minor == 6:
+       elif minor >= 6:
            return 85
    if major == 11:
        return 70 + minor

works

@bertmaher
Copy link
Collaborator

Proposing #5765 as a cherry-pick to Triton 3.2; but since we just pushed PT 2.6 I think it'll be a while before we can get this into a patch release. @atalman can clarify, hopefully

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants