Skip to content

Commit

Permalink
Apply rewrite for normal attention and MQA
Browse files Browse the repository at this point in the history
Fixes a bug introduced in mlc-ai#1052,
where use of the `--use-flash-attn-mqa` flag on a model that doesn't
use MQA would prevent the use of CUTLASS attention at all.
  • Loading branch information
Lunderberg committed Oct 27, 2023
1 parent 24f795e commit f43dd23
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,9 @@ def mod_transform_before_build(
has_cutlass = tvm.get_global_func("relax.ext.cutlass", True)

if has_cutlass and not args.no_cutlass_attn:
mod = rewrite_attention(use_flash_mqa=args.use_flash_attn_mqa)(mod)
if args.use_flash_attn_mqa:
mod = rewrite_attention(use_flash_mqa=True)(mod)
mod = rewrite_attention(use_flash_mqa=False)(mod)
patterns += get_patterns_with_prefix("cutlass.attention")

if has_cutlass and not args.no_cutlass_norm:
Expand Down

0 comments on commit f43dd23

Please sign in to comment.