From 8b819513c379254397ae779c32f1606455c67f99 Mon Sep 17 00:00:00 2001 From: nohehf Date: Wed, 22 May 2024 09:18:11 +0200 Subject: [PATCH] feat: add status --- src/classes.rs | 95 +++++++++++++++++++++++++++++++-- src/lib.rs | 4 +- src/stack_graphs_wrapper/mod.rs | 50 +++++++++++++++++ stack_graphs_python.pyi | 36 +++++++++++++ tests/js_ok_test.py | 6 ++- tests/ts_ko_test.py | 22 +++++--- 6 files changed, 201 insertions(+), 12 deletions(-) diff --git a/src/classes.rs b/src/classes.rs index 0908bdb..b42d297 100644 --- a/src/classes.rs +++ b/src/classes.rs @@ -6,7 +6,9 @@ use stack_graphs::storage::{SQLiteReader, SQLiteWriter}; use tree_sitter_stack_graphs::cli::util::{SourcePosition, SourceSpan}; use tree_sitter_stack_graphs::loader::Loader; -use crate::stack_graphs_wrapper::{index_all, new_loader, query_definition}; +use crate::stack_graphs_wrapper::{ + get_status, get_status_all, index_all, new_loader, query_definition, +}; #[pyclass] #[derive(Clone)] @@ -17,6 +19,77 @@ pub enum Language { Java, } +#[pyclass] +#[derive(Clone)] +pub enum FileStatus { + Missing, + Indexed, + Error, +} + +#[pyclass] +#[derive(Clone)] +pub struct FileEntry { + #[pyo3(get)] + pub path: String, + #[pyo3(get)] + pub tag: String, + #[pyo3(get)] + pub status: FileStatus, + // As pyo3 does not support string enums, we use Option here instead. + #[pyo3(get)] + pub error: Option, +} + +impl From for FileEntry { + fn from(entry: stack_graphs::storage::FileEntry) -> Self { + let status = match entry.status { + stack_graphs::storage::FileStatus::Missing => FileStatus::Missing, + stack_graphs::storage::FileStatus::Indexed => FileStatus::Indexed, + stack_graphs::storage::FileStatus::Error(_) => FileStatus::Error, + }; + + let error = match entry.status { + stack_graphs::storage::FileStatus::Error(e) => Some(e), + _ => None, + }; + + FileEntry { + path: entry.path.to_str().unwrap().to_string(), + tag: entry.tag, + status, + error, + } + } +} + +#[pymethods] +impl FileEntry { + fn __repr__(&self) -> String { + match self { + FileEntry { + path, + tag, + status, + error, + } => { + let error = match error { + Some(e) => format!("(\"{}\")", e), + None => "".to_string(), + }; + + format!( + "FileEntry(path=\"{}\", tag=\"{}\", status={}{})", + path, + tag, + status.__pyo3__repr__(), + error + ) + } + } + } +} + #[pyclass] #[derive(Clone)] pub struct Position { @@ -66,6 +139,7 @@ impl Querier { #[pyclass] pub struct Indexer { db_writer: SQLiteWriter, + db_reader: SQLiteReader, db_path: String, loader: Loader, } @@ -76,6 +150,7 @@ impl Indexer { pub fn new(db_path: String, languages: Vec) -> Self { Indexer { db_writer: SQLiteWriter::open(db_path.clone()).unwrap(), + db_reader: SQLiteReader::open(db_path.clone()).unwrap(), db_path: db_path, loader: new_loader(languages), } @@ -91,8 +166,22 @@ impl Indexer { } } - // @TODO: Add a method to retrieve the status of the files (indexed, failed, etc.) - // This might be done on a separate class (Database / Storage), as it is tied to the storage, not a specific indexer + pub fn status(&mut self, paths: Vec) -> PyResult> { + let paths: Vec = + paths.iter().map(|p| std::path::PathBuf::from(p)).collect(); + + get_status(paths, &mut self.db_reader)? + .into_iter() + .map(|e| Ok(e.into())) + .collect() + } + + pub fn status_all(&mut self) -> PyResult> { + get_status_all(&mut self.db_reader)? + .into_iter() + .map(|e| Ok(e.into())) + .collect() + } fn __repr__(&self) -> String { format!("Indexer(db_path=\"{}\")", self.db_path) diff --git a/src/lib.rs b/src/lib.rs index 63e0e54..c808aee 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,7 +3,7 @@ use pyo3::prelude::*; mod classes; mod stack_graphs_wrapper; -use classes::{Indexer, Language, Position, Querier}; +use classes::{FileEntry, FileStatus, Indexer, Language, Position, Querier}; /// Formats the sum of two numbers as string. #[pyfunction] @@ -34,6 +34,8 @@ fn stack_graphs_python(_py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(index, m)?)?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; Ok(()) diff --git a/src/stack_graphs_wrapper/mod.rs b/src/stack_graphs_wrapper/mod.rs index 4e713fa..38701fc 100644 --- a/src/stack_graphs_wrapper/mod.rs +++ b/src/stack_graphs_wrapper/mod.rs @@ -15,6 +15,12 @@ pub struct StackGraphsError { message: String, } +impl StackGraphsError { + pub fn from(message: String) -> StackGraphsError { + StackGraphsError { message } + } +} + impl std::convert::From for PyErr { fn from(err: StackGraphsError) -> PyErr { PyException::new_err(err.message) @@ -141,3 +147,47 @@ fn canonicalize_paths(paths: Vec) -> Vec { .collect::, _>>() .unwrap() } + +pub fn get_status_all( + db_reader: &mut SQLiteReader, +) -> Result, StackGraphsError> { + let mut files = db_reader + .list_all() + .map_err(|e| StackGraphsError::from(e.to_string()))?; + let iter = files + .try_iter() + .map_err(|e| StackGraphsError::from(e.to_string()))?; + + let results = iter + .collect::, _>>() + .map_err(|e| StackGraphsError::from(e.to_string()))?; + + Ok(results) +} + +pub fn get_status( + paths: Vec, + db_reader: &mut SQLiteReader, +) -> Result, StackGraphsError> { + let paths = canonicalize_paths(paths); + + let mut entries: Vec = Vec::new(); + + for path in paths { + let mut files = db_reader + .list_file_or_directory(&path) + .map_err(|e| StackGraphsError::from(e.to_string()))?; + + let iter = files + .try_iter() + .map_err(|e| StackGraphsError::from(e.to_string()))?; + + let results = iter + .collect::, _>>() + .map_err(|e| StackGraphsError::from(e.to_string()))?; + + entries.extend(results) + } + + Ok(entries) +} diff --git a/stack_graphs_python.pyi b/stack_graphs_python.pyi index ae78ba8..073c39a 100644 --- a/stack_graphs_python.pyi +++ b/stack_graphs_python.pyi @@ -6,6 +6,26 @@ class Language(Enum): TypeScript = 2 Java = 3 +class FileStatus(Enum): + Indexed = 0 + Missing = 1 + Error = 2 + +class FileEntry: + """ + An entry in the stack graphs database for a given file: + """ + + path: str + tag: str + status: FileStatus + error: str | None + """ + Error message if status is FileStatus.Error + """ + + def __repr__(self) -> str: ... + class Position: """ A position in a given file: @@ -51,6 +71,22 @@ class Indexer: Index all the files in the given paths, recursively """ ... + + def status(self, paths: list[str]) -> list[FileEntry]: + """ + Get the status of the given files + - paths: the paths to the files or directories + - returns: a list of FileEntry objects + """ + ... + + def status_all(self) -> list[FileEntry]: + """ + Get the status of all the files in the database + - returns: a list of FileEntry objects + """ + ... + def __repr__(self) -> str: ... def index(paths: list[str], db_path: str, language: Language) -> None: diff --git a/tests/js_ok_test.py b/tests/js_ok_test.py index c149eeb..6d6911d 100644 --- a/tests/js_ok_test.py +++ b/tests/js_ok_test.py @@ -1,5 +1,5 @@ from helpers.virtual_files import string_to_virtual_files -from stack_graphs_python import index, Indexer, Querier, Language +from stack_graphs_python import index, Indexer, Querier, Language, FileStatus import os code = """ @@ -24,6 +24,10 @@ def test_js_ok(): db_path = os.path.join(dir, "db.sqlite") indexer = Indexer(db_path, [Language.JavaScript]) indexer.index_all([dir]) + status = indexer.status_all() + assert len(status) == 2 + assert status[0].path == os.path.join(dir, "index.js") + assert status[0].status == FileStatus.Indexed querier = Querier(db_path) source_reference = positions["query"] results = querier.definitions(source_reference) diff --git a/tests/ts_ko_test.py b/tests/ts_ko_test.py index 503b711..f4e8464 100644 --- a/tests/ts_ko_test.py +++ b/tests/ts_ko_test.py @@ -1,7 +1,6 @@ from helpers.virtual_files import string_to_virtual_files -from stack_graphs_python import Indexer, Language +from stack_graphs_python import Indexer, Language, FileStatus import os -import pytest ok_code = """ ;---index.ts--- @@ -27,17 +26,26 @@ class A { def test_ts_ok(): with string_to_virtual_files(ok_code) as (dir, _): - db_path = os.path.abspath("./db.sqlite") + db_path = os.path.join(dir, "db.sqlite") dir = os.path.abspath(dir) indexer = Indexer(db_path, [Language.TypeScript]) indexer.index_all([dir]) + status = indexer.status_all() + assert len(status) == 1 + assert status[0].path == os.path.join(dir, "index.ts") + assert status[0].status == FileStatus.Indexed -@pytest.mark.skip("WIP: add a way to check for errors indexing errors") def test_ts_ko(): with string_to_virtual_files(ko_code) as (dir, _): - print("here") - db_path = os.path.abspath("./db.sqlite") + db_path = os.path.join(dir, "db.sqlite") dir = os.path.abspath(dir) indexer = Indexer(db_path, [Language.TypeScript]) - indexer.index_all([dir], db_path, language=Language.TypeScript) + indexer.index_all([dir]) + status = indexer.status_all() + assert len(status) == 1 + assert status[0].path == os.path.join(dir, "index.ts") + assert status[0].status == FileStatus.Error + assert status[0].error is not None + # TODO(@nohehf): Add logs when we fail to index a file + assert status[0].error != "Error parsing source"