diff --git a/crates/llm-ls/src/config.rs b/crates/llm-ls/src/config.rs index db87f34..66abd0b 100644 --- a/crates/llm-ls/src/config.rs +++ b/crates/llm-ls/src/config.rs @@ -6,8 +6,28 @@ use tokio::fs::write; use crate::error::Result; +#[derive(Clone, Deserialize, Serialize)] +pub(crate) struct ModelConfig { + pub(crate) id: String, + pub(crate) revision: String, + pub(crate) embeddings_size: usize, + pub(crate) max_input_size: usize, +} + +impl Default for ModelConfig { + fn default() -> Self { + Self { + id: "intfloat/multilingual-e5-small".to_string(), + revision: "main".to_string(), + embeddings_size: 384, + max_input_size: 512, + } + } +} + #[derive(Deserialize, Serialize)] pub(crate) struct LlmLsConfig { + pub(crate) model: ModelConfig, /// .gitignore-like glob patterns to exclude from indexing pub(crate) ignored_paths: Vec, } @@ -15,22 +35,31 @@ pub(crate) struct LlmLsConfig { impl Default for LlmLsConfig { fn default() -> Self { Self { + model: ModelConfig::default(), ignored_paths: vec![".git".into(), ".idea".into(), ".DS_Store".into()], } } } -pub async fn load_config(cache_path: &str) -> Result { +/// Loads configuration from a file and environment variables. +/// +/// If the file does not exist, it will be created with the default configuration. +/// +/// # Arguments +/// +/// * `cache_path` - Path to the directory where the configuration file will be stored. +pub(crate) async fn load_config(cache_path: &str) -> Result { let config_file_path = Path::new(cache_path).join("config.yaml"); - if config_file_path.exists() { - Ok(Config::builder() + let config = if config_file_path.exists() { + Config::builder() .add_source(config::File::with_name(&format!("{cache_path}/config"))) .add_source(config::Environment::with_prefix("LLM_LS")) .build()? - .try_deserialize()?) + .try_deserialize()? } else { let config = LlmLsConfig::default(); write(config_file_path, serde_yaml::to_string(&config)?.as_bytes()).await?; - Ok(config) - } + config + }; + Ok(config) } diff --git a/crates/llm-ls/src/main.rs b/crates/llm-ls/src/main.rs index c26d6f6..0ad9ecf 100644 --- a/crates/llm-ls/src/main.rs +++ b/crates/llm-ls/src/main.rs @@ -882,11 +882,6 @@ async fn main() { .build() .expect("failed to build reqwest unsafe client"); - let snippet_retriever = Arc::new(RwLock::new( - SnippetRetriever::new(cache_dir.join("database"), 20, 10) - .await - .expect("failed to initialise snippet retriever"), - )); let config = Arc::new( load_config( cache_dir @@ -896,6 +891,11 @@ async fn main() { .await .expect("failed to load config file"), ); + let snippet_retriever = Arc::new(RwLock::new( + SnippetRetriever::new(cache_dir.join("embeddings"), config.model.clone(), 20, 10) + .await + .expect("failed to initialise snippet retriever"), + )); let (service, socket) = LspService::build(|client| LlmService { cache_dir, client, diff --git a/crates/llm-ls/src/retrieval.rs b/crates/llm-ls/src/retrieval.rs index 65bf1aa..8860b2d 100644 --- a/crates/llm-ls/src/retrieval.rs +++ b/crates/llm-ls/src/retrieval.rs @@ -1,4 +1,4 @@ -use crate::config::LlmLsConfig; +use crate::config::{LlmLsConfig, ModelConfig}; use crate::error::{Error, Result}; use candle::utils::{cuda_is_available, metal_is_available}; use candle::{Device, Tensor}; @@ -138,11 +138,12 @@ fn is_code_file(file_name: &Path) -> bool { } } -async fn build_model_and_tokenizer() -> Result<(BertModel, Tokenizer)> { +async fn build_model_and_tokenizer( + model_id: String, + revision: String, +) -> Result<(BertModel, Tokenizer)> { let start = Instant::now(); let device = device(false)?; - let model_id = "intfloat/multilingual-e5-small".to_string(); - let revision = "main".to_string(); let repo = Repo::with_revision(model_id, RepoType::Model, revision); let (config_filename, tokenizer_filename, weights_filename) = { let api = Api::new()?; @@ -218,6 +219,7 @@ pub(crate) struct SnippetRetriever { collection_name: String, db: Option, model: Arc, + model_config: ModelConfig, tokenizer: Tokenizer, window_size: usize, window_step: usize, @@ -229,16 +231,20 @@ impl SnippetRetriever { /// Panics if the database cannot be initialised. pub(crate) async fn new( cache_path: PathBuf, + model_config: ModelConfig, window_size: usize, window_step: usize, ) -> Result { let collection_name = "code-slices".to_owned(); - let (model, tokenizer) = build_model_and_tokenizer().await?; + let (model, tokenizer) = + build_model_and_tokenizer(model_config.id.clone(), model_config.revision.clone()) + .await?; Ok(Self { cache_path, collection_name, db: None, model: Arc::new(model), + model_config, tokenizer, window_size, window_step, @@ -249,7 +255,11 @@ impl SnippetRetriever { let uri = self.cache_path.join(db_name); let mut db = Db::open(uri).await.expect("failed to open database"); match db - .create_collection(self.collection_name.clone(), 384, Distance::Cosine) + .create_collection( + self.collection_name.clone(), + self.model_config.embeddings_size, + Distance::Cosine, + ) .await { Ok(_) @@ -378,7 +388,11 @@ impl SnippetRetriever { }; let col = db.get_collection(&self.collection_name).await?; let mut encoding = self.tokenizer.encode(snippet.clone(), true)?; - encoding.truncate(512, 1, TruncationDirection::Right); + encoding.truncate( + self.model_config.max_input_size, + 1, + TruncationDirection::Right, + ); let query = self .generate_embedding(encoding, self.model.clone()) .await?; @@ -495,7 +509,11 @@ impl SnippetRetriever { } let mut encoding = self.tokenizer.encode(snippet.clone(), true)?; - encoding.truncate(512, 1, TruncationDirection::Right); + encoding.truncate( + self.model_config.max_input_size, + 1, + TruncationDirection::Right, + ); let result = self.generate_embedding(encoding, self.model.clone()).await; let embedding = match result { Ok(e) => e,