Skip to content

Commit

Permalink
Merge branch 'main' into feat/alibi
Browse files Browse the repository at this point in the history
  • Loading branch information
monk.detective committed Oct 18, 2023
2 parents bb4fa9f + 02ac572 commit 2b4226a
Show file tree
Hide file tree
Showing 24 changed files with 1,424 additions and 850 deletions.
16 changes: 7 additions & 9 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
os: [ubuntu-20.04]
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.0.dev20230731']
torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.0']
cuda-version: ['11.6.2', '11.7.1', '11.8.0', '12.1.0', '12.2.0']
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
Expand All @@ -58,7 +58,7 @@ jobs:
# Pytorch >= 2.0 only supports Python >= 3.8
- torch-version: '2.0.1'
python-version: '3.7'
- torch-version: '2.1.0.dev20230731'
- torch-version: '2.1.0'
python-version: '3.7'
# Pytorch <= 2.0 only supports CUDA <= 11.8
- torch-version: '1.12.1'
Expand All @@ -73,17 +73,15 @@ jobs:
cuda-version: '12.1.0'
- torch-version: '2.0.1'
cuda-version: '12.2.0'
# Pytorch >= 2.1 only supports CUDA >= 12.1
- torch-version: '2.1.0.dev20230731'
# Pytorch >= 2.1 only supports CUDA >= 11.8
- torch-version: '2.1.0'
cuda-version: '11.6.2'
- torch-version: '2.1.0.dev20230731'
- torch-version: '2.1.0'
cuda-version: '11.7.1'
- torch-version: '2.1.0.dev20230731'
cuda-version: '11.8.0'
# Pytorch >= 2.1 with nvcc 12.1.0 segfaults during compilation, so
# we only use CUDA 12.2. setup.py as a special case that will
# download the wheel for CUDA 12.2 instead.
- torch-version: '2.1.0.dev20230731'
- torch-version: '2.1.0'
cuda-version: '12.1.0'

steps:
Expand Down Expand Up @@ -132,7 +130,7 @@ jobs:
# We want to figure out the CUDA version to download pytorch
# e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
# This code is ugly, maybe there's a better way to do this.
export TORCH_CUDA_VERSION=$(python -c "import os; minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 121}[os.environ['MATRIX_TORCH_VERSION']]; maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121}[os.environ['MATRIX_TORCH_VERSION']]; print(max(min(int(os.environ['MATRIX_CUDA_VERSION']), maxv), minv))")
export TORCH_CUDA_VERSION=$(python -c "import os; minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118}[os.environ['MATRIX_TORCH_VERSION']]; maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121}[os.environ['MATRIX_TORCH_VERSION']]; print(max(min(int(os.environ['MATRIX_CUDA_VERSION']), maxv), minv))")
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
else
Expand Down
105 changes: 101 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Please cite and credit FlashAttention if you use it.
Requirements:
- CUDA 11.6 and above.
- PyTorch 1.12 and above.
- Linux. Might work for Windows starting v2.3.2 (we've seen a few positive [reports](https://github.com/Dao-AILab/flash-attention/issues/595)) but Windows compilation still requires more testing. If you have ideas on how to set up prebuilt CUDA wheels for Windows, please reach out via Github issue.

We recommend the
[Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)
Expand Down Expand Up @@ -81,29 +82,35 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
```

```python
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False):
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1)):
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of Q, K, V.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
Arguments:
qkv: (batch_size, seqlen, 3, nheads, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Return:
out: (batch_size, seqlen, nheads, headdim).
"""
```

```python
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False):
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1)):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
Expand All @@ -113,15 +120,86 @@ Arguments:
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Return:
out: (batch_size, seqlen, nheads, headdim).
"""
```

```python
def flash_attn_with_kvcache(
q,
k_cache,
v_cache,
k=None,
v=None,
rotary_cos=None,
rotary_sin=None,
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
cache_batch_idx: Optional[torch.Tensor] = None,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
rotary_interleaved=True,
):
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
the previous step, and update them with the new keys/values from the current step, and do
attention with the updated cache, all in 1 kernel.
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
Note: Does not support backward pass.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
k with k_cache, starting at the indices specified by cache_seqlens.
v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
KV cache.
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
If the indices are not distinct, and k and v are provided, the values updated in the cache
might come from any of the duplicate indices.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
(i.e. GPT-NeoX style).
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
to automatically determine the number of splits.
Don't change this unless you know what you are doing.
Return:
out: (batch_size, seqlen, nheads, headdim).
"""
```

To see how these functions are used in a multi-head attention layer (which
includes QKV projection, output projection), see the MHA [implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py).

## Upgrading from FlashAttention (1.x) to FlashAttention-2
## Changelog

### 2.0: Complete rewrite, 2x faster
Upgrading from FlashAttention (1.x) to FlashAttention-2

These functions have been renamed:
- `flash_attn_unpadded_func` -> `flash_attn_varlen_func`
Expand All @@ -136,7 +214,7 @@ flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False)
```python
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)
```
## Changes in v2.1 (compared to v2.0)
### 2.1: Change behavior of causal flag

If seqlen_q != seqlen_k and causal=True, the causal mask is aligned to the
bottom right corner of the attention matrix, instead of the top-left corner.
Expand Down Expand Up @@ -165,6 +243,25 @@ v2.1:
1 1
If the row of the mask is all zero, the output will be zero.

### 2.2: Optimize for inference

Optimize for inference (iterative decoding) when query has very small sequence
length (e.g., query sequence length = 1). The bottleneck here is to load KV
cache as fast as possible, and we split the loading across different thread
blocks, with a separate kernel to combine results.

See the function `flash_attn_with_kvcache` with more features for inference
(perform rotary embedding, updating KV cache inplace).

Thanks to the xformers team, and in particular Daniel Haziza, for this
collaboration.

### 2.3: Local (i.e., sliding window) attention

Implement sliding window attention (i.e., local attention). Thanks to [Mistral
AI](https://mistral.ai/) and in particular Timothée Lacroix for this
contribution. Sliding window was used in the [Mistral 7B](https://mistral.ai/news/announcing-mistral-7b/) model.

## Performance

We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).
Expand Down
2 changes: 1 addition & 1 deletion csrc/cutlass
Submodule cutlass updated 26 files
+1 −1 examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h
+1 −1 examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu
+1 −1 examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu
+1 −1 examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_kernel.py
+1 −1 examples/45_dual_gemm/kernel/dual_gemm.h
+1 −1 examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu
+1 −1 include/cute/arch/mma_sm90_desc.hpp
+1 −1 include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp
+14 −3 tools/library/CMakeLists.txt
+15 −38 tools/library/scripts/generator.py
+0 −365 tools/library/src/reference/gemm.cu
+120 −0 tools/library/src/reference/gemm_e4m3a_e4m3out.cu
+111 −0 tools/library/src/reference/gemm_e4m3a_e5m2out.cu
+111 −0 tools/library/src/reference/gemm_e5m2a_e4m3out.cu
+111 −0 tools/library/src/reference/gemm_e5m2a_e5m2out.cu
+112 −0 tools/library/src/reference/gemm_fp32out.cu
+0 −418 tools/library/src/reference/gemm_fp8.cu
+93 −0 tools/library/src/reference/gemm_fp8in_bf16out.cu
+93 −0 tools/library/src/reference/gemm_fp8in_fp16out.cu
+93 −0 tools/library/src/reference/gemm_fp8in_fp32out.cu
+88 −0 tools/library/src/reference/gemm_fp_other.cu
+129 −0 tools/library/src/reference/gemm_int4.cu
+122 −0 tools/library/src/reference/gemm_int8_canonical.cu
+129 −0 tools/library/src/reference/gemm_int8_interleaved_32.cu
+129 −0 tools/library/src/reference/gemm_int8_interleaved_64.cu
+32 −4 tools/library/src/reference/initialize_reference_operations.cu
Loading

0 comments on commit 2b4226a

Please sign in to comment.