Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add context tokens to api response #91

Merged
merged 9 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions galaxygpt-api/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -111,24 +111,26 @@ public static void Main(string[] args)

var requestStart = Stopwatch.StartNew();

(string, int) context = await contextManager.FetchContext(askPayload.Prompt, askPayload.MaxContextLength ?? 5);
(string context, int contextTokenCount, int questionTokenCount) context = await contextManager.FetchContext(askPayload.Prompt, askPayload.MaxContextLength ?? 5);

// hash the username to prevent any potential privacy issues
// string? username = askPayload.Username != null ? Convert.ToHexString(SHA256.HashData(Encoding.UTF8.GetBytes(askPayload.Username))) : null;

try
{
(string, int) answer = await galaxyGpt.AnswerQuestion(askPayload.Prompt, context.Item1, username: askPayload.Username, maxOutputTokens: askPayload.MaxLength);
(string output, int promptTokenCount, int responseTokenCount) answer = await galaxyGpt.AnswerQuestion(askPayload.Prompt, context.Item1, username: askPayload.Username, maxOutputTokens: askPayload.MaxLength);
requestStart.Stop();

return Results.Json(new AskResponse
{
Answer = answer.Item1.Trim(),
Context = context.Item1,
Context = context.context,
Duration = requestStart.ElapsedMilliseconds.ToString(),
Version = version,
QuestionTokens = context.Item2.ToString(),
ResponseTokens = answer.Item2.ToString()
PromptTokens = answer.promptTokenCount.ToString(),
ContextTokens = context.contextTokenCount.ToString(),
QuestionTokens = context.questionTokenCount.ToString(),
ResponseTokens = answer.responseTokenCount.ToString()
});
}
catch (BonkedException e)
Expand Down
14 changes: 13 additions & 1 deletion galaxygpt-api/Types/AskQuestion/AskResponse.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,19 @@ public class AskResponse
public required string Version { get; init; }

/// <summary>
/// The combined amount of tokens in the system prompt, context, and user's question
/// The amount of tokens in the system prompt
/// </summary>
[JsonPropertyName("context_tokens")]
public required string PromptTokens { get; init; }

/// <summary>
/// The amount of tokens in the context
/// </summary>
[JsonPropertyName("context_tokens")]
public required string ContextTokens { get; init; }

/// <summary>
/// The amount of tokens in the user's question
/// </summary>
[JsonPropertyName("question_tokens")]
public required string QuestionTokens { get; init; }
Expand Down
7 changes: 4 additions & 3 deletions galaxygpt-tests/AiClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public AiClientTests(ITestOutputHelper output)
ChatCompletion chatCompletion = OpenAIChatModelFactory.ChatCompletion(content:
[
ChatMessageContentPart.CreateTextPart("goofy ahh uncle productions")
], role: ChatMessageRole.Assistant);
], role: ChatMessageRole.Assistant, usage: OpenAIChatModelFactory.ChatTokenUsage(100, 100, 100));

Mock<ClientResult<ChatCompletion>> chatClientResultMock = new(null!, Mock.Of<PipelineResponse>());

Expand Down Expand Up @@ -78,13 +78,14 @@ public async void TestAnswersQuestion()
int? maxOutputTokens = 100;

// Act
(string output, int tokencount) result =
(string output, int promptTokenCount, int answerTokenCount) result =
await _aiClient.AnswerQuestion(question, context, maxInputTokens, username, maxOutputTokens);

// Assert
Assert.NotNull(result.output);
Assert.False(string.IsNullOrWhiteSpace(result.output));
Assert.True(result.tokencount > 0);
Assert.True(result.promptTokenCount > 0);
Assert.True(result.answerTokenCount > 0);
_output.WriteLine(result.Item1);
}

Expand Down
4 changes: 2 additions & 2 deletions galaxygpt/AiClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public partial class AiClient(
/// <exception cref="ArgumentException"></exception>
/// <exception cref="InvalidOperationException"></exception>
/// <exception cref="BonkedException">The moderation API flagged the response</exception>
public async Task<(string output, int tokencount)> AnswerQuestion(string question, string context, int? maxInputTokens = null,
public async Task<(string output, int promptTokenCount, int responseTokenCount)> AnswerQuestion(string question, string context, int? maxInputTokens = null,
string? username = null, int? maxOutputTokens = null)
{
question = question.Trim();
Expand Down Expand Up @@ -73,7 +73,7 @@ public partial class AiClient(

string finalMessage = messages[^1].Content[0].Text;
await ModerateText(finalMessage, moderationClient);
return (finalMessage, gptTokenizer.CountTokens(finalMessage));
return (finalMessage, gptTokenizer.CountTokens(messages.First().Content[0].Text), clientResult.Value.Usage.OutputTokenCount);
smallketchup82 marked this conversation as resolved.
Show resolved Hide resolved
}

/// <summary>
Expand Down
5 changes: 3 additions & 2 deletions galaxygpt/ContextManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public ContextManager(EmbeddingClient embeddingClient,
qdrantUrlAndPort.Length > 1 ? int.Parse(qdrantUrlAndPort[1]) : 6334);
}

public async Task<(string, int)> FetchContext(string question, ulong maxResults = 5)
public async Task<(string context, int contextTokenCount, int questionTokenCount)> FetchContext(string question, ulong maxResults = 5)
{
if (_qdrantClient == null)
throw new InvalidOperationException("The Qdrant client is not available.");
Expand All @@ -63,6 +63,7 @@ public ContextManager(EmbeddingClient embeddingClient,
context.Append($"Page: {searchResult.Payload["title"].StringValue}\nContent: {searchResult.Payload["content"].StringValue}\n\n###\n\n");
}

return (context.ToString(), _embeddingsTokenizer.CountTokens(question));
string contextString = context.ToString();
return (contextString, _embeddingsTokenizer.CountTokens(contextString), _embeddingsTokenizer.CountTokens(question));
}
}