Skip to content

Commit

Permalink
unable to past messages as history
Browse files Browse the repository at this point in the history
  • Loading branch information
luv-singh-ai committed Jan 11, 2025
1 parent 46a8c56 commit 25241b9
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 121 deletions.
Binary file modified db/conversations.db
Binary file not shown.
174 changes: 89 additions & 85 deletions new.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def load_prompts(file_path='prompts.yaml'):
# model = ChatAnthropic(model="claude-3-5-haiku-20241022", api_key=anthropic_api_key, max_tokens=300, max_retries=2, temperature=0.7)

# in_memory_store = InMemoryStore()
store = InMemoryStore()
# store = InMemoryStore()

# Define the state
class State(TypedDict):
Expand All @@ -58,54 +58,54 @@ class State(TypedDict):
# Define the conversational prompt
conversational_prompt = ChatPromptTemplate.from_messages([
("system", prompts['MYCA']),
("system", "Context from past conversations:\n{memory_context}"),
# ("system", "Context from past conversations:\n{memory_context}"),
("human", "{input}"),
])

class ConversationalAgent:
def __init__(self, model, store: BaseStore, max_memories: int = 10):
def __init__(self, model): # , store: BaseStore, max_memories: int = 10
# self.prompt_template = prompt_template
self.model = model
self.store = store
self.max_memories = max_memories
# self.store = store
# self.max_memories = max_memories
# Cache the memory prompt template for reuse
self.conversational_prompt = ChatPromptTemplate.from_messages([
("system", prompts['MYCA']),
("system", "Context from past conversations:\n{memory_context}"),
# ("system", "Context from past conversations:\n{memory_context}"),
("human", "{input}")
])

# Initialize SQLite database
self._init_database()
# self._init_database()

def _init_database(self):
"""Initialize SQLite database with required table"""
conn = sqlite3.connect('db/conversations.db')
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS conversations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
namespace TEXT,
message TEXT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
)
''')
conn.commit()
conn.close()
# def _init_database(self):
# """Initialize SQLite database with required table"""
# conn = sqlite3.connect('db/conversations.db')
# cursor = conn.cursor()
# cursor.execute('''
# CREATE TABLE IF NOT EXISTS conversations (
# id INTEGER PRIMARY KEY AUTOINCREMENT,
# namespace TEXT,
# message TEXT,
# timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
# )
# ''')
# conn.commit()
# conn.close()

def fetch_conversations_from_sqlite(self, namespace):
# Implement the logic to fetch past conversations from SQLite
# For example, retrieve the last N messages for the user
try:
conn = sqlite3.connect('db/conversations.db')
cursor = conn.cursor()
cursor.execute("SELECT message FROM conversations WHERE namespace=? ORDER BY timestamp DESC LIMIT ?", (str(namespace),self.max_memories))
rows = cursor.fetchall()
conn.close()
return [row[0] for row in rows]
except Exception as e:
logging.error(f"Error fetching conversations: {str(e)}")
return []
# def fetch_conversations_from_sqlite(self, namespace):
# # Implement the logic to fetch past conversations from SQLite
# # For example, retrieve the last N messages for the user
# try:
# conn = sqlite3.connect('db/conversations.db')
# cursor = conn.cursor()
# cursor.execute("SELECT message FROM conversations WHERE namespace=? ORDER BY timestamp DESC LIMIT ?", (str(namespace),self.max_memories))
# rows = cursor.fetchall()
# conn.close()
# return [row[0] for row in rows]
# except Exception as e:
# logging.error(f"Error fetching conversations: {str(e)}")
# return []

def summarize_conversations(self, conversations):
try:
Expand All @@ -122,68 +122,68 @@ def summarize_conversations(self, conversations):
return ""

# Fixed method to handle different message types
def store_conversation_in_sqlite(self, namespace, messages):
try:
conn = sqlite3.connect('db/conversations.db')
cursor = conn.cursor()
for message in messages:
# Handle different message types
content = message.content if hasattr(message, 'content') else str(message)
cursor.execute(
"INSERT INTO conversations (namespace, message, timestamp) VALUES (?, ?, ?)",
(str(namespace), content, datetime.now())
)
conn.commit()
conn.close()
# def store_conversation_in_sqlite(self, namespace, messages):
# try:
# conn = sqlite3.connect('db/conversations.db')
# cursor = conn.cursor()
# for message in messages:
# # Handle different message types
# content = message.content if hasattr(message, 'content') else str(message)
# cursor.execute(
# "INSERT INTO conversations (namespace, message, timestamp) VALUES (?, ?, ?)",
# (str(namespace), content, datetime.now())
# )
# conn.commit()
# conn.close()

# Cleanup old messages
self._cleanup_old_messages(namespace)
except Exception as e:
logging.error(f"Error storing conversation: {str(e)}")
# # Cleanup old messages
# self._cleanup_old_messages(namespace)
# except Exception as e:
# logging.error(f"Error storing conversation: {str(e)}")

# Added new method to cleanup old messages
def _cleanup_old_messages(self, namespace):
"""Keep only the latest max_memories messages"""
try:
conn = sqlite3.connect('db/conversations.db')
cursor = conn.cursor()
cursor.execute("""
DELETE FROM conversations
WHERE namespace = ?
AND id NOT IN (
SELECT id FROM conversations
WHERE namespace = ?
ORDER BY timestamp DESC
LIMIT ?
)
""", (str(namespace), str(namespace), self.max_memories))
conn.commit()
conn.close()
except Exception as e:
logging.error(f"Error cleaning up messages: {str(e)}")
# def _cleanup_old_messages(self, namespace):
# """Keep only the latest max_memories messages"""
# try:
# conn = sqlite3.connect('db/conversations.db')
# cursor = conn.cursor()
# cursor.execute("""
# DELETE FROM conversations
# WHERE namespace = ?
# AND id NOT IN (
# SELECT id FROM conversations
# WHERE namespace = ?
# ORDER BY timestamp DESC
# LIMIT ?
# )
# """, (str(namespace), str(namespace), self.max_memories))
# conn.commit()
# conn.close()
# except Exception as e:
# logging.error(f"Error cleaning up messages: {str(e)}")

# Updated main method to use class methods and handle streaming
def run_conversational_agent(self, state: State):
try:
# Get user ID from state config
user_id = state.get("configurable", {}).get("user_id", "default")
namespace = f"memories_user_{user_id}"
# user_id = state.get("configurable", {}).get("user_id", "default")
# namespace = f"memories_user_{user_id}"

# Fetch and summarize past conversations
past_conversations = self.fetch_conversations_from_sqlite(namespace)
summary = self.summarize_conversations(past_conversations)
# # Fetch and summarize past conversations
# past_conversations = self.fetch_conversations_from_sqlite(namespace)
# summary = self.summarize_conversations(past_conversations)

# Format messages with context
formatted_messages = self.conversational_prompt.format_messages(
memory_context=summary,
input=state["messages"][-1].content
)
# formatted_messages = self.conversational_prompt.format_messages(
# memory_context=summary,
# input=state["messages"][-1].content
# )

# Get streaming response from model
response = self.model.invoke(formatted_messages)

# response = self.model.invoke(formatted_messages)
response = self.model.invoke(state["messages"][-1].content)
# Store the conversation
self.store_conversation_in_sqlite(namespace, state["messages"] + [response])
# self.store_conversation_in_sqlite(namespace, state["messages"] + [response])

return {"messages": state["messages"] + [response]}

Expand All @@ -192,7 +192,8 @@ def run_conversational_agent(self, state: State):
return {"messages": [AIMessage(content="I'm here to help. Could you please rephrase that?")]}

# Instantiate the conversational agent with tools
conversational_agent = ConversationalAgent(model, store, max_memories=10) # conversational_prompt, llm_with_tools
# conversational_agent = ConversationalAgent(model, store, max_memories=10) # conversational_prompt, llm_with_tools
conversational_agent = ConversationalAgent(model)
# agent = ConversationalAgent(prompt_template=conversational_prompt,model=model,store=store,max_memories=10)

# # Define the router function
Expand All @@ -213,7 +214,8 @@ def run_conversational_agent(self, state: State):

# Compile the graph
memory = MemorySaver()
graph = workflow.compile(checkpointer=memory, store=store)
graph = workflow.compile(checkpointer=memory)
# graph = workflow.compile(checkpointer=memory, store=store)

# Function to run a conversation turn
def chat(message: str, config: dict, history: List):
Expand All @@ -224,9 +226,11 @@ def chat(message: str, config: dict, history: List):
for msg in history:
# Add user message if it exists and is not empty
if msg.get("user"):
logging.info(f"Adding user message to history: {msg['user']}")
messages.append(HumanMessage(content=msg["user"]))
# Add AI response if it exists and is not empty
if msg.get("response"):
logging.info(f"Adding AI response to history: {msg['response']}")
messages.append(AIMessage(content=msg["response"]))

# Add current message
Expand All @@ -235,7 +239,7 @@ def chat(message: str, config: dict, history: List):

# Add current message
messages.append(HumanMessage(content=message))

logging.info(f"\n current FULL MESSAGE is: \n{messages}\n")
# Invoke the model with messages and config
try:
result = graph.invoke({"messages": messages}, config=config)
Expand Down
1 change: 1 addition & 0 deletions sukoon_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ async def process_query(request: MYCARequest, supabase: SupabaseManager = Depend

# Process chat. If chat() is blocking, consider running it in a threadpool using run_in_executor.
history = supabase.get_chat_history(mobile=mobile)
logger.info("Retrieved chat history for mobile %s: %s", mobile, history)
response = chat(user_input, config, history)
chat_response = response.content

Expand Down
Loading

0 comments on commit 25241b9

Please sign in to comment.