Skip to content

Commit

Permalink
add function to qwen
Browse files Browse the repository at this point in the history
  • Loading branch information
WHALEEYE committed Jan 13, 2025
1 parent e8c8148 commit 520c29a
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 30 deletions.
12 changes: 6 additions & 6 deletions camel/agents/chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,9 @@ def step(

break

self._log_final_output(response.output_messages)
self._record_final_output(response.output_messages)

return self._parse_chatagent_response(
return self._convert_to_chatagent_response(
response, tool_call_records, num_tokens
)

Expand Down Expand Up @@ -441,13 +441,13 @@ async def astep(

break

self._log_final_output(response.output_messages)
self._record_final_output(response.output_messages)

return self._parse_chatagent_response(
return self._convert_to_chatagent_response(
response, tool_call_records, num_tokens
)

def _parse_chatagent_response(
def _convert_to_chatagent_response(
self,
response: ModelResponse,
tool_call_records: List[ToolCallingRecord],
Expand All @@ -469,7 +469,7 @@ def _parse_chatagent_response(
info=info,
)

def _log_final_output(self, output_messages: List[BaseMessage]) -> None:
def _record_final_output(self, output_messages: List[BaseMessage]) -> None:
r"""Log final messages or warnings about multiple responses."""
if len(output_messages) == 1:
self.record_message(output_messages[0])
Expand Down
27 changes: 14 additions & 13 deletions camel/configs/qwen_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations

from typing import ClassVar, Optional, Union
from typing import Dict, List, Optional, Union

from pydantic import Field

from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN, NotGiven


class QwenConfig(BaseConfig):
Expand Down Expand Up @@ -52,16 +53,16 @@ class QwenConfig(BaseConfig):
keeping other parameters unchanged, the model is likely to return
the same result.
(default: :obj:`None`)
stop (str or list, optional): Using the stop parameter, the model will
automatically stop generating text when it is about to include the
specified string or token_id. You can use the stop parameter to
control the output of the model by passing sensitive words.
(default: :obj:`None`)
tools (list, optional): Specifies an array of tools that the model can
stop (Union[str, List], optional): Using the stop parameter, the model
will automatically stop generating text when it is about to
include the specified string or token_id. You can use the stop
parameter to control the output of the model by passing sensitive
words. (default: :obj:`None`)
tools (List, optional): Specifies an array of tools that the model can
call. It can contain one or more tool objects. During a function
call process, the model will select one tool from the array.
(default: :obj:`None`)
extra_body (dict, optional): Additional parameters to be sent to the
extra_body (Dict, optional): Additional parameters to be sent to the
Qwen API. If you want to enable internet search, you can set this
parameter to `{"enable_search": True}`.
(default: :obj:`{"enable_search": False}`)
Expand All @@ -74,11 +75,11 @@ class QwenConfig(BaseConfig):
temperature: float = 0.3
top_p: float = 0.9
presence_penalty: float = 0.0
response_format: ClassVar[dict] = {"type": "text"}
max_tokens: Union[int, NotGiven] = NOT_GIVEN
response_format: Dict = Field(default_factory=lambda: {"type": "text"})
max_tokens: Optional[int] = None
seed: Optional[int] = None
stop: Optional[Union[str, list]] = None
extra_body: ClassVar[dict] = {"enable_search": False}
stop: Optional[Union[str, List]] = None
extra_body: Dict = Field(default_factory=lambda: {"enable_search": False})

def __init__(self, include_usage: bool = True, **kwargs):
super().__init__(**kwargs)
Expand Down
7 changes: 1 addition & 6 deletions camel/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,7 @@ def run(
`ChatCompletion` in the non-stream mode, or
`Stream[ChatCompletionChunk]` in the stream mode.
"""
response_format = (
self.model_config_dict.get("response_format", None)
or response_format
)
# If tools are empty, make it None
tools = self.model_config_dict.get("tools", None) or tools or None
tools = tools or self.model_config_dict.get("tools", None)
return self._run(messages, response_format, tools)

async def arun(
Expand Down
10 changes: 8 additions & 2 deletions camel/models/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ def _run(
`ChatCompletion` in the non-stream mode, or
`Stream[ChatCompletionChunk]` in the stream mode.
"""
response_format = response_format or self.model_config_dict.get(
"response_format", None
)
if response_format:
return self._request_parse(messages, response_format, tools)
else:
Expand All @@ -177,6 +180,9 @@ async def _arun(
`ChatCompletion` in the non-stream mode, or
`AsyncStream[ChatCompletionChunk]` in the stream mode.
"""
response_format = response_format or self.model_config_dict.get(
"response_format", None
)
if response_format:
return await self._arequest_parse(messages, response_format, tools)
else:
Expand All @@ -189,7 +195,7 @@ def _request_chat_completion(
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
request_config = self.model_config_dict.copy()

if tools is not None:
if tools:
for tool in tools:
function_dict = tool.get('function', {})
function_dict.pop("strict", None)
Expand All @@ -208,7 +214,7 @@ async def _arequest_chat_completion(
) -> Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]:
request_config = self.model_config_dict.copy()

if tools is not None:
if tools:
for tool in tools:
function_dict = tool.get('function', {})
function_dict.pop("strict", None)
Expand Down
38 changes: 35 additions & 3 deletions camel/models/qwen_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ async def _arun(
def _run(
self,
messages: List[OpenAIMessage],
response_format: Optional[Type[BaseModel]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
response_format: Optional[Type[BaseModel]],
tools: Optional[List[Dict[str, Any]]],
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
r"""Runs inference of Qwen chat completion.
Expand All @@ -130,13 +130,45 @@ def _run(
`ChatCompletion` in the non-stream mode, or
`Stream[ChatCompletionChunk]` in the stream mode.
"""
request_config = self._prepare_request(
messages, response_format, tools
)

response = self._client.chat.completions.create(
messages=messages,
model=self.model_type,
**self.model_config_dict,
**request_config,
)
return response

def _prepare_request(
self,
messages: List[OpenAIMessage],
response_format: Optional[Type[BaseModel]],
tools: Optional[List[Dict[str, Any]]],
) -> Dict[str, Any]:
request_config = self.model_config_dict.copy()

if tools:
request_config["tools"] = tools

if response_format is None:
return request_config

# get all keys of the response_format
response_format_keys = response_format.model_fields.keys()
additional_prompt = (
"The response should be in JSON format with the following keys: "
f"{', '.join(response_format_keys)}."
)
user_message = messages[-1]
user_message["content"] = (
f"{user_message['content']}\n{additional_prompt}"
)

request_config["response_format"] = {"type": "json_object"}
return request_config

@property
def token_counter(self) -> BaseTokenCounter:
r"""Initialize the token counter for the model backend.
Expand Down
32 changes: 32 additions & 0 deletions examples/simple_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from pydantic import BaseModel

from camel.agents import ChatAgent
from camel.models import ModelFactory
from camel.toolkits import WeatherToolkit
from camel.types import ModelPlatformType, ModelType

model = ModelFactory.create(
model_platform=ModelPlatformType.QWEN,
model_type=ModelType.QWEN_TURBO,
)


class ResponseFormat(BaseModel):
weather: str
time: str


agent = ChatAgent(model=model, tools=[WeatherToolkit().get_weather_data])

resp = agent.step(
"What's the current weather in New York?",
response_format=ResponseFormat,
)
print(resp.msg.content)


# resp = agent.step(
# "Format your last response.",
# response_format=ResponseFormat,
# )
# print(resp.msg.content)

0 comments on commit 520c29a

Please sign in to comment.