From e6a8026489cedee3d39f76e1fbfecda9f278e424 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 18 Sep 2023 16:08:44 -0700 Subject: [PATCH 01/23] [Gen] Rename max_sequence_len->max_seqlen, sequence_len_offset->seqlen_offset --- flash_attn/modules/mha.py | 48 ++++++++---------- flash_attn/utils/generation.py | 93 +++++++++++++++++----------------- tests/models/test_gpt.py | 6 +-- 3 files changed, 72 insertions(+), 75 deletions(-) diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 233f84faf..4894daca5 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -300,7 +300,7 @@ def _update_kv_cache(kv, inference_params, layer_idx): if layer_idx not in inference_params.key_value_memory_dict: kv_cache = torch.empty( inference_params.max_batch_size, - inference_params.max_sequence_len, + inference_params.max_seqlen, 2, num_heads, head_dim, @@ -313,7 +313,7 @@ def _update_kv_cache(kv, inference_params, layer_idx): # Adjust key and value for inference batch_start = inference_params.batch_size_offset batch_end = batch_start + kv.shape[0] - sequence_start = inference_params.sequence_len_offset + sequence_start = inference_params.seqlen_offset sequence_end = sequence_start + kv.shape[1] assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0]) assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2]) @@ -445,12 +445,12 @@ def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params): q: (batch_size, seqlen_q, nheads, head_dim) kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim) """ - assert inference_params is not None and inference_params.sequence_len_offset > 0 + assert inference_params is not None and inference_params.seqlen_offset > 0 assert self.use_flash_attn if self.rotary_emb_dim > 0: assert self.rotary_emb.scale is None, "This code path does not support xPos" self.rotary_emb._update_cos_sin_cache( - inference_params.max_sequence_len, device=q.device, dtype=q.dtype + inference_params.max_seqlen, device=q.device, dtype=q.dtype ) rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached else: @@ -460,7 +460,7 @@ def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params): cache_seqlens = ( inference_params.lengths_per_sample[:batch] if inference_params.lengths_per_sample is not None - else inference_params.sequence_len_offset + else inference_params.seqlen_offset ) context = flash_attn_with_kvcache( q, @@ -480,11 +480,11 @@ def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params): def _update_kvcache_attention(self, q, kv, inference_params): """Write kv to inference_params, then do attention""" if ( - inference_params.sequence_len_offset == 0 + inference_params.seqlen_offset == 0 or flash_attn_with_kvcache is None or not self.use_flash_attn ): - # TODO: this only uses sequence_len_offset and not lengths_per_sample. + # TODO: this only uses seqlen_offset and not lengths_per_sample. kv = self._update_kv_cache(kv, inference_params) return self.inner_cross_attn(q, kv) else: @@ -493,7 +493,7 @@ def _update_kvcache_attention(self, q, kv, inference_params): cache_seqlens = ( inference_params.lengths_per_sample[:batch] if inference_params.lengths_per_sample is not None - else inference_params.sequence_len_offset + else inference_params.seqlen_offset ) return flash_attn_with_kvcache( q, @@ -561,12 +561,10 @@ def forward( else ( inference_params.lengths_per_sample if inference_params.lengths_per_sample is not None - else inference_params.sequence_len_offset + else inference_params.seqlen_offset ) ) - rotary_max_seqlen = ( - inference_params.max_sequence_len if inference_params is not None else None - ) + rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None batch, seqlen = x.shape[:2] if not self.cross_attn and self.num_heads_kv == self.num_heads: assert x_kv is None and mixer_subset is None @@ -581,7 +579,7 @@ def forward( qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim) if ( inference_params is None - or inference_params.sequence_len_offset == 0 + or inference_params.seqlen_offset == 0 or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) or not self.use_flash_attn ): @@ -632,7 +630,7 @@ def forward( ).contiguous() if ( inference_params is None - or inference_params.sequence_len_offset == 0 + or inference_params.seqlen_offset == 0 or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) or not self.use_flash_attn ): @@ -772,12 +770,12 @@ def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params): q: (batch_size, seqlen_q, nheads, head_dim) kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim) """ - assert inference_params is not None and inference_params.sequence_len_offset > 0 + assert inference_params is not None and inference_params.seqlen_offset > 0 assert self.use_flash_attn if self.rotary_emb_dim > 0: assert self.rotary_emb.scale is None, "This code path does not support xPos" self.rotary_emb._update_cos_sin_cache( - inference_params.max_sequence_len, device=q.device, dtype=q.dtype + inference_params.max_seqlen, device=q.device, dtype=q.dtype ) rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached else: @@ -787,7 +785,7 @@ def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params): cache_seqlens = ( inference_params.lengths_per_sample[:batch] if inference_params.lengths_per_sample is not None - else inference_params.sequence_len_offset + else inference_params.seqlen_offset ) context = flash_attn_with_kvcache( q, @@ -806,8 +804,8 @@ def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params): def _update_kvcache_attention(self, q, kv, inference_params): """Write kv to inference_params, then do attention""" - if inference_params.sequence_len_offset == 0 or not self.use_flash_attn: - # TODO: this only uses sequence_len_offset and not lengths_per_sample. + if inference_params.seqlen_offset == 0 or not self.use_flash_attn: + # TODO: this only uses seqlen_offset and not lengths_per_sample. kv = self._update_kv_cache(kv, inference_params) return self.inner_cross_attn(q, kv) else: @@ -816,7 +814,7 @@ def _update_kvcache_attention(self, q, kv, inference_params): cache_seqlens = ( inference_params.lengths_per_sample[:batch] if inference_params.lengths_per_sample is not None - else inference_params.sequence_len_offset + else inference_params.seqlen_offset ) context = flash_attn_with_kvcache( q, @@ -847,17 +845,15 @@ def forward(self, x, seqlen=None, inference_params=None, **kwargs): else ( inference_params.lengths_per_sample if inference_params.lengths_per_sample is not None - else inference_params.sequence_len_offset + else inference_params.seqlen_offset ) ) - rotary_max_seqlen = ( - inference_params.max_sequence_len if inference_params is not None else None - ) + rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None if self.num_heads_kv == self.num_heads: qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim) if ( inference_params is None - or inference_params.sequence_len_offset == 0 + or inference_params.seqlen_offset == 0 or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) or not self.use_flash_attn ): @@ -892,7 +888,7 @@ def forward(self, x, seqlen=None, inference_params=None, **kwargs): ) if ( inference_params is None - or inference_params.sequence_len_offset == 0 + or inference_params.seqlen_offset == 0 or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) or not self.use_flash_attn ): diff --git a/flash_attn/utils/generation.py b/flash_attn/utils/generation.py index 6e671cef4..7afcd92cb 100644 --- a/flash_attn/utils/generation.py +++ b/flash_attn/utils/generation.py @@ -20,13 +20,20 @@ class InferenceParams: """Inference parameters that are passed to the main model in order to efficienly calculate and store the context during inference.""" - max_sequence_len: int + max_seqlen: int max_batch_size: int - sequence_len_offset: int = 0 + seqlen_offset: int = 0 batch_size_offset: int = 0 key_value_memory_dict: dict = field(default_factory=dict) lengths_per_sample: Optional[Tensor] = None + def reset(self, max_seqlen, max_batch_size): + self.max_seqlen = max_seqlen + self.max_batch_size = max_batch_size + self.seqlen_offset = 0 + if self.lengths_per_sample is not None: + self.lengths_per_sample.zero_() + # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231 @@ -127,19 +134,16 @@ def decode( tensor_parallel=tensor_parallel, ) inference_params = model._decoding_cache.inference_params - inference_params.max_sequence_len = max_length - inference_params.max_batch_size = batch_size - inference_params.sequence_len_offset = 0 - inference_params.lengths_per_sample.zero_() + inference_params.reset(max_length, batch_size) else: - inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size) + inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) def get_logits(input_ids, inference_params): - decoding = inference_params.sequence_len_offset > 0 + decoding = inference_params.seqlen_offset > 0 if decoding: position_ids = torch.full( (batch_size, 1), - inference_params.sequence_len_offset, + inference_params.seqlen_offset, dtype=torch.long, device=input_ids.device, ) @@ -154,24 +158,24 @@ def get_logits(input_ids, inference_params): ).logits.squeeze(dim=1) else: logits = model._decoding_cache.run( - input_ids, position_ids, inference_params.sequence_len_offset + input_ids, position_ids, inference_params.seqlen_offset ).clone() return logits[..., :vocab_size] if vocab_size is not None else logits def sample_tokens(logits, inference_params): - if teacher_outputs is None or teacher_output_len <= inference_params.sequence_len_offset: + if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset: token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) else: - token = teacher_outputs[:, inference_params.sequence_len_offset] + token = teacher_outputs[:, inference_params.seqlen_offset] # return rearrange(token, "b -> b 1") return token.unsqueeze(1) def should_stop(current_token, inference_params): - if inference_params.sequence_len_offset == 0: + if inference_params.seqlen_offset == 0: return False if eos_token_id is not None and (current_token == eos_token_id).all(): return True - if inference_params.sequence_len_offset >= max_length - 1: + if inference_params.seqlen_offset >= max_length - 1: return True return False @@ -185,7 +189,7 @@ def should_stop(current_token, inference_params): scores, sequences = [], [input_ids] while not should_stop(sequences[-1], inference_params): scores.append(get_logits(sequences[-1], inference_params)) - inference_params.sequence_len_offset += sequences[-1].shape[1] + inference_params.seqlen_offset += sequences[-1].shape[1] sequences.append(sample_tokens(scores[-1], inference_params)) if enable_timing: end.record() @@ -256,6 +260,7 @@ def sample_speculative(logits, logits_draft, tokens_draft, top_k=1, top_p=0.0, t return tokens, first_rejected_idx + 1 +@torch.inference_mode() def decode_speculative( input_ids, model, @@ -303,15 +308,11 @@ def decode_speculative( tensor_parallel=tensor_parallel, ) inference_params_draft = model_draft._decoding_cache.inference_params - inference_params_draft.max_sequence_len = max_length - inference_params_draft.max_batch_size = batch_size - inference_params_draft.sequence_len_offset = 0 - inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size) + inference_params_draft.reset(max_length, batch_size) + inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) else: - inference_params_draft = InferenceParams( - max_sequence_len=max_length, max_batch_size=batch_size - ) - inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size) + inference_params_draft = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) + inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) def logits_forward_fn(model, input_ids, position_ids, inference_params, cg=False): if not cg: @@ -323,7 +324,7 @@ def logits_forward_fn(model, input_ids, position_ids, inference_params, cg=False ).logits.squeeze(dim=1) else: return model._decoding_cache.run( - input_ids, position_ids, inference_params.sequence_len_offset + input_ids, position_ids, inference_params.seqlen_offset ).clone() logits_postprocess_fn = ( @@ -365,13 +366,13 @@ def sample_tokens( assert seqlen == 1 position_ids = repeat( torch.arange(seqlen, dtype=torch.long, device=input_ids.device) - + inference_params.sequence_len_offset, + + inference_params.seqlen_offset, "s -> b s", b=batch_size, ) # position_ids = torch.full( # (batch_size, 1), - # inference_params.sequence_len_offset, + # inference_params.seqlen_offset, # dtype=torch.long, # device=input_ids.device, # ) @@ -380,7 +381,7 @@ def sample_tokens( logits = logits_postprocess_fn( logits_forward_fn(model, input_ids, position_ids, inference_params, cg=decoding and cg) ) - inference_params.sequence_len_offset += input_ids.shape[1] + inference_params.seqlen_offset += input_ids.shape[1] scores = [logits] next_token = sample_fn(logits) sequences.append(next_token) @@ -388,7 +389,7 @@ def sample_tokens( if i < num_tokens - 1 or last_token_logits: position_ids = torch.full( (batch_size, 1), - inference_params_draft.sequence_len_offset, + inference_params_draft.seqlen_offset, dtype=torch.long, device=input_ids.device, ) @@ -401,7 +402,7 @@ def sample_tokens( cg=cg, ) ) - inference_params.sequence_len_offset += 1 + inference_params.seqlen_offset += 1 scores.append(logits) if i < num_tokens - 1: next_token = sample_fn(logits) @@ -476,8 +477,8 @@ def sample_tokens( scores.append(logits[:1, : num_generated_tokens[0]]) # Note that @model has not evaluated the last sampled token yet, so we'll need to pass # that in the next time we call @model. - inference_params.sequence_len_offset = seqlen_og + num_generated_tokens[0].item() - 1 - inference_params_draft.sequence_len_offset = inference_params.sequence_len_offset + inference_params.seqlen_offset = seqlen_og + num_generated_tokens[0].item() - 1 + inference_params_draft.seqlen_offset = inference_params.seqlen_offset if debug: cur_ids = torch.cat([input_ids, sequences[-1]], dim=1) scores_ref = model( @@ -486,10 +487,10 @@ def sample_tokens( print((scores[-1] - scores_ref[:, :-1]).abs().max()) while True: - # sequence_len_offset is total length generated - 1 - if inference_params.sequence_len_offset >= max_length - 1: + # seqlen_offset is total length generated - 1 + if inference_params.seqlen_offset >= max_length - 1: break - if inference_params.sequence_len_offset >= max_length - 2: + if inference_params.seqlen_offset >= max_length - 2: # Don't do speculative sampling, just sample 1 token from the model tokens, scores_new = sample_tokens_main(sequences[-1][:, -1:], num_tokens=1) sequences.append(tokens) @@ -497,7 +498,7 @@ def sample_tokens( break # Sample from draft model n_spec_tokens = min( - speculative_lookahead, max_length - inference_params_draft.sequence_len_offset - 2 + speculative_lookahead, max_length - inference_params_draft.seqlen_offset - 2 ) tokens_draft, scores_draft = sample_tokens_draft( sequences[-1][:, -1:], num_tokens=n_spec_tokens @@ -510,9 +511,9 @@ def sample_tokens( # Evaluate the draft tokens with the model position_ids = repeat( torch.arange( - inference_params.sequence_len_offset, + inference_params.seqlen_offset, # 1 extra token from last time that hasn't been passed through model - inference_params.sequence_len_offset + n_spec_tokens + 1, + inference_params.seqlen_offset + n_spec_tokens + 1, dtype=torch.long, device=input_ids.device, ), @@ -525,7 +526,7 @@ def sample_tokens( inference_params=inference_params, ).logits # (batch, n_spec_tokens, vocab_size) logits = logits_postprocess_fn(logits) - inference_params.sequence_len_offset += 1 + inference_params.seqlen_offset += 1 if debug: logits_ref = model( torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1 @@ -539,8 +540,8 @@ def sample_tokens( print(num_generated_tokens) sequences.append(tokens[:1, : num_generated_tokens[0]]) scores.append(logits[:1, : num_generated_tokens[0]]) - inference_params.sequence_len_offset += num_generated_tokens[0].item() - 1 - inference_params_draft.sequence_len_offset = inference_params.sequence_len_offset + inference_params.seqlen_offset += num_generated_tokens[0].item() - 1 + inference_params_draft.seqlen_offset = inference_params.seqlen_offset # breakpoint() if debug: cur_ids = torch.cat([cur_ids, sequences[-1]], dim=1) @@ -679,9 +680,9 @@ def update_graph_cache( ) lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device) cache.inference_params = InferenceParams( - max_sequence_len=max_seqlen, + max_seqlen=max_seqlen, max_batch_size=batch_size, - sequence_len_offset=seqlen_og, + seqlen_offset=seqlen_og, key_value_memory_dict=inf_cache, lengths_per_sample=lengths_per_sample, ) @@ -705,7 +706,7 @@ def dispatch(input_ids, position_ids, seqlen): ) cache.run = dispatch - cache.inference_params.sequence_len_offset = 0 # Reset so it's not confusing + cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing return cache @@ -713,10 +714,10 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, device = next(iter(model.parameters())).device input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device) position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device) - sequence_len_offset_og = inference_params.sequence_len_offset + seqlen_offset_og = inference_params.seqlen_offset # TD [2023-04-14]: important for correctness of the FT's attention kernel, as seqlen_cpu is # used to determine the size of smem. Hence seqlen_cpu must be >= lengths_per_sample. - inference_params.sequence_len_offset = max_seqlen - 1 + inference_params.seqlen_offset = max_seqlen - 1 inference_params.lengths_per_sample[:] = max_seqlen - 1 # Warmup before capture @@ -755,5 +756,5 @@ def run(new_input_ids, new_position_ids, seqlen): graph.replay() return logits.clone() - inference_params.sequence_len_offset = sequence_len_offset_og + inference_params.seqlen_offset = seqlen_offset_og return run diff --git a/tests/models/test_gpt.py b/tests/models/test_gpt.py index 2e1845147..09c6556bf 100644 --- a/tests/models/test_gpt.py +++ b/tests/models/test_gpt.py @@ -364,14 +364,14 @@ def test_gpt2_multiple_token_generation(model_name, optimized): logits_ref = model(input_ids).logits # Run 10 tokens, then pass in another 4, then another 6, to see if we get the same logits - inference_params = InferenceParams(max_sequence_len=20, max_batch_size=1) + inference_params = InferenceParams(max_seqlen=20, max_batch_size=1) logits_10 = model(input_ids[:, :10], inference_params=inference_params).logits - inference_params.sequence_len_offset += 10 + inference_params.seqlen_offset += 10 position_ids = torch.arange(10, 14, dtype=torch.long, device=device) logits_1014 = model( input_ids[:, 10:14], position_ids=position_ids, inference_params=inference_params ).logits - inference_params.sequence_len_offset += 4 + inference_params.seqlen_offset += 4 position_ids = torch.arange(14, 20, dtype=torch.long, device=device) logits_1420 = model( input_ids[:, 14:20], position_ids=position_ids, inference_params=inference_params From e0fbaa7016e30dff62992706f39cab4a3dade7c4 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 18 Sep 2023 22:57:20 -0700 Subject: [PATCH 02/23] [Gen] Simplify decode_speculative --- flash_attn/utils/generation.py | 461 ++++++++++++++++----------------- tests/models/test_gpt.py | 12 +- 2 files changed, 229 insertions(+), 244 deletions(-) diff --git a/flash_attn/utils/generation.py b/flash_attn/utils/generation.py index 7afcd92cb..bbbe34bc1 100644 --- a/flash_attn/utils/generation.py +++ b/flash_attn/utils/generation.py @@ -159,7 +159,7 @@ def get_logits(input_ids, inference_params): else: logits = model._decoding_cache.run( input_ids, position_ids, inference_params.seqlen_offset - ).clone() + ).squeeze(dim=1) return logits[..., :vocab_size] if vocab_size is not None else logits def sample_tokens(logits, inference_params): @@ -305,256 +305,250 @@ def decode_speculative( batch_size, seqlen_og, max_length, + # draft model needs to process either 1 or 2 tokens at a time + decoding_seqlens=(1, 2), tensor_parallel=tensor_parallel, ) inference_params_draft = model_draft._decoding_cache.inference_params inference_params_draft.reset(max_length, batch_size) - inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) + if not hasattr(model, "_decoding_cache"): + model._decoding_cache = None + model._decoding_cache = update_graph_cache( + model, + model._decoding_cache, + batch_size, + seqlen_og, + max_length, + decoding_seqlens=range(1, speculative_lookahead + 2), + tensor_parallel=tensor_parallel, + ) + inference_params = model._decoding_cache.inference_params + inference_params.reset(max_length, batch_size) else: inference_params_draft = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) - def logits_forward_fn(model, input_ids, position_ids, inference_params, cg=False): - if not cg: - return model( + def get_logits(input_ids, inference_params, model, num_last_tokens=1, cg=False): + decoding = inference_params.seqlen_offset > 0 + if decoding: + seqlen = input_ids.shape[1] + # if inference_params.lengths_per_sample is None: + # TODO: in the case of batched decoding where each sequence has a different length, + # we need to compute the position_ids for each sequence using lengths_per_sample + if True: + cache_seqlens = torch.full( + (input_ids.shape[0],), + inference_params.seqlen_offset, + dtype=torch.int32, + device=input_ids.device, + ) + else: + cache_seqlens = inference_params.lengths_per_sample + position_ids = cache_seqlens[:, None] + torch.arange( + seqlen, dtype=torch.long, device=input_ids.device + ) + else: + position_ids = None + if not cg or not decoding: + logits = model( input_ids, position_ids=position_ids, inference_params=inference_params, - num_last_tokens=1, - ).logits.squeeze(dim=1) + num_last_tokens=num_last_tokens, + ).logits else: - return model._decoding_cache.run( + # NOTE: careful, CUDA graph is set to have num_last_tokens=input_ids.shape[1]. + # This might not be compatible the num_last_tokens used here. + assert num_last_tokens <= input_ids.shape[1] + logits = model._decoding_cache.run( input_ids, position_ids, inference_params.seqlen_offset - ).clone() - - logits_postprocess_fn = ( - lambda logits: logits[..., :vocab_size] if vocab_size is not None else logits - ) + )[:, -num_last_tokens:] + return logits[..., :vocab_size] if vocab_size is not None else logits - def sample_tokens( - input_ids, - model, - inference_params, - sample_fn, - num_tokens=1, - cg=False, - decoding=True, - last_token_logits=False, - ): + def sample_tokens(input_ids, get_logits_fn, inference_params, sample_fn, num_tokens=1): """Sample `num_tokens` tokens from the model, given the previous logits. Also return the logits of the sampled tokens. Arguments: input_ids: (batch, seqlen) - decoding: whether we're in the decoding phase or the prefilling phase. Prefill doesn't - need special position_ids. - last_token_logits: whether to return the logits of the last token. Normally we don't need this. - However, for speculative sampling, if the main model accepts all the draft tokens, plus it - samples one new token, then by right at the next iteration the draft model need to evaluate - the logits of the last draft token and the logits of the newly sampled token. - This makes implementation more complicated. So here we just evaluate the logits of the last - token in the draft model to simplify the implementation. Return: tokens: (batch, num_tokens) scores: (batch, num_tokens), which contains @previous_logits and the logits of the next - (num_tokens - 1) tokens. The logits of the last token isn't computed unless last_token_logits=True. - In which case we have scores of shape (batch, num_tokens + 1) + (num_tokens - 1) tokens. The logits of the last token isn't computed. """ - batch_size, seqlen = input_ids.shape assert num_tokens >= 1 - sequences = [] - if decoding: - assert seqlen == 1 - position_ids = repeat( - torch.arange(seqlen, dtype=torch.long, device=input_ids.device) - + inference_params.seqlen_offset, - "s -> b s", - b=batch_size, - ) - # position_ids = torch.full( - # (batch_size, 1), - # inference_params.seqlen_offset, - # dtype=torch.long, - # device=input_ids.device, - # ) - else: - position_ids = None - logits = logits_postprocess_fn( - logits_forward_fn(model, input_ids, position_ids, inference_params, cg=decoding and cg) - ) - inference_params.seqlen_offset += input_ids.shape[1] - scores = [logits] - next_token = sample_fn(logits) - sequences.append(next_token) + sequences, scores = [input_ids], [] for i in range(num_tokens): - if i < num_tokens - 1 or last_token_logits: - position_ids = torch.full( - (batch_size, 1), - inference_params_draft.seqlen_offset, - dtype=torch.long, - device=input_ids.device, - ) - logits = logits_postprocess_fn( - logits_forward_fn( - model, - rearrange(next_token, "b -> b 1"), - position_ids, - inference_params, - cg=cg, - ) - ) - inference_params.seqlen_offset += 1 - scores.append(logits) - if i < num_tokens - 1: - next_token = sample_fn(logits) - sequences.append(next_token) - return torch.stack(sequences, dim=1), torch.stack(scores, dim=1) + scores.append(get_logits_fn(sequences[-1], inference_params)[:, -1]) + inference_params.seqlen_offset += sequences[-1].shape[1] + sequences.append(sample_fn(scores[-1]).unsqueeze(1)) + return torch.cat(sequences[1:], dim=1), torch.stack(scores, dim=1) sampling_kwargs = dict(top_k=top_k, top_p=top_p, temperature=temperature) sample_fn = partial(sample, **sampling_kwargs) + get_logits_main = partial(get_logits, model=model, cg=cg) + get_logits_draft = partial(get_logits, model=model_draft, cg=cg) sample_tokens_main = partial( - sample_tokens, model=model, sample_fn=sample_fn, inference_params=inference_params, cg=False - ) # main model doesn't use CUDA graph + sample_tokens, + get_logits_fn=get_logits_main, + sample_fn=sample_fn, + inference_params=inference_params, + ) sample_tokens_draft = partial( sample_tokens, - model=model_draft, + get_logits_fn=get_logits_draft, sample_fn=sample_fn, - last_token_logits=True, inference_params=inference_params_draft, - cg=cg, ) if debug: from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("gpt2") - sequences = [input_ids] - scores = [] - with torch.inference_mode(): - if enable_timing: - if tensor_parallel > 1: - torch.distributed.barrier() - torch.cuda.synchronize() - start = time.time() - - if seqlen_og >= max_length - 1: + if enable_timing: + if tensor_parallel > 1: + torch.distributed.barrier() + torch.cuda.synchronize() + start = time.time() + + sequences, scores = [input_ids], [] + num_main_model_calls = 0 + num_draft_tokens = 0 + num_accepted_tokens_history = [] + if seqlen_og >= max_length - 1: + # Don't do speculative sampling, just sample 1 token from the model + tokens, scores_new = sample_tokens_main(input_ids, num_tokens=1) + sequences.append(tokens) + scores.append(scores_new) + else: + # Sample from draft model, which produces @n_spec_tokens, and @model + # will then use to produce between 1 and 1 + @n_spec_tokens tokens. + # We want seqlen_og + 1 + @n_spec_tokens to be <= @max_length. + n_spec_tokens = min(speculative_lookahead, max_length - seqlen_og - 1) + tokens_draft, scores_draft = sample_tokens_draft(input_ids, num_tokens=n_spec_tokens) + num_draft_tokens += n_spec_tokens + if debug: + scores_draft_ref = model_draft( + torch.cat([input_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1 + ).logits + print((scores_draft - scores_draft_ref[:, :-1]).abs().max()) + + # Evaluate the draft tokens with the model + logits = get_logits_main( + torch.cat([input_ids, tokens_draft], dim=1), + inference_params, + num_last_tokens=n_spec_tokens + 1, + ) + num_main_model_calls += 1 + if debug: + logits_ref = model( + torch.cat([input_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1 + ).logits + print((logits - logits_ref).abs().max()) + # breakpoint() + tokens, num_generated_tokens = sample_speculative( + logits, scores_draft, tokens_draft, **sampling_kwargs + ) + num_accepted_tokens_history.append(num_generated_tokens - 1) + if debug: + print(tokens) + print(num_generated_tokens) + # breakpoint() + # TODO: we're using the fact that batch_size == 1 + # TODO: check eos_token_id + sequences.append(tokens[:1, : num_generated_tokens[0]]) + scores.append(logits[:1, : num_generated_tokens[0]]) + # Note that @model has not evaluated the last sampled token yet, so we'll need to pass + # that in the next time we call @model. + num_generated = num_generated_tokens[0].item() + inference_params.seqlen_offset = seqlen_og + num_generated - 1 + inference_params_draft.seqlen_offset = ( + inference_params.seqlen_offset - 1 + if num_generated > 1 + else inference_params.seqlen_offset + ) + if debug: + cur_ids = torch.cat([input_ids, sequences[-1]], dim=1) + scores_ref = model(cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1).logits + print((scores[-1] - scores_ref[:, :-1]).abs().max()) + # breakpoint() + + while True: + # seqlen_offset is total length generated - 1 + if inference_params.seqlen_offset >= max_length - 1: + break + if inference_params.seqlen_offset >= max_length - 2: # Don't do speculative sampling, just sample 1 token from the model - tokens, scores_new = sample_tokens_main(input_ids, num_tokens=1, decoding=False) + tokens, scores_new = sample_tokens_main(sequences[-1][:, -1:], num_tokens=1) sequences.append(tokens) scores.append(scores_new) - else: - # Sample from draft model, which produces @n_spec_tokens, and @model - # will then use to produce between 1 and 1 + @n_spec_tokens tokens. - # We want seqlen_og + 1 + @n_spec_tokens to be <= @max_length. - n_spec_tokens = min(speculative_lookahead, max_length - seqlen_og - 1) - tokens_draft, scores_draft = sample_tokens_draft( - input_ids, - num_tokens=n_spec_tokens, - decoding=False, - ) - if debug: - scores_draft_ref = model_draft( - torch.cat([input_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1 - ).logits - print((scores_draft[:, :-1] - scores_draft_ref[:, :-1]).abs().max()) - - # Evaluate the draft tokens with the model - logits = model( - torch.cat([input_ids, tokens_draft], dim=1), - inference_params=inference_params, - num_last_tokens=n_spec_tokens + 1, + break + # Sample from draft model + n_spec_tokens = min( + speculative_lookahead, max_length - inference_params_draft.seqlen_offset - 2 + ) + # If the main model accepts all the draft tokens, plus it samples one new token, + # then at the next iteration the draft model need to evaluate the logits of the last draft + # token and the logits of the newly sampled token. So here we pass in the last 2 tokens + # of sequences[-1]. + # This exception is when the main model rejects all the draft tokens, in which case we + # will only have 1 token to pass in. + tokens_draft, scores_draft = sample_tokens_draft( + sequences[-1][:, -2:], num_tokens=n_spec_tokens + ) + num_draft_tokens += n_spec_tokens + if debug: + scores_draft_ref = model_draft( + torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1 ).logits - logits = logits_postprocess_fn(logits) - tokens, num_generated_tokens = sample_speculative( - logits, scores_draft[:, :-1], tokens_draft, **sampling_kwargs - ) - if debug: - print(tokens) - print(num_generated_tokens) - # breakpoint() - # TODO: we're using the fact that batch_size == 1 - # TODO: check eos_token_id - sequences.append(tokens[:1, : num_generated_tokens[0]]) - scores.append(logits[:1, : num_generated_tokens[0]]) - # Note that @model has not evaluated the last sampled token yet, so we'll need to pass - # that in the next time we call @model. - inference_params.seqlen_offset = seqlen_og + num_generated_tokens[0].item() - 1 - inference_params_draft.seqlen_offset = inference_params.seqlen_offset - if debug: - cur_ids = torch.cat([input_ids, sequences[-1]], dim=1) - scores_ref = model( - cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1 - ).logits - print((scores[-1] - scores_ref[:, :-1]).abs().max()) - - while True: - # seqlen_offset is total length generated - 1 - if inference_params.seqlen_offset >= max_length - 1: - break - if inference_params.seqlen_offset >= max_length - 2: - # Don't do speculative sampling, just sample 1 token from the model - tokens, scores_new = sample_tokens_main(sequences[-1][:, -1:], num_tokens=1) - sequences.append(tokens) - scores.append(scores_new) - break - # Sample from draft model - n_spec_tokens = min( - speculative_lookahead, max_length - inference_params_draft.seqlen_offset - 2 - ) - tokens_draft, scores_draft = sample_tokens_draft( - sequences[-1][:, -1:], num_tokens=n_spec_tokens - ) - if debug: - scores_draft_ref = model_draft( - torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1 - ).logits - print((scores_draft[:, :-1] - scores_draft_ref[:, :-1]).abs().max()) - # Evaluate the draft tokens with the model - position_ids = repeat( - torch.arange( - inference_params.seqlen_offset, - # 1 extra token from last time that hasn't been passed through model - inference_params.seqlen_offset + n_spec_tokens + 1, - dtype=torch.long, - device=input_ids.device, - ), - "s -> b s", - b=batch_size, - ) - logits = model( - torch.cat([sequences[-1][:, -1:], tokens_draft], dim=1), - position_ids=position_ids, - inference_params=inference_params, - ).logits # (batch, n_spec_tokens, vocab_size) - logits = logits_postprocess_fn(logits) - inference_params.seqlen_offset += 1 - if debug: - logits_ref = model( - torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1 - ).logits - print((logits - logits_ref).abs().max()) - tokens, num_generated_tokens = sample_speculative( - logits, scores_draft[:, :-1], tokens_draft, **sampling_kwargs - ) - if debug: - print(tokens) - print(num_generated_tokens) - sequences.append(tokens[:1, : num_generated_tokens[0]]) - scores.append(logits[:1, : num_generated_tokens[0]]) - inference_params.seqlen_offset += num_generated_tokens[0].item() - 1 - inference_params_draft.seqlen_offset = inference_params.seqlen_offset + print((scores_draft - scores_draft_ref[:, :-1]).abs().max()) + # breakpoint() + # Evaluate the draft tokens with the model + logits = get_logits_main( + torch.cat([sequences[-1][:, -1:], tokens_draft], dim=1), + inference_params, + num_last_tokens=n_spec_tokens + 1, + ) # (batch, n_spec_tokens + 1, vocab_size) + num_main_model_calls += 1 + if debug: + logits_ref = model( + torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1 + ).logits + print((logits - logits_ref).abs().max()) + # breakpoint() + tokens, num_generated_tokens = sample_speculative( + logits, scores_draft, tokens_draft, **sampling_kwargs + ) + num_accepted_tokens_history.append(num_generated_tokens - 1) + if debug: + print(tokens) + print(num_generated_tokens) # breakpoint() - if debug: - cur_ids = torch.cat([cur_ids, sequences[-1]], dim=1) - scores_ref = model( - cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1 - ).logits - print((scores[-1] - scores_ref[:, :-1]).abs().max()) - - if enable_timing: - if tensor_parallel > 1: - torch.distributed.barrier() - torch.cuda.synchronize() - print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") + sequences.append(tokens[:1, : num_generated_tokens[0]]) + scores.append(logits[:1, : num_generated_tokens[0]]) + # We've evaluated 1 token from sequences[-1][:, -1:] above, plus + # num_generated_tokens[0].item() - 1 tokens from the draft model. + num_generated = num_generated_tokens[0].item() + inference_params.seqlen_offset += num_generated + inference_params_draft.seqlen_offset = ( + inference_params.seqlen_offset - 1 + if num_generated > 1 + else inference_params.seqlen_offset + ) + if debug: + cur_ids = torch.cat([cur_ids, sequences[-1]], dim=1) + scores_ref = model(cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1).logits + print((scores[-1] - scores_ref[:, :-1]).abs().max()) + # breakpoint() + + if enable_timing: + if tensor_parallel > 1: + torch.distributed.barrier() + torch.cuda.synchronize() + print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") + print(f"Number of calls to main model: {num_main_model_calls}") + print( + f"Acceptance rate: {torch.cat(num_accepted_tokens_history).sum().item() / num_draft_tokens * 100:.2f}%" + ) sequences = torch.cat(sequences, dim=1) scores = torch.cat(scores, dim=1) if debug: @@ -607,20 +601,6 @@ def allocate_inference_cache( return {i: torch.empty(kv_cache_shape, device=device, dtype=dtype) for i in layers} -def seqlen_to_seqlen_type(seqlen: int) -> int: - """Convert sequence length to a seqlen_type. - This is used to determine which cuda graph to use. - Arguments: - seqlen: int - """ - return 0 - - -def seqlen_type_to_max_seqlen(seqlen_type: int) -> int: - assert seqlen_type in [0] - return 2**32 - - @dataclass class DecodingCGCache: max_batch_size: int = 0 @@ -640,6 +620,7 @@ def update_graph_cache( batch_size, seqlen_og, max_seqlen, + decoding_seqlens=(1,), tensor_parallel=1, dtype=None, n_warmups=2, @@ -687,38 +668,36 @@ def update_graph_cache( lengths_per_sample=lengths_per_sample, ) cache.mempool = torch.cuda.graphs.graph_pool_handle() - for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_seqlen) + 1): - if (batch_size, s_type) not in cache.callables: - max_seqlen_ = min(max(seqlen_og, seqlen_type_to_max_seqlen(s_type)), max_seqlen) - cache.callables[batch_size, s_type] = capture_graph( + for decoding_seqlen in decoding_seqlens: + if (batch_size, decoding_seqlen) not in cache.callables: + cache.callables[batch_size, decoding_seqlen] = capture_graph( model, cache.inference_params, batch_size, - max_seqlen_, + max_seqlen, + decoding_seqlen=decoding_seqlen, mempool=cache.mempool, n_warmups=n_warmups, ) def dispatch(input_ids, position_ids, seqlen): - batch_size = input_ids.shape[0] - return cache.callables[batch_size, seqlen_to_seqlen_type(seqlen)]( - input_ids, position_ids, seqlen - ) + batch_size, decoding_seqlen = input_ids.shape[:2] + return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen) cache.run = dispatch cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing return cache -def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, n_warmups=2): +def capture_graph( + model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2 +): device = next(iter(model.parameters())).device - input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device) - position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device) + input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) + position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) seqlen_offset_og = inference_params.seqlen_offset - # TD [2023-04-14]: important for correctness of the FT's attention kernel, as seqlen_cpu is - # used to determine the size of smem. Hence seqlen_cpu must be >= lengths_per_sample. - inference_params.seqlen_offset = max_seqlen - 1 - inference_params.lengths_per_sample[:] = max_seqlen - 1 + inference_params.seqlen_offset = max_seqlen - decoding_seqlen + inference_params.lengths_per_sample[:] = inference_params.seqlen_offset # Warmup before capture s = torch.cuda.Stream() @@ -729,7 +708,7 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, input_ids, position_ids=position_ids, inference_params=inference_params, - num_last_tokens=1, + num_last_tokens=decoding_seqlen, ).logits s.synchronize() # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0, @@ -746,8 +725,8 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, input_ids, position_ids=position_ids, inference_params=inference_params, - num_last_tokens=1, - ).logits.squeeze(dim=1) + num_last_tokens=decoding_seqlen, + ).logits def run(new_input_ids, new_position_ids, seqlen): inference_params.lengths_per_sample[:] = seqlen diff --git a/tests/models/test_gpt.py b/tests/models/test_gpt.py index 09c6556bf..982203005 100644 --- a/tests/models/test_gpt.py +++ b/tests/models/test_gpt.py @@ -383,11 +383,14 @@ def test_gpt2_multiple_token_generation(model_name, optimized): @pytest.mark.parametrize("cg", [False, True]) -# @pytest.mark.parametrize("optimized", [False, True]) -@pytest.mark.parametrize("optimized", [True]) +# @pytest.mark.parametrize("cg", [True]) +@pytest.mark.parametrize("optimized", [False, True]) +# @pytest.mark.parametrize("optimized", [True]) # @pytest.mark.parametrize("model_name", ["gpt2-medium"]) @pytest.mark.parametrize("model_name", ["gpt2-xl"]) def test_gpt2_speculative_decoding(model_name, optimized, cg): + if cg and not optimized: + pytest.skip() # CG requires use_flash_attn dtype = torch.float16 device = "cuda" rtol, atol = 3e-3, 3e-1 @@ -421,6 +424,7 @@ def test_gpt2_speculative_decoding(model_name, optimized, cg): from flash_attn.utils.generation import decode_speculative torch.manual_seed(42) + print(f"Speculative decoding, {optimized = }") out = decode_speculative( input_ids, model, @@ -430,13 +434,15 @@ def test_gpt2_speculative_decoding(model_name, optimized, cg): cg=cg, speculative_lookahead=4, enable_timing=True, + # debug=True, ) print(tokenizer.batch_decode(out.sequences)) + print(f"Without speculative decoding, {cg = }") out_og = model.generate( input_ids, max_length=max_length, top_k=5, - cg=False, + cg=cg, enable_timing=True, return_dict_in_generate=True, ) From 0705d2718dd39a39507dbdac85c538189a8436a1 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 20 Sep 2023 23:36:46 -0700 Subject: [PATCH 03/23] [Llama] Fix some tests, add tests for Llama 2 and CodeLlama --- flash_attn/models/llama.py | 31 +-- flash_attn/modules/mha.py | 2 +- tests/models/test_gpt_generation_parallel.py | 1 + tests/models/test_llama.py | 251 ++++++++----------- 4 files changed, 124 insertions(+), 161 deletions(-) diff --git a/flash_attn/models/llama.py b/flash_attn/models/llama.py index 7bea141b1..2841efd76 100644 --- a/flash_attn/models/llama.py +++ b/flash_attn/models/llama.py @@ -13,6 +13,8 @@ from sentencepiece import SentencePieceProcessor from transformers import GPT2Config, LlamaConfig +from einops import rearrange + def remap_state_dict_meta_llama( state_dict: dict[str, torch.Tensor], config: GPT2Config @@ -30,9 +32,7 @@ def key_mapping_layers(key): # Word embedding def key_mapping_emb(key): return re.sub( - r"^transformer.tok_embeddings.", - "transformer.embeddings.word_embeddings.", - key, + r"^transformer.tok_embeddings.", "transformer.embeddings.word_embeddings.", key ) state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) @@ -113,7 +113,7 @@ def key_mapping_attn(key): def remap_state_dict_hf_llama( - state_dict: dict[str, torch.Tensor], config: GPT2Config, multi_query: bool = False + state_dict: dict[str, torch.Tensor], config: GPT2Config ) -> dict[str, torch.Tensor]: """Convert the state_dict in Hugging Face format to standard GPT format. @@ -186,13 +186,11 @@ def key_mapping_ln(key): state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) - def inv_permute(w, first_dim=None): + def inv_permute(w): # Inverse of permute implemented in: # https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114 - return ( - w.reshape(first_dim or config.n_head, 2, -1, config.n_embd) - .transpose(1, 2) - .reshape(-1, config.n_embd) + return rearrange( + w, "(h two d) n -> (h d two) n", d=config.n_embd // config.n_head // 2, two=2 ) # Attention @@ -202,8 +200,7 @@ def inv_permute(w, first_dim=None): Wv = state_dict.pop(f"model.layers.{l}.self_attn.v_proj.weight") state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat( - (inv_permute(Wq), inv_permute(Wk, getattr(config, "n_head_kv")), Wv), - dim=0, + [inv_permute(Wq), inv_permute(Wk), Wv], dim=0 ) # We don't store these state_dict.pop(f"model.layers.{l}.self_attn.rotary_emb.inv_freq", None) @@ -220,7 +217,7 @@ def key_mapping_attn(key): def inv_remap_state_dict_hf_llama( - state_dict: dict[str, torch.Tensor], config: GPT2Config, multi_query: bool = False + state_dict: dict[str, torch.Tensor], config: GPT2Config ) -> dict[str, torch.Tensor]: """Convert the state_dict in standard GPT format to Hugging Face format. @@ -293,11 +290,9 @@ def key_mapping_ln(key): state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) - def permute(w, first_dim=None): - return ( - w.view(first_dim or config.n_head, -1, 2, config.n_embd) - .transpose(1, 2) - .reshape(-1, config.n_embd) + def permute(w): + return rearrange( + w, "(h d two) n -> (h two d) n", d=config.n_embd // config.n_head // 2, two=2 ) n_head = config.n_head @@ -316,7 +311,7 @@ def permute(w, first_dim=None): Wk = Wqkv[q_dim : q_dim + k_dim] Wv = Wqkv[q_dim + k_dim : q_dim + k_dim + v_dim] state_dict[f"model.layers.{l}.self_attn.q_proj.weight"] = permute(Wq) - state_dict[f"model.layers.{l}.self_attn.k_proj.weight"] = permute(Wk, n_head_kv) + state_dict[f"model.layers.{l}.self_attn.k_proj.weight"] = permute(Wk) state_dict[f"model.layers.{l}.self_attn.v_proj.weight"] = Wv state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None) diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 4894daca5..976bd3d2c 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -725,7 +725,7 @@ def __init__( process_group, bias=qkv_proj_bias, sequence_parallel=sequence_parallel, - multiple_of=self.head_dim * (self.num_heads_per_rank + 2 * self.num_heads_kv_per_rank), + multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2), **factory_kwargs, ) inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention diff --git a/tests/models/test_gpt_generation_parallel.py b/tests/models/test_gpt_generation_parallel.py index b398bf968..bcf2bf513 100644 --- a/tests/models/test_gpt_generation_parallel.py +++ b/tests/models/test_gpt_generation_parallel.py @@ -160,6 +160,7 @@ def test_tensor_parallel(model_name, rotary, world_size): assert torch.allclose( torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol ) + assert torch.equal(torch.stack(out.scores, dim=1), torch.stack(out_cg.scores, dim=1)) if not rotary: assert torch.all(out.sequences == out_ref.sequences) assert torch.all(out.sequences == out_hf.sequences) diff --git a/tests/models/test_llama.py b/tests/models/test_llama.py index e0b8c305e..32e9cd211 100644 --- a/tests/models/test_llama.py +++ b/tests/models/test_llama.py @@ -1,6 +1,6 @@ # Copyright (c) 2023, Tri Dao. -# To run the huggingface implementation, we first need to convert the weights: +# To run the huggingface implementation of LLaMa (1), we first need to convert the weights: # https://github.com/huggingface/transformers/pull/21955 # python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir $CHECKPOINT_DIR/llama --model_size 7B --output_dir $CHECKPOINT_DIR/llama/7B-hf # and repeat for 13B, 30B, 65B @@ -30,6 +30,7 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained from transformers import LlamaConfig, LlamaTokenizer from transformers.models.llama.modeling_llama import LlamaForCausalLM +from transformers import AutoConfig def _pretrained_state_dict_from_checkpoint(checkpoint_path, model_name, config, checkpoint_format): @@ -60,9 +61,38 @@ def test_llama_state_dict(model_name): assert state_dict[k].shape == pretrained_state_dict[k].shape -@pytest.mark.parametrize("model_name", ["7B", "13B"]) -@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"]) -def test_llama_optimized(model_name, checkpoint_format): +# TinyLlama-1.1B is to test MQA +@pytest.mark.parametrize( + "model_name", ["meta-llama/Llama-2-7b-hf", "PY007/TinyLlama-1.1B-step-50K-105b"] +) +def test_inv_remap_state_dict_hf_llama(model_name): + config = llama_config_to_gpt2_config( + AutoConfig.from_pretrained(model_name, trust_remote_code=True) + ) + state_dict = state_dict_from_pretrained(model_name) + # inv_remap_state_dict_hf_llama should be the inverse of remap_state_dict_hf_llama + state_dict = {key: val for key, val in state_dict.items() if "rotary_emb.inv_freq" not in key} + pretrained_state_dict = remap_state_dict_hf_llama(state_dict, config) + state_dict_recover = inv_remap_state_dict_hf_llama(pretrained_state_dict, config) + assert set(state_dict_recover.keys()) == set(state_dict.keys()) + for key in state_dict_recover.keys(): + torch.testing.assert_close(state_dict_recover[key], state_dict[key]) + + +# TinyLlama-1.1B is to test MQA +@pytest.mark.parametrize( + "model_name", + [ + "7B", # Llama 1 + "13B", # Llama 1 + "meta-llama/Llama-2-13b-hf", + "codellama/CodeLlama-7b-hf", + "codellama/CodeLlama-13b-hf", + "codellama/CodeLlama-34b-hf", + "PY007/TinyLlama-1.1B-step-50K-105b", + ], +) +def test_llama_optimized(model_name): """Check that our implementation of LLaMa (with all optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF forward pass in fp16, when compared to the HF forward pass in fp32. @@ -73,17 +103,27 @@ def test_llama_optimized(model_name, checkpoint_format): dtype = torch.float16 device = "cuda" - config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format) - config = llama_config_to_gpt2_config(config) + if "/" in model_name: # Download from HF + config = llama_config_to_gpt2_config( + AutoConfig.from_pretrained(model_name, trust_remote_code=True) + ) + else: + config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format="meta") + config = llama_config_to_gpt2_config(config) config.use_flash_attn = True config.fused_bias_fc = True config.fused_mlp = False # We don't have fused GatedMLP yet config.fused_dropout_add_ln = True config.residual_in_fp32 = True - pretrained_state_dict = _pretrained_state_dict_from_checkpoint( - checkpoint_path, model_name, config, checkpoint_format - ) + if "/" in model_name: # Download from HF + pretrained_state_dict = remap_state_dict_hf_llama( + state_dict_from_pretrained(model_name), config + ) + else: + pretrained_state_dict = _pretrained_state_dict_from_checkpoint( + checkpoint_path, model_name, config, checkpoint_format="meta" + ) model = GPTLMHeadModel(config, device=device, dtype=dtype) model.load_state_dict(pretrained_state_dict) model.eval() @@ -103,7 +143,8 @@ def test_llama_optimized(model_name, checkpoint_format): # Without device_map, the model is loaded on the CPU, which is very slow # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB model_ref = LlamaForCausalLM.from_pretrained( - Path(checkpoint_path) / f"{model_name}-hf", device_map="auto" + model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf", + device_map="auto", ) model_ref.eval() with torch.no_grad(): @@ -112,7 +153,9 @@ def test_llama_optimized(model_name, checkpoint_format): del model_ref model_hf = LlamaForCausalLM.from_pretrained( - Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map={"": device} + model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf", + torch_dtype=dtype, + device_map={"": device}, ) model_hf.eval() with torch.no_grad(): @@ -135,77 +178,12 @@ def test_llama_optimized(model_name, checkpoint_format): ).abs().max().item() - - -@pytest.mark.parametrize("model_name", ["PY007/TinyLlama-1.1B-step-50K-105b"]) -def test_mqa_optimized(model_name): - """Check that our implementation of Llama with MQA/GQA (with all optimizations enabled) matches the - HF implementation: the output of our forward pass in fp16 should be around the same as the HF - forward pass in fp16, when compared to the HF forward pass in fp32. - """ - dtype = torch.float16 - device = "cuda" - config = llama_config_to_gpt2_config(LlamaConfig.from_pretrained(model_name)) - config.use_flash_attn = True # FlashAttention-2 supports headdim 256 - config.fused_bias_fc = True - config.fused_mlp = False - config.fused_dropout_add_ln = True - config.residual_in_fp32 = True - - # Without device_map, the model is loaded on the CPU, which is very slow - model_ref = LlamaForCausalLM.from_pretrained(model_name, device_map={"": device}) - model_ref.eval() - - model = GPTLMHeadModel(config, device=device, dtype=dtype) - model.load_state_dict(remap_state_dict_hf_llama(model_ref.state_dict(), config)) - model.eval() - - torch.manual_seed(0) - batch_size = 2 - max_seqlen = 256 - input_ids = torch.randint( - 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device - ) - with torch.no_grad(): - out = model.transformer(input_ids) - logits = model(input_ids).logits - del model - - with torch.no_grad(): - out_ref = model_ref.model(input_ids).last_hidden_state - logits_ref = model_ref(input_ids).logits - del model_ref - - model_hf = LlamaForCausalLM.from_pretrained( - model_name, torch_dtype=dtype, device_map={"": device} - ) - model_hf.eval() - out_hf = model_hf.model(input_ids).last_hidden_state - logits_hf = model_hf(input_ids).logits - del model_hf - - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}") - print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}") - assert (out - out_ref).abs().max().item() < 3 * ( - out_hf - out_ref - ).abs().max().item() - - print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") - print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}") - print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}") - print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}") - assert (logits - logits_ref).abs().max().item() < 3 * ( - logits_hf - logits_ref - ).abs().max().item() - - # torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel" @pytest.mark.parametrize("world_size", [2]) -@pytest.mark.parametrize("model_name", ["13B"]) -@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"]) -def test_llama_parallel(model_name, world_size, checkpoint_format): +@pytest.mark.parametrize( + "model_name", ["13B", "meta-llama/Llama-2-13b-hf", "codellama/CodeLlama-34b-hf"] +) +def test_llama_parallel(model_name, world_size): """Check that our implementation of LLaMa (with all optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF forward pass in fp16, when compared to the HF forward pass in fp32. @@ -217,8 +195,13 @@ def test_llama_parallel(model_name, world_size, checkpoint_format): ) dtype = torch.float16 - config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format) - config = llama_config_to_gpt2_config(config) + if "/" in model_name: # Download from HF + config = llama_config_to_gpt2_config( + AutoConfig.from_pretrained(model_name, trust_remote_code=True) + ) + else: + config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format="meta") + config = llama_config_to_gpt2_config(config) config.use_flash_attn = True config.fused_bias_fc = True config.fused_mlp = False # We don't have fused GatedMLP yet @@ -233,9 +216,14 @@ def test_llama_parallel(model_name, world_size, checkpoint_format): rank = parallel_state.get_tensor_model_parallel_rank() process_group = parallel_state.get_tensor_model_parallel_group() - pretrained_state_dict = _pretrained_state_dict_from_checkpoint( - checkpoint_path, model_name, config, checkpoint_format - ) + if "/" in model_name: # Download from HF + pretrained_state_dict = remap_state_dict_hf_llama( + state_dict_from_pretrained(model_name), config + ) + else: + pretrained_state_dict = _pretrained_state_dict_from_checkpoint( + checkpoint_path, model_name, config, checkpoint_format="meta" + ) model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype) model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank)) model.eval() @@ -260,7 +248,8 @@ def test_llama_parallel(model_name, world_size, checkpoint_format): if rank == 0: # Without device_map, the model is loaded on the CPU, which is very slow model_ref = LlamaForCausalLM.from_pretrained( - Path(checkpoint_path) / f"{model_name}-hf", device_map="auto" + model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf", + device_map="auto", ) model_ref.eval() with torch.no_grad(): @@ -269,7 +258,9 @@ def test_llama_parallel(model_name, world_size, checkpoint_format): del model_ref model_hf = LlamaForCausalLM.from_pretrained( - Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map="auto" + model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf", + torch_dtype=dtype, + device_map="auto", ) model_hf.eval() with torch.no_grad(): @@ -405,9 +396,10 @@ def test_llama_generation(model_name, checkpoint_format): # torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "llama_parallel_generation" @pytest.mark.parametrize("world_size", [2]) -@pytest.mark.parametrize("model_name", ["13B"]) -@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"]) -def test_llama_parallel_generation(model_name, world_size, checkpoint_format): +@pytest.mark.parametrize( + "model_name", ["13B", "meta-llama/Llama-2-13b-hf", "codellama/CodeLlama-34b-hf"] +) +def test_llama_parallel_generation(model_name, world_size): """Check that our implementation matches the HF implementation: the scores in fp16 should be around the same as the HF scores in fp16, when compared to the HF scores in fp32. @@ -419,12 +411,17 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format): ) dtype = torch.float16 - config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format) - config = llama_config_to_gpt2_config(config) - config.use_flash_attn = False + if "/" in model_name: # Download from HF + config = llama_config_to_gpt2_config( + AutoConfig.from_pretrained(model_name, trust_remote_code=True) + ) + else: + config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format="meta") + config = llama_config_to_gpt2_config(config) + config.use_flash_attn = True config.fused_bias_fc = True config.fused_mlp = False # We don't have fused GatedMLP yet - config.fused_dropout_add_ln = False + config.fused_dropout_add_ln = True config.residual_in_fp32 = True config.pad_vocab_size_multiple = 8 * world_size config.sequence_parallel = False # Need to set this to False for generation @@ -450,9 +447,14 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format): # GPU0 and GPU1 and things would hang torch.cuda.set_device(device) - pretrained_state_dict = _pretrained_state_dict_from_checkpoint( - checkpoint_path, model_name, config, checkpoint_format - ) + if "/" in model_name: # Download from HF + pretrained_state_dict = remap_state_dict_hf_llama( + state_dict_from_pretrained(model_name), config + ) + else: + pretrained_state_dict = _pretrained_state_dict_from_checkpoint( + checkpoint_path, model_name, config, checkpoint_format="meta" + ) model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype) model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank)) model.eval() @@ -490,7 +492,9 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format): if rank == 0: # Without device_map, the model is loaded on the CPU, which is very slow model_hf = LlamaForCausalLM.from_pretrained( - Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map="auto" + model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf", + torch_dtype=dtype, + device_map="auto", ) model_hf.eval() print("HF fp16") @@ -508,7 +512,8 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format): del model_hf model_ref = LlamaForCausalLM.from_pretrained( - Path(checkpoint_path) / f"{model_name}-hf", device_map="auto" + model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf", + device_map="auto", ) model_ref.eval() with torch.inference_mode(): @@ -594,15 +599,16 @@ def test_llama_parallel_uneven_num_heads(world_size): if rank == 0: model_ref = LlamaForCausalLM.from_pretrained( - Path(checkpoint_path) / f"{model_name}-hf", device_map="auto" + Path(checkpoint_path) / f"{model_name}-hf", device_map={"": device} ) + model_ref = model_ref.to(device=device) model_ref.eval() - out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device) - logits_ref = model_ref(input_ids).logits.to(device=device) + out_ref = model_ref.model(input_ids).last_hidden_state + logits_ref = model_ref(input_ids).logits del model_ref model_hf = LlamaForCausalLM.from_pretrained( - Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map="auto" + Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map={"": device} ) model_hf.eval() out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device) @@ -625,42 +631,3 @@ def test_llama_parallel_uneven_num_heads(world_size): if os.path.exists(checkpoint_path / f"{model_name}-hf"): shutil.rmtree(checkpoint_path / f"{model_name}-hf") - - -@torch.no_grad() -def test_inv_remap_state_dict_hf_llama(): - checkpoint_path = ( - Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama" - ) - model_name = f"teeny" - - llama_config = LlamaConfig( - num_attention_heads=2, - hidden_size=256 * 2, - intermediate_size=256 * 2 * 4, - num_hidden_layers=4, - ) - config = llama_config_to_gpt2_config(llama_config) - config.use_flash_attn = True - config.fused_bias_fc = True - config.fused_mlp = False # We don't have fused GatedMLP yet - config.fused_dropout_add_ln = True - config.residual_in_fp32 = True - - # Set up. - LlamaForCausalLM(config=llama_config).save_pretrained(checkpoint_path / f"{model_name}-hf") - - # inv_remap_state_dict_hf_llama should be the inverse of remap_state_dict_hf_llama - state_dict = state_dict_from_pretrained(checkpoint_path / f"{model_name}-hf") - state_dict = {key: val for key, val in state_dict.items() if "rotary_emb.inv_freq" not in key} - pretrained_state_dict = remap_state_dict_hf_llama(state_dict, config) - state_dict_recover = inv_remap_state_dict_hf_llama(pretrained_state_dict, config) - - assert set(state_dict_recover.keys()) == set(state_dict.keys()) - - for key in state_dict_recover.keys(): - torch.testing.assert_close(state_dict_recover[key], state_dict[key]) - - # Tear down. - if os.path.exists(checkpoint_path / f"{model_name}-hf"): - shutil.rmtree(checkpoint_path / f"{model_name}-hf") From 2d8ea9a5303b7de8865279b66c8f7e8ed2a59aee Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 20 Sep 2023 23:38:22 -0700 Subject: [PATCH 04/23] Swap seqlen_q and ngroups when seqlen_q=1 (h/t Daniel Haziza) --- csrc/flash_attn/flash_api.cpp | 55 ++++++++++--------- .../src/flash_bwd_launch_template.h | 2 - csrc/flash_attn/src/flash_fwd_kernel.h | 30 +++++----- .../src/flash_fwd_launch_template.h | 24 ++++---- tests/test_flash_attn.py | 4 +- 5 files changed, 62 insertions(+), 53 deletions(-) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 8b4df5bfe..d62bd0dba 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -282,11 +282,14 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case - // Faster to transpose q from (b, 1, h, d) to (b, h, 1, d) in this case - const int seqlenq_nheads_swapped = seqlen_q == 1 && num_heads_k == 1 && num_heads > 1 and p_dropout == 0.f and head_size_og % 8 == 0; - if (seqlenq_nheads_swapped) { - q = q.transpose(1, 2); - std::swap(seqlen_q, num_heads); + // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case + // H/t Daniel Haziza + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && p_dropout == 0.f && head_size_og % 8 == 0; + if (seqlenq_ngroups_swapped) { + const int ngroups = num_heads / num_heads_k; + q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); + seqlen_q = ngroups; + num_heads = num_heads_k; } CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); @@ -353,9 +356,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size is_causal); // This needs to match with run_mha_fwd_splitkv_dispatch - const int block_n = is_sm90 || is_sm8x - ? (head_size <= 64 ? 256 : (head_size <= 160 ? 128 : 64)) - : (head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64)); + const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); const int num_n_blocks = (seqlen_k + block_n - 1) / block_n; // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. // In any case we don't expect seqlen_q to be larger than 64 for inference. @@ -369,6 +370,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); params.oaccum_ptr = out_accum.data_ptr(); } + TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported"); } // number of times random will be generated per thread, to offset philox counter in thc random @@ -397,11 +399,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size if (out_.has_value()) { out_.value().copy_(out); } } - if (seqlenq_nheads_swapped) { - out = out.transpose(1, 2); - out_padded = out_padded.transpose(1, 2); - q_padded = q_padded.transpose(1, 2); - softmax_lse = softmax_lse.transpose(1, 2); + if (seqlenq_ngroups_swapped) { + out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og}); + out_padded = out_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og}); + q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og}); + softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); } return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state}; } @@ -1050,11 +1052,14 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case - // Faster to transpose q from (b, 1, h, d) to (b, h, 1, d) in this case - const int seqlenq_nheads_swapped = seqlen_q == 1 && num_heads_k == 1 && num_heads > 1; - if (seqlenq_nheads_swapped) { - q = q.transpose(1, 2); - std::swap(seqlen_q, num_heads); + // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case + // H/t Daniel Haziza + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && head_size_og % 8 == 0; + if (seqlenq_ngroups_swapped) { + const int ngroups = num_heads / num_heads_k; + q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); + seqlen_q = ngroups; + num_heads = num_heads_k; } CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); @@ -1184,12 +1189,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he params.rotary_dim = 0; } - // This needs to match with run_mha_fwd_splitkv_dispatch - const int block_n = is_sm90 || is_sm8x - ? (head_size <= 64 ? 256 : (head_size <= 160 ? 128 : 64)) - : (head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64)); - const int num_n_blocks = (seqlen_k + (params.knew_ptr == nullptr ? 0 : seqlen_q) + block_n - 1) / block_n; + const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); + const int num_n_blocks = (seqlen_k + block_n - 1) / block_n; // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. // In any case we don't expect seqlen_q to be larger than 64 for inference. const int num_m_blocks = (seqlen_q + 64 - 1) / 64; @@ -1197,6 +1199,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he if (num_splits < 1) { params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128); } + TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported"); if (params.num_splits > 1) { at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_rounded}, opts.dtype(at::kFloat)); @@ -1219,9 +1222,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he } } - if (seqlenq_nheads_swapped) { - out = out.transpose(1, 2); - softmax_lse = softmax_lse.transpose(1, 2); + if (seqlenq_ngroups_swapped) { + out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og}); + softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); } return {out, softmax_lse}; } diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index f4f2388b9..fa45398d2 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -123,14 +123,12 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, kernel_dkv<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); } -// template void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { if (configure) return; run_flash_bwd_seqk_parallel(params, stream, configure); } -// template void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index d3736bea4..68d613431 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -1141,19 +1141,18 @@ inline __device__ void compute_attn_splitkv(const Params ¶ms) { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; constexpr int kMaxSplits = 1 << Log_max_splits; - constexpr int kBlockM = 16; constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNThreads = Kernel_traits::kNThreads; static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128"); - // static_assert(kMaxSplits <= 8, "kMaxSplits must be <= 8 for now, will extend layer"); - static_assert(kBlockM == 16 || kBlockM == 32, "kBlockM must be 16 or 32"); - static_assert(Kernel_traits::kNThreads == 128, "We assume that each block has 128 threads"); + static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32"); + static_assert(kNThreads == 128, "We assume that each block has 128 threads"); // Shared memory. // kBlockM + 1 instead of kBlockM to reduce bank conflicts. @@ -1169,17 +1168,17 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { make_stride(params.b * params.h * params.seqlen_q, _1{})); Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), Shape>{}, Stride<_1>{}); - constexpr int kNLsePerThread = (kMaxSplits * kBlockM + Kernel_traits::kNThreads - 1) / Kernel_traits::kNThreads; + constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads; // Read the LSE values from gmem and store them in shared memory, then tranpose them. - constexpr int kRowsPerLoadLSE = Kernel_traits::kNThreads / kBlockM; + constexpr int kRowsPerLoadLSE = kNThreads / kBlockM; #pragma unroll for (int l = 0; l < kNLsePerThread; ++l) { const int row = l * kRowsPerLoadLSE + tidx / kBlockM; const int col = tidx % kBlockM; ElementAccum lse = (row < params.num_splits && col < params.b * params.h * params.seqlen_q - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY; if (row < kMaxSplits) { sLSE[row][col] = lse; } - // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); } } // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); } __syncthreads(); @@ -1187,7 +1186,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits); // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads, - // 16 rows, so each time we load we can load 8 rows). + // kBlockM rows, so each time we load we can load 128 / kBlockM rows). // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose; // static_assert(kThreadsPerSplit <= 32); static_assert(kRowsPerLoadTranspose <= 32); @@ -1230,7 +1229,13 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), Shape, Int>{}, Stride, _1>{}); - typename Kernel_traits::GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + constexpr int kBlockN = kNThreads / kBlockM; + using GmemLayoutAtomOaccum = Layout, Int>, Stride, _1>>; + using GmemTiledCopyOaccum = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); Tensor tOrO = make_tensor(shape(tOgOaccum)); @@ -1247,7 +1252,6 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { for (int k = 0; k < size(tOpOaccum); ++k) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; } } // Load Oaccum in then scale and accumulate to O - #pragma unroll 2 for (int split = 0; split < params.num_splits; ++split) { flash::copy( gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM @@ -1263,11 +1267,11 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k); } } - // if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0], sLSE[split][1]); print(tOrOaccum); print(tOrO); } + // if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0], sLSE[split][1]); print(tOrOaccum); } } tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded; } - // if (cute::thread0()) { print(tOrO); } + // if (cute::thread0()) { print_tensor(tOrO); } Tensor rO = flash::convert_type(tOrO); // Write to gO diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index 9c8c750c5..51d75768b 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -20,10 +20,10 @@ __global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) { flash::compute_attn_splitkv(params); } -template +template __global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) { static_assert(Log_max_splits >= 1); - flash::combine_attn_seqk_parallel(params); + flash::combine_attn_seqk_parallel(params); } template @@ -93,22 +93,26 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { }); }); if (params.num_splits > 1) { - dim3 grid_combine((params.b * params.h * params.seqlen_q + 16 - 1) / 16); + // We want kBlockM to be as small as possible for more parallelism. + // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4. + // If headdim is divisible by 64, then we set kBlockM = 8, etc. + constexpr int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); + dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { if (params.num_splits <= 2) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 4) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 8) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 16) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 32) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 64) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 128) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } C10_CUDA_KERNEL_LAUNCH_CHECK(); }); diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 17651851c..d37c5c7e5 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -1505,12 +1505,12 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype): @pytest.mark.parametrize("rotary_interleaved", [False, True]) # @pytest.mark.parametrize("rotary_interleaved", [False]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) -# @pytest.mark.parametrize("rotary_fraction", [1.0]) +# @pytest.mark.parametrize("rotary_fraction", [0.0]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) -# @pytest.mark.parametrize("d", [64]) +# @pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ From 229080b9d2af9ffd5657cf65056af7181dcea7f1 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 20 Sep 2023 23:39:38 -0700 Subject: [PATCH 05/23] Bump to v2.2.4 --- flash_attn/__init__.py | 2 +- training/Dockerfile | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index 15ea32e7a..ef30834f9 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.2.3.post2" +__version__ = "2.2.4" from flash_attn.flash_attn_interface import ( flash_attn_func, diff --git a/training/Dockerfile b/training/Dockerfile index fbea6cbc2..a00fcd0d3 100644 --- a/training/Dockerfile +++ b/training/Dockerfile @@ -85,11 +85,11 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0 # Install FlashAttention -RUN pip install flash-attn==2.2.3.post2 +RUN pip install flash-attn==2.2.4 # Install CUDA extensions for fused dense, layer norm RUN git clone https://github.com/HazyResearch/flash-attention \ - && cd flash-attention && git checkout v2.2.3.post2 \ + && cd flash-attention && git checkout v2.2.4 \ && cd csrc/layer_norm && pip install . && cd ../../ \ && cd csrc/fused_dense_lib && pip install . && cd ../../ \ && cd .. && rm -rf flash-attention From 187c2a06358d421e3d350fac8ff8714013c1f1cd Mon Sep 17 00:00:00 2001 From: Yuchao Dai <3407450+icyblade@users.noreply.github.com> Date: Fri, 22 Sep 2023 02:48:23 +0800 Subject: [PATCH 06/23] Fix E1136 (#563) --- flash_attn/models/gpt.py | 3 ++- flash_attn/models/llama.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index e82202811..b2403dc34 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -6,6 +6,7 @@ from collections import OrderedDict, namedtuple from collections.abc import Sequence from functools import partial +from typing import Dict, List import torch import torch.nn as nn @@ -810,7 +811,7 @@ def shard_qkv_headdim(state_dict, key): return state_dict -def combine_state_dicts_tp(state_dicts: list[dict[str, torch.Tensor]], config: GPT2Config): +def combine_state_dicts_tp(state_dicts: List[Dict[str, torch.Tensor]], config: GPT2Config): """Convert the list of sharded state_dict of a GPT model with tensor parallel to the state_dict of a standard GPT model. diff --git a/flash_attn/models/llama.py b/flash_attn/models/llama.py index 2841efd76..3bfb51d17 100644 --- a/flash_attn/models/llama.py +++ b/flash_attn/models/llama.py @@ -6,7 +6,7 @@ import re from collections import OrderedDict from pathlib import Path -from typing import Union +from typing import Dict, List, Union import torch import torch.nn.functional as F @@ -17,8 +17,8 @@ def remap_state_dict_meta_llama( - state_dict: dict[str, torch.Tensor], config: GPT2Config -) -> dict[str, torch.Tensor]: + state_dict: Dict[str, torch.Tensor], config: GPT2Config +) -> Dict[str, torch.Tensor]: """Convert the state_dict in Meta format to standard GPT format. This function modifies state_dict in place. @@ -113,8 +113,8 @@ def key_mapping_attn(key): def remap_state_dict_hf_llama( - state_dict: dict[str, torch.Tensor], config: GPT2Config -) -> dict[str, torch.Tensor]: + state_dict: Dict[str, torch.Tensor], config: GPT2Config +) -> Dict[str, torch.Tensor]: """Convert the state_dict in Hugging Face format to standard GPT format. This function modifies state_dict in place. @@ -217,8 +217,8 @@ def key_mapping_attn(key): def inv_remap_state_dict_hf_llama( - state_dict: dict[str, torch.Tensor], config: GPT2Config -) -> dict[str, torch.Tensor]: + state_dict: Dict[str, torch.Tensor], config: GPT2Config +) -> Dict[str, torch.Tensor]: """Convert the state_dict in standard GPT format to Hugging Face format. This function is meant to be the inverse of remap_state_dict_hf_llama, up to a @@ -382,7 +382,7 @@ def config_from_checkpoint( def state_dicts_from_checkpoint( checkpoint_path: Union[str, os.PathLike], model_name: str -) -> list[dict]: +) -> List[dict]: # Need to sort, otherwise we mess up the ordering and the weights are wrong return [ torch.load(path, map_location="cpu") From bff3147175e2f05ae0a4d44495aef97ca45ce531 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Thu, 21 Sep 2023 23:55:25 -0700 Subject: [PATCH 07/23] Re-enable compilation for Hopper --- flash_attn/__init__.py | 2 +- setup.py | 8 ++++---- training/Dockerfile | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index ef30834f9..79b9dcc4c 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.2.4" +__version__ = "2.2.4.post1" from flash_attn.flash_attn_interface import ( flash_attn_func, diff --git a/setup.py b/setup.py index 75fb15285..33f3813db 100644 --- a/setup.py +++ b/setup.py @@ -122,10 +122,10 @@ def append_nvcc_threads(nvcc_extra_args): # cc_flag.append("arch=compute_75,code=sm_75") cc_flag.append("-gencode") cc_flag.append("arch=compute_80,code=sm_80") - # if CUDA_HOME is not None: - # if bare_metal_version >= Version("11.8"): - # cc_flag.append("-gencode") - # cc_flag.append("arch=compute_90,code=sm_90") + if CUDA_HOME is not None: + if bare_metal_version >= Version("11.8"): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_90,code=sm_90") # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as # torch._C._GLIBCXX_USE_CXX11_ABI diff --git a/training/Dockerfile b/training/Dockerfile index a00fcd0d3..30b219ded 100644 --- a/training/Dockerfile +++ b/training/Dockerfile @@ -85,11 +85,11 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0 # Install FlashAttention -RUN pip install flash-attn==2.2.4 +RUN pip install flash-attn==2.2.4.post1 # Install CUDA extensions for fused dense, layer norm RUN git clone https://github.com/HazyResearch/flash-attention \ - && cd flash-attention && git checkout v2.2.4 \ + && cd flash-attention && git checkout v2.2.4.post1 \ && cd csrc/layer_norm && pip install . && cd ../../ \ && cd csrc/fused_dense_lib && pip install . && cd ../../ \ && cd .. && rm -rf flash-attention From dd9a6fa45a9b90ff954d2b3f3f44241b9216190e Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 22 Sep 2023 02:31:00 -0700 Subject: [PATCH 08/23] Add placeholder for inference example --- examples/inference/README.md | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 examples/inference/README.md diff --git a/examples/inference/README.md b/examples/inference/README.md new file mode 100644 index 000000000..695f04b1c --- /dev/null +++ b/examples/inference/README.md @@ -0,0 +1,2 @@ +# Example of LLM inference using FlashAttention + From 1879e089c72050c0d472491ed98971dffc280ba8 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 23 Sep 2023 22:24:30 -0700 Subject: [PATCH 09/23] Reduce number of templates for headdim > 128 --- csrc/flash_attn/src/flash_bwd_launch_template.h | 3 ++- csrc/flash_attn/src/flash_fwd_launch_template.h | 6 ++++-- flash_attn/layers/rotary.py | 1 + setup.py | 2 +- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index fa45398d2..d13f1d53b 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -63,7 +63,8 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; if (smem_size_dq_dk_dv >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index 51d75768b..4c336ec94 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -45,7 +45,9 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { // Will only return softmax if dropout, to reduce compilation time. // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - auto kernel = &flash_fwd_kernel; + // If return_softmax, set IsEvenMNConst to false to reduce number of templates + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + auto kernel = &flash_fwd_kernel; // auto kernel = &flash_fwd_kernel; if (smem_size >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( @@ -78,7 +80,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - auto kernel = &flash_fwd_splitkv_kernel; + auto kernel = &flash_fwd_splitkv_kernel; // auto kernel = &flash_fwd_splitkv_kernel; // auto kernel = &flash_fwd_splitkv_kernel; if (smem_size >= 48 * 1024) { diff --git a/flash_attn/layers/rotary.py b/flash_attn/layers/rotary.py index 71259d020..bd05258f7 100644 --- a/flash_attn/layers/rotary.py +++ b/flash_attn/layers/rotary.py @@ -371,6 +371,7 @@ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): # or if we're switching from inference mode to training if ( seqlen > self._seq_len_cached + or self._cos_cached is None or self._cos_cached.device != device or self._cos_cached.dtype != dtype or (self.training and self._cos_cached.is_inference()) diff --git a/setup.py b/setup.py index 33f3813db..f5e17a4a2 100644 --- a/setup.py +++ b/setup.py @@ -201,7 +201,7 @@ def append_nvcc_threads(nvcc_extra_args): "--use_fast_math", # "--ptxas-options=-v", # "--ptxas-options=-O2", - "-lineinfo", + # "-lineinfo", ] + generator_flag + cc_flag From 65c234ed9071d0fb3fd87b2a758f10431ce0d5e5 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 24 Sep 2023 00:36:07 -0700 Subject: [PATCH 10/23] Don't over-allocate dq_accum in case of varlen --- csrc/flash_attn/flash_api.cpp | 12 ++++++++++-- csrc/flash_attn/src/flash_bwd_kernel.h | 24 ++++++++++++++---------- tests/test_flash_attn.py | 6 +++--- 3 files changed, 27 insertions(+), 15 deletions(-) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index d62bd0dba..6757f186c 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -713,7 +713,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si at::Tensor dq_accum; at::Tensor dk_accum, dv_accum; if (loop) { - dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded, head_size_rounded}, opts.dtype(at::kFloat)); + dq_accum = torch::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); // dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); // dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); } @@ -923,7 +923,15 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat)); at::Tensor dq_accum; if (loop) { - dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded, head_size_rounded}, opts.dtype(at::kFloat)); + // We don't want to allocate dq_accum of size (batch, seqlen_q_rounded, num_heads, head_size_rounded) + // because that would be too large if there is a very long sequence and the rest of the sequences are short. + // Instead, we allocate dq_accum of size (total_q + 128 * batch, num_heads, head_size_rounded). + // Note that 128 is the max block size on the seqlen_q dimension. + // For dQ, the i-th sequence is stored in indices from cu_seqlens[i] + 128 * i to + // cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will + // be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally + // allowed to do. So we won't have to do any bound checking, and performance should stay the same. + dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); } at::Tensor dk_expanded, dv_expanded; diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index 6bece9b6f..3a0a84731 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -127,7 +127,8 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) { + m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride; const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; - const index_t row_offset_dq_accum = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.d_rounded; + const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) + + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM; Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), @@ -137,7 +138,8 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) { Shape, Int>{}, make_stride(params.o_row_stride, _1{})); Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), - Shape, Int>{}, Stride, _1>{}); + Shape, Int>{}, + make_stride(params.h * params.d_rounded, _1{})); Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum) + row_offset_dpsum), Shape>{}, Stride<_1>{}); @@ -175,6 +177,8 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) { dot_do_o(tdOrdO, tdOrO, dP_sum, Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout); if (Clear_dQaccum) { + // We're actually not zero'ing out all of dQaccum, but only the part that we're going to + // do atomicAdds on. Tensor zero = make_fragment_like(tdQgdQaccum); clear(zero); cute::copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum); @@ -248,15 +252,15 @@ inline __device__ void convert_dQ(const Params ¶ms) { const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) + m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; - const index_t row_offset_dq_accum = ((bidb * params.h + bidh) * params.seqlen_q_rounded - + m_block * kBlockM) * params.d_rounded; + const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) + + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), Shape, Int>{}, make_stride(params.dq_row_stride, _1{})); Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), Shape, Int>{}, - Stride, _1>{}); + make_stride(params.h * params.d_rounded, _1{})); Tensor sdQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutdQ{}); @@ -456,8 +460,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in + (m_block_max - 1) * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) + (m_block_max - 1) * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; - const index_t row_offset_dq_accum = ((bidb * params.h + bidh) * params.seqlen_q_rounded - + (m_block_max - 1) * kBlockM) * params.d_rounded; + const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) + + ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + (m_block_max - 1) * kBlockM; const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded @@ -483,7 +487,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in make_stride(params.dq_row_stride, _1{})); Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), Shape, Int>{}, - Stride, _1>{}); + make_stride(params.h * params.d_rounded, _1{})); Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), Shape>{}, Stride<_1>{}); Tensor gdPsum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum) + row_offset_dpsum), @@ -648,7 +652,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // We'll advance gdQ and gdQaccum before the 1st read/write. tdQgdQ.data() = tdQgdQ.data() + kBlockM * params.dq_row_stride; - tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.d_rounded; + tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.h * params.d_rounded; int m_block = m_block_max - 1; int m_block_min = !Is_causal ? 0 : std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k) / kBlockM); @@ -857,7 +861,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // if (cute::thread0()) { print(dS); } Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // MMA, MMA_N, MMA_K - tdQgdQaccum.data() = tdQgdQaccum.data() + (-int(kBlockM * params.d_rounded)); + tdQgdQaccum.data() = tdQgdQaccum.data() + (-int(kBlockM * params.h * params.d_rounded)); if (Is_first || Seq_parallel) { clear(acc_dq); } else { diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index d37c5c7e5..11daa43e6 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -492,16 +492,16 @@ def get_dropout_fraction( @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize('causal', [True]) +# @pytest.mark.parametrize('causal', [False]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128]) # @pytest.mark.parametrize('d', [64]) # @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048]) @pytest.mark.parametrize("seqlen", [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) -# @pytest.mark.parametrize('seqlen', [97]) +# @pytest.mark.parametrize('seqlen', [128]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) -# @pytest.mark.parametrize('dropout_p', [0.17]) +# @pytest.mark.parametrize('dropout_p', [0.0]) def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype): if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: pytest.skip() # Reference implementation OOM From 812cb1c990f4ea91bfa083c4113ca4ac69d10439 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 24 Sep 2023 00:42:50 -0700 Subject: [PATCH 11/23] Switch cutlass to newer commit to avoid compilation warning --- csrc/cutlass | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cutlass b/csrc/cutlass index 34fd98056..e0aaa3c3b 160000 --- a/csrc/cutlass +++ b/csrc/cutlass @@ -1 +1 @@ -Subproject commit 34fd98056b69fbf7f0929b3f734bb5f00642e2c9 +Subproject commit e0aaa3c3b38db9a89c31f04fef91e92123ad5e2e From 0a1d03c7eacdf1ea5cb5848a9929908ef55175f7 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 24 Sep 2023 00:54:03 -0700 Subject: [PATCH 12/23] Bump to v2.2.5 --- flash_attn/__init__.py | 2 +- training/Dockerfile | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index 79b9dcc4c..f3d32304d 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.2.4.post1" +__version__ = "2.2.5" from flash_attn.flash_attn_interface import ( flash_attn_func, diff --git a/training/Dockerfile b/training/Dockerfile index 30b219ded..b9eed40c3 100644 --- a/training/Dockerfile +++ b/training/Dockerfile @@ -85,11 +85,11 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0 # Install FlashAttention -RUN pip install flash-attn==2.2.4.post1 +RUN pip install flash-attn==2.2.5 # Install CUDA extensions for fused dense, layer norm RUN git clone https://github.com/HazyResearch/flash-attention \ - && cd flash-attention && git checkout v2.2.4.post1 \ + && cd flash-attention && git checkout v2.2.5 \ && cd csrc/layer_norm && pip install . && cd ../../ \ && cd csrc/fused_dense_lib && pip install . && cd ../../ \ && cd .. && rm -rf flash-attention From 4c8ff9154e76c68e7114292bd527c22f45fbf586 Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Mon, 25 Sep 2023 10:47:34 -0700 Subject: [PATCH 13/23] Fix NameError and typo in ApplyRotaryEmbQKV_ (#569) --- flash_attn/layers/rotary.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flash_attn/layers/rotary.py b/flash_attn/layers/rotary.py index bd05258f7..4ec049e00 100644 --- a/flash_attn/layers/rotary.py +++ b/flash_attn/layers/rotary.py @@ -193,16 +193,16 @@ def backward(ctx, dqkv): sin_k = sin if sin_k is None else sin_k dq, dk = dqkv[:, :, 0], dqkv[:, :, 1] apply_rotary( - dq, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True, conjugate=True + dq, cos, sin, seqlen_offsets, interleaved=ctx.interleaved, inplace=True, conjugate=True ) apply_rotary( dk, cos_k, sin_k, seqlen_offsets, - interleaved=interleaved, + interleaved=ctx.interleaved, inplace=True, - conjudate=True, + conjugate=True, ) return dqkv, None, None, None, None, None, None From 083e8f525f1f8e1dde5044afdac79f9588302207 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 24 Sep 2023 22:48:39 -0700 Subject: [PATCH 14/23] Implement local attention Co-authored-by: Timothee Lacroix --- csrc/flash_attn/flash_api.cpp | 58 ++- csrc/flash_attn/src/flash.h | 3 + csrc/flash_attn/src/flash_bwd_kernel.h | 66 ++- .../src/flash_bwd_launch_template.h | 28 +- csrc/flash_attn/src/flash_fwd_kernel.h | 114 +++-- .../src/flash_fwd_launch_template.h | 74 ++-- csrc/flash_attn/src/softmax.h | 23 +- flash_attn/flash_attn_interface.py | 184 ++++++-- tests/test_flash_attn.py | 411 +++++++++++++----- 9 files changed, 706 insertions(+), 255 deletions(-) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 6757f186c..91cc370fe 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -40,7 +40,8 @@ void set_params_fprop(Flash_fwd_params ¶ms, void *softmax_lse_d, float p_dropout, float softmax_scale, - bool is_causal) { + int window_size_left, + int window_size_right) { // Reset the parameters memset(¶ms, 0, sizeof(params)); @@ -105,7 +106,15 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; TORCH_CHECK(p_dropout < 1.f); - params.is_causal = is_causal; + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + params.is_causal = window_size_left < 0 && window_size_right == 0; + + if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; } + if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + params.is_seqlens_k_cumulative = true; } @@ -138,7 +147,8 @@ void set_params_dgrad(Flash_bwd_params ¶ms, void *dsoftmax_sum_d, float p_dropout, float softmax_scale, - bool is_causal) { + int window_size_left, + int window_size_right) { set_params_fprop(params, b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded, @@ -149,7 +159,8 @@ void set_params_dgrad(Flash_bwd_params ¶ms, softmax_lse_d, p_dropout, softmax_scale, - is_causal); + window_size_left, + window_size_right); // Set the pointers and strides. params.do_ptr = dout.data_ptr(); @@ -242,6 +253,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const float p_dropout, const float softmax_scale, bool is_causal, + const int window_size_left, + int window_size_right, const bool return_softmax, c10::optional gen_) { @@ -281,10 +294,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case + if (is_causal) { window_size_right = 0; } // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza - const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && p_dropout == 0.f && head_size_og % 8 == 0; + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0; if (seqlenq_ngroups_swapped) { const int ngroups = num_heads / num_heads_k; q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); @@ -353,7 +367,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size softmax_lse.data_ptr(), p_dropout, softmax_scale, - is_causal); + window_size_left, + window_size_right); // This needs to match with run_mha_fwd_splitkv_dispatch const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); @@ -421,9 +436,12 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q const float softmax_scale, const bool zero_tensors, const bool is_causal, + const int window_size_left, + int window_size_right, const bool return_softmax, c10::optional gen_) { + if (is_causal) { window_size_right = 0; } auto dprops = at::cuda::getCurrentDeviceProperties(); // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; @@ -534,7 +552,8 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q softmax_lse.data_ptr(), p_dropout, softmax_scale, - is_causal); + window_size_left, + window_size_right); // number of times random will be generated per thread, to offset philox counter in thc random // state @@ -600,8 +619,12 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si const float p_dropout, // probability to drop const float softmax_scale, const bool is_causal, + const int window_size_left, + int window_size_right, c10::optional gen_, c10::optional &rng_state) { + + if (is_causal) { window_size_right = 0; } auto dprops = at::cuda::getCurrentDeviceProperties(); // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; @@ -748,7 +771,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si softmax_d.data_ptr(), p_dropout, softmax_scale, - is_causal); + window_size_left, + window_size_right); auto launch = &run_mha_bwd; // launch(params, stream, /*configure=*/true); @@ -804,9 +828,12 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size const float softmax_scale, const bool zero_tensors, const bool is_causal, + const int window_size_left, + int window_size_right, c10::optional gen_, - c10::optional &rng_state -) { + c10::optional &rng_state) { + + if (is_causal) { window_size_right = 0; } auto dprops = at::cuda::getCurrentDeviceProperties(); // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; @@ -969,7 +996,8 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size softmax_d.data_ptr(), p_dropout, softmax_scale, - is_causal); + window_size_left, + window_size_right); auto launch = &run_mha_bwd; // launch(params, stream, /*configure=*/true); @@ -1019,6 +1047,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size const float softmax_scale, bool is_causal, + const int window_size_left, + int window_size_right, bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 int num_splits ) { @@ -1059,10 +1089,11 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case + if (is_causal) { window_size_right = 0; } // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza - const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && head_size_og % 8 == 0; + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0; if (seqlenq_ngroups_swapped) { const int ngroups = num_heads / num_heads_k; q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); @@ -1125,7 +1156,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he softmax_lse.data_ptr(), /*p_dropout=*/0.f, softmax_scale, - is_causal); + window_size_left, + window_size_right); at::Tensor k, v, k_padded, v_padded; if (k_.has_value()) { diff --git a/csrc/flash_attn/src/flash.h b/csrc/flash_attn/src/flash.h index e04507f0a..81f33d9a3 100644 --- a/csrc/flash_attn/src/flash.h +++ b/csrc/flash_attn/src/flash.h @@ -105,6 +105,9 @@ struct Flash_fwd_params : public Qkv_params { float rp_dropout; float scale_softmax_rp_dropout; + // Local window size + int window_size_left, window_size_right; + // Random state. at::PhiloxCudaState philox_args; diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index 3a0a84731..69dde7ec2 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -422,7 +422,7 @@ inline __device__ void convert_dKV(const Params ¶ms) { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, const int bidh, const int n_block) { using Element = typename Kernel_traits::Element; @@ -447,6 +447,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in if (n_block * kBlockN >= binfo.actual_seqlen_k || binfo.actual_seqlen_q == 0) return; int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM); + if (Is_local) { + m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left, kBlockM)); + } const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + (m_block_max - 1) * kBlockM * params.q_row_stride + bidh * params.q_head_stride; @@ -655,14 +658,53 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.h * params.d_rounded; int m_block = m_block_max - 1; - int m_block_min = !Is_causal ? 0 : std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k) / kBlockM); - // We're guaranteed that m_block_min <= m_block: + int m_block_min = (!Is_causal && !Is_local) + ? 0 + : std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right) / kBlockM); + // If not local, we're guaranteed that m_block_min <= m_block: // We checked earlier that n_block * kBlockN < actual_seqlen_k, so in the causal case, // n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k < actual_seqlen_q. // So m_block_min <= (actual_seqlen_q - 1) / kBlockM. // Recall that m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM) = (actual_seqlen_q + kBlockM - 1) / kBlockM. // So m_block_m - 1 = (actual_seqlen_q - 1) / kBlockM. // We conclude that m_block_min <= m_block, so we will always have at least 1 iteration of the for loop. + // However, if local, then this possible to have some blocks of K & V not attending to any query. + // We might need to exit early and write 0 to dK and dV for those blocks. + // Otherwise we get wrong result for the case where we don't enter the for loop. + // And we might read OOB elements from gQ and gdO. + if (Is_local && m_block < m_block_min) { + const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb) + + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride; + const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) + + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride; + Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_ptr) + row_offset_dk), + Shape, Int>{}, + make_stride(params.dk_row_stride, _1{})); + Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_ptr) + row_offset_dv), + Shape, Int>{}, + make_stride(params.dv_row_stride, _1{})); + typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV; + auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx); + Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK); + Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV); + Tensor tdKrdK = make_tensor(shape(tdKgdK)); + Tensor tdVrdV = make_tensor(shape(tdVgdV)); + clear(tdKrdK); + clear(tdVrdV); + Tensor cdKV = make_identity_tensor(make_shape(size<0>(gdK), size<1>(gdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); + Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKgdK))); + #pragma unroll + for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + flash::copy( + gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + return; + } if (Double_buffer && m_block % 2 == 1) { // Double buffer for sQ tQsQ.data() = tQsQ.data() + size(sQ); @@ -777,12 +819,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // However, it's possible that the values in acc_s are so large that they overflow // when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ. // So we need to mask out the elements beyond actual_seqlen_k. - if (!Is_causal) { + if (!Is_causal && !Is_local) { if (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k) { flash::apply_mask(scores, binfo.actual_seqlen_k, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16); } - } else { + } else if (Is_causal) { // Putting this causal masking right after acc_s is *much* slower for some reason. // TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short // (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking. @@ -795,6 +837,16 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4, AtomLayoutMS * 16); } + } else if (Is_local) { + if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right + || (m_block + 1) * kBlockM >= n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left + || (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) { + flash::apply_mask_local(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, + binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), + binfo.actual_seqlen_q, AtomLayoutMS * 16, + params.window_size_left, params.window_size_right); + } + } // if (cute::thread(32, 0)) { print(scores); } // Compute the exponential value. @@ -1510,7 +1562,7 @@ inline __device__ void compute_dq_dk_dv(const Params ¶ms) { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) { const int n_block = blockIdx.x; @@ -1519,7 +1571,7 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) { // The block index for the head. const int bidh = blockIdx.z; - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index d13f1d53b..a2a41679d 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -23,9 +23,10 @@ __global__ void flash_bwd_dq_dk_dv_loop_kernel(Flash_bwd_params params) { flash::compute_dq_dk_dv(params); } -template +template __global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(Flash_bwd_params params) { - flash::compute_dq_dk_dv_seqk_parallel(params); + static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false + flash::compute_dq_dk_dv_seqk_parallel(params); } template @@ -62,16 +63,19 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, BOOL_SWITCH(params.is_causal, IsCausalConst, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If head dim > 128, set IsEvenMNConst to false to reduce number of templates - auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; - // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; - if (smem_size_dq_dk_dv >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] { + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + if (smem_size_dq_dk_dv >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); }); }); diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 68d613431..312b4dda6 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -71,7 +71,7 @@ inline __device__ void write_softmax_to_gmem( //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { using Element = typename Kernel_traits::Element; @@ -93,16 +93,17 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi const BlockInfo binfo(params, bidb); if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; + const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); - if (Is_causal) { + if (Is_causal || Is_local) { n_block_max = std::min(n_block_max, - cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN)); + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max); // } // We exit early and write 0 to gO and gLSE. // Otherwise we might read OOB elements from gK and gV. - if (n_block_max <= 0) { + if (n_block_max <= n_block_min) { // Save seed and offset for backward. If we don't have this here, the 0-th thread block might // exit early and no one saves the rng state. if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) { @@ -145,6 +146,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi return; } } + // if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); } // We iterate over the blocks in reverse order. This is because the last block is the only one // that needs masking when we read K and V from global memory. Moreover, iterating in reverse @@ -326,9 +328,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. - constexpr int n_masking_steps = !Is_causal + constexpr int n_masking_steps = (!Is_causal && !Is_local) ? 1 - : (Is_even_MN ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) @@ -356,11 +358,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - // if (cute::thread0()) { print(scores); } + // if (cute::thread0()) { print_tensor(scores); } // We don't put the masking before the matmul S = Q K^T because we don't clear sK // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul // can produce Inf / NaN. - if (!Is_causal) { + if (!Is_causal && !Is_local) { if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } } else { // Tensor caccS = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n) @@ -374,18 +376,21 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Idk why it's get<1> and not get<0> of the stride. // if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); } // I can't get the stride from idx_row - flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, - // m_block * kBlockM + get<0>(idx_row(0)), - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, - kNWarps * 16); - // m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16); - // m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16); + flash::apply_mask_local( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + // m_block * kBlockM + get<0>(idx_row(0)), + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right + // m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16 + // m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16 + ); + // if (cute::thread0()) { print_tensor(scores); } } flash::cp_async_wait<0>(); __syncthreads(); - if (n_block > 0) { + if (n_block > n_block_min) { // Advance gK tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); @@ -396,8 +401,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // TODO: when we have key_padding_mask we'll need to Check_inf masking_step == 0 - ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) - : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); // Convert scores from fp32 to fp16/bf16 Tensor rP = flash::convert_type(scores); @@ -426,14 +431,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // if (cute::thread0()) { print(scores); } // This check is at the end of the loop since we always have at least 1 iteration - if (n_masking_steps > 1 && n_block <= 0) { + if (n_masking_steps > 1 && n_block <= n_block_min) { --n_block; break; } } // These are the iterations where we don't need masking on S - for (; n_block >= 0; --n_block) { + for (; n_block >= n_block_min; --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) clear(acc_s); flash::cp_async_wait<0>(); @@ -450,7 +455,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi flash::cp_async_wait<0>(); __syncthreads(); - if (n_block > 0) { + if (n_block > n_block_min) { // Advance gK tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); @@ -461,7 +466,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { + flash::apply_mask_local( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right + ); + } + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); Tensor rP = flash::convert_type(scores); // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) @@ -568,7 +581,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { using Element = typename Kernel_traits::Element; @@ -599,11 +612,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (m_block * kBlockM >= binfo.actual_seqlen_q) return; const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; - const int n_block_min = n_split_idx * n_blocks_per_split; + const int n_block_min = !Is_local + ? n_split_idx * n_blocks_per_split + : std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split); - if (Is_causal) { + if (Is_causal || Is_local) { n_block_max = std::min(n_block_max, - cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN)); + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); } if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0 // We exit early and write 0 to gOaccum and -inf to gLSEaccum. @@ -842,21 +857,21 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM); } else { - const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); + const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. // We do this by setting the row stride of gCos / gSin to 0. Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), Shape, Int>{}, - make_stride(Is_causal ? params.rotary_dim / 2 : 0, _1{})); + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), Shape, Int>{}, - make_stride(Is_causal ? params.rotary_dim / 2 : 0, _1{})); + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), Shape, Int>{}, - make_stride(Is_causal ? params.rotary_dim / 2 : 0, _1{})); + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), Shape, Int>{}, - make_stride(Is_causal ? params.rotary_dim / 2 : 0, _1{})); + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); @@ -895,9 +910,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. - constexpr int n_masking_steps = !Is_causal + constexpr int n_masking_steps = (!Is_causal && !Is_local) ? 1 - : (Is_even_MN ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) @@ -929,13 +944,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // We don't put the masking before the matmul S = Q K^T because we don't clear sK // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul // can produce Inf / NaN. - if (!Is_causal) { + if (!Is_causal && !Is_local) { if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } } else { - flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, - kNWarps * 16); + flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right + ); } flash::cp_async_wait<0>(); @@ -954,8 +970,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // We have key_padding_mask so we'll need to Check_inf masking_step == 0 - ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) - : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } // Convert scores from fp32 to fp16/bf16 @@ -1003,7 +1019,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { + flash::apply_mask_local( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right + ); + } + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); Tensor rP = flash::convert_type(scores); // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) @@ -1106,7 +1130,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn(const Params ¶ms) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1122,12 +1146,12 @@ inline __device__ void compute_attn(const Params ¶ms) { // the attention matrix. This way, as long as we have the batch, head, and the location of // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. - flash::compute_attn_1rowblock(params, bidb, bidh, m_block); + flash::compute_attn_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_splitkv(const Params ¶ms) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1136,7 +1160,7 @@ inline __device__ void compute_attn_splitkv(const Params ¶ms) { const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; const int n_split_idx = Split ? blockIdx.y : 0; const int num_n_splits = Split ? gridDim.y : 1; - flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); + flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index 4c336ec94..fbf3cda22 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -10,14 +10,15 @@ #include "flash.h" #include "flash_fwd_kernel.h" -template +template __global__ void flash_fwd_kernel(Flash_fwd_params params) { - flash::compute_attn(params); + static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false + flash::compute_attn(params); } -template +template __global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) { - flash::compute_attn_splitkv(params); + flash::compute_attn_splitkv(params); } template @@ -42,23 +43,25 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { const bool return_softmax = params.p_ptr != nullptr; BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { - BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { - // Will only return softmax if dropout, to reduce compilation time. - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If return_softmax, set IsEvenMNConst to false to reduce number of templates - // If head dim > 128, set IsEvenMNConst to false to reduce number of templates - auto kernel = &flash_fwd_kernel; - // auto kernel = &flash_fwd_kernel; - if (smem_size >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - // int ctas_per_sm; - // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); - // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] { + BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { + // Will only return softmax if dropout, to reduce compilation time. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If return_softmax, set IsEvenMNConst to false to reduce number of templates + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); }); }); @@ -76,19 +79,22 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { - BOOL_SWITCH(params.num_splits > 1, Split, [&] { - BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { - // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - auto kernel = &flash_fwd_splitkv_kernel; - // auto kernel = &flash_fwd_splitkv_kernel; - // auto kernel = &flash_fwd_splitkv_kernel; - if (smem_size >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] { + BOOL_SWITCH(params.num_splits > 1, Split, [&] { + BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { + // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); }); }); diff --git a/csrc/flash_attn/src/softmax.h b/csrc/flash_attn/src/softmax.h index 987f5efe7..221a37be8 100644 --- a/csrc/flash_attn/src/softmax.h +++ b/csrc/flash_attn/src/softmax.h @@ -139,10 +139,11 @@ inline __device__ void apply_mask(Tensor &tensor, const int max_ } } -template -inline __device__ void apply_mask_causal(Tensor &tensor, const int col_idx_offset_, - const int max_seqlen_k, const int row_idx_offset_, - const int max_seqlen_q, const int warp_row_stride) { +template +inline __device__ void apply_mask_local(Tensor &tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset_, + const int max_seqlen_q, const int warp_row_stride, + const int window_size_left, const int window_size_right) { // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); const int lane_id = threadIdx.x % 32; @@ -155,14 +156,15 @@ inline __device__ void apply_mask_causal(Tensor &tensor, const i #pragma unroll for (int i = 0; i < size<0, 0>(tensor); ++i) { const int row_idx = row_idx_base + i * 8; - const int col_idx_limit = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q); + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); #pragma unroll for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll for (int j = 0; j < size<1, 0>(tensor); ++j) { const int col_idx = col_idx_base + j; - if (col_idx >= col_idx_limit) { + if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; } } @@ -176,6 +178,15 @@ inline __device__ void apply_mask_causal(Tensor &tensor, const i } } +template +inline __device__ void apply_mask_causal(Tensor &tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset_, + const int max_seqlen_q, const int warp_row_stride) { + // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 + apply_mask_local(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset_, + max_seqlen_q, warp_row_stride, -1, 0); +} + template inline __device__ void apply_mask_causal_w_idx( Tensor &tensor, Tensor const &idx_rowcol, diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 1fe3b89f0..eba913ad6 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -41,11 +41,21 @@ def _get_block_size(device, head_dim, is_dropout, is_causal): return (128, 64) if is_sm80 else (64, 64) -def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softmax): +def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, window_size, return_softmax): maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd( - q, k, v, None, dropout_p, softmax_scale, causal, return_softmax, None + q, + k, + v, + None, + dropout_p, + softmax_scale, + causal, + window_size[0], + window_size[1], + return_softmax, + None, ) return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state @@ -61,6 +71,7 @@ def _flash_attn_varlen_forward( dropout_p, softmax_scale, causal, + window_size, return_softmax, ): maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x @@ -78,6 +89,8 @@ def _flash_attn_varlen_forward( softmax_scale, False, causal, + window_size[0], + window_size[1], return_softmax, None, ) @@ -87,7 +100,20 @@ def _flash_attn_varlen_forward( def _flash_attn_backward( - dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, softmax_scale, causal, rng_state=None + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + dropout_p, + softmax_scale, + causal, + window_size, + rng_state=None, ): maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x # dq, dk, dv are allocated by us so they should already be contiguous @@ -105,6 +131,8 @@ def _flash_attn_backward( dropout_p, softmax_scale, causal, + window_size[0], + window_size[1], None, rng_state, ) @@ -128,6 +156,7 @@ def _flash_attn_varlen_backward( dropout_p, softmax_scale, causal, + window_size, rng_state=None, ): maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x @@ -151,6 +180,8 @@ def _flash_attn_varlen_backward( softmax_scale, False, causal, + window_size[0], + window_size[1], None, rng_state, ) @@ -161,7 +192,7 @@ def _flash_attn_varlen_backward( class FlashAttnQKVPackedFunc(torch.autograd.Function): @staticmethod - def forward(ctx, qkv, dropout_p, softmax_scale, causal, return_softmax): + def forward(ctx, qkv, dropout_p, softmax_scale, causal, window_size, return_softmax): if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( @@ -171,12 +202,14 @@ def forward(ctx, qkv, dropout_p, softmax_scale, causal, return_softmax): dropout_p, softmax_scale, causal=causal, + window_size=window_size, return_softmax=return_softmax and dropout_p > 0, ) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) ctx.dropout_p = dropout_p ctx.softmax_scale = softmax_scale ctx.causal = causal + ctx.window_size = window_size return out if not return_softmax else (out, softmax_lse, S_dmask) @staticmethod @@ -197,15 +230,26 @@ def backward(ctx, dout, *args): ctx.dropout_p, ctx.softmax_scale, ctx.causal, + ctx.window_size, rng_state=rng_state, ) dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None + return dqkv, None, None, None, None, None class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): @staticmethod - def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax): + def forward( + ctx, + qkv, + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + causal, + window_size, + return_softmax, + ): if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( @@ -219,6 +263,7 @@ def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, dropout_p, softmax_scale, causal=causal, + window_size=window_size, return_softmax=return_softmax and dropout_p > 0, ) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state) @@ -226,6 +271,7 @@ def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, ctx.max_seqlen = max_seqlen ctx.softmax_scale = softmax_scale ctx.causal = causal + ctx.window_size = window_size return out if not return_softmax else (out, softmax_lse, S_dmask) @staticmethod @@ -250,15 +296,16 @@ def backward(ctx, dout, *args): ctx.dropout_p, ctx.softmax_scale, ctx.causal, + ctx.window_size, rng_state=rng_state, ) dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None, None, None + return dqkv, None, None, None, None, None, None, None class FlashAttnKVPackedFunc(torch.autograd.Function): @staticmethod - def forward(ctx, q, kv, dropout_p, softmax_scale, causal, return_softmax): + def forward(ctx, q, kv, dropout_p, softmax_scale, causal, window_size, return_softmax): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( @@ -268,12 +315,14 @@ def forward(ctx, q, kv, dropout_p, softmax_scale, causal, return_softmax): dropout_p, softmax_scale, causal=causal, + window_size=window_size, return_softmax=return_softmax and dropout_p > 0, ) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) ctx.dropout_p = dropout_p ctx.softmax_scale = softmax_scale ctx.causal = causal + ctx.window_size = window_size return out if not return_softmax else (out, softmax_lse, S_dmask) @staticmethod @@ -295,11 +344,12 @@ def backward(ctx, dout, *args): ctx.dropout_p, ctx.softmax_scale, ctx.causal, + ctx.window_size, rng_state=rng_state, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dkv = dkv[..., : dout.shape[-1]] - return dq, dkv, None, None, None, None + return dq, dkv, None, None, None, None, None class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): @@ -315,6 +365,7 @@ def forward( dropout_p, softmax_scale, causal, + window_size, return_softmax, ): if softmax_scale is None: @@ -330,6 +381,7 @@ def forward( dropout_p, softmax_scale, causal=causal, + window_size=window_size, return_softmax=return_softmax and dropout_p > 0, ) ctx.save_for_backward( @@ -340,6 +392,7 @@ def forward( ctx.max_seqlen_k = max_seqlen_k ctx.softmax_scale = softmax_scale ctx.causal = causal + ctx.window_size = window_size return out if not return_softmax else (out, softmax_lse, S_dmask) @staticmethod @@ -365,16 +418,17 @@ def backward(ctx, dout, *args): ctx.dropout_p, ctx.softmax_scale, ctx.causal, + ctx.window_size, rng_state=rng_state, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dkv = dkv[..., : dout.shape[-1]] - return dq, dkv, None, None, None, None, None, None, None, None + return dq, dkv, None, None, None, None, None, None, None, None, None class FlashAttnFunc(torch.autograd.Function): @staticmethod - def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, return_softmax): + def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, window_size, return_softmax): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( @@ -384,12 +438,14 @@ def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, return_softmax): dropout_p, softmax_scale, causal=causal, + window_size=window_size, return_softmax=return_softmax and dropout_p > 0, ) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) ctx.dropout_p = dropout_p ctx.softmax_scale = softmax_scale ctx.causal = causal + ctx.window_size = window_size return out if not return_softmax else (out, softmax_lse, S_dmask) @staticmethod @@ -409,12 +465,13 @@ def backward(ctx, dout, *args): ctx.dropout_p, ctx.softmax_scale, ctx.causal, + ctx.window_size, rng_state=rng_state, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None class FlashAttnVarlenFunc(torch.autograd.Function): @@ -431,6 +488,7 @@ def forward( dropout_p, softmax_scale, causal, + window_size, return_softmax, ): if softmax_scale is None: @@ -446,6 +504,7 @@ def forward( dropout_p, softmax_scale, causal=causal, + window_size=window_size, return_softmax=return_softmax and dropout_p > 0, ) ctx.save_for_backward( @@ -456,6 +515,7 @@ def forward( ctx.max_seqlen_k = max_seqlen_k ctx.softmax_scale = softmax_scale ctx.causal = causal + ctx.window_size = window_size return out if not return_softmax else (out, softmax_lse, S_dmask) @staticmethod @@ -479,16 +539,22 @@ def backward(ctx, dout, *args): ctx.dropout_p, ctx.softmax_scale, ctx.causal, + ctx.window_size, rng_state=rng_state, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None def flash_attn_qkvpacked_func( - qkv, dropout_p=0.0, softmax_scale=None, causal=False, return_attn_probs=False + qkv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + return_attn_probs=False, ): """dropout_p should be set to 0.0 during evaluation If Q, K, V are already stacked into 1 tensor, this function will be faster than @@ -497,12 +563,16 @@ def flash_attn_qkvpacked_func( For multi-query and grouped-query attention (MQA/GQA), please see flash_attn_kvpacked_func and flash_attn_func. + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive. + Arguments: qkv: (batch_size, seqlen, 3, nheads, headdim) dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). @@ -515,11 +585,19 @@ def flash_attn_qkvpacked_func( The output of softmax (possibly with different scaling). It also encodes the dropout pattern (negative means that location was dropped, nonnegative means it was kept). """ - return FlashAttnQKVPackedFunc.apply(qkv, dropout_p, softmax_scale, causal, return_attn_probs) + return FlashAttnQKVPackedFunc.apply( + qkv, dropout_p, softmax_scale, causal, window_size, return_attn_probs + ) def flash_attn_kvpacked_func( - q, kv, dropout_p=0.0, softmax_scale=None, causal=False, return_attn_probs=False + q, + kv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + return_attn_probs=False, ): """dropout_p should be set to 0.0 during evaluation If K, V are already stacked into 1 tensor, this function will be faster than @@ -542,6 +620,10 @@ def flash_attn_kvpacked_func( 1 1 If the row of the mask is all zero, the output will be zero. + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + Arguments: q: (batch_size, seqlen, nheads, headdim) kv: (batch_size, seqlen, 2, nheads_k, headdim) @@ -549,6 +631,7 @@ def flash_attn_kvpacked_func( softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). @@ -561,11 +644,20 @@ def flash_attn_kvpacked_func( The output of softmax (possibly with different scaling). It also encodes the dropout pattern (negative means that location was dropped, nonnegative means it was kept). """ - return FlashAttnKVPackedFunc.apply(q, kv, dropout_p, softmax_scale, causal, return_attn_probs) + return FlashAttnKVPackedFunc.apply( + q, kv, dropout_p, softmax_scale, causal, window_size, return_attn_probs + ) def flash_attn_func( - q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, return_attn_probs=False + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + return_attn_probs=False, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads @@ -585,6 +677,10 @@ def flash_attn_func( 1 1 If the row of the mask is all zero, the output will be zero. + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + Arguments: q: (batch_size, seqlen, nheads, headdim) k: (batch_size, seqlen, nheads_k, headdim) @@ -593,6 +689,7 @@ def flash_attn_func( softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). @@ -605,7 +702,9 @@ def flash_attn_func( The output of softmax (possibly with different scaling). It also encodes the dropout pattern (negative means that location was dropped, nonnegative means it was kept). """ - return FlashAttnFunc.apply(q, k, v, dropout_p, softmax_scale, causal, return_attn_probs) + return FlashAttnFunc.apply( + q, k, v, dropout_p, softmax_scale, causal, window_size, return_attn_probs + ) def flash_attn_varlen_qkvpacked_func( @@ -615,6 +714,7 @@ def flash_attn_varlen_qkvpacked_func( dropout_p=0.0, softmax_scale=None, causal=False, + window_size=(-1, -1), # -1 means infinite context window return_attn_probs=False, ): """dropout_p should be set to 0.0 during evaluation @@ -624,6 +724,9 @@ def flash_attn_varlen_qkvpacked_func( For multi-query and grouped-query attention (MQA/GQA), please see flash_attn_varlen_kvpacked_func and flash_attn_varlen_func. + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive. + Arguments: qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch. cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths @@ -633,6 +736,7 @@ def flash_attn_varlen_qkvpacked_func( softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). @@ -646,7 +750,14 @@ def flash_attn_varlen_qkvpacked_func( pattern (negative means that location was dropped, nonnegative means it was kept). """ return FlashAttnVarlenQKVPackedFunc.apply( - qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_attn_probs + qkv, + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + causal, + window_size, + return_attn_probs, ) @@ -660,6 +771,7 @@ def flash_attn_varlen_kvpacked_func( dropout_p=0.0, softmax_scale=None, causal=False, + window_size=(-1, -1), # -1 means infinite context window return_attn_probs=False, ): """dropout_p should be set to 0.0 during evaluation @@ -683,6 +795,10 @@ def flash_attn_varlen_kvpacked_func( 1 1 If the row of the mask is all zero, the output will be zero. + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + Arguments: q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch. @@ -696,6 +812,7 @@ def flash_attn_varlen_kvpacked_func( softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). @@ -718,6 +835,7 @@ def flash_attn_varlen_kvpacked_func( dropout_p, softmax_scale, causal, + window_size, return_attn_probs, ) @@ -733,6 +851,7 @@ def flash_attn_varlen_func( dropout_p=0.0, softmax_scale=None, causal=False, + window_size=(-1, -1), # -1 means infinite context window return_attn_probs=False, ): """dropout_p should be set to 0.0 during evaluation @@ -753,6 +872,10 @@ def flash_attn_varlen_func( 1 1 If the row of the mask is all zero, the output will be zero. + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + Arguments: q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. @@ -767,6 +890,7 @@ def flash_attn_varlen_func( softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). @@ -790,6 +914,7 @@ def flash_attn_varlen_func( dropout_p, softmax_scale, causal, + window_size, return_attn_probs, ) @@ -805,6 +930,7 @@ def flash_attn_with_kvcache( cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, softmax_scale=None, causal=False, + window_size=(-1, -1), # -1 means infinite context window rotary_interleaved=True, num_splits=0, ): @@ -818,11 +944,12 @@ def flash_attn_with_kvcache( For example, the KV cache could be pre-allocated with the max sequence length, and you can use cache_seqlens to keep track of the current sequence lengths of each sequence in the batch. - Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be rotated - by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. - If causal, the query @q will be rotated by rotary_cos and rotary_sin at indices cache_seqlens, - cache_seqlens + 1, etc. If not causal, the query @q will be rotated by rotary_cos and rotary_sin - at indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens). + Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be + rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. + If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos + and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. + If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at + indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens). See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function. @@ -843,6 +970,10 @@ def flash_attn_with_kvcache( 1 1 If the row of the mask is all zero, the output will be zero. + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + Note: Does not support backward pass. Arguments: @@ -860,6 +991,7 @@ def flash_attn_with_kvcache( softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 @@ -894,6 +1026,8 @@ def flash_attn_with_kvcache( None, softmax_scale, causal, + window_size[0], + window_size[1], rotary_interleaved, num_splits, ) diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 11daa43e6..4fe330979 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -150,8 +150,13 @@ def generate_qkv( ) -def construct_causal_mask( - seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, device=None +def construct_local_mask( + seqlen_q, + seqlen_k, + window_size=(-1, -1), # -1 means infinite window size + query_padding_mask=None, + key_padding_mask=None, + device=None, ): row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) @@ -165,7 +170,14 @@ def construct_causal_mask( if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") ) - return col_idx > row_idx + sk - sq + if window_size[0] < 0: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + col_idx < row_idx + sk - sq - window_size[0], + ) def attention_ref( @@ -177,6 +189,7 @@ def attention_ref( dropout_p=0.0, dropout_mask=None, causal=False, + window_size=(-1, -1), # -1 means infinite window size upcast=True, reorder_ops=False, ): @@ -189,6 +202,8 @@ def attention_ref( key_padding_mask: (batch_size, seqlen_k) dropout_p: float dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + causal: whether to apply causal masking + window_size: (int, int), left and right window size upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast output back to fp16/bf16. reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) @@ -198,6 +213,8 @@ def attention_ref( output: (batch_size, seqlen_q, nheads, head_dim) attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout """ + if causal: + window_size = (window_size[0], 0) dtype_og = q.dtype if upcast: q, k, v = q.float(), k.float(), v.float() @@ -211,17 +228,24 @@ def attention_ref( scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) - if causal: - # causal_mask = torch.triu( - # torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1 - # ) - causal_mask = construct_causal_mask( - seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, q.device + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + q.device, ) - scores.masked_fill_(causal_mask, float("-inf")) + scores.masked_fill_(local_mask, float("-inf")) attention = torch.softmax(scores, dim=-1) - if causal: # Some rows are completely masked out so we fill them with zero instead of NaN - attention = attention.masked_fill(torch.all(causal_mask, dim=-1, keepdim=True), 0.0) + # Some rows might be completely masked out so we fill them with zero instead of NaN + if window_size[0] >= 0 or window_size[1] >= 0: + attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) + # We want to mask here so that the attention matrix doesn't have any NaNs + # Otherwise we'll get NaN in dV + if query_padding_mask is not None: + attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) dropout_scaling = 1.0 / (1 - dropout_p) # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) @@ -232,7 +256,6 @@ def attention_ref( output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) if query_padding_mask is not None: output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) - attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) @@ -244,6 +267,7 @@ def attention_kvpacked_ref( dropout_p=0.0, dropout_mask=None, causal=False, + window_size=(-1, -1), # -1 means infinite window size upcast=True, reorder_ops=False, ): @@ -257,6 +281,7 @@ def attention_kvpacked_ref( dropout_mask, upcast=upcast, causal=causal, + window_size=window_size, reorder_ops=reorder_ops, ) @@ -267,6 +292,7 @@ def attention_qkvpacked_ref( dropout_p=0.0, dropout_mask=None, causal=False, + window_size=(-1, -1), # -1 means infinite window size upcast=True, reorder_ops=False, ): @@ -280,6 +306,7 @@ def attention_qkvpacked_ref( dropout_mask, upcast=upcast, causal=causal, + window_size=window_size, reorder_ops=reorder_ops, ) @@ -327,7 +354,15 @@ def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask def convert_flash_attn_S_to_softmax( - S, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, head_dim, is_dropout, causal=False + S, + seqlen_q, + seqlen_k, + query_padding_mask, + key_padding_mask, + head_dim, + is_dropout, + causal=False, + window_size=(-1, -1), # -1 means infinite window size ): """FlashAttention stores the S matrix in a different way. Arguments: @@ -335,6 +370,8 @@ def convert_flash_attn_S_to_softmax( query_padding_mask: (batch_size, seqlen_q_rounded) key_padding_mask: (batch_size, seqlen_k_rounded) """ + if causal: + window_size = (window_size[0], 0) seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:] warps_n = 4 blocksize_m, blocksize_n = _get_block_size(S.device, head_dim, is_dropout, causal) @@ -359,19 +396,21 @@ def convert_flash_attn_S_to_softmax( four=4, ) - if causal: - # causal_mask = torch.triu( - # torch.ones(seqlen_q_rounded, seqlen_k_rounded, dtype=torch.bool, device=q.device), 1 - # ) - causal_mask = construct_causal_mask( - seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, S.device + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + S.device, ) - causal_mask = F.pad( - causal_mask, + local_mask = F.pad( + local_mask, (0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q), value=True, ) - S_converted.masked_fill_(causal_mask, 0.0) + S_converted.masked_fill_(local_mask, 0.0) # Need to zero out things not in attention_mask in case S was initialized with random values # and some of those values aren't overwritten. @@ -399,6 +438,7 @@ def normalize_flash_attn_S( key_padding_mask=None, is_dropout=False, causal=False, + window_size=(-1, -1), # -1 means infinite window size ): """ Arguments: @@ -409,20 +449,24 @@ def normalize_flash_attn_S( softmax_lse: (batch_size, nheads, seqlen_q) softmax_max: (batch_size, nheads, seqlen_q) """ + if causal: + window_size = (window_size[0], 0) q, k, v = q.float(), k.float(), v.float() _, seqlen_q, _, head_dim = q.shape seqlen_k = k.shape[1] scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(head_dim), k) if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) - if causal: - # causal_mask = torch.triu( - # torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1 - # ) - causal_mask = construct_causal_mask( - seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, q.device + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + q.device, ) - scores.masked_fill_(causal_mask, float("-inf")) + scores.masked_fill_(local_mask, float("-inf")) _, block_size_n = _get_block_size(scores.device, head_dim, is_dropout, causal) scores_block = scores.split(block_size_n, dim=-1) lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1) @@ -446,79 +490,84 @@ def normalize_flash_attn_S( def get_dropout_fraction( - dropout_mask, query_padding_mask=None, key_padding_mask=None, causal=False + dropout_mask, + query_padding_mask=None, + key_padding_mask=None, + causal=False, + window_size=(-1, -1), # -1 means infinite window size ): """ dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop. query_padding_mask: (batch_size, seqlen_q) key_padding_mask: (batch_size, seqlen_k) """ + if causal: + window_size = (window_size[0], 0) batch_size, nheads, seqlen_q, seqlen_k = dropout_mask.shape dropped = ~dropout_mask + valid = torch.ones_like(dropout_mask) if query_padding_mask is not None: dropped.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), False) + valid.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), False) if key_padding_mask is not None: dropped.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False) - if causal: - # causal_mask = torch.triu( - # torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=dropout_mask.device), 1 - # ) - causal_mask = construct_causal_mask( - seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, dropout_mask.device + valid.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False) + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + dropout_mask.device, ) - dropped.masked_fill_(causal_mask, False) + dropped.masked_fill_(local_mask, False) + valid.masked_fill_(local_mask, False) dropped_total = dropped.sum() - query_lengths = ( - query_padding_mask.sum(dim=-1) - if query_padding_mask is not None - else torch.full((batch_size,), seqlen_q, device=dropout_mask.device) - ) - key_lengths = ( - key_padding_mask.sum(dim=-1) - if key_padding_mask is not None - else torch.full((batch_size,), seqlen_k, device=dropout_mask.device) - ) - if not causal: - numel_per_batch = query_lengths * key_lengths - else: - numel_per_batch = torch.where( - key_lengths <= query_lengths, - key_lengths * (key_lengths + 1) / 2, - query_lengths * key_lengths - (query_lengths * (query_lengths - 1) / 2), - ) - return dropped_total / (numel_per_batch.sum() * nheads) + return dropped.sum() / valid.sum() @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -# @pytest.mark.parametrize('dtype', [torch.float16]) +# @pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize('causal', [False]) +# @pytest.mark.parametrize("causal", [False]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) -# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128]) -# @pytest.mark.parametrize('d', [64]) +# @pytest.mark.parametrize("d", [64]) # @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048]) @pytest.mark.parametrize("seqlen", [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) -# @pytest.mark.parametrize('seqlen', [128]) +# @pytest.mark.parametrize("seqlen", [128]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) -# @pytest.mark.parametrize('dropout_p', [0.0]) -def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype): +# @pytest.mark.parametrize("dropout_p", [0.0]) +def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, dtype): if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: pytest.skip() # Reference implementation OOM device = "cuda" # set seed torch.random.manual_seed(0) - batch_size = 16 + batch_size = 13 nheads = 9 + window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,)) qkv = torch.randn( batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True ) out, lse, S_dmask = flash_attn_qkvpacked_func( - qkv, dropout_p, return_attn_probs=True, causal=causal + qkv, dropout_p, causal=causal, window_size=window_size, return_attn_probs=True ) if dropout_p > 0.0: S_dmask_converted = convert_flash_attn_S_to_softmax( - S_dmask, seqlen, seqlen, None, None, d, dropout_p > 0.0, causal=causal + S_dmask, + seqlen, + seqlen, + None, + None, + d, + dropout_p > 0.0, + causal=causal, + window_size=window_size, ) dropout_mask = S_dmask_converted >= 0 attn_unnorm = S_dmask_converted.abs() @@ -531,15 +580,27 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype): None, dropout_p > 0.0, causal=causal, + window_size=window_size, ) - dropout_fraction = get_dropout_fraction(dropout_mask, None, None, causal=causal).item() + dropout_fraction = get_dropout_fraction( + dropout_mask, None, None, causal=causal, window_size=window_size + ).item() print(f"Actual dropout fraction: {dropout_fraction}") else: dropout_mask = None - out_ref, attn_ref = attention_qkvpacked_ref(qkv, None, dropout_p, dropout_mask, causal=causal) + out_ref, attn_ref = attention_qkvpacked_ref( + qkv, None, dropout_p, dropout_mask, causal=causal, window_size=window_size + ) out_pt, attn_pt = attention_qkvpacked_ref( - qkv, None, dropout_p, dropout_mask, causal=causal, upcast=False, reorder_ops=True + qkv, + None, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + upcast=False, + reorder_ops=True, ) # v = qkv[:, :, 2].float() # qk = torch.einsum('bshd,bthd->bhst', qkv[:, :, 0], qkv[:, :, 1]).float() @@ -590,7 +651,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype): if dropout_p > 0.0: assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() - assert abs(dropout_fraction - dropout_p) <= 0.01 + assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() @@ -598,15 +659,18 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype): @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize('dtype', [torch.float16]) +@pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize('causal', [False]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [64]) @pytest.mark.parametrize("seqlen", [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) # @pytest.mark.parametrize('seqlen', [128]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) # @pytest.mark.parametrize('dropout_p', [0.0]) -def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype): +def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype): if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: pytest.skip() # Reference implementation OOM device = "cuda" @@ -614,6 +678,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype): torch.random.manual_seed(0) batch_size = 5 nheads = 6 + window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,)) qkv = torch.randn( batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True ) @@ -626,7 +691,13 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype): ) out_unpad, sm_lse, S_dmask = flash_attn_varlen_qkvpacked_func( - qkv_unpad, cu_seqlens, max_seqlen, dropout_p, return_attn_probs=True, causal=causal + qkv_unpad, + cu_seqlens, + max_seqlen, + dropout_p, + causal=causal, + window_size=window_size, + return_attn_probs=True, ) out = output_pad_fn(out_unpad) if dropout_p > 0.0: @@ -639,6 +710,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype): d, dropout_p > 0.0, causal=causal, + window_size=window_size, ) dropout_mask = S_dmask_converted >= 0 attn_unnorm = S_dmask_converted.abs() @@ -651,16 +723,17 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype): key_padding_mask, dropout_p > 0.0, causal=causal, + window_size=window_size, ) dropout_fraction = get_dropout_fraction( - dropout_mask, key_padding_mask, key_padding_mask, causal=causal + dropout_mask, key_padding_mask, key_padding_mask, causal=causal, window_size=window_size ).item() print(f"Actual dropout fraction: {dropout_fraction}") else: dropout_mask = None out_ref, attn_ref = attention_qkvpacked_ref( - qkv, key_padding_mask, dropout_p, dropout_mask, causal=causal + qkv, key_padding_mask, dropout_p, dropout_mask, causal=causal, window_size=window_size ) out_pt, attn_pt = attention_qkvpacked_ref( qkv, @@ -668,6 +741,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype): dropout_p, dropout_mask, causal=causal, + window_size=window_size, upcast=False, reorder_ops=True, ) @@ -700,7 +774,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype): if dropout_p > 0.0: assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() - assert abs(dropout_fraction - dropout_p) <= 0.01 + assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() @@ -712,10 +786,12 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype): # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) -# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) @@ -738,7 +814,9 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype): # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) # @pytest.mark.parametrize("dropout_p", [0.17]) -def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, dtype, kvpacked): +def test_flash_attn_output( + seqlen_q, seqlen_k, d, dropout_p, causal, local, mha_type, dtype, kvpacked +): if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 @@ -747,10 +825,11 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d device = "cuda" # set seed torch.random.manual_seed(0) - batch_size = 16 + batch_size = 13 nheads = 9 nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) assert nheads % nheads_k == 0 + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) if kvpacked: kv = torch.randn( @@ -766,15 +845,23 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d if kvpacked: out, lse, S_dmask = flash_attn_kvpacked_func( - q, kv, dropout_p, return_attn_probs=True, causal=causal + q, kv, dropout_p, causal=causal, window_size=window_size, return_attn_probs=True ) else: out, lse, S_dmask = flash_attn_func( - q, k, v, dropout_p, return_attn_probs=True, causal=causal + q, k, v, dropout_p, causal=causal, window_size=window_size, return_attn_probs=True ) if dropout_p > 0.0: S_dmask_converted = convert_flash_attn_S_to_softmax( - S_dmask, seqlen_q, seqlen_k, None, None, d, dropout_p > 0.0, causal=causal + S_dmask, + seqlen_q, + seqlen_k, + None, + None, + d, + dropout_p > 0.0, + causal=causal, + window_size=window_size, ) dropout_mask = S_dmask_converted >= 0 attn_unnorm = S_dmask_converted.abs() @@ -785,16 +872,33 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k) v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k) attn = normalize_flash_attn_S( - attn_unnorm, q, k_rep, v_rep, None, None, dropout_p > 0.0, causal=causal + attn_unnorm, + q, + k_rep, + v_rep, + None, + None, + dropout_p > 0.0, + causal=causal, + window_size=window_size, ) - dropout_fraction = get_dropout_fraction(dropout_mask, None, None, causal=causal).item() + dropout_fraction = get_dropout_fraction( + dropout_mask, None, None, causal=causal, window_size=window_size + ).item() print(f"Actual dropout fraction: {dropout_fraction}") else: dropout_mask = None if kvpacked: out_ref, attn_ref = attention_kvpacked_ref( - q, kv, None, None, dropout_p, dropout_mask, causal=causal + q, + kv, + None, + None, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, ) out_pt, attn_pt = attention_kvpacked_ref( q, @@ -804,12 +908,21 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d dropout_p, dropout_mask, causal=causal, + window_size=window_size, upcast=False, reorder_ops=True, ) else: out_ref, attn_ref = attention_ref( - q, k, v, None, None, dropout_p, dropout_mask, causal=causal + q, + k, + v, + None, + None, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, ) out_pt, attn_pt = attention_ref( q, @@ -820,6 +933,7 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d dropout_p, dropout_mask, causal=causal, + window_size=window_size, upcast=False, reorder_ops=True, ) @@ -886,7 +1000,7 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d if dropout_p > 0.0: assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() - assert abs(dropout_fraction - dropout_p) <= 0.01 + assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() @@ -900,10 +1014,12 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d # @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize('mha_type', ["mqa"]) +@pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize('causal', [True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) -# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [64]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", @@ -925,7 +1041,7 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) # @pytest.mark.parametrize('dropout_p', [0.0]) def test_flash_attn_varlen_output( - seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, dtype, kvpacked + seqlen_q, seqlen_k, d, dropout_p, causal, local, mha_type, dtype, kvpacked ): if ( max(seqlen_q, seqlen_k) >= 2048 @@ -935,10 +1051,11 @@ def test_flash_attn_varlen_output( device = "cuda" # set seed torch.random.manual_seed(0) - batch_size = 16 + batch_size = 13 nheads = 9 nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) assert nheads % nheads_k == 0 + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) if kvpacked: kv = torch.randn( @@ -980,6 +1097,7 @@ def test_flash_attn_varlen_output( dropout_p, return_attn_probs=True, causal=causal, + window_size=window_size, ) else: ( @@ -1008,6 +1126,7 @@ def test_flash_attn_varlen_output( dropout_p, return_attn_probs=True, causal=causal, + window_size=window_size, ) out = output_pad_fn(out_unpad) if dropout_p > 0.0: @@ -1020,6 +1139,7 @@ def test_flash_attn_varlen_output( d, dropout_p > 0.0, causal=causal, + window_size=window_size, ) dropout_mask = S_dmask_converted >= 0 attn_unnorm = S_dmask_converted.abs() @@ -1038,9 +1158,14 @@ def test_flash_attn_varlen_output( key_padding_mask, dropout_p > 0.0, causal=causal, + window_size=window_size, ) dropout_fraction = get_dropout_fraction( - dropout_mask, query_padding_mask, key_padding_mask, causal=causal + dropout_mask, + query_padding_mask, + key_padding_mask, + causal=causal, + window_size=window_size, ).item() print(f"Actual dropout fraction: {dropout_fraction}") else: @@ -1048,7 +1173,14 @@ def test_flash_attn_varlen_output( if kvpacked: out_ref, attn_ref = attention_kvpacked_ref( - q, kv, query_padding_mask, key_padding_mask, dropout_p, dropout_mask, causal=causal + q, + kv, + query_padding_mask, + key_padding_mask, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, ) out_pt, attn_pt = attention_kvpacked_ref( q, @@ -1058,12 +1190,21 @@ def test_flash_attn_varlen_output( dropout_p, dropout_mask, causal=causal, + window_size=window_size, upcast=False, reorder_ops=True, ) else: out_ref, attn_ref = attention_ref( - q, k, v, query_padding_mask, key_padding_mask, dropout_p, dropout_mask, causal=causal + q, + k, + v, + query_padding_mask, + key_padding_mask, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, ) out_pt, attn_pt = attention_ref( q, @@ -1074,6 +1215,7 @@ def test_flash_attn_varlen_output( dropout_p, dropout_mask, causal=causal, + window_size=window_size, upcast=False, reorder_ops=True, ) @@ -1142,7 +1284,7 @@ def test_flash_attn_varlen_output( if dropout_p > 0.0: assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() - assert abs(dropout_fraction - dropout_p) <= 0.01 + assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() @@ -1152,8 +1294,10 @@ def test_flash_attn_varlen_output( @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) -# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) @@ -1176,7 +1320,7 @@ def test_flash_attn_varlen_output( ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) -def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype): +def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 @@ -1188,13 +1332,16 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype): causal = True # set seed torch.random.manual_seed(0) - batch_size = 16 + batch_size = 13 nheads = 9 + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) - out = flash_attn_func(q, k, v, 0.0, causal=causal) - out_ref, attn_ref = attention_ref(q, k, v, None, None, 0.0, None, causal=causal) + out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size) + out_ref, attn_ref = attention_ref( + q, k, v, None, None, 0.0, None, causal=causal, window_size=window_size + ) out_pt, attn_pt = attention_ref( q, k, @@ -1204,6 +1351,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype): 0.0, None, causal=causal, + window_size=window_size, upcast=False, reorder_ops=True, ) @@ -1256,12 +1404,14 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype): @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) -# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) -# @pytest.mark.parametrize("d", [128]) +# @pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize("swap_sq_sk", [False, True]) # @pytest.mark.parametrize("swap_sq_sk", [True]) @pytest.mark.parametrize( @@ -1280,7 +1430,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype): ], ) # @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)]) -def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype): +def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 @@ -1292,8 +1442,9 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype): causal = True # set seed torch.random.manual_seed(0) - batch_size = 16 + batch_size = 13 nheads = 9 + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) @@ -1324,10 +1475,19 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype): max_seqlen_k, 0.0, causal=causal, + window_size=window_size, ) out = output_pad_fn(out_unpad) out_ref, attn_ref = attention_ref( - q, k, v, query_padding_mask, key_padding_mask, 0.0, None, causal=causal + q, + k, + v, + query_padding_mask, + key_padding_mask, + 0.0, + None, + causal=causal, + window_size=window_size, ) out_pt, attn_pt = attention_ref( q, @@ -1338,6 +1498,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype): 0.0, None, causal=causal, + window_size=window_size, upcast=False, reorder_ops=True, ) @@ -1393,6 +1554,8 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype): @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) @@ -1418,7 +1581,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype): ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) -def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype): +def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype): if swap_sq_sk: seqlen_q, seqlen_k = seqlen_k, seqlen_q device = "cuda" @@ -1426,11 +1589,16 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype): torch.random.manual_seed(0) batch_size = 1 nheads = 12 + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) - out, lse, _ = flash_attn_func(q, k, v, 0.0, causal=causal, return_attn_probs=True) - out_ref, attn_ref = attention_ref(q, k, v, None, None, 0.0, None, causal=causal) + out, lse, _ = flash_attn_func( + q, k, v, 0.0, causal=causal, window_size=window_size, return_attn_probs=True + ) + out_ref, attn_ref = attention_ref( + q, k, v, None, None, 0.0, None, causal=causal, window_size=window_size + ) out_pt, attn_pt = attention_ref( q, k, @@ -1440,6 +1608,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype): 0.0, None, causal=causal, + window_size=window_size, upcast=False, reorder_ops=True, ) @@ -1498,6 +1667,8 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype): # @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("new_kv", [False, True]) # @pytest.mark.parametrize("new_kv", [True]) +@pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) @@ -1506,7 +1677,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype): # @pytest.mark.parametrize("rotary_interleaved", [False]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0]) -@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) @@ -1536,6 +1707,7 @@ def test_flash_attn_kvcache( rotary_interleaved, seqlen_new_eq_seqlen_q, causal, + local, new_kv, mha_type, num_splits, @@ -1554,6 +1726,7 @@ def test_flash_attn_kvcache( rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) assert nheads % nheads_k == 0 + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() if new_kv: @@ -1566,7 +1739,7 @@ def test_flash_attn_kvcache( cache_seqlens = torch.randint( 0, # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough - (seqlen_k - (seqlen_q if causal and rotary_dim > 1 else seqlen_new) + 1) + (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) if new_kv else (seqlen_k + 1), (batch_size,), @@ -1578,7 +1751,7 @@ def test_flash_attn_kvcache( angle = torch.rand(seqlen_k, rotary_dim // 2, device=device) * 2 * math.pi cos = torch.cos(angle).to(dtype=dtype) sin = torch.sin(angle).to(dtype=dtype) - if causal: + if causal or local: q_ro = apply_rotary_emb( q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved ) @@ -1624,11 +1797,14 @@ def test_flash_attn_kvcache( sin, cache_seqlens, causal=causal, + window_size=window_size, rotary_interleaved=rotary_interleaved, num_splits=num_splits, ) - # out = flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal) - # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal) + # out = flash_attn_with_kvcache( + # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size + # ) + # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) # m = qk.amax(-1, keepdim=True) # s_tmp = torch.exp((qk - m) / math.sqrt(d)) @@ -1637,7 +1813,15 @@ def test_flash_attn_kvcache( # probs = torch.softmax(qk, dim=-1) key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0) out_ref, _ = attention_ref( - q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=causal + q_ro, + k_cache_rep, + v_cache_rep, + None, + key_padding_mask, + 0.0, + None, + causal=causal, + window_size=window_size, ) out_pt, _ = attention_ref( q_ro, @@ -1648,6 +1832,7 @@ def test_flash_attn_kvcache( 0.0, None, causal=causal, + window_size=window_size, upcast=False, reorder_ops=True, ) From 601b4dc48dbe9d87c468daa2b4c0c8388b83753c Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 26 Sep 2023 22:08:29 -0700 Subject: [PATCH 15/23] Bump to v2.3.0 --- flash_attn/__init__.py | 2 +- training/Dockerfile | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index f3d32304d..686e8d0f7 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.2.5" +__version__ = "2.3.0" from flash_attn.flash_attn_interface import ( flash_attn_func, diff --git a/training/Dockerfile b/training/Dockerfile index b9eed40c3..f3ee6422c 100644 --- a/training/Dockerfile +++ b/training/Dockerfile @@ -85,11 +85,11 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0 # Install FlashAttention -RUN pip install flash-attn==2.2.5 +RUN pip install flash-attn==2.3.0 # Install CUDA extensions for fused dense, layer norm RUN git clone https://github.com/HazyResearch/flash-attention \ - && cd flash-attention && git checkout v2.2.5 \ + && cd flash-attention && git checkout v2.3.0 \ && cd csrc/layer_norm && pip install . && cd ../../ \ && cd csrc/fused_dense_lib && pip install . && cd ../../ \ && cd .. && rm -rf flash-attention From e279bf8ed967c7c3e62380e5ff7f69f4d343275f Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 3 Oct 2023 16:27:26 -0700 Subject: [PATCH 16/23] [Gen] Accept cache_batch_idx to index into the KV cache --- csrc/flash_attn/flash_api.cpp | 21 +++++++++++++++------ csrc/flash_attn/src/flash.h | 3 +++ csrc/flash_attn/src/flash_fwd_kernel.h | 5 +++-- flash_attn/flash_attn_interface.py | 12 ++++++++++-- tests/test_flash_attn.py | 25 ++++++++++++++++++------- 5 files changed, 49 insertions(+), 17 deletions(-) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 91cc370fe..bf8cdcb6c 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -1037,13 +1037,14 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size std::vector mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor &kcache, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor &vcache, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size + const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size c10::optional &k_, // batch_size x seqlen_knew x num_heads_k x head_size c10::optional &v_, // batch_size x seqlen_knew x num_heads_k x head_size c10::optional &seqlens_k_, // batch_size c10::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) c10::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) + c10::optional &cache_batch_idx_, // indices to index into the KV cache c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size const float softmax_scale, bool is_causal, @@ -1084,6 +1085,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he const int head_size_og = sizes[3]; const int seqlen_k = kcache.size(1); const int num_heads_k = kcache.size(2); + const int batch_size_c = kcache.size(0); TORCH_CHECK(batch_size > 0, "batch size must be postive"); TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); @@ -1102,8 +1104,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he } CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); - CHECK_SHAPE(kcache, batch_size, seqlen_k, num_heads_k, head_size_og); - CHECK_SHAPE(vcache, batch_size, seqlen_k, num_heads_k, head_size_og); + CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og); + CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og); at::Tensor q_padded, kcache_padded, vcache_padded; if (head_size_og % 8 != 0) { @@ -1229,6 +1231,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he params.rotary_dim = 0; } + if (cache_batch_idx_.has_value()) { + auto cache_batch_idx = cache_batch_idx_.value(); + CHECK_DEVICE(cache_batch_idx); + CHECK_CONTIGUOUS(cache_batch_idx); + TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32"); + params.cache_batch_idx = reinterpret_cast(cache_batch_idx.data_ptr()); + } // This needs to match with run_mha_fwd_splitkv_dispatch const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); const int num_n_blocks = (seqlen_k + block_n - 1) / block_n; @@ -1248,8 +1257,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he } auto stream = at::cuda::getCurrentCUDAStream().stream(); - // Only split kernel supports appending to KV cache - run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value()); + // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx + run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value()); if (head_size_og % 8 != 0) { out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); diff --git a/csrc/flash_attn/src/flash.h b/csrc/flash_attn/src/flash.h index 81f33d9a3..fe0fe3fab 100644 --- a/csrc/flash_attn/src/flash.h +++ b/csrc/flash_attn/src/flash.h @@ -95,6 +95,9 @@ struct Flash_fwd_params : public Qkv_params { void * __restrict__ rotary_cos_ptr; void * __restrict__ rotary_sin_ptr; + // The indices to index into the KV cache. + int *__restrict__ cache_batch_idx; + // The dropout probability (probability of keeping an activation). float p_dropout; // uint32_t p_dropout_in_uint; diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 312b4dda6..323068e10 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -668,9 +668,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; // We move K and V to the last block. - const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; + const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; - const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index eba913ad6..e0444cdae 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -928,6 +928,7 @@ def flash_attn_with_kvcache( rotary_cos=None, rotary_sin=None, cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, + cache_batch_idx: Optional[torch.Tensor] = None, softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window @@ -978,8 +979,8 @@ def flash_attn_with_kvcache( Arguments: q: (batch_size, seqlen, nheads, headdim) - k_cache: (batch_size, seqlen_cache, nheads_k, headdim) - v_cache: (batch_size, seqlen_cache, nheads_k, headdim) + k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) + v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate k with k_cache, starting at the indices specified by cache_seqlens. v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k. @@ -988,6 +989,10 @@ def flash_attn_with_kvcache( rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the KV cache. + cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. + If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. + If the indices are not distinct, and k and v are provided, the values updated in the cache + might come from any of the duplicate indices. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). @@ -1014,6 +1019,8 @@ def flash_attn_with_kvcache( cache_seqlens = torch.full( (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device ) + cache_seqlens = maybe_contiguous(cache_seqlens) + cache_batch_idx = maybe_contiguous(cache_batch_idx) out, softmax_lse = flash_attn_cuda.fwd_kvcache( q, k_cache, @@ -1023,6 +1030,7 @@ def flash_attn_with_kvcache( cache_seqlens, rotary_cos, rotary_sin, + cache_batch_idx, None, softmax_scale, causal, diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 4fe330979..90e589907 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -1668,7 +1668,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt @pytest.mark.parametrize("new_kv", [False, True]) # @pytest.mark.parametrize("new_kv", [True]) @pytest.mark.parametrize("local", [False, True]) -# @pytest.mark.parametrize("local", [True]) +# @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) @@ -1677,6 +1677,8 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt # @pytest.mark.parametrize("rotary_interleaved", [False]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0]) +@pytest.mark.parametrize("has_batch_idx", [False, True]) +# @pytest.mark.parametrize("has_batch_idx", [True]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) @@ -1703,6 +1705,7 @@ def test_flash_attn_kvcache( seqlen_q, seqlen_k, d, + has_batch_idx, rotary_fraction, rotary_interleaved, seqlen_new_eq_seqlen_q, @@ -1721,6 +1724,7 @@ def test_flash_attn_kvcache( # set seed torch.random.manual_seed(0) batch_size = 2 + batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 nheads = 6 # rotary_dim must be a multiple of 16, and must be <= d rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 @@ -1734,8 +1738,8 @@ def test_flash_attn_kvcache( v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) else: k, v = None, None - k_cache = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype) - v_cache = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype) + k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) + v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) cache_seqlens = torch.randint( 0, # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough @@ -1746,6 +1750,10 @@ def test_flash_attn_kvcache( dtype=torch.int32, device=device, ) + if has_batch_idx: + cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[:batch_size] + else: + cache_batch_idx = None # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) if rotary_dim > 0: angle = torch.rand(seqlen_k, rotary_dim // 2, device=device) * 2 * math.pi @@ -1775,8 +1783,8 @@ def test_flash_attn_kvcache( cos, sin = None, None q_ro, k_ro = q, k # k_cache[:, 64:] = -1 - k_cache_ref = k_cache.clone() - v_cache_ref = v_cache.clone() + k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone() + v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone() arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") if new_kv: @@ -1796,6 +1804,7 @@ def test_flash_attn_kvcache( cos, sin, cache_seqlens, + cache_batch_idx, causal=causal, window_size=window_size, rotary_interleaved=rotary_interleaved, @@ -1844,8 +1853,10 @@ def test_flash_attn_kvcache( # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. if new_kv: - assert torch.allclose(k_cache, k_cache_ref, rtol=1e-3, atol=1e-3) - assert torch.equal(v_cache, v_cache_ref) + k_cache_select = k_cache if not has_batch_idx else k_cache[cache_batch_idx] + v_cache_select = v_cache if not has_batch_idx else v_cache[cache_batch_idx] + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) + assert torch.equal(v_cache_select, v_cache_ref) assert (out - out_ref).abs().max().item() <= 3 * (out_pt - out_ref).abs().max().item() + 1e-5 From 21c3b0d8f656049267791c6a6bca0ce3080789ea Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 3 Oct 2023 19:56:45 -0700 Subject: [PATCH 17/23] Bump to v2.3.1 --- flash_attn/__init__.py | 2 +- training/Dockerfile | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index 686e8d0f7..19950f89e 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.3.0" +__version__ = "2.3.1" from flash_attn.flash_attn_interface import ( flash_attn_func, diff --git a/training/Dockerfile b/training/Dockerfile index f3ee6422c..ee2a5e79b 100644 --- a/training/Dockerfile +++ b/training/Dockerfile @@ -85,11 +85,11 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0 # Install FlashAttention -RUN pip install flash-attn==2.3.0 +RUN pip install flash-attn==2.3.1 # Install CUDA extensions for fused dense, layer norm RUN git clone https://github.com/HazyResearch/flash-attention \ - && cd flash-attention && git checkout v2.3.0 \ + && cd flash-attention && git checkout v2.3.1 \ && cd csrc/layer_norm && pip install . && cd ../../ \ && cd csrc/fused_dense_lib && pip install . && cd ../../ \ && cd .. && rm -rf flash-attention From 5e525a8dc8a473f418c8aaf82f5322eb225d9ee5 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 3 Oct 2023 22:18:11 -0700 Subject: [PATCH 18/23] [CI] Use official Pytorch 2.1, add CUDA 11.8 for Pytorch 2.1 --- .github/workflows/publish.yml | 16 +++++++--------- flash_attn/__init__.py | 2 +- setup.py | 3 ++- training/Dockerfile | 4 ++-- 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index bb8587be9..6e82bdbbd 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -44,7 +44,7 @@ jobs: # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. os: [ubuntu-20.04] python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] - torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.0.dev20230731'] + torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.0'] cuda-version: ['11.6.2', '11.7.1', '11.8.0', '12.1.0', '12.2.0'] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. @@ -58,7 +58,7 @@ jobs: # Pytorch >= 2.0 only supports Python >= 3.8 - torch-version: '2.0.1' python-version: '3.7' - - torch-version: '2.1.0.dev20230731' + - torch-version: '2.1.0' python-version: '3.7' # Pytorch <= 2.0 only supports CUDA <= 11.8 - torch-version: '1.12.1' @@ -73,17 +73,15 @@ jobs: cuda-version: '12.1.0' - torch-version: '2.0.1' cuda-version: '12.2.0' - # Pytorch >= 2.1 only supports CUDA >= 12.1 - - torch-version: '2.1.0.dev20230731' + # Pytorch >= 2.1 only supports CUDA >= 11.8 + - torch-version: '2.1.0' cuda-version: '11.6.2' - - torch-version: '2.1.0.dev20230731' + - torch-version: '2.1.0' cuda-version: '11.7.1' - - torch-version: '2.1.0.dev20230731' - cuda-version: '11.8.0' # Pytorch >= 2.1 with nvcc 12.1.0 segfaults during compilation, so # we only use CUDA 12.2. setup.py as a special case that will # download the wheel for CUDA 12.2 instead. - - torch-version: '2.1.0.dev20230731' + - torch-version: '2.1.0' cuda-version: '12.1.0' steps: @@ -132,7 +130,7 @@ jobs: # We want to figure out the CUDA version to download pytorch # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116 # This code is ugly, maybe there's a better way to do this. - export TORCH_CUDA_VERSION=$(python -c "import os; minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 121}[os.environ['MATRIX_TORCH_VERSION']]; maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121}[os.environ['MATRIX_TORCH_VERSION']]; print(max(min(int(os.environ['MATRIX_CUDA_VERSION']), maxv), minv))") + export TORCH_CUDA_VERSION=$(python -c "import os; minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118}[os.environ['MATRIX_TORCH_VERSION']]; maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121}[os.environ['MATRIX_TORCH_VERSION']]; print(max(min(int(os.environ['MATRIX_CUDA_VERSION']), maxv), minv))") if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} else diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index 19950f89e..3a6f611fe 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.3.1" +__version__ = "2.3.1.post1" from flash_attn.flash_attn_interface import ( flash_attn_func, diff --git a/setup.py b/setup.py index f5e17a4a2..d85b7259e 100644 --- a/setup.py +++ b/setup.py @@ -233,7 +233,8 @@ def get_wheel_url(): # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) torch_cuda_version = parse(torch.version.cuda) torch_version_raw = parse(torch.__version__) - if torch_version_raw.major == 2 and torch_version_raw.minor == 1: + # Workaround for nvcc 12.1 segfaults when compiling with Pytorch 2.1 + if torch_version_raw.major == 2 and torch_version_raw.minor == 1 and torch_cuda_version.major == 12: torch_cuda_version = parse("12.2") python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" platform_name = get_platform() diff --git a/training/Dockerfile b/training/Dockerfile index ee2a5e79b..c218cc649 100644 --- a/training/Dockerfile +++ b/training/Dockerfile @@ -85,11 +85,11 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0 # Install FlashAttention -RUN pip install flash-attn==2.3.1 +RUN pip install flash-attn==2.3.1.post1 # Install CUDA extensions for fused dense, layer norm RUN git clone https://github.com/HazyResearch/flash-attention \ - && cd flash-attention && git checkout v2.3.1 \ + && cd flash-attention && git checkout v2.3.1.post1 \ && cd csrc/layer_norm && pip install . && cd ../../ \ && cd csrc/fused_dense_lib && pip install . && cd ../../ \ && cd .. && rm -rf flash-attention From aa4fd2d16654a2105e92b761df73ed8222a098dd Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Thu, 5 Oct 2023 14:00:45 -0700 Subject: [PATCH 19/23] Clarify that Windows is not supported right now --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index cccaf836d..382fb7fa8 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,8 @@ Please cite and credit FlashAttention if you use it. Requirements: - CUDA 11.6 and above. - PyTorch 1.12 and above. +- Linux. Windows is not supported for now. If you have ideas on how to modify + the code to support Windows, please reach out via Github issue. We recommend the [Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) From 3a9fe7b0faaa9d648394026c9c20231c07bf999d Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Thu, 5 Oct 2023 14:18:14 -0700 Subject: [PATCH 20/23] Add change log --- README.md | 107 +++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 101 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 382fb7fa8..c4e704c1f 100644 --- a/README.md +++ b/README.md @@ -31,8 +31,7 @@ Please cite and credit FlashAttention if you use it. Requirements: - CUDA 11.6 and above. - PyTorch 1.12 and above. -- Linux. Windows is not supported for now. If you have ideas on how to modify - the code to support Windows, please reach out via Github issue. +- Linux. Windows is not supported for now. If you have ideas on how to modify the code to support Windows, please reach out via Github issue. We recommend the [Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) @@ -83,29 +82,35 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func ``` ```python -flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False): +flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1)): """dropout_p should be set to 0.0 during evaluation If Q, K, V are already stacked into 1 tensor, this function will be faster than calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation of the gradients of Q, K, V. +If window_size != (-1, -1), implements sliding window local attention. Query at position i +will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive. Arguments: qkv: (batch_size, seqlen, 3, nheads, headdim) dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. Return: out: (batch_size, seqlen, nheads, headdim). """ ``` ```python -flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False): +flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1)): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. +If window_size != (-1, -1), implements sliding window local attention. Query at position i +will only attend to keys between +[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Arguments: q: (batch_size, seqlen, nheads, headdim) @@ -115,15 +120,86 @@ Arguments: softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. Return: out: (batch_size, seqlen, nheads, headdim). """ ``` +```python +def flash_attn_with_kvcache( + q, + k_cache, + v_cache, + k=None, + v=None, + rotary_cos=None, + rotary_sin=None, + cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, + cache_batch_idx: Optional[torch.Tensor] = None, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + rotary_interleaved=True, +): + """ + If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from + k and v. This is useful for incremental decoding: you can pass in the cached keys/values from + the previous step, and update them with the new keys/values from the current step, and do + attention with the updated cache, all in 1 kernel. + + Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be + rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. + If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos + and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. + If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at + indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens). + + See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function. + + Note: Does not support backward pass. + + Arguments: + q: (batch_size, seqlen, nheads, headdim) + k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) + v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) + k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate + k with k_cache, starting at the indices specified by cache_seqlens. + v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k. + rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding + to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. + rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. + cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the + KV cache. + cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. + If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. + If the indices are not distinct, and k and v are provided, the values updated in the cache + might come from any of the duplicate indices. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. + If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, + rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 + (i.e. GPT-NeoX style). + num_splits: int. If > 1, split the key/value into this many chunks along the sequence. + If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic + to automatically determine the number of splits. + Don't change this unless you know what you are doing. + + Return: + out: (batch_size, seqlen, nheads, headdim). + """ +``` + To see how these functions are used in a multi-head attention layer (which includes QKV projection, output projection), see the MHA [implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py). -## Upgrading from FlashAttention (1.x) to FlashAttention-2 +## Changelog + +### 2.0 +Upgrading from FlashAttention (1.x) to FlashAttention-2 These functions have been renamed: - `flash_attn_unpadded_func` -> `flash_attn_varlen_func` @@ -138,7 +214,7 @@ flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False) ```python flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False) ``` -## Changes in v2.1 (compared to v2.0) +### 2.1 If seqlen_q != seqlen_k and causal=True, the causal mask is aligned to the bottom right corner of the attention matrix, instead of the top-left corner. @@ -167,6 +243,25 @@ v2.1: 1 1 If the row of the mask is all zero, the output will be zero. +### 2.2 + +Optimize for inference (iterative decoding) when query has very small sequence +length (e.g., query sequence length = 1). The bottleneck here is to load KV +cache as fast as possible, and we split the loading across different thread +blocks, with a separate kernel to combine results. + +See the function `flash_attn_with_kvcache` with more features for inference +(perform rotary embedding, updating KV cache inplace). + +Thanks to the xformers team, and in particular Daniel Haziza, for this +collaboration. + +### 2.3 + +Implement sliding window attention (i.e., local attention). Thanks to [Mistral +AI](https://mistral.ai/) and in particular Timothée Lacroix for this +contribution. Sliding window was used in the [Mistral 7B](https://mistral.ai/news/announcing-mistral-7b/) model. + ## Performance We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory). From 5a834254428fbdc2371ffb23a9cde40a287a7ff6 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 8 Oct 2023 16:26:33 -0700 Subject: [PATCH 21/23] Change constexpr int to constexpr static int --- README.md | 8 +++---- .../src/flash_bwd_launch_template.h | 16 +++++++------- .../src/flash_fwd_launch_template.h | 22 +++++++++---------- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index c4e704c1f..56ec4dfa3 100644 --- a/README.md +++ b/README.md @@ -198,7 +198,7 @@ includes QKV projection, output projection), see the MHA [implementation](https: ## Changelog -### 2.0 +### 2.0: Complete rewrite, 2x faster Upgrading from FlashAttention (1.x) to FlashAttention-2 These functions have been renamed: @@ -214,7 +214,7 @@ flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False) ```python flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False) ``` -### 2.1 +### 2.1: Change behavior of causal flag If seqlen_q != seqlen_k and causal=True, the causal mask is aligned to the bottom right corner of the attention matrix, instead of the top-left corner. @@ -243,7 +243,7 @@ v2.1: 1 1 If the row of the mask is all zero, the output will be zero. -### 2.2 +### 2.2: Optimize for inference Optimize for inference (iterative decoding) when query has very small sequence length (e.g., query sequence length = 1). The bottleneck here is to load KV @@ -256,7 +256,7 @@ See the function `flash_attn_with_kvcache` with more features for inference Thanks to the xformers team, and in particular Daniel Haziza, for this collaboration. -### 2.3 +### 2.3: Local (i.e., sliding window) attention Implement sliding window attention (i.e., local attention). Thanks to [Mistral AI](https://mistral.ai/) and in particular Timothée Lacroix for this diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index a2a41679d..744c1d53f 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -137,7 +137,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool con template void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - constexpr int Headdim = 32; + constexpr static int Headdim = 32; int device; cudaGetDevice(&device); int max_smem_per_block; @@ -158,7 +158,7 @@ void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream, const boo template void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - constexpr int Headdim = 64; + constexpr static int Headdim = 64; int device; cudaGetDevice(&device); int max_smem_per_block; @@ -201,7 +201,7 @@ void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream, const boo template void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - constexpr int Headdim = 96; + constexpr static int Headdim = 96; int device; cudaGetDevice(&device); int max_smem_per_block; @@ -228,7 +228,7 @@ void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream, const boo template void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - constexpr int Headdim = 128; + constexpr static int Headdim = 128; int device; cudaGetDevice(&device); int max_smem_per_block; @@ -264,7 +264,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream, const bo template void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - constexpr int Headdim = 160; + constexpr static int Headdim = 160; int device; cudaGetDevice(&device); int max_smem_per_block; @@ -281,7 +281,7 @@ void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream, const bo template void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - constexpr int Headdim = 192; + constexpr static int Headdim = 192; int device; cudaGetDevice(&device); int max_smem_per_block; @@ -298,7 +298,7 @@ void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream, const bo template void run_mha_bwd_hdim224(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - constexpr int Headdim = 224; + constexpr static int Headdim = 224; BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { run_flash_bwd, Is_dropout>(params, stream, configure); }); @@ -306,7 +306,7 @@ void run_mha_bwd_hdim224(Flash_bwd_params ¶ms, cudaStream_t stream, const bo template void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - constexpr int Headdim = 256; + constexpr static int Headdim = 256; int device; cudaGetDevice(&device); int max_smem_per_block; diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index fbf3cda22..4a1192780 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -104,7 +104,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // We want kBlockM to be as small as possible for more parallelism. // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4. // If headdim is divisible by 64, then we set kBlockM = 8, etc. - constexpr int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); + constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { if (params.num_splits <= 2) { @@ -129,17 +129,17 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int kBlockM = 64; // Fixed for all head dimensions + constexpr static int kBlockM = 64; // Fixed for all head dimensions // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, // and for headdim 192 with block size 64 x 128. // Also for headdim 160 with block size 64 x 128 after the rotary addition. - constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); + constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); run_flash_splitkv_fwd>(params, stream); } template void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 32; + constexpr static int Headdim = 32; BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { run_flash_fwd, Is_dropout, Is_causal>(params, stream); @@ -149,7 +149,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 64; + constexpr static int Headdim = 64; BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { if constexpr(!Is_dropout) { @@ -171,7 +171,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 96; + constexpr static int Headdim = 96; auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm8x = dprops->major == 8 && dprops->minor > 0; BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { @@ -197,7 +197,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 128; + constexpr static int Headdim = 128; auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm8x = dprops->major == 8 && dprops->minor > 0; BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { @@ -234,7 +234,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 160; + constexpr static int Headdim = 160; auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm8x = dprops->major == 8 && dprops->minor > 0; BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { @@ -264,7 +264,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 192; + constexpr static int Headdim = 192; BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { if constexpr(!Is_dropout) { @@ -283,7 +283,7 @@ void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 224; + constexpr static int Headdim = 224; int device; cudaGetDevice(&device); int max_smem_per_block; @@ -309,7 +309,7 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 256; + constexpr static int Headdim = 256; int device; cudaGetDevice(&device); int max_smem_per_sm, max_smem_per_block; From 7f31e7c16a58227e04e0a7aed9ca0066ec8126fe Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 8 Oct 2023 17:21:29 -0700 Subject: [PATCH 22/23] Bump to v2.3.2 --- README.md | 2 +- flash_attn/__init__.py | 2 +- training/Dockerfile | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 56ec4dfa3..582f1154a 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ Please cite and credit FlashAttention if you use it. Requirements: - CUDA 11.6 and above. - PyTorch 1.12 and above. -- Linux. Windows is not supported for now. If you have ideas on how to modify the code to support Windows, please reach out via Github issue. +- Linux. Might work for Windows starting v2.3.2 (we've seen a few positive [reports](https://github.com/Dao-AILab/flash-attention/issues/595)) but Windows compilation still requires more testing. If you have ideas on how to set up prebuilt CUDA wheels for Windows, please reach out via Github issue. We recommend the [Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index 3a6f611fe..971ebc6ba 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.3.1.post1" +__version__ = "2.3.2" from flash_attn.flash_attn_interface import ( flash_attn_func, diff --git a/training/Dockerfile b/training/Dockerfile index c218cc649..807e5787a 100644 --- a/training/Dockerfile +++ b/training/Dockerfile @@ -85,11 +85,11 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0 # Install FlashAttention -RUN pip install flash-attn==2.3.1.post1 +RUN pip install flash-attn==2.3.2 # Install CUDA extensions for fused dense, layer norm RUN git clone https://github.com/HazyResearch/flash-attention \ - && cd flash-attention && git checkout v2.3.1.post1 \ + && cd flash-attention && git checkout v2.3.2 \ && cd csrc/layer_norm && pip install . && cd ../../ \ && cd csrc/fused_dense_lib && pip install . && cd ../../ \ && cd .. && rm -rf flash-attention From 02ac572f3ffc4f402e4183aaa6824b45859d3ed3 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Thu, 12 Oct 2023 10:14:58 -0700 Subject: [PATCH 23/23] Clarify inference README is a placeholder --- examples/inference/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/inference/README.md b/examples/inference/README.md index 695f04b1c..267db86e2 100644 --- a/examples/inference/README.md +++ b/examples/inference/README.md @@ -1,2 +1,3 @@ # Example of LLM inference using FlashAttention +Example script of using FlashAttention for inference coming soon.