diff --git a/galaxygpt-api/Program.cs b/galaxygpt-api/Program.cs
index de5035f..b284931 100644
--- a/galaxygpt-api/Program.cs
+++ b/galaxygpt-api/Program.cs
@@ -111,24 +111,26 @@ public static void Main(string[] args)
var requestStart = Stopwatch.StartNew();
- (string, int) context = await contextManager.FetchContext(askPayload.Prompt, askPayload.MaxContextLength ?? 5);
+ (string context, int contextTokenCount, int questionTokenCount) context = await contextManager.FetchContext(askPayload.Prompt, askPayload.MaxContextLength ?? 5);
// hash the username to prevent any potential privacy issues
// string? username = askPayload.Username != null ? Convert.ToHexString(SHA256.HashData(Encoding.UTF8.GetBytes(askPayload.Username))) : null;
try
{
- (string, int) answer = await galaxyGpt.AnswerQuestion(askPayload.Prompt, context.Item1, username: askPayload.Username, maxOutputTokens: askPayload.MaxLength);
+ (string output, int promptTokenCount, int responseTokenCount) answer = await galaxyGpt.AnswerQuestion(askPayload.Prompt, context.Item1, username: askPayload.Username, maxOutputTokens: askPayload.MaxLength);
requestStart.Stop();
return Results.Json(new AskResponse
{
Answer = answer.Item1.Trim(),
- Context = context.Item1,
+ Context = context.context,
Duration = requestStart.ElapsedMilliseconds.ToString(),
Version = version,
- QuestionTokens = context.Item2.ToString(),
- ResponseTokens = answer.Item2.ToString()
+ PromptTokens = answer.promptTokenCount.ToString(),
+ ContextTokens = context.contextTokenCount.ToString(),
+ QuestionTokens = context.questionTokenCount.ToString(),
+ ResponseTokens = answer.responseTokenCount.ToString()
});
}
catch (BonkedException e)
diff --git a/galaxygpt-api/Types/AskQuestion/AskResponse.cs b/galaxygpt-api/Types/AskQuestion/AskResponse.cs
index 2903d3f..51fa162 100644
--- a/galaxygpt-api/Types/AskQuestion/AskResponse.cs
+++ b/galaxygpt-api/Types/AskQuestion/AskResponse.cs
@@ -13,7 +13,19 @@ public class AskResponse
public required string Version { get; init; }
///
- /// The combined amount of tokens in the system prompt, context, and user's question
+ /// The amount of tokens in the system prompt
+ ///
+ [JsonPropertyName("context_tokens")]
+ public required string PromptTokens { get; init; }
+
+ ///
+ /// The amount of tokens in the context
+ ///
+ [JsonPropertyName("context_tokens")]
+ public required string ContextTokens { get; init; }
+
+ ///
+ /// The amount of tokens in the user's question
///
[JsonPropertyName("question_tokens")]
public required string QuestionTokens { get; init; }
diff --git a/galaxygpt-tests/AiClientTests.cs b/galaxygpt-tests/AiClientTests.cs
index 643db34..7ed4656 100644
--- a/galaxygpt-tests/AiClientTests.cs
+++ b/galaxygpt-tests/AiClientTests.cs
@@ -34,7 +34,7 @@ public AiClientTests(ITestOutputHelper output)
ChatCompletion chatCompletion = OpenAIChatModelFactory.ChatCompletion(content:
[
ChatMessageContentPart.CreateTextPart("goofy ahh uncle productions")
- ], role: ChatMessageRole.Assistant);
+ ], role: ChatMessageRole.Assistant, usage: OpenAIChatModelFactory.ChatTokenUsage(100, 100, 100));
Mock> chatClientResultMock = new(null!, Mock.Of());
@@ -78,13 +78,14 @@ public async void TestAnswersQuestion()
int? maxOutputTokens = 100;
// Act
- (string output, int tokencount) result =
+ (string output, int promptTokenCount, int answerTokenCount) result =
await _aiClient.AnswerQuestion(question, context, maxInputTokens, username, maxOutputTokens);
// Assert
Assert.NotNull(result.output);
Assert.False(string.IsNullOrWhiteSpace(result.output));
- Assert.True(result.tokencount > 0);
+ Assert.True(result.promptTokenCount > 0);
+ Assert.True(result.answerTokenCount > 0);
_output.WriteLine(result.Item1);
}
diff --git a/galaxygpt/AiClient.cs b/galaxygpt/AiClient.cs
index e739a70..73ea9e6 100644
--- a/galaxygpt/AiClient.cs
+++ b/galaxygpt/AiClient.cs
@@ -43,7 +43,7 @@ public partial class AiClient(
///
///
/// The moderation API flagged the response
- public async Task<(string output, int tokencount)> AnswerQuestion(string question, string context, int? maxInputTokens = null,
+ public async Task<(string output, int promptTokenCount, int responseTokenCount)> AnswerQuestion(string question, string context, int? maxInputTokens = null,
string? username = null, int? maxOutputTokens = null)
{
question = question.Trim();
@@ -73,7 +73,7 @@ public partial class AiClient(
string finalMessage = messages[^1].Content[0].Text;
await ModerateText(finalMessage, moderationClient);
- return (finalMessage, gptTokenizer.CountTokens(finalMessage));
+ return (finalMessage, gptTokenizer.CountTokens(messages.First().Content[0].Text), clientResult.Value.Usage.OutputTokenCount);
}
///
diff --git a/galaxygpt/ContextManager.cs b/galaxygpt/ContextManager.cs
index 1425223..4eaaeb5 100644
--- a/galaxygpt/ContextManager.cs
+++ b/galaxygpt/ContextManager.cs
@@ -38,7 +38,7 @@ public ContextManager(EmbeddingClient embeddingClient,
qdrantUrlAndPort.Length > 1 ? int.Parse(qdrantUrlAndPort[1]) : 6334);
}
- public async Task<(string, int)> FetchContext(string question, ulong maxResults = 5)
+ public async Task<(string context, int contextTokenCount, int questionTokenCount)> FetchContext(string question, ulong maxResults = 5)
{
if (_qdrantClient == null)
throw new InvalidOperationException("The Qdrant client is not available.");
@@ -63,6 +63,7 @@ public ContextManager(EmbeddingClient embeddingClient,
context.Append($"Page: {searchResult.Payload["title"].StringValue}\nContent: {searchResult.Payload["content"].StringValue}\n\n###\n\n");
}
- return (context.ToString(), _embeddingsTokenizer.CountTokens(question));
+ string contextString = context.ToString();
+ return (contextString, _embeddingsTokenizer.CountTokens(contextString), _embeddingsTokenizer.CountTokens(question));
}
}
\ No newline at end of file