Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
mzegla committed Jan 17, 2025
1 parent 9a1f68a commit 6912a7d
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions src/llm/llmnoderesources.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ Status LLMNodeResources::initializeLLMNodeResources(LLMNodeResources& nodeResour
auto draftSchedulerConfig = prepareDraftModelSchedulerConfig(nodeOptions);
auto draftModelConfig = ov::genai::draft_model(nodeOptions.draft_models_path(), nodeOptions.draft_device(),
ov::genai::scheduler_config(draftSchedulerConfig));
nodeResources->pluginConfig.insert(draftModelConfig);
nodeResources->isSpeculativePipeline = true;
nodeResources.pluginConfig.insert(draftModelConfig);
nodeResources.isSpeculativePipeline = true;
}

auto status = JsonParser::parsePluginConfig(nodeOptions.plugin_config(), nodeResources.pluginConfig);
Expand Down Expand Up @@ -214,14 +214,13 @@ std::unordered_map<std::string, std::string> LLMNodeResources::prepareLLMNodeIni
}

ov::genai::SchedulerConfig LLMNodeResources::prepareDraftModelSchedulerConfig(const mediapipe::LLMCalculatorOptions& nodeOptions) {
return {
.max_num_batched_tokens = nodeOptions.has_draft_max_num_batched_tokens() ? nodeOptions.draft_max_num_batched_tokens() : nodeOptions.max_num_batched_tokens(),
.cache_size = nodeOptions.has_draft_cache_size() ? nodeOptions.draft_cache_size() : nodeOptions.cache_size(),
.block_size = nodeOptions.has_draft_block_size() ? nodeOptions.draft_block_size() : nodeOptions.block_size(),
.dynamic_split_fuse = nodeOptions.has_draft_dynamic_split_fuse() ? nodeOptions.draft_dynamic_split_fuse() : nodeOptions.dynamic_split_fuse(),
.max_num_seqs = nodeOptions.has_draft_max_num_seqs() ? nodeOptions.draft_max_num_seqs() : nodeOptions.max_num_seqs(),
.enable_prefix_caching = nodeOptions.enable_prefix_caching(),
};
ov::genai::SchedulerConfig config;
config.max_num_batched_tokens = nodeOptions.has_draft_max_num_batched_tokens() ? nodeOptions.draft_max_num_batched_tokens() : nodeOptions.max_num_batched_tokens();
config.cache_size = nodeOptions.has_draft_cache_size() ? nodeOptions.draft_cache_size() : nodeOptions.cache_size();
config.dynamic_split_fuse = nodeOptions.has_draft_dynamic_split_fuse() ? nodeOptions.draft_dynamic_split_fuse() : nodeOptions.dynamic_split_fuse();
config.max_num_seqs = nodeOptions.has_draft_max_num_seqs() ? nodeOptions.draft_max_num_seqs() : nodeOptions.max_num_seqs();
config.enable_prefix_caching = nodeOptions.enable_prefix_caching();
return config;
}

} // namespace ovms

0 comments on commit 6912a7d

Please sign in to comment.