Skip to content

Commit

Permalink
Bump redisvl and fix bugs (#26)
Browse files Browse the repository at this point in the history
* round scores to handle any float match discrepencies

* bump redisvl

* add dtype param to custom vectorizer in cache

* fix spacing and linting

* fix formatting

* update formatting and linting

* fix linting
  • Loading branch information
tylerhutcherson authored Oct 18, 2024
1 parent f7d624a commit 3757d98
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 26 deletions.
50 changes: 39 additions & 11 deletions libs/redis/langchain_redis/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from redisvl.extensions.llmcache import ( # type: ignore[import]
SemanticCache as RedisVLSemanticCache,
)
from redisvl.schema.fields import VectorDataType # type: ignore[import]
from redisvl.utils.vectorize import BaseVectorizer # type: ignore[import]

from langchain_redis.version import __full_lib_name__
Expand All @@ -36,22 +37,49 @@ def __init__(self, embeddings: Embeddings):
dims = len(embeddings.embed_query("test"))
super().__init__(model="custom_embeddings", dims=dims, embeddings=embeddings)

def encode(self, texts: Union[str, List[str]]) -> np.ndarray:
def encode(
self,
texts: Union[str, List[str]],
dtype: Union[str, VectorDataType],
**kwargs: Any,
) -> np.ndarray:
if isinstance(dtype, VectorDataType):
dtype = dtype.value.lower()
if isinstance(texts, str):
return np.array(self.embeddings.embed_query(texts), dtype=np.float32)
return np.array(self.embeddings.embed_documents(texts), dtype=np.float32)
return np.array(self.embeddings.embed_query(texts), dtype=dtype)
return np.array(self.embeddings.embed_documents(texts), dtype=dtype)

def embed(self, text: str) -> List[float]:
return self.encode(text).tolist()
def embed(
self,
text: str,
dtype: Union[str, VectorDataType] = "float32",
**kwargs: Any,
) -> List[float]:
return self.encode(text, dtype, **kwargs).tolist()

def embed_many(self, texts: List[str]) -> List[List[float]]:
return self.encode(texts).tolist()
def embed_many(
self,
texts: List[str],
dtype: Union[str, VectorDataType] = "float32",
**kwargs: Any,
) -> List[List[float]]:
return self.encode(texts, dtype, **kwargs).tolist()

async def aembed(self, text: str) -> List[float]:
return await asyncio.to_thread(self.embed, text)
async def aembed(
self,
text: str,
dtype: Union[str, VectorDataType] = "float32",
**kwargs: Any,
) -> List[float]:
return await asyncio.to_thread(self.embed, text, dtype, **kwargs)

async def aembed_many(self, texts: List[str]) -> List[List[float]]:
return await asyncio.to_thread(self.embed_many, texts)
async def aembed_many(
self,
texts: List[str],
dtype: Union[str, VectorDataType] = "float32",
**kwargs: Any,
) -> List[List[float]]:
return await asyncio.to_thread(self.embed_many, texts, dtype, **kwargs)


class RedisCache(BaseCache):
Expand Down
10 changes: 8 additions & 2 deletions libs/redis/langchain_redis/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,10 @@ def similarity_search_with_score_by_vector(
doc_embeddings_dict = {
doc_id: doc[self.config.embedding_field]
if self.config.storage_type == StorageType.JSON.value
else buffer_to_array(doc[self.config.embedding_field])
else buffer_to_array(
doc[self.config.embedding_field],
dtype=self.config.vector_datatype,
)
for doc_id, doc in zip(doc_ids, docs_from_storage)
}

Expand Down Expand Up @@ -1039,7 +1042,10 @@ def similarity_search_with_score_by_vector(
},
),
float(result.get("vector_distance", 0)),
buffer_to_array(doc.get(self.config.embedding_field)),
buffer_to_array(
doc.get(self.config.embedding_field),
dtype=self.config.vector_datatype,
),
)
for doc, result in zip(full_docs, results)
]
Expand Down
58 changes: 48 additions & 10 deletions libs/redis/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion libs/redis/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ license = "MIT"
[tool.poetry.dependencies]
python = ">=3.9,<4.0"
langchain-core = "^0.3"
redisvl = "^0.3.3"
redisvl = "^0.3.5"
numpy = "^1"
python-ulid = "^2.7.0"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,7 @@ def test_similarity_search_with_score_with_limit_distance(redis_url: str) -> Non

# Print and verify the scores
for doc, score in output: # type: ignore[misc]
assert score >= 0 # Ensure score is non-negative
assert float(round(score, 4)) >= 0 # Ensure score is non-negative

# Clean up
vector_store.index.delete(drop=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -897,7 +897,7 @@ def test_similarity_search_with_score_with_limit_distance(redis_url: str) -> Non

# Print and verify the scores
for doc, score in output: # type: ignore[misc]
assert score >= 0 # Ensure score is non-negative
assert float(round(score, 4)) >= 0 # Ensure score is non-negative

# Clean up
vector_store.index.delete(drop=True)
Expand Down

0 comments on commit 3757d98

Please sign in to comment.