From e43d8e9448daa9be0686b3c17aa685969d3e74cf Mon Sep 17 00:00:00 2001 From: Seunghoon Lee Date: Sat, 11 Jan 2025 23:36:51 +0900 Subject: [PATCH] check hipblaslt availability in windows --- installer.py | 16 +++++---- modules/rocm.py | 18 ++++------ modules/zluda_hijacks.py | 2 +- modules/zluda_installer.py | 67 ++++++++++++++++++++------------------ 4 files changed, 53 insertions(+), 50 deletions(-) diff --git a/installer.py b/installer.py index 7b351c799..91496ac49 100644 --- a/installer.py +++ b/installer.py @@ -569,6 +569,7 @@ def install_rocm_zluda(): msg += f', using agent {device.name}' log.info(msg) torch_command = '' + if sys.platform == "win32": # TODO install: enable ROCm for windows when available @@ -584,17 +585,20 @@ def install_rocm_zluda(): try: if args.reinstall: zluda_installer.uninstall() - zluda_path = zluda_installer.get_path() - zluda_installer.install(zluda_path) - zluda_installer.make_copy(zluda_path) + zluda_installer.install() except Exception as e: error = e log.warning(f'Failed to install ZLUDA: {e}') + if error is None: try: - zluda_installer.load(zluda_path) + if device is not None and zluda_installer.get_blaslt_enabled(): + log.debug(f'ROCm hipBLASLt: arch={device.name} available={device.blaslt_supported}') + zluda_installer.set_blaslt_enabled(device.blaslt_supported) + zluda_installer.make_copy() + zluda_installer.load() torch_command = os.environ.get('TORCH_COMMAND', f'torch=={zluda_installer.get_default_torch_version(device)} torchvision --index-url https://download.pytorch.org/whl/cu118') - log.info(f'Using ZLUDA in {zluda_path}') + log.info(f'Using ZLUDA in {zluda_installer.path}') except Exception as e: error = e log.warning(f'Failed to load ZLUDA: {e}') @@ -631,7 +635,7 @@ def install_rocm_zluda(): #elif not args.experimental: # uninstall('flash-attn') - if device is not None and rocm.version != "6.2" and rocm.version == rocm.version_torch and rocm.get_blaslt_enabled(): + if device is not None and rocm.version != "6.2" and rocm.get_blaslt_enabled(): log.debug(f'ROCm hipBLASLt: arch={device.name} available={device.blaslt_supported}') rocm.set_blaslt_enabled(device.blaslt_supported) diff --git a/modules/rocm.py b/modules/rocm.py index ef76a1cfa..704ae19c8 100644 --- a/modules/rocm.py +++ b/modules/rocm.py @@ -8,10 +8,6 @@ from enum import Enum -HIPBLASLT_TENSILE_LIBPATH = os.environ.get("HIPBLASLT_TENSILE_LIBPATH", None if sys.platform == "win32" # not available - else "/opt/rocm/lib/hipblaslt/library") - - def resolve_link(path_: str) -> str: if not os.path.islink(path_): return path_ @@ -55,8 +51,7 @@ class Agent: gfx_version: int arch: MicroArchitecture is_apu: bool - if sys.platform != "win32": - blaslt_supported: bool + blaslt_supported: bool @staticmethod def parse_gfx_version(name: str) -> int: @@ -83,8 +78,7 @@ def __init__(self, name: str): else: self.arch = MicroArchitecture.GCN self.is_apu = (self.gfx_version & 0xFFF0 == 0x1150) or self.gfx_version in (0x801, 0x902, 0x90c, 0x1013, 0x1033, 0x1035, 0x1036, 0x1103,) - if sys.platform != "win32": - self.blaslt_supported = os.path.exists(os.path.join(HIPBLASLT_TENSILE_LIBPATH, f"extop_{name}.co")) + self.blaslt_supported = os.path.exists(os.path.join(blaslt_tensile_libpath, f"Kernels.so-000-{name}.hsaco" if sys.platform == "win32" else f"extop_{name}.co")) def get_gfx_version(self) -> Union[str, None]: if self.gfx_version >= 0x1200: @@ -163,6 +157,7 @@ def get_agents() -> List[Agent]: return [Agent(x.split(' ')[-1].strip()) for x in spawn("hipinfo", cwd=os.path.join(path, 'bin')).split("\n") if x.startswith('gcnArchName:')] is_wsl: bool = False + version_torch = None else: def find() -> Union[str, None]: rocm_path = shutil.which("hipconfig") @@ -199,12 +194,12 @@ def load_hsa_runtime() -> None: def set_blaslt_enabled(enabled: bool) -> None: if enabled: load_library_global("/opt/rocm/lib/libhipblaslt.so") # Preload hipBLASLt. - os.environ["HIPBLASLT_TENSILE_LIBPATH"] = HIPBLASLT_TENSILE_LIBPATH + os.environ["HIPBLASLT_TENSILE_LIBPATH"] = blaslt_tensile_libpath else: os.environ["TORCH_BLAS_PREFER_HIPBLASLT"] = "0" def get_blaslt_enabled() -> bool: - return bool(int(os.environ.get("TORCH_BLAS_PREFER_HIPBLASLT", "1"))) + return version == version_torch and bool(int(os.environ.get("TORCH_BLAS_PREFER_HIPBLASLT", "1"))) def get_flash_attention_command(agent: Agent): if os.environ.get("FLASH_ATTENTION_USE_TRITON_ROCM", "FALSE") == "TRUE": @@ -215,10 +210,11 @@ def get_flash_attention_command(agent: Agent): return os.environ.get("FLASH_ATTENTION_PACKAGE", default) is_wsl: bool = os.environ.get('WSL_DISTRO_NAME', 'unknown' if spawn('wslpath -w /') else None) is not None + version_torch = get_version_torch() path = find() +blaslt_tensile_libpath = os.environ.get("HIPBLASLT_TENSILE_LIBPATH", os.path.join(path, "bin" if sys.platform == "win32" else "lib", "hipblaslt", "library")) is_installed = False version = None -version_torch = get_version_torch() if path is not None: is_installed = True version = get_version() diff --git a/modules/zluda_hijacks.py b/modules/zluda_hijacks.py index cd41bcb72..01ca4dec8 100644 --- a/modules/zluda_hijacks.py +++ b/modules/zluda_hijacks.py @@ -36,5 +36,5 @@ def do_hijack(): torch.fft.ifftn = fft_ifftn torch.fft.rfftn = fft_rfftn - if not zluda_installer.experimental_hipBLASLt_support: + if not zluda_installer.get_blaslt_enabled(): torch.jit.script = jit_script diff --git a/modules/zluda_installer.py b/modules/zluda_installer.py index a3cc6d7ab..6c74ee6b8 100644 --- a/modules/zluda_installer.py +++ b/modules/zluda_installer.py @@ -16,12 +16,10 @@ } HIPSDK_TARGETS = ['rocblas.dll', 'rocsolver.dll'] ZLUDA_TARGETS = ('nvcuda.dll', 'nvml.dll',) -experimental_hipBLASLt_support: bool = False -default_agent: Union[rocm.Agent, None] = None - -def get_path() -> str: - return os.path.abspath(os.environ.get('ZLUDA', '.zluda')) +path = os.path.abspath(os.environ.get('ZLUDA', '.zluda')) +default_agent: Union[rocm.Agent, None] = None +hipBLASLt_enabled = os.path.exists(os.path.join(rocm.path, "bin", "hipblaslt.dll")) and os.path.exists(rocm.blaslt_tensile_libpath) and os.path.exists(os.path.join(path, 'cublasLt.dll')) def set_default_agent(agent: rocm.Agent): @@ -29,9 +27,8 @@ def set_default_agent(agent: rocm.Agent): default_agent = agent -def install(zluda_path: os.PathLike) -> None: - if os.path.exists(zluda_path): - __initialize(zluda_path) +def install() -> None: + if os.path.exists(path): return platform = "windows" @@ -46,7 +43,6 @@ def install(zluda_path: os.PathLike) -> None: info.filename = os.path.basename(info.filename) archive.extract(info, '.zluda') os.remove('_zluda') - __initialize(zluda_path) def uninstall() -> None: @@ -54,27 +50,45 @@ def uninstall() -> None: shutil.rmtree('.zluda') -def make_copy(zluda_path: os.PathLike) -> None: - __initialize(zluda_path) +def set_blaslt_enabled(enabled: bool): + global hipBLASLt_enabled # pylint: disable=global-statement + hipBLASLt_enabled = enabled + + +def get_blaslt_enabled() -> bool: + return hipBLASLt_enabled + + +def link_or_copy(src: os.PathLike, dst: os.PathLike): + try: + os.link(src, dst) + except Exception: + shutil.copyfile(src, dst) + +def make_copy() -> None: for k, v in DLL_MAPPING.items(): - if not os.path.exists(os.path.join(zluda_path, v)): - try: - os.link(os.path.join(zluda_path, k), os.path.join(zluda_path, v)) - except Exception: - shutil.copyfile(os.path.join(zluda_path, k), os.path.join(zluda_path, v)) + if not os.path.exists(os.path.join(path, v)): + link_or_copy(os.path.join(path, k), os.path.join(path, v)) + + if hipBLASLt_enabled and not os.path.exists(os.path.join(path, 'cublasLt64_11.dll')): + link_or_copy(os.path.join(path, 'cublasLt.dll'), os.path.join(path, 'cublasLt64_11.dll')) -def load(zluda_path: os.PathLike) -> None: +def load() -> None: os.environ["ZLUDA_COMGR_LOG_LEVEL"] = "1" os.environ["ZLUDA_NVRTC_LIB"] = os.path.join([v for v in site.getsitepackages() if v.endswith("site-packages")][0], "torch", "lib", "nvrtc64_112_0.dll") for v in HIPSDK_TARGETS: ctypes.windll.LoadLibrary(os.path.join(rocm.path, 'bin', v)) for v in ZLUDA_TARGETS: - ctypes.windll.LoadLibrary(os.path.join(zluda_path, v)) + ctypes.windll.LoadLibrary(os.path.join(path, v)) for v in DLL_MAPPING.values(): - ctypes.windll.LoadLibrary(os.path.join(zluda_path, v)) + ctypes.windll.LoadLibrary(os.path.join(path, v)) + + if hipBLASLt_enabled: + ctypes.windll.LoadLibrary(os.path.join(rocm.path, 'bin', 'hipblaslt.dll')) + ctypes.windll.LoadLibrary(os.path.join(path, 'cublasLt64_11.dll')) def conceal(): import torch # pylint: disable=unused-import @@ -94,18 +108,7 @@ def _join_rocm_home(*paths) -> str: def get_default_torch_version(agent: Optional[rocm.Agent]) -> str: if agent is not None: if agent.arch in (rocm.MicroArchitecture.RDNA, rocm.MicroArchitecture.CDNA,): - return "2.4.1" if experimental_hipBLASLt_support else "2.3.1" + return "2.4.1" if hipBLASLt_enabled else "2.3.1" elif agent.arch == rocm.MicroArchitecture.GCN: return "2.2.1" - return "2.4.1" if experimental_hipBLASLt_support else "2.3.1" - - -def __initialize(zluda_path: os.PathLike): - global experimental_hipBLASLt_support # pylint: disable=global-statement - experimental_hipBLASLt_support = os.path.exists(os.path.join(zluda_path, 'cublasLt.dll')) - - if experimental_hipBLASLt_support: - HIPSDK_TARGETS.append('hipblaslt.dll') - DLL_MAPPING['cublasLt.dll'] = 'cublasLt64_11.dll' - else: - HIPSDK_TARGETS.append(f'hiprtc{"".join([v.zfill(2) for v in rocm.version.split(".")])}.dll') + return "2.4.1" if hipBLASLt_enabled else "2.3.1"