Skip to content

Commit

Permalink
[Feature] Support minicpmv v2.6d
Browse files Browse the repository at this point in the history
  • Loading branch information
mickqian committed Jan 10, 2025
1 parent 11fffbc commit 42f09c0
Show file tree
Hide file tree
Showing 16 changed files with 1,534 additions and 23 deletions.
2 changes: 1 addition & 1 deletion docs/references/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
- InternLM 2
- Exaone 3
- BaiChuan2
- MiniCPM / MiniCPM 3
- MiniCPM / MiniCPM 3 / MiniCPMV
- XVERSE / XVERSE MoE
- SmolLM
- GLM-4
Expand Down
19 changes: 15 additions & 4 deletions python/sglang/lang/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def get_chat_template_by_model_path(model_path):
)
)


register_chat_template(
ChatTemplate(
name="claude",
Expand All @@ -101,7 +100,6 @@ def get_chat_template_by_model_path(model_path):
)
)


register_chat_template(
ChatTemplate(
name="chatml",
Expand All @@ -116,7 +114,6 @@ def get_chat_template_by_model_path(model_path):
)
)


register_chat_template(
ChatTemplate(
name="chatml-llava",
Expand All @@ -132,7 +129,6 @@ def get_chat_template_by_model_path(model_path):
)
)


# There is default system prompt for qwen
# reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
# The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
Expand Down Expand Up @@ -219,6 +215,21 @@ def get_chat_template_by_model_path(model_path):
)
)

# https://huggingface.co/openbmb/MiniCPM-V-2_6
register_chat_template(
ChatTemplate(
name="minicpmv",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("", " "),
"user": ("user:", " "),
"assistant": ("assistant:", "</s>"),
},
stop_str=("<|im_end|>", "<|endoftext|>"),
image_token="(<image>./</image>)",
)
)

# The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
register_chat_template(
ChatTemplate(
Expand Down
15 changes: 14 additions & 1 deletion python/sglang/srt/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,6 @@ def generate_chat_conv(

# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)

return conv


Expand Down Expand Up @@ -555,3 +554,17 @@ def generate_chat_conv(
image_token="<|vision_start|><|image_pad|><|vision_end|>",
)
)

# Reference: https://huggingface.co/openbmb/MiniCPM-V-2_6#usage
register_conv_template(
Conversation(
name="minicpmv",
system_message="You are a helpful assistant",
system_template="<|im_start|>system\n{system_message}.",
roles=("<|im_start|>user", "<|im_start|>assistant"),
sep="<|im_end|>\n",
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
stop_str=("<|im_end|>", "<|endoftext|>"),
image_token="(<image>./</image>)",
)
)
2 changes: 1 addition & 1 deletion python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def forward(
hidden_states,
lm_head: VocabParallelEmbedding,
logits_metadata: Union[LogitsMetadata, ForwardBatch],
):
) -> LogitsProcessorOutput:
if isinstance(logits_metadata, ForwardBatch):
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)

Expand Down
125 changes: 125 additions & 0 deletions python/sglang/srt/managers/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import numpy as np
import transformers
from decord import VideoReader, cpu
from PIL import Image

from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.mm_utils import expand2square, process_anyres_image
Expand Down Expand Up @@ -36,6 +38,7 @@ class BaseImageProcessor(ABC):
def __init__(self, hf_config, server_args, _processor):
self.hf_config = hf_config
self._processor = _processor
self.server_args = server_args

self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
Expand Down Expand Up @@ -229,6 +232,126 @@ async def process_images_async(
return image_inputs


class MiniCPMVImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)

@staticmethod
def _process_images_task(images, input_text):
# print("_process_images_task...")
result = global_processor.__call__(
text=input_text, images=images, return_tensors="pt"
)
return {
"input_ids": result["input_ids"],
"pixel_values": result["pixel_values"],
"tgt_sizes": result["tgt_sizes"],
"image_bound": result["image_bound"],
}

async def _process_images(self, images, input_text):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
MiniCPMVImageProcessor._process_images_task,
images,
input_text,
)
else:
image_inputs = self._processor(
images=images, text=input_text, return_tensors="pt"
)

return image_inputs

async def process_images_async(
self, image_data: List[Union[str, bytes]], input_text, request_obj
):
if not image_data:
return None

# print("process_images_async...")

if not isinstance(image_data, list):
image_data = [image_data]

image_hashes, image_sizes = [], []
raw_images = []
IMAGE_TOKEN = "(<image>./</image>)"

def encode_video(video_path):
if not os.path.exists(video_path):
return []
MAX_NUM_FRAMES = 3 # if cuda OOM set a smaller number

def uniform_sample(l, n):
gap = len(l) / n
idxs = [int(i * gap + gap / 2) for i in range(n)]
return [l[i] for i in idxs]

vr = VideoReader(video_path, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1) # FPS
frame_idx = [i for i in range(0, len(vr), sample_fps)]
if len(frame_idx) > MAX_NUM_FRAMES:
frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
frames = vr.get_batch(frame_idx).asnumpy()
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
return frames

if isinstance(input_text, list):
assert len(input_text) and isinstance(input_text[0], int)
input_text = self._processor.tokenizer.decode(input_text)

# MiniCPMV requires each frame of video as a single image token
text_parts = input_text.split(IMAGE_TOKEN)
new_text_parts = []

for image_index, image in enumerate(image_data):
if isinstance(image, str) and image.startswith("video:"):
path = image[len("video:") :]
frames: List[Image] = encode_video(path)
image_sizes += frames[0].size * len(frames)
else:
raw_image, size = load_image(image)
image_sizes += [size]
frames = [raw_image]

image_hashes += [hash(image)]
raw_images += frames
new_text_parts.append(text_parts[image_index])
new_text_parts.append(IMAGE_TOKEN * len(frames))

new_text_parts.append(text_parts[-1])
input_text = "".join(new_text_parts)
res = await self._process_images(images=raw_images, input_text=input_text)
pixel_values = res["pixel_values"]
tgt_sizes = res["tgt_sizes"]
image_bound = res["image_bound"]
input_ids = res["input_ids"]

# Collect special token ids
tokenizer = self._processor.tokenizer
im_start_id = [tokenizer.im_start_id]
im_end_id = [tokenizer.im_end_id]
if tokenizer.slice_start_id:
slice_start_id = [tokenizer.slice_start_id]
slice_end_id = [tokenizer.slice_end_id]

return {
"input_ids": input_ids.flatten().tolist(),
"pixel_values": pixel_values,
"tgt_sizes": tgt_sizes,
"image_hashes": image_hashes,
"modalities": request_obj.modalities or ["image"],
"im_start_id": im_start_id,
"im_end_id": im_end_id,
"slice_start_id": slice_start_id,
"slice_end_id": slice_end_id,
"image_bound": image_bound,
}


class Qwen2VLImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _image_processor):
self.hf_config = hf_config
Expand Down Expand Up @@ -350,6 +473,8 @@ def get_image_processor(
return MllamaImageProcessor(hf_config, server_args, processor)
elif "Qwen2VLForConditionalGeneration" in hf_config.architectures:
return Qwen2VLImageProcessor(hf_config, server_args, processor.image_processor)
elif "MiniCPMV" in hf_config.architectures:
return MiniCPMVImageProcessor(hf_config, server_args, processor)
else:
return LlavaImageProcessor(hf_config, server_args, processor.image_processor)

Expand Down
20 changes: 17 additions & 3 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
if TYPE_CHECKING:
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm


INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5

# Put some global args for easy access
Expand All @@ -67,7 +66,6 @@
"enable_ep_moe": ServerArgs.enable_ep_moe,
}


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -148,6 +146,17 @@ class ImageInputs:
image_grid_thws: List[Tuple[int, int, int]] = None
mrope_position_delta: Optional[torch.Tensor] = None

# MiniCPMV related
# All the images in the batch should share the same special image
# bound token ids.
im_start_id: Optional[torch.Tensor] = None
im_end_id: Optional[torch.Tensor] = None
slice_start_id: Optional[torch.Tensor] = None
slice_end_id: Optional[torch.Tensor] = None
image_bound: Optional[torch.Tensor] = None

tgt_sizes: Optional[list] = None

@staticmethod
def from_dict(obj: dict):
ret = ImageInputs(
Expand All @@ -167,6 +176,12 @@ def from_dict(obj: dict):
"aspect_ratio_ids",
"aspect_ratio_mask",
"image_grid_thws",
"im_start_id",
"im_end_id",
"slice_start_id",
"slice_end_id",
"tgt_sizes",
"image_bound",
]
for arg in optional_args:
if arg in obj:
Expand Down Expand Up @@ -1136,7 +1151,6 @@ def get_model_worker_batch(self):

global bid
bid += 1

return ModelWorkerBatch(
bid=bid,
forward_mode=self.forward_mode,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def init_torch_distributed(self):
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)

if not self.is_draft_worker:
# Only initilzie the distributed environment on the target model worker.
# Only initialize the distributed environment on the target model worker.
init_distributed_environment(
backend=backend,
world_size=self.tp_size,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def forward(
pt += 1

return self.language_model(
input_ids, positions, forward_batch, input_embeds=input_embeds
input_ids, positions, forward_batch, inputs_embeds=input_embeds
)
elif forward_batch.forward_mode.is_decode():
return self.language_model(input_ids, positions, forward_batch)
Expand Down
Loading

0 comments on commit 42f09c0

Please sign in to comment.