Skip to content

Commit

Permalink
Reintroduce cached history
Browse files Browse the repository at this point in the history
  • Loading branch information
mbertrand committed Dec 19, 2024
1 parent 863835e commit 8ee0425
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 4 deletions.
46 changes: 45 additions & 1 deletion ai_agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import pydantic
import requests
from django.conf import settings
from django.core.cache import caches
from django.utils.module_loading import import_string
from llama_cloud import ChatMessage
from llama_index.agent.openai import OpenAIAgent
from llama_index.core.agent import AgentRunner
from llama_index.core.constants import DEFAULT_TEMPERATURE
Expand Down Expand Up @@ -45,6 +47,7 @@ class BaseChatAgent(ABC):
# For LiteLLM tracking purposes
JOB_ID = "BASECHAT_JOB"
TASK_NAME = "BASECHAT_TASK"
CACHE_PREFIX = "base_ai_"

def __init__(
self,
Expand All @@ -67,6 +70,39 @@ def __init__(
else:
self.proxy = None
self.agent = None
self.save_history = settings.AI_CACHE_HISTORY and self.user_id
if self.save_history:
self.cache = caches[settings.AI_CACHE]
self.cache_timeout = settings.AI_CACHE_TIMEOUT
self.cache_key = f"{self.CACHE_PREFIX}{self.user_id}"

def get_or_create_chat_history_cache(self) -> None:
"""
Get the user chat history from the cache and load it into the
llamaindex agent's chat history (agent.chat_history).
Create an empty cache key if it doesn't exist.
"""
if self.cache_key in self.cache:
try:
for message in json.loads(self.cache.get(self.cache_key)):
self.agent.chat_history.append(ChatMessage(**message))
except json.JSONDecodeError:
self.cache.set(self.cache_key, "[]", timeout=self.cache_timeout)
else:
if self.proxy:
self.proxy.create_proxy_user(self.user_id)
self.cache.set(self.cache_key, "[]", timeout=self.cache_timeout)

def save_chat_history(self) -> None:
"""Save the agent chat history to the cache"""
chat_history = [
message.dict()
for message in self.agent.chat_history
if message.role != "tool" and message.content
]
self.cache.set(
self.cache_key, json.dumps(chat_history), timeout=settings.AI_CACHE_TIMEOUT
)

def create_agent(self) -> AgentRunner:
"""Create an AgentRunner for the relevant AI source"""
Expand All @@ -87,6 +123,9 @@ def create_openai_agent(self) -> OpenAIAgent:
def clear_chat_history(self) -> None:
"""Clear the chat history from the cache"""
self.agent.chat_history.clear()
if self.save_history:
self.cache.delete(self.cache_key)
self.get_or_create_chat_history_cache()

@abstractmethod
def get_comment_metadata(self):
Expand Down Expand Up @@ -127,6 +166,8 @@ def get_completion(self, message: str, *, debug: bool = settings.AI_DEBUG) -> st
log.exception("Error running AI agent")
if debug:
yield f"\n\n<!-- {self.get_comment_metadata()} -->\n\n"
if self.save_history:
self.save_chat_history()


class RecommendationAgent(BaseChatAgent):
Expand Down Expand Up @@ -370,12 +411,15 @@ def create_openai_agent(self) -> OpenAIAgent:
self.proxy.get_additional_kwargs(self) if self.proxy else {}
),
)
return OpenAIAgent.from_tools(
agent = OpenAIAgent.from_tools(
tools=self.create_tools(),
llm=llm,
verbose=True,
system_prompt=self.instructions,
)
if settings.AI_CACHE_HISTORY:
self.get_or_create_chat_history_cache()
return agent

def create_tools(self):
"""Create tools required by the agent"""
Expand Down
8 changes: 8 additions & 0 deletions ai_agents/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@
from ai_agents.factories import ChatMessageFactory


@pytest.fixture(autouse=True)
def ai_settings(settings):
"""Assign default AI settings"""
settings.AI_PROXY = None
settings.AI_PROXY_URL = None
return settings


@pytest.fixture
def chat_history():
"""Return one round trip chat history for testing."""
Expand Down
12 changes: 10 additions & 2 deletions ai_agents/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,16 @@ class RecommendationAgentConsumer(AsyncWebsocketConsumer):
async def connect(self):
"""Connect to the websocket and initialize the AI agent."""
user = self.scope.get("user", None)
self.user_id = user.username if user else "anonymous"
log.info("Username is %s", self.user_id)
session = self.scope.get("session", None)

if user and user.username:
self.user_id = user.username
elif session:
if not session.session_key:
session.save()
self.user_id = session.session_key
else:
self.user_id = None

self.agent = RecommendationAgent(self.user_id)
await super().connect()
Expand Down
3 changes: 2 additions & 1 deletion main/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,8 +550,9 @@ def get_all_config_keys():

# AI settings
AI_DEBUG = get_bool("AI_DEBUG", False) # noqa: FBT003
AI_CACHE_TIMEOUT = get_int(name="AI_CACHE_TIMEOUT", default=3600)
AI_CACHE = get_string(name="AI_CACHE", default="redis")
AI_CACHE_HISTORY = get_bool(name="AI_CACHE_HISTORY", default=True)
AI_CACHE_TIMEOUT = get_int(name="AI_CACHE_TIMEOUT", default=3600)
AI_MIT_SEARCH_URL = get_string(
name="AI_MIT_SEARCH_URL",
default="https://api.learn.mit.edu/api/v1/learning_resources_search/",
Expand Down

0 comments on commit 8ee0425

Please sign in to comment.