Skip to content

Commit

Permalink
Add parsing of max model context length
Browse files Browse the repository at this point in the history
  • Loading branch information
michalkulakowski committed Nov 21, 2024
1 parent b7b57e7 commit 9d0624a
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 2 deletions.
34 changes: 33 additions & 1 deletion src/llm/apis/openai_completions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,12 @@ void OpenAIChatCompletionsHandler::incrementProcessedTokens(int numTokens) {
}

ov::genai::GenerationConfig OpenAIChatCompletionsHandler::createGenerationConfig() const {
return request.createGenerationConfig();
ov::genai::GenerationConfig config = request.createGenerationConfig();
if (maxModelLength.has_value()){
config.max_length = maxModelLength.value();
SPDLOG_LOGGER_ERROR(llm_calculator_logger, "Max model length {}", maxModelLength.value());
}
return config;
}

absl::Status OpenAIChatCompletionsHandler::parseRequest(uint32_t maxTokensLimit, uint32_t bestOfLimit) {
Expand All @@ -406,6 +411,33 @@ absl::Status OpenAIChatCompletionsHandler::parseRequest(uint32_t maxTokensLimit,
return status;
}

void OpenAIChatCompletionsHandler::parseMaxModelLength(std::string& modelsPath){
std::string configPath = modelsPath + "/config.json";
if (std::filesystem::exists(configPath.c_str())) {
std::ifstream ifs(configPath);
if (!ifs.is_open()) {
return;
}
rapidjson::Document modelConfig;
rapidjson::IStreamWrapper isw(ifs);
rapidjson::ParseResult parseResult = modelConfig.ParseStream(isw);
if(parseResult.Code()){
return;
}
std::vector<std::string> maxLengthFields = {"max_position_embeddings", "n_positions", "seq_len", "seq_length", "n_ctx", "sliding_window"};
for(auto field : maxLengthFields){
if(modelConfig.HasMember(field.c_str()) && modelConfig[field.c_str()].IsUint()){
maxModelLength = modelConfig[field.c_str()].GetUint();
}
}
}
return;
}

std::optional<uint> OpenAIChatCompletionsHandler::getMaxModelLength(){
return maxModelLength;
}

std::string OpenAIChatCompletionsHandler::serializeUnaryResponse(const std::vector<ov::genai::GenerationOutput>& generationOutputs) {
OVMS_PROFILE_FUNCTION();
StringBuffer buffer;
Expand Down
7 changes: 7 additions & 0 deletions src/llm/apis/openai_completions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@
#include <string>
#include <unordered_map>
#include <vector>
#include <fstream>

#include <openvino/genai/generation_config.hpp>
#include <openvino/genai/generation_handle.hpp>
#include <openvino/genai/tokenizer.hpp>
#include <rapidjson/document.h>
#include <rapidjson/writer.h>
#include <rapidjson/error/en.h>
#include <rapidjson/istreamwrapper.h>

#include "absl/status/status.h"

Expand Down Expand Up @@ -156,6 +159,7 @@ class OpenAIChatCompletionsHandler {
std::chrono::time_point<std::chrono::system_clock> created;
ov::genai::Tokenizer tokenizer;
size_t processedTokens = 0; // tracks overall number of tokens processed by the pipeline
std::optional<uint> maxModelLength;

absl::Status parseCompletionsPart();
absl::Status parseChatCompletionsPart();
Expand Down Expand Up @@ -184,6 +188,9 @@ class OpenAIChatCompletionsHandler {

absl::Status parseRequest(uint32_t maxTokensLimit, uint32_t bestOfLimit);

void parseMaxModelLength(std::string& modelsPath);
std::optional<uint> getMaxModelLength();

std::string serializeUnaryResponse(const std::vector<ov::genai::GenerationOutput>& generationOutputs);
std::string serializeStreamingChunk(const std::string& chunkResponse, ov::genai::GenerationFinishReason finishReason);
std::string serializeStreamingUsageChunk();
Expand Down
3 changes: 2 additions & 1 deletion src/llm/http_llm_calculator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class HttpLLMCalculator : public CalculatorBase {

std::string finalPrompt = "";
bool encodeAddSpecialTokens = false;

switch (endpoint) {
case Endpoint::CHAT_COMPLETIONS: {
if (!TextProcessor::applyChatTemplate(this->nodeResources->textProcessor, this->nodeResources->modelsPath, payload.body, finalPrompt)) {
Expand Down Expand Up @@ -160,7 +161,7 @@ class HttpLLMCalculator : public CalculatorBase {
ov::Tensor finalPromptIds = nodeResources->cbPipe->get_tokenizer().encode(finalPrompt, ov::genai::add_special_tokens(encodeAddSpecialTokens)).input_ids;
this->apiHandler->setPromptTokensUsage(finalPromptIds.get_size());
SPDLOG_LOGGER_TRACE(llm_calculator_logger, "{}", getPromptTokensString(finalPromptIds));

this->apiHandler->parseMaxModelLength(this->nodeResources->modelsPath);
this->generationHandle = nodeResources->cbPipe->add_request(
currentRequestId++, /*to be removed from API?*/
finalPromptIds,
Expand Down

0 comments on commit 9d0624a

Please sign in to comment.