diff --git a/dataset-assistant/Program.cs b/dataset-assistant/Program.cs index 08d77c1..e5a0001 100644 --- a/dataset-assistant/Program.cs +++ b/dataset-assistant/Program.cs @@ -331,6 +331,19 @@ private static async Task Main(string[] args) #endregion + #region Add Metadata + + db.Metadata.Add(new Metadata + { + DatasetName = datasetNameOptionValue, + CreatedAt = DateTime.UtcNow, + ChunkMaxSize = maxtokens + }); + + await db.SaveChangesAsync(); + + #endregion + globalProgressBar.Tick("Done"); }); diff --git a/galaxygpt-api/Program.cs b/galaxygpt-api/Program.cs index b50a116..9bf1387 100644 --- a/galaxygpt-api/Program.cs +++ b/galaxygpt-api/Program.cs @@ -4,7 +4,12 @@ using System.Reflection; using Asp.Versioning.Builder; using galaxygpt; +using galaxygpt.Database; +using Microsoft.AspNetCore.Diagnostics; using Microsoft.Extensions.Options; +using Microsoft.ML.Tokenizers; +using OpenAI; +using Sentry.Profiling; using Swashbuckle.AspNetCore.SwaggerGen; namespace galaxygpt_api; @@ -30,10 +35,49 @@ public static void Main(string[] args) builder.Services.AddTransient, ConfigureSwaggerOptions>(); builder.Services.AddSwaggerGen(options => options.OperationFilter()); - builder.Services.AddMemoryCache(); - builder.Configuration.AddJsonFile("appsettings.json", optional: false, reloadOnChange: true); + #region Configuration + + IConfigurationRoot configuration = new ConfigurationBuilder() + .AddJsonFile("appsettings.json", optional: true, reloadOnChange: true) + .AddEnvironmentVariables() + .AddUserSecrets() + .Build(); + + builder.Configuration.Sources.Clear(); + builder.Configuration.AddConfiguration(configuration); + + #endregion + + builder.WebHost.UseSentry(o => + { + o.Dsn = configuration["SENTRY_DSN"] ?? "https://1df72bed08400836796f15c03748d195@o4507833886834688.ingest.us.sentry.io/4507833934544896"; +#if DEBUG + o.Debug = true; +#endif + o.TracesSampleRate = 1.0; + o.ProfilesSampleRate = 1.0; + o.AddIntegration(new ProfilingIntegration()); + }); + + #region GalaxyGPT Services + + var openAiClient = new OpenAIClient(configuration["OPENAI_API_KEY"] ?? throw new InvalidOperationException()); + string gptModel = configuration["GPT_MODEL"] ?? "gpt-4o-mini"; + string textEmbeddingModel = configuration["TEXT_EMBEDDING_MODEL"] ?? "text-embedding-3-small"; + string moderationModel = configuration["MODERATION_MODEL"] ?? "text-moderation-stable"; + + builder.Services.AddSingleton(new VectorDb()); + builder.Services.AddSingleton(openAiClient.GetChatClient(gptModel)); + builder.Services.AddSingleton(openAiClient.GetEmbeddingClient(textEmbeddingModel)); + builder.Services.AddSingleton(openAiClient.GetModerationClient(moderationModel)); + builder.Services.AddKeyedSingleton("gptTokenizer", TiktokenTokenizer.CreateForModel("gpt-4o-mini")); + builder.Services.AddKeyedSingleton("embeddingsTokenizer", TiktokenTokenizer.CreateForModel("text-embedding-3-small")); + builder.Services.AddSingleton(); + builder.Services.AddSingleton(); + + #endregion WebApplication app = builder.Build(); IVersionedEndpointRouteBuilder versionedApi = app.NewVersionedApi("galaxygpt"); @@ -47,13 +91,16 @@ public static void Main(string[] args) #region API RouteGroupBuilder v1 = versionedApi.MapGroup("/api/v{version:apiVersion}").HasApiVersion(1.0); + var galaxyGpt = app.Services.GetRequiredService(); + var contextManager = app.Services.GetRequiredService(); + v1.MapPost("ask", async (AskPayload askPayload) => { if (string.IsNullOrEmpty(askPayload.Prompt)) return Results.BadRequest("The question cannot be empty."); - (string, int) context = await GalaxyGpt.FetchContext(askPayload.Prompt, "text-embedding-3-small"); - string answer = await GalaxyGpt.AnswerQuestion(askPayload.Prompt, context.Item1, askPayload.Model ?? app.Configuration["MODEL"] ?? throw new InvalidOperationException(), 4096, 4096, username: askPayload.Username); + (string, int) context = await contextManager.FetchContext(askPayload.Prompt); + string answer = await galaxyGpt.AnswerQuestion(askPayload.Prompt, context.Item1, 4096, username: askPayload.Username); var results = new Dictionary { diff --git a/galaxygpt-api/galaxygpt-api.csproj b/galaxygpt-api/galaxygpt-api.csproj index 0203a9c..a50e5ee 100644 --- a/galaxygpt-api/galaxygpt-api.csproj +++ b/galaxygpt-api/galaxygpt-api.csproj @@ -11,6 +11,8 @@ + + diff --git a/galaxygpt/AiClient.cs b/galaxygpt/AiClient.cs new file mode 100644 index 0000000..93c932d --- /dev/null +++ b/galaxygpt/AiClient.cs @@ -0,0 +1,71 @@ +// Copyright (c) smallketchup82. Licensed under the GPLv3 Licence. +// See the LICENCE file in the repository root for full licence text. + +using System.ClientModel; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.ML.Tokenizers; +using OpenAI.Chat; +using OpenAI.Moderations; + +namespace galaxygpt; + +public class AiClient(ChatClient chatClient, [FromKeyedServices("gptTokenizer")] TiktokenTokenizer gptTokenizer, ModerationClient? moderationClient = null) +{ + public async Task AnswerQuestion(string question, string context, int maxInputTokens, string? username = null, int? maxOutputTokens = null) + { + #region Sanitize & Check the question + + question = question.Trim(); + + if (string.IsNullOrWhiteSpace(question)) + throw new ArgumentException("The question cannot be empty."); + + if (gptTokenizer.CountTokens(question) > maxInputTokens) + throw new ArgumentException("The question is too long to be answered."); + + // Throw the question into the moderation API + if (moderationClient != null) + { + ClientResult moderation = await moderationClient.ClassifyTextInputAsync(question); + + if (moderation.Value.Flagged) + throw new InvalidOperationException("The question was flagged by the moderation API."); + } else + Console.WriteLine("Warning: No moderation client was provided. Skipping moderation check. This can be dangerous"); + #endregion + + List messages = + [ + new SystemChatMessage(""" + You are GalaxyGPT, a helpful assistant that answers questions about Galaxy, a ROBLOX Space Game. + The Galaxypedia is the game's official wiki and it is your creator. + The Galaxypedia's slogans are "The new era of the Galaxy Wiki" and "A hub for all things Galaxy". + Answer the question based on the supplied context. If the question cannot be answered, politely say you don't know the answer and ask the user for clarification, or if they have any further questions about Galaxy. + If the user has a username, it will be provided and you can address them by it. If a username is not provided (it shows as N/A), do not address/refer the user apart from "you" or "your". + Do not reference or mention the "context provided" in your response, no matter what. + The context will be given in the format of wikitext. You will be given multiple different pages in your context to work with. The different pages will be separated by "###". + If a ship infobox is present in the context, prefer using data from within the infobox. An infobox can be found by looking for a wikitext template that has the word "infobox" in its name. + If the user is not asking a question (e.g. "thank you", "thanks for the help"): Respond to it and ask the user if they have any further questions. + Respond to greetings (e.g. "hi", "hello") with (in this exact order): A greeting, a brief description of yourself, and a question addressed to the user if they have a question or need assistance. + Above all, be polite and helpful to the user. + + Steps for responding: + First check if the user is asking about a ship (e.g. "what is the deity?", "how much shield does the theia have?"), if so, use the ship's wiki page (supplied in the context) and the statistics from the ship's infobox to answer the question. + If you determine the user is not asking about a ship (e.g. "who is ?", "what is ?"), do your best to answer the question based on the context provided. + """), + new UserChatMessage($"Context:\n{context.Trim()}\n\n---\n\nQuestion: {question}\nUsername: {username ?? "N/A"}") + { + ParticipantName = username ?? null + } + ]; + + ClientResult? idk = await chatClient.CompleteChatAsync(messages, new ChatCompletionOptions + { + MaxTokens = maxOutputTokens + }); + messages.Add(new AssistantChatMessage(idk)); + + return messages[^1].Content[0].Text; + } + +} \ No newline at end of file diff --git a/galaxygpt/ContextManager.cs b/galaxygpt/ContextManager.cs new file mode 100644 index 0000000..816550c --- /dev/null +++ b/galaxygpt/ContextManager.cs @@ -0,0 +1,92 @@ +// Copyright (c) smallketchup82. Licensed under the GPLv3 Licence. +// See the LICENCE file in the repository root for full licence text. + +using System.ClientModel; +using System.Numerics.Tensors; +using System.Text; +using galaxygpt.Database; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.ML.Tokenizers; +using OpenAI; +using OpenAI.Embeddings; + +namespace galaxygpt; + +/// +/// Handles context management +/// +public class ContextManager(VectorDb db, EmbeddingClient embeddingClient, [FromKeyedServices("gptTokenizer")] TiktokenTokenizer gptTokenizer, [FromKeyedServices("embeddingsTokenizer")] TiktokenTokenizer embeddingsTokenizer) +{ + /// + /// Load all pages from the database into memory + /// + /// + /// Honestly, I tried to avoid this, but considering we'll be doing cosine similarity on everything anyway, it's better to load everything into memory. + /// + private List _pages = db.Pages.Include(chunk => chunk.Chunks).ToList(); + + public async Task<(string, int)> FetchContext(string question, int? maxLength = null) + { + question = question.Trim(); + + if (string.IsNullOrEmpty(question)) + throw new ArgumentException("The question cannot be empty."); + + if (!db.Pages.Any()) + throw new InvalidOperationException("The database is empty. Please load a dataset first."); + + ClientResult? questionEmbeddings = await embeddingClient.GenerateEmbeddingAsync(question); + + var pageEmbeddings = new List<(Page page, float[] embeddings, int chunkId, float distance)>(); + + foreach (Page page in db.Pages.Include(chunk => chunk.Chunks)) + { + if (page.Chunks == null || page.Chunks.Count == 0) + { + if (page.Embeddings == null) continue; + + float distance = TensorPrimitives.CosineSimilarity(questionEmbeddings.Value.Vector.ToArray(), page.Embeddings.ToArray()); + pageEmbeddings.Add((page, page.Embeddings.ToArray(), -1, distance)); + } + else if (page.Chunks != null) + { + foreach (Chunk chunk in page.Chunks) + { + if (chunk.Embeddings == null) continue; + + float distance = TensorPrimitives.CosineSimilarity(questionEmbeddings.Value.Vector.ToArray(), chunk.Embeddings.ToArray()); + pageEmbeddings.Add((page, chunk.Embeddings.ToArray(), chunk.Id, distance)); + } + } + } + + pageEmbeddings.Sort((a, b) => b.distance.CompareTo(a.distance)); + + StringBuilder context = new(); + int tokenCount = gptTokenizer.CountTokens(question); + int iterations = 0; + + foreach ((Page page, float[] _, int chunkId, float _) in pageEmbeddings) + { + string content = chunkId == -1|| page.Chunks == null || page.Chunks.Count == 0 ? page.Content : page.Chunks.First(chunk => chunk.Id == chunkId).Content; + + if (maxLength == null) + { + if (iterations >= 5) + break; + } + else + { + tokenCount += gptTokenizer.CountTokens(content); + if (tokenCount > maxLength) + break; + } + + context.Append($"Page: {page.Title}\nContent: {content}\n\n###\n\n"); + iterations++; + } + + return (context.ToString(), embeddingsTokenizer.CountTokens(question)); + } +} \ No newline at end of file diff --git a/galaxygpt/Database/Metadata.cs b/galaxygpt/Database/Metadata.cs new file mode 100644 index 0000000..6cf0d0a --- /dev/null +++ b/galaxygpt/Database/Metadata.cs @@ -0,0 +1,24 @@ +// Copyright (c) smallketchup82. Licensed under the GPLv3 Licence. +// See the LICENCE file in the repository root for full licence text. + +namespace galaxygpt.Database; + +public class Metadata +{ + public int Id { get; init; } + + /// + /// The name of the dataset (typically something like "dataset-v1") + /// + public required string DatasetName { get; init; } + + /// + /// The date and time the dataset was created at. Use UTC time. + /// + public required DateTime CreatedAt { get; init; } + + /// + /// The maximum size of each chunk + /// + public required int ChunkMaxSize { get; init; } +} \ No newline at end of file diff --git a/galaxygpt/Database/VectorDB.cs b/galaxygpt/Database/VectorDB.cs index 5ce0e49..1aaf471 100644 --- a/galaxygpt/Database/VectorDB.cs +++ b/galaxygpt/Database/VectorDB.cs @@ -9,6 +9,7 @@ public class VectorDb(string? path = null) : DbContext { public DbSet Pages { get; set; } public DbSet Chunks { get; set; } + public DbSet Metadata { get; set; } private readonly string _dbPath = "Data Source=" + (path ?? Path.Join(Environment.GetFolderPath(Environment.SpecialFolder.LocalApplicationData), "embeddings.db")); protected override void OnConfiguring(DbContextOptionsBuilder options)