diff --git a/op_builder/builder.py b/op_builder/builder.py index 461281d4a569..ab26054bda7d 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -415,10 +415,11 @@ def cpu_arch(self): return '-mcpu=native' return '-march=native' - def is_cuda_enable(self): + def get_cuda_compile_flag(self): try: - assert_no_cuda_mismatch(self.name) - return '-D__ENABLE_CUDA__' + if not self.is_rocm_pytorch(): + assert_no_cuda_mismatch(self.name) + return "-D__ENABLE_CUDA__" except MissingCUDAException: print(f"{WARNING} {self.name} cuda is missing or is incompatible with installed torch, " "only cpu ops can be compiled!") @@ -839,7 +840,7 @@ def cxx_args(self): CPU_ARCH = self.cpu_arch() SIMD_WIDTH = self.simd_width() - CUDA_ENABLE = self.is_cuda_enable() + CUDA_ENABLE = self.get_cuda_compile_flag() args += [ CPU_ARCH, '-fopenmp',