diff --git a/galaxygpt-api/Program.cs b/galaxygpt-api/Program.cs index de5035f..b284931 100644 --- a/galaxygpt-api/Program.cs +++ b/galaxygpt-api/Program.cs @@ -111,24 +111,26 @@ public static void Main(string[] args) var requestStart = Stopwatch.StartNew(); - (string, int) context = await contextManager.FetchContext(askPayload.Prompt, askPayload.MaxContextLength ?? 5); + (string context, int contextTokenCount, int questionTokenCount) context = await contextManager.FetchContext(askPayload.Prompt, askPayload.MaxContextLength ?? 5); // hash the username to prevent any potential privacy issues // string? username = askPayload.Username != null ? Convert.ToHexString(SHA256.HashData(Encoding.UTF8.GetBytes(askPayload.Username))) : null; try { - (string, int) answer = await galaxyGpt.AnswerQuestion(askPayload.Prompt, context.Item1, username: askPayload.Username, maxOutputTokens: askPayload.MaxLength); + (string output, int promptTokenCount, int responseTokenCount) answer = await galaxyGpt.AnswerQuestion(askPayload.Prompt, context.Item1, username: askPayload.Username, maxOutputTokens: askPayload.MaxLength); requestStart.Stop(); return Results.Json(new AskResponse { Answer = answer.Item1.Trim(), - Context = context.Item1, + Context = context.context, Duration = requestStart.ElapsedMilliseconds.ToString(), Version = version, - QuestionTokens = context.Item2.ToString(), - ResponseTokens = answer.Item2.ToString() + PromptTokens = answer.promptTokenCount.ToString(), + ContextTokens = context.contextTokenCount.ToString(), + QuestionTokens = context.questionTokenCount.ToString(), + ResponseTokens = answer.responseTokenCount.ToString() }); } catch (BonkedException e) diff --git a/galaxygpt-api/Types/AskQuestion/AskResponse.cs b/galaxygpt-api/Types/AskQuestion/AskResponse.cs index 2903d3f..51fa162 100644 --- a/galaxygpt-api/Types/AskQuestion/AskResponse.cs +++ b/galaxygpt-api/Types/AskQuestion/AskResponse.cs @@ -13,7 +13,19 @@ public class AskResponse public required string Version { get; init; } /// - /// The combined amount of tokens in the system prompt, context, and user's question + /// The amount of tokens in the system prompt + /// + [JsonPropertyName("context_tokens")] + public required string PromptTokens { get; init; } + + /// + /// The amount of tokens in the context + /// + [JsonPropertyName("context_tokens")] + public required string ContextTokens { get; init; } + + /// + /// The amount of tokens in the user's question /// [JsonPropertyName("question_tokens")] public required string QuestionTokens { get; init; } diff --git a/galaxygpt-tests/AiClientTests.cs b/galaxygpt-tests/AiClientTests.cs index 643db34..7ed4656 100644 --- a/galaxygpt-tests/AiClientTests.cs +++ b/galaxygpt-tests/AiClientTests.cs @@ -34,7 +34,7 @@ public AiClientTests(ITestOutputHelper output) ChatCompletion chatCompletion = OpenAIChatModelFactory.ChatCompletion(content: [ ChatMessageContentPart.CreateTextPart("goofy ahh uncle productions") - ], role: ChatMessageRole.Assistant); + ], role: ChatMessageRole.Assistant, usage: OpenAIChatModelFactory.ChatTokenUsage(100, 100, 100)); Mock> chatClientResultMock = new(null!, Mock.Of()); @@ -78,13 +78,14 @@ public async void TestAnswersQuestion() int? maxOutputTokens = 100; // Act - (string output, int tokencount) result = + (string output, int promptTokenCount, int answerTokenCount) result = await _aiClient.AnswerQuestion(question, context, maxInputTokens, username, maxOutputTokens); // Assert Assert.NotNull(result.output); Assert.False(string.IsNullOrWhiteSpace(result.output)); - Assert.True(result.tokencount > 0); + Assert.True(result.promptTokenCount > 0); + Assert.True(result.answerTokenCount > 0); _output.WriteLine(result.Item1); } diff --git a/galaxygpt/AiClient.cs b/galaxygpt/AiClient.cs index e739a70..73ea9e6 100644 --- a/galaxygpt/AiClient.cs +++ b/galaxygpt/AiClient.cs @@ -43,7 +43,7 @@ public partial class AiClient( /// /// /// The moderation API flagged the response - public async Task<(string output, int tokencount)> AnswerQuestion(string question, string context, int? maxInputTokens = null, + public async Task<(string output, int promptTokenCount, int responseTokenCount)> AnswerQuestion(string question, string context, int? maxInputTokens = null, string? username = null, int? maxOutputTokens = null) { question = question.Trim(); @@ -73,7 +73,7 @@ public partial class AiClient( string finalMessage = messages[^1].Content[0].Text; await ModerateText(finalMessage, moderationClient); - return (finalMessage, gptTokenizer.CountTokens(finalMessage)); + return (finalMessage, gptTokenizer.CountTokens(messages.First().Content[0].Text), clientResult.Value.Usage.OutputTokenCount); } /// diff --git a/galaxygpt/ContextManager.cs b/galaxygpt/ContextManager.cs index 1425223..4eaaeb5 100644 --- a/galaxygpt/ContextManager.cs +++ b/galaxygpt/ContextManager.cs @@ -38,7 +38,7 @@ public ContextManager(EmbeddingClient embeddingClient, qdrantUrlAndPort.Length > 1 ? int.Parse(qdrantUrlAndPort[1]) : 6334); } - public async Task<(string, int)> FetchContext(string question, ulong maxResults = 5) + public async Task<(string context, int contextTokenCount, int questionTokenCount)> FetchContext(string question, ulong maxResults = 5) { if (_qdrantClient == null) throw new InvalidOperationException("The Qdrant client is not available."); @@ -63,6 +63,7 @@ public ContextManager(EmbeddingClient embeddingClient, context.Append($"Page: {searchResult.Payload["title"].StringValue}\nContent: {searchResult.Payload["content"].StringValue}\n\n###\n\n"); } - return (context.ToString(), _embeddingsTokenizer.CountTokens(question)); + string contextString = context.ToString(); + return (contextString, _embeddingsTokenizer.CountTokens(contextString), _embeddingsTokenizer.CountTokens(question)); } } \ No newline at end of file