Skip to content

Commit

Permalink
update reducer & fix a few small bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarda committed Jul 29, 2024
1 parent 7580724 commit 2f49c5f
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 10 deletions.
29 changes: 27 additions & 2 deletions backend/app/graphs/new_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from langgraph.graph.message import Messages, add_messages
from langgraph.managed.few_shot import FewShotExamples
from langgraph.prebuilt import ToolNode

Expand All @@ -26,8 +26,33 @@ def filter_by_assistant_id(config: RunnableConfig) -> Dict[str, Any]:
return {}


def custom_add_messages(left: Messages, right: Messages):
combined_messages = add_messages(left, right)
for message in combined_messages:
# TODO: handle this correctly in ChatAnthropic.
# this is needed to handle content blocks in AIMessageChunk when using
# streaming with the graph. if we don't have that, all of the AIMessages
# will have list of dicts in the content
if (
isinstance(message, AIMessage)
and isinstance(message.content, list)
and (
text_content_blocks := [
content_block
for content_block in message.content
if content_block["type"] == "text"
]
)
):
message.content = "".join(
content_block["text"] for content_block in text_content_blocks
)

return combined_messages


class BaseState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
messages: Annotated[list[AnyMessage], custom_add_messages]
examples: Annotated[
list, FewShotExamples.configure(metadata_filter=filter_by_assistant_id)
]
Expand Down
8 changes: 2 additions & 6 deletions backend/app/graphs/new_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,7 @@ def _get_messages(messages, system_message=DEFAULT_SYSTEM_MESSAGE):
@chain
async def get_search_query(messages: Sequence[BaseMessage], config):
llm = get_llm(
config["configurable"].get(
"agent==chat_retrieval/llm_type", LLMType.GPT_35_TURBO
)
config["configurable"].get("type==chat_retrieval/llm_type", LLMType.GPT_4O_MINI)
)
convo = []
for m in messages:
Expand Down Expand Up @@ -137,9 +135,7 @@ async def retrieve(state: AgentState, config):
def call_model(state: AgentState, config):
messages = state["messages"]
llm = get_llm(
config["configurable"].get(
"agent==chat_retrieval/llm_type", LLMType.GPT_4O_MINI
)
config["configurable"].get("type==chat_retrieval/llm_type", LLMType.GPT_4O_MINI)
)
response = llm.invoke(
_get_messages(
Expand Down
5 changes: 3 additions & 2 deletions backend/tests/unit_tests/app/helpers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from contextlib import asynccontextmanager

from app.lifespan import lifespan
from httpx import AsyncClient, ASGITransport
from httpx import ASGITransport, AsyncClient
from typing_extensions import AsyncGenerator

from app.lifespan import lifespan


@asynccontextmanager
async def get_client() -> AsyncGenerator[AsyncClient, None]:
Expand Down

0 comments on commit 2f49c5f

Please sign in to comment.