Skip to content

Commit

Permalink
Merge branch 'main' into debug_radixcache_stack_overflow
Browse files Browse the repository at this point in the history
  • Loading branch information
luzengxiangcn authored Jan 23, 2025
2 parents 2cc5089 + 553f5a3 commit 0eaec72
Show file tree
Hide file tree
Showing 34 changed files with 1,185 additions and 61 deletions.
35 changes: 35 additions & 0 deletions .devcontainer/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
From lmsysorg/sglang:dev

# Create non-root user with specified UID and GID
# NOTE: Replace with your own UID and GID. This is a workaround from https://github.com/microsoft/vscode-remote-release/issues/49#issuecomment-489060908.
ARG HOST_UID=1003
ARG HOST_GID=1003
RUN groupadd -g $HOST_GID devuser && \
useradd -m -u $HOST_UID -g $HOST_GID -s /bin/zsh devuser

# Give devuser sudo access
RUN apt-get update && apt-get install -y sudo && \
echo "devuser ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/devuser && \
rm -rf /var/lib/apt/lists/* && \
apt-get clean

# Set up oh-my-zsh for devuser
RUN cp -r /root/.oh-my-zsh /home/devuser/.oh-my-zsh && \
cp /root/.zshrc /home/devuser/.zshrc && \
cp /root/.vimrc /home/devuser/.vimrc && \
cp /root/.tmux.conf /home/devuser/.tmux.conf && \
sed -i 's|/root/.oh-my-zsh|/home/devuser/.oh-my-zsh|g' /home/devuser/.zshrc && \
chown -R devuser:devuser /home/devuser/

# Set workspace directory and ownership
WORKDIR /sgl-workspace/sglang
RUN chown -R devuser:devuser /sgl-workspace

# Switch to devuser
USER devuser

# Install uv
RUN curl -LsSf https://astral.sh/uv/install.sh | sh

# Install rust
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
3 changes: 2 additions & 1 deletion .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
{
"name": "sglang",
"build": {
"dockerfile": "../docker/Dockerfile.dev"
"dockerfile": "Dockerfile"
},
"remoteUser": "devuser",
"customizations": {
"vscode": {
"extensions": [
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pr-test-sgl-kernel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:

- name: Install
run: |
pip3 install torch==2.5.1 && pip3 install pytest && pip3 install vllm
pip3 install torch==2.5.1 && pip3 install pytest && pip3 install vllm==0.6.4.post1
pip3 uninstall sgl-kernel -y || true
find . -name index.lock -delete
cd sgl-kernel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import triton
import triton.language as tl
from einops import rearrange
from sgl_kernel import lightning_attention_decode as sgl_lightning_attention_decode


@triton.jit
Expand Down Expand Up @@ -332,7 +333,6 @@ def test_lightning_attention_implementations(model_params):
model_params["num_attention_heads"],
d,
d,
dtype=dtype,
device=device,
)
with torch.no_grad():
Expand All @@ -350,30 +350,64 @@ def test_lightning_attention_implementations(model_params):
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
past_kv = past_kv.contiguous()
slope_rate = slope_rate.contiguous()

# Test Triton implementation
triton_output, triton_new_kv = lightning_attn_decode(q, k, v, past_kv, slope_rate)
triton_output = triton_output.transpose(1, 2).contiguous()
triton_output = triton_output.view(batch_size, seq_len, -1)
triton_output = model_attn.norm(triton_output)
triton_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * triton_output
triton_output = model_attn.out_proj(triton_output)

# Test SGL implementation
sgl_output = torch.empty_like(v)
sgl_new_kv = torch.empty_like(past_kv)
sgl_lightning_attention_decode(q, k, v, past_kv, slope_rate, sgl_output, sgl_new_kv)

sgl_output = sgl_output.transpose(1, 2).contiguous()
sgl_output = sgl_output.view(batch_size, seq_len, -1)
sgl_output = model_attn.norm(sgl_output)
sgl_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * sgl_output
sgl_output = model_attn.out_proj(sgl_output)

# Verify Triton implementation results
torch.testing.assert_close(
model_output,
triton_output,
rtol=1e-3,
atol=1e-2,
msg="Lightning attention implementations produce different output results",
msg="Triton lightning attention implementation produces different output results",
)
torch.testing.assert_close(
new_kv,
triton_new_kv,
rtol=1e-3,
atol=1e-2,
msg="Lightning attention implementations produce different kv results",
msg="Triton lightning attention implementation produces different kv results",
)

print("✅ Two implementations match")
# Verify SGL implementation results
torch.testing.assert_close(
model_output,
sgl_output,
rtol=1e-3,
atol=1e-2,
msg="SGL lightning attention implementation produces different output results",
)
torch.testing.assert_close(
new_kv,
sgl_new_kv,
rtol=1e-3,
atol=1e-2,
msg="SGL lightning attention implementation produces different kv results",
)

print("✅ All implementations match")


def _build_slope_tensor(n_attention_heads: int):
Expand Down Expand Up @@ -408,12 +442,13 @@ def get_benchmark():
x_names=["batch_size", "seq_len"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["Original", "Triton"],
line_vals=["Original", "Triton", "SGL"],
line_names=[
"Original PyTorch Implementation",
"Triton Implementation",
"SGL Implementation",
],
styles=[("blue", "-"), ("green", "-")],
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
ylabel="us",
plot_name="lightning-attention-decode-performance",
args={},
Expand Down Expand Up @@ -446,7 +481,6 @@ def benchmark(batch_size, seq_len, provider):
params["num_attention_heads"],
d,
d,
dtype=dtype,
device=device,
)

Expand All @@ -461,7 +495,7 @@ def benchmark(batch_size, seq_len, provider):
),
quantiles=quantiles,
)
else:
elif provider == "Triton":

def run_triton():
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
Expand All @@ -483,6 +517,33 @@ def run_triton():
run_triton,
quantiles=quantiles,
)
else: # SGL

def run_sgl():
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
qkv = qkv.view(*new_shape)
q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous()
v = v.transpose(1, 2).contiguous()

output = torch.empty_like(v)
new_kv = torch.empty_like(past_kv)
sgl_lightning_attention_decode(
q, k, v, past_kv, slope_rate, output, new_kv
)

output = output.transpose(1, 2).contiguous()
output = output.view(batch_size, seq_len, -1)
output = model_attn.norm(output)
output = torch.sigmoid(model_attn.output_gate(hidden_states)) * output
return model_attn.out_proj(output)

ms, min_ms, max_ms = triton.testing.do_bench(
run_sgl,
quantiles=quantiles,
)

return 1000 * ms, 1000 * max_ms, 1000 * min_ms

Expand Down
6 changes: 0 additions & 6 deletions docker/Dockerfile.dev
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,6 @@ RUN wget https://github.com/Kitware/CMake/releases/download/v3.31.1/cmake-3.31.1
&& cp -r cmake-3.31.1-linux-x86_64/share/* /usr/local/share/ \
&& rm -rf cmake-3.31.1-linux-x86_64 cmake-3.31.1-linux-x86_64.tar.gz

# Install uv
RUN curl -LsSf https://astral.sh/uv/install.sh | sh

# Install rust
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y

# Add yank script
COPY --chown=root:root <<-"EOF" /usr/local/bin/yank
#!/bin/bash
Expand Down
21 changes: 18 additions & 3 deletions docs/references/benchmark_and_profiling.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,31 @@ with nvtx.annotate("description", color="color"):
```bash
# set trace path
export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log

# start server
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct

python -m sglang.bench_serving --backend sglang --model-path meta-llama/Llama-3.1-8B-Instruct --num-prompts 10 --profile
# send profiling request from client
python -m sglang.bench_serving --backend sglang --model-path meta-llama/Llama-3.1-8B-Instruct --num-prompts 10 --sharegpt-output-len 100 --profile
```

Traces can be visualized using https://ui.perfetto.dev/.
Please make sure that the `SGLANG_TORCH_PROFILER_DIR` should be set at both server and client side, otherwise the trace file cannot be generated correctly . A secure way will be setting `SGLANG_TORCH_PROFILER_DIR` in the `.*rc` file of shell (e.g. `~/.bashrc` for bash shells).

- To profile offline
```bash
export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log
python -m sglang.bench_offline_throughput --model-path meta-llama/Llama-3.1-8B-Instruct --dataset-name random --num-prompts 10 --profile --mem-frac=0.8
```

- View Traces

Trace files can be loaded and visualized from:
1. https://ui.perfetto.dev/ (any browser)
2. chrome://tracing (Chrome browser only)

If browser cannot open trace file due to its large size,
client can generate a small trace file (<100MB) by controlling number of prompts and lengths of prompt outputs.
For example, when profiling a server,
```bash
python -m sglang.bench_serving --backend sglang --model-path meta-llama/Llama-3.1-8B-Instruct --num-prompts 2 --sharegpt-output-len 100 --profile
```
sets the number of prompts to 2 with `--num-prompts` argument and limits the length of output sequences to 100 with `--sharegpt-output-len` argument, which can generate a small trace file for browser to open smoothly.
1 change: 1 addition & 0 deletions python/sglang/srt/configs/load_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class LoadFormat(str, enum.Enum):
GGUF = "gguf"
BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral"
LAYERED = "layered"


@dataclass
Expand Down
7 changes: 6 additions & 1 deletion python/sglang/srt/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,7 +1018,12 @@ def get_rope(
head_size, rotary_dim, max_position, base, is_neox_style, dtype
)
else:
scaling_type = rope_scaling["rope_type"]
if "rope_type" in rope_scaling:
scaling_type = rope_scaling["rope_type"]
elif "type" in rope_scaling:
scaling_type = rope_scaling["type"]
else:
raise ValueError("Unknown RoPE scaling type")

if scaling_type == "llama3":
scaling_factor = rope_scaling["factor"]
Expand Down
29 changes: 28 additions & 1 deletion python/sglang/srt/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,19 @@
from typing import List

import torch
import torch.distributed as dist
from torch import nn

from sglang.srt.distributed import get_tensor_model_parallel_group
from sglang.srt.layers.dp_attention import get_attention_tp_group
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
from sglang.srt.utils import (
crash_on_warnings,
get_bool_env_var,
is_flashinfer_available,
)

if is_flashinfer_available():
from flashinfer.sampling import (
Expand All @@ -20,11 +27,17 @@

logger = logging.getLogger(__name__)

SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP")


class Sampler(nn.Module):
def __init__(self):
super().__init__()
self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"]
self.tp_sync_group = get_tensor_model_parallel_group().device_group

if global_server_args_dict["enable_dp_attention"]:
self.tp_sync_group = get_attention_tp_group().device_group

def forward(
self,
Expand Down Expand Up @@ -121,6 +134,20 @@ def forward(
batch_next_token_ids,
]

if SYNC_TOKEN_IDS_ACROSS_TP or sampling_info.grammars:
# For performance reasons, SGLang does not sync the final token IDs across TP ranks by default.
# This saves one all-reduce, but the correctness of this approach depends on the determinism of several operators:
# the last all-reduce, the last lm_head matmul, and all sampling kernels.
# These kernels are deterministic in most cases, but there are some rare instances where they are not deterministic.
# In such cases, enable this env variable to prevent hanging due to TP ranks becoming desynchronized.
# When using xgrammar, this becomes more likely so we also do the sync when grammar is used.

torch.distributed.all_reduce(
batch_next_token_ids,
op=dist.ReduceOp.MIN,
group=self.tp_sync_group,
)

return batch_next_token_ids.to(torch.int32)

def _apply_custom_logit_processor(
Expand Down
18 changes: 12 additions & 6 deletions python/sglang/srt/layers/torchao_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import os
import pwd
from typing import Callable, Optional

import torch

Expand All @@ -27,8 +28,18 @@ def save_gemlite_cache(print_error: bool = False) -> bool:
return True


def proj_filter(
module: torch.nn.Module,
fqn: str,
):
"""Filter function for quantizing projection layers."""
return "proj" in fqn


def apply_torchao_config_to_model(
model: torch.nn.Module, torchao_config: str, filter_fn=None
model: torch.nn.Module,
torchao_config: str,
filter_fn: Optional[Callable] = proj_filter,
):
"""Quantize a modelwith torchao quantization specified by torchao_config
Expand All @@ -49,11 +60,6 @@ def apply_torchao_config_to_model(
)
from torchao.quantization.observer import PerRow, PerTensor

if filter_fn is None:

def filter_fn(module, fqn):
return "proj" in fqn

if torchao_config == "" or torchao_config is None:
return model
elif "int8wo" in torchao_config:
Expand Down
9 changes: 6 additions & 3 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,12 @@ def __init__(
self.load_model()

# Apply torchao quantization
apply_torchao_config_to_model(
self.model, global_server_args_dict["torchao_config"]
)
torchao_applied = getattr(self.model, "torchao_applied", False)
# In layered loading, torchao may have been applied
if not torchao_applied:
apply_torchao_config_to_model(
self.model, global_server_args_dict["torchao_config"]
)

# Apply torch TP if the model supports it
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
Expand Down
Loading

0 comments on commit 0eaec72

Please sign in to comment.