๐ค [WIP] FFPA: Yet antother Faster Flash Prefill Attention with O(1) SRAM complexity & O(d/4) or O(1) register complexity for large headdim (D > 256), almost 1.5x~2x ๐ faster than SDPA EA with or without MMA Acc F32 on many devices, such as NVIDIA L20, A30, 4090, 3080 Laptop (Experimental ๐~). The FFPA kernels are modified from my repo ๐CUDA-Learn-Notes .
NOTE: This project is still in its early dev stages and now provides a few experimental kernels and benchmarks for reference. More features will be added in the future. Welcome to ๐๐๐ปstar this repo to support me ~ ๐๐
@misc{cuffpa-py@2025,
title={FFPA: Yet another Faster Flash Prefill Attention for large headdim.},
url={https://github.com/DefTruth/cuffpa-py.git},
note={Open-source software available at https://github.com/DefTruth/cuffpa-py.git},
author={DefTruth etc},
year={2025}
}
- ๐ Installationโ๏ธ
- ๐ Python Testing๐
- ๐ FFPA L1~L3 Design๐ก
- ๐ FFPA L1: L20 ~1.7xโ๐
- ๐ FFPA L1: A30 ~1.5xโ๐
- ๐ FFPA L1: 3080 ~2.5xโ๐
- ๐ FFPA L1: 4090 ~1.8xโ๐
We have extended FlashAttention for large headdim (D > 256) by implementing Fine-grained Tiling at the MMA level (GEMM style) for the Q@K^T and P@V matmul. This approach results in a constant SRAM usage of Br * 16 or Bc * 16 for Q, K, and V, leading to an overall SRAM complexity of O(Br * 16) โ O(1) and a register complexity of O(d/4) or O(1). Consequently, this method allows us to extend headdim > 256 and achieve faster performance compared to SDPA with or without MMA Accumulation F32 (almost 1.5x~2x ๐ faster than SDPA EA).
We have named this new attention tiling technique FFPA: Faster Flash Prefill Attention. We have designed three (L1~L3)
levels of FFPA based on SRAM and register complexity considerations. All levels will not introduce any additional VRAM requirements, ensuring that the HBM memory complexity remains same as FlashAttention. ๐
- ๐L1: level 1, O(Brx16)โO(1) SRAM complexity, โO(d/4) register complexity.
- ๐L2: level 2, O(Brx16)โO(1) SRAM complexity, โO(1) register complexity + Q@K^T recomputation.
- ๐L3: level 3, O(Brx16)โO(1) SRAM complexity, โO(1) register complexity + scaling O via HBM offloading.
By leveraging this approach, we can achieve better performance for large headdim (D > 256) through a balanced utilization of FlashAttention (which is not designed to support D > 256) and SDPA EA. Approximate SRAM and register complexity analysis for L1~L3 is as follows: (d
=headdim, C,Br,Bc
=Constant, Br=Bc
) ๐
๐Complexity | ๐FFPA L1 | ๐FFPA L2 | ๐FFPA L3 | ๐FA-2 |
---|---|---|---|---|
SRAM | O(2xBrx16)โO(1) | O(2xBrx16)โO(1) | O(2xBrx16)โO(1) | โO(3xBrxd), dโ |
Register | โO(d/4), dโ | O((Bc/16)x4+2C)โO(1) | O((Bc/16)x4+2C)โO(1) | โO(d/2), dโ |
HBM | โFA2 | โFA2 | โFA2 | =FA2 |
- Python >= 3.10
- PyTorch >= 2.4.0, CUDA >= 12.4
- Recommended: PyTorch 2.5.1, CUDA 12.5
The FFPA implemented in this repo can be install as a python library, namely, cuffpa-py
library (optional).
git clone https://github.com/DefTruth/cuffpa-py.git
# clone, then, run bash .dev/install.sh directly or run commands:
python3 setup.py bdist_wheel && rm -rf *.egg-info # build 'cuffpa-py' from sources
cd dist && python3 -m pip install cuffpa_py-*-linux_x86_64.whl # pip uninstall cuffpa-py -y
L1: level 1, O(2xBrx16)โO(1) SRAM complexity, O(d/4) register complexity, the same GPU HBM memory complexity as FlashAttention. B=1, H=48, N=8192, D=320-1024(FA2 not supported ๐). (Notes, *
=MMA Acc F32, ^
=MMA Acc F16, Softmax Acc dtype is always be F32, T=TFLOPS, ๐Benchmark)
- ๐ NVIDIA L20 (
*
=MMA Acc F32,^
=MMA Acc F16,T
=TFLOPS, ~1.7xโ๐)
Algorithm | 320 | 384 | 448 | 512 | 576 | 640 | 704 | 768 | 832 | 896 | 960 | 1024 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
SDPA EA | 56T | 64T | 58T | 58T | 55T | 56T | 54T | 55T | 54T | 55T | 54T | 56T |
FFPA L1* | 98T | 100T | 102T | 94T | 94T | 93T | 93T | 92T | 90T | 91T | 90T | 91T |
Speedup | 1.75x | 1.56x | 1.76x | 1.62x | 1.71x | 1.66x | 1.72x | 1.67x | 1.67x | 1.65x | 1.67x | 1.62x |
FFPA L1^ | 96T | 97T | 101T | 98T | 100T | 92T | 92T | 90T | 90T | 90T | 89T | 89T |
Speedup | 1.71x | 1.52x | 1.74x | 1.69x | 1.82x | 1.64x | 1.7x | 1.64x | 1.67x | 1.64x | 1.65x | 1.59x |
- ๐ NVIDIA A30 (
*
=MMA Acc F32,^
=MMA Acc F16,T
=TFLOPS, ~1.5xโ๐)
Algorithm | 320 | 384 | 448 | 512 | 576 | 640 | 704 | 768 | 832 | 896 | 960 | 1024 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
SDPA EA | 25T | 25T | 19T | 22T | 23T | 23T | 20T | 22T | 22T | 22T | 22T | 18T |
FFPA L1* | 31T | 31T | 31T | 30T | 31T | 30T | 30T | 30T | 29T | 28T | 29T | 28T |
Speedup | 1.24x | 1.24x | 1.63x | 1.36x | 1.35x | 1.3x | 1.5x | 1.36x | 1.32x | 1.27x | 1.32x | 1.56x |
FFPA L1^ | 31T | 31T | 32T | 31T | 31T | 31T | 31T | 30T | 30T | 30T | 29T | 29T |
Speedup | 1.24x | 1.24x | 1.68x | 1.41x | 1.35x | 1.35x | 1.55x | 1.36x | 1.36x | 1.36x | 1.32x | 1.61x |
- ๐ NVIDIA RTX 3080 Laptop (
*
=MMA Acc F32,^
=MMA Acc F16,T
=TFLOPS, ~2.5xโ๐)
Algorithm | 320 | 384 | 448 | 512 | 576 | 640 | 704 | 768 | 832 | 896 | 960 | 1024 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
SDPA EA | 13T | 16T | 12T | 16T | 15T | 15T | 15T | 15T | 15T | 15T | 15T | 15T |
FFPA L1* | 32T | 30T | 30T | 28T | 28T | 27T | 26T | 25T | 25T | 25T | 25T | 24T |
Speedup | 2.48x | 1.88x | 2.55x | 1.75x | 1.90x | 1.77x | 1.73x | 1.67x | 1.66x | 1.66x | 1.66x | 1.54x |
FFPA L1^ | 40T | 38T | 39T | 36T | 35T | 34T | 33T | 32T | 31T | 31T | 28T | 27T |
Speedup | 3.07x | 2.42x | 3.33x | 2.24x | 2.35x | 2.19x | 2.19x | 2.13x | 2.03x | 2.03x | 1.90x | 1.74x |
- ๐ NVIDIA RTX 4090 (
*
=MMA Acc F32,^
=MMA Acc F16,T
=TFLOPS, ~1.8xโ๐)
Algorithm | 320 | 384 | 448 | 512 | 576 | 640 | 704 | 768 | 832 | 896 | 960 | 1024 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
SDPA EA | 80T | 94T | 86T | 85T | 79T | 81T | 79T | 81T | 79T | 80T | 79T | 72T |
FFPA L1* | 135T | 140T | 143T | 135T | 134T | 134T | 134T | 134T | 131T | 131T | 130T | 131T |
Speedup | 1.69x | 1.49x | 1.66x | 1.59x | 1.7x | 1.65x | 1.7x | 1.65x | 1.66x | 1.64x | 1.65x | 1.82x |
FFPA L1^ | 153T | 155T | 157T | 157T | 159T | 157T | 157T | 156T | 151T | 151T | 150T | 153T |
Speedup | 1.91x | 1.65x | 1.83x | 1.85x | 2.01x | 1.94x | 1.99x | 1.93x | 1.91x | 1.89x | 1.9x | 2.12x |
๐ You can test many custom FFPA kernels via Python and figure out the difference in their performance.
# You can test on many devices, such as Volta, Ampere, Ada, Hopper, ...
cd tests && python3 test.py --B 1 --H 48 --N 8192 --show-all --D 320
- ๐ case: B=1, H=48, N=8192, D=320(
FA2 not supported
), Device=NVIDIA RTX 4090.
python3 test.py --B 1 --H 48 --N 8192 --show-all --D 320 # NVIDIA RTX 4090
-----------------------------B=1, H=48, N=8192, D=320, Warmup: 1, Iters: 5-----------------------
(sdpa): ['-0.01750183 '], time:50.36ms, TFLOPS:82.19 (+0.00 %)(~1.00x)
(ffpa+acc+f32+L1+stage1): ['-0.01754761 '], time:40.23ms, TFLOPS:102.87(+25.17%)(~1.25x)
(ffpa+acc+f32+L1+stage2): ['-0.01754761 '], time:30.35ms, TFLOPS:136.34(+32.54%)(~1.66x)
(ffpa+acc+f16+L1+stage1): ['-0.01747131 '], time:31.03ms, TFLOPS:133.27(+0.00 %)(~1.62x)
(ffpa+acc+f16+L1+stage2): ['-0.01747131 '], time:26.98ms, TFLOPS:153.41(+12.51%)(~1.87x)
-------------------------------------------------------------------------------------------------
- ๐ case: Generate benchmark table on Your own device (Welcome to PR your benchmark table ๐๐)
python3 test.py --gen-bench --show-all # NVIDIA RTX 4090
|Algorithm|320|384|448|512|576|640|704|768|832|896|960|1024|
|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
|SDPA EA|80T|94T|86T|85T|79T|81T|79T|81T|79T|80T|79T|72T|
|FFPA L1*|135T|140T|143T|135T|134T|134T|134T|134T|131T|131T|130T|131T|
|Speedup|1.69x|1.49x|1.66x|1.59x|1.7x|1.65x|1.7x|1.65x|1.66x|1.64x|1.65x|1.82x|
|FFPA L1^|153T|155T|157T|157T|159T|157T|157T|156T|151T|151T|150T|153T|
|Speedup|1.91x|1.65x|1.83x|1.85x|2.01x|1.94x|1.99x|1.93x|1.91x|1.89x|1.9x|2.12x|
GNU General Public License v3.0
How to contribute? Wecome to starโญ๏ธ this repo to support me๐๐ป ~