Skip to content

Commit

Permalink
Alterrnate SSE endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
mbertrand committed Dec 19, 2024
1 parent 8ee0425 commit 9c6202e
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 34 deletions.
10 changes: 4 additions & 6 deletions ai_agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from django.conf import settings
from django.core.cache import caches
from django.utils.module_loading import import_string
from llama_cloud import ChatMessage
from llama_index.agent.openai import OpenAIAgent
from llama_index.core.agent import AgentRunner
from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.constants import DEFAULT_TEMPERATURE
from llama_index.core.tools import FunctionTool, ToolMetadata
from llama_index.llms.openai import OpenAI
Expand Down Expand Up @@ -335,8 +335,6 @@ def __init__(
)
self.search_parameters = []
self.search_results = []

self.agent = self.create_agent()
self.create_agent()

def search_courses(self, q: str, **kwargs) -> str:
Expand Down Expand Up @@ -411,15 +409,15 @@ def create_openai_agent(self) -> OpenAIAgent:
self.proxy.get_additional_kwargs(self) if self.proxy else {}
),
)
agent = OpenAIAgent.from_tools(
self.agent = OpenAIAgent.from_tools(
tools=self.create_tools(),
llm=llm,
verbose=True,
system_prompt=self.instructions,
)
if settings.AI_CACHE_HISTORY:
if self.save_history:
self.get_or_create_chat_history_cache()
return agent
return self.agent

def create_tools(self):
"""Create tools required by the agent"""
Expand Down
125 changes: 105 additions & 20 deletions ai_agents/consumers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import json
import logging

from channels.generic.http import AsyncHttpConsumer
from channels.generic.websocket import AsyncWebsocketConsumer
from channels.layers import get_channel_layer
from django.utils.text import slugify
from llama_index.core.base.llms.types import ChatMessage

from ai_agents.agents import RecommendationAgent
Expand All @@ -10,7 +13,33 @@
log = logging.getLogger(__name__)


class RecommendationAgentConsumer(AsyncWebsocketConsumer):
def process_message(message_json, agent) -> str:
"""
Validate the message, update the agent if necessary
"""
text_data_json = json.loads(message_json)
serializer = ChatRequestSerializer(data=text_data_json)
serializer.is_valid(raise_exception=True)
message_text = serializer.validated_data.pop("message", "")
clear_history = serializer.validated_data.pop("clear_history", False)
temperature = serializer.validated_data.pop("temperature", None)
instructions = serializer.validated_data.pop("instructions", None)
model = serializer.validated_data.pop("model", None)

if clear_history:
agent.agent.clear_chat_history()
if model:
agent.agent.agent_worker._llm.model = model # noqa: SLF001
if temperature:
agent.agent.agent_worker._llm.temperature = temperature # noqa: SLF001
if instructions:
agent.agent.agent_worker.prefix_messages = [
ChatMessage(content=instructions, role="system")
]
return message_text


class RecommendationAgentWSConsumer(AsyncWebsocketConsumer):
"""
Async websocket consumer for the recommendation agent.
"""
Expand All @@ -36,25 +65,7 @@ async def receive(self, text_data: str) -> str:
"""Send the message to the AI agent and return its response."""

try:
text_data_json = json.loads(text_data)
serializer = ChatRequestSerializer(data=text_data_json)
serializer.is_valid(raise_exception=True)
message_text = serializer.validated_data.pop("message", "")
clear_history = serializer.validated_data.pop("clear_history", False)
temperature = serializer.validated_data.pop("temperature", None)
instructions = serializer.validated_data.pop("instructions", None)
model = serializer.validated_data.pop("model", None)

if clear_history:
self.agent.clear_chat_history()
if model:
self.agent.agent.agent_worker._llm.model = model # noqa: SLF001
if temperature:
self.agent.agent.agent_worker._llm.temperature = temperature # noqa: SLF001
if instructions:
self.agent.agent.agent_worker.prefix_messages = [
ChatMessage(content=instructions, role="system")
]
message_text = process_message(text_data, self.agent)

for chunk in self.agent.get_completion(message_text):
await self.send(text_data=chunk)
Expand All @@ -63,3 +74,77 @@ async def receive(self, text_data: str) -> str:
finally:
# This is a bit hacky, but it works for now
await self.send(text_data="!endResponse")


class RecommendationAgentSSEConsumer(AsyncHttpConsumer):
async def handle(self, message: str):
user = self.scope.get("user", None)
session = self.scope.get("session", None)

if user and user.username and user.username != "AnonymousUser":
self.user_id = user.username
elif session:
if not session.session_key:
session.save()
self.user_id = slugify(session.session_key)[:100]
else:
log.info("Anon user, no session")
self.user_id = "Anonymous"

agent = RecommendationAgent(self.user_id)

self.channel_layer = get_channel_layer()
self.room_name = "recommendation_bot"
self.room_group_name = f"recommendation_bot_{self.user_id}"
await self.channel_layer.group_add(
f"recommendation_bot_{self.user_id}", self.channel_name
)

await self.send_headers(
headers=[
(b"Cache-Control", b"no-cache"),
(
b"Content-Type",
b"text/event-stream",
),
(
b"Transfer-Encoding",
b"chunked",
),
(b"Connection", b"keep-alive"),
]
)
# Headers are only sent after the first body event.
# Set "more_body" to tell the interface server to not
# finish the response yet:
payload = "\nevent: ping", "data: null\n\n\n"
await self.send_body(payload.encode("utf-8"), more_body=True)

try:
message_text = process_message(message, agent)

for chunk in agent.get_completion(message_text):
await self.send_event(event=chunk)
except: # noqa: E722
log.exception("Error in RecommendationAgentConsumer")
finally:
self.disconnect()

async def disconnect(self):
await self.channel_layer.group_discard(f"sse_{self.user_id}", self.channel_name)

async def send_event(self, event: str):
# Send response event
log.info(event)
data = f"event: agent_response\ndata: {event}\n\n"
await self.send_body(data.encode("utf-8"), more_body=True)

async def http_request(self, message):
"""
Receives an SSE request and holds the connection open
until the client or server chooses to disconnect.
"""
try:
await self.handle(message.get("body"))
finally:
pass
2 changes: 1 addition & 1 deletion ai_agents/consumers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def agent_user():
@pytest.fixture
def recommendation_consumer(agent_user):
"""Return a recommendation consumer."""
consumer = consumers.RecommendationAgentConsumer()
consumer = consumers.RecommendationAgentWSConsumer()
consumer.scope = {"user": agent_user}
return consumer

Expand Down
14 changes: 11 additions & 3 deletions ai_agents/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,19 @@

from ai_agents import consumers

websocket_urlpatterns = [
websocket_patterns = [
# websocket URLs go here
re_path(
r"ws/recommendation_agent/",
consumers.RecommendationAgentConsumer.as_asgi(),
name="recommendation_agent",
consumers.RecommendationAgentWSConsumer.as_asgi(),
name="recommendation_agent_ws",
),
]

http_patterns = [
re_path(
r"sse/recommendation_agent/",
consumers.RecommendationAgentSSEConsumer.as_asgi(),
name="recommendation_agent_sse",
),
]
8 changes: 5 additions & 3 deletions main/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@

os.environ.setdefault("DJANGO_SETTINGS_MODULE", "main.settings")

import ai_agents.routing
django_asgi_app = get_asgi_application()

import ai_agents.routing # noqa: E402

application = ProtocolTypeRouter(
{
"http": get_asgi_application(),
"http": AuthMiddlewareStack(URLRouter(ai_agents.routing.http_patterns)),
"websocket": AuthMiddlewareStack(
URLRouter(ai_agents.routing.websocket_urlpatterns)
URLRouter(ai_agents.routing.websocket_patterns)
),
}
)
9 changes: 9 additions & 0 deletions main/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,15 @@ def get_all_config_keys():
KEYCLOAK_ADMIN_SECURE = get_bool("KEYCLOAK_ADMIN_SECURE", True) # noqa: FBT003


CHANNEL_LAYERS = {
"default": {
"BACKEND": "channels_redis.core.RedisChannelLayer",
"CONFIG": {
"hosts": [("redis", 6379)],
},
},
}

# AI settings
AI_DEBUG = get_bool("AI_DEBUG", False) # noqa: FBT003
AI_CACHE = get_string(name="AI_CACHE", default="redis")
Expand Down
23 changes: 22 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ uvicorn = {extras = ["standard"], version = "^0.32.1"}
django-guardian = "^2.4.0"
named-enum = "^1.4.0"
ulid-py = "^0.2.0"
channels-redis = "^4.2.1"

[tool.poetry.group.dev.dependencies]
bpython = "^0.24"
Expand Down

0 comments on commit 9c6202e

Please sign in to comment.