Skip to content

Commit

Permalink
Added euclidean metric to basic backend
Browse files Browse the repository at this point in the history
  • Loading branch information
Pringled committed Nov 29, 2024
1 parent f131c19 commit 00e4af2
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 32 deletions.
92 changes: 60 additions & 32 deletions vicinity/backends/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dataclasses import dataclass
from pathlib import Path
from typing import Any
from typing import Any, Literal

import numpy as np
from numpy import typing as npt
Expand All @@ -13,7 +13,8 @@


@dataclass
class BasicArgs(BaseArgs): ...
class BasicArgs(BaseArgs):
metric: Literal["cosine", "euclidean"] = "cosine"


class BasicBackend(AbstractBackend[BasicArgs]):
Expand All @@ -24,6 +25,8 @@ def __init__(self, vectors: npt.NDArray, arguments: BasicArgs) -> None:
super().__init__(arguments)
self._vectors = vectors
self._norm_vectors: npt.NDArray | None = None
self._squared_norm_vectors: npt.NDArray | None = None
self._update_precomputed_data()

def __len__(self) -> int:
"""Get the number of vectors."""
Expand All @@ -37,7 +40,8 @@ def backend_type(self) -> Backend:
@classmethod
def from_vectors(cls: type[BasicBackend], vectors: npt.NDArray, **kwargs: Any) -> BasicBackend:
"""Create a new instance from vectors."""
return cls(vectors, BasicArgs())
arguments = BasicArgs(**kwargs)
return cls(vectors, arguments)

@classmethod
def load(cls: type[BasicBackend], folder: Path) -> BasicBackend:
Expand Down Expand Up @@ -70,9 +74,18 @@ def vectors(self, x: Matrix) -> None:
if not np.ndim(matrix) == 2:
raise ValueError(f"Your array does not have 2 dimensions: {np.ndim(matrix)}")
self._vectors = matrix
# Make sure norm vectors is updated.
if self._norm_vectors is not None:
self._norm_vectors = normalize_or_copy(matrix)
self._update_precomputed_data()

def squared_norm(self, x: np.ndarray) -> np.ndarray:
"""Compute the squared norm of a matrix."""
return (x**2).sum(1)

def _update_precomputed_data(self) -> None:
"""Update precomputed data based on the metric."""
if self.arguments.metric == "cosine":
self._norm_vectors = normalize_or_copy(self._vectors)
elif self.arguments.metric == "euclidean":
self._squared_norm_vectors = self.squared_norm(self._vectors)

@property
def norm_vectors(self) -> npt.NDArray:
Expand All @@ -85,17 +98,24 @@ def norm_vectors(self) -> npt.NDArray:
self._norm_vectors = normalize_or_copy(self.vectors)
return self._norm_vectors

@property
def squared_norm_vectors(self) -> npt.NDArray:
"""The squared norms of the vectors."""
if self._squared_norm_vectors is None:
self._squared_norm_vectors = self.squared_norm(self.vectors)
return self._squared_norm_vectors

def threshold(
self,
vectors: npt.NDArray,
threshold: float,
) -> list[npt.NDArray]:
"""Batched cosine similarity."""
"""Batched distance thresholding."""
out: list[npt.NDArray] = []
for i in range(0, len(vectors), 1024):
batch = vectors[i : i + 1024]
distances = self._dist(batch, self.norm_vectors)
for _, sims in enumerate(distances):
distances = self._dist(batch)
for sims in distances:
indices = np.flatnonzero(sims <= threshold)
sorted_indices = indices[np.argsort(sims[indices])]
out.append(sorted_indices)
Expand All @@ -107,43 +127,51 @@ def query(
vectors: npt.NDArray,
k: int,
) -> QueryResult:
"""Batched cosine distance."""
"""Batched distance query."""
if k < 1:
raise ValueError("num should be >= 1, is now {num}")
raise ValueError(f"k should be >= 1, is now {k}")

out: QueryResult = []
num_vectors = len(self.vectors)
effective_k = min(k, num_vectors)

for index in range(0, len(vectors), 1024):
batch = vectors[index : index + 1024]
distances = self._dist(batch, self.norm_vectors)
if k == 1:
sorted_indices = np.argmin(distances, 1, keepdims=True)
elif k >= len(self.vectors):
# If we want more than we have, just sort everything.
sorted_indices = np.stack([np.arange(len(self.vectors))] * len(vectors))
else:
sorted_indices = np.argpartition(distances, kth=k, axis=1)
sorted_indices = sorted_indices[:, :k]
for lidx, indices in enumerate(sorted_indices):
dists_for_word = distances[lidx, indices]
word_index = np.argsort(dists_for_word)
i = indices[word_index]
d = dists_for_word[word_index]
out.append((i, d))
distances = self._dist(batch)

return out
# Use argpartition for efficiency
indices = np.argpartition(distances, kth=effective_k - 1, axis=1)[:, :effective_k]
sorted_indices = np.take_along_axis(
indices, np.argsort(np.take_along_axis(distances, indices, axis=1)), axis=1
)
sorted_distances = np.take_along_axis(distances, sorted_indices, axis=1)

@classmethod
def _dist(cls, x: npt.NDArray, y: npt.NDArray) -> npt.NDArray:
"""Cosine distance function. This assumes y is normalized."""
sim = normalize(x).dot(y.T)
out.extend(zip(sorted_indices, sorted_distances))

return out

return 1 - sim
def _dist(self, x: npt.NDArray) -> npt.NDArray:
"""Compute distances between x and self._vectors based on the given metric."""
if self.arguments.metric == "cosine":
x_norm = normalize(x)
sim = x_norm.dot(self.norm_vectors.T)
return 1 - sim
elif self.arguments.metric == "euclidean":
x_norm = self.squared_norm(x)
dists_squared = (x_norm[:, None] + self.squared_norm_vectors[None, :]) - 2 * (x @ self._vectors.T)

# Ensure non-negative distances
dists_squared = np.maximum(dists_squared, 1e-12)
return np.sqrt(dists_squared)
else:
raise ValueError(f"Unsupported metric: {self.arguments.metric}")

def insert(self, vectors: npt.NDArray) -> None:
"""Insert vectors into the vector space."""
self._vectors = np.vstack([self._vectors, vectors])
self._update_precomputed_data()

def delete(self, indices: list[int]) -> None:
"""Deletes specific indices from the vector space."""
self._vectors = np.delete(self._vectors, indices, axis=0)
self._update_precomputed_data()
1 change: 1 addition & 0 deletions vicinity/vicinity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import logging
import time
from io import open
from pathlib import Path
from typing import Any, Sequence, Union
Expand Down

0 comments on commit 00e4af2

Please sign in to comment.