Skip to content

Commit

Permalink
Introduce Usage to mistral_ai
Browse files Browse the repository at this point in the history
  • Loading branch information
michalwarda committed Feb 18, 2024
1 parent 618154b commit 3b13559
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 12 deletions.
5 changes: 5 additions & 0 deletions apps/api/lib/buildel/clients/chat.ex
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ end

defmodule Buildel.Clients.Chat do
require Logger
alias Buildel.Langchain.TokenUsage
alias Buildel.Langchain.ChatModels.ChatMistralAI
alias Buildel.LangChain.ChatModels.ChatGoogleAI
alias Buildel.Langchain.ChatGptTokenizer
Expand Down Expand Up @@ -104,6 +105,10 @@ defmodule Buildel.Clients.Chat do
%Message{} ->
nil

%TokenUsage{} = usage ->
IO.inspect(usage)
nil

{:error, reason} ->
on_error.(reason)
nil
Expand Down
29 changes: 17 additions & 12 deletions apps/api/lib/buildel/langchain/chat_mistral_ai.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ defmodule Buildel.Langchain.ChatModels.ChatMistralAI do
require Logger
import Ecto.Changeset
import LangChain.Utils.ApiOverride
alias Buildel.Langchain.TokenUsage
alias LangChain.ForOpenAIApi
alias __MODULE__
alias LangChain.Config
Expand Down Expand Up @@ -239,7 +240,9 @@ defmodule Buildel.Langchain.ChatModels.ChatMistralAI do
headers: get_headers(mistral),
receive_timeout: mistral.receive_timeout
)
|> Req.post(into: Utils.handle_stream_fn(mistral, &do_process_response/1, callback_fn))
|> Req.post(
into: Utils.handle_stream_fn(mistral, &do_process_response(&1, callback_fn), callback_fn)
)
|> case do
{:ok, %Req.Response{body: data}} ->
data
Expand All @@ -264,6 +267,11 @@ defmodule Buildel.Langchain.ChatModels.ChatMistralAI do
end
end

def do_process_response(data, callback_fn) do
call_callback_with_token_usage(data, callback_fn)
do_process_response(data)
end

# Parse a new message response
@doc false
@spec do_process_response(data :: %{String.t() => any()} | {:error, any()}) ::
Expand All @@ -272,7 +280,7 @@ defmodule Buildel.Langchain.ChatModels.ChatMistralAI do
| MessageDelta.t()
| [MessageDelta.t()]
| {:error, String.t()}
def do_process_response(%{"choices" => choices}) when is_list(choices) do
def do_process_response(%{"choices" => choices} = _msg) when is_list(choices) do
# process each response individually. Return a list of all processed choices
for choice <- choices do
do_process_response(choice)
Expand Down Expand Up @@ -301,18 +309,8 @@ defmodule Buildel.Langchain.ChatModels.ChatMistralAI do
nil
end

# more explicitly interpret the role. We treat a "function_call" as a a role
# while OpenAI addresses it as an "assistant". Technically, they are correct
# that the assistant is issuing the function_call.
role =
case delta_body do
%{"role" => role} -> role
_other -> "unknown"
end

data =
delta_body
|> Map.put("role", role)
|> Map.put("index", index)
|> Map.put("status", status)

Expand Down Expand Up @@ -370,4 +368,11 @@ defmodule Buildel.Langchain.ChatModels.ChatMistralAI do
Logger.error("Trying to process an unexpected response. #{inspect(other)}")
{:error, "Unexpected response"}
end

defp call_callback_with_token_usage(_data, nil), do: nil

defp call_callback_with_token_usage(%{"usage" => usage}, callback_fn) when is_map(usage),
do: callback_fn.(TokenUsage.new!(usage))

defp call_callback_with_token_usage(_data, _callback_fn), do: nil
end
43 changes: 43 additions & 0 deletions apps/api/lib/buildel/langchain/usage.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
defmodule Buildel.Langchain.TokenUsage do
use Ecto.Schema
import Ecto.Changeset

alias LangChain.LangChainError
alias __MODULE__

@primary_key false
embedded_schema do
field :completion_tokens, :integer
field :prompt_tokens, :integer
field :total_tokens, :integer
end

@type t :: %TokenUsage{}

@create_fields [
:completion_tokens,
:prompt_tokens,
:total_tokens
]
@required_fields [
:completion_tokens,
:prompt_tokens,
:total_tokens
]

@spec new(attrs :: map()) :: {:ok, t} | {:error, Ecto.Changeset.t()}
def new(%{} = attrs \\ %{}) do
%TokenUsage{}
|> cast(attrs, @create_fields)
|> validate_required(@required_fields)
|> apply_action(:inserts)
end

@spec new!(attrs :: map()) :: t() | no_return()
def new!(%{} = attrs \\ %{}) do
case new(attrs) do
{:ok, usage} -> usage
{:error, changeset} -> raise LangChainError, changeset
end
end
end

0 comments on commit 3b13559

Please sign in to comment.