From 4bdd901c1e3baa77b7808177b843aae1ad9e9224 Mon Sep 17 00:00:00 2001 From: libra Date: Mon, 6 Jan 2025 11:07:50 +0800 Subject: [PATCH 01/11] Improve the mixed chunk prefill launch two kernels for one batch, --- .../sglang/srt/layers/attention/__init__.py | 12 + .../layers/attention/flashinfer_backend.py | 213 +++++++++++++++++- python/sglang/srt/managers/schedule_batch.py | 6 + .../srt/model_executor/forward_batch_info.py | 4 + 4 files changed, 230 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/attention/__init__.py b/python/sglang/srt/layers/attention/__init__.py index 140755ff5e6..bc4d3210866 100644 --- a/python/sglang/srt/layers/attention/__init__.py +++ b/python/sglang/srt/layers/attention/__init__.py @@ -80,6 +80,18 @@ def forward_decode( """Run a forward for decode.""" raise NotImplementedError() + def forward_mixed( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache: bool = True, + ): + """Run a forward for mixed prefill & decode.""" + raise NotImplementedError() + def forward_extend( self, q: torch.Tensor, diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 8b823cc5a5d..7ccfb4710d2 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -52,6 +52,14 @@ class PrefillMetadata: extend_no_prefix: bool +@dataclass +class MixedMetadata: + prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper] + decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper] + use_ragged: bool + extend_no_prefix: bool + + class FlashInferAttnBackend(AttentionBackend): """Flashinfer attention kernels.""" @@ -138,7 +146,9 @@ def __init__(self, model_runner: ModelRunner): ) # Other metadata - self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None + self.forward_metadata: Union[PrefillMetadata, DecodeMetadata, MixedMetadata] = ( + None + ) self.decode_cuda_graph_metadata = {} self.prefill_cuda_graph_metadata = {} @@ -153,6 +163,80 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): spec_info=forward_batch.spec_info, ) self.forward_metadata = DecodeMetadata(self.decode_wrappers) + elif forward_batch.forward_mode.is_mixed(): + # Part 0: prepare + extend_bs = forward_batch.decode_start_idx + print( + f"init_forward_metadata: batch_size={forward_batch.batch_size}, running_bs={running_bs}, extend_bs={extend_bs}" + ) + + req_pool_indices_extend = forward_batch.req_pool_indices[:extend_bs] + req_pool_indices_decode = (forward_batch.req_pool_indices[extend_bs:],) + seq_lens_extend, seq_lens_decode = ( + forward_batch.seq_lens[:extend_bs], + forward_batch.seq_lens[extend_bs:], + ) + + seq_lens_sum_extend = forward_batch.seq_lens[:extend_bs].sum().item() + seq_lens_sum_decode = forward_batch.seq_lens[extend_bs:].sum().item() + prefix_lens_extend = forward_batch.extend_prefix_lens[:extend_bs] + + extend_prefix_lens_origin = len(forward_batch.extend_prefix_lens) + print( + f"origin prefix lens = {extend_prefix_lens_origin}, extend_bs={extend_bs}" + ) + assert extend_prefix_lens_origin == extend_bs + encoder_lens_extend = ( + forward_batch.encoder_lens[:extend_bs] + if forward_batch.encoder_lens is not None + else None + ) + encoder_lens_decode = ( + ( + forward_batch.encoder_lens[extend_bs:] + if forward_batch.encoder_lens is not None + else None + ), + ) + self.indices_updater_decode.decode_indices = extend_bs + + # Part1: Prefill + if forward_batch.decode_start_idx >= 4096 and self.num_wrappers == 1: + use_ragged = True + extend_no_prefix = not any( + forward_batch.extend_prefix_lens_cpu[:extend_bs] + ) + else: + use_ragged = False + extend_no_prefix = False + + self.indices_updater_prefill.update( + req_pool_indices_extend, + seq_lens_extend, + seq_lens_sum_extend, + prefix_lens_extend, + prefill_wrappers=self.prefill_wrappers_paged, + use_ragged=use_ragged, + encoder_lens=encoder_lens_extend, + spec_info=None, + ) + + # Part2: Decode + self.indices_updater_decode.update( + req_pool_indices_decode, + seq_lens_decode, + seq_lens_sum_decode, + decode_wrappers=self.decode_wrappers, + encoder_lens=encoder_lens_decode, + spec_info=forward_batch.spec_info, + ) + + self.forward_metadata = MixedMetadata( + self.prefill_wrappers_paged, + self.decode_wrappers, + use_ragged, + extend_no_prefix, + ) elif forward_batch.forward_mode.is_draft_extend(): self.indices_updater_prefill.update( forward_batch.req_pool_indices, @@ -338,6 +422,9 @@ def forward_extend( forward_batch: ForwardBatch, save_kv_cache=True, ): + if forward_batch.forward_mode.is_mixed(): + return self.forward_mixed(q, k, v, layer, forward_batch, save_kv_cache) + prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[ self._get_wrapper_idx(layer) ] @@ -389,6 +476,110 @@ def forward_extend( return o.view(-1, layer.tp_q_head_num * layer.head_dim) + def forward_mixed( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + # Part0: split the prefill and decode + extend_tokens = forward_batch.decode_start_idx + + print(f"forward_mixed: extend_tokens={extend_tokens}") + + k_extend, k_decode = k[:extend_tokens], k[extend_tokens:] + v_extend, v_decode = v[:extend_tokens], v[extend_tokens:] + q_extend, q_decode = q[:extend_tokens], q[extend_tokens:] + + out_cache_loc_extend = forward_batch.out_cache_loc[:extend_tokens] + out_cache_loc_decode = forward_batch.out_cache_loc[extend_tokens:] + + ## Part1: Prefill + prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[ + self._get_wrapper_idx(layer) + ] + cache_loc_extend = ( + out_cache_loc_extend + if not layer.is_cross_attention + else forward_batch.encoder_out_cache_loc[:extend_tokens] + ) + + if not self.forward_metadata.use_ragged: + if k_extend is not None: + assert v_extend is not None + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc_extend, k_extend, v_extend + ) + + o = prefill_wrapper_paged.forward( + q_extend.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), + causal=not layer.is_cross_attention, + sm_scale=layer.scaling, + window_left=layer.sliding_window_size, + logits_soft_cap=layer.logit_cap, + ) + else: + o1, s1 = self.prefill_wrapper_ragged.forward_return_lse( + q_extend.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + k_extend.contiguous().view(-1, layer.tp_k_head_num, layer.head_dim), + v_extend.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim), + causal=True, + sm_scale=layer.scaling, + logits_soft_cap=layer.logit_cap, + ) + + if self.forward_metadata.extend_no_prefix: + o = o1 + else: + o2, s2 = prefill_wrapper_paged.forward_return_lse( + q_extend.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), + causal=False, + sm_scale=layer.scaling, + logits_soft_cap=layer.logit_cap, + ) + + o, _ = merge_state(o1, s1, o2, s2) + + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc_extend, k_extend, v_extend + ) + + o = o.view(-1, layer.tp_q_head_num * layer.head_dim) + + ## Part2: decode + decode_wrapper = self.forward_metadata.decode_wrappers[ + self._get_wrapper_idx(layer) + ] + cache_loc_decode = ( + out_cache_loc_decode + if not layer.is_cross_attention + else forward_batch.encoder_out_cache_loc[extend_tokens:] + ) + + if k_decode is not None: + assert v_decode is not None + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc_decode, k_decode, v_decode + ) + + o_decode = decode_wrapper.forward( + q_decode.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), + sm_scale=layer.scaling, + logits_soft_cap=layer.logit_cap, + ) + o_decode = o_decode.view(-1, layer.tp_q_head_num * layer.head_dim) + + return torch.cat((o, o_decode), dim=0) + def forward_decode( self, q: torch.Tensor, @@ -570,8 +761,18 @@ def call_begin_forward( ): if spec_info is None: bs = len(req_pool_indices) - kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) - kv_indptr = kv_indptr[: bs + 1] + if self.decode_indices > 0: + kv_indptr[1 + self.decode_indices] = 0 + kv_indptr[2 + self.decode_indices : 2 + self.decode_indices + bs] = ( + torch.cumsum(paged_kernel_lens, dim=0) + ) + kv_indptr_decode = kv_indptr[ + 1 + self.decode_indices : 2 + self.decode_indices + bs + ] + self.decode_indices = 0 + else: + kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) + kv_indptr_decode = kv_indptr[: bs + 1] kv_indices = torch.empty( paged_kernel_lens_sum, dtype=torch.int32, device="cuda" ) @@ -579,7 +780,7 @@ def call_begin_forward( self.req_to_token, req_pool_indices, paged_kernel_lens, - kv_indptr, + kv_indptr_decode, kv_start_idx, kv_indices, self.req_to_token.shape[1], @@ -590,10 +791,12 @@ def call_begin_forward( paged_kernel_lens, self.req_to_token, ) + # TODO(lihu): fix this ? + kv_indptr_decode = kv_indptr[: bs + 1] wrapper.end_forward() wrapper.begin_forward( - kv_indptr, + kv_indptr_decode, kv_indices, self.kv_last_page_len[:bs], self.num_qo_heads, diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 2a5db90842f..c3dc508bcfb 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -558,6 +558,7 @@ class ScheduleBatch: extend_num_tokens: int = None decoding_reqs: List[Req] = None extend_logprob_start_lens: List[int] = None + decode_start_idx: int = 0 # For encoder-decoder encoder_cached: Optional[List[bool]] = None @@ -1100,6 +1101,7 @@ def merge_batch(self, other: "ScheduleBatch"): self.req_pool_indices = torch.concat( [self.req_pool_indices, other.req_pool_indices] ) + self.decode_start_idx = len(self.seq_lens) self.seq_lens = torch.concat([self.seq_lens, other.seq_lens]) self.out_cache_loc = None self.seq_lens_sum += other.seq_lens_sum @@ -1163,6 +1165,7 @@ def get_model_worker_batch(self): input_embeds=self.input_embeds, spec_algorithm=self.spec_algorithm, spec_info=self.spec_info, + decode_start_idx=self.decode_start_idx, ) def copy(self): @@ -1216,6 +1219,9 @@ class ModelWorkerBatch: extend_prefix_lens: Optional[List[int]] extend_logprob_start_lens: Optional[List[int]] + # For mixed chunked prefill + decode_start_idx: int = 0 + # For multimodal image_inputs: Optional[List[ImageInputs]] diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 9269611491c..261d99bd573 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -171,6 +171,9 @@ class ForwardBatch: gathered_buffer: Optional[torch.Tensor] = None can_run_dp_cuda_graph: bool = False + # For mixed chunked prefill + decode_start_idx: int = 0 + # Speculative decoding spec_info: SpecInfo = None spec_algorithm: SpeculativeAlgorithm = None @@ -266,6 +269,7 @@ def init_new( spec_algorithm=batch.spec_algorithm, spec_info=batch.spec_info, input_embeds=batch.input_embeds, + decode_start_idx=self.decode_start_idx, ) if ret.global_num_tokens is not None: From 7dd7ce5c4fb61f34d592cbdc1b6782018e74713c Mon Sep 17 00:00:00 2001 From: libra Date: Fri, 10 Jan 2025 18:33:08 +0800 Subject: [PATCH 02/11] Improve the index in the FlashInferIndicesUpdaterDecode --- .../layers/attention/flashinfer_backend.py | 33 ++++++++++++------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 7ccfb4710d2..4556ec33ff0 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -58,6 +58,7 @@ class MixedMetadata: decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper] use_ragged: bool extend_no_prefix: bool + decode_start_idx: int class FlashInferAttnBackend(AttentionBackend): @@ -167,18 +168,21 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): # Part 0: prepare extend_bs = forward_batch.decode_start_idx print( - f"init_forward_metadata: batch_size={forward_batch.batch_size}, running_bs={running_bs}, extend_bs={extend_bs}" + f"init_forward_metadata: batch_size={forward_batch.batch_size}, extend_bs={extend_bs}" ) - req_pool_indices_extend = forward_batch.req_pool_indices[:extend_bs] - req_pool_indices_decode = (forward_batch.req_pool_indices[extend_bs:],) + req_pool_indices_extend, req_pool_indices_decode = ( + forward_batch.req_pool_indices[:extend_bs], + forward_batch.req_pool_indices[extend_bs:], + ) seq_lens_extend, seq_lens_decode = ( forward_batch.seq_lens[:extend_bs], forward_batch.seq_lens[extend_bs:], ) - - seq_lens_sum_extend = forward_batch.seq_lens[:extend_bs].sum().item() - seq_lens_sum_decode = forward_batch.seq_lens[extend_bs:].sum().item() + seq_lens_sum_extend, seq_lens_sum_decode = ( + forward_batch.seq_lens[:extend_bs].sum().item(), + forward_batch.seq_lens[extend_bs:].sum().item(), + ) prefix_lens_extend = forward_batch.extend_prefix_lens[:extend_bs] extend_prefix_lens_origin = len(forward_batch.extend_prefix_lens) @@ -198,7 +202,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): else None ), ) - self.indices_updater_decode.decode_indices = extend_bs # Part1: Prefill if forward_batch.decode_start_idx >= 4096 and self.num_wrappers == 1: @@ -230,12 +233,16 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): encoder_lens=encoder_lens_decode, spec_info=forward_batch.spec_info, ) + # since the `decode_start_idx` only used in the mixed mode + # Manually update it so we do not need change the update method + self.indices_updater_decode.decode_start_idx = extend_bs self.forward_metadata = MixedMetadata( self.prefill_wrappers_paged, self.decode_wrappers, use_ragged, extend_no_prefix, + extend_bs, ) elif forward_batch.forward_mode.is_draft_extend(): self.indices_updater_prefill.update( @@ -639,6 +646,9 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): self.sliding_window_size = model_runner.sliding_window_size self.attn_backend = attn_backend + # mixed chunk prefill + self.decode_start_idx = 0 + # Buffers and wrappers self.kv_indptr = attn_backend.kv_indptr self.kv_last_page_len = attn_backend.kv_last_page_len @@ -769,7 +779,6 @@ def call_begin_forward( kv_indptr_decode = kv_indptr[ 1 + self.decode_indices : 2 + self.decode_indices + bs ] - self.decode_indices = 0 else: kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr_decode = kv_indptr[: bs + 1] @@ -982,9 +991,11 @@ def call_begin_forward( ): bs = len(req_pool_indices) if spec_info is None: - # Normal extend - kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) - kv_indptr = kv_indptr[: bs + 1] + shift_pos = self.decode_start_idx + kv_indptr[1 + shift_pos : bs + 1 + shift_pos] = torch.cumsum( + paged_kernel_lens, dim=0 + ) + kv_indptr = kv_indptr[shift_pos : bs + 1 + shift_pos] kv_indices = torch.empty( paged_kernel_lens_sum, dtype=torch.int32, device="cuda" ) From 5435a015c27885a20e8db9495ffb1a85d1865045 Mon Sep 17 00:00:00 2001 From: libra Date: Fri, 10 Jan 2025 18:39:20 +0800 Subject: [PATCH 03/11] improve the MixedMetadata --- python/sglang/srt/layers/attention/flashinfer_backend.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 5fbac249e0b..f01d699db7c 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -58,7 +58,6 @@ class MixedMetadata: decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper] use_ragged: bool extend_no_prefix: bool - decode_start_idx: int class FlashInferAttnBackend(AttentionBackend): @@ -242,7 +241,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): self.decode_wrappers, use_ragged, extend_no_prefix, - extend_bs, ) elif forward_batch.forward_mode.is_draft_extend(): self.indices_updater_prefill.update( From cb2d4f7d9f4feba64e23aaf7bfb0be296aa05683 Mon Sep 17 00:00:00 2001 From: libra Date: Fri, 10 Jan 2025 18:46:13 +0800 Subject: [PATCH 04/11] fix the FlashInferIndicesUpdaterPrefill --- .../layers/attention/flashinfer_backend.py | 29 +++++++------------ 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index f01d699db7c..5e927d09e2d 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -771,17 +771,12 @@ def call_begin_forward( ): if spec_info is None: bs = len(req_pool_indices) - if self.decode_indices > 0: - kv_indptr[1 + self.decode_indices] = 0 - kv_indptr[2 + self.decode_indices : 2 + self.decode_indices + bs] = ( - torch.cumsum(paged_kernel_lens, dim=0) - ) - kv_indptr_decode = kv_indptr[ - 1 + self.decode_indices : 2 + self.decode_indices + bs - ] - else: - kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) - kv_indptr_decode = kv_indptr[: bs + 1] + shift_pos = self.decode_start_idx + kv_indptr[1 + shift_pos : bs + 1 + shift_pos] = torch.cumsum( + paged_kernel_lens, dim=0 + ) + kv_indptr = kv_indptr[shift_pos : bs + 1 + shift_pos] + kv_indices = torch.empty( paged_kernel_lens_sum, dtype=torch.int32, device="cuda" ) @@ -800,12 +795,10 @@ def call_begin_forward( paged_kernel_lens, self.req_to_token, ) - # TODO(lihu): fix this ? - kv_indptr_decode = kv_indptr[: bs + 1] wrapper.end_forward() wrapper.begin_forward( - kv_indptr_decode, + kv_indptr, kv_indices, self.kv_last_page_len[:bs], self.num_qo_heads, @@ -991,11 +984,9 @@ def call_begin_forward( ): bs = len(req_pool_indices) if spec_info is None: - shift_pos = self.decode_start_idx - kv_indptr[1 + shift_pos : bs + 1 + shift_pos] = torch.cumsum( - paged_kernel_lens, dim=0 - ) - kv_indptr = kv_indptr[shift_pos : bs + 1 + shift_pos] + # Normal extend + kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] kv_indices = torch.empty( paged_kernel_lens_sum, dtype=torch.int32, device="cuda" ) From bd85c2cdacf6f6ecf4c12a841420e53ff59f0630 Mon Sep 17 00:00:00 2001 From: libra Date: Fri, 10 Jan 2025 18:55:09 +0800 Subject: [PATCH 05/11] Fix some bugs --- python/sglang/srt/layers/attention/flashinfer_backend.py | 5 +---- python/sglang/srt/model_executor/forward_batch_info.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 5e927d09e2d..17e4932fdc4 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -185,9 +185,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): prefix_lens_extend = forward_batch.extend_prefix_lens[:extend_bs] extend_prefix_lens_origin = len(forward_batch.extend_prefix_lens) - print( - f"origin prefix lens = {extend_prefix_lens_origin}, extend_bs={extend_bs}" - ) assert extend_prefix_lens_origin == extend_bs encoder_lens_extend = ( forward_batch.encoder_lens[:extend_bs] @@ -784,7 +781,7 @@ def call_begin_forward( self.req_to_token, req_pool_indices, paged_kernel_lens, - kv_indptr_decode, + kv_indptr, kv_start_idx, kv_indices, self.req_to_token.shape[1], diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 666ae0eccd3..2a2561f7121 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -289,7 +289,7 @@ def init_new( spec_info=batch.spec_info, capture_hidden_mode=batch.capture_hidden_mode, input_embeds=batch.input_embeds, - decode_start_idx=self.decode_start_idx, + decode_start_idx=batch.decode_start_idx, ) if ret.global_num_tokens is not None: From d092fdcc06e987eacb7f365577bad80e5812c32d Mon Sep 17 00:00:00 2001 From: libra Date: Sat, 11 Jan 2025 08:56:58 +0800 Subject: [PATCH 06/11] Fix the init order --- python/sglang/srt/managers/schedule_batch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index a66e292acdd..2d34f53275e 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1165,12 +1165,12 @@ def get_model_worker_batch(self): input_embeds=self.input_embeds, spec_algorithm=self.spec_algorithm, spec_info=self.spec_info, - decode_start_idx=self.decode_start_idx, capture_hidden_mode=( getattr(self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL) if self.spec_info else CaptureHiddenMode.NULL ), + decode_start_idx=self.decode_start_idx, ) def copy(self): @@ -1224,9 +1224,6 @@ class ModelWorkerBatch: extend_prefix_lens: Optional[List[int]] extend_logprob_start_lens: Optional[List[int]] - # For mixed chunked prefill - decode_start_idx: int = 0 - # For multimodal image_inputs: Optional[List[ImageInputs]] @@ -1250,6 +1247,9 @@ class ModelWorkerBatch: spec_info: Optional[SpecInfo] = None capture_hidden_mode: CaptureHiddenMode = None + # For mixed chunked prefill + decode_start_idx: int = 0 + @triton.jit def write_req_to_token_pool_triton( From 2b3c48671469030a3f5f0ba7b1dd5cbcb3e44129 Mon Sep 17 00:00:00 2001 From: libra Date: Mon, 13 Jan 2025 12:50:16 +0800 Subject: [PATCH 07/11] Enable mixed chunk by default --- python/sglang/srt/layers/attention/flashinfer_backend.py | 4 +++- python/sglang/srt/server_args.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 17e4932fdc4..451e5200702 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -185,7 +185,9 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): prefix_lens_extend = forward_batch.extend_prefix_lens[:extend_bs] extend_prefix_lens_origin = len(forward_batch.extend_prefix_lens) - assert extend_prefix_lens_origin == extend_bs + assert ( + extend_prefix_lens_origin == extend_bs + ), f"Assertion failed: extend_prefix_lens_origin={extend_prefix_lens_origin}, extend_bs={extend_bs}" encoder_lens_extend = ( forward_batch.encoder_lens[:extend_bs] if forward_batch.encoder_lens is not None diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 09d1a3edebc..96cf7bee1e5 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -142,7 +142,7 @@ class ServerArgs: disable_custom_all_reduce: bool = False disable_mla: bool = False disable_overlap_schedule: bool = False - enable_mixed_chunk: bool = False + enable_mixed_chunk: bool = True enable_dp_attention: bool = False enable_ep_moe: bool = False enable_torch_compile: bool = False @@ -775,6 +775,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--enable-mixed-chunk", action="store_true", + default=True, help="Enabling mixing prefill and decode in a batch when using chunked prefill.", ) parser.add_argument( From 62ab35612d46116d78dfa1ee7cdd3dc295e2ec27 Mon Sep 17 00:00:00 2001 From: libra Date: Mon, 13 Jan 2025 18:43:42 +0800 Subject: [PATCH 08/11] remove the unused assert --- python/sglang/srt/layers/attention/flashinfer_backend.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 451e5200702..5892819a47d 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -184,10 +184,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): ) prefix_lens_extend = forward_batch.extend_prefix_lens[:extend_bs] - extend_prefix_lens_origin = len(forward_batch.extend_prefix_lens) - assert ( - extend_prefix_lens_origin == extend_bs - ), f"Assertion failed: extend_prefix_lens_origin={extend_prefix_lens_origin}, extend_bs={extend_bs}" encoder_lens_extend = ( forward_batch.encoder_lens[:extend_bs] if forward_batch.encoder_lens is not None From 5ad9c1d602321b28bc811403038b5229831d72cd Mon Sep 17 00:00:00 2001 From: libra Date: Wed, 15 Jan 2025 11:20:00 +0800 Subject: [PATCH 09/11] add some assert --- python/sglang/srt/layers/attention/flashinfer_backend.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 5ee19078ad9..a10e223625f 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -166,10 +166,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): elif forward_batch.forward_mode.is_mixed(): # Part 0: prepare extend_bs = forward_batch.decode_start_idx - print( - f"init_forward_metadata: batch_size={forward_batch.batch_size}, extend_bs={extend_bs}" - ) - + assert extend_bs > 0, f"extent_bs = {extend_bs}" req_pool_indices_extend, req_pool_indices_decode = ( forward_batch.req_pool_indices[:extend_bs], forward_batch.req_pool_indices[extend_bs:], @@ -182,6 +179,9 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): forward_batch.seq_lens[:extend_bs].sum().item(), forward_batch.seq_lens[extend_bs:].sum().item(), ) + assert ( + seq_lens_sum_extend + seq_lens_sum_decode == forward_batch.seq_lens_sum + ), f"{seq_lens_sum_extend} + {seq_lens_sum_ex} != {forward_batch.seq_lens_sum}" prefix_lens_extend = forward_batch.extend_prefix_lens[:extend_bs] encoder_lens_extend = ( From 83341bea248211688a36d31996e27e37df5c77d8 Mon Sep 17 00:00:00 2001 From: libra Date: Thu, 16 Jan 2025 15:24:39 +0800 Subject: [PATCH 10/11] Fix the errors --- .../layers/attention/flashinfer_backend.py | 69 +++++++++++++++---- python/sglang/srt/managers/schedule_batch.py | 21 ++++++ 2 files changed, 75 insertions(+), 15 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index a10e223625f..09632e1f52b 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -99,6 +99,7 @@ def __init__(self, model_runner: ModelRunner): device=model_runner.device, ) max_bs = model_runner.req_to_token_pool.size + print(f"max batch size = {max_bs}") self.kv_indptr = [ torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device) for _ in range(self.num_wrappers) @@ -153,6 +154,9 @@ def __init__(self, model_runner: ModelRunner): self.prefill_cuda_graph_metadata = {} def init_forward_metadata(self, forward_batch: ForwardBatch): + print( + f"mode={forward_batch.forward_mode},batch_size={forward_batch.batch_size}, req_pool_indices={forward_batch.req_pool_indices}, seq_lens={forward_batch.seq_lens},decode_start_idx={forward_batch.decode_start_idx}" + ) if forward_batch.forward_mode.is_decode(): self.indices_updater_decode.update( forward_batch.req_pool_indices, @@ -198,7 +202,11 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): ) # Part1: Prefill - if forward_batch.decode_start_idx >= 4096 and self.num_wrappers == 1: + decode_running_bs = ( + forward_batch.batch_size - forward_batch.decode_start_idx + ) + extend_tokens = forward_batch.extend_num_tokens - decode_running_bs + if extend_tokens >= 4096 and self.num_wrappers == 1: use_ragged = True extend_no_prefix = not any( forward_batch.extend_prefix_lens_cpu[:extend_bs] @@ -219,6 +227,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): ) # Part2: Decode + self.indices_updater_decode.decode_start_idx = extend_bs self.indices_updater_decode.update( req_pool_indices_decode, seq_lens_decode, @@ -229,7 +238,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): ) # since the `decode_start_idx` only used in the mixed mode # Manually update it so we do not need change the update method - self.indices_updater_decode.decode_start_idx = extend_bs self.forward_metadata = MixedMetadata( self.prefill_wrappers_paged, @@ -493,14 +501,22 @@ def forward_mixed( forward_batch: ForwardBatch, save_kv_cache=True, ): + # Part0: split the prefill and decode - extend_tokens = forward_batch.decode_start_idx + decode_running_bs = forward_batch.batch_size - forward_batch.decode_start_idx + # each decode request only use one token + extend_tokens = forward_batch.extend_num_tokens - decode_running_bs - print(f"forward_mixed: extend_tokens={extend_tokens}") + print( + f"mixed forward: q={q.shape}, k={k.shape}, extend_tokens={extend_tokens}, decode_running_bs={decode_running_bs}" + ) + q_extend, q_decode = q[:extend_tokens], q[extend_tokens:] k_extend, k_decode = k[:extend_tokens], k[extend_tokens:] v_extend, v_decode = v[:extend_tokens], v[extend_tokens:] - q_extend, q_decode = q[:extend_tokens], q[extend_tokens:] + + print(f"q_extend={q_extend.shape}, q_decode={q_decode.shape}") + print(f"k_extend={k_extend.shape}, k_decode={k_decode.shape}") out_cache_loc_extend = forward_batch.out_cache_loc[:extend_tokens] out_cache_loc_decode = forward_batch.out_cache_loc[extend_tokens:] @@ -774,13 +790,17 @@ def call_begin_forward( kv_start_idx: torch.Tensor, spec_info: Optional[SpecInfo], ): + shift_pos = self.decode_start_idx if spec_info is None: bs = len(req_pool_indices) - shift_pos = self.decode_start_idx - kv_indptr[1 + shift_pos : bs + 1 + shift_pos] = torch.cumsum( - paged_kernel_lens, dim=0 + print("====Decode call_begin_forward====") + print( + f"req_pool_indices={req_pool_indices}, paged_kernel_lens={paged_kernel_lens}, decode_start_idx={self.decode_start_idx}, shift_pos = {shift_pos}" ) - kv_indptr = kv_indptr[shift_pos : bs + 1 + shift_pos] + print(f"kv_indptr before modification:, {kv_indptr[:20]}") + kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) + kv_indptr[0] = 0 + # kv_indptr = kv_indptr[shift_pos : bs + 1 + shift_pos] kv_indices = torch.empty( paged_kernel_lens_sum, dtype=torch.int32, device="cuda" @@ -789,11 +809,14 @@ def call_begin_forward( self.req_to_token, req_pool_indices, paged_kernel_lens, - kv_indptr, + kv_indptr[: bs + 1], kv_start_idx, kv_indices, self.req_to_token.shape[1], ) + print(f"kv_indptr after modification:, {kv_indptr[:20]}") + print(f"kv_indices={kv_indices}") + print("========") else: bs, kv_indices, kv_indptr = spec_info.generate_attn_arg_decode( req_pool_indices, @@ -803,9 +826,9 @@ def call_begin_forward( wrapper.end_forward() wrapper.begin_forward( - kv_indptr, + kv_indptr[: bs + 1], kv_indices, - self.kv_last_page_len[:bs], + self.kv_last_page_len[shift_pos : bs + shift_pos], self.num_qo_heads, self.num_kv_heads, self.head_dim, @@ -813,6 +836,9 @@ def call_begin_forward( data_type=self.data_type, q_data_type=self.q_data_type, ) + print( + f"decode kv_last_page_len = {self.kv_last_page_len[shift_pos:bs+shift_pos]}" + ) class FlashInferIndicesUpdaterPrefill: @@ -990,8 +1016,14 @@ def call_begin_forward( bs = len(req_pool_indices) if spec_info is None: # Normal extend + print("****Prefill call_begin_forward****") + print( + f"req_pool_indices={req_pool_indices}, paged_kernel_lens={paged_kernel_lens}" + ) + print(f"kv_indptr before modification:, {kv_indptr[:20]}") kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) - kv_indptr = kv_indptr[: bs + 1] + kv_indptr[0] = 0 + # kv_indptr = kv_indptr[: bs + 1] kv_indices = torch.empty( paged_kernel_lens_sum, dtype=torch.int32, device="cuda" ) @@ -999,7 +1031,7 @@ def call_begin_forward( self.req_to_token, req_pool_indices, paged_kernel_lens, - kv_indptr, + kv_indptr[: bs + 1], kv_start_idx, kv_indices, self.req_to_token.shape[1], @@ -1008,6 +1040,10 @@ def call_begin_forward( qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0) qo_indptr = qo_indptr[: bs + 1] custom_mask = None + print(f"kv_indptr after modification:, {kv_indptr[:20]}") + print(f"kv_indices={kv_indices}") + print(f"qo_indptr={qo_indptr}") + print("========") else: kv_indices, kv_indptr, qo_indptr, custom_mask = ( spec_info.generate_attn_arg_prefill( @@ -1033,7 +1069,9 @@ def call_begin_forward( wrapper_paged.end_forward() wrapper_paged.begin_forward( qo_indptr, - kv_indptr, + kv_indptr[ + : bs + 1 + ], # TODO(lihu): spec_info is not empty should consider, pass or modified? kv_indices, self.kv_last_page_len[:bs], self.num_qo_heads, @@ -1043,6 +1081,7 @@ def call_begin_forward( q_data_type=self.q_data_type, custom_mask=custom_mask, ) + print(f"forward kv_last_page_len = {self.kv_last_page_len[:bs]}") @triton.jit diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 2d34f53275e..a3839111f89 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -841,6 +841,13 @@ def mix_with_running(self, running_batch: "ScheduleBatch"): self.input_ids = input_ids self.out_cache_loc = out_cache_loc + print("After merging:") + print(f"decode_start_idx={self.decode_start_idx}") + print(f"seq_lens: {self.seq_lens}") + print(f"seq_lens_sum: {self.seq_lens_sum}") + print(f"batch_size: {self.batch_size()}") + print(f"self.out_cache_loc: {self.out_cache_loc}") + # For overlap scheduler, the output_ids has one step delay delta = 0 if self.enable_overlap else -1 @@ -1088,6 +1095,19 @@ def filter_batch( self.sampling_info.filter_batch(keep_indices, new_indices) def merge_batch(self, other: "ScheduleBatch"): + print("Before merging:") + print(f"self.bs={self.batch_size()}, other.bs={other.batch_size()}") + print( + f"self.req_pool_indices: {self.req_pool_indices}, other.req_pool_indices: {other.req_pool_indices}" + ) + print(f"self.seq_lens: {self.seq_lens}, other.seq_lens: {other.seq_lens}") + print( + f"self.seq_lens_sum: {self.seq_lens_sum}, other.seq_lens_sum: {other.seq_lens_sum}" + ) + print( + f"self.out_cache_loc: {self.out_cache_loc}, other.out_cache_loc: {other.out_cache_loc}" + ) + # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it # needs to be called with pre-merged Batch.reqs. @@ -1105,6 +1125,7 @@ def merge_batch(self, other: "ScheduleBatch"): self.seq_lens = torch.concat([self.seq_lens, other.seq_lens]) self.out_cache_loc = None self.seq_lens_sum += other.seq_lens_sum + if self.output_ids is not None: self.output_ids = torch.concat([self.output_ids, other.output_ids]) if self.return_logprob and other.return_logprob: From ccee4a2501f569cb33b39285d75ffc3163dcf8a4 Mon Sep 17 00:00:00 2001 From: libra Date: Thu, 16 Jan 2025 15:27:24 +0800 Subject: [PATCH 11/11] restore the default parameter --- python/sglang/srt/server_args.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index c48eb3211c1..be85a3670d4 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -143,7 +143,7 @@ class ServerArgs: disable_custom_all_reduce: bool = False disable_mla: bool = False disable_overlap_schedule: bool = False - enable_mixed_chunk: bool = True + enable_mixed_chunk: bool = False enable_dp_attention: bool = False enable_ep_moe: bool = False enable_torch_compile: bool = False @@ -778,7 +778,6 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--enable-mixed-chunk", action="store_true", - default=True, help="Enabling mixing prefill and decode in a batch when using chunked prefill.", ) parser.add_argument(