Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle multiple initialized workspaces in get Completions #91

Open
wants to merge 16 commits into
base: feat/multi_file_context
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 114 additions & 1 deletion crates/llm-ls/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ async fn build_prompt(
file_url: &str,
language_id: LanguageId,
snippet_retriever: Arc<RwLock<SnippetRetriever>>,
target_workspace: &str,
) -> Result<String> {
let t = Instant::now();
if fim.enabled {
Expand Down Expand Up @@ -259,6 +260,7 @@ async fn build_prompt(
Compare::Neq,
file_url.into(),
)),
target_workspace
)
.await?;
let context_header = build_context_header(language_id, snippets);
Expand Down Expand Up @@ -307,6 +309,7 @@ async fn build_prompt(
Compare::Neq,
file_url.into(),
)),
target_workspace
)
.await?;
let context_header = build_context_header(language_id, snippets);
Expand Down Expand Up @@ -508,6 +511,25 @@ fn build_url(backend: Backend, model: &str) -> String {
}
}

fn file_uri_to_workspace(workspace_folders: Option<&Vec<WorkspaceFolder>>, file_uri: &str) -> String {
// let folders = self.workspace_folders.read().await;
match workspace_folders {
Some(folders) => {
let parent_workspace = folders
.clone()
.into_iter()
.filter(|folder| file_uri.contains(folder.uri.path()))
.collect::<Vec<WorkspaceFolder>>();
if parent_workspace.is_empty() {
folders[0].name.clone()
} else {
parent_workspace[0].name.clone()
}
}
None => "".to_string(),
}
}

impl LlmService {
async fn get_completions(
&self,
Expand All @@ -520,6 +542,9 @@ impl LlmService {
let document_map = self.document_map.read().await;

let file_url = params.text_document_position.text_document.uri.as_str();
let target_workspace = file_uri_to_workspace(
self.workspace_folders.read().await.as_ref(),
file_url);
let document =
match document_map.get(file_url) {
Some(doc) => doc,
Expand Down Expand Up @@ -578,6 +603,7 @@ impl LlmService {
&file_url.replace("file://", ""),
document.language_id,
self.snippet_retriever.clone(),
&target_workspace,
).await?;

let http_client = if params.tls_skip_verify_insecure {
Expand Down Expand Up @@ -700,12 +726,15 @@ impl LanguageServer for LlmService {
.await;
}
let mut guard = snippet_retriever.write().await;
let workspace_path = &workspace_folders[0].uri.path().to_string();
let workspace_root = file_uri_to_workspace(Some(workspace_folders), &workspace_path);
tokio::select! {
res = guard.build_workspace_snippets(
client.clone(),
config,
token,
workspace_folders[0].uri.path(),
&workspace_root,
&workspace_path,
) => {
if let Err(err) = res {
error!("failed building workspace snippets: {err}");
Expand Down Expand Up @@ -778,13 +807,17 @@ impl LanguageServer for LlmService {
match doc.change(range, &change.text).await {
Ok((start, old_end, new_end)) => {
let start = Position::new(start as u32, 0);
let workspace_folders = self.workspace_folders.read().await;
let target_workspace = file_uri_to_workspace(workspace_folders.as_ref(), path);

if let Err(err) = self
.snippet_retriever
.write()
.await
.remove(
path.to_owned(),
Range::new(start, Position::new(old_end as u32, 0)),
&target_workspace,
)
.await
{
Expand All @@ -797,6 +830,7 @@ impl LanguageServer for LlmService {
.update_document(
path.to_owned(),
Range::new(start, Position::new(new_end as u32, 0)),
&target_workspace,
)
.await
{
Expand Down Expand Up @@ -897,6 +931,7 @@ async fn main() {
.danger_accept_invalid_certs(true)
.build()
.expect("failed to build reqwest unsafe client");
debug!("Reading {:?}", cache_dir);

let config = Arc::new(
load_config(
Expand Down Expand Up @@ -952,3 +987,81 @@ async fn main() {
Server::new(stdin, stdout, socket).serve(service).await;
}
}

#[cfg(test)]
mod tests {
use super::*;

async fn service_setup() -> LspService<LlmService> {
let cache_dir = PathBuf::from(r"idontexist");
let config = Arc::new(LlmLsConfig {
..Default::default()
});
let snippet_retriever = Arc::new(RwLock::new(
SnippetRetriever::new(cache_dir.join("embeddings"), config.model.clone(), 20, 10)
.await
.unwrap(),
));
let (service, _) = LspService::build(|client| LlmService {
cache_dir,
client,
config,
document_map: Arc::new(RwLock::new(HashMap::new())),
http_client: reqwest::Client::new(),
unsafe_http_client: reqwest::Client::new(),
workspace_folders: Arc::new(RwLock::new(None)),
tokenizer_map: Arc::new(RwLock::new(HashMap::new())),
unauthenticated_warn_at: Arc::new(RwLock::new(
Instant::now()
.checked_sub(MAX_WARNING_REPEAT)
.expect("instant to be in bounds"),
)),
snippet_retriever,
supports_progress_bar: Arc::new(RwLock::new(false)),
cancel_snippet_build_tx: Arc::new(RwLock::new(None)),
indexation_handle: Arc::new(RwLock::new(None)),
})
.finish();
service
}

#[tokio::test]
async fn test_file_uri_to_workspace() {
// let (service, socket) = LspService::new(|client| LlmService { client });
let service = service_setup().await;
{
let folders = service.inner().workspace_folders.read().await;
let inn = file_uri_to_workspace(folders.as_ref(),
"/home/test");

assert_eq!(inn, "");
}
{
*service.inner().workspace_folders.write().await = vec![
WorkspaceFolder {
name: "other_repo".to_string(),
uri: Url::from_directory_path("/home/other_test").unwrap(),
},
WorkspaceFolder {
name: "test_repo".to_string(),
uri: Url::from_directory_path("/home/test").unwrap(),
},
]
.into();
let folders = service.inner().workspace_folders.read().await;
let inn = file_uri_to_workspace(folders.as_ref(),
"/home/test/src/lib/main.py");
assert_eq!(inn, "test_repo");
}
{
*service.inner().workspace_folders.write().await = vec![WorkspaceFolder {
name: "other_repo".to_string(),
uri: Url::from_directory_path("/home/other_test").unwrap(),
}]
.into();
let folders = service.inner().workspace_folders.read().await;
let inn = file_uri_to_workspace(folders.as_ref(), "/home/test/src/lib/main.py");
assert_eq!(inn, "other_repo");
}
}
}
78 changes: 51 additions & 27 deletions crates/llm-ls/src/retrieval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,6 @@ impl TryFrom<&SimilarityResult> for Snippet {

pub(crate) struct SnippetRetriever {
cache_path: PathBuf,
collection_name: String,
db: Option<Db>,
model: Arc<BertModel>,
model_config: ModelConfig,
Expand All @@ -260,13 +259,12 @@ impl SnippetRetriever {
window_size: usize,
window_step: usize,
) -> Result<Self> {
let collection_name = "code-slices".to_owned();
let (model, tokenizer) =
build_model_and_tokenizer(model_config.id.clone(), model_config.revision.clone())
.await?;
Ok(Self {
cache_path,
collection_name,
// collection_name,
db: None,
model: Arc::new(model),
model_config,
Expand All @@ -276,12 +274,20 @@ impl SnippetRetriever {
})
}

pub(crate) async fn initialise_database(&mut self, db_name: &str) -> Result<Db> {
fn workspace_name_to_snippet_collections(&self, workspace_root: &str) -> String {
format!("{}--{}", "workspace", workspace_root).to_string()
}

pub(crate) async fn initialise_database(
&mut self,
db_name: &str,
workspace_root: &str,
) -> Result<Db> {
let uri = self.cache_path.join(db_name);
let mut db = Db::open(uri).await.expect("failed to open database");
match db
.create_collection(
self.collection_name.clone(),
self.workspace_name_to_snippet_collections(workspace_root),
self.model_config.embeddings_size,
Distance::Cosine,
)
Expand All @@ -303,24 +309,24 @@ impl SnippetRetriever {
config: Arc<LlmLsConfig>,
token: NumberOrString,
workspace_root: &str,
workspace_path: &str,
) -> Result<()> {
debug!("building workspace snippets");
debug!("building snippets for workspace {workspace_root}");
let start = Instant::now();
let workspace_root = PathBuf::from(workspace_root);
if self.db.is_none() {
self.initialise_database(&format!(
"{}--{}",
workspace_root
.file_name()
.ok_or_else(|| Error::NoFinalPath(workspace_root.clone()))?
.to_str()
.ok_or(Error::NonUnicode)?,
self.model_config.id.replace('/', "--"),
))
self.initialise_database(
&format!(
"{}--{}",
"code-slices".to_owned(),
self.model_config.id.replace('/', "--"),
),
workspace_root,
)
.await?;
}
let workspace_path: PathBuf = PathBuf::from(workspace_path);
let mut files = Vec::new();
let mut gitignore = Gitignore::parse(&workspace_root).ok();
let mut gitignore = Gitignore::parse(&workspace_path).ok();
for pattern in config.ignored_paths.iter() {
if let Some(gitignore) = gitignore.as_mut() {
if let Err(err) = gitignore.add_rule(pattern.clone()) {
Expand All @@ -341,7 +347,7 @@ impl SnippetRetriever {
})
.await;
let mut stack = VecDeque::new();
stack.push_back(workspace_root.clone());
stack.push_back(workspace_path.clone());
while let Some(src) = stack.pop_back() {
let mut entries = tokio::fs::read_dir(&src).await?;
while let Some(entry) = entries.next_entry().await? {
Expand All @@ -367,7 +373,7 @@ impl SnippetRetriever {
}
for (i, file) in files.iter().enumerate() {
let file_url = file.to_str().expect("file path should be utf8").to_owned();
self.add_document(file_url).await?;
self.add_document(file_url, workspace_root).await?;
client
.send_notification::<Progress>(ProgressParams {
token: token.clone(),
Expand All @@ -376,7 +382,7 @@ impl SnippetRetriever {
message: Some(format!(
"{i}/{} ({})",
files.len(),
file.strip_prefix(workspace_root.as_path())?
file.strip_prefix(workspace_path.as_path())?
.to_str()
.expect("expect file name to be valid unicode")
)),
Expand All @@ -393,16 +399,23 @@ impl SnippetRetriever {
Ok(())
}

pub(crate) async fn add_document(&self, file_url: String) -> Result<()> {
self.build_and_add_snippets(file_url, 0, None).await?;
pub(crate) async fn add_document(&self, file_url: String, workspace_root: &str) -> Result<()> {
self.build_and_add_snippets(file_url, 0, None, workspace_root)
.await?;
Ok(())
}

pub(crate) async fn update_document(&mut self, file_url: String, range: Range) -> Result<()> {
pub(crate) async fn update_document(
&mut self,
file_url: String,
range: Range,
workspace_root: &str,
) -> Result<()> {
self.build_and_add_snippets(
file_url,
range.start.line as usize,
Some(range.end.line as usize),
workspace_root,
)
.await?;
Ok(())
Expand Down Expand Up @@ -459,12 +472,14 @@ impl SnippetRetriever {
&self,
query: &[f32],
filter: Option<FilterBuilder>,
workspace_root: &str,
) -> Result<Vec<Snippet>> {
let db = match self.db.as_ref() {
Some(db) => db.clone(),
None => return Err(Error::UninitialisedDatabase),
};
let col = db.get_collection(&self.collection_name).await?;
let target_collection_name = self.workspace_name_to_snippet_collections(workspace_root);
let col = db.get_collection(&target_collection_name).await?;
let result = col
.read()
.await
Expand All @@ -485,12 +500,19 @@ impl SnippetRetriever {
Ok(())
}

pub(crate) async fn remove(&self, file_url: String, range: Range) -> Result<()> {
pub(crate) async fn remove(
&self,
file_url: String,
range: Range,
target_workspace: &str,
) -> Result<()> {
let db = match self.db.as_ref() {
Some(db) => db.clone(),
None => return Err(Error::UninitialisedDatabase),
};
let col = db.get_collection(&self.collection_name).await?;
let col = db
.get_collection(&self.workspace_name_to_snippet_collections(&target_workspace))
.await?;
col.write().await.remove(Some(
Collection::filter()
.comparison(
Expand Down Expand Up @@ -542,12 +564,14 @@ impl SnippetRetriever {
file_url: String,
start: usize,
end: Option<usize>,
workspace_root: &str,
) -> Result<()> {
let db = match self.db.as_ref() {
Some(db) => db.clone(),
None => return Err(Error::UninitialisedDatabase),
};
let col = db.get_collection("code-slices").await?;
let collection_name = self.workspace_name_to_snippet_collections(workspace_root);
let col = db.get_collection(&collection_name).await?;
let file = tokio::fs::read_to_string(&file_url).await?;
let lines = file.split('\n').collect::<Vec<_>>();
let end = end.unwrap_or(lines.len()).min(lines.len());
Expand Down
Loading
Loading