diff --git a/galaxygpt-api/Program.cs b/galaxygpt-api/Program.cs index b284931..de5035f 100644 --- a/galaxygpt-api/Program.cs +++ b/galaxygpt-api/Program.cs @@ -111,26 +111,24 @@ public static void Main(string[] args) var requestStart = Stopwatch.StartNew(); - (string context, int contextTokenCount, int questionTokenCount) context = await contextManager.FetchContext(askPayload.Prompt, askPayload.MaxContextLength ?? 5); + (string, int) context = await contextManager.FetchContext(askPayload.Prompt, askPayload.MaxContextLength ?? 5); // hash the username to prevent any potential privacy issues // string? username = askPayload.Username != null ? Convert.ToHexString(SHA256.HashData(Encoding.UTF8.GetBytes(askPayload.Username))) : null; try { - (string output, int promptTokenCount, int responseTokenCount) answer = await galaxyGpt.AnswerQuestion(askPayload.Prompt, context.Item1, username: askPayload.Username, maxOutputTokens: askPayload.MaxLength); + (string, int) answer = await galaxyGpt.AnswerQuestion(askPayload.Prompt, context.Item1, username: askPayload.Username, maxOutputTokens: askPayload.MaxLength); requestStart.Stop(); return Results.Json(new AskResponse { Answer = answer.Item1.Trim(), - Context = context.context, + Context = context.Item1, Duration = requestStart.ElapsedMilliseconds.ToString(), Version = version, - PromptTokens = answer.promptTokenCount.ToString(), - ContextTokens = context.contextTokenCount.ToString(), - QuestionTokens = context.questionTokenCount.ToString(), - ResponseTokens = answer.responseTokenCount.ToString() + QuestionTokens = context.Item2.ToString(), + ResponseTokens = answer.Item2.ToString() }); } catch (BonkedException e) diff --git a/galaxygpt-api/Types/AskQuestion/AskResponse.cs b/galaxygpt-api/Types/AskQuestion/AskResponse.cs index 51fa162..2903d3f 100644 --- a/galaxygpt-api/Types/AskQuestion/AskResponse.cs +++ b/galaxygpt-api/Types/AskQuestion/AskResponse.cs @@ -13,19 +13,7 @@ public class AskResponse public required string Version { get; init; } /// - /// The amount of tokens in the system prompt - /// - [JsonPropertyName("context_tokens")] - public required string PromptTokens { get; init; } - - /// - /// The amount of tokens in the context - /// - [JsonPropertyName("context_tokens")] - public required string ContextTokens { get; init; } - - /// - /// The amount of tokens in the user's question + /// The combined amount of tokens in the system prompt, context, and user's question /// [JsonPropertyName("question_tokens")] public required string QuestionTokens { get; init; } diff --git a/galaxygpt-tests/AiClientTests.cs b/galaxygpt-tests/AiClientTests.cs index e247db4..32faf27 100644 --- a/galaxygpt-tests/AiClientTests.cs +++ b/galaxygpt-tests/AiClientTests.cs @@ -34,7 +34,7 @@ public AiClientTests(ITestOutputHelper output) ChatCompletion chatCompletion = OpenAIChatModelFactory.ChatCompletion(content: [ ChatMessageContentPart.CreateTextPart("goofy ahh uncle productions") - ], role: ChatMessageRole.Assistant, usage: OpenAIChatModelFactory.ChatTokenUsage(100, 100, 100)); + ], role: ChatMessageRole.Assistant); Mock> chatClientResultMock = new(null!, Mock.Of()); @@ -78,14 +78,13 @@ public async Task TestAnswersQuestion() int? maxOutputTokens = 100; // Act - (string output, int promptTokenCount, int answerTokenCount) result = + (string output, int tokencount) result = await _aiClient.AnswerQuestion(question, context, maxInputTokens, username, maxOutputTokens); // Assert Assert.NotNull(result.output); Assert.False(string.IsNullOrWhiteSpace(result.output)); - Assert.True(result.promptTokenCount > 0); - Assert.True(result.answerTokenCount > 0); + Assert.True(result.tokencount > 0); _output.WriteLine(result.Item1); } diff --git a/galaxygpt/AiClient.cs b/galaxygpt/AiClient.cs index 73ea9e6..e739a70 100644 --- a/galaxygpt/AiClient.cs +++ b/galaxygpt/AiClient.cs @@ -43,7 +43,7 @@ public partial class AiClient( /// /// /// The moderation API flagged the response - public async Task<(string output, int promptTokenCount, int responseTokenCount)> AnswerQuestion(string question, string context, int? maxInputTokens = null, + public async Task<(string output, int tokencount)> AnswerQuestion(string question, string context, int? maxInputTokens = null, string? username = null, int? maxOutputTokens = null) { question = question.Trim(); @@ -73,7 +73,7 @@ public partial class AiClient( string finalMessage = messages[^1].Content[0].Text; await ModerateText(finalMessage, moderationClient); - return (finalMessage, gptTokenizer.CountTokens(messages.First().Content[0].Text), clientResult.Value.Usage.OutputTokenCount); + return (finalMessage, gptTokenizer.CountTokens(finalMessage)); } /// diff --git a/galaxygpt/ContextManager.cs b/galaxygpt/ContextManager.cs index 4eaaeb5..1425223 100644 --- a/galaxygpt/ContextManager.cs +++ b/galaxygpt/ContextManager.cs @@ -38,7 +38,7 @@ public ContextManager(EmbeddingClient embeddingClient, qdrantUrlAndPort.Length > 1 ? int.Parse(qdrantUrlAndPort[1]) : 6334); } - public async Task<(string context, int contextTokenCount, int questionTokenCount)> FetchContext(string question, ulong maxResults = 5) + public async Task<(string, int)> FetchContext(string question, ulong maxResults = 5) { if (_qdrantClient == null) throw new InvalidOperationException("The Qdrant client is not available."); @@ -63,7 +63,6 @@ public ContextManager(EmbeddingClient embeddingClient, context.Append($"Page: {searchResult.Payload["title"].StringValue}\nContent: {searchResult.Payload["content"].StringValue}\n\n###\n\n"); } - string contextString = context.ToString(); - return (contextString, _embeddingsTokenizer.CountTokens(contextString), _embeddingsTokenizer.CountTokens(question)); + return (context.ToString(), _embeddingsTokenizer.CountTokens(question)); } } \ No newline at end of file