Skip to content

Commit

Permalink
Lets go
Browse files Browse the repository at this point in the history
  • Loading branch information
smallketchup82 committed Aug 24, 2024
1 parent 2e9bde6 commit af67dd7
Show file tree
Hide file tree
Showing 7 changed files with 254 additions and 4 deletions.
13 changes: 13 additions & 0 deletions dataset-assistant/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,19 @@ private static async Task<int> Main(string[] args)

#endregion

#region Add Metadata

db.Metadata.Add(new Metadata
{
DatasetName = datasetNameOptionValue,
CreatedAt = DateTime.UtcNow,
ChunkMaxSize = maxtokens
});

await db.SaveChangesAsync();

#endregion

globalProgressBar.Tick("Done");
});

Expand Down
55 changes: 51 additions & 4 deletions galaxygpt-api/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
using System.Reflection;
using Asp.Versioning.Builder;
using galaxygpt;
using galaxygpt.Database;
using Microsoft.AspNetCore.Diagnostics;
using Microsoft.Extensions.Options;
using Microsoft.ML.Tokenizers;
using OpenAI;
using Sentry.Profiling;
using Swashbuckle.AspNetCore.SwaggerGen;

namespace galaxygpt_api;
Expand All @@ -30,10 +35,49 @@ public static void Main(string[] args)

builder.Services.AddTransient<IConfigureOptions<SwaggerGenOptions>, ConfigureSwaggerOptions>();
builder.Services.AddSwaggerGen(options => options.OperationFilter<SwaggerDefaultValues>());

builder.Services.AddMemoryCache();

builder.Configuration.AddJsonFile("appsettings.json", optional: false, reloadOnChange: true);
#region Configuration

IConfigurationRoot configuration = new ConfigurationBuilder()
.AddJsonFile("appsettings.json", optional: true, reloadOnChange: true)
.AddEnvironmentVariables()
.AddUserSecrets<Program>()
.Build();

builder.Configuration.Sources.Clear();
builder.Configuration.AddConfiguration(configuration);

#endregion

builder.WebHost.UseSentry(o =>
{
o.Dsn = configuration["SENTRY_DSN"] ?? "https://1df72bed08400836796f15c03748d195@o4507833886834688.ingest.us.sentry.io/4507833934544896";
#if DEBUG
o.Debug = true;
#endif
o.TracesSampleRate = 1.0;
o.ProfilesSampleRate = 1.0;
o.AddIntegration(new ProfilingIntegration());
});

#region GalaxyGPT Services

var openAiClient = new OpenAIClient(configuration["OPENAI_API_KEY"] ?? throw new InvalidOperationException());
string gptModel = configuration["GPT_MODEL"] ?? "gpt-4o-mini";
string textEmbeddingModel = configuration["TEXT_EMBEDDING_MODEL"] ?? "text-embedding-3-small";
string moderationModel = configuration["MODERATION_MODEL"] ?? "text-moderation-stable";

builder.Services.AddSingleton(new VectorDb());
builder.Services.AddSingleton(openAiClient.GetChatClient(gptModel));
builder.Services.AddSingleton(openAiClient.GetEmbeddingClient(textEmbeddingModel));
builder.Services.AddSingleton(openAiClient.GetModerationClient(moderationModel));
builder.Services.AddKeyedSingleton("gptTokenizer", TiktokenTokenizer.CreateForModel("gpt-4o-mini"));
builder.Services.AddKeyedSingleton("embeddingsTokenizer", TiktokenTokenizer.CreateForModel("text-embedding-3-small"));
builder.Services.AddSingleton<ContextManager>();
builder.Services.AddSingleton<AiClient>();

#endregion

WebApplication app = builder.Build();
IVersionedEndpointRouteBuilder versionedApi = app.NewVersionedApi("galaxygpt");
Expand All @@ -47,13 +91,16 @@ public static void Main(string[] args)
#region API
RouteGroupBuilder v1 = versionedApi.MapGroup("/api/v{version:apiVersion}").HasApiVersion(1.0);

var galaxyGpt = app.Services.GetRequiredService<AiClient>();
var contextManager = app.Services.GetRequiredService<ContextManager>();

v1.MapPost("ask", async (AskPayload askPayload) =>
{
if (string.IsNullOrEmpty(askPayload.Prompt))
return Results.BadRequest("The question cannot be empty.");

(string, int) context = await GalaxyGpt.FetchContext(askPayload.Prompt, "text-embedding-3-small");
string answer = await GalaxyGpt.AnswerQuestion(askPayload.Prompt, context.Item1, askPayload.Model ?? app.Configuration["MODEL"] ?? throw new InvalidOperationException(), 4096, 4096, username: askPayload.Username);
(string, int) context = await contextManager.FetchContext(askPayload.Prompt);
string answer = await galaxyGpt.AnswerQuestion(askPayload.Prompt, context.Item1, 4096, username: askPayload.Username);

var results = new Dictionary<string, string>
{
Expand Down
2 changes: 2 additions & 0 deletions galaxygpt-api/galaxygpt-api.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
<PackageReference Include="Asp.Versioning.Http" Version="8.1.0" />
<PackageReference Include="Asp.Versioning.Mvc.ApiExplorer" Version="8.1.0" />
<PackageReference Include="Microsoft.AspNetCore.OpenApi" Version="8.0.7"/>
<PackageReference Include="Sentry.AspNetCore" Version="4.10.2" />
<PackageReference Include="Sentry.Profiling" Version="4.10.2" />
<PackageReference Include="Swashbuckle.AspNetCore" Version="6.4.0"/>
</ItemGroup>

Expand Down
71 changes: 71 additions & 0 deletions galaxygpt/AiClient.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright (c) smallketchup82. Licensed under the GPLv3 Licence.
// See the LICENCE file in the repository root for full licence text.

using System.ClientModel;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.ML.Tokenizers;
using OpenAI.Chat;
using OpenAI.Moderations;

namespace galaxygpt;

public class AiClient(ChatClient chatClient, [FromKeyedServices("gptTokenizer")] TiktokenTokenizer gptTokenizer, ModerationClient? moderationClient = null)
{
public async Task<string> AnswerQuestion(string question, string context, int maxInputTokens, string? username = null, int? maxOutputTokens = null)
{
#region Sanitize & Check the question

question = question.Trim();

if (string.IsNullOrWhiteSpace(question))
throw new ArgumentException("The question cannot be empty.");

if (gptTokenizer.CountTokens(question) > maxInputTokens)
throw new ArgumentException("The question is too long to be answered.");

// Throw the question into the moderation API
if (moderationClient != null)
{
ClientResult<ModerationResult> moderation = await moderationClient.ClassifyTextInputAsync(question);

if (moderation.Value.Flagged)
throw new InvalidOperationException("The question was flagged by the moderation API.");
} else
Console.WriteLine("Warning: No moderation client was provided. Skipping moderation check. This can be dangerous");
#endregion

List<ChatMessage> messages =
[
new SystemChatMessage("""
You are GalaxyGPT, a helpful assistant that answers questions about Galaxy, a ROBLOX Space Game.
The Galaxypedia is the game's official wiki and it is your creator.
The Galaxypedia's slogans are "The new era of the Galaxy Wiki" and "A hub for all things Galaxy".
Answer the question based on the supplied context. If the question cannot be answered, politely say you don't know the answer and ask the user for clarification, or if they have any further questions about Galaxy.
If the user has a username, it will be provided and you can address them by it. If a username is not provided (it shows as N/A), do not address/refer the user apart from "you" or "your".
Do not reference or mention the "context provided" in your response, no matter what.
The context will be given in the format of wikitext. You will be given multiple different pages in your context to work with. The different pages will be separated by "###".
If a ship infobox is present in the context, prefer using data from within the infobox. An infobox can be found by looking for a wikitext template that has the word "infobox" in its name.
If the user is not asking a question (e.g. "thank you", "thanks for the help"): Respond to it and ask the user if they have any further questions.
Respond to greetings (e.g. "hi", "hello") with (in this exact order): A greeting, a brief description of yourself, and a question addressed to the user if they have a question or need assistance.
Above all, be polite and helpful to the user.
Steps for responding:
First check if the user is asking about a ship (e.g. "what is the deity?", "how much shield does the theia have?"), if so, use the ship's wiki page (supplied in the context) and the statistics from the ship's infobox to answer the question.
If you determine the user is not asking about a ship (e.g. "who is <player>?", "what is <item>?"), do your best to answer the question based on the context provided.
"""),
new UserChatMessage($"Context:\n{context.Trim()}\n\n---\n\nQuestion: {question}\nUsername: {username ?? "N/A"}")
{
ParticipantName = username ?? null
}
];

ClientResult<ChatCompletion>? idk = await chatClient.CompleteChatAsync(messages, new ChatCompletionOptions
{
MaxTokens = maxOutputTokens
});
messages.Add(new AssistantChatMessage(idk));

return messages[^1].Content[0].Text;
}

}
92 changes: 92 additions & 0 deletions galaxygpt/ContextManager.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Copyright (c) smallketchup82. Licensed under the GPLv3 Licence.
// See the LICENCE file in the repository root for full licence text.

using System.ClientModel;
using System.Numerics.Tensors;
using System.Text;
using galaxygpt.Database;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.ML.Tokenizers;
using OpenAI;
using OpenAI.Embeddings;

namespace galaxygpt;

/// <summary>
/// Handles context management
/// </summary>
public class ContextManager(VectorDb db, EmbeddingClient embeddingClient, [FromKeyedServices("gptTokenizer")] TiktokenTokenizer gptTokenizer, [FromKeyedServices("embeddingsTokenizer")] TiktokenTokenizer embeddingsTokenizer)
{
/// <summary>
/// Load all pages from the database into memory
/// </summary>
/// <remarks>
/// Honestly, I tried to avoid this, but considering we'll be doing cosine similarity on everything anyway, it's better to load everything into memory.
/// </remarks>
private List<Page> _pages = db.Pages.Include(chunk => chunk.Chunks).ToList();

public async Task<(string, int)> FetchContext(string question, int? maxLength = null)
{
question = question.Trim();

if (string.IsNullOrEmpty(question))
throw new ArgumentException("The question cannot be empty.");

if (!db.Pages.Any())
throw new InvalidOperationException("The database is empty. Please load a dataset first.");

ClientResult<Embedding>? questionEmbeddings = await embeddingClient.GenerateEmbeddingAsync(question);

var pageEmbeddings = new List<(Page page, float[] embeddings, int chunkId, float distance)>();

foreach (Page page in db.Pages.Include(chunk => chunk.Chunks))
{
if (page.Chunks == null || page.Chunks.Count == 0)
{
if (page.Embeddings == null) continue;

float distance = TensorPrimitives.CosineSimilarity(questionEmbeddings.Value.Vector.ToArray(), page.Embeddings.ToArray());
pageEmbeddings.Add((page, page.Embeddings.ToArray(), -1, distance));
}
else if (page.Chunks != null)
{
foreach (Chunk chunk in page.Chunks)
{
if (chunk.Embeddings == null) continue;

float distance = TensorPrimitives.CosineSimilarity(questionEmbeddings.Value.Vector.ToArray(), chunk.Embeddings.ToArray());
pageEmbeddings.Add((page, chunk.Embeddings.ToArray(), chunk.Id, distance));
}
}
}

pageEmbeddings.Sort((a, b) => b.distance.CompareTo(a.distance));

StringBuilder context = new();
int tokenCount = gptTokenizer.CountTokens(question);
int iterations = 0;

foreach ((Page page, float[] _, int chunkId, float _) in pageEmbeddings)
{
string content = chunkId == -1|| page.Chunks == null || page.Chunks.Count == 0 ? page.Content : page.Chunks.First(chunk => chunk.Id == chunkId).Content;

if (maxLength == null)
{
if (iterations >= 5)
break;
}
else
{
tokenCount += gptTokenizer.CountTokens(content);
if (tokenCount > maxLength)
break;
}

context.Append($"Page: {page.Title}\nContent: {content}\n\n###\n\n");
iterations++;
}

return (context.ToString(), embeddingsTokenizer.CountTokens(question));
}
}
24 changes: 24 additions & 0 deletions galaxygpt/Database/Metadata.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (c) smallketchup82. Licensed under the GPLv3 Licence.
// See the LICENCE file in the repository root for full licence text.

namespace galaxygpt.Database;

public class Metadata
{
public int Id { get; init; }

/// <summary>
/// The name of the dataset (typically something like "dataset-v1")
/// </summary>
public required string DatasetName { get; init; }

/// <summary>
/// The date and time the dataset was created at. Use UTC time.
/// </summary>
public required DateTime CreatedAt { get; init; }

/// <summary>
/// The maximum size of each chunk
/// </summary>
public required int ChunkMaxSize { get; init; }
}
1 change: 1 addition & 0 deletions galaxygpt/Database/VectorDB.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ public class VectorDb(string? path = null) : DbContext
{
public DbSet<Page> Pages { get; set; }
public DbSet<Chunk> Chunks { get; set; }
public DbSet<Metadata> Metadata { get; set; }
private readonly string _dbPath = "Data Source=" + (path ?? Path.Join(Environment.GetFolderPath(Environment.SpecialFolder.LocalApplicationData), "embeddings.db"));

protected override void OnConfiguring(DbContextOptionsBuilder options)
Expand Down

0 comments on commit af67dd7

Please sign in to comment.