Skip to content

๐Ÿ“š[WIP] FFPA: Yet another Faster Flash Prefill Attention with O(1)๐ŸŽ‰GPU SRAM complexity for headdim > 256, 1.5x~2x๐ŸŽ‰faster vs SDPA EA.

License

Notifications You must be signed in to change notification settings

DefTruth/cuffpa-py

Folders and files

NameName
Last commit message
Last commit date

Latest commit

ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 

Repository files navigation

๐Ÿค– [WIP] FFPA: Yet antother Faster Flash Prefill Attention with O(1) SRAM complexity & O(d/4) or O(1) register complexity for large headdim (D > 256), almost 1.5x~2x ๐ŸŽ‰ faster than SDPA EA with or without MMA Acc F32 on many devices, such as NVIDIA L20, A30, 4090, 3080 Laptop (Experimental ๐Ÿ‘€~). The FFPA kernels are modified from my repo ๐Ÿ“–CUDA-Learn-Notes .

NOTE: This project is still in its early dev stages and now provides a few experimental kernels and benchmarks for reference. More features will be added in the future. Welcome to ๐ŸŒŸ๐Ÿ‘†๐Ÿปstar this repo to support me ~ ๐ŸŽ‰๐ŸŽ‰

ยฉ๏ธCitations๐ŸŽ‰๐ŸŽ‰

@misc{cuffpa-py@2025,
  title={FFPA: Yet another Faster Flash Prefill Attention for large headdim.},
  url={https://github.com/DefTruth/cuffpa-py.git},
  note={Open-source software available at https://github.com/DefTruth/cuffpa-py.git},
  author={DefTruth etc},
  year={2025}
}

๐Ÿ“– Contents

๐Ÿ“– FFPA L1~L3: FlashAttention + QKV Fine-grained Tiling at MMA level ๐Ÿ”‘๏ธ

We have extended FlashAttention for large headdim (D > 256) by implementing Fine-grained Tiling at the MMA level (GEMM style) for the Q@K^T and P@V matmul. This approach results in a constant SRAM usage of Br * 16 or Bc * 16 for Q, K, and V, leading to an overall SRAM complexity of O(Br * 16) โ‰ˆ O(1) and a register complexity of O(d/4) or O(1). Consequently, this method allows us to extend headdim > 256 and achieve faster performance compared to SDPA with or without MMA Accumulation F32 (almost 1.5x~2x ๐ŸŽ‰ faster than SDPA EA).

We have named this new attention tiling technique FFPA: Faster Flash Prefill Attention. We have designed three (L1~L3) levels of FFPA based on SRAM and register complexity considerations. All levels will not introduce any additional VRAM requirements, ensuring that the HBM memory complexity remains same as FlashAttention. ๐Ÿ‘‡

  • ๐Ÿ“šL1: level 1, O(Brx16)โ‰ˆO(1) SRAM complexity, โ‰ˆO(d/4) register complexity.
  • ๐Ÿ“šL2: level 2, O(Brx16)โ‰ˆO(1) SRAM complexity, โ‰ˆO(1) register complexity + Q@K^T recomputation.
  • ๐Ÿ“šL3: level 3, O(Brx16)โ‰ˆO(1) SRAM complexity, โ‰ˆO(1) register complexity + scaling O via HBM offloading.

By leveraging this approach, we can achieve better performance for large headdim (D > 256) through a balanced utilization of FlashAttention (which is not designed to support D > 256) and SDPA EA. Approximate SRAM and register complexity analysis for L1~L3 is as follows: (d=headdim, C,Br,Bc=Constant, Br=Bc) ๐Ÿ‘‡

๐Ÿ“šComplexity ๐Ÿ“šFFPA L1 ๐Ÿ“šFFPA L2 ๐Ÿ“šFFPA L3 ๐Ÿ“šFA-2
SRAM O(2xBrx16)โ‰ˆO(1) O(2xBrx16)โ‰ˆO(1) O(2xBrx16)โ‰ˆO(1) โ‰ˆO(3xBrxd), dโ†‘
Register โ‰ˆO(d/4), dโ†‘ O((Bc/16)x4+2C)โ‰ˆO(1) O((Bc/16)x4+2C)โ‰ˆO(1) โ‰ˆO(d/2), dโ†‘
HBM โ‰ˆFA2 โ‰ˆFA2 โ‰ˆFA2 =FA2

๐Ÿ“– Prerequisites

  • Python >= 3.10
  • PyTorch >= 2.4.0, CUDA >= 12.4
  • Recommended: PyTorch 2.5.1, CUDA 12.5

๐Ÿ“– Installation

The FFPA implemented in this repo can be install as a python library, namely, cuffpa-py library (optional).

git clone https://github.com/DefTruth/cuffpa-py.git
# clone, then, run bash .dev/install.sh directly or run commands:
python3 setup.py bdist_wheel && rm -rf *.egg-info # build 'cuffpa-py' from sources
cd dist && python3 -m pip install cuffpa_py-*-linux_x86_64.whl # pip uninstall cuffpa-py -y

๐Ÿ“– FFPA L1 (Level 1): Benchmark ๐ŸŽ‰๐ŸŽ‰

L1: level 1, O(2xBrx16)โ‰ˆO(1) SRAM complexity, O(d/4) register complexity, the same GPU HBM memory complexity as FlashAttention. B=1, H=48, N=8192, D=320-1024(FA2 not supported ๐Ÿ‘€). (Notes, *=MMA Acc F32, ^=MMA Acc F16, Softmax Acc dtype is always be F32, T=TFLOPS, ๐Ÿ‘‡Benchmark)

  • ๐Ÿ“š NVIDIA L20 (*=MMA Acc F32, ^=MMA Acc F16, T=TFLOPS, ~1.7xโ†‘๐ŸŽ‰)
Algorithm 320 384 448 512 576 640 704 768 832 896 960 1024
SDPA EA 56T 64T 58T 58T 55T 56T 54T 55T 54T 55T 54T 56T
FFPA L1* 98T 100T 102T 94T 94T 93T 93T 92T 90T 91T 90T 91T
Speedup 1.75x 1.56x 1.76x 1.62x 1.71x 1.66x 1.72x 1.67x 1.67x 1.65x 1.67x 1.62x
FFPA L1^ 96T 97T 101T 98T 100T 92T 92T 90T 90T 90T 89T 89T
Speedup 1.71x 1.52x 1.74x 1.69x 1.82x 1.64x 1.7x 1.64x 1.67x 1.64x 1.65x 1.59x
  • ๐Ÿ“š NVIDIA A30 (*=MMA Acc F32, ^=MMA Acc F16, T=TFLOPS, ~1.5xโ†‘๐ŸŽ‰)
Algorithm 320 384 448 512 576 640 704 768 832 896 960 1024
SDPA EA 25T 25T 19T 22T 23T 23T 20T 22T 22T 22T 22T 18T
FFPA L1* 31T 31T 31T 30T 31T 30T 30T 30T 29T 28T 29T 28T
Speedup 1.24x 1.24x 1.63x 1.36x 1.35x 1.3x 1.5x 1.36x 1.32x 1.27x 1.32x 1.56x
FFPA L1^ 31T 31T 32T 31T 31T 31T 31T 30T 30T 30T 29T 29T
Speedup 1.24x 1.24x 1.68x 1.41x 1.35x 1.35x 1.55x 1.36x 1.36x 1.36x 1.32x 1.61x
  • ๐Ÿ“š NVIDIA RTX 3080 Laptop (*=MMA Acc F32, ^=MMA Acc F16, T=TFLOPS, ~2.5xโ†‘๐ŸŽ‰)
Algorithm 320 384 448 512 576 640 704 768 832 896 960 1024
SDPA EA 13T 16T 12T 16T 15T 15T 15T 15T 15T 15T 15T 15T
FFPA L1* 32T 30T 30T 28T 28T 27T 26T 25T 25T 25T 25T 24T
Speedup 2.48x 1.88x 2.55x 1.75x 1.90x 1.77x 1.73x 1.67x 1.66x 1.66x 1.66x 1.54x
FFPA L1^ 40T 38T 39T 36T 35T 34T 33T 32T 31T 31T 28T 27T
Speedup 3.07x 2.42x 3.33x 2.24x 2.35x 2.19x 2.19x 2.13x 2.03x 2.03x 1.90x 1.74x
  • ๐Ÿ“š NVIDIA RTX 4090 (*=MMA Acc F32, ^=MMA Acc F16, T=TFLOPS, ~1.8xโ†‘๐ŸŽ‰)
Algorithm 320 384 448 512 576 640 704 768 832 896 960 1024
SDPA EA 80T 94T 86T 85T 79T 81T 79T 81T 79T 80T 79T 72T
FFPA L1* 135T 140T 143T 135T 134T 134T 134T 134T 131T 131T 130T 131T
Speedup 1.69x 1.49x 1.66x 1.59x 1.7x 1.65x 1.7x 1.65x 1.66x 1.64x 1.65x 1.82x
FFPA L1^ 153T 155T 157T 157T 159T 157T 157T 156T 151T 151T 150T 153T
Speedup 1.91x 1.65x 1.83x 1.85x 2.01x 1.94x 1.99x 1.93x 1.91x 1.89x 1.9x 2.12x

๐Ÿ“– Python Testing

๐Ÿ‘‡ You can test many custom FFPA kernels via Python and figure out the difference in their performance.

# You can test on many devices, such as Volta, Ampere, Ada, Hopper, ...
cd tests && python3 test.py --B 1 --H 48 --N 8192 --show-all --D 320
  • ๐Ÿ“š case: B=1, H=48, N=8192, D=320(FA2 not supported), Device=NVIDIA RTX 4090.
python3 test.py --B 1 --H 48 --N 8192 --show-all --D 320 # NVIDIA RTX 4090
-----------------------------B=1, H=48, N=8192, D=320, Warmup: 1, Iters: 5-----------------------
                   (sdpa): ['-0.01750183 '], time:50.36ms, TFLOPS:82.19 (+0.00 %)(~1.00x)
 (ffpa+acc+f32+L1+stage1): ['-0.01754761 '], time:40.23ms, TFLOPS:102.87(+25.17%)(~1.25x)
 (ffpa+acc+f32+L1+stage2): ['-0.01754761 '], time:30.35ms, TFLOPS:136.34(+32.54%)(~1.66x)
 (ffpa+acc+f16+L1+stage1): ['-0.01747131 '], time:31.03ms, TFLOPS:133.27(+0.00 %)(~1.62x)
 (ffpa+acc+f16+L1+stage2): ['-0.01747131 '], time:26.98ms, TFLOPS:153.41(+12.51%)(~1.87x)
-------------------------------------------------------------------------------------------------
  • ๐Ÿ“š case: Generate benchmark table on Your own device (Welcome to PR your benchmark table ๐ŸŽ‰๐ŸŽ‰)
python3 test.py --gen-bench --show-all # NVIDIA RTX 4090
|Algorithm|320|384|448|512|576|640|704|768|832|896|960|1024|
|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
|SDPA EA|80T|94T|86T|85T|79T|81T|79T|81T|79T|80T|79T|72T|
|FFPA L1*|135T|140T|143T|135T|134T|134T|134T|134T|131T|131T|130T|131T|
|Speedup|1.69x|1.49x|1.66x|1.59x|1.7x|1.65x|1.7x|1.65x|1.66x|1.64x|1.65x|1.82x|
|FFPA L1^|153T|155T|157T|157T|159T|157T|157T|156T|151T|151T|150T|153T|
|Speedup|1.91x|1.65x|1.83x|1.85x|2.01x|1.94x|1.99x|1.93x|1.91x|1.89x|1.9x|2.12x|

ยฉ๏ธLicense

GNU General Public License v3.0

๐ŸŽ‰Contribute

How to contribute? Wecome to starโญ๏ธ this repo to support me๐Ÿ‘†๐Ÿป ~

๐Ÿ“– References

About

๐Ÿ“š[WIP] FFPA: Yet another Faster Flash Prefill Attention with O(1)๐ŸŽ‰GPU SRAM complexity for headdim > 256, 1.5x~2x๐ŸŽ‰faster vs SDPA EA.

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published