Skip to content

Commit

Permalink
Add parsing of max model context length
Browse files Browse the repository at this point in the history
  • Loading branch information
michalkulakowski committed Dec 12, 2024
1 parent eb054f5 commit c3e3fe7
Show file tree
Hide file tree
Showing 7 changed files with 185 additions and 5 deletions.
1 change: 1 addition & 0 deletions src/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1906,6 +1906,7 @@ cc_test(
"//conditions:default" : [
# LLM logic uses Python for processing Jinja templates
"test/llmnode_test.cpp",
"test/max_model_length_test.cpp",
"test/llmtemplate_test.cpp",
"test/text_streamer_test.cpp",],
}) + select({
Expand Down
9 changes: 7 additions & 2 deletions src/llm/apis/openai_completions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,13 @@ void OpenAIChatCompletionsHandler::incrementProcessedTokens(int numTokens) {
usage.completionTokens += numTokens;
}

ov::genai::GenerationConfig OpenAIChatCompletionsHandler::createGenerationConfig() const {
return request.createGenerationConfig();
ov::genai::GenerationConfig OpenAIChatCompletionsHandler::createGenerationConfig(std::optional<uint> maxModelLength) const {
ov::genai::GenerationConfig config = request.createGenerationConfig();
if (maxModelLength.has_value()) {
config.max_length = maxModelLength.value();
SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Parsed max model length {}", maxModelLength.value());
}
return config;
}

absl::Status OpenAIChatCompletionsHandler::parseRequest(uint32_t maxTokensLimit, uint32_t bestOfLimit) {
Expand Down
3 changes: 2 additions & 1 deletion src/llm/apis/openai_completions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ class OpenAIChatCompletionsHandler {
std::chrono::time_point<std::chrono::system_clock> created;
ov::genai::Tokenizer tokenizer;
size_t processedTokens = 0; // tracks overall number of tokens processed by the pipeline
std::optional<uint> maxModelLength;

absl::Status parseCompletionsPart();
absl::Status parseChatCompletionsPart();
Expand All @@ -180,7 +181,7 @@ class OpenAIChatCompletionsHandler {

void incrementProcessedTokens(int numTokens = 1);

ov::genai::GenerationConfig createGenerationConfig() const;
ov::genai::GenerationConfig createGenerationConfig(std::optional<uint> maxModelLength) const;

absl::Status parseRequest(uint32_t maxTokensLimit, uint32_t bestOfLimit);

Expand Down
3 changes: 1 addition & 2 deletions src/llm/http_llm_calculator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,10 @@ class HttpLLMCalculator : public CalculatorBase {
ov::Tensor finalPromptIds = nodeResources->cbPipe->get_tokenizer().encode(finalPrompt, ov::genai::add_special_tokens(encodeAddSpecialTokens)).input_ids;
this->apiHandler->setPromptTokensUsage(finalPromptIds.get_size());
SPDLOG_LOGGER_TRACE(llm_calculator_logger, "{}", getPromptTokensString(finalPromptIds));

this->generationHandle = nodeResources->cbPipe->add_request(
currentRequestId++, /*to be removed from API?*/
finalPromptIds,
this->apiHandler->createGenerationConfig());
this->apiHandler->createGenerationConfig(this->nodeResources->maxModelLength));

// TODO: Revert when drogon adds disconnection callbacks: https://github.com/drogonframework/drogon/pull/2204
// this->client->registerDisconnectionCallback([genHandle = this->generationHandle]() {
Expand Down
30 changes: 30 additions & 0 deletions src/llm/llmnoderesources.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@
#include "mediapipe/framework/calculator_graph.h"
#pragma GCC diagnostic pop

#include <fstream>

#include <rapidjson/error/en.h>
#include <rapidjson/istreamwrapper.h>

#include "../mediapipe_internal/mediapipe_utils.hpp"
#include "src/llm/llm_calculator.pb.h"
#include "src/llm/llm_executor.hpp"
Expand Down Expand Up @@ -119,6 +124,30 @@ void LLMNodeResources::loadTextProcessor(LLMNodeResources& nodeResources, const
}
}

std::optional<uint> LLMNodeResources::parseMaxModelLength(std::string& modelsPath) {
std::string configPath = modelsPath + "/config.json";
std::optional<uint> maxModelLength;
if (std::filesystem::exists(configPath.c_str())) {
std::ifstream ifs(configPath);
if (!ifs.is_open()) {
return maxModelLength;
}
rapidjson::Document modelConfig;
rapidjson::IStreamWrapper isw(ifs);
rapidjson::ParseResult parseResult = modelConfig.ParseStream(isw);
if (parseResult.Code()) {
return maxModelLength;
}
std::vector<std::string> maxLengthFields = {"max_position_embeddings", "n_positions", "seq_len", "seq_length", "n_ctx", "sliding_window"};
for (auto field : maxLengthFields) {
if (modelConfig.HasMember(field.c_str()) && modelConfig[field.c_str()].IsUint()) {
maxModelLength = modelConfig[field.c_str()].GetUint();
}
}
}
return maxModelLength;
}

Status LLMNodeResources::initializeLLMNodeResources(LLMNodeResources& nodeResources, const ::mediapipe::CalculatorGraphConfig::Node& graphNodeConfig, std::string graphPath) {
mediapipe::LLMCalculatorOptions nodeOptions;
graphNodeConfig.node_options(0).UnpackTo(&nodeOptions);
Expand All @@ -144,6 +173,7 @@ Status LLMNodeResources::initializeLLMNodeResources(LLMNodeResources& nodeResour
SPDLOG_LOGGER_ERROR(modelmanager_logger, "LLM node models_path: {} is not a directory. ", basePath);
return StatusCode::LLM_NODE_DIRECTORY_DOES_NOT_EXIST;
}
nodeResources.maxModelLength = parseMaxModelLength(basePath);

nodeResources.schedulerConfig = {
.max_num_batched_tokens = nodeOptions.max_num_batched_tokens(),
Expand Down
2 changes: 2 additions & 0 deletions src/llm/llmnoderesources.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,11 @@ struct LLMNodeResources {
TextProcessor textProcessor;
int maxTokensLimit;
int bestOfLimit;
std::optional<uint> maxModelLength;

static Status initializeLLMNodeResources(LLMNodeResources& nodeResources, const ::mediapipe::CalculatorGraphConfig::Node& graphNode, std::string graphPath);
static void loadTextProcessor(LLMNodeResources& nodeResources, const std::string& chatTemplateDirectory);
static std::optional<uint> parseMaxModelLength(std::string& modelsPath);

LLMNodeResources(const LLMNodeResources&) = delete;
LLMNodeResources& operator=(LLMNodeResources&) = delete;
Expand Down
142 changes: 142 additions & 0 deletions src/test/max_model_length_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
//*****************************************************************************
// Copyright 2024 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <rapidjson/document.h>

#include "../llm/llmnoderesources.hpp"
#include "test_utils.hpp"

using namespace ovms;

class MaxModelLengthTest : public TestWithTempDir {
protected:
std::string configFilePath;
rapidjson::Document doc;
ov::genai::Tokenizer dummyTokenizer;

void SetUp() {
TestWithTempDir::SetUp();
configFilePath = directoryPath + "/config.json";
}
};

TEST_F(MaxModelLengthTest, maxModelLength_MaxPositionEmbeddings_VALID) {
std::string modelConfigContent = R"({"max_position_embeddings" : 5})";
createConfigFileWithContent(modelConfigContent, configFilePath);
auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath);
ASSERT_TRUE(maxModelLength.has_value());
EXPECT_EQ(maxModelLength.value(), 5);
}

TEST_F(MaxModelLengthTest, maxModelLength_MaxPositionEmbeddings_INVALID) {
std::string modelConfigContent = R"({"max_position_embeddings" : "INVALID"})";
createConfigFileWithContent(modelConfigContent, configFilePath);
auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath);
EXPECT_FALSE(maxModelLength.has_value());
}

TEST_F(MaxModelLengthTest, maxModelLength_nPositions_VALID) {
std::string modelConfigContent = R"({"n_positions" : 5})";
createConfigFileWithContent(modelConfigContent, configFilePath);
auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath);
ASSERT_TRUE(maxModelLength.has_value());
EXPECT_EQ(maxModelLength.value(), 5);
}

TEST_F(MaxModelLengthTest, maxModelLength_nPositions_INVALID) {
std::string modelConfigContent = R"({"n_positions" : "INVALID"})";
createConfigFileWithContent(modelConfigContent, configFilePath);
auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath);
EXPECT_FALSE(maxModelLength.has_value());
}

TEST_F(MaxModelLengthTest, maxModelLength_seqLen_VALID) {
std::string modelConfigContent = R"({"seq_len" : 5})";
createConfigFileWithContent(modelConfigContent, configFilePath);
auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath);
ASSERT_TRUE(maxModelLength.has_value());
EXPECT_EQ(maxModelLength.value(), 5);
}

TEST_F(MaxModelLengthTest, maxModelLength_seqLen_INVALID) {
std::string modelConfigContent = R"({"seq_len" : "INVALID"})";
createConfigFileWithContent(modelConfigContent, configFilePath);
auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath);
EXPECT_FALSE(maxModelLength.has_value());
}

TEST_F(MaxModelLengthTest, maxModelLength_seqLength_VALID) {
std::string modelConfigContent = R"({"seq_length" : 5})";
createConfigFileWithContent(modelConfigContent, configFilePath);
auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath);
ASSERT_TRUE(maxModelLength.has_value());
EXPECT_EQ(maxModelLength.value(), 5);
}

TEST_F(MaxModelLengthTest, maxModelLength_seqLength_INVALID) {
std::string modelConfigContent = R"({"seq_length" : "INVALID"})";
createConfigFileWithContent(modelConfigContent, configFilePath);
auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath);
EXPECT_FALSE(maxModelLength.has_value());
}

TEST_F(MaxModelLengthTest, maxModelLength_nCtx_VALID) {
std::string modelConfigContent = R"({"n_ctx" : 5})";
createConfigFileWithContent(modelConfigContent, configFilePath);
auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath);
ASSERT_TRUE(maxModelLength.has_value());
EXPECT_EQ(maxModelLength.value(), 5);
}

TEST_F(MaxModelLengthTest, maxModelLength_nCtx_INVALID) {
std::string modelConfigContent = R"({"n_ctx" : "INVALID"})";
createConfigFileWithContent(modelConfigContent, configFilePath);
auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath);
EXPECT_FALSE(maxModelLength.has_value());
}

TEST_F(MaxModelLengthTest, maxModelLength_slidingWindow_VALID) {
std::string modelConfigContent = R"({"sliding_window" : 5})";
createConfigFileWithContent(modelConfigContent, configFilePath);
auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath);
ASSERT_TRUE(maxModelLength.has_value());
EXPECT_EQ(maxModelLength.value(), 5);
}

TEST_F(MaxModelLengthTest, maxModelLength_slidingWindow_INVALID) {
std::string modelConfigContent = R"({"sliding_window" : "INVALID"})";
createConfigFileWithContent(modelConfigContent, configFilePath);
auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath);
EXPECT_FALSE(maxModelLength.has_value());
}

TEST_F(MaxModelLengthTest, maxModelLength_emptyConfig) {
std::string modelConfigContent = R"({})";
createConfigFileWithContent(modelConfigContent, configFilePath);
auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath);
EXPECT_FALSE(maxModelLength.has_value());
}

TEST_F(MaxModelLengthTest, maxModelLength_parsingOrder) {
std::string modelConfigContent = R"({"max_position_embeddings" : 5, "seq_length" : 6, "n_positions" : 7, "sliding_window" : 8, "seq_len" : 9, "n_ctx" : 10})";
createConfigFileWithContent(modelConfigContent, configFilePath);
auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath);
ASSERT_TRUE(maxModelLength.has_value());
EXPECT_EQ(maxModelLength.value(), 8);
}

// TODO: Add e2e test

0 comments on commit c3e3fe7

Please sign in to comment.