Skip to content

Commit

Permalink
Add check for GPU availability in attention (NVIDIA#1287)
Browse files Browse the repository at this point in the history
* check if GPU is available

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Charlene Yang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
cyanguwa and pre-commit-ci[bot] authored Oct 29, 2024
1 parent d710c24 commit 8bdb54f
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def _get_supported_versions(version_min, version_max):
try:
_flash_attn_version = PkgVersion(get_pkg_version("flash-attn"))
except PackageNotFoundError:
if get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN:
if torch.cuda.is_available() and get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN:
fa_logger.debug(
"flash-attn v2 is not installed. To use, please install it by"
""" "pip install flash-attn".""",
Expand All @@ -158,7 +158,9 @@ def _get_supported_versions(version_min, version_max):
_flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1")
_flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7")
_flash_attn_2_6_0_plus = _flash_attn_version >= PkgVersion("2.6.0")
elif get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN:
elif (
torch.cuda.is_available() and get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN
):
fa_logger.warning(
"Supported flash-attn versions are %s. Found flash-attn %s.",
_get_supported_versions(
Expand All @@ -183,7 +185,7 @@ def _get_supported_versions(version_min, version_max):
try:
_flash_attn_3_version = PkgVersion(get_pkg_version("flashattn-hopper"))
except PackageNotFoundError:
if get_device_compute_capability() >= (9, 0) and _NVTE_FLASH_ATTN:
if torch.cuda.is_available() and get_device_compute_capability() >= (9, 0) and _NVTE_FLASH_ATTN:
fa_logger.debug(
"flash-attn v3 is not installed. To use, please install it by \n%s",
_flash_attn_3_installation_steps,
Expand Down

0 comments on commit 8bdb54f

Please sign in to comment.