Skip to content

Commit

Permalink
fix: bug in insert and delete
Browse files Browse the repository at this point in the history
  • Loading branch information
stephantul committed Dec 11, 2024
1 parent 852c6d4 commit db6cf93
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 11 deletions.
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,6 @@ def vicinity_instance_with_stored_vectors(


@pytest.fixture()
def vicinity_with_basic_backend(vectors: np.ndarray, items: list[str]) -> Vicinity:
def vicinity_with_basic_backend_and_store(vectors: np.ndarray, items: list[str]) -> Vicinity:
"""Fixture providing a BasicBackend instance."""
return Vicinity.from_vectors_and_items(vectors, items, backend_type=Backend.BASIC, store_vectors=True)
39 changes: 34 additions & 5 deletions tests/test_vicinity.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,25 +154,25 @@ def test_vicinity_save_and_load_vector_store(tmp_path: Path, vicinity_instance_w
assert v.vector_store is not None


def test_index_vector_store(vicinity_with_basic_backend: Vicinity, vectors: np.ndarray) -> None:
def test_index_vector_store(vicinity_with_basic_backend_and_store: Vicinity, vectors: np.ndarray) -> None:
"""
Index vectors in the Vicinity instance.
:param vicinity_instance: A Vicinity instance.
:param vectors: Array of vectors to index.
"""
v = vicinity_with_basic_backend.get_vector_by_index(0)
v = vicinity_with_basic_backend_and_store.get_vector_by_index(0)
assert np.allclose(v, vectors[0])

idx = [0, 1, 2, 3, 4, 10]
v = vicinity_with_basic_backend.get_vector_by_index(idx)
v = vicinity_with_basic_backend_and_store.get_vector_by_index(idx)
assert np.allclose(v, vectors[idx])

with pytest.raises(ValueError):
vicinity_with_basic_backend.get_vector_by_index([10_000])
vicinity_with_basic_backend_and_store.get_vector_by_index([10_000])

with pytest.raises(ValueError):
vicinity_with_basic_backend.get_vector_by_index([-1])
vicinity_with_basic_backend_and_store.get_vector_by_index([-1])


def test_vicinity_insert_duplicate(vicinity_instance: Vicinity, query_vector: np.ndarray) -> None:
Expand Down Expand Up @@ -203,6 +203,35 @@ def test_vicinity_delete_nonexistent(vicinity_instance: Vicinity) -> None:
vicinity_instance.delete(["item10002"])


def test_vicinity_insert_with_store(vicinity_with_basic_backend_and_store: Vicinity) -> None:
"""
Test that Vicinity.insert raises ValueError when trying to insert vectors into a Vicinity instance with stored vectors.
:param vicinity_with_basic_backend_and_store: A Vicinity instance with stored vectors.
"""
new_item = ["item10002"]
new_vector = np.full((1, vicinity_with_basic_backend_and_store.dim), 0.5)

vicinity_with_basic_backend_and_store.insert(new_item, new_vector)
assert vicinity_with_basic_backend_and_store.vector_store is not None
assert len(vicinity_with_basic_backend_and_store) == len(vicinity_with_basic_backend_and_store.vector_store)


def test_vicinity_delete_with_store(vicinity_with_basic_backend_and_store: Vicinity) -> None:
"""
Test Vicinity.delete method by verifying that the vector for a deleted item is not returned in subsequent queries.
:param vicinity_with_basic_backend_and_store: A Vicinity instance.
"""
assert vicinity_with_basic_backend_and_store.vector_store is not None
# Delete "item2" from the Vicinity instance
vicinity_with_basic_backend_and_store.delete(["item2"])

# Ensure "item2" is no longer in the items list
assert "item2" not in vicinity_with_basic_backend_and_store.items
assert len(vicinity_with_basic_backend_and_store) == len(vicinity_with_basic_backend_and_store.vector_store)


def test_vicinity_insert_mismatched_lengths(vicinity_instance: Vicinity, query_vector: np.ndarray) -> None:
"""
Test that Vicinity.insert raises ValueError when tokens and vectors lengths do not match.
Expand Down
42 changes: 41 additions & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions vicinity/backends/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ def vectors(self, x: Matrix) -> None:
self._vectors = matrix
self._update_precomputed_data()

def __len__(self) -> int:
"""Get the number of vectors."""
return self.vectors.shape[0]


class BasicBackend(BasicVectorStore, AbstractBackend[BasicArgs], ABC):
argument_class = BasicArgs
Expand All @@ -100,10 +104,6 @@ def __init__(self, vectors: npt.NDArray, arguments: BasicArgs) -> None:
"""Initialize the backend."""
super().__init__(vectors=vectors, arguments=arguments)

def __len__(self) -> int:
"""Get the number of vectors."""
return self.vectors.shape[0]

@property
def backend_type(self) -> Backend:
"""The type of the backend."""
Expand Down
4 changes: 4 additions & 0 deletions vicinity/vicinity.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ def insert(self, tokens: Sequence[str], vectors: npt.NDArray) -> None:
raise ValueError(f"Token {token} is already in the vector space.")
self.items.append(token)
self.backend.insert(vectors)
if self.vector_store is not None:
self.vector_store.insert(vectors)

def delete(self, tokens: Sequence[str]) -> None:
"""
Expand All @@ -260,6 +262,8 @@ def delete(self, tokens: Sequence[str]) -> None:
raise ValueError(f"Token {exc} was not in the vector space.") from exc

self.backend.delete(curr_indices)
if self.vector_store is not None:
self.vector_store.delete(curr_indices)

# Delete items starting from the highest index
for index in sorted(curr_indices, reverse=True):
Expand Down

0 comments on commit db6cf93

Please sign in to comment.