From 17168295cedcf60dba89e3600ec87fb74857c92d Mon Sep 17 00:00:00 2001 From: Quentin Maire Date: Tue, 27 Feb 2024 15:55:53 +0100 Subject: [PATCH 01/16] KB is now processed in batch --- crates/llm-ls/src/retrieval.rs | 50 ++++++++++++++++++++++++++-------- 1 file changed, 39 insertions(+), 11 deletions(-) diff --git a/crates/llm-ls/src/retrieval.rs b/crates/llm-ls/src/retrieval.rs index 6bea3f8..d52a77e 100644 --- a/crates/llm-ls/src/retrieval.rs +++ b/crates/llm-ls/src/retrieval.rs @@ -16,15 +16,16 @@ use tokenizers::{ Encoding, PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer, TruncationDirection, }; use tokio::io::AsyncReadExt; -use tokio::task::spawn_blocking; +use tokio::task::{spawn_blocking}; use tokio::time::Instant; use tower_lsp::lsp_types::notification::Progress; use tower_lsp::lsp_types::{ NumberOrString, ProgressParams, ProgressParamsValue, Range, WorkDoneProgress, WorkDoneProgressReport, }; +use std::iter::zip; use tower_lsp::Client; -use tracing::{debug, error, warn}; +use tracing::{debug, info, error, warn}; // TODO: // - create sliding window and splitting of files logic @@ -198,6 +199,23 @@ fn device(cpu: bool) -> Result { } } +async fn initialse_database(cache_path: PathBuf) -> Db { + let uri = cache_path.join("database"); + let mut db = Db::open(uri).await.expect("failed to open database"); + match db + .create_collection("code-slices".to_owned(), 384, Distance::Cosine) + .await + { + Ok(_) + | Err(tinyvec_embed::error::Error::Collection( + tinyvec_embed::error::Collection::UniqueViolation, + )) => (), + Err(err) => panic!("failed to create collection: {err}"), + } + db +} + +#[derive(Default)] pub(crate) struct Snippet { pub(crate) file_url: String, pub(crate) code: String, @@ -464,15 +482,24 @@ impl SnippetRetriever { Some(db) => db.clone(), None => return Err(Error::UninitialisedDatabase), }; - let col = db.get_collection(&self.collection_name).await?; - let result = col - .read() - .await - .get(query, 5, filter) - .await? - .iter() - .map(TryInto::try_into) - .collect::>>()?; + let col = self.db.get_collection("code-slices").await?; + let mut encoding = self.tokenizer.encode(snippet.clone(), true)?; + encoding.truncate(512, 1, TruncationDirection::Right); + let batch = vec![encoding]; + let query = self + .generate_embedding(batch, self.model.clone()) + .await?; + let result = match query.first() { + Some(res) => col + .read() + .await + .get(res, 5, filter) + .await? + .iter() + .map(TryInto::try_into) + .collect::>>()?, + _ => vec![Snippet {..Default::default()}] + }; Ok(result) } @@ -555,6 +582,7 @@ impl SnippetRetriever { debug!("Building embeddings for {file_url}"); for start_line in (start..end).step_by(self.window_step) { let end_line = (start_line + self.window_size - 1).min(lines.len()); + debug!("Going from line {start_line} to {end_line} in {file_url}"); if !col .read() .await From 49edf95bd86a1b9940004572218e1552ad617f42 Mon Sep 17 00:00:00 2001 From: Wats0ns Date: Wed, 28 Feb 2024 10:17:41 +0100 Subject: [PATCH 02/16] Update crates/llm-ls/src/retrieval.rs Co-authored-by: Luc Georges --- crates/llm-ls/src/retrieval.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/llm-ls/src/retrieval.rs b/crates/llm-ls/src/retrieval.rs index d52a77e..9d109f4 100644 --- a/crates/llm-ls/src/retrieval.rs +++ b/crates/llm-ls/src/retrieval.rs @@ -498,7 +498,7 @@ impl SnippetRetriever { .iter() .map(TryInto::try_into) .collect::>>()?, - _ => vec![Snippet {..Default::default()}] + _ => vec![] }; Ok(result) } From 6813b01b78f63bf8d66da2fa403c0c91ba2d321b Mon Sep 17 00:00:00 2001 From: Wats0ns Date: Wed, 28 Feb 2024 10:46:33 +0100 Subject: [PATCH 03/16] Update crates/llm-ls/src/retrieval.rs Co-authored-by: Luc Georges --- crates/llm-ls/src/retrieval.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/crates/llm-ls/src/retrieval.rs b/crates/llm-ls/src/retrieval.rs index 9d109f4..4dc8f4a 100644 --- a/crates/llm-ls/src/retrieval.rs +++ b/crates/llm-ls/src/retrieval.rs @@ -582,7 +582,6 @@ impl SnippetRetriever { debug!("Building embeddings for {file_url}"); for start_line in (start..end).step_by(self.window_step) { let end_line = (start_line + self.window_size - 1).min(lines.len()); - debug!("Going from line {start_line} to {end_line} in {file_url}"); if !col .read() .await From b2623f7749c69196486e5eef73ba685dbd0610eb Mon Sep 17 00:00:00 2001 From: Quentin Maire Date: Wed, 28 Feb 2024 11:55:28 +0100 Subject: [PATCH 04/16] Corrections after feedbacks --- crates/llm-ls/src/backend.rs | 10 ++++++++-- crates/llm-ls/src/retrieval.rs | 6 +++--- crates/tinyvec-embed/src/db.rs | 9 +++++++++ 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/crates/llm-ls/src/backend.rs b/crates/llm-ls/src/backend.rs index dba1fab..f24c13b 100644 --- a/crates/llm-ls/src/backend.rs +++ b/crates/llm-ls/src/backend.rs @@ -4,6 +4,7 @@ use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT}; use serde::{Deserialize, Serialize}; use serde_json::{json, Map, Value}; use std::fmt::Display; +use tracing::{debug}; use crate::error::{Error, Result}; @@ -168,11 +169,15 @@ fn build_openai_headers(api_token: Option<&String>, ide: Ide) -> Result Result> { - match serde_json::from_str(text)? { + match serde_json::from_str(text).unwrap() { OpenAIAPIResponse::Generation(completion) => { Ok(completion.choices.into_iter().map(|x| x.into()).collect()) } - OpenAIAPIResponse::Error(err) => Err(Error::OpenAI(err)), + OpenAIAPIResponse::Error(err) => { + debug!("Got {text}"); + Err(Error::OpenAI(err)) + }, + } } @@ -215,6 +220,7 @@ pub(crate) fn build_headers( } pub(crate) fn parse_generations(backend: &Backend, text: &str) -> Result> { + debug!("Got {text}"); match backend { Backend::HuggingFace { .. } => parse_api_text(text), Backend::Ollama { .. } => parse_ollama_text(text), diff --git a/crates/llm-ls/src/retrieval.rs b/crates/llm-ls/src/retrieval.rs index 4dc8f4a..35802cf 100644 --- a/crates/llm-ls/src/retrieval.rs +++ b/crates/llm-ls/src/retrieval.rs @@ -25,7 +25,7 @@ use tower_lsp::lsp_types::{ }; use std::iter::zip; use tower_lsp::Client; -use tracing::{debug, info, error, warn}; +use tracing::{debug, info, error, warn, error}; // TODO: // - create sliding window and splitting of files logic @@ -215,7 +215,6 @@ async fn initialse_database(cache_path: PathBuf) -> Db { db } -#[derive(Default)] pub(crate) struct Snippet { pub(crate) file_url: String, pub(crate) code: String, @@ -487,7 +486,7 @@ impl SnippetRetriever { encoding.truncate(512, 1, TruncationDirection::Right); let batch = vec![encoding]; let query = self - .generate_embedding(batch, self.model.clone()) + .generate_embeddings(batch, self.model.clone()) .await?; let result = match query.first() { Some(res) => col @@ -546,6 +545,7 @@ impl SnippetRetriever { encodings: Vec, model: Arc, ) -> Result>> { + // Embedding order has to be preserved and stay the same as encoding input let start = Instant::now(); let embedding = spawn_blocking(move || -> Result>> { let tokens = encodings diff --git a/crates/tinyvec-embed/src/db.rs b/crates/tinyvec-embed/src/db.rs index dd20995..a5cbb52 100644 --- a/crates/tinyvec-embed/src/db.rs +++ b/crates/tinyvec-embed/src/db.rs @@ -235,6 +235,7 @@ impl Eq for Embedding {} pub enum Value { String(String), Number(f32), + Usize(usize), } impl Display for Value { @@ -242,6 +243,7 @@ impl Display for Value { match self { Self::String(s) => write!(f, "{s}"), Self::Number(n) => write!(f, "{n}"), + Self::Usize(u) => write!(f, "{u}"), } } } @@ -253,6 +255,13 @@ impl Value { _ => Err(Error::ValueNotString(self.to_owned())), } } + + pub fn inner_value(&self) -> Result { + match self { + Self::Usize(s) => Ok(s.to_owned()), + _ => Err(Error::ValueNotString(self.to_owned())), + } + } } impl TryInto for &Value { From f2e575ae6cf8353bf6272a393df5d171aba46e19 Mon Sep 17 00:00:00 2001 From: Quentin Maire Date: Wed, 28 Feb 2024 13:44:01 +0100 Subject: [PATCH 05/16] Corrections after feedbacks --- crates/llm-ls/src/backend.rs | 10 ++-------- crates/llm-ls/src/retrieval.rs | 1 - 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/crates/llm-ls/src/backend.rs b/crates/llm-ls/src/backend.rs index f24c13b..007ee6c 100644 --- a/crates/llm-ls/src/backend.rs +++ b/crates/llm-ls/src/backend.rs @@ -4,7 +4,6 @@ use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT}; use serde::{Deserialize, Serialize}; use serde_json::{json, Map, Value}; use std::fmt::Display; -use tracing::{debug}; use crate::error::{Error, Result}; @@ -169,15 +168,11 @@ fn build_openai_headers(api_token: Option<&String>, ide: Ide) -> Result Result> { - match serde_json::from_str(text).unwrap() { + match serde_json::from_str(text)? { OpenAIAPIResponse::Generation(completion) => { Ok(completion.choices.into_iter().map(|x| x.into()).collect()) } - OpenAIAPIResponse::Error(err) => { - debug!("Got {text}"); - Err(Error::OpenAI(err)) - }, - + OpenAIAPIResponse::Error(err) => Err(Error::OpenAI(err)) } } @@ -220,7 +215,6 @@ pub(crate) fn build_headers( } pub(crate) fn parse_generations(backend: &Backend, text: &str) -> Result> { - debug!("Got {text}"); match backend { Backend::HuggingFace { .. } => parse_api_text(text), Backend::Ollama { .. } => parse_ollama_text(text), diff --git a/crates/llm-ls/src/retrieval.rs b/crates/llm-ls/src/retrieval.rs index 35802cf..631d0cb 100644 --- a/crates/llm-ls/src/retrieval.rs +++ b/crates/llm-ls/src/retrieval.rs @@ -545,7 +545,6 @@ impl SnippetRetriever { encodings: Vec, model: Arc, ) -> Result>> { - // Embedding order has to be preserved and stay the same as encoding input let start = Instant::now(); let embedding = spawn_blocking(move || -> Result>> { let tokens = encodings From 1835d0c7bf3fb9859425d29810745dd36694b93d Mon Sep 17 00:00:00 2001 From: Quentin Maire Date: Thu, 29 Feb 2024 10:44:50 +0100 Subject: [PATCH 06/16] Corrections after feedbacks --- crates/tinyvec-embed/src/db.rs | 15 +++++++++------ crates/tinyvec-embed/src/error.rs | 2 ++ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/crates/tinyvec-embed/src/db.rs b/crates/tinyvec-embed/src/db.rs index a5cbb52..8a7c0b8 100644 --- a/crates/tinyvec-embed/src/db.rs +++ b/crates/tinyvec-embed/src/db.rs @@ -235,7 +235,6 @@ impl Eq for Embedding {} pub enum Value { String(String), Number(f32), - Usize(usize), } impl Display for Value { @@ -243,7 +242,6 @@ impl Display for Value { match self { Self::String(s) => write!(f, "{s}"), Self::Number(n) => write!(f, "{n}"), - Self::Usize(u) => write!(f, "{u}"), } } } @@ -255,11 +253,16 @@ impl Value { _ => Err(Error::ValueNotString(self.to_owned())), } } +} - pub fn inner_value(&self) -> Result { - match self { - Self::Usize(s) => Ok(s.to_owned()), - _ => Err(Error::ValueNotString(self.to_owned())), +impl TryInto for &Value { + type Error = Error; + + fn try_into(self) -> Result { + if let Value::Number(n) = self { + Ok(n.clone() as usize) + } else { + Err(Error::ValueNotUsize(self.to_owned())) } } } diff --git a/crates/tinyvec-embed/src/error.rs b/crates/tinyvec-embed/src/error.rs index 8dc0f54..60e41b4 100644 --- a/crates/tinyvec-embed/src/error.rs +++ b/crates/tinyvec-embed/src/error.rs @@ -36,6 +36,8 @@ pub enum Error { ValueNotNumber(Value), #[error("expected value to be string, got: {0}")] ValueNotString(Value), + #[error("expected value to be a valid size, got: {0}")] + ValueNotUsize(Value), } pub type Result = std::result::Result; From 025ebdb575316b8fcbcfc6ded965285a9dd62bc5 Mon Sep 17 00:00:00 2001 From: Wats0ns Date: Thu, 29 Feb 2024 10:47:46 +0100 Subject: [PATCH 07/16] Update crates/llm-ls/src/backend.rs Co-authored-by: Luc Georges --- crates/llm-ls/src/backend.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/llm-ls/src/backend.rs b/crates/llm-ls/src/backend.rs index 007ee6c..dba1fab 100644 --- a/crates/llm-ls/src/backend.rs +++ b/crates/llm-ls/src/backend.rs @@ -172,7 +172,7 @@ fn parse_openai_text(text: &str) -> Result> { OpenAIAPIResponse::Generation(completion) => { Ok(completion.choices.into_iter().map(|x| x.into()).collect()) } - OpenAIAPIResponse::Error(err) => Err(Error::OpenAI(err)) + OpenAIAPIResponse::Error(err) => Err(Error::OpenAI(err)), } } From 66ed16c847d9a594e87bde81ee86bf3fd160de64 Mon Sep 17 00:00:00 2001 From: Wats0ns Date: Thu, 29 Feb 2024 11:42:46 +0100 Subject: [PATCH 08/16] Update crates/tinyvec-embed/src/db.rs Co-authored-by: Luc Georges --- crates/tinyvec-embed/src/db.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/tinyvec-embed/src/db.rs b/crates/tinyvec-embed/src/db.rs index 8a7c0b8..8c5a99a 100644 --- a/crates/tinyvec-embed/src/db.rs +++ b/crates/tinyvec-embed/src/db.rs @@ -260,7 +260,7 @@ impl TryInto for &Value { fn try_into(self) -> Result { if let Value::Number(n) = self { - Ok(n.clone() as usize) + Ok(n as usize) } else { Err(Error::ValueNotUsize(self.to_owned())) } From c0e76660a7e7ea9c765da10f807c2eaa38571d3e Mon Sep 17 00:00:00 2001 From: Wats0ns Date: Thu, 29 Feb 2024 11:42:59 +0100 Subject: [PATCH 09/16] Update crates/tinyvec-embed/src/db.rs Co-authored-by: Luc Georges --- crates/tinyvec-embed/src/db.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/tinyvec-embed/src/db.rs b/crates/tinyvec-embed/src/db.rs index 8c5a99a..bef9f47 100644 --- a/crates/tinyvec-embed/src/db.rs +++ b/crates/tinyvec-embed/src/db.rs @@ -262,7 +262,7 @@ impl TryInto for &Value { if let Value::Number(n) = self { Ok(n as usize) } else { - Err(Error::ValueNotUsize(self.to_owned())) + Err(Error::ValueNotUsize(self)) } } } From d8dc17faff1943e45690bdf84a8c7b541997a9cf Mon Sep 17 00:00:00 2001 From: Wats0ns Date: Thu, 29 Feb 2024 11:43:27 +0100 Subject: [PATCH 10/16] Update crates/tinyvec-embed/src/error.rs Co-authored-by: Luc Georges --- crates/tinyvec-embed/src/error.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/crates/tinyvec-embed/src/error.rs b/crates/tinyvec-embed/src/error.rs index 60e41b4..8dc0f54 100644 --- a/crates/tinyvec-embed/src/error.rs +++ b/crates/tinyvec-embed/src/error.rs @@ -36,8 +36,6 @@ pub enum Error { ValueNotNumber(Value), #[error("expected value to be string, got: {0}")] ValueNotString(Value), - #[error("expected value to be a valid size, got: {0}")] - ValueNotUsize(Value), } pub type Result = std::result::Result; From 29d9d99836b69c9bd68ec615968a77ef01c7ec51 Mon Sep 17 00:00:00 2001 From: Quentin Maire Date: Thu, 29 Feb 2024 11:50:19 +0100 Subject: [PATCH 11/16] Corrections after feedbacks --- crates/llm-ls/src/retrieval.rs | 45 ++++++++-------------------------- crates/tinyvec-embed/src/db.rs | 15 +++++------- 2 files changed, 16 insertions(+), 44 deletions(-) diff --git a/crates/llm-ls/src/retrieval.rs b/crates/llm-ls/src/retrieval.rs index 631d0cb..3471ef4 100644 --- a/crates/llm-ls/src/retrieval.rs +++ b/crates/llm-ls/src/retrieval.rs @@ -25,7 +25,7 @@ use tower_lsp::lsp_types::{ }; use std::iter::zip; use tower_lsp::Client; -use tracing::{debug, info, error, warn, error}; +use tracing::{debug, error, warn}; // TODO: // - create sliding window and splitting of files logic @@ -199,22 +199,6 @@ fn device(cpu: bool) -> Result { } } -async fn initialse_database(cache_path: PathBuf) -> Db { - let uri = cache_path.join("database"); - let mut db = Db::open(uri).await.expect("failed to open database"); - match db - .create_collection("code-slices".to_owned(), 384, Distance::Cosine) - .await - { - Ok(_) - | Err(tinyvec_embed::error::Error::Collection( - tinyvec_embed::error::Collection::UniqueViolation, - )) => (), - Err(err) => panic!("failed to create collection: {err}"), - } - db -} - pub(crate) struct Snippet { pub(crate) file_url: String, pub(crate) code: String, @@ -481,24 +465,15 @@ impl SnippetRetriever { Some(db) => db.clone(), None => return Err(Error::UninitialisedDatabase), }; - let col = self.db.get_collection("code-slices").await?; - let mut encoding = self.tokenizer.encode(snippet.clone(), true)?; - encoding.truncate(512, 1, TruncationDirection::Right); - let batch = vec![encoding]; - let query = self - .generate_embeddings(batch, self.model.clone()) - .await?; - let result = match query.first() { - Some(res) => col - .read() - .await - .get(res, 5, filter) - .await? - .iter() - .map(TryInto::try_into) - .collect::>>()?, - _ => vec![] - }; + let col = db.get_collection("code-slices").await?; + let result = col + .read() + .await + .get(query, 5, filter) + .await? + .iter() + .map(TryInto::try_into) + .collect::>>()?; Ok(result) } diff --git a/crates/tinyvec-embed/src/db.rs b/crates/tinyvec-embed/src/db.rs index bef9f47..a5cbb52 100644 --- a/crates/tinyvec-embed/src/db.rs +++ b/crates/tinyvec-embed/src/db.rs @@ -235,6 +235,7 @@ impl Eq for Embedding {} pub enum Value { String(String), Number(f32), + Usize(usize), } impl Display for Value { @@ -242,6 +243,7 @@ impl Display for Value { match self { Self::String(s) => write!(f, "{s}"), Self::Number(n) => write!(f, "{n}"), + Self::Usize(u) => write!(f, "{u}"), } } } @@ -253,16 +255,11 @@ impl Value { _ => Err(Error::ValueNotString(self.to_owned())), } } -} -impl TryInto for &Value { - type Error = Error; - - fn try_into(self) -> Result { - if let Value::Number(n) = self { - Ok(n as usize) - } else { - Err(Error::ValueNotUsize(self)) + pub fn inner_value(&self) -> Result { + match self { + Self::Usize(s) => Ok(s.to_owned()), + _ => Err(Error::ValueNotString(self.to_owned())), } } } From 7df31e5be97de7f4b493d7f7108d44a4a8a468c3 Mon Sep 17 00:00:00 2001 From: Quentin Maire Date: Wed, 6 Mar 2024 11:33:10 +0100 Subject: [PATCH 12/16] Corrections after new feedbacks, rollbacking Value conversion changes --- crates/llm-ls/src/retrieval.rs | 2 +- crates/tinyvec-embed/src/db.rs | 15 +++++++++------ crates/tinyvec-embed/src/error.rs | 2 ++ 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/crates/llm-ls/src/retrieval.rs b/crates/llm-ls/src/retrieval.rs index 3471ef4..79c02ad 100644 --- a/crates/llm-ls/src/retrieval.rs +++ b/crates/llm-ls/src/retrieval.rs @@ -465,7 +465,7 @@ impl SnippetRetriever { Some(db) => db.clone(), None => return Err(Error::UninitialisedDatabase), }; - let col = db.get_collection("code-slices").await?; + let col = db.get_collection(&self.collection_name).await?; let result = col .read() .await diff --git a/crates/tinyvec-embed/src/db.rs b/crates/tinyvec-embed/src/db.rs index a5cbb52..8a7c0b8 100644 --- a/crates/tinyvec-embed/src/db.rs +++ b/crates/tinyvec-embed/src/db.rs @@ -235,7 +235,6 @@ impl Eq for Embedding {} pub enum Value { String(String), Number(f32), - Usize(usize), } impl Display for Value { @@ -243,7 +242,6 @@ impl Display for Value { match self { Self::String(s) => write!(f, "{s}"), Self::Number(n) => write!(f, "{n}"), - Self::Usize(u) => write!(f, "{u}"), } } } @@ -255,11 +253,16 @@ impl Value { _ => Err(Error::ValueNotString(self.to_owned())), } } +} - pub fn inner_value(&self) -> Result { - match self { - Self::Usize(s) => Ok(s.to_owned()), - _ => Err(Error::ValueNotString(self.to_owned())), +impl TryInto for &Value { + type Error = Error; + + fn try_into(self) -> Result { + if let Value::Number(n) = self { + Ok(n.clone() as usize) + } else { + Err(Error::ValueNotUsize(self.to_owned())) } } } diff --git a/crates/tinyvec-embed/src/error.rs b/crates/tinyvec-embed/src/error.rs index 8dc0f54..60e41b4 100644 --- a/crates/tinyvec-embed/src/error.rs +++ b/crates/tinyvec-embed/src/error.rs @@ -36,6 +36,8 @@ pub enum Error { ValueNotNumber(Value), #[error("expected value to be string, got: {0}")] ValueNotString(Value), + #[error("expected value to be a valid size, got: {0}")] + ValueNotUsize(Value), } pub type Result = std::result::Result; From 9073d5d75ef4b4e79cd551462d8e49cbb1c6ff22 Mon Sep 17 00:00:00 2001 From: Quentin Maire Date: Wed, 6 Mar 2024 11:56:30 +0100 Subject: [PATCH 13/16] Corrections after new feedbacks --- crates/llm-ls/src/retrieval.rs | 3 +-- crates/tinyvec-embed/src/db.rs | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/crates/llm-ls/src/retrieval.rs b/crates/llm-ls/src/retrieval.rs index 79c02ad..6bea3f8 100644 --- a/crates/llm-ls/src/retrieval.rs +++ b/crates/llm-ls/src/retrieval.rs @@ -16,14 +16,13 @@ use tokenizers::{ Encoding, PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer, TruncationDirection, }; use tokio::io::AsyncReadExt; -use tokio::task::{spawn_blocking}; +use tokio::task::spawn_blocking; use tokio::time::Instant; use tower_lsp::lsp_types::notification::Progress; use tower_lsp::lsp_types::{ NumberOrString, ProgressParams, ProgressParamsValue, Range, WorkDoneProgress, WorkDoneProgressReport, }; -use std::iter::zip; use tower_lsp::Client; use tracing::{debug, error, warn}; diff --git a/crates/tinyvec-embed/src/db.rs b/crates/tinyvec-embed/src/db.rs index 8a7c0b8..a6dfdb5 100644 --- a/crates/tinyvec-embed/src/db.rs +++ b/crates/tinyvec-embed/src/db.rs @@ -262,7 +262,7 @@ impl TryInto for &Value { if let Value::Number(n) = self { Ok(n.clone() as usize) } else { - Err(Error::ValueNotUsize(self.to_owned())) + Err(Error::ValueNotNumber(self.to_owned())) } } } From fba84317ff1e7d4ac1e9d9d9613ece781bdef038 Mon Sep 17 00:00:00 2001 From: Quentin Maire Date: Thu, 7 Mar 2024 11:17:30 +0100 Subject: [PATCH 14/16] Fetch workspace name from file URI in completion route --- crates/llm-ls/src/main.rs | 104 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) diff --git a/crates/llm-ls/src/main.rs b/crates/llm-ls/src/main.rs index 38996dd..e64d142 100644 --- a/crates/llm-ls/src/main.rs +++ b/crates/llm-ls/src/main.rs @@ -509,6 +509,27 @@ fn build_url(backend: Backend, model: &str) -> String { } impl LlmService { + async fn file_uri_to_workspace(&self, file_uri: String) -> String { + debug!("From file to workspace {}", file_uri); + debug!("With workspaces {:?}", self.workspace_folders); + let folders = self.workspace_folders.read().await; + match folders.as_ref() { + Some(folders) => { + let parent_workspace = folders + .clone() + .into_iter() + .filter(|folder| file_uri.contains(folder.uri.path())) + .collect::>(); + if parent_workspace.is_empty() { + folders[0].name.clone() + } else { + parent_workspace[0].name.clone() + } + } + None => "".to_string(), + } + } + async fn get_completions( &self, params: GetCompletionsParams, @@ -520,6 +541,7 @@ impl LlmService { let document_map = self.document_map.read().await; let file_url = params.text_document_position.text_document.uri.as_str(); + let target_workspace = self.file_uri_to_workspace(file_url.to_string()).await; let document = match document_map.get(file_url) { Some(doc) => doc, @@ -897,6 +919,7 @@ async fn main() { .danger_accept_invalid_certs(true) .build() .expect("failed to build reqwest unsafe client"); + debug!("Reading {:?}", cache_dir); let config = Arc::new( load_config( @@ -952,3 +975,84 @@ async fn main() { Server::new(stdin, stdout, socket).serve(service).await; } } + +#[cfg(test)] +mod tests { + use super::*; + + async fn service_setup() -> LspService { + let cache_dir = PathBuf::from(r"idontexist"); + let config = Arc::new(LlmLsConfig { + ..Default::default() + }); + let snippet_retriever = Arc::new(RwLock::new( + SnippetRetriever::new(cache_dir.join("embeddings"), config.model.clone(), 20, 10) + .await + .unwrap(), + )); + let (service, _) = LspService::build(|client| LlmService { + cache_dir, + client, + config, + document_map: Arc::new(RwLock::new(HashMap::new())), + http_client: reqwest::Client::new(), + unsafe_http_client: reqwest::Client::new(), + workspace_folders: Arc::new(RwLock::new(None)), + tokenizer_map: Arc::new(RwLock::new(HashMap::new())), + unauthenticated_warn_at: Arc::new(RwLock::new( + Instant::now() + .checked_sub(MAX_WARNING_REPEAT) + .expect("instant to be in bounds"), + )), + snippet_retriever, + supports_progress_bar: Arc::new(RwLock::new(false)), + cancel_snippet_build_tx: Arc::new(RwLock::new(None)), + indexation_handle: Arc::new(RwLock::new(None)), + }) + .finish(); + service + } + + #[tokio::test] + async fn test_file_uri_to_workspace() { + // let (service, socket) = LspService::new(|client| LlmService { client }); + let service = service_setup().await; + { + let inn = service + .inner() + .file_uri_to_workspace("/home/test".to_string()) + .await; + assert_eq!(inn, ""); + } + { + *service.inner().workspace_folders.write().await = vec![ + WorkspaceFolder { + name: "other_repo".to_string(), + uri: Url::from_directory_path("/home/other_test").unwrap(), + }, + WorkspaceFolder { + name: "test_repo".to_string(), + uri: Url::from_directory_path("/home/test").unwrap(), + }, + ] + .into(); + let inn = service + .inner() + .file_uri_to_workspace("/home/test/src/lib/main.py".to_string()) + .await; + assert_eq!(inn, "test_repo"); + } + { + *service.inner().workspace_folders.write().await = vec![WorkspaceFolder { + name: "other_repo".to_string(), + uri: Url::from_directory_path("/home/other_test").unwrap(), + }] + .into(); + let inn = service + .inner() + .file_uri_to_workspace("/home/test/src/lib/main.py".to_string()) + .await; + assert_eq!(inn, "other_repo"); + } + } +} From c529e0df196096f67f27a5783f6493512ef9d42b Mon Sep 17 00:00:00 2001 From: Quentin Maire Date: Fri, 15 Mar 2024 16:13:37 +0100 Subject: [PATCH 15/16] Handle multiple initialized workspaces calling completions request --- crates/llm-ls/src/main.rs | 73 +++++++++++++++++-------------- crates/llm-ls/src/retrieval.rs | 78 ++++++++++++++++++++++------------ 2 files changed, 92 insertions(+), 59 deletions(-) diff --git a/crates/llm-ls/src/main.rs b/crates/llm-ls/src/main.rs index e64d142..d854956 100644 --- a/crates/llm-ls/src/main.rs +++ b/crates/llm-ls/src/main.rs @@ -190,6 +190,7 @@ async fn build_prompt( file_url: &str, language_id: LanguageId, snippet_retriever: Arc>, + target_workspace: &str, ) -> Result { let t = Instant::now(); if fim.enabled { @@ -259,6 +260,7 @@ async fn build_prompt( Compare::Neq, file_url.into(), )), + target_workspace ) .await?; let context_header = build_context_header(language_id, snippets); @@ -307,6 +309,7 @@ async fn build_prompt( Compare::Neq, file_url.into(), )), + target_workspace ) .await?; let context_header = build_context_header(language_id, snippets); @@ -508,28 +511,26 @@ fn build_url(backend: Backend, model: &str) -> String { } } -impl LlmService { - async fn file_uri_to_workspace(&self, file_uri: String) -> String { - debug!("From file to workspace {}", file_uri); - debug!("With workspaces {:?}", self.workspace_folders); - let folders = self.workspace_folders.read().await; - match folders.as_ref() { - Some(folders) => { - let parent_workspace = folders - .clone() - .into_iter() - .filter(|folder| file_uri.contains(folder.uri.path())) - .collect::>(); - if parent_workspace.is_empty() { - folders[0].name.clone() - } else { - parent_workspace[0].name.clone() - } +fn file_uri_to_workspace(workspace_folders: Option<&Vec>, file_uri: &str) -> String { + // let folders = self.workspace_folders.read().await; + match workspace_folders { + Some(folders) => { + let parent_workspace = folders + .clone() + .into_iter() + .filter(|folder| file_uri.contains(folder.uri.path())) + .collect::>(); + if parent_workspace.is_empty() { + folders[0].name.clone() + } else { + parent_workspace[0].name.clone() } - None => "".to_string(), } + None => "".to_string(), } +} +impl LlmService { async fn get_completions( &self, params: GetCompletionsParams, @@ -541,7 +542,9 @@ impl LlmService { let document_map = self.document_map.read().await; let file_url = params.text_document_position.text_document.uri.as_str(); - let target_workspace = self.file_uri_to_workspace(file_url.to_string()).await; + let target_workspace = file_uri_to_workspace( + self.workspace_folders.read().await.as_ref(), + file_url); let document = match document_map.get(file_url) { Some(doc) => doc, @@ -600,6 +603,7 @@ impl LlmService { &file_url.replace("file://", ""), document.language_id, self.snippet_retriever.clone(), + &target_workspace, ).await?; let http_client = if params.tls_skip_verify_insecure { @@ -722,12 +726,15 @@ impl LanguageServer for LlmService { .await; } let mut guard = snippet_retriever.write().await; + let workspace_path = &workspace_folders[0].uri.path().to_string(); + let workspace_root = file_uri_to_workspace(Some(workspace_folders), &workspace_path); tokio::select! { res = guard.build_workspace_snippets( client.clone(), config, token, - workspace_folders[0].uri.path(), + &workspace_root, + &workspace_path, ) => { if let Err(err) = res { error!("failed building workspace snippets: {err}"); @@ -800,6 +807,9 @@ impl LanguageServer for LlmService { match doc.change(range, &change.text).await { Ok((start, old_end, new_end)) => { let start = Position::new(start as u32, 0); + let workspace_folders = self.workspace_folders.read().await; + let target_workspace = file_uri_to_workspace(workspace_folders.as_ref(), path); + if let Err(err) = self .snippet_retriever .write() @@ -807,6 +817,7 @@ impl LanguageServer for LlmService { .remove( path.to_owned(), Range::new(start, Position::new(old_end as u32, 0)), + &target_workspace, ) .await { @@ -819,6 +830,7 @@ impl LanguageServer for LlmService { .update_document( path.to_owned(), Range::new(start, Position::new(new_end as u32, 0)), + &target_workspace, ) .await { @@ -1018,10 +1030,10 @@ mod tests { // let (service, socket) = LspService::new(|client| LlmService { client }); let service = service_setup().await; { - let inn = service - .inner() - .file_uri_to_workspace("/home/test".to_string()) - .await; + let folders = service.inner().workspace_folders.read().await; + let inn = file_uri_to_workspace(folders.as_ref(), + "/home/test"); + assert_eq!(inn, ""); } { @@ -1036,10 +1048,9 @@ mod tests { }, ] .into(); - let inn = service - .inner() - .file_uri_to_workspace("/home/test/src/lib/main.py".to_string()) - .await; + let folders = service.inner().workspace_folders.read().await; + let inn = file_uri_to_workspace(folders.as_ref(), + "/home/test/src/lib/main.py"); assert_eq!(inn, "test_repo"); } { @@ -1048,10 +1059,8 @@ mod tests { uri: Url::from_directory_path("/home/other_test").unwrap(), }] .into(); - let inn = service - .inner() - .file_uri_to_workspace("/home/test/src/lib/main.py".to_string()) - .await; + let folders = service.inner().workspace_folders.read().await; + let inn = file_uri_to_workspace(folders.as_ref(), "/home/test/src/lib/main.py"); assert_eq!(inn, "other_repo"); } } diff --git a/crates/llm-ls/src/retrieval.rs b/crates/llm-ls/src/retrieval.rs index 6bea3f8..2a5f93d 100644 --- a/crates/llm-ls/src/retrieval.rs +++ b/crates/llm-ls/src/retrieval.rs @@ -241,7 +241,6 @@ impl TryFrom<&SimilarityResult> for Snippet { pub(crate) struct SnippetRetriever { cache_path: PathBuf, - collection_name: String, db: Option, model: Arc, model_config: ModelConfig, @@ -260,13 +259,12 @@ impl SnippetRetriever { window_size: usize, window_step: usize, ) -> Result { - let collection_name = "code-slices".to_owned(); let (model, tokenizer) = build_model_and_tokenizer(model_config.id.clone(), model_config.revision.clone()) .await?; Ok(Self { cache_path, - collection_name, + // collection_name, db: None, model: Arc::new(model), model_config, @@ -276,12 +274,20 @@ impl SnippetRetriever { }) } - pub(crate) async fn initialise_database(&mut self, db_name: &str) -> Result { + fn workspace_name_to_snippet_collections(&self, workspace_root: &str) -> String { + format!("{}--{}", "workspace", workspace_root).to_string() + } + + pub(crate) async fn initialise_database( + &mut self, + db_name: &str, + workspace_root: &str, + ) -> Result { let uri = self.cache_path.join(db_name); let mut db = Db::open(uri).await.expect("failed to open database"); match db .create_collection( - self.collection_name.clone(), + self.workspace_name_to_snippet_collections(workspace_root), self.model_config.embeddings_size, Distance::Cosine, ) @@ -303,24 +309,24 @@ impl SnippetRetriever { config: Arc, token: NumberOrString, workspace_root: &str, + workspace_path: &str, ) -> Result<()> { - debug!("building workspace snippets"); + debug!("building snippets for workspace {workspace_root}"); let start = Instant::now(); - let workspace_root = PathBuf::from(workspace_root); if self.db.is_none() { - self.initialise_database(&format!( - "{}--{}", - workspace_root - .file_name() - .ok_or_else(|| Error::NoFinalPath(workspace_root.clone()))? - .to_str() - .ok_or(Error::NonUnicode)?, - self.model_config.id.replace('/', "--"), - )) + self.initialise_database( + &format!( + "{}--{}", + "code-slices".to_owned(), + self.model_config.id.replace('/', "--"), + ), + workspace_root, + ) .await?; } + let workspace_path: PathBuf = PathBuf::from(workspace_path); let mut files = Vec::new(); - let mut gitignore = Gitignore::parse(&workspace_root).ok(); + let mut gitignore = Gitignore::parse(&workspace_path).ok(); for pattern in config.ignored_paths.iter() { if let Some(gitignore) = gitignore.as_mut() { if let Err(err) = gitignore.add_rule(pattern.clone()) { @@ -341,7 +347,7 @@ impl SnippetRetriever { }) .await; let mut stack = VecDeque::new(); - stack.push_back(workspace_root.clone()); + stack.push_back(workspace_path.clone()); while let Some(src) = stack.pop_back() { let mut entries = tokio::fs::read_dir(&src).await?; while let Some(entry) = entries.next_entry().await? { @@ -367,7 +373,7 @@ impl SnippetRetriever { } for (i, file) in files.iter().enumerate() { let file_url = file.to_str().expect("file path should be utf8").to_owned(); - self.add_document(file_url).await?; + self.add_document(file_url, workspace_root).await?; client .send_notification::(ProgressParams { token: token.clone(), @@ -376,7 +382,7 @@ impl SnippetRetriever { message: Some(format!( "{i}/{} ({})", files.len(), - file.strip_prefix(workspace_root.as_path())? + file.strip_prefix(workspace_path.as_path())? .to_str() .expect("expect file name to be valid unicode") )), @@ -393,16 +399,23 @@ impl SnippetRetriever { Ok(()) } - pub(crate) async fn add_document(&self, file_url: String) -> Result<()> { - self.build_and_add_snippets(file_url, 0, None).await?; + pub(crate) async fn add_document(&self, file_url: String, workspace_root: &str) -> Result<()> { + self.build_and_add_snippets(file_url, 0, None, workspace_root) + .await?; Ok(()) } - pub(crate) async fn update_document(&mut self, file_url: String, range: Range) -> Result<()> { + pub(crate) async fn update_document( + &mut self, + file_url: String, + range: Range, + workspace_root: &str, + ) -> Result<()> { self.build_and_add_snippets( file_url, range.start.line as usize, Some(range.end.line as usize), + workspace_root, ) .await?; Ok(()) @@ -459,12 +472,14 @@ impl SnippetRetriever { &self, query: &[f32], filter: Option, + workspace_root: &str, ) -> Result> { let db = match self.db.as_ref() { Some(db) => db.clone(), None => return Err(Error::UninitialisedDatabase), }; - let col = db.get_collection(&self.collection_name).await?; + let target_collection_name = self.workspace_name_to_snippet_collections(workspace_root); + let col = db.get_collection(&target_collection_name).await?; let result = col .read() .await @@ -485,12 +500,19 @@ impl SnippetRetriever { Ok(()) } - pub(crate) async fn remove(&self, file_url: String, range: Range) -> Result<()> { + pub(crate) async fn remove( + &self, + file_url: String, + range: Range, + target_workspace: &str, + ) -> Result<()> { let db = match self.db.as_ref() { Some(db) => db.clone(), None => return Err(Error::UninitialisedDatabase), }; - let col = db.get_collection(&self.collection_name).await?; + let col = db + .get_collection(&self.workspace_name_to_snippet_collections(&target_workspace)) + .await?; col.write().await.remove(Some( Collection::filter() .comparison( @@ -542,12 +564,14 @@ impl SnippetRetriever { file_url: String, start: usize, end: Option, + workspace_root: &str, ) -> Result<()> { let db = match self.db.as_ref() { Some(db) => db.clone(), None => return Err(Error::UninitialisedDatabase), }; - let col = db.get_collection("code-slices").await?; + let collection_name = self.workspace_name_to_snippet_collections(workspace_root); + let col = db.get_collection(&collection_name).await?; let file = tokio::fs::read_to_string(&file_url).await?; let lines = file.split('\n').collect::>(); let end = end.unwrap_or(lines.len()).min(lines.len()); From 90013ef054a6ffbb07572e85267a6934d05f91fd Mon Sep 17 00:00:00 2001 From: Luc Georges Date: Wed, 6 Mar 2024 12:25:12 +0100 Subject: [PATCH 16/16] refactor: cleanup unused error variant --- crates/tinyvec-embed/src/error.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/crates/tinyvec-embed/src/error.rs b/crates/tinyvec-embed/src/error.rs index 60e41b4..8dc0f54 100644 --- a/crates/tinyvec-embed/src/error.rs +++ b/crates/tinyvec-embed/src/error.rs @@ -36,8 +36,6 @@ pub enum Error { ValueNotNumber(Value), #[error("expected value to be string, got: {0}")] ValueNotString(Value), - #[error("expected value to be a valid size, got: {0}")] - ValueNotUsize(Value), } pub type Result = std::result::Result;