From 8e09567d1610bda71b8d8c6374ce83f82ff78743 Mon Sep 17 00:00:00 2001 From: Michal Warda Date: Fri, 23 Feb 2024 21:01:13 +0100 Subject: [PATCH] Handle chat! through non stream! --- apps/api/lib/buildel/blocks/chat.ex | 146 +++++++++--------- .../chat_completion_message_formatter.ex | 54 +++++-- apps/api/lib/buildel/langchain/usage.ex | 15 +- .../chat_completion_controller.ex | 19 ++- 4 files changed, 139 insertions(+), 95 deletions(-) diff --git a/apps/api/lib/buildel/blocks/chat.ex b/apps/api/lib/buildel/blocks/chat.ex index 4e3a16033..4527ea996 100644 --- a/apps/api/lib/buildel/blocks/chat.ex +++ b/apps/api/lib/buildel/blocks/chat.ex @@ -1,5 +1,6 @@ defmodule Buildel.Blocks.Chat do require Logger + alias Buildel.Langchain.TokenUsage alias Buildel.Blocks.Utils.ChatMemory use Buildel.Blocks.Block use Buildel.Blocks.Utils.TakeLatest @@ -185,6 +186,14 @@ defmodule Buildel.Blocks.Chat do GenServer.cast(pid, {:save_tool_result, tool_name, content}) end + defp save_usage(pid, usage) do + GenServer.cast(pid, {:save_usage, usage}) + end + + defp usage(pid) do + GenServer.call(pid, {:usage}) + end + # Server @impl true @@ -225,7 +234,8 @@ defmodule Buildel.Blocks.Chat do initial_messages: initial_messages(state), type: memory_type }) - )} + ) + |> Map.put(:usage, Buildel.Langchain.TokenUsage.new!())} end @impl true @@ -304,6 +314,11 @@ defmodule Buildel.Blocks.Chat do {:noreply, state |> send_stream_stop()} end + @impl true + def handle_cast({:save_usage, usage}, state) do + {:noreply, %{state | usage: TokenUsage.add(state.usage, usage)}} + end + @impl true def handle_call({:function, %{block_name: block_name}}, _, state) do pid = self() @@ -367,10 +382,13 @@ defmodule Buildel.Blocks.Chat do @impl true def handle_call( - {:chat_completion, %{messages: messages, model: _model, stream: true, stream_to: pid}}, + {:chat_completion, + %{messages: messages, model: _model, stream: true, stream_to: stream_to}}, _from, state ) do + pid = self() + tools = state[:tool_connections] |> Enum.map(fn connection -> @@ -381,87 +399,65 @@ defmodule Buildel.Blocks.Chat do completion_id = "chatcmpl-#{:crypto.strong_rand_bytes(32) |> Base.encode64()}" Task.start(fn -> - chat().stream_chat(%{ - context: %{messages: messages}, - on_message: fn - %LangChain.MessageDelta{} = message -> - message = - Buildel.Blocks.Utils.ChatCompletionMessageFormatter.format_message_delta( - message, - completion_id, - state[:opts].model - ) - - send(pid, {:chat_completion, message}) + {:ok, _chain, last_message} = + chat().stream_chat(%{ + context: %{messages: messages}, + on_message: fn + %LangChain.MessageDelta{} = message -> + message = + Buildel.Blocks.Utils.ChatCompletionMessageFormatter.format_message_delta( + message, + completion_id, + state[:opts].model + ) + + send(stream_to, {:chat_completion, message}) + + %Buildel.Langchain.TokenUsage{} = usage -> + save_usage(pid, usage) + + _ -> + nil + end, + on_content: fn _content -> nil end, + on_tool_call: fn _tool_name, _arguments, _message -> nil end, + on_tool_content: fn _tool_name, _content, _message -> nil end, + on_cost: fn token_summary -> + chat_cost = Buildel.Costs.CostCalculator.calculate_chat_cost(token_summary) + + block_context().create_run_cost( + state[:context_id], + state[:block_name], + chat_cost + ) + end, + on_end: fn -> nil end, + on_error: fn _ -> nil end, + api_key: state[:api_key], + model: state[:opts].model, + temperature: state[:opts].temperature, + tools: tools, + endpoint: state[:opts].endpoint, + api_type: state[:opts].api_type + }) - %LangChain.Message{} = message -> - message = - Buildel.Blocks.Utils.ChatCompletionMessageFormatter.format_message( - message, - completion_id, - state[:opts].model - ) - - send(pid, {:chat_end, message}) + message = + Buildel.Blocks.Utils.ChatCompletionMessageFormatter.format_message( + last_message, + completion_id, + state[:opts].model, + usage(pid) + ) - _ -> - nil - end, - on_content: fn _content -> nil end, - on_tool_call: fn _tool_name, _arguments, _message -> nil end, - on_tool_content: fn _tool_name, _content, _message -> nil end, - on_cost: fn _token_summary -> nil end, - on_end: fn -> nil end, - on_error: fn _ -> nil end, - api_key: state[:api_key], - model: state[:opts].model, - temperature: state[:opts].temperature, - tools: tools, - endpoint: state[:opts].endpoint, - api_type: state[:opts].api_type - }) + send(stream_to, {:chat_end, message}) end) {:reply, {:ok, "streaming"}, state} end @impl true - def handle_call({:chat_completion, %{messages: messages, model: _model}}, _from, state) do - tools = - state[:tool_connections] - |> Enum.map(fn connection -> - pid = block_context().block_pid(state[:context_id], connection.from.block_name) - Buildel.Blocks.Block.function(pid, %{block_name: state.block_name}) - end) - - completion_id = "chatcmpl-#{:crypto.strong_rand_bytes(32) |> Base.encode64()}" - - with {:ok, _chain, message} = - chat().stream_chat(%{ - context: %{messages: messages}, - on_message: fn _message -> nil end, - on_content: fn _content -> nil end, - on_tool_call: fn _tool_name, _arguments, _message -> nil end, - on_tool_content: fn _tool_name, _content, _message -> nil end, - on_cost: fn _token_summary -> nil end, - on_end: fn -> nil end, - on_error: fn _ -> nil end, - api_key: state[:api_key], - model: state[:opts].model, - temperature: state[:opts].temperature, - tools: tools, - endpoint: state[:opts].endpoint, - api_type: state[:opts].api_type - }) do - message = - Buildel.Blocks.Utils.ChatCompletionMessageFormatter.format_message( - message, - completion_id, - state[:opts].model - ) - - {:reply, {:ok, message}, state} - end + def handle_call({:usage}, _from, state) do + {:reply, state.usage, state} end @impl true diff --git a/apps/api/lib/buildel/blocks/utils/chat_completion_message_formatter.ex b/apps/api/lib/buildel/blocks/utils/chat_completion_message_formatter.ex index 7205140bb..8fba14a81 100644 --- a/apps/api/lib/buildel/blocks/utils/chat_completion_message_formatter.ex +++ b/apps/api/lib/buildel/blocks/utils/chat_completion_message_formatter.ex @@ -4,13 +4,43 @@ defmodule Buildel.Blocks.Utils.ChatCompletionMessageFormatter do delta = case reason do - "stop" -> %{} - _ -> %{"content" => message_delta.content} + "stop" -> + %{ + "role" => "assistant" + } + + _ -> + %{"content" => message_delta.content} + end + + delta = + case message_delta.function_name do + nil -> + delta + + function_name -> + %{ + "content" => "Calling: #{function_name} " + } + end + + delta = + case message_delta.arguments do + nil -> + delta + + "" -> + delta + + arguments -> + %{ + "content" => arguments + } end choices = [ %{ - "finish_reason" => finish_reason(message_delta), + "finish_reason" => reason, "index" => message_delta.index, "delta" => delta, "logprobs" => nil @@ -26,7 +56,7 @@ defmodule Buildel.Blocks.Utils.ChatCompletionMessageFormatter do } end - def format_message(message, completion_id, model) do + def format_message(message, completion_id, model, usage) do choices = [ %{ "finish_reason" => finish_reason(message), @@ -45,12 +75,7 @@ defmodule Buildel.Blocks.Utils.ChatCompletionMessageFormatter do "id" => completion_id, "model" => model, "object" => "chat.completion", - # TODO: Add usage - "usage" => %{ - "completion_tokens" => 17, - "prompt_tokens" => 57, - "total_tokens" => 74 - } + "usage" => usage(usage) } end @@ -72,7 +97,16 @@ defmodule Buildel.Blocks.Utils.ChatCompletionMessageFormatter do defp role(message) do case message.role do :function_call -> "function_call" + :assistant -> "assistant" _ -> "other" end end + + defp usage(%Buildel.Langchain.TokenUsage{} = usage) do + %{ + "completion_tokens" => usage.completion_tokens, + "prompt_tokens" => usage.prompt_tokens, + "total_tokens" => usage.total_tokens + } + end end diff --git a/apps/api/lib/buildel/langchain/usage.ex b/apps/api/lib/buildel/langchain/usage.ex index fbfc3f504..681dfc7b4 100644 --- a/apps/api/lib/buildel/langchain/usage.ex +++ b/apps/api/lib/buildel/langchain/usage.ex @@ -7,9 +7,9 @@ defmodule Buildel.Langchain.TokenUsage do @primary_key false embedded_schema do - field :completion_tokens, :integer - field :prompt_tokens, :integer - field :total_tokens, :integer + field :completion_tokens, :integer, default: 0 + field :prompt_tokens, :integer, default: 0 + field :total_tokens, :integer, default: 0 end @type t :: %TokenUsage{} @@ -40,4 +40,13 @@ defmodule Buildel.Langchain.TokenUsage do {:error, changeset} -> raise LangChainError, changeset end end + + @spec add(usage :: t(), another_usage :: t()) :: t() + def add(usage, another_usage) do + %TokenUsage{ + completion_tokens: usage.completion_tokens + another_usage.completion_tokens, + prompt_tokens: usage.prompt_tokens + another_usage.prompt_tokens, + total_tokens: usage.total_tokens + another_usage.total_tokens + } + end end diff --git a/apps/api/lib/buildel_web/controllers/organizations/pipelines/chat_completions/chat_completion_controller.ex b/apps/api/lib/buildel_web/controllers/organizations/pipelines/chat_completions/chat_completion_controller.ex index 01f252d6b..8ddf12636 100644 --- a/apps/api/lib/buildel_web/controllers/organizations/pipelines/chat_completions/chat_completion_controller.ex +++ b/apps/api/lib/buildel_web/controllers/organizations/pipelines/chat_completions/chat_completion_controller.ex @@ -92,13 +92,18 @@ defmodule BuildelWeb.OrganizationPipelineChatCompletionController do Pipelines.get_organization_pipeline(organization, pipeline_id), {:ok, config} <- Pipelines.get_pipeline_config(pipeline, "latest"), {:ok, run} <- Pipelines.create_run(%{pipeline_id: pipeline_id, config: config}), - {:ok, run} <- Pipelines.Runner.start_run(run), - {:ok, chat_completion} <- - Pipelines.Runner.create_chat_completion(run, params), - {:ok, _run} <- Pipelines.Runner.stop_run(run) do - conn - |> put_status(:ok) - |> json(chat_completion) + {:ok, run} <- Pipelines.Runner.start_run(run) do + {:ok, _} = + Pipelines.Runner.create_chat_completion_stream(run, params |> Map.put(:stream, true)) + + receive do + {:chat_end, message} -> + Pipelines.Runner.stop_run(run) + + conn + |> put_status(:ok) + |> json(message) + end end end end