Skip to content

Commit

Permalink
fix: issue with low k in various backends (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephantul authored Dec 27, 2024
1 parent 6a51be5 commit 3e82218
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 2 deletions.
8 changes: 8 additions & 0 deletions tests/test_vicinity.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def test_vicinity_query(vicinity_instance: Vicinity, query_vector: np.ndarray) -

assert len(results) == 1

results = vicinity_instance.query(np.stack([query_vector, query_vector]), k=2)

assert results[0] == results[1]


def test_vicinity_query_threshold(vicinity_instance: Vicinity, query_vector: np.ndarray) -> None:
"""
Expand All @@ -70,6 +74,10 @@ def test_vicinity_query_threshold(vicinity_instance: Vicinity, query_vector: np.

assert len(results) >= 1

results = vicinity_instance.query_threshold(np.stack([query_vector, query_vector]), threshold=0.7)

assert results[0] == results[1]


def test_vicinity_insert(vicinity_instance: Vicinity, query_vector: np.ndarray) -> None:
"""
Expand Down
1 change: 1 addition & 0 deletions vicinity/backends/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def dim(self) -> int:

def query(self, vectors: npt.NDArray, k: int) -> QueryResult:
"""Perform a k-NN search in the FAISS index."""
k = min(len(self), k)
if self.arguments.metric == "cosine":
vectors = normalize(vectors)
distances, indices = self.index.search(vectors, k)
Expand Down
1 change: 1 addition & 0 deletions vicinity/backends/hnsw.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def save(self, base_path: Path) -> None:

def query(self, vectors: npt.NDArray, k: int) -> QueryResult:
"""Query the backend."""
k = min(k, len(self))
return list(zip(*self.index.knn_query(vectors, k)))

def insert(self, vectors: npt.NDArray) -> None:
Expand Down
5 changes: 3 additions & 2 deletions vicinity/backends/usearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,10 @@ def save(self, base_path: Path) -> None:

def query(self, vectors: npt.NDArray, k: int) -> QueryResult:
"""Query the backend and return results as tuples of keys and distances."""
k = min(k, len(self))
results = self.index.search(vectors, k)
keys = np.array(results.keys).reshape(-1, k)
distances = np.array(results.distances, dtype=np.float32).reshape(-1, k)
keys = np.atleast_2d(results.keys)
distances = np.atleast_2d(results.distances)
return list(zip(keys, distances))

def insert(self, vectors: npt.NDArray) -> None:
Expand Down
1 change: 1 addition & 0 deletions vicinity/backends/voyager.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def from_vectors(

def query(self, query: npt.NDArray, k: int) -> QueryResult:
"""Query the backend for the nearest neighbors."""
k = min(k, len(self))
indices, distances = self.index.query(query, k)
return list(zip(indices, distances))

Expand Down

0 comments on commit 3e82218

Please sign in to comment.