From 3e822189dc99ca977b33d3c626d7f2119855c78e Mon Sep 17 00:00:00 2001 From: Stephan Tulkens Date: Fri, 27 Dec 2024 13:49:07 +0100 Subject: [PATCH] fix: issue with low k in various backends (#52) --- tests/test_vicinity.py | 8 ++++++++ vicinity/backends/faiss.py | 1 + vicinity/backends/hnsw.py | 1 + vicinity/backends/usearch.py | 5 +++-- vicinity/backends/voyager.py | 1 + 5 files changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/test_vicinity.py b/tests/test_vicinity.py index 22c25ee..daae1ca 100644 --- a/tests/test_vicinity.py +++ b/tests/test_vicinity.py @@ -58,6 +58,10 @@ def test_vicinity_query(vicinity_instance: Vicinity, query_vector: np.ndarray) - assert len(results) == 1 + results = vicinity_instance.query(np.stack([query_vector, query_vector]), k=2) + + assert results[0] == results[1] + def test_vicinity_query_threshold(vicinity_instance: Vicinity, query_vector: np.ndarray) -> None: """ @@ -70,6 +74,10 @@ def test_vicinity_query_threshold(vicinity_instance: Vicinity, query_vector: np. assert len(results) >= 1 + results = vicinity_instance.query_threshold(np.stack([query_vector, query_vector]), threshold=0.7) + + assert results[0] == results[1] + def test_vicinity_insert(vicinity_instance: Vicinity, query_vector: np.ndarray) -> None: """ diff --git a/vicinity/backends/faiss.py b/vicinity/backends/faiss.py index fc0ca31..d6b5d27 100644 --- a/vicinity/backends/faiss.py +++ b/vicinity/backends/faiss.py @@ -146,6 +146,7 @@ def dim(self) -> int: def query(self, vectors: npt.NDArray, k: int) -> QueryResult: """Perform a k-NN search in the FAISS index.""" + k = min(len(self), k) if self.arguments.metric == "cosine": vectors = normalize(vectors) distances, indices = self.index.search(vectors, k) diff --git a/vicinity/backends/hnsw.py b/vicinity/backends/hnsw.py index e9c124b..df1ac93 100644 --- a/vicinity/backends/hnsw.py +++ b/vicinity/backends/hnsw.py @@ -93,6 +93,7 @@ def save(self, base_path: Path) -> None: def query(self, vectors: npt.NDArray, k: int) -> QueryResult: """Query the backend.""" + k = min(k, len(self)) return list(zip(*self.index.knn_query(vectors, k))) def insert(self, vectors: npt.NDArray) -> None: diff --git a/vicinity/backends/usearch.py b/vicinity/backends/usearch.py index 2781cdc..41b7b85 100644 --- a/vicinity/backends/usearch.py +++ b/vicinity/backends/usearch.py @@ -115,9 +115,10 @@ def save(self, base_path: Path) -> None: def query(self, vectors: npt.NDArray, k: int) -> QueryResult: """Query the backend and return results as tuples of keys and distances.""" + k = min(k, len(self)) results = self.index.search(vectors, k) - keys = np.array(results.keys).reshape(-1, k) - distances = np.array(results.distances, dtype=np.float32).reshape(-1, k) + keys = np.atleast_2d(results.keys) + distances = np.atleast_2d(results.distances) return list(zip(keys, distances)) def insert(self, vectors: npt.NDArray) -> None: diff --git a/vicinity/backends/voyager.py b/vicinity/backends/voyager.py index 5b2c316..b09b707 100644 --- a/vicinity/backends/voyager.py +++ b/vicinity/backends/voyager.py @@ -68,6 +68,7 @@ def from_vectors( def query(self, query: npt.NDArray, k: int) -> QueryResult: """Query the backend for the nearest neighbors.""" + k = min(k, len(self)) indices, distances = self.index.query(query, k) return list(zip(indices, distances))