diff --git a/docs/references/supported_models.md b/docs/references/supported_models.md
index 9dafc3d2a3d..0c9586c671d 100644
--- a/docs/references/supported_models.md
+++ b/docs/references/supported_models.md
@@ -24,7 +24,7 @@
- InternLM 2
- Exaone 3
- BaiChuan2
-- MiniCPM / MiniCPM 3
+- MiniCPM / MiniCPM 3 / MiniCPMV
- XVERSE / XVERSE MoE
- SmolLM
- GLM-4
diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py
index 4a774c4fb6b..845e1e52dda 100644
--- a/python/sglang/lang/chat_template.py
+++ b/python/sglang/lang/chat_template.py
@@ -88,7 +88,6 @@ def get_chat_template_by_model_path(model_path):
)
)
-
register_chat_template(
ChatTemplate(
name="claude",
@@ -101,7 +100,6 @@ def get_chat_template_by_model_path(model_path):
)
)
-
register_chat_template(
ChatTemplate(
name="chatml",
@@ -116,7 +114,6 @@ def get_chat_template_by_model_path(model_path):
)
)
-
register_chat_template(
ChatTemplate(
name="chatml-llava",
@@ -132,7 +129,6 @@ def get_chat_template_by_model_path(model_path):
)
)
-
# There is default system prompt for qwen
# reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
# The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
@@ -219,6 +215,21 @@ def get_chat_template_by_model_path(model_path):
)
)
+# https://huggingface.co/openbmb/MiniCPM-V-2_6
+register_chat_template(
+ ChatTemplate(
+ name="minicpmv",
+ default_system_prompt=None,
+ role_prefix_and_suffix={
+ "system": ("", " "),
+ "user": ("user:", " "),
+ "assistant": ("assistant:", ""),
+ },
+ stop_str=("<|im_end|>", "<|endoftext|>"),
+ image_token="(./)",
+ )
+)
+
# The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
register_chat_template(
ChatTemplate(
diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py
index a2f9b82844e..1472b0f1694 100644
--- a/python/sglang/srt/configs/model_config.py
+++ b/python/sglang/srt/configs/model_config.py
@@ -393,6 +393,7 @@ def is_multimodal_model(model_architectures: List[str]):
or "LlavaVidForCausalLM" in model_architectures
or "MllamaForConditionalGeneration" in model_architectures
or "Qwen2VLForConditionalGeneration" in model_architectures
+ or "MiniCPMV" in model_architectures
):
return True
else:
diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py
index 60dba87cb08..3a775aa1e95 100644
--- a/python/sglang/srt/conversation.py
+++ b/python/sglang/srt/conversation.py
@@ -452,7 +452,6 @@ def generate_chat_conv(
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
-
return conv
@@ -555,3 +554,17 @@ def generate_chat_conv(
image_token="<|vision_start|><|image_pad|><|vision_end|>",
)
)
+
+# Reference: https://huggingface.co/openbmb/MiniCPM-V-2_6#usage
+register_conv_template(
+ Conversation(
+ name="minicpmv",
+ system_message="You are a helpful assistant",
+ system_template="<|im_start|>system\n{system_message}.",
+ roles=("<|im_start|>user", "<|im_start|>assistant"),
+ sep="<|im_end|>\n",
+ sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
+ stop_str=("<|im_end|>", "<|endoftext|>"),
+ image_token="(./)",
+ )
+)
diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py
index 7ca1d51a756..5008985e8fe 100644
--- a/python/sglang/srt/layers/logits_processor.py
+++ b/python/sglang/srt/layers/logits_processor.py
@@ -129,7 +129,7 @@ def forward(
hidden_states,
lm_head: VocabParallelEmbedding,
logits_metadata: Union[LogitsMetadata, ForwardBatch],
- ):
+ ) -> LogitsProcessorOutput:
if isinstance(logits_metadata, ForwardBatch):
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
diff --git a/python/sglang/srt/managers/image_processor.py b/python/sglang/srt/managers/image_processor.py
index 7120fa48d52..f47acb1aafd 100644
--- a/python/sglang/srt/managers/image_processor.py
+++ b/python/sglang/srt/managers/image_processor.py
@@ -9,6 +9,8 @@
import numpy as np
import transformers
+from decord import VideoReader, cpu
+from PIL import Image
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.mm_utils import expand2square, process_anyres_image
@@ -36,6 +38,7 @@ class BaseImageProcessor(ABC):
def __init__(self, hf_config, server_args, _processor):
self.hf_config = hf_config
self._processor = _processor
+ self.server_args = server_args
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
@@ -229,6 +232,126 @@ async def process_images_async(
return image_inputs
+class MiniCPMVImageProcessor(BaseImageProcessor):
+ def __init__(self, hf_config, server_args, _processor):
+ super().__init__(hf_config, server_args, _processor)
+
+ @staticmethod
+ def _process_images_task(images, input_text):
+ # print("_process_images_task...")
+ result = global_processor.__call__(
+ text=input_text, images=images, return_tensors="pt"
+ )
+ return {
+ "input_ids": result["input_ids"],
+ "pixel_values": result["pixel_values"],
+ "tgt_sizes": result["tgt_sizes"],
+ "image_bound": result["image_bound"],
+ }
+
+ async def _process_images(self, images, input_text):
+ if self.executor is not None:
+ loop = asyncio.get_event_loop()
+ image_inputs = await loop.run_in_executor(
+ self.executor,
+ MiniCPMVImageProcessor._process_images_task,
+ images,
+ input_text,
+ )
+ else:
+ image_inputs = self._processor(
+ images=images, text=input_text, return_tensors="pt"
+ )
+
+ return image_inputs
+
+ async def process_images_async(
+ self, image_data: List[Union[str, bytes]], input_text, request_obj
+ ):
+ if not image_data:
+ return None
+
+ # print("process_images_async...")
+
+ if not isinstance(image_data, list):
+ image_data = [image_data]
+
+ image_hashes, image_sizes = [], []
+ raw_images = []
+ IMAGE_TOKEN = "(./)"
+
+ def encode_video(video_path):
+ if not os.path.exists(video_path):
+ return []
+ MAX_NUM_FRAMES = 3 # if cuda OOM set a smaller number
+
+ def uniform_sample(l, n):
+ gap = len(l) / n
+ idxs = [int(i * gap + gap / 2) for i in range(n)]
+ return [l[i] for i in idxs]
+
+ vr = VideoReader(video_path, ctx=cpu(0))
+ sample_fps = round(vr.get_avg_fps() / 1) # FPS
+ frame_idx = [i for i in range(0, len(vr), sample_fps)]
+ if len(frame_idx) > MAX_NUM_FRAMES:
+ frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
+ frames = vr.get_batch(frame_idx).asnumpy()
+ frames = [Image.fromarray(v.astype("uint8")) for v in frames]
+ return frames
+
+ if isinstance(input_text, list):
+ assert len(input_text) and isinstance(input_text[0], int)
+ input_text = self._processor.tokenizer.decode(input_text)
+
+ # MiniCPMV requires each frame of video as a single image token
+ text_parts = input_text.split(IMAGE_TOKEN)
+ new_text_parts = []
+
+ for image_index, image in enumerate(image_data):
+ if isinstance(image, str) and image.startswith("video:"):
+ path = image[len("video:") :]
+ frames: List[Image] = encode_video(path)
+ image_sizes += frames[0].size * len(frames)
+ else:
+ raw_image, size = load_image(image)
+ image_sizes += [size]
+ frames = [raw_image]
+
+ image_hashes += [hash(image)]
+ raw_images += frames
+ new_text_parts.append(text_parts[image_index])
+ new_text_parts.append(IMAGE_TOKEN * len(frames))
+
+ new_text_parts.append(text_parts[-1])
+ input_text = "".join(new_text_parts)
+ res = await self._process_images(images=raw_images, input_text=input_text)
+ pixel_values = res["pixel_values"]
+ tgt_sizes = res["tgt_sizes"]
+ image_bound = res["image_bound"]
+ input_ids = res["input_ids"]
+
+ # Collect special token ids
+ tokenizer = self._processor.tokenizer
+ im_start_id = [tokenizer.im_start_id]
+ im_end_id = [tokenizer.im_end_id]
+ if tokenizer.slice_start_id:
+ slice_start_id = [tokenizer.slice_start_id]
+ slice_end_id = [tokenizer.slice_end_id]
+
+ return {
+ "input_ids": input_ids.flatten().tolist(),
+ "pixel_values": pixel_values,
+ "tgt_sizes": tgt_sizes,
+ "image_hashes": image_hashes,
+ "modalities": request_obj.modalities or ["image"],
+ "im_start_id": im_start_id,
+ "im_end_id": im_end_id,
+ "slice_start_id": slice_start_id,
+ "slice_end_id": slice_end_id,
+ "image_bound": image_bound,
+ }
+
+
class Qwen2VLImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _image_processor):
self.hf_config = hf_config
@@ -350,6 +473,8 @@ def get_image_processor(
return MllamaImageProcessor(hf_config, server_args, processor)
elif "Qwen2VLForConditionalGeneration" in hf_config.architectures:
return Qwen2VLImageProcessor(hf_config, server_args, processor.image_processor)
+ elif "MiniCPMV" in hf_config.architectures:
+ return MiniCPMVImageProcessor(hf_config, server_args, processor)
else:
return LlavaImageProcessor(hf_config, server_args, processor.image_processor)
diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py
index 3b056cc5d49..9877f338cb7 100644
--- a/python/sglang/srt/managers/schedule_batch.py
+++ b/python/sglang/srt/managers/schedule_batch.py
@@ -52,7 +52,6 @@
if TYPE_CHECKING:
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
-
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
# Put some global args for easy access
@@ -67,7 +66,6 @@
"enable_ep_moe": ServerArgs.enable_ep_moe,
}
-
logger = logging.getLogger(__name__)
@@ -148,6 +146,17 @@ class ImageInputs:
image_grid_thws: List[Tuple[int, int, int]] = None
mrope_position_delta: Optional[torch.Tensor] = None
+ # MiniCPMV related
+ # All the images in the batch should share the same special image
+ # bound token ids.
+ im_start_id: Optional[torch.Tensor] = None
+ im_end_id: Optional[torch.Tensor] = None
+ slice_start_id: Optional[torch.Tensor] = None
+ slice_end_id: Optional[torch.Tensor] = None
+ image_bound: Optional[torch.Tensor] = None
+
+ tgt_sizes: Optional[list] = None
+
@staticmethod
def from_dict(obj: dict):
ret = ImageInputs(
@@ -167,6 +176,12 @@ def from_dict(obj: dict):
"aspect_ratio_ids",
"aspect_ratio_mask",
"image_grid_thws",
+ "im_start_id",
+ "im_end_id",
+ "slice_start_id",
+ "slice_end_id",
+ "tgt_sizes",
+ "image_bound",
]
for arg in optional_args:
if arg in obj:
@@ -1136,7 +1151,6 @@ def get_model_worker_batch(self):
global bid
bid += 1
-
return ModelWorkerBatch(
bid=bid,
forward_mode=self.forward_mode,
diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py
index 719db19cd76..e837ea7113f 100644
--- a/python/sglang/srt/model_executor/model_runner.py
+++ b/python/sglang/srt/model_executor/model_runner.py
@@ -220,7 +220,7 @@ def init_torch_distributed(self):
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
if not self.is_draft_worker:
- # Only initilzie the distributed environment on the target model worker.
+ # Only initialize the distributed environment on the target model worker.
init_distributed_environment(
backend=backend,
world_size=self.tp_size,
diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py
index c8ce9302b4f..867a470b72c 100644
--- a/python/sglang/srt/models/llava.py
+++ b/python/sglang/srt/models/llava.py
@@ -403,7 +403,7 @@ def forward(
pt += 1
return self.language_model(
- input_ids, positions, forward_batch, input_embeds=input_embeds
+ input_ids, positions, forward_batch, inputs_embeds=input_embeds
)
elif forward_batch.forward_mode.is_decode():
return self.language_model(input_ids, positions, forward_batch)
diff --git a/python/sglang/srt/models/minicpmv.py b/python/sglang/srt/models/minicpmv.py
new file mode 100644
index 00000000000..60b803c119d
--- /dev/null
+++ b/python/sglang/srt/models/minicpmv.py
@@ -0,0 +1,1322 @@
+# Adapted from
+# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
+# Copyright 2023 The vLLM team.
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
+from functools import cached_property, partial
+from typing import (
+ Any,
+ Callable,
+ Iterable,
+ List,
+ Literal,
+ Optional,
+ Tuple,
+ TypedDict,
+ Union,
+)
+
+import torch
+import torch.types
+from PIL import Image
+from torch import nn
+from torch.nn.init import trunc_normal_
+from transformers import PretrainedConfig
+from transformers.models.idefics2.configuration_idefics2 import (
+ Idefics2Config,
+ Idefics2VisionConfig,
+)
+from vllm.distributed import divide, get_tensor_model_parallel_world_size
+from vllm.model_executor.layers.resampler import get_2d_sincos_pos_embed
+from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
+from vllm.model_executor.model_loader.utils import set_default_torch_dtype
+from vllm.model_executor.models.module_mapping import MultiModelKeys
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+
+from sglang.srt.layers.activation import get_act_fn
+from sglang.srt.layers.linear import (
+ ColumnParallelLinear,
+ QKVParallelLinear,
+ ReplicatedLinear,
+ RowParallelLinear,
+)
+from sglang.srt.layers.logits_processor import LogitsProcessor
+from sglang.srt.layers.quantization.base_config import QuantizationConfig
+from sglang.srt.managers.schedule_batch import ImageInputs
+from sglang.srt.model_executor.forward_batch_info import ForwardBatch
+from sglang.srt.model_loader.utils import set_default_torch_dtype
+from sglang.srt.model_loader.weight_utils import default_weight_loader
+from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
+
+RawImageType = Union[Image.Image, torch.Tensor]
+
+
+class MultiHeadAttention(nn.Module):
+ """Multi-headed attention without any cache, used for ViT."""
+
+ def __init__(
+ self,
+ num_heads: int,
+ head_size: int,
+ scale: float,
+ num_kv_heads: Optional[int] = None,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ self.head_size = head_size
+ self.scale = scale
+ self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
+
+ # dtype = torch.get_default_dtype()
+ # self.attn_backend = FlashInferAttnBackend(self)
+ # attn_backend = get_attn_backend(head_size,
+ # dtype,
+ # kv_cache_dtype=None,
+ # block_size=16,
+ # is_attention_free=False)
+ # attn_backend = backend_name_to_enum(attn_backend.get_name())
+ # if attn_backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
+ # attn_backend = _Backend.XFORMERS
+ #
+ # self.attn_backend = attn_backend if attn_backend in {
+ # _Backend.TORCH_SDPA, _Backend.XFORMERS
+ # } else _Backend.TORCH_SDPA
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ forward_batch: ForwardBatch,
+ ) -> torch.Tensor:
+ """Input shape: batch_size x seq_len x hidden_size"""
+ # TODO(Isotr0py): Use existing backend implementations and support FA2
+ bsz, q_len, _ = query.size()
+ kv_len = key.size(1)
+
+ query = query.view(bsz, q_len, self.num_heads, self.head_size)
+ key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
+ value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
+ import triton
+
+ out = triton.ops.attention(query, key, value, scale=self.scale)
+ return out.view(bsz, q_len, -1)
+
+
+class Idefics2VisionAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ config: Idefics2Config,
+ layer_id: int,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" # noqa: E501
+ f" {self.num_heads})."
+ )
+ self.scale = self.head_dim**-0.5
+ self.dropout = config.attention_dropout
+ self.qkv_proj = QKVParallelLinear(
+ self.embed_dim,
+ self.head_dim,
+ self.num_heads,
+ quant_config=quant_config,
+ prefix=f"{prefix}.qkv_proj",
+ )
+ self.out_proj = RowParallelLinear(
+ self.embed_dim,
+ self.embed_dim,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.out_proj",
+ )
+ self.tp_size = get_tensor_model_parallel_world_size()
+ self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
+ # TODO: does forward_batch supports two attentions at the same time?
+ # self.attn = RadixAttention(num_heads=self.num_heads_per_partition,
+ # head_dim=self.head_dim,
+ # scaling=self.scale,
+ # num_kv_heads=self.num_heads,
+ # layer_id=layer_id
+ # )
+ self.attn = MultiHeadAttention(
+ self.num_heads_per_partition, self.head_dim, self.scale
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ forward_batch: ForwardBatch,
+ ) -> torch.Tensor:
+ qkv, _ = self.qkv_proj(
+ hidden_states
+ ) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
+ query_states, key_states, value_states = qkv.chunk(3, dim=-1)
+ out = self.attn(
+ query_states, key_states, value_states, forward_batch=forward_batch
+ )
+ attn_output, _ = self.out_proj(out)
+ return attn_output
+
+
+class Idefics2VisionMLP(nn.Module):
+
+ def __init__(
+ self,
+ config: Idefics2Config,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.config = config
+ self.activation_fn = get_act_fn(config.hidden_act)
+ self.fc1 = ColumnParallelLinear(
+ config.hidden_size,
+ config.intermediate_size,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.fc1",
+ )
+ self.fc2 = RowParallelLinear(
+ config.intermediate_size,
+ config.hidden_size,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.fc2",
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states, _ = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states, _ = self.fc2(hidden_states)
+ return hidden_states
+
+
+class Idefics2EncoderLayer(nn.Module):
+
+ def __init__(
+ self,
+ config: Idefics2Config,
+ layer_id: int,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = Idefics2VisionAttention(
+ config,
+ quant_config=quant_config,
+ layer_id=layer_id,
+ prefix=f"{prefix}.self_attn",
+ )
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.mlp = Idefics2VisionMLP(
+ config, quant_config=quant_config, prefix=f"{prefix}.mlp"
+ )
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+
+ def forward(
+ self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
+ ) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`):
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
+
+ """
+ residual = hidden_states
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states = self.self_attn(hidden_states, forward_batch=forward_batch)
+ hidden_states = residual + hidden_states
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+class Idefics2Encoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention
+ layers. Each layer is a
+ [`Idefics2EncoderLayer`].
+
+ Args:
+ config: Idefics2Config
+ """
+
+ def __init__(
+ self,
+ config: Idefics2Config,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+
+ self.config = config
+ self.layers = nn.ModuleList(
+ [
+ Idefics2EncoderLayer(
+ config,
+ quant_config=quant_config,
+ layer_id=layer_id,
+ prefix=f"{prefix}.layers.{layer_id}",
+ )
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ forward_batch: ForwardBatch,
+ ) -> torch.Tensor:
+ r"""
+ Args:
+ inputs_embeds (torch.Tensor):
+ Optionally, instead of passing `input_ids` you can choose to
+ directly pass an embedded representation.
+ This is useful if you want more control over how to convert
+ `input_ids` indices into associated vectorsthan the model's
+ internal embedding lookup matrix.
+ """
+ hidden_states = inputs_embeds
+ for encoder_layer in self.layers:
+ layer_outputs = encoder_layer(hidden_states, forward_batch=forward_batch)
+ hidden_states = layer_outputs
+ return hidden_states
+
+
+class Idefics2VisionEmbeddings(nn.Module):
+ """
+ This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings
+ ` to enable images of variable
+ resolution.
+
+ The modifications are adapted from [Patch n' Pack: NaViT, a Vision
+ Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
+ which allows treating images in their native aspect ratio and without the
+ need to resize them to the same fixed size. In particular, we start from the
+ original pre-trained SigLIP model(which uses images of fixed-size square
+ images) and adapt it by training on images of variable resolutions.
+ """
+
+ def __init__(self, config: Idefics2VisionConfig):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ padding="valid",
+ )
+ self.num_patches_per_side = self.image_size // self.patch_size
+ self.num_patches = self.num_patches_per_side**2
+ self.num_positions = self.num_patches
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ patch_attention_mask: torch.BoolTensor,
+ forward_batch: ForwardBatch,
+ tgt_sizes: Optional[torch.IntTensor] = None,
+ ) -> torch.Tensor:
+ batch_size, _, max_im_h, max_im_w = pixel_values.shape
+ target_dtype = self.patch_embedding.weight.dtype
+ pixel_values = pixel_values.to(
+ device=self.patch_embedding.weight.device, dtype=target_dtype
+ )
+ patch_embeds = self.patch_embedding(pixel_values)
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
+ max_nb_patches_h, max_nb_patches_w = (
+ max_im_h // self.patch_size,
+ max_im_w // self.patch_size,
+ )
+ boundaries = torch.arange(
+ 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side
+ )
+ position_ids = torch.full(
+ size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0
+ )
+
+ for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
+
+ if tgt_sizes is not None:
+ nb_patches_h = tgt_sizes[batch_idx][0]
+ nb_patches_w = tgt_sizes[batch_idx][1]
+ else:
+ nb_patches_h = p_attn_mask[:, 0].sum()
+ nb_patches_w = p_attn_mask[0].sum()
+ fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
+ fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
+ bucket_coords_h = torch.bucketize(
+ fractional_coords_h, boundaries, right=True
+ )
+ bucket_coords_w = torch.bucketize(
+ fractional_coords_w, boundaries, right=True
+ )
+ pos_ids = (
+ bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w
+ ).flatten()
+ position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
+ position_ids = position_ids.to(self.position_embedding.weight.device)
+ embeddings = embeddings + self.position_embedding(position_ids)
+ return embeddings
+
+
+class Idefics2VisionTransformer(nn.Module):
+
+ def __init__(
+ self,
+ config: Idefics2VisionConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+
+ embed_dim = config.hidden_size
+ self.config = config
+ self.embeddings = Idefics2VisionEmbeddings(config)
+ self.encoder = Idefics2Encoder(
+ config=config, quant_config=quant_config, prefix=f"{prefix}.encoder"
+ )
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+
+ def get_input_embeddings(self):
+ return self.embeddings
+
+ def forward(
+ self,
+ pixel_values,
+ forward_batch: ForwardBatch,
+ patch_attention_mask: Optional[torch.BoolTensor] = None,
+ tgt_sizes: Optional[torch.IntTensor] = None,
+ ) -> torch.Tensor:
+ hidden_states = self.embeddings(
+ pixel_values=pixel_values,
+ patch_attention_mask=patch_attention_mask,
+ forward_batch=forward_batch,
+ tgt_sizes=tgt_sizes,
+ )
+ encoder_outputs = self.encoder(hidden_states, forward_batch=forward_batch)
+ last_hidden_state = self.post_layernorm(encoder_outputs)
+ return last_hidden_state
+
+
+class MiniCPMVImagePixelInputs(TypedDict):
+ type: Literal["pixel_values"]
+ data: List[torch.Tensor]
+ """
+ Shape: `(batch_size * num_images, num_channels, height, width)`
+
+ Note that the image size may vary, so we pass it as a list
+ instead of a batched tensor.
+ """
+
+ image_bounds: torch.Tensor
+ """
+ Shape: `(batch_size * num_images, 2)`
+
+ This should be in `(start, stop)` format.
+ """
+
+ tgt_sizes: torch.Tensor
+ """
+ Shape: `(batch_size * num_images, 2)`
+
+ This should be in `(height, width)` format.
+ """
+
+
+class MiniCPMVImageEmbeddingInputs(TypedDict):
+ type: Literal["image_embeds"]
+ data: torch.Tensor
+ """
+ Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
+
+ `hidden_size` must match the hidden size of language model backbone.
+ instead of a batched tensor.
+ """
+
+ image_bounds: torch.Tensor
+ """
+ Shape: `(batch_size * num_images, 2)`
+
+ This should be in `(start, stop)` format.
+ """
+
+
+MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs, MiniCPMVImageEmbeddingInputs]
+
+DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
+
+
+class BaseResampler(nn.Module):
+ """
+ A 2D perceiver-resampler network with one cross attention layers by
+ (grid_size**2) learnable queries and 2d sincos pos_emb.
+ Outputs:
+ A tensor with the shape of (grid_size**2, embed_dim)
+ """
+
+ def __init__(
+ self,
+ num_queries: int,
+ embed_dim: int,
+ num_heads: int,
+ kv_dim: Optional[int] = None,
+ norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
+ do_post_projection: bool = True,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+
+ self.num_queries = num_queries
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+
+ self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
+ trunc_normal_(self.query, std=0.02)
+ if kv_dim is not None and kv_dim != embed_dim:
+ self.kv_proj = ReplicatedLinear(
+ kv_dim,
+ embed_dim,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.kv_proj",
+ )
+ else:
+ # Maintain the same return value with ReplicatedLinear.forward
+ self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa
+ nn.Identity()(*args, **kwargs),
+ None,
+ )
+ self.attn = nn.MultiheadAttention(embed_dim, num_heads)
+ # self.attn = RadixAttention(num_heads=num_heads,
+ # head_dim=embed_dim,
+ # num_kv_heads=num_heads,
+ # scaling=1.0,
+ # layer_id=0
+ # )
+ self.ln_q = norm_layer(embed_dim)
+ self.ln_kv = norm_layer(embed_dim)
+ self.do_post_projection = do_post_projection
+ self.ln_post = norm_layer(embed_dim) if do_post_projection else None
+ self.proj = (
+ nn.Parameter((embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))
+ if do_post_projection
+ else None
+ )
+
+ def _init_weights(self, m: nn.Module) -> None:
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def _repeat(self, query, N: int):
+ return query.unsqueeze(1).repeat(1, N, 1)
+
+
+class Resampler2_5(BaseResampler):
+
+ def __init__(
+ self,
+ num_queries: int,
+ embed_dim: int,
+ num_heads: int,
+ kv_dim: Optional[int] = None,
+ norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
+ max_size: Tuple[int, int] = (70, 70),
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__(
+ num_queries,
+ embed_dim,
+ num_heads,
+ kv_dim,
+ norm_layer,
+ quant_config=quant_config,
+ prefix=prefix,
+ )
+
+ self.max_size = max_size
+ self._set_2d_pos_cache(self.max_size)
+
+ self.apply(self._init_weights)
+
+ def _set_2d_pos_cache(
+ self, max_size: Tuple[int, int], device: torch.types.Device = "cpu"
+ ) -> None:
+ pos_embed_arr = get_2d_sincos_pos_embed(
+ self.embed_dim, max_size, version=(2, 5)
+ )
+ pos_embed = torch.from_numpy(pos_embed_arr).float().to(device)
+ self.register_buffer("pos_embed", pos_embed, persistent=False)
+
+ def _adjust_pos_cache(
+ self, tgt_sizes: torch.Tensor, device: torch.types.Device
+ ) -> None:
+ max_h = tgt_sizes[:, 0].max().item()
+ max_w = tgt_sizes[:, 1].max().item()
+ assert isinstance(max_h, int) and isinstance(max_w, int)
+
+ if max_h > self.max_size[0] or max_w > self.max_size[1]:
+ self.max_size = (
+ max(max_h, self.max_size[0]),
+ max(max_w, self.max_size[1]),
+ )
+ self._set_2d_pos_cache(self.max_size, device)
+
+ def forward(self, x: torch.Tensor, tgt_sizes: torch.Tensor) -> torch.Tensor:
+ assert x.shape[0] == tgt_sizes.shape[0]
+ bs = x.shape[0]
+
+ device = x.device
+ dtype = x.dtype
+
+ patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]
+
+ self._adjust_pos_cache(tgt_sizes, device=device)
+
+ max_patch_len = patch_len.max().item()
+ assert isinstance(max_patch_len, int)
+
+ key_padding_mask = torch.zeros(
+ (bs, max_patch_len), dtype=torch.bool, device=device
+ )
+
+ pos_embed = []
+ for i in range(bs):
+ tgt_h, tgt_w = tgt_sizes[i].tolist()
+ pos_embed.append(
+ self.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(dtype)
+ ) # patches * D
+ key_padding_mask[i, patch_len[i] :] = True
+ pos_embed = torch.nn.utils.rnn.pad_sequence(
+ pos_embed, batch_first=True, padding_value=0.0
+ ).permute(
+ 1, 0, 2
+ ) # BLD => L * B * D
+ x, _ = self.kv_proj(x) # B * L * D
+ x = self.ln_kv(x).permute(1, 0, 2) # L * B * D
+
+ q = self.ln_q(self.query) # Q * D
+
+ out = self.attn(
+ self._repeat(q, bs), # Q * B * D
+ x + pos_embed, # L * B * D + L * B * D
+ x,
+ key_padding_mask=key_padding_mask,
+ )[0]
+ # out: Q * B * D
+ x = out.permute(1, 0, 2) # B * Q * D
+
+ x = self.ln_post(x)
+ x = x @ self.proj
+ return x
+
+
+def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
+ version_float = getattr(config, "version", None)
+
+ # The old configs do not include version number
+ # TODO: Remove this after the HF repos are updated
+ if version_float is None:
+ if config.hidden_size == 2304 and config.query_num == 64:
+ return 2, 0
+ return 2, 5
+
+ version_str = str(version_float)
+ return tuple(int(x) for x in version_str.split("."))
+
+
+class MiniCPMVBaseModel(nn.Module):
+ """
+ The abstract class of MiniCPMV can only be inherited, but cannot be
+ instantiated.
+ """
+
+ def __init__(
+ self,
+ *,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ ):
+ # multimodal_config = config.model_config.multimodal_config
+ super().__init__()
+ # All MiniCPM-V models disable `tie_word_embeddings` but
+ # `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot
+ # check `tie_word_embeddings` until vLLM integrate MiniCPM-V model
+ # and config class
+ self.config = config
+ # self.multimodal_config = multimodal_config
+
+ self.version = get_version_by_config(self.config)
+ self.llm = self.init_llm(config=config, quant_config=quant_config)
+ self.vpm = self.init_vision_module(config, quant_config)
+ self.vision_dim = (
+ self.vpm.embed_dim
+ if self.version == (2, 0)
+ else self.vpm.embeddings.embed_dim
+ )
+ self.embed_dim = self.config.hidden_size
+
+ self.resampler = self.init_resampler(
+ self.embed_dim, self.vision_dim, quant_config=quant_config
+ )
+
+ self.logits_processor = LogitsProcessor(config)
+
+ @cached_property
+ def sampler(self):
+ if hasattr(self.llm, "sampler"):
+ return self.llm.sampler
+
+ return get_sampler()
+
+ def _get_image_bounds(
+ self,
+ input_ids: torch.Tensor,
+ im_start_id: torch.Tensor,
+ im_end_id: torch.Tensor,
+ slice_start_id: Optional[torch.Tensor] = None,
+ slice_end_id: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """
+ Returns a tensor indicating the bounds (start and end token ids) of the images
+ """
+ # All the images in the batch should share the same special image
+ # bound token ids.
+ start_cond = input_ids == im_start_id[0]
+ end_cond = input_ids == im_end_id[0]
+ if slice_start_id is not None:
+ start_cond |= input_ids == slice_start_id[0]
+ end_cond |= input_ids == slice_end_id[0]
+
+ (image_start_tokens,) = torch.where(start_cond)
+ image_start_tokens += 1
+ (image_end_tokens,) = torch.where(end_cond)
+ valid_image_nums = min(len(image_start_tokens), len(image_end_tokens))
+
+ if valid_image_nums == 0:
+ return torch.zeros((0, 2), device=input_ids.device)
+
+ # Filter out pairs where start_token >= end_token
+ valid_pairs = []
+ for i in range(valid_image_nums):
+ start_token = image_start_tokens[i]
+ end_token = image_end_tokens[i]
+ if start_token < end_token:
+ valid_pairs.append((start_token, end_token))
+
+ if not valid_pairs:
+ return torch.zeros((0, 2), device=input_ids.device)
+
+ # Convert valid pairs to tensor
+ valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
+ return valid_pairs_tensor
+
+ def get_embedding(
+ self,
+ input_ids: torch.Tensor,
+ image_inputs: Optional[MiniCPMVImageInputs],
+ forward_batch: ForwardBatch,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
+
+ if image_inputs is None: # No image
+ vision_hidden_states = torch.tensor([], device=input_ids.device)
+ else:
+ if image_inputs["type"] == "image_embeds":
+ vision_hidden_states = (
+ image_inputs["data"]
+ .type(vlm_embedding.dtype)
+ .to(vlm_embedding.device)
+ )
+ else:
+ vision_hidden_states = self.get_vision_hidden_states(
+ forward_batch, image_inputs
+ )
+
+ # See NOTE in _parse_and_validate_inputs
+ image_bounds = image_inputs["image_bounds"]
+ if len(image_bounds) > 0:
+ image_indices = torch.stack(
+ [
+ torch.arange(start, end, dtype=torch.long)
+ for start, end in image_bounds.tolist()
+ ]
+ ).to(vlm_embedding.device)
+ vlm_embedding.scatter_(
+ 0,
+ image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
+ vision_hidden_states.view(-1, vision_hidden_states.shape[-1]),
+ )
+
+ return vlm_embedding, vision_hidden_states
+
+ def _parse_and_validate_inputs(
+ self,
+ input_ids: torch.Tensor,
+ **kwargs: object,
+ ) -> Optional[MiniCPMVImageInputs]:
+ pixel_values = kwargs.pop("pixel_values", [])
+ tgt_sizes = kwargs.pop("tgt_sizes", [])
+ im_start_id = kwargs.pop("im_start_id", None)
+ im_end_id = kwargs.pop("im_end_id", None)
+ slice_start_id = kwargs.pop("slice_start_id", None)
+ slice_end_id = kwargs.pop("slice_end_id", None)
+ image_embeds = kwargs.pop("image_embeds", None)
+ image_bound = kwargs.pop("image_bound", None)
+
+ if isinstance(image_bound, list):
+ image_bound = torch.concat(image_bound)
+
+ if image_embeds is not None:
+ if not isinstance(image_embeds, (torch.Tensor, list)):
+ raise ValueError(
+ f"Incorrect type of image embeds. "
+ f"Got type: {type(image_embeds)}"
+ )
+
+ if isinstance(image_embeds, list):
+ image_embeds = torch.concat(image_embeds)
+
+ return MiniCPMVImageEmbeddingInputs(
+ image_bounds=image_bound,
+ data=image_embeds,
+ type="image_embeds",
+ )
+
+ if not isinstance(pixel_values, (torch.Tensor, list)):
+ raise ValueError(
+ "Incorrect type of pixel values. " f"Got type: {type(pixel_values)}"
+ )
+
+ if not isinstance(tgt_sizes, (torch.Tensor, list)):
+ raise ValueError(
+ "Incorrect type of target sizes. " f"Got type: {type(tgt_sizes)}"
+ )
+
+ if len(pixel_values) != len(tgt_sizes):
+ raise ValueError(
+ "Inconsistent batch lengths, found: "
+ f"{len(pixel_values)} vs. {len(tgt_sizes)}"
+ )
+
+ pixel_values_flat: List[torch.Tensor] = []
+ # print(f'pixel_values type: {type(pixel_values)}')
+ tgt_sizes_flat: List[torch.Tensor] = []
+ for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
+ if len(pixel_b) != len(tgt_b):
+ raise ValueError(
+ "Inconsistent N lengths, found: " f"{len(pixel_b)} vs {len(tgt_b)}"
+ )
+
+ for pixel_n, tgt_n in zip(pixel_b, tgt_b):
+ pixel_values_flat += pixel_n
+ tgt_sizes_flat += tgt_n
+
+ # NOTE: Input IDs does not contain image tokens during memory profiling,
+ # so we allow it to be empty
+ if len(pixel_values_flat) != len(tgt_sizes_flat):
+ raise ValueError(
+ "Inconsistent flattened lengths, found: "
+ f"{len(pixel_values_flat)} vs. "
+ f"{len(tgt_sizes_flat)}"
+ )
+
+ if len(pixel_values_flat) == 0:
+ return None
+
+ image_bounds = self._get_image_bounds(
+ input_ids=input_ids,
+ im_start_id=im_start_id,
+ im_end_id=im_end_id,
+ slice_start_id=slice_start_id,
+ slice_end_id=slice_end_id,
+ )
+ return MiniCPMVImagePixelInputs(
+ image_bounds=image_bounds.to(device=input_ids.device),
+ data=pixel_values_flat,
+ tgt_sizes=torch.stack(tgt_sizes_flat),
+ type="pixel_values",
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ forward_batch: ForwardBatch,
+ **kwargs: Any,
+ ) -> torch.Tensor:
+ # if intermediate_tensors is not None:
+ # vlm_embeddings = None
+ # else:
+ if forward_batch.image_inputs is not None and forward_batch.image_inputs != [
+ None
+ ]:
+ kwargs.update(
+ {
+ "pixel_values": (
+ None
+ if forward_batch.image_inputs is None
+ else [
+ None if i is None else i.pixel_values
+ for i in forward_batch.image_inputs
+ ]
+ ),
+ "im_start_id": forward_batch.image_inputs[0].im_start_id,
+ "im_end_id": forward_batch.image_inputs[0].im_end_id,
+ "slice_start_id": forward_batch.image_inputs[0].slice_start_id,
+ "slice_end_id": forward_batch.image_inputs[0].slice_end_id,
+ "tgt_sizes": (
+ None
+ if forward_batch.image_inputs is None
+ else [
+ None if i is None else i.tgt_sizes
+ for i in forward_batch.image_inputs
+ ]
+ ),
+ "image_bound": (
+ torch.zeros(0)
+ if forward_batch.image_inputs is None
+ else forward_batch.image_inputs[0].image_bound
+ ),
+ }
+ )
+
+ image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs)
+
+ # Clamp input ids. This is because the input_ids for the image tokens are
+ # filled with the hash values of the image for the prefix matching in the radix attention.
+ # There values are useless because their embeddings will be replaced by vision embeddings anyway.
+ input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
+
+ vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs, forward_batch)
+
+ # always pass the input via `inputs_embeds`
+ # to make sure the computation graph is consistent
+ # for `torch.compile` integration
+ input_ids = None
+
+ hidden_states = self.llm.model(
+ input_ids=input_ids,
+ positions=positions,
+ forward_batch=forward_batch,
+ inputs_embeds=vlm_embeddings,
+ )
+
+ return self.logits_processor(
+ input_ids, hidden_states, self.llm.lm_head, forward_batch
+ )
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[torch.Tensor]:
+ return self.llm.compute_logits(hidden_states, sampling_metadata)
+
+ def sample(
+ self,
+ logits: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[SamplerOutput]:
+ next_tokens = self.sampler(logits, sampling_metadata)
+ return next_tokens
+
+ def get_mm_mapping(self) -> MultiModelKeys:
+ """
+ Get the module prefix in multimodal models
+ """
+ return MultiModelKeys.from_string_field(
+ language_model="llm", connector="resampler", tower_model="vpm"
+ )
+
+ def init_llm(
+ self,
+ config: Qwen2Config,
+ quant_config: Optional[QuantizationConfig] = None,
+ ) -> nn.Module:
+ raise NotImplementedError
+
+ def init_vision_module(
+ self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig],
+ ) -> nn.Module:
+ raise NotImplementedError
+
+ def init_resampler(
+ self,
+ embed_dim: int,
+ vision_dim: int,
+ quant_config: Optional[QuantizationConfig] = None,
+ ) -> nn.Module:
+ raise NotImplementedError
+
+ def get_vision_embedding(
+ self,
+ pixel_values: List[torch.Tensor],
+ patch_attn_mask: Optional[torch.Tensor] = None,
+ tgt_sizes: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ raise NotImplementedError
+
+ def get_vision_hidden_states(
+ self, forward_batch: ForwardBatch, data: MiniCPMVImageInputs
+ ) -> torch.Tensor:
+ raise NotImplementedError
+
+
+class MiniCPMV2_6(MiniCPMVBaseModel):
+ packed_modules_mapping = {
+ "qkv_proj": [
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ ],
+ "gate_up_proj": [
+ "gate_proj",
+ "up_proj",
+ ],
+ }
+ # LoRA specific attributes
+ supported_lora_modules = [
+ # vision encoder
+ "fc1",
+ "fc2",
+ "out_proj",
+ # language model
+ "qkv_proj", # same name with vision encoder
+ "o_proj",
+ "gate_up_proj",
+ "down_proj",
+ # resampler
+ "kv_proj",
+ ]
+
+ # BitandBytes specific attributes
+ bitsandbytes_stacked_params_mapping = {
+ # shard_name, weight_name, index
+ "q_proj": ("qkv_proj", 0),
+ "k_proj": ("qkv_proj", 1),
+ "v_proj": ("qkv_proj", 2),
+ "gate_proj": ("gate_up_proj", 0),
+ "up_proj": ("gate_up_proj", 1),
+ }
+
+ embedding_modules = {}
+ embedding_padding_modules = []
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ ):
+ super().__init__(config=config, quant_config=quant_config)
+ assert self.version == (2, 6)
+
+ def init_llm(
+ self,
+ config: Qwen2Config,
+ quant_config: Optional[QuantizationConfig] = None,
+ ) -> nn.Module:
+ return Qwen2ForCausalLM(config=config, quant_config=quant_config)
+
+ def init_vision_module(
+ self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig],
+ ) -> nn.Module:
+ model = Idefics2VisionTransformer(
+ config=config.vision_config, quant_config=quant_config
+ )
+ # model = SiglipVisionTransformer(config.vision_config)
+ if self.config.drop_vision_last_layer:
+ model.encoder.layers = model.encoder.layers[:-1]
+
+ setattr(model, "embed_dim", model.embeddings.embed_dim)
+ setattr(model, "patch_size", model.embeddings.patch_size)
+ return model
+
+ def init_resampler(
+ self,
+ embed_dim: int,
+ vision_dim: int,
+ quant_config: Optional[QuantizationConfig] = None,
+ ) -> nn.Module:
+ with set_default_torch_dtype(torch.float16):
+ # The resampler in 2.6 remains consistent with the one in 2.5.
+ resampler = Resampler2_5(
+ num_queries=self.config.query_num,
+ embed_dim=embed_dim,
+ num_heads=embed_dim // 128,
+ kv_dim=vision_dim,
+ quant_config=quant_config,
+ )
+
+ return resampler.to(device="cuda", dtype=torch.get_default_dtype())
+
+ def get_vision_embedding(
+ self,
+ pixel_values: List[torch.Tensor],
+ patch_attn_mask: Optional[torch.Tensor] = None,
+ tgt_sizes: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ vision_embedding = self.vpm(
+ pixel_values,
+ patch_attention_mask=patch_attn_mask,
+ tgt_sizes=tgt_sizes,
+ )
+ return vision_embedding
+
+ def get_vision_hidden_states(
+ self,
+ forward_batch: ForwardBatch,
+ data: MiniCPMVImageInputs,
+ ) -> torch.Tensor:
+ pixel_values = data["data"]
+ tgt_sizes = data["tgt_sizes"]
+
+ device = self.vpm.embeddings.position_embedding.weight.device
+ dtype = self.vpm.embeddings.position_embedding.weight.dtype
+ all_pixel_values_lst = [
+ i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
+ ]
+
+ max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
+ assert isinstance(max_patches, int)
+
+ all_pixel_values = torch.nn.utils.rnn.pad_sequence(
+ all_pixel_values_lst, batch_first=True, padding_value=0.0
+ )
+ B, L, _ = all_pixel_values.shape
+ all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
+ patch_attn_mask = torch.zeros(
+ (B, 1, max_patches), dtype=torch.bool, device=device
+ )
+ for i in range(B):
+ patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True
+ vision_embedding = self.vpm(
+ all_pixel_values.type(dtype),
+ forward_batch=forward_batch,
+ patch_attention_mask=patch_attn_mask,
+ tgt_sizes=tgt_sizes,
+ )
+
+ return self.resampler(vision_embedding, tgt_sizes)
+
+ def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
+ if not isinstance(image_inputs.im_start_id, list) or not isinstance(
+ image_inputs.im_end_id, list
+ ):
+ return input_ids
+
+ pad_values = image_inputs.pad_values
+ new_input_ids = []
+ last_idx = 0
+
+ # Get all special token IDs
+ im_start_id = (
+ image_inputs.im_start_id[0].item()
+ if isinstance(image_inputs.im_start_id[0], torch.Tensor)
+ else image_inputs.im_start_id[0]
+ )
+ im_end_id = (
+ image_inputs.im_end_id[0].item()
+ if isinstance(image_inputs.im_end_id[0], torch.Tensor)
+ else image_inputs.im_end_id[0]
+ )
+ slice_start_id = (
+ image_inputs.slice_start_id[0].item()
+ if isinstance(image_inputs.slice_start_id[0], torch.Tensor)
+ else image_inputs.slice_start_id[0]
+ )
+ slice_end_id = (
+ image_inputs.slice_end_id[0].item()
+ if isinstance(image_inputs.slice_end_id[0], torch.Tensor)
+ else image_inputs.slice_end_id[0]
+ )
+
+ # Find all start and end positions for both types
+ start_indices = [
+ i
+ for i, x in enumerate(input_ids)
+ if x == im_start_id or x == slice_start_id
+ ]
+ end_indices = [
+ i for i, x in enumerate(input_ids) if x == im_end_id or x == slice_end_id
+ ]
+
+ if len(start_indices) != len(end_indices):
+ return input_ids
+
+ # Process each region (both image and slice)
+ for start_idx, end_idx in zip(start_indices, end_indices):
+ # Add non-image tokens before this region
+ new_input_ids.extend(
+ input_ids[last_idx : start_idx + 1]
+ ) # include start token
+
+ # Calculate the number of tokens to pad
+ num_tokens = end_idx - start_idx - 1 # exclude start and end tokens
+
+ # Generate pad_ids
+ pad_ids = pad_values * ((num_tokens + len(pad_values)) // len(pad_values))
+ pad_ids = pad_ids[:num_tokens]
+
+ # Add pad_ids
+ new_input_ids.extend(pad_ids)
+
+ # Update last_idx to after end token
+ last_idx = end_idx
+
+ # Add remaining tokens after last region
+ new_input_ids.extend(input_ids[last_idx:])
+ assert len(input_ids) == len(new_input_ids)
+ return new_input_ids
+
+
+_SUPPORT_VERSION = {(2, 6): MiniCPMV2_6}
+
+
+class MiniCPMV:
+ """
+ Different versions of MiniCPMV use different visual encoders and LLMs,
+ which is not conducive to the current integration logic of LoRA and
+ bitsandbytes in vLLM. Therefore, it is necessary to separate them.
+ """
+
+ # Ensure that the LoRA support check passes when the class is not
+ # initialized, but set all these attributes to empty.
+ packed_modules_mapping = {}
+ supported_lora_modules = []
+ embedding_modules = {}
+ embedding_padding_modules = []
+
+ minicpmv: nn.Module
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ ) -> None:
+ super().__init__()
+
+ if not hasattr(config, "version"):
+ version = (2, 6)
+ else:
+ version = str(config.version).split(".")
+ version = tuple([int(x) for x in version])
+ # Dispatch class based on version
+ instance_class = _SUPPORT_VERSION.get(version)
+ if instance_class is None:
+ raise ValueError("Currently, MiniCPMV only supports versions 2.6")
+
+ try:
+ minicpmv = instance_class(config=config, quant_config=quant_config)
+ self.minicpmv = minicpmv
+ except Exception as e:
+ print(f"Failed to instantiate MiniCPMV: {e}")
+ raise e
+ self.config = config
+
+ def __getattr__(self, name):
+ if name == "minicpmv":
+ return None
+ return getattr(self.minicpmv, name)
+
+ def __call__(self, *args, **kwargs):
+ return self.minicpmv(*args, **kwargs)
+
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ("qkv_proj", "q_proj", "q"),
+ ("qkv_proj", "k_proj", "k"),
+ ("qkv_proj", "v_proj", "v"),
+ ("gate_up_proj", "gate_proj", 0),
+ ("gate_up_proj", "up_proj", 1),
+ ]
+
+ params_dict = dict(self.minicpmv.named_parameters())
+ # for key in params_dict.keys():
+ # print(f'{key}')
+ for name, loaded_weight in weights:
+ if "rotary_emb.inv_freq~" in name or "projector" in name:
+ continue
+ # if "resampler" in name in name:
+ # continue
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
+ # Models trained using ColossalAI may include these tensors in
+ # the checkpoint. Skip them.
+ continue
+ if name.startswith("model.vision_tower") and name not in params_dict:
+ continue
+
+ if "sampler" in name:
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ weight_loader(param, loaded_weight)
+ continue
+
+ for param_name, weight_name, shard_id in stacked_params_mapping:
+ # replace the name and load with customized loader
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ weight_loader(param, loaded_weight)
+
+
+EntryClass = MiniCPMV
diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py
index 2a20d6c50de..df8e7e586d9 100644
--- a/python/sglang/srt/models/qwen2.py
+++ b/python/sglang/srt/models/qwen2.py
@@ -242,17 +242,20 @@ def __init__(
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.embed_tokens(input_ids)
+
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
- input_embeds: torch.Tensor = None,
+ inputs_embeds: torch.Tensor = None,
) -> torch.Tensor:
- if input_embeds is None:
+ if inputs_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
- hidden_states = input_embeds
+ hidden_states = inputs_embeds
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
@@ -267,7 +270,6 @@ def forward(
class Qwen2ForCausalLM(nn.Module):
-
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
@@ -305,16 +307,19 @@ def __init__(
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.model.get_input_embeddings(input_ids)
+
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
- input_embeds: torch.Tensor = None,
+ inputs_embeds: torch.Tensor = None,
get_embedding: bool = False,
) -> torch.Tensor:
- hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
+ hidden_states = self.model(input_ids, positions, forward_batch, inputs_embeds)
if not get_embedding:
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py
index 2e9ec9d8f50..7b25d52f3ab 100644
--- a/python/sglang/srt/models/qwen2_vl.py
+++ b/python/sglang/srt/models/qwen2_vl.py
@@ -661,7 +661,7 @@ def forward(
input_ids=input_ids,
positions=positions,
forward_batch=forward_batch,
- input_embeds=inputs_embeds,
+ inputs_embeds=inputs_embeds,
)
if not get_embedding:
diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py
index 5056ba22ef9..32c094c3d0c 100644
--- a/python/sglang/srt/openai_api/adapter.py
+++ b/python/sglang/srt/openai_api/adapter.py
@@ -897,7 +897,7 @@ def v1_chat_generate_request(
{"role": message.role, "content": message.content}
)
else:
- content_list = message.dict()["content"]
+ content_list = message["content"]
for content in content_list:
if content["type"] == "text":
openai_compatible_messages.append(
diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py
index 44a5e41a41b..80b10d70802 100644
--- a/python/sglang/srt/utils.py
+++ b/python/sglang/srt/utils.py
@@ -440,6 +440,8 @@ def load_image(image_file: Union[str, bytes]):
else:
raise ValueError(f"Invalid image: {image}")
+ # if image_size is None:
+ # image_size = image.size
return image, image_size
diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py
index 4121deb17cc..46e72960acb 100644
--- a/python/sglang/test/test_utils.py
+++ b/python/sglang/test/test_utils.py
@@ -405,7 +405,7 @@ def popen_launch_server(
base_url: str,
timeout: float,
api_key: Optional[str] = None,
- other_args: tuple = (),
+ other_args: list[str] = (),
env: Optional[dict] = None,
return_stdout_stderr: Optional[tuple] = None,
):
diff --git a/test/README.md b/test/README.md
index 3d739cc0496..868061bbc4a 100644
--- a/test/README.md
+++ b/test/README.md
@@ -25,7 +25,7 @@ export OPENAI_API_KEY=sk-*****
python3 test_openai_backend.py
# Run a single test
-python3 -m unittest test_openai_backend.TestOpenAIBackend.test_few_shot_qa
+python3 -m unittest test_openai_backend.TestOpenAIServer.test_few_shot_qa
# Run a suite with multiple files
python3 run_suite.py --suite per-commit
diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py
index e19e6b01d51..5af61034b9e 100644
--- a/test/srt/test_vision_openai_server.py
+++ b/test/srt/test_vision_openai_server.py
@@ -171,7 +171,7 @@ def test_multi_images_chat_completion(self):
text = response.choices[0].message.content
assert isinstance(text, str)
print(text)
- assert "man" in text or "cab" in text, text
+ assert "man" in text or "cab" in text or "SUV" in text, text
assert "logo" in text or '"S"' in text or "SG" in text, text
assert response.id
assert response.created
@@ -444,5 +444,24 @@ def test_video_chat_completion(self):
pass
+class TestMinicpmvServer(TestOpenAIVisionServer):
+ @classmethod
+ def setUpClass(cls):
+ cls.model = "openbmb/MiniCPM-V-2_6"
+ cls.base_url = DEFAULT_URL_FOR_TEST
+ cls.api_key = "sk-123456"
+ cls.process = popen_launch_server(
+ cls.model,
+ cls.base_url,
+ timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
+ other_args=[
+ "--trust-remote-code",
+ "--chat-template",
+ "minicpmv",
+ ],
+ )
+ cls.base_url += "/v1"
+
+
if __name__ == "__main__":
unittest.main()