Skip to content

Commit

Permalink
Add tools functionality to vLLM (kserve#4033)
Browse files Browse the repository at this point in the history
* Add tools to chat template

Signed-off-by: Arjun Bhalla <[email protected]>

Linting

Signed-off-by: Arjun Bhalla <[email protected]>

add test

Signed-off-by: Arjun Bhalla <[email protected]>

Fix linting manually

Signed-off-by: Arjun Bhalla <[email protected]>

* Fix linting

Signed-off-by: Arjun Bhalla <[email protected]>

---------

Signed-off-by: Arjun Bhalla <[email protected]>
Signed-off-by: Arjun Bhalla <[email protected]>
Co-authored-by: Arjun Bhalla <[email protected]>
  • Loading branch information
ArjunBhalla98 and Arjun Bhalla authored Nov 5, 2024
1 parent 9f71bbb commit 505cede
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
CompletionRequest,
OpenAIChatAdapterModel,
)
from kserve.protocol.rest.openai.types.openapi import ChatCompletionTool
from kserve.protocol.rest.openai.types import (
ChatCompletionRequestMessage,
Completion,
Expand Down Expand Up @@ -387,6 +388,7 @@ def apply_chat_template(
self,
messages: Iterable[ChatCompletionRequestMessage],
chat_template: Optional[str] = None,
tools: Optional[list[ChatCompletionTool]] = None,
) -> ChatPrompt:
"""
Given a list of chat completion messages, convert them to a prompt.
Expand All @@ -399,6 +401,7 @@ def apply_chat_template(
chat_template=chat_template,
tokenize=False,
add_generation_prompt=True,
tools=[tool.model_dump() for tool in tools] if tools else None,
),
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
CreateCompletionRequest,
CreateCompletionResponse as Completion,
Logprobs,
ChatCompletionTool,
)
from kserve.protocol.rest.openai.errors import OpenAIError, create_error_response
from kserve.protocol.rest.openai import ChatCompletionRequestMessage, CompletionRequest
Expand Down Expand Up @@ -93,7 +94,6 @@ def logit_bias_logits_processor(


class OpenAIServingCompletion:

def __init__(self, engine: AsyncLLMEngine, request_logger: RequestLogger = None):
self.engine = engine

Expand Down Expand Up @@ -365,14 +365,16 @@ def request_output_to_completion_response(

def apply_chat_template(
self,
messages: Iterable[ChatCompletionRequestMessage,],
messages: Iterable[ChatCompletionRequestMessage],
chat_template: Optional[str] = None,
tools: Optional[list[ChatCompletionTool]] = None,
):
return self.tokenizer.apply_chat_template(
conversation=messages,
chat_template=chat_template,
tokenize=False,
add_generation_prompt=True,
tools=tools,
)

async def _post_init(self):
Expand Down
6 changes: 4 additions & 2 deletions python/huggingfaceserver/huggingfaceserver/vllm/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
CompletionRequest,
OpenAIChatAdapterModel,
)
from kserve.protocol.rest.openai.types.openapi import ChatCompletionTool
from kserve.protocol.rest.openai.types import Completion
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm import AsyncEngineArgs
Expand Down Expand Up @@ -68,15 +69,16 @@ async def healthy(self) -> bool:

def apply_chat_template(
self,
messages: Iterable[ChatCompletionRequestMessage,],
messages: Iterable[ChatCompletionRequestMessage],
chat_template: Optional[str] = None,
tools: Optional[list[ChatCompletionTool]] = None,
) -> ChatPrompt:
"""
Given a list of chat completion messages, convert them to a prompt.
"""
return ChatPrompt(
prompt=self.openai_serving_completion.apply_chat_template(
messages, chat_template
messages, chat_template, tools
)
)

Expand Down
54 changes: 54 additions & 0 deletions python/huggingfaceserver/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,3 +507,57 @@ async def test_input_padding_with_pad_token_not_specified(
== "west , and the sun sets in the west . \n the sun rises in the"
)
assert "a member of the royal family ." in response.choices[1].text


@pytest.mark.asyncio
async def test_tools_chat_completion(bloom_model: HuggingfaceGenerativeModel):
messages = [
{
"role": "system",
"content": "You are a friendly chatbot whose purpose is to tell me what the weather is.",
},
{
"role": "user",
"content": "weather in Ithaca, NY",
},
]

tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "dict",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location.",
},
},
"required": ["location", "format"],
},
},
}
]
params = CreateChatCompletionRequest(
model="bloom-560m",
messages=messages,
stream=False,
max_tokens=100,
tools=tools,
tool_choice="auto",
chat_template="{% for message in messages %}"
"{{ message.content }} You have these tools: {% for tool in tools %} {{ eos_token }}"
"{% endfor %}{% endfor %}",
)
request = ChatCompletionRequest(params=params, context={})
response = await bloom_model.create_chat_completion(request)

assert response.choices[0].message.content
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,9 @@ async def create_chat_completion(
raise InvalidInput("n != 1 is not supported")

# Convert the messages into a prompt
chat_prompt = self.apply_chat_template(params.messages, params.chat_template)
chat_prompt = self.apply_chat_template(
params.messages, params.chat_template, params.tools
)
# Translate the chat completion request to a completion request
completion_params = self.chat_completion_params_to_completion_params(
params, chat_prompt.prompt
Expand Down
2 changes: 2 additions & 0 deletions python/kserve/test/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from kserve.protocol.rest.openai.errors import OpenAIError
from kserve.protocol.rest.openai.types.openapi import (
ChatCompletionTool,
CreateChatCompletionRequest,
Error,
ErrorResponse,
Expand Down Expand Up @@ -87,6 +88,7 @@ def apply_chat_template(
self,
messages: Iterable[ChatCompletionRequestMessage],
chat_template: Optional[str] = None,
tools: Optional[list[ChatCompletionTool]] = None,
) -> ChatPrompt:
return ChatPrompt(prompt="hello")

Expand Down

0 comments on commit 505cede

Please sign in to comment.