Skip to content

Commit

Permalink
feat: add ModelConfig to LlmLsConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
McPatate committed Feb 28, 2024
1 parent 8b87df6 commit baedf85
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 19 deletions.
41 changes: 35 additions & 6 deletions crates/llm-ls/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,60 @@ use tokio::fs::write;

use crate::error::Result;

#[derive(Clone, Deserialize, Serialize)]
pub(crate) struct ModelConfig {
pub(crate) id: String,
pub(crate) revision: String,
pub(crate) embeddings_size: usize,
pub(crate) max_input_size: usize,
}

impl Default for ModelConfig {
fn default() -> Self {
Self {
id: "intfloat/multilingual-e5-small".to_string(),
revision: "main".to_string(),
embeddings_size: 384,
max_input_size: 512,
}
}
}

#[derive(Deserialize, Serialize)]
pub(crate) struct LlmLsConfig {
pub(crate) model: ModelConfig,
/// .gitignore-like glob patterns to exclude from indexing
pub(crate) ignored_paths: Vec<String>,
}

impl Default for LlmLsConfig {
fn default() -> Self {
Self {
model: ModelConfig::default(),
ignored_paths: vec![".git".into(), ".idea".into(), ".DS_Store".into()],
}
}
}

pub async fn load_config(cache_path: &str) -> Result<LlmLsConfig> {
/// Loads configuration from a file and environment variables.
///
/// If the file does not exist, it will be created with the default configuration.
///
/// # Arguments
///
/// * `cache_path` - Path to the directory where the configuration file will be stored.
pub(crate) async fn load_config(cache_path: &str) -> Result<LlmLsConfig> {
let config_file_path = Path::new(cache_path).join("config.yaml");
if config_file_path.exists() {
Ok(Config::builder()
let config = if config_file_path.exists() {
Config::builder()
.add_source(config::File::with_name(&format!("{cache_path}/config")))
.add_source(config::Environment::with_prefix("LLM_LS"))
.build()?
.try_deserialize()?)
.try_deserialize()?
} else {
let config = LlmLsConfig::default();
write(config_file_path, serde_yaml::to_string(&config)?.as_bytes()).await?;
Ok(config)
}
config
};
Ok(config)
}
10 changes: 5 additions & 5 deletions crates/llm-ls/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -882,11 +882,6 @@ async fn main() {
.build()
.expect("failed to build reqwest unsafe client");

let snippet_retriever = Arc::new(RwLock::new(
SnippetRetriever::new(cache_dir.join("database"), 20, 10)
.await
.expect("failed to initialise snippet retriever"),
));
let config = Arc::new(
load_config(
cache_dir
Expand All @@ -896,6 +891,11 @@ async fn main() {
.await
.expect("failed to load config file"),
);
let snippet_retriever = Arc::new(RwLock::new(
SnippetRetriever::new(cache_dir.join("embeddings"), config.model.clone(), 20, 10)
.await
.expect("failed to initialise snippet retriever"),
));
let (service, socket) = LspService::build(|client| LlmService {
cache_dir,
client,
Expand Down
34 changes: 26 additions & 8 deletions crates/llm-ls/src/retrieval.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::config::LlmLsConfig;
use crate::config::{LlmLsConfig, ModelConfig};
use crate::error::{Error, Result};
use candle::utils::{cuda_is_available, metal_is_available};
use candle::{Device, Tensor};
Expand Down Expand Up @@ -138,11 +138,12 @@ fn is_code_file(file_name: &Path) -> bool {
}
}

async fn build_model_and_tokenizer() -> Result<(BertModel, Tokenizer)> {
async fn build_model_and_tokenizer(
model_id: String,
revision: String,
) -> Result<(BertModel, Tokenizer)> {
let start = Instant::now();
let device = device(false)?;
let model_id = "intfloat/multilingual-e5-small".to_string();
let revision = "main".to_string();
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let (config_filename, tokenizer_filename, weights_filename) = {
let api = Api::new()?;
Expand Down Expand Up @@ -218,6 +219,7 @@ pub(crate) struct SnippetRetriever {
collection_name: String,
db: Option<Db>,
model: Arc<BertModel>,
model_config: ModelConfig,
tokenizer: Tokenizer,
window_size: usize,
window_step: usize,
Expand All @@ -229,16 +231,20 @@ impl SnippetRetriever {
/// Panics if the database cannot be initialised.
pub(crate) async fn new(
cache_path: PathBuf,
model_config: ModelConfig,
window_size: usize,
window_step: usize,
) -> Result<Self> {
let collection_name = "code-slices".to_owned();
let (model, tokenizer) = build_model_and_tokenizer().await?;
let (model, tokenizer) =
build_model_and_tokenizer(model_config.id.clone(), model_config.revision.clone())
.await?;
Ok(Self {
cache_path,
collection_name,
db: None,
model: Arc::new(model),
model_config,
tokenizer,
window_size,
window_step,
Expand All @@ -249,7 +255,11 @@ impl SnippetRetriever {
let uri = self.cache_path.join(db_name);
let mut db = Db::open(uri).await.expect("failed to open database");
match db
.create_collection(self.collection_name.clone(), 384, Distance::Cosine)
.create_collection(
self.collection_name.clone(),
self.model_config.embeddings_size,
Distance::Cosine,
)
.await
{
Ok(_)
Expand Down Expand Up @@ -378,7 +388,11 @@ impl SnippetRetriever {
};
let col = db.get_collection(&self.collection_name).await?;
let mut encoding = self.tokenizer.encode(snippet.clone(), true)?;
encoding.truncate(512, 1, TruncationDirection::Right);
encoding.truncate(
self.model_config.max_input_size,
1,
TruncationDirection::Right,
);
let query = self
.generate_embedding(encoding, self.model.clone())
.await?;
Expand Down Expand Up @@ -495,7 +509,11 @@ impl SnippetRetriever {
}

let mut encoding = self.tokenizer.encode(snippet.clone(), true)?;
encoding.truncate(512, 1, TruncationDirection::Right);
encoding.truncate(
self.model_config.max_input_size,
1,
TruncationDirection::Right,
);
let result = self.generate_embedding(encoding, self.model.clone()).await;
let embedding = match result {
Ok(e) => e,
Expand Down

0 comments on commit baedf85

Please sign in to comment.