Skip to content

Commit

Permalink
fix: consider the tool role when being in open interpreter
Browse files Browse the repository at this point in the history
It uses a new message role tool, that contains some context for using
internal tools and the data it provides. So we need to consider it when
parsing messages, or when inserting context messages

Closes: #820
  • Loading branch information
yrobla committed Jan 29, 2025
1 parent e7ab015 commit 324d287
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 8 deletions.
15 changes: 11 additions & 4 deletions src/codegate/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from codegate.db.models import Alert, Output, Prompt
from codegate.pipeline.secrets.manager import SecretsManager
from codegate.utils.utils import get_tool_name_from_messages

logger = structlog.get_logger("codegate")

Expand Down Expand Up @@ -260,19 +261,25 @@ def get_last_user_message_block(
messages = request["messages"]
block_start_index = None

base_tool = get_tool_name_from_messages(request)
accepted_roles = ["user", "assistant"]
if base_tool == "open interpreter":
# open interpreter also uses the role "tool"
accepted_roles.append("tool")

# Iterate in reverse to find the last block of consecutive 'user' messages
for i in reversed(range(len(messages))):
if messages[i]["role"] == "user" or messages[i]["role"] == "assistant":
content_str = messages[i].get("content")
if messages[i]["role"] in accepted_roles:
content_str = messages[i].get("content")
if content_str is None:
continue

if messages[i]["role"] == "user":
if messages[i]["role"] in ["user", "tool"]:
user_messages.append(content_str)
block_start_index = i

# Specifically for Aider, when "Ok." block is found, stop
if content_str == "Ok." and messages[i]["role"] == "assistant":
if base_tool == "aider" and content_str == "Ok." and messages[i]["role"] == "assistant":
break
else:
# Stop when a message with a different role is encountered
Expand Down
14 changes: 11 additions & 3 deletions src/codegate/pipeline/codegate_context_retriever/codegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ async def process(
)

# split messages into double newlines, to avoid passing so many content in the search
split_messages = re.split(r"</?task>|(\n\n)", user_messages)
split_messages = re.split(r"</?task>|\n|\\n", user_messages)
collected_bad_packages = []
for item_message in split_messages:
for item_message in filter(None, map(str.strip, split_messages)):
# Vector search to find bad packages
bad_packages = await storage_engine.search(query=item_message, distance=0.5, limit=100)
if bad_packages and len(bad_packages) > 0:
Expand All @@ -128,12 +128,12 @@ async def process(
new_request = request.copy()

# perform replacement in all the messages starting from this index
base_tool = get_tool_name_from_messages(request)
for i in range(last_user_idx, len(new_request["messages"])):
message = new_request["messages"][i]
message_str = str(message["content"]) # type: ignore
context_msg = message_str
# Add the context to the last user message
base_tool = get_tool_name_from_messages(request)
if base_tool in ["cline", "kodu"]:
match = re.search(r"<task>\s*(.*?)\s*</task>(.*)", message_str, re.DOTALL)
if match:
Expand All @@ -149,6 +149,14 @@ async def process(
# Combine updated task content with the rest of the message
context_msg = updated_task_content + rest_of_message

elif base_tool == "open interpreter":
# if we find the context in a "tool" role, move it to the previous message
context_msg = f"Context: {context_str} \n\n Query: {message_str}" # type: ignore
if message["role"] == "tool":
if i > 0:
message_str = str(new_request["messages"][i-1]["content"]) # type: ignore
context_msg = f"Context: {context_str} \n\n Query: {message_str}" # type: ignore
new_request["messages"][i-1]["content"] = context_msg
else:
context_msg = f"Context: {context_str} \n\n Query: {message_str}" # type: ignore

Expand Down
2 changes: 1 addition & 1 deletion src/codegate/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def get_tool_name_from_messages(data):
Returns:
str: The name of the tool found in the messages, or None if no match is found.
"""
tools = ["Cline", "Kodu"]
tools = ["Cline", "Kodu", "Open Interpreter"]
for message in data.get("messages", []):
message_content = str(message.get("content", ""))
for tool in tools:
Expand Down

0 comments on commit 324d287

Please sign in to comment.