Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the mixed chunk prefill by lanuch two kernels #2811

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions python/sglang/srt/layers/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,18 @@ def forward_decode(
"""Run a forward for decode."""
raise NotImplementedError()

def forward_mixed(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
):
"""Run a forward for mixed prefill & decode."""
raise NotImplementedError()

def forward_extend(
self,
q: torch.Tensor,
Expand Down
253 changes: 245 additions & 8 deletions python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ class PrefillMetadata:
extend_no_prefix: bool


@dataclass
class MixedMetadata:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
use_ragged: bool
extend_no_prefix: bool


class FlashInferAttnBackend(AttentionBackend):
"""Flashinfer attention kernels."""

Expand Down Expand Up @@ -91,6 +99,7 @@ def __init__(self, model_runner: ModelRunner):
device=model_runner.device,
)
max_bs = model_runner.req_to_token_pool.size
print(f"max batch size = {max_bs}")
self.kv_indptr = [
torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
for _ in range(self.num_wrappers)
Expand Down Expand Up @@ -138,11 +147,16 @@ def __init__(self, model_runner: ModelRunner):
)

# Other metadata
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata, MixedMetadata] = (
None
)
self.decode_cuda_graph_metadata = {}
self.prefill_cuda_graph_metadata = {}

def init_forward_metadata(self, forward_batch: ForwardBatch):
print(
f"mode={forward_batch.forward_mode},batch_size={forward_batch.batch_size}, req_pool_indices={forward_batch.req_pool_indices}, seq_lens={forward_batch.seq_lens},decode_start_idx={forward_batch.decode_start_idx}"
)
if forward_batch.forward_mode.is_decode():
self.indices_updater_decode.update(
forward_batch.req_pool_indices,
Expand All @@ -153,6 +167,84 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
spec_info=forward_batch.spec_info,
)
self.forward_metadata = DecodeMetadata(self.decode_wrappers)
elif forward_batch.forward_mode.is_mixed():
# Part 0: prepare
extend_bs = forward_batch.decode_start_idx
assert extend_bs > 0, f"extent_bs = {extend_bs}"
req_pool_indices_extend, req_pool_indices_decode = (
forward_batch.req_pool_indices[:extend_bs],
forward_batch.req_pool_indices[extend_bs:],
)
seq_lens_extend, seq_lens_decode = (
forward_batch.seq_lens[:extend_bs],
forward_batch.seq_lens[extend_bs:],
)
seq_lens_sum_extend, seq_lens_sum_decode = (
forward_batch.seq_lens[:extend_bs].sum().item(),
forward_batch.seq_lens[extend_bs:].sum().item(),
)
assert (
seq_lens_sum_extend + seq_lens_sum_decode == forward_batch.seq_lens_sum
), f"{seq_lens_sum_extend} + {seq_lens_sum_ex} != {forward_batch.seq_lens_sum}"
prefix_lens_extend = forward_batch.extend_prefix_lens[:extend_bs]

encoder_lens_extend = (
forward_batch.encoder_lens[:extend_bs]
if forward_batch.encoder_lens is not None
else None
)
encoder_lens_decode = (
(
forward_batch.encoder_lens[extend_bs:]
if forward_batch.encoder_lens is not None
else None
),
)

# Part1: Prefill
decode_running_bs = (
forward_batch.batch_size - forward_batch.decode_start_idx
)
extend_tokens = forward_batch.extend_num_tokens - decode_running_bs
if extend_tokens >= 4096 and self.num_wrappers == 1:
use_ragged = True
extend_no_prefix = not any(
forward_batch.extend_prefix_lens_cpu[:extend_bs]
)
else:
use_ragged = False
extend_no_prefix = False

self.indices_updater_prefill.update(
req_pool_indices_extend,
seq_lens_extend,
seq_lens_sum_extend,
prefix_lens_extend,
prefill_wrappers=self.prefill_wrappers_paged,
use_ragged=use_ragged,
encoder_lens=encoder_lens_extend,
spec_info=None,
)

# Part2: Decode
self.indices_updater_decode.decode_start_idx = extend_bs
self.indices_updater_decode.update(
req_pool_indices_decode,
seq_lens_decode,
seq_lens_sum_decode,
decode_wrappers=self.decode_wrappers,
encoder_lens=encoder_lens_decode,
spec_info=forward_batch.spec_info,
)
# since the `decode_start_idx` only used in the mixed mode
# Manually update it so we do not need change the update method

self.forward_metadata = MixedMetadata(
self.prefill_wrappers_paged,
self.decode_wrappers,
use_ragged,
extend_no_prefix,
)
elif forward_batch.forward_mode.is_draft_extend():
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
Expand Down Expand Up @@ -338,6 +430,9 @@ def forward_extend(
forward_batch: ForwardBatch,
save_kv_cache=True,
):
if forward_batch.forward_mode.is_mixed():
return self.forward_mixed(q, k, v, layer, forward_batch, save_kv_cache)

prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
self._get_wrapper_idx(layer)
]
Expand Down Expand Up @@ -397,6 +492,118 @@ def forward_extend(

return o.view(-1, layer.tp_q_head_num * layer.head_dim)

def forward_mixed(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):

# Part0: split the prefill and decode
decode_running_bs = forward_batch.batch_size - forward_batch.decode_start_idx
# each decode request only use one token
extend_tokens = forward_batch.extend_num_tokens - decode_running_bs

print(
f"mixed forward: q={q.shape}, k={k.shape}, extend_tokens={extend_tokens}, decode_running_bs={decode_running_bs}"
)

q_extend, q_decode = q[:extend_tokens], q[extend_tokens:]
k_extend, k_decode = k[:extend_tokens], k[extend_tokens:]
v_extend, v_decode = v[:extend_tokens], v[extend_tokens:]

print(f"q_extend={q_extend.shape}, q_decode={q_decode.shape}")
print(f"k_extend={k_extend.shape}, k_decode={k_decode.shape}")

out_cache_loc_extend = forward_batch.out_cache_loc[:extend_tokens]
out_cache_loc_decode = forward_batch.out_cache_loc[extend_tokens:]

## Part1: Prefill
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
self._get_wrapper_idx(layer)
]
cache_loc_extend = (
out_cache_loc_extend
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc[:extend_tokens]
)

if not self.forward_metadata.use_ragged:
if k_extend is not None:
assert v_extend is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc_extend, k_extend, v_extend
)

o = prefill_wrapper_paged.forward(
q_extend.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=not layer.is_cross_attention,
sm_scale=layer.scaling,
window_left=layer.sliding_window_size,
logits_soft_cap=layer.logit_cap,
)
else:
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
q_extend.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_extend.contiguous().view(-1, layer.tp_k_head_num, layer.head_dim),
v_extend.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
)

if self.forward_metadata.extend_no_prefix:
o = o1
else:
o2, s2 = prefill_wrapper_paged.forward_return_lse(
q_extend.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=False,
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
)

o, _ = merge_state(o1, s1, o2, s2)

if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc_extend, k_extend, v_extend
)

o = o.view(-1, layer.tp_q_head_num * layer.head_dim)

## Part2: decode
decode_wrapper = self.forward_metadata.decode_wrappers[
self._get_wrapper_idx(layer)
]
cache_loc_decode = (
out_cache_loc_decode
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc[extend_tokens:]
)

if k_decode is not None:
assert v_decode is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc_decode, k_decode, v_decode
)

o_decode = decode_wrapper.forward(
q_decode.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
)
o_decode = o_decode.view(-1, layer.tp_q_head_num * layer.head_dim)

return torch.cat((o, o_decode), dim=0)

def forward_decode(
self,
q: torch.Tensor,
Expand Down Expand Up @@ -460,6 +667,9 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
self.sliding_window_size = model_runner.sliding_window_size
self.attn_backend = attn_backend

# mixed chunk prefill
self.decode_start_idx = 0

# Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len
Expand Down Expand Up @@ -580,22 +790,33 @@ def call_begin_forward(
kv_start_idx: torch.Tensor,
spec_info: Optional[SpecInfo],
):
shift_pos = self.decode_start_idx
if spec_info is None:
bs = len(req_pool_indices)
print("====Decode call_begin_forward====")
print(
f"req_pool_indices={req_pool_indices}, paged_kernel_lens={paged_kernel_lens}, decode_start_idx={self.decode_start_idx}, shift_pos = {shift_pos}"
)
print(f"kv_indptr before modification:, {kv_indptr[:20]}")
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indptr[0] = 0
# kv_indptr = kv_indptr[shift_pos : bs + 1 + shift_pos]

kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_indptr,
kv_indptr[: bs + 1],
kv_start_idx,
kv_indices,
self.req_to_token.shape[1],
)
print(f"kv_indptr after modification:, {kv_indptr[:20]}")
print(f"kv_indices={kv_indices}")
print("========")
else:
bs, kv_indices, kv_indptr = spec_info.generate_attn_arg_decode(
req_pool_indices,
Expand All @@ -605,16 +826,19 @@ def call_begin_forward(

wrapper.end_forward()
wrapper.begin_forward(
kv_indptr,
kv_indptr[: bs + 1],
kv_indices,
self.kv_last_page_len[:bs],
self.kv_last_page_len[shift_pos : bs + shift_pos],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
1,
data_type=self.data_type,
q_data_type=self.q_data_type,
)
print(
f"decode kv_last_page_len = {self.kv_last_page_len[shift_pos:bs+shift_pos]}"
)


class FlashInferIndicesUpdaterPrefill:
Expand Down Expand Up @@ -792,16 +1016,22 @@ def call_begin_forward(
bs = len(req_pool_indices)
if spec_info is None:
# Normal extend
print("****Prefill call_begin_forward****")
print(
f"req_pool_indices={req_pool_indices}, paged_kernel_lens={paged_kernel_lens}"
)
print(f"kv_indptr before modification:, {kv_indptr[:20]}")
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indptr[0] = 0
# kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_indptr,
kv_indptr[: bs + 1],
kv_start_idx,
kv_indices,
self.req_to_token.shape[1],
Expand All @@ -810,6 +1040,10 @@ def call_begin_forward(
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1]
custom_mask = None
print(f"kv_indptr after modification:, {kv_indptr[:20]}")
print(f"kv_indices={kv_indices}")
print(f"qo_indptr={qo_indptr}")
print("========")
else:
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
Expand All @@ -835,7 +1069,9 @@ def call_begin_forward(
wrapper_paged.end_forward()
wrapper_paged.begin_forward(
qo_indptr,
kv_indptr,
kv_indptr[
: bs + 1
], # TODO(lihu): spec_info is not empty should consider, pass or modified?
kv_indices,
self.kv_last_page_len[:bs],
self.num_qo_heads,
Expand All @@ -845,6 +1081,7 @@ def call_begin_forward(
q_data_type=self.q_data_type,
custom_mask=custom_mask,
)
print(f"forward kv_last_page_len = {self.kv_last_page_len[:bs]}")


@triton.jit
Expand Down
Loading
Loading