diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 00000000000..0c061cd1871 --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,35 @@ +From lmsysorg/sglang:dev + +# Create non-root user with specified UID and GID +# NOTE: Replace with your own UID and GID. This is a workaround from https://github.com/microsoft/vscode-remote-release/issues/49#issuecomment-489060908. +ARG HOST_UID=1003 +ARG HOST_GID=1003 +RUN groupadd -g $HOST_GID devuser && \ + useradd -m -u $HOST_UID -g $HOST_GID -s /bin/zsh devuser + +# Give devuser sudo access +RUN apt-get update && apt-get install -y sudo && \ + echo "devuser ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/devuser && \ + rm -rf /var/lib/apt/lists/* && \ + apt-get clean + +# Set up oh-my-zsh for devuser +RUN cp -r /root/.oh-my-zsh /home/devuser/.oh-my-zsh && \ + cp /root/.zshrc /home/devuser/.zshrc && \ + cp /root/.vimrc /home/devuser/.vimrc && \ + cp /root/.tmux.conf /home/devuser/.tmux.conf && \ + sed -i 's|/root/.oh-my-zsh|/home/devuser/.oh-my-zsh|g' /home/devuser/.zshrc && \ + chown -R devuser:devuser /home/devuser/ + +# Set workspace directory and ownership +WORKDIR /sgl-workspace/sglang +RUN chown -R devuser:devuser /sgl-workspace + +# Switch to devuser +USER devuser + +# Install uv +RUN curl -LsSf https://astral.sh/uv/install.sh | sh + +# Install rust +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 66f7aecbf82..5767aa2631a 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -1,8 +1,9 @@ { "name": "sglang", "build": { - "dockerfile": "../docker/Dockerfile.dev" + "dockerfile": "Dockerfile" }, + "remoteUser": "devuser", "customizations": { "vscode": { "extensions": [ diff --git a/.github/workflows/pr-test-sgl-kernel.yml b/.github/workflows/pr-test-sgl-kernel.yml index 794a73f3661..55eb636d64f 100644 --- a/.github/workflows/pr-test-sgl-kernel.yml +++ b/.github/workflows/pr-test-sgl-kernel.yml @@ -38,7 +38,7 @@ jobs: - name: Install run: | - pip3 install torch==2.5.1 && pip3 install pytest && pip3 install vllm + pip3 install torch==2.5.1 && pip3 install pytest && pip3 install vllm==0.6.4.post1 pip3 uninstall sgl-kernel -y || true find . -name index.lock -delete cd sgl-kernel diff --git a/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py b/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py index a2d1e10f662..57fbcfddf2c 100644 --- a/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py +++ b/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py @@ -9,6 +9,7 @@ import triton import triton.language as tl from einops import rearrange +from sgl_kernel import lightning_attention_decode as sgl_lightning_attention_decode @triton.jit @@ -332,7 +333,6 @@ def test_lightning_attention_implementations(model_params): model_params["num_attention_heads"], d, d, - dtype=dtype, device=device, ) with torch.no_grad(): @@ -350,7 +350,13 @@ def test_lightning_attention_implementations(model_params): q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + past_kv = past_kv.contiguous() + slope_rate = slope_rate.contiguous() + # Test Triton implementation triton_output, triton_new_kv = lightning_attn_decode(q, k, v, past_kv, slope_rate) triton_output = triton_output.transpose(1, 2).contiguous() triton_output = triton_output.view(batch_size, seq_len, -1) @@ -358,22 +364,50 @@ def test_lightning_attention_implementations(model_params): triton_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * triton_output triton_output = model_attn.out_proj(triton_output) + # Test SGL implementation + sgl_output = torch.empty_like(v) + sgl_new_kv = torch.empty_like(past_kv) + sgl_lightning_attention_decode(q, k, v, past_kv, slope_rate, sgl_output, sgl_new_kv) + + sgl_output = sgl_output.transpose(1, 2).contiguous() + sgl_output = sgl_output.view(batch_size, seq_len, -1) + sgl_output = model_attn.norm(sgl_output) + sgl_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * sgl_output + sgl_output = model_attn.out_proj(sgl_output) + + # Verify Triton implementation results torch.testing.assert_close( model_output, triton_output, rtol=1e-3, atol=1e-2, - msg="Lightning attention implementations produce different output results", + msg="Triton lightning attention implementation produces different output results", ) torch.testing.assert_close( new_kv, triton_new_kv, rtol=1e-3, atol=1e-2, - msg="Lightning attention implementations produce different kv results", + msg="Triton lightning attention implementation produces different kv results", ) - print("✅ Two implementations match") + # Verify SGL implementation results + torch.testing.assert_close( + model_output, + sgl_output, + rtol=1e-3, + atol=1e-2, + msg="SGL lightning attention implementation produces different output results", + ) + torch.testing.assert_close( + new_kv, + sgl_new_kv, + rtol=1e-3, + atol=1e-2, + msg="SGL lightning attention implementation produces different kv results", + ) + + print("✅ All implementations match") def _build_slope_tensor(n_attention_heads: int): @@ -408,12 +442,13 @@ def get_benchmark(): x_names=["batch_size", "seq_len"], x_vals=[list(_) for _ in configs], line_arg="provider", - line_vals=["Original", "Triton"], + line_vals=["Original", "Triton", "SGL"], line_names=[ "Original PyTorch Implementation", "Triton Implementation", + "SGL Implementation", ], - styles=[("blue", "-"), ("green", "-")], + styles=[("blue", "-"), ("green", "-"), ("red", "-")], ylabel="us", plot_name="lightning-attention-decode-performance", args={}, @@ -446,7 +481,6 @@ def benchmark(batch_size, seq_len, provider): params["num_attention_heads"], d, d, - dtype=dtype, device=device, ) @@ -461,7 +495,7 @@ def benchmark(batch_size, seq_len, provider): ), quantiles=quantiles, ) - else: + elif provider == "Triton": def run_triton(): qkv = model_attn.act(model_attn.qkv_proj(hidden_states)) @@ -483,6 +517,33 @@ def run_triton(): run_triton, quantiles=quantiles, ) + else: # SGL + + def run_sgl(): + qkv = model_attn.act(model_attn.qkv_proj(hidden_states)) + new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1) + qkv = qkv.view(*new_shape) + q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1) + q = q.transpose(1, 2).contiguous() + k = k.transpose(1, 2).contiguous() + v = v.transpose(1, 2).contiguous() + + output = torch.empty_like(v) + new_kv = torch.empty_like(past_kv) + sgl_lightning_attention_decode( + q, k, v, past_kv, slope_rate, output, new_kv + ) + + output = output.transpose(1, 2).contiguous() + output = output.view(batch_size, seq_len, -1) + output = model_attn.norm(output) + output = torch.sigmoid(model_attn.output_gate(hidden_states)) * output + return model_attn.out_proj(output) + + ms, min_ms, max_ms = triton.testing.do_bench( + run_sgl, + quantiles=quantiles, + ) return 1000 * ms, 1000 * max_ms, 1000 * min_ms diff --git a/docker/Dockerfile.dev b/docker/Dockerfile.dev index 9d05ee5997e..5ff1fa7a51a 100644 --- a/docker/Dockerfile.dev +++ b/docker/Dockerfile.dev @@ -67,12 +67,6 @@ RUN wget https://github.com/Kitware/CMake/releases/download/v3.31.1/cmake-3.31.1 && cp -r cmake-3.31.1-linux-x86_64/share/* /usr/local/share/ \ && rm -rf cmake-3.31.1-linux-x86_64 cmake-3.31.1-linux-x86_64.tar.gz -# Install uv -RUN curl -LsSf https://astral.sh/uv/install.sh | sh - -# Install rust -RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y - # Add yank script COPY --chown=root:root <<-"EOF" /usr/local/bin/yank #!/bin/bash diff --git a/docs/references/benchmark_and_profiling.md b/docs/references/benchmark_and_profiling.md index 87ac5177424..0600b192b4f 100644 --- a/docs/references/benchmark_and_profiling.md +++ b/docs/references/benchmark_and_profiling.md @@ -64,16 +64,31 @@ with nvtx.annotate("description", color="color"): ```bash # set trace path export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log + # start server python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct -python -m sglang.bench_serving --backend sglang --model-path meta-llama/Llama-3.1-8B-Instruct --num-prompts 10 --profile +# send profiling request from client +python -m sglang.bench_serving --backend sglang --model-path meta-llama/Llama-3.1-8B-Instruct --num-prompts 10 --sharegpt-output-len 100 --profile ``` - -Traces can be visualized using https://ui.perfetto.dev/. +Please make sure that the `SGLANG_TORCH_PROFILER_DIR` should be set at both server and client side, otherwise the trace file cannot be generated correctly . A secure way will be setting `SGLANG_TORCH_PROFILER_DIR` in the `.*rc` file of shell (e.g. `~/.bashrc` for bash shells). - To profile offline ```bash export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log python -m sglang.bench_offline_throughput --model-path meta-llama/Llama-3.1-8B-Instruct --dataset-name random --num-prompts 10 --profile --mem-frac=0.8 ``` + +- View Traces + +Trace files can be loaded and visualized from: +1. https://ui.perfetto.dev/ (any browser) +2. chrome://tracing (Chrome browser only) + +If browser cannot open trace file due to its large size, +client can generate a small trace file (<100MB) by controlling number of prompts and lengths of prompt outputs. +For example, when profiling a server, +```bash +python -m sglang.bench_serving --backend sglang --model-path meta-llama/Llama-3.1-8B-Instruct --num-prompts 2 --sharegpt-output-len 100 --profile +``` +sets the number of prompts to 2 with `--num-prompts` argument and limits the length of output sequences to 100 with `--sharegpt-output-len` argument, which can generate a small trace file for browser to open smoothly. diff --git a/python/sglang/srt/configs/load_config.py b/python/sglang/srt/configs/load_config.py index 2b2b341faeb..6cb35ab47c6 100644 --- a/python/sglang/srt/configs/load_config.py +++ b/python/sglang/srt/configs/load_config.py @@ -20,6 +20,7 @@ class LoadFormat(str, enum.Enum): GGUF = "gguf" BITSANDBYTES = "bitsandbytes" MISTRAL = "mistral" + LAYERED = "layered" @dataclass diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 43478f39d2c..ad265830f8f 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -1018,7 +1018,12 @@ def get_rope( head_size, rotary_dim, max_position, base, is_neox_style, dtype ) else: - scaling_type = rope_scaling["rope_type"] + if "rope_type" in rope_scaling: + scaling_type = rope_scaling["rope_type"] + elif "type" in rope_scaling: + scaling_type = rope_scaling["type"] + else: + raise ValueError("Unknown RoPE scaling type") if scaling_type == "llama3": scaling_factor = rope_scaling["factor"] diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index f3c376ed1eb..3173d533d16 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -2,12 +2,19 @@ from typing import List import torch +import torch.distributed as dist from torch import nn +from sglang.srt.distributed import get_tensor_model_parallel_group +from sglang.srt.layers.dp_attention import get_attention_tp_group from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo -from sglang.srt.utils import crash_on_warnings, is_flashinfer_available +from sglang.srt.utils import ( + crash_on_warnings, + get_bool_env_var, + is_flashinfer_available, +) if is_flashinfer_available(): from flashinfer.sampling import ( @@ -20,11 +27,17 @@ logger = logging.getLogger(__name__) +SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP") + class Sampler(nn.Module): def __init__(self): super().__init__() self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"] + self.tp_sync_group = get_tensor_model_parallel_group().device_group + + if global_server_args_dict["enable_dp_attention"]: + self.tp_sync_group = get_attention_tp_group().device_group def forward( self, @@ -121,6 +134,20 @@ def forward( batch_next_token_ids, ] + if SYNC_TOKEN_IDS_ACROSS_TP or sampling_info.grammars: + # For performance reasons, SGLang does not sync the final token IDs across TP ranks by default. + # This saves one all-reduce, but the correctness of this approach depends on the determinism of several operators: + # the last all-reduce, the last lm_head matmul, and all sampling kernels. + # These kernels are deterministic in most cases, but there are some rare instances where they are not deterministic. + # In such cases, enable this env variable to prevent hanging due to TP ranks becoming desynchronized. + # When using xgrammar, this becomes more likely so we also do the sync when grammar is used. + + torch.distributed.all_reduce( + batch_next_token_ids, + op=dist.ReduceOp.MIN, + group=self.tp_sync_group, + ) + return batch_next_token_ids.to(torch.int32) def _apply_custom_logit_processor( diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py index c5bca25df37..e08abd5ae1d 100644 --- a/python/sglang/srt/layers/torchao_utils.py +++ b/python/sglang/srt/layers/torchao_utils.py @@ -5,6 +5,7 @@ import logging import os import pwd +from typing import Callable, Optional import torch @@ -27,8 +28,18 @@ def save_gemlite_cache(print_error: bool = False) -> bool: return True +def proj_filter( + module: torch.nn.Module, + fqn: str, +): + """Filter function for quantizing projection layers.""" + return "proj" in fqn + + def apply_torchao_config_to_model( - model: torch.nn.Module, torchao_config: str, filter_fn=None + model: torch.nn.Module, + torchao_config: str, + filter_fn: Optional[Callable] = proj_filter, ): """Quantize a modelwith torchao quantization specified by torchao_config @@ -49,11 +60,6 @@ def apply_torchao_config_to_model( ) from torchao.quantization.observer import PerRow, PerTensor - if filter_fn is None: - - def filter_fn(module, fqn): - return "proj" in fqn - if torchao_config == "" or torchao_config is None: return model elif "int8wo" in torchao_config: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index d5cdcf2beb0..e7dc6bd66c5 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -185,9 +185,12 @@ def __init__( self.load_model() # Apply torchao quantization - apply_torchao_config_to_model( - self.model, global_server_args_dict["torchao_config"] - ) + torchao_applied = getattr(self.model, "torchao_applied", False) + # In layered loading, torchao may have been applied + if not torchao_applied: + apply_torchao_config_to_model( + self.model, global_server_args_dict["torchao_config"] + ) # Apply torch TP if the model supports it supports_torch_tp = getattr(self.model, "supports_torch_tp", False) diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 677d716d43b..9e6b09488e6 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -374,6 +374,78 @@ def load_model( return model.eval() +class LayeredModelLoader(DefaultModelLoader): + """Model loader that loads weights layer by layer so that one can quantize a + layer before loading another to make the peak memory envelope smaller.""" + + def __init__(self, load_config: LoadConfig): + # Back to the default load format + load_config.load_format = LoadFormat.AUTO + super().__init__(load_config) + + def load_model( + self, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + ) -> nn.Module: + from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model + from sglang.srt.managers.schedule_batch import global_server_args_dict + + torchao_config = global_server_args_dict.get("torchao_config") + target_device = torch.device(device_config.device) + + with set_default_torch_dtype(model_config.dtype): + # Create model on meta device + with torch.device("meta"): + model = _initialize_model( + model_config, + self.load_config, + ) + + # Check model's layered load support + if not hasattr(model, "load_weights_to_module"): + raise ValueError( + "LayeredModelLoader requires the model to have a " + "`load_weights_to_module` method. " + f"{model_config.model_path} does not support it." + ) + + # Get all weights from disk + weights = self._get_all_weights(model_config, model) + + # Helper function to recursively fill the weights of a module + def fill_module(module, fqn: List[str], weights): + """ + fqn: list of strings representing the fully qualified name of `module`. + """ + # Layer by layer + for name, submod in module.named_children(): + fill_module(submod, fqn + [name], weights) + + # First materialize on target device + module.to_empty(device=target_device, recurse=False) + fqn_path = ".".join(fqn) + # Fill weights + model.load_weights_to_module( + fqn_path, + weights, + ) + # Quantize weights if applicable + if torchao_config and "proj" in fqn_path: + # Note: `None` here is needed to indicate no filter, see + # `apply_torchao_config_to_model` for details. + apply_torchao_config_to_model(module, torchao_config, None) + + # Start calling on root module + fill_module(model, [], weights) + + if torchao_config: + model.torchao_applied = True + + return model.eval() + + class DummyModelLoader(BaseModelLoader): """Model loader that will set model weights to random values.""" @@ -1149,4 +1221,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: if load_config.load_format == LoadFormat.GGUF: return GGUFModelLoader(load_config) + if load_config.load_format == LoadFormat.LAYERED: + return LayeredModelLoader(load_config) + return DefaultModelLoader(load_config) diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index 77c3fcbee74..f2f67ecab1d 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -27,6 +27,7 @@ import numpy as np import torch from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download +from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator from safetensors.torch import load_file, safe_open, save_file from tqdm.auto import tqdm @@ -650,6 +651,81 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: return name +# Adapted from https://github.com/vllm-project/vllm/blob/68ad4e3a8d8a66fb2a43be57471ee13a8bec4ec0/vllm/model_executor/layers/quantization/schema.py +class KVCacheQuantSchema(BaseModel): + dtype: str + # Each key is a TP rank. Each value is a dictionary mapping a TP rank's + # layer indices to their per-tensor KV cache scaling factor. + # TODO: Consider pulling this and its validation methods out into its + # own schema class (tricky as its members are variable) + scaling_factor: Dict[int, Dict[int, float]] + + @model_validator(mode="after") + def check_is_fp8(self) -> "KVCacheQuantSchema": + assert self.dtype == "float8_e4m3fn", ( + "Loaded scaling factors intended for KV cache dtype = " + f"{self.dtype} rather than float8_e4m3fn!" + ) + return self + + @model_validator(mode="after") + def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema": + context = info.context + if context: + tp_size = context["tp_size"] + num_hidden_layers = context["num_hidden_layers"] + assert len(self.scaling_factor) == tp_size, ( + f"Loaded dictionary has TP size {len(self.scaling_factor)} " + f"but LLM engine is currently running with TP size {tp_size}." + ) + for tp_rank, layer_maps in self.scaling_factor.items(): + assert len(layer_maps) == num_hidden_layers, ( + f"KV cache scales map for TP rank {tp_rank} is malformed. " + f"Expected {num_hidden_layers} layers, got " + f"{len(layer_maps)}." + ) + for i in range(tp_size): + assert ( + i in self.scaling_factor + ), f"KV cache scales map for TP rank {i} not found." + return self + + @model_validator(mode="after") + def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema": + context = info.context + if context: + tp_rank = context["tp_rank"] + num_hidden_layers = context["num_hidden_layers"] + layer_scales_map = self.scaling_factor[tp_rank] + for i in range(num_hidden_layers): + assert i in layer_scales_map, ( + f"Could not find KV cache scales for layer {i} in " + f"TP rank {tp_rank}." + ) + return self + + +class QuantParamSchema(BaseModel): + # TODO: Generalize and extend with more fields + # (e.g. weights/activations params) once functionality is enabled + model_config = ConfigDict(protected_namespaces=()) + model_type: Optional[str] + kv_cache: KVCacheQuantSchema + + @model_validator(mode="after") + def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema": + context = info.context + if context: + model_type = context.get("model_type", None) + if model_type is not None: + assert model_type == self.model_type, ( + f"Model type is {model_type} but loaded " + f"scaling factors belonging to different " + f"model type {self.model_type}!" + ) + return self + + def kv_cache_scales_loader( filename: str, tp_rank: int, @@ -681,7 +757,7 @@ def kv_cache_scales_loader( except json.JSONDecodeError: logger.error("Error decoding JSON in file '%s'.", filename) except Exception: - logger.exception("An error occurred while reading '%s'.", filename) + logger.error("An error occurred while reading '%s'.", filename) # This section is reached if and only if any of the excepts are hit # Return an empty iterable (list) => no KV cache scales are loaded # which ultimately defaults to 1.0 scales diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py index 024a6f317fa..7b3e5bc5ddd 100644 --- a/python/sglang/srt/models/torch_native_llama.py +++ b/python/sglang/srt/models/torch_native_llama.py @@ -460,7 +460,12 @@ def get_num_params(self): params_dict = dict(self.named_parameters()) return len(params_dict) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights_to_module( + self, + fqn: str, + weights: Iterable[Tuple[str, torch.Tensor]], + ): + """Load weights onto submodule pointed by path `fqn`.""" stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -469,7 +474,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ] - params_dict = dict(self.named_parameters()) + module = self.get_submodule(fqn) + params_dict = dict(module.named_parameters(prefix=fqn, recurse=False)) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name or "projector" in name: @@ -486,7 +492,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + if name.endswith(".bias") or name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader @@ -494,12 +500,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): break else: # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + if name.endswith(".bias") or name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + def load_weights( + self, + weights: Iterable[Tuple[str, torch.Tensor]], + ): + """Load weights onto the full model.""" + self.load_weights_to_module("", weights) + class TorchNativePhi3ForCausalLM(TorchNativeLlamaForCausalLM): pass diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 4a7a28751db..330c3813288 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -317,6 +317,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "dummy", "gguf", "bitsandbytes", + "layered", ], help="The format of the model weights to load. " '"auto" will try to load the weights in the safetensors format ' @@ -330,7 +331,10 @@ def add_cli_args(parser: argparse.ArgumentParser): "which is mainly for profiling." '"gguf" will load the weights in the gguf format. ' '"bitsandbytes" will load the weights using bitsandbytes ' - "quantization.", + "quantization." + '"layered" loads weights layer by layer so that one can quantize a ' + "layer before loading another to make the peak memory envelope " + "smaller.", ) parser.add_argument( "--trust-remote-code", diff --git a/sgl-kernel/Makefile b/sgl-kernel/Makefile index c7641bb5fee..9261b896934 100644 --- a/sgl-kernel/Makefile +++ b/sgl-kernel/Makefile @@ -19,7 +19,7 @@ clean: @rm -rf build dist *.egg-info test: - @find tests -name "test_*.py" | xargs -n 1 python3 + @find tests -name "test_*.py" | xargs -n 1 python3 && pytest tests/test_norm.py && pytest tests/test_activation.py format: @find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' -o -name '*.hpp' | xargs clang-format -i && find src tests -name '*.py' | xargs isort && find src tests -name '*.py' | xargs black diff --git a/sgl-kernel/benchmark/bench_lightning_attention_decode.py b/sgl-kernel/benchmark/bench_lightning_attention_decode.py new file mode 100644 index 00000000000..24872e61a4d --- /dev/null +++ b/sgl-kernel/benchmark/bench_lightning_attention_decode.py @@ -0,0 +1,299 @@ +import itertools +import math + +import torch +import triton +import triton.language as tl +from sgl_kernel import lightning_attention_decode + + +def next_power_of_2(n): + return 2 ** (int(math.ceil(math.log(n, 2)))) + + +@triton.jit +def _decode_kernel( + Q, + K, + V, + KV, + Out, + S, + b: tl.constexpr, + h: tl.constexpr, + n: tl.constexpr, + d: tl.constexpr, + d_original: tl.constexpr, + e: tl.constexpr, + e_original: tl.constexpr, +): + off_bh = tl.program_id(0) + off_h = off_bh % h + + qk_offset = off_bh * n * d + v_offset = off_bh * n * e + o_offset = off_bh * n * e + kv_offset = off_bh * d * e + + s = tl.load(S + off_h) + ratio = tl.exp(-s) + + d_idx = tl.arange(0, d) + e_idx = tl.arange(0, e) + + # Create masks for original dimensions + d_mask = d_idx < d_original + e_mask = e_idx < e_original + + # Load with masking + q = tl.load(Q + qk_offset + d_idx, mask=d_mask, other=0.0) + k = tl.load(K + qk_offset + d_idx, mask=d_mask, other=0.0) + v = tl.load(V + v_offset + e_idx, mask=e_mask, other=0.0) + + # Load KV with 2D masking + kv = tl.load( + KV + kv_offset + d_idx[:, None] * e + e_idx[None, :], + mask=(d_mask[:, None] & e_mask[None, :]), + other=0.0, + ) + + # Compute outer product using element-wise operations + k_v_prod = k[:, None] * v[None, :] + kv = ratio * kv + k_v_prod + + # Store KV with 2D masking + tl.store( + KV + kv_offset + d_idx[:, None] * e + e_idx[None, :], + kv.to(KV.dtype.element_ty), + mask=(d_mask[:, None] & e_mask[None, :]), + ) + + # Compute matrix-vector multiplication using element-wise operations and reduction + o = tl.sum(q[:, None] * kv, axis=0) + + # Store output with masking + tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty), mask=e_mask) + + +def triton_lightning_attn_decode(q, k, v, kv, s): + """Triton implementation of Lightning Attention decode operation""" + b, h, n, d = q.shape + e = v.shape[-1] + assert n == 1, "Sequence length must be 1 in decode mode" + + # Get padded dimensions (power of 2) + d_padded = next_power_of_2(d) + e_padded = next_power_of_2(e) + + # Create output tensor (padded) + o_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device) + + # Create padded tensors without actually padding the data + q_padded = torch.empty(b, h, n, d_padded, dtype=q.dtype, device=q.device) + k_padded = torch.empty(b, h, n, d_padded, dtype=k.dtype, device=k.device) + v_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device) + kv_padded = torch.empty( + b, h, d_padded, e_padded, dtype=torch.float32, device=kv.device + ) + + # Copy data to padded tensors + q_padded[..., :d] = q + k_padded[..., :d] = k + v_padded[..., :e] = v + kv_padded[..., :d, :e] = kv + + # Launch kernel + grid = (b * h, 1) + _decode_kernel[grid]( + q_padded, + k_padded, + v_padded, + kv_padded, + o_padded, + s, + b=b, + h=h, + n=n, + d=d_padded, + d_original=d, + e=e_padded, + e_original=e, + ) + + # Get unpadded outputs + o = o_padded[..., :e] + kv_out = kv_padded[..., :d, :e] + + return o, kv_out + + +def lightning_attention_decode_naive(q, k, v, past_kv, slope): + """Naive implementation of lightning attention decode""" + original_dtype = q.dtype + ratio = torch.exp(-slope) # [h, 1, 1] + + kv = past_kv + b, h, n, d = q.shape + + output = [] + for i in range(n): + kv = ratio * kv.to(torch.float32) + torch.einsum( + "... n d, ... n e -> ... d e", + k[:, :, i : i + 1], + v[:, :, i : i + 1], + ) + qkv = torch.einsum( + "... n e, ... e d -> ... n d", + q[:, :, i : i + 1].to(torch.float32), + kv.to(torch.float32), + ) + output.append(qkv) + output = torch.concat(output, dim=-2) + + return output.to(original_dtype), kv + + +def lightning_attention_decode_kernel(q, k, v, past_kv, slope, output, new_kv): + return lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv) + + +def calculate_diff(batch_size): + dtype = torch.bfloat16 + device = torch.device("cuda") + num_heads = 64 + head_dim = 96 + seq_len = 1 + + q = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + k = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + v = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device) + slope = torch.randn(num_heads, 1, 1, device=device) + + output_naive, new_kv_naive = lightning_attention_decode_naive( + q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone() + ) + + output_kernel = torch.empty_like(output_naive) + new_kv_kernel = torch.empty_like(new_kv_naive) + lightning_attention_decode_kernel( + q.clone(), + k.clone(), + v.clone(), + past_kv.clone(), + slope.clone(), + output_kernel, + new_kv_kernel, + ) + + output_triton, new_kv_triton = triton_lightning_attn_decode( + q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone() + ) + + if ( + torch.allclose(output_naive, output_kernel, atol=1e-2, rtol=1e-2) + and torch.allclose(output_naive, output_triton, atol=1e-2, rtol=1e-2) + and torch.allclose(new_kv_naive, new_kv_kernel, atol=1e-2, rtol=1e-2) + and torch.allclose(new_kv_naive, new_kv_triton, atol=1e-2, rtol=1e-2) + ): + print("✅ All implementations match") + else: + print("❌ Implementations differ") + + +batch_size_range = [i for i in range(1, 65)] # 1 to 128 +configs = [(bs,) for bs in batch_size_range] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["naive", "kernel", "triton"], + line_names=["PyTorch Naive", "SGL Kernel", "Triton"], + styles=[("blue", "-"), ("red", "-"), ("green", "-")], + ylabel="us", + plot_name="lightning-attention-decode-performance", + args={}, + ) +) +def benchmark(batch_size, provider): + dtype = torch.bfloat16 + device = torch.device("cuda") + num_heads = 64 + head_dim = 96 + seq_len = 1 + + q = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + k = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + v = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device) + slope = torch.randn(num_heads, 1, 1, device=device) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "naive": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: lightning_attention_decode_naive( + q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone() + ), + quantiles=quantiles, + ) + elif provider == "kernel": + output = torch.empty( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + new_kv = torch.empty(batch_size, num_heads, head_dim, head_dim, device=device) + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: lightning_attention_decode_kernel( + q.clone(), + k.clone(), + v.clone(), + past_kv.clone(), + slope.clone(), + output, + new_kv, + ), + quantiles=quantiles, + ) + elif provider == "triton": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: triton_lightning_attn_decode( + q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone() + ), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_path", + type=str, + default="./configs/benchmark_ops/lightning_attention_decode_sgl/", + help="Path to save lightning attention decode benchmark results", + ) + args = parser.parse_args() + + # Run correctness test + calculate_diff(batch_size=4) + + # Run performance benchmark + benchmark.run(print_data=True) diff --git a/sgl-kernel/developer_guide.md b/sgl-kernel/developer_guide.md new file mode 100644 index 00000000000..8afb6b0e460 --- /dev/null +++ b/sgl-kernel/developer_guide.md @@ -0,0 +1,51 @@ +# Developer Guide for sgl-kernel + +## Development Environment Setup + +Use Docker to set up the development environment. See [Docker setup guide](https://github.com/sgl-project/sglang/blob/main/docs/developer/development_guide_using_docker.md#setup-docker-container). + +Create and enter development container: +```bash +docker run -itd --shm-size 32g --gpus all -v $HOME/.cache:/root/.cache --ipc=host --name sglang_zhyncs lmsysorg/sglang:dev /bin/zsh +docker exec -it sglang_zhyncs /bin/zsh +``` + +## Project Structure + +### Dependencies + +Third-party libraries: + +- [CCCL](https://github.com/NVIDIA/cccl) +- [CUTLASS](https://github.com/NVIDIA/cutlass) +- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) + +### Kernel Development + +Steps to add a new kernel: + +1. Implement in [src/sgl-kernel/csrc/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/src/sgl-kernel/csrc) +2. Expose interface in [csrc/sgl_kernel_ops.cu](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu) with pybind11 +3. Create Python wrapper in [src/sgl-kernel/ops/__init__.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/ops/__init__.py) +4. Expose Python interface in [src/sgl-kernel/__init__.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/__init__.py) +5. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source + +### Build & Install + +Development build: + +```bash +make build +pip3 install dist/*whl --force-reinstall --no-deps +# Or use: make install (runs pip install -e .) +``` + +### Testing & Benchmarking + +1. Add pytest tests in [tests/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/tests) +2. Add benchmarks using [triton benchmark](https://triton-lang.org/main/python-api/generated/triton.testing.Benchmark.html) in [benchmark/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/benchmark) +3. Run test suite + +### Release new version + +Update version in [pyproject.toml](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/pyproject.toml) diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index ab9d68b44c8..11e9880a5af 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -14,9 +14,7 @@ classifiers = [ "License :: OSI Approved :: Apache Software License", "Environment :: GPU :: NVIDIA CUDA" ] -dependencies = [ - "torch", -] +dependencies = [] [project.urls] "Homepage" = "https://github.com/sgl-project/sglang" diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index b9324c35543..c51fd704504 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -62,12 +62,22 @@ def get_device_sm(): "-std=c++17", "-use_fast_math", "-DFLASHINFER_ENABLE_F16", - "-DFLASHINFER_ENABLE_BF16", ] if cuda_version >= (12, 0) and sm_version >= 90: nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a") +if sm_version >= 90: + nvcc_flags.extend( + [ + "-DFLASHINFER_ENABLE_FP8", + "-DFLASHINFER_ENABLE_FP8_E4M3", + "-DFLASHINFER_ENABLE_FP8_E5M2", + ] + ) +if sm_version >= 80: + nvcc_flags.append("-DFLASHINFER_ENABLE_BF16") + for flag in [ "-D__CUDA_NO_HALF_OPERATORS__", "-D__CUDA_NO_HALF_CONVERSIONS__", @@ -90,6 +100,7 @@ def get_device_sm(): "src/sgl-kernel/csrc/moe_align_kernel.cu", "src/sgl-kernel/csrc/int8_gemm_kernel.cu", "src/sgl-kernel/csrc/sampling_scaling_penalties.cu", + "src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu", "src/sgl-kernel/csrc/sgl_kernel_ops.cu", "src/sgl-kernel/csrc/rotary_embedding.cu", "3rdparty/flashinfer/csrc/activation.cu", @@ -116,7 +127,6 @@ def get_device_sm(): package_dir={"": "src"}, ext_modules=ext_modules, cmdclass={"build_ext": BuildExtension}, - install_requires=["torch"], ) update_wheel_platform_tag() diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py index bdbc0ce846c..9eaa64e5083 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/src/sgl-kernel/__init__.py @@ -1,23 +1,31 @@ from sgl_kernel.ops import ( + bmm_fp8, custom_dispose, custom_reduce, fused_add_rmsnorm, + gelu_and_mul, + gelu_tanh_and_mul, gemma_fused_add_rmsnorm, gemma_rmsnorm, get_graph_buffer_ipc_meta, init_custom_reduce, int8_scaled_mm, + lightning_attention_decode, moe_align_block_size, register_graph_buffers, rmsnorm, rotary_embedding, sampling_scaling_penalties, + silu_and_mul, ) __all__ = [ + "bmm_fp8", "custom_dispose", "custom_reduce", "fused_add_rmsnorm", + "gelu_and_mul", + "gelu_tanh_and_mul", "gemma_fused_add_rmsnorm", "gemma_rmsnorm", "get_graph_buffer_ipc_meta", @@ -28,4 +36,6 @@ "rmsnorm", "rotary_embedding", "sampling_scaling_penalties", + "lightning_attention_decode", + "silu_and_mul", ] diff --git a/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu index 8e3f7275702..c77851c32b6 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu @@ -16,7 +16,7 @@ #include "cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h" #include "cutlass_extensions/gemm/gemm_universal_base_compat.h" #include "cutlass_extensions/gemm/gemm_with_epilogue_visitor.h" -#include "utils.hpp" +#include "utils.h" using namespace cute; diff --git a/sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu new file mode 100644 index 00000000000..eb79373b22c --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu @@ -0,0 +1,119 @@ +#include +#include +#include +#include +#include + +#include "utils.h" + +#define THREADS_PER_BLOCK 128 + +template +__global__ void lightning_attention_decode_kernel(const T* __restrict__ q, // [b, h, 1, d] + const T* __restrict__ k, // [b, h, 1, d] + const T* __restrict__ v, // [b, h, 1, e] + const float* __restrict__ past_kv, // [b, h, d, e] + const float* __restrict__ slope, // [h, 1, 1] + T* __restrict__ output, // [b, h, 1, e] + float* __restrict__ new_kv, // [b, h, d, e] + const int batch_size, const int num_heads, const int qk_dim, + const int v_dim) { + extern __shared__ char smem[]; + T* q_shared = reinterpret_cast(smem); + T* k_shared = reinterpret_cast(smem + qk_dim * sizeof(T)); + T* v_shared = reinterpret_cast(smem + 2 * qk_dim * sizeof(T)); + float* new_kv_shared = reinterpret_cast(smem + (2 * qk_dim + v_dim) * sizeof(T)); + T* output_shared = + reinterpret_cast(smem + (2 * qk_dim + v_dim) * sizeof(T) + qk_dim * (v_dim + 1) * sizeof(float)); + + const int32_t tid = threadIdx.x; + const int32_t current_head = blockIdx.x; + const int32_t b = current_head / num_heads; + const int32_t h = current_head % num_heads; + + if (b >= batch_size) return; + + const int32_t qk_offset = b * num_heads * qk_dim + h * qk_dim; + const int32_t v_offset = b * num_heads * v_dim + h * v_dim; + const int32_t kv_offset = b * num_heads * qk_dim * v_dim + h * qk_dim * v_dim; + + for (int d = tid; d < qk_dim; d += blockDim.x) { + q_shared[d] = q[qk_offset + d]; + k_shared[d] = k[qk_offset + d]; + } + for (int e = tid; e < v_dim; e += blockDim.x) { + v_shared[e] = v[v_offset + e]; + } + + __syncthreads(); + + const float ratio = expf(-1.0f * slope[h]); + + for (int d = tid; d < qk_dim; d += blockDim.x) { + T k_val = k_shared[d]; + for (int e = 0; e < v_dim; ++e) { + int past_kv_idx = kv_offset + d * v_dim + e; + T v_val = v_shared[e]; + float new_val = ratio * past_kv[past_kv_idx] + k_val * v_val; + int shared_idx = d * (v_dim + 1) + e; + new_kv_shared[shared_idx] = new_val; + } + } + + __syncthreads(); + + for (int idx = tid; idx < qk_dim * v_dim; idx += blockDim.x) { + int d = idx / v_dim; + int e = idx % v_dim; + int shared_idx = d * (v_dim + 1) + e; + int global_idx = kv_offset + idx; + new_kv[global_idx] = new_kv_shared[shared_idx]; + } + + __syncthreads(); + + for (int e = tid; e < v_dim; e += blockDim.x) { + float sum = 0.0f; + for (int d = 0; d < qk_dim; ++d) { + int shared_idx = d * (v_dim + 1) + e; + sum += q_shared[d] * new_kv_shared[shared_idx]; + } + output_shared[e] = static_cast(sum); + } + + __syncthreads(); + + if (tid == 0) { + for (int e = 0; e < v_dim; ++e) { + output[v_offset + e] = output_shared[e]; + } + } +} + +void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, + const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output, + torch::Tensor new_kv) { + TORCH_CHECK(q.is_contiguous(), "q must be contiguous"); + TORCH_CHECK(k.is_contiguous(), "k must be contiguous"); + TORCH_CHECK(v.is_contiguous(), "v must be contiguous"); + TORCH_CHECK(past_kv.is_contiguous(), "past_kv must be contiguous"); + + auto batch_size = q.size(0); + auto num_heads = q.size(1); + auto qk_dim = q.size(3); + auto v_dim = v.size(3); + + dim3 block(THREADS_PER_BLOCK); + dim3 grid(batch_size * num_heads); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, q.scalar_type(), "lightning_attention_decode_kernel", ([&] { + size_t smem_size = (2 * qk_dim + 2 * v_dim) * sizeof(scalar_t) + qk_dim * (v_dim + 1) * sizeof(float); + lightning_attention_decode_kernel<<>>( + q.data_ptr(), k.data_ptr(), v.data_ptr(), past_kv.data_ptr(), + slope.data_ptr(), output.data_ptr(), new_kv.data_ptr(), batch_size, num_heads, + qk_dim, v_dim); + })); +} diff --git a/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu index c7faf9d3775..83861aee071 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu @@ -6,7 +6,7 @@ #include -#include "utils.hpp" +#include "utils.h" #ifdef USE_ROCM #include diff --git a/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu b/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu index a61d4b86059..2f53bb1a99f 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu @@ -4,7 +4,7 @@ #include -#include "utils.hpp" +#include "utils.h" #include "vectorization.cuh" template diff --git a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu index 8f9d1ae5333..cd5df07895a 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu @@ -1,6 +1,6 @@ #include -#include "utils.hpp" +#include "utils.h" // trt_reduce using fptr_t = int64_t; @@ -26,6 +26,11 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma const torch::Tensor& scales_b, const torch::Dtype& out_dtype, const c10::optional& bias); +// lightning_attention_decode +void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, + const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output, + torch::Tensor new_kv); + // rotary embedding void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox); @@ -43,6 +48,19 @@ void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, do void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, int64_t cuda_stream); +// silu and mul +void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); + +// gelu tanh and mul +void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); + +// gelu and mul +void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); + +// bmm fp8 +void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::Tensor B_scale, + at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // trt_reduce m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)"); @@ -56,6 +74,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("sampling_scaling_penalties", &sampling_scaling_penalties, "Sampling scaling penalties (CUDA)"); // int8_scaled_mm m.def("int8_scaled_mm", &int8_scaled_mm, "INT8 scaled matmul (CUDA)"); + // lightning_attention_decode + m.def("lightning_attention_decode", &lightning_attention_decode, "Lightning Attention Ddecode (CUDA)"); // rotary embedding m.def("rotary_embedding", &rotary_embedding, "Rotary Embedding (CUDA)"); // rms norm @@ -66,4 +86,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma RMSNorm (CUDA)"); // fused gemma rms norm m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm, "Gemma Fused Add RMSNorm (CUDA)"); + // silu and mul + m.def("silu_and_mul", &silu_and_mul, "Silu and Mul (CUDA)"); + // gelu tanh and mul + m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Gelu Tanh and Mul (CUDA)"); + // gelu and mul + m.def("gelu_and_mul", &gelu_and_mul, "Gelu and Mul (CUDA)"); + // bmm fp8 + m.def("bmm_fp8", &bmm_fp8, "BMM FP8 (CUDA)"); } diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh index 9d6f9722eb5..22ba0e414fc 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh @@ -21,7 +21,7 @@ #include #include -#include "utils.hpp" +#include "utils.h" namespace trt_llm { constexpr size_t WARP_SIZE = 32; diff --git a/sgl-kernel/src/sgl-kernel/csrc/utils.hpp b/sgl-kernel/src/sgl-kernel/csrc/utils.h similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/utils.hpp rename to sgl-kernel/src/sgl-kernel/csrc/utils.h diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py index bbfd76878a7..0aead260bc4 100644 --- a/sgl-kernel/src/sgl-kernel/ops/__init__.py +++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py @@ -2,8 +2,11 @@ import torch from sgl_kernel.ops._kernels import all_reduce as _all_reduce +from sgl_kernel.ops._kernels import bmm_fp8 as _bmm_fp8 from sgl_kernel.ops._kernels import dispose as _dispose from sgl_kernel.ops._kernels import fused_add_rmsnorm as _fused_add_rmsnorm +from sgl_kernel.ops._kernels import gelu_and_mul as _gelu_and_mul +from sgl_kernel.ops._kernels import gelu_tanh_and_mul as _gelu_tanh_and_mul from sgl_kernel.ops._kernels import gemma_fused_add_rmsnorm as _gemma_fused_add_rmsnorm from sgl_kernel.ops._kernels import gemma_rmsnorm as _gemma_rmsnorm from sgl_kernel.ops._kernels import ( @@ -11,6 +14,9 @@ ) from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm +from sgl_kernel.ops._kernels import ( + lightning_attention_decode as _lightning_attention_decode, +) from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size from sgl_kernel.ops._kernels import register_graph_buffers as _register_graph_buffers from sgl_kernel.ops._kernels import rmsnorm as _rmsnorm @@ -18,10 +24,8 @@ from sgl_kernel.ops._kernels import ( sampling_scaling_penalties as _sampling_scaling_penalties, ) - - -def get_cuda_stream(device: torch.device) -> int: - return torch.cuda.current_stream(device).cuda_stream +from sgl_kernel.ops._kernels import silu_and_mul as _silu_and_mul +from sgl_kernel.ops.utils import _get_cache_buf, _get_cuda_stream def init_custom_reduce( @@ -85,10 +89,16 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): ) +def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv): + _lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv) + + def rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox): return _rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox) +# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer +# Kudos to @yzh119 def rmsnorm( input: torch.Tensor, weight: torch.Tensor, @@ -98,7 +108,7 @@ def rmsnorm( with input.device as device: if out is None: out = torch.empty_like(input) - _rmsnorm(out, input, weight, eps, get_cuda_stream(device)) + _rmsnorm(out, input, weight, eps, _get_cuda_stream(device)) return out @@ -106,7 +116,7 @@ def fused_add_rmsnorm( input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 ) -> None: with input.device as device: - _fused_add_rmsnorm(input, residual, weight, eps, get_cuda_stream(device)) + _fused_add_rmsnorm(input, residual, weight, eps, _get_cuda_stream(device)) def gemma_rmsnorm( @@ -118,7 +128,7 @@ def gemma_rmsnorm( with input.device as device: if out is None: out = torch.empty_like(input) - _gemma_rmsnorm(out, input, weight, eps, get_cuda_stream(device)) + _gemma_rmsnorm(out, input, weight, eps, _get_cuda_stream(device)) return out @@ -126,4 +136,103 @@ def gemma_fused_add_rmsnorm( input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 ) -> None: with input.device as device: - _gemma_fused_add_rmsnorm(input, residual, weight, eps, get_cuda_stream(device)) + _gemma_fused_add_rmsnorm(input, residual, weight, eps, _get_cuda_stream(device)) + + +def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None: + assert input.ndim == output.ndim, f"{input.ndim} != {output.ndim}" + assert ( + input.shape[:-1] == output.shape[:-1] + ), f"{input.shape[:-1]} != {output.shape[:-1]}" + assert ( + input.shape[-1] == 2 * output.shape[-1] + ), f"{input.shape[-1]} != {2 * output.shape[-1]}" + + +def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: + if input.shape[-1] * input.dtype.itemsize % 16 != 0: + raise ValueError("The pointers must be multiple of 16 bytes.") + if out is not None: + _check_shape(input, out) + else: + out = torch.empty( + input.shape[:-1] + (input.shape[-1] // 2,), + device=input.device, + dtype=input.dtype, + ) + with input.device as device: + _silu_and_mul(out, input, _get_cuda_stream(device)) + return out + + +def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: + if input.shape[-1] * input.dtype.itemsize % 16 != 0: + raise ValueError("The pointers must be multiple of 16 bytes.") + if out is not None: + _check_shape(input, out) + else: + out = torch.empty( + input.shape[:-1] + (input.shape[-1] // 2,), + device=input.device, + dtype=input.dtype, + ) + with input.device as device: + _gelu_tanh_and_mul(out, input, _get_cuda_stream(device)) + return out + + +def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: + if input.shape[-1] * input.dtype.itemsize % 16 != 0: + raise ValueError("The pointers must be multiple of 16 bytes.") + if out is not None: + _check_shape(input, out) + else: + out = torch.empty( + input.shape[:-1] + (input.shape[-1] // 2,), + device=input.device, + dtype=input.dtype, + ) + with input.device as device: + _gelu_and_mul(out, input, _get_cuda_stream(device)) + return out + + +def _bmm_fp8_internal( + workspace_buffer: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, + D: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, +) -> None: + with A.device as device: + cublas_handle = torch.cuda.current_blas_handle() + _bmm_fp8( + A, + B, + D, + A_scale, + B_scale, + workspace_buffer, + cublas_handle, + _get_cuda_stream(device), + ) + + +def bmm_fp8( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if out is None: + out = torch.empty( + (A.shape[0], A.shape[1], B.shape[2]), + device=A.device, + dtype=dtype, + ) + workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device) + _bmm_fp8_internal(workspace_buffer, A, B, out, A_scale, B_scale) + return out diff --git a/sgl-kernel/src/sgl-kernel/ops/utils.py b/sgl-kernel/src/sgl-kernel/ops/utils.py new file mode 100644 index 00000000000..af5fccbb786 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/ops/utils.py @@ -0,0 +1,19 @@ +from typing import Dict, Tuple + +import torch + + +def _get_cuda_stream(device: torch.device) -> int: + return torch.cuda.current_stream(device).cuda_stream + + +_cache_buf: Dict[Tuple[str, torch.device], torch.Tensor] = {} + + +def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor: + key = (name, device) + buf = _cache_buf.get(key) + if buf is None: + buf = torch.empty(bytes, dtype=torch.uint8, device=device) + _cache_buf[key] = buf + return buf diff --git a/sgl-kernel/tests/test_activation.py b/sgl-kernel/tests/test_activation.py new file mode 100644 index 00000000000..f71f36b513d --- /dev/null +++ b/sgl-kernel/tests/test_activation.py @@ -0,0 +1,38 @@ +# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_activation.py + +import pytest +import sgl_kernel +import torch + + +@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384]) +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512]) +def test_fused_silu_mul(dim, batch_size, seq_len): + x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16) + y_ref = x[..., dim:] * torch.nn.functional.silu(x[..., :dim]) + y = sgl_kernel.silu_and_mul(x) + torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384]) +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512]) +def test_fused_gelu_tanh_mul(dim, batch_size, seq_len): + x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16) + y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="tanh") + y = sgl_kernel.gelu_tanh_and_mul(x) + torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384]) +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512]) +def test_fused_gelu_mul(dim, batch_size, seq_len): + x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16) + y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="none") + y = sgl_kernel.gelu_and_mul(x) + torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + + +test_fused_silu_mul(128, 1, 1) diff --git a/sgl-kernel/tests/test_bmm_fp8.py b/sgl-kernel/tests/test_bmm_fp8.py new file mode 100644 index 00000000000..e0be92896f6 --- /dev/null +++ b/sgl-kernel/tests/test_bmm_fp8.py @@ -0,0 +1,43 @@ +# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_bmm_fp8.py + +import pytest +import torch +import torch.nn.functional as F +from sgl_kernel import bmm_fp8 + + +def to_float8(x, dtype=torch.float8_e4m3fn): + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype), scale.float().reciprocal() + + +@pytest.mark.parametrize("input_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.parametrize("mat2_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16]) +def test_bmm_fp8(input_dtype, mat2_dtype, res_dtype): + if input_dtype == torch.float8_e5m2 and mat2_dtype == torch.float8_e5m2: + pytest.skip("Invalid combination: both input and mat2 are e5m2") + + input = torch.randn([16, 48, 64], device="cuda", dtype=torch.bfloat16) + input_fp8, input_inv_s = to_float8(input, dtype=input_dtype) + + # mat2 row major -> column major + mat2 = torch.randn([16, 80, 64], device="cuda", dtype=torch.bfloat16).transpose( + -2, -1 + ) + mat2_fp8, mat2_inv_s = to_float8(mat2, dtype=mat2_dtype) + + res = torch.empty([16, 48, 80], device="cuda", dtype=res_dtype) + bmm_fp8(input_fp8, mat2_fp8, input_inv_s, mat2_inv_s, res_dtype, res) + + reference = torch.bmm(input, mat2) + cos_sim = F.cosine_similarity(reference.reshape(-1), res.reshape(-1), dim=0) + assert cos_sim > 0.99 + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_lightning_attention_decode.py b/sgl-kernel/tests/test_lightning_attention_decode.py new file mode 100644 index 00000000000..74af78e27b5 --- /dev/null +++ b/sgl-kernel/tests/test_lightning_attention_decode.py @@ -0,0 +1,84 @@ +import pytest +import torch +from sgl_kernel import lightning_attention_decode + + +def naive_lightning_attention_decode(q, k, v, past_kv, slope): + """Naive implementation of lightning attention decode""" + original_dtype = q.dtype + ratio = torch.exp(-slope) # [h, 1, 1] + + kv = past_kv + b, h, n, d = q.shape + + output = [] + for i in range(n): + kv = ratio * kv.to(torch.float32) + torch.einsum( + "... n d, ... n e -> ... d e", + k[:, :, i : i + 1], + v[:, :, i : i + 1], + ) + qkv = torch.einsum( + "... n e, ... e d -> ... n d", + q[:, :, i : i + 1].to(torch.float32), + kv.to(torch.float32), + ) + output.append(qkv) + output = torch.concat(output, dim=-2) + + return output.to(original_dtype), kv + + +configs = [ + # (batch_size, num_heads, dim, embed_dim) + (1, 8, 64, 64), + (2, 8, 64, 64), + (1, 32, 32, 64), + (2, 32, 32, 64), + (4, 32, 64, 64), + (4, 32, 64, 64), + (16, 64, 96, 96), + (64, 64, 96, 96), +] + +dtypes = [torch.float32, torch.float16, torch.bfloat16] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("batch_size,num_heads,dim,embed_dim", configs) +def test_lightning_attention_decode(dtype, batch_size, num_heads, dim, embed_dim): + device = torch.device("cuda") + + q = torch.randn(batch_size, num_heads, 1, dim, device=device, dtype=dtype) + k = torch.randn(batch_size, num_heads, 1, dim, device=device, dtype=dtype) + v = torch.randn(batch_size, num_heads, 1, embed_dim, device=device, dtype=dtype) + past_kv = torch.randn(batch_size, num_heads, dim, embed_dim, device=device) + slope = torch.randn(num_heads, 1, 1, device=device) + + ref_output, ref_new_kv = naive_lightning_attention_decode(q, k, v, past_kv, slope) + + output = torch.empty_like(ref_output) + new_kv = torch.empty_like(ref_new_kv) + lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv) + + rtol = 1e-2 + atol = 1e-2 + + torch.testing.assert_close( + output, + ref_output, + rtol=rtol, + atol=atol, + msg=f"Output mismatch for batch_size={batch_size}, num_heads={num_heads}, " + f"dim={dim}, embed_dim={embed_dim}, dtype={dtype}", + ) + + torch.testing.assert_close( + new_kv, + ref_new_kv, + rtol=rtol, + atol=atol, + msg=f"New KV mismatch for batch_size={batch_size}, num_heads={num_heads}, " + f"dim={dim}, embed_dim={embed_dim}, dtype={dtype}", + ) diff --git a/sgl-router/src/router.rs b/sgl-router/src/router.rs index 5bbffc74ccf..a189ff9eb88 100644 --- a/sgl-router/src/router.rs +++ b/sgl-router/src/router.rs @@ -238,12 +238,12 @@ impl Router { loop { if start_time.elapsed() > Duration::from_secs(timeout_secs) { error!( - "Timeout {}s waiting for workers to become healthy", - timeout_secs + "Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", + timeout_secs, worker_urls ); return Err(format!( - "Timeout {}s waiting for workers to become healthy", - timeout_secs + "Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", + timeout_secs, worker_urls )); } @@ -644,11 +644,11 @@ impl Router { loop { if start_time.elapsed() > Duration::from_secs(timeout_secs) { error!( - "Timeout {}s waiting for worker {} to become healthy", + "Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", timeout_secs, worker_url ); return Err(format!( - "Timeout {}s waiting for worker {} to become healthy", + "Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", timeout_secs, worker_url )); }