From 324d287d8d0c0d941fa17f800a48a6f5c1ac22b5 Mon Sep 17 00:00:00 2001 From: Yolanda Robla Date: Wed, 29 Jan 2025 16:08:56 +0100 Subject: [PATCH] fix: consider the tool role when being in open interpreter It uses a new message role tool, that contains some context for using internal tools and the data it provides. So we need to consider it when parsing messages, or when inserting context messages Closes: #820 --- src/codegate/pipeline/base.py | 15 +++++++++++---- .../codegate_context_retriever/codegate.py | 14 +++++++++++--- src/codegate/utils/utils.py | 2 +- 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/src/codegate/pipeline/base.py b/src/codegate/pipeline/base.py index f0e13196..82e1e47f 100644 --- a/src/codegate/pipeline/base.py +++ b/src/codegate/pipeline/base.py @@ -13,6 +13,7 @@ from codegate.db.models import Alert, Output, Prompt from codegate.pipeline.secrets.manager import SecretsManager +from codegate.utils.utils import get_tool_name_from_messages logger = structlog.get_logger("codegate") @@ -260,19 +261,25 @@ def get_last_user_message_block( messages = request["messages"] block_start_index = None + base_tool = get_tool_name_from_messages(request) + accepted_roles = ["user", "assistant"] + if base_tool == "open interpreter": + # open interpreter also uses the role "tool" + accepted_roles.append("tool") + # Iterate in reverse to find the last block of consecutive 'user' messages for i in reversed(range(len(messages))): - if messages[i]["role"] == "user" or messages[i]["role"] == "assistant": - content_str = messages[i].get("content") + if messages[i]["role"] in accepted_roles: + content_str = messages[i].get("content") if content_str is None: continue - if messages[i]["role"] == "user": + if messages[i]["role"] in ["user", "tool"]: user_messages.append(content_str) block_start_index = i # Specifically for Aider, when "Ok." block is found, stop - if content_str == "Ok." and messages[i]["role"] == "assistant": + if base_tool == "aider" and content_str == "Ok." and messages[i]["role"] == "assistant": break else: # Stop when a message with a different role is encountered diff --git a/src/codegate/pipeline/codegate_context_retriever/codegate.py b/src/codegate/pipeline/codegate_context_retriever/codegate.py index 27ebdd41..8aaadfbc 100644 --- a/src/codegate/pipeline/codegate_context_retriever/codegate.py +++ b/src/codegate/pipeline/codegate_context_retriever/codegate.py @@ -100,9 +100,9 @@ async def process( ) # split messages into double newlines, to avoid passing so many content in the search - split_messages = re.split(r"|(\n\n)", user_messages) + split_messages = re.split(r"|\n|\\n", user_messages) collected_bad_packages = [] - for item_message in split_messages: + for item_message in filter(None, map(str.strip, split_messages)): # Vector search to find bad packages bad_packages = await storage_engine.search(query=item_message, distance=0.5, limit=100) if bad_packages and len(bad_packages) > 0: @@ -128,12 +128,12 @@ async def process( new_request = request.copy() # perform replacement in all the messages starting from this index + base_tool = get_tool_name_from_messages(request) for i in range(last_user_idx, len(new_request["messages"])): message = new_request["messages"][i] message_str = str(message["content"]) # type: ignore context_msg = message_str # Add the context to the last user message - base_tool = get_tool_name_from_messages(request) if base_tool in ["cline", "kodu"]: match = re.search(r"\s*(.*?)\s*(.*)", message_str, re.DOTALL) if match: @@ -149,6 +149,14 @@ async def process( # Combine updated task content with the rest of the message context_msg = updated_task_content + rest_of_message + elif base_tool == "open interpreter": + # if we find the context in a "tool" role, move it to the previous message + context_msg = f"Context: {context_str} \n\n Query: {message_str}" # type: ignore + if message["role"] == "tool": + if i > 0: + message_str = str(new_request["messages"][i-1]["content"]) # type: ignore + context_msg = f"Context: {context_str} \n\n Query: {message_str}" # type: ignore + new_request["messages"][i-1]["content"] = context_msg else: context_msg = f"Context: {context_str} \n\n Query: {message_str}" # type: ignore diff --git a/src/codegate/utils/utils.py b/src/codegate/utils/utils.py index 51b3f931..38d2d77a 100644 --- a/src/codegate/utils/utils.py +++ b/src/codegate/utils/utils.py @@ -45,7 +45,7 @@ def get_tool_name_from_messages(data): Returns: str: The name of the tool found in the messages, or None if no match is found. """ - tools = ["Cline", "Kodu"] + tools = ["Cline", "Kodu", "Open Interpreter"] for message in data.get("messages", []): message_content = str(message.get("content", "")) for tool in tools: