Skip to content

Commit

Permalink
[feat] tune block size for L1 persist kv g2s (#54)
Browse files Browse the repository at this point in the history
* Update env.py

* Update launch_templates.cuh
  • Loading branch information
DefTruth authored Jan 18, 2025
1 parent 73c89fc commit ef8cb83
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
12 changes: 12 additions & 0 deletions csrc/cuffpa/launch_templates.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,24 @@ void launch_ffpa_mma_L1_template(torch::Tensor Q,
// Need more SRAM, use small tile 64x64 for large headdim
// and large tile for small headdim. headdim > 256 will
// use ffpa-attn, not flash-attn.
// TODO: tune block size for L20/4090/3080 etc.
// prefer small block size on NVIDIA L20 device.
#ifdef BUILD_FFPA_ATTN_MMA_L20
constexpr int kMmaTileSeqLenQ = (kHeadDim <= 128 || kHeadDim > 256) ? 4 : 4;
constexpr int kMmaTileSeqLenK = 1;
constexpr int kMmaTileSeqLenP = (kHeadDim <= 128 || kHeadDim > 256) ? 4 : 4;
constexpr int kMmaTileHeadDimV = 1;
constexpr int kWarpTileSeqLenQ = 1;
constexpr int kWarpTileSeqLenK = (kHeadDim <= 128 || kHeadDim > 256) ? 8 : 8;
#else
constexpr int kMmaTileSeqLenQ = (kHeadDim <= 128 || kHeadDim > 256) ? 8 : 4;
constexpr int kMmaTileSeqLenK = 1;
constexpr int kMmaTileSeqLenP = (kHeadDim <= 128 || kHeadDim > 256) ? 8 : 4;
constexpr int kMmaTileHeadDimV = 1;
constexpr int kWarpTileSeqLenQ = 1;
constexpr int kWarpTileSeqLenK = (kHeadDim <= 128 || kHeadDim > 256) ? 16 : 8;
#endif

#else
// O(1) SRAM complexity, always use large tile for large headdim.
constexpr int kMmaTileSeqLenQ = (kHeadDim <= 128) ? 8 : 8;
Expand Down
10 changes: 10 additions & 0 deletions env.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ def csrc(sub_dir, filename):

@staticmethod
def get_build_cuda_cflags(build_pkg: bool = False):
device_name = ENV.get_device_name()
extra_cuda_cflags = []
extra_cuda_cflags.append("-O3")
extra_cuda_cflags.append("-std=c++17")
Expand All @@ -295,6 +296,15 @@ def get_build_cuda_cflags(build_pkg: bool = False):
extra_cuda_cflags.append(
"-Xptxas -v" if not build_pkg else "--ptxas-options=-O3"
)
extra_cuda_cflags.append(
"-DBUILD_FFPA_ATTN_MMA_L20" if "L20" in device_name else ""
)
extra_cuda_cflags.append(
"-DBUILD_FFPA_ATTN_MMA_4090" if "4090" in device_name else ""
)
extra_cuda_cflags.append(
"-DBUILD_FFPA_ATTN_MMA_3080" if "3080" in device_name else ""
)
extra_cuda_cflags.extend(ENV.env_cuda_cflags())
extra_cuda_cflags.append(f"-I {ENV.project_dir()}/include")
extra_cuda_cflags.append(f"-I {ENV.project_dir()}/csrc/cuffpa")
Expand Down

0 comments on commit ef8cb83

Please sign in to comment.