Skip to content

Commit

Permalink
Add support for .cu->.cpp copy
Browse files Browse the repository at this point in the history
Signed-off-by: Luka Govedič <[email protected]>
  • Loading branch information
Luka Govedič committed Jan 29, 2025
1 parent 598a000 commit c9da588
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 9 deletions.
36 changes: 27 additions & 9 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -284,21 +284,39 @@ elseif (VLLM_GPU_LANG STREQUAL "HIP")
list(APPEND FA3_GEN_SRCS_CU ${FILE_HIP})
endforeach ()

# TODO: copy cpp->cu for correct hipification
# - try copying into gen/ or maybe even directly into build-tree (make sure that it's where hipify would copy it)
# These files are "converted" to .cu before being passed to torch.build_extension on upstream.
# We need to do the same so that hipify treats them correctly. We copy the files in the source tree like upstream.
set(VLLM_FA2_CPP_CU_SRCS
# csrc/flash_attn_ck/flash_api.cpp # only contains declarations & PyBind
csrc/flash_attn_ck/flash_common.cpp
csrc/flash_attn_ck/mha_bwd.cpp
csrc/flash_attn_ck/mha_fwd_kvcache.cpp
csrc/flash_attn_ck/mha_fwd.cpp
csrc/flash_attn_ck/mha_varlen_bwd.cpp
csrc/flash_attn_ck/mha_varlen_fwd.cpp
)

foreach(CPP_FILE ${VLLM_FA2_CPP_CU_SRCS})
string(REGEX REPLACE "\.cpp$" ".cu" CU_FILE ${CPP_FILE})
set(CU_FILE_ABS ${CMAKE_CURRENT_SOURCE_DIR}/${CU_FILE})
set(CPP_FILE_ABS ${CMAKE_CURRENT_SOURCE_DIR}/${CPP_FILE})
add_custom_command(
OUTPUT ${CU_FILE_ABS}
COMMAND ${CMAKE_COMMAND} -E copy ${CPP_FILE_ABS} ${CU_FILE_ABS}
DEPENDS ${CPP_FILE_ABS}
COMMENT "Copying ${CPP_FILE} to ${CU_FILE_ABS}"
)
list(APPEND VLLM_FA2_CU_SRCS ${CU_FILE}) # relative to source dir
endforeach ()

# This target automatically depends on the copy by depending on copied files
define_gpu_extension_target(
_vllm_fa2_C
DESTINATION vllm_flash_attn
LANGUAGE ${VLLM_GPU_LANG}
SOURCES
# csrc/flash_attn_ck/flash_api.cu # only contains declarations & PyBind
csrc/flash_attn_ck/flash_api_torch_lib.cpp
csrc/flash_attn_ck/flash_common.cu
csrc/flash_attn_ck/mha_bwd.cu
csrc/flash_attn_ck/mha_fwd_kvcache.cu
csrc/flash_attn_ck/mha_fwd.cu
csrc/flash_attn_ck/mha_varlen_bwd.cu
csrc/flash_attn_ck/mha_varlen_fwd.cu
${VLLM_FA2_CU_SRCS}
${FA3_GEN_SRCS_CU}
COMPILE_FLAGS ${VLLM_FA_GPU_FLAGS}
USE_SABI 3
Expand Down
2 changes: 2 additions & 0 deletions csrc/flash_attn_ck/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Renamed from .cpp during build
*.cu

0 comments on commit c9da588

Please sign in to comment.