diff --git a/galaxygpt-api/Program.cs b/galaxygpt-api/Program.cs
index b284931..de5035f 100644
--- a/galaxygpt-api/Program.cs
+++ b/galaxygpt-api/Program.cs
@@ -111,26 +111,24 @@ public static void Main(string[] args)
var requestStart = Stopwatch.StartNew();
- (string context, int contextTokenCount, int questionTokenCount) context = await contextManager.FetchContext(askPayload.Prompt, askPayload.MaxContextLength ?? 5);
+ (string, int) context = await contextManager.FetchContext(askPayload.Prompt, askPayload.MaxContextLength ?? 5);
// hash the username to prevent any potential privacy issues
// string? username = askPayload.Username != null ? Convert.ToHexString(SHA256.HashData(Encoding.UTF8.GetBytes(askPayload.Username))) : null;
try
{
- (string output, int promptTokenCount, int responseTokenCount) answer = await galaxyGpt.AnswerQuestion(askPayload.Prompt, context.Item1, username: askPayload.Username, maxOutputTokens: askPayload.MaxLength);
+ (string, int) answer = await galaxyGpt.AnswerQuestion(askPayload.Prompt, context.Item1, username: askPayload.Username, maxOutputTokens: askPayload.MaxLength);
requestStart.Stop();
return Results.Json(new AskResponse
{
Answer = answer.Item1.Trim(),
- Context = context.context,
+ Context = context.Item1,
Duration = requestStart.ElapsedMilliseconds.ToString(),
Version = version,
- PromptTokens = answer.promptTokenCount.ToString(),
- ContextTokens = context.contextTokenCount.ToString(),
- QuestionTokens = context.questionTokenCount.ToString(),
- ResponseTokens = answer.responseTokenCount.ToString()
+ QuestionTokens = context.Item2.ToString(),
+ ResponseTokens = answer.Item2.ToString()
});
}
catch (BonkedException e)
diff --git a/galaxygpt-api/Types/AskQuestion/AskResponse.cs b/galaxygpt-api/Types/AskQuestion/AskResponse.cs
index 51fa162..2903d3f 100644
--- a/galaxygpt-api/Types/AskQuestion/AskResponse.cs
+++ b/galaxygpt-api/Types/AskQuestion/AskResponse.cs
@@ -13,19 +13,7 @@ public class AskResponse
public required string Version { get; init; }
///
- /// The amount of tokens in the system prompt
- ///
- [JsonPropertyName("context_tokens")]
- public required string PromptTokens { get; init; }
-
- ///
- /// The amount of tokens in the context
- ///
- [JsonPropertyName("context_tokens")]
- public required string ContextTokens { get; init; }
-
- ///
- /// The amount of tokens in the user's question
+ /// The combined amount of tokens in the system prompt, context, and user's question
///
[JsonPropertyName("question_tokens")]
public required string QuestionTokens { get; init; }
diff --git a/galaxygpt-tests/AiClientTests.cs b/galaxygpt-tests/AiClientTests.cs
index e247db4..32faf27 100644
--- a/galaxygpt-tests/AiClientTests.cs
+++ b/galaxygpt-tests/AiClientTests.cs
@@ -34,7 +34,7 @@ public AiClientTests(ITestOutputHelper output)
ChatCompletion chatCompletion = OpenAIChatModelFactory.ChatCompletion(content:
[
ChatMessageContentPart.CreateTextPart("goofy ahh uncle productions")
- ], role: ChatMessageRole.Assistant, usage: OpenAIChatModelFactory.ChatTokenUsage(100, 100, 100));
+ ], role: ChatMessageRole.Assistant);
Mock> chatClientResultMock = new(null!, Mock.Of());
@@ -78,14 +78,13 @@ public async Task TestAnswersQuestion()
int? maxOutputTokens = 100;
// Act
- (string output, int promptTokenCount, int answerTokenCount) result =
+ (string output, int tokencount) result =
await _aiClient.AnswerQuestion(question, context, maxInputTokens, username, maxOutputTokens);
// Assert
Assert.NotNull(result.output);
Assert.False(string.IsNullOrWhiteSpace(result.output));
- Assert.True(result.promptTokenCount > 0);
- Assert.True(result.answerTokenCount > 0);
+ Assert.True(result.tokencount > 0);
_output.WriteLine(result.Item1);
}
diff --git a/galaxygpt/AiClient.cs b/galaxygpt/AiClient.cs
index 73ea9e6..e739a70 100644
--- a/galaxygpt/AiClient.cs
+++ b/galaxygpt/AiClient.cs
@@ -43,7 +43,7 @@ public partial class AiClient(
///
///
/// The moderation API flagged the response
- public async Task<(string output, int promptTokenCount, int responseTokenCount)> AnswerQuestion(string question, string context, int? maxInputTokens = null,
+ public async Task<(string output, int tokencount)> AnswerQuestion(string question, string context, int? maxInputTokens = null,
string? username = null, int? maxOutputTokens = null)
{
question = question.Trim();
@@ -73,7 +73,7 @@ public partial class AiClient(
string finalMessage = messages[^1].Content[0].Text;
await ModerateText(finalMessage, moderationClient);
- return (finalMessage, gptTokenizer.CountTokens(messages.First().Content[0].Text), clientResult.Value.Usage.OutputTokenCount);
+ return (finalMessage, gptTokenizer.CountTokens(finalMessage));
}
///
diff --git a/galaxygpt/ContextManager.cs b/galaxygpt/ContextManager.cs
index 4eaaeb5..1425223 100644
--- a/galaxygpt/ContextManager.cs
+++ b/galaxygpt/ContextManager.cs
@@ -38,7 +38,7 @@ public ContextManager(EmbeddingClient embeddingClient,
qdrantUrlAndPort.Length > 1 ? int.Parse(qdrantUrlAndPort[1]) : 6334);
}
- public async Task<(string context, int contextTokenCount, int questionTokenCount)> FetchContext(string question, ulong maxResults = 5)
+ public async Task<(string, int)> FetchContext(string question, ulong maxResults = 5)
{
if (_qdrantClient == null)
throw new InvalidOperationException("The Qdrant client is not available.");
@@ -63,7 +63,6 @@ public ContextManager(EmbeddingClient embeddingClient,
context.Append($"Page: {searchResult.Payload["title"].StringValue}\nContent: {searchResult.Payload["content"].StringValue}\n\n###\n\n");
}
- string contextString = context.ToString();
- return (contextString, _embeddingsTokenizer.CountTokens(contextString), _embeddingsTokenizer.CountTokens(question));
+ return (context.ToString(), _embeddingsTokenizer.CountTokens(question));
}
}
\ No newline at end of file