Skip to content

Commit

Permalink
Merge pull request #277 from tigergraph/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
parkererickson-tg authored Nov 18, 2024
2 parents 9a49902 + d9672c4 commit d80acfa
Show file tree
Hide file tree
Showing 52 changed files with 4,275 additions and 790 deletions.
15 changes: 6 additions & 9 deletions common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@
AWSBedrock,
AzureOpenAI,
GoogleVertexAI,
OpenAI,
Groq,
Ollama,
HuggingFaceEndpoint,
LLM_Model,
Ollama,
OpenAI,
IBMWatsonX
)
from common.logs.logwriter import LogWriter
from common.session import SessionHandler
from common.status import StatusManager
from common.logs.logwriter import LogWriter

security = HTTPBasic()
session_handler = SessionHandler()
Expand Down Expand Up @@ -91,8 +92,6 @@
"MILVUS_CONFIG must be a .json file or a JSON string, failed with error: "
+ str(e)
)


if llm_config["embedding_service"]["embedding_model_service"].lower() == "openai":
embedding_service = OpenAI_Embedding(llm_config["embedding_service"])
elif llm_config["embedding_service"]["embedding_model_service"].lower() == "azure":
Expand All @@ -105,7 +104,7 @@
raise Exception("Embedding service not implemented")


def get_llm_service(llm_config):
def get_llm_service(llm_config) -> LLM_Model:
if llm_config["completion_service"]["llm_service"].lower() == "openai":
return OpenAI(llm_config["completion_service"])
elif llm_config["completion_service"]["llm_service"].lower() == "azure":
Expand All @@ -127,11 +126,9 @@ def get_llm_service(llm_config):
else:
raise Exception("LLM Completion Service Not Supported")


LogWriter.info(
f"Milvus enabled for host {milvus_config['host']} at port {milvus_config['port']}"
)

if os.getenv("INIT_EMBED_STORE", "true")=="true":
LogWriter.info("Setting up Milvus embedding store for InquiryAI")
try:
Expand Down Expand Up @@ -190,7 +187,7 @@ def get_llm_service(llm_config):
):
doc_processing_config = {
"chunker": "semantic",
"chunker_config": {"method": "percentile", "threshold": 0.95},
"chunker_config": {"method": "percentile", "threshold": 0.90},
"extractor": "llm",
"extractor_config": {},
}
Expand Down
40 changes: 35 additions & 5 deletions common/embeddings/embedding_services.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import logging
import os
import time
from typing import List

from langchain.schema.embeddings import Embeddings
import logging
import time

from common.logs.log import req_id_cv
from common.metrics.prometheus_metrics import metrics
from common.logs.logwriter import LogWriter
from common.metrics.prometheus_metrics import metrics

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -87,6 +89,33 @@ def embed_query(self, question: str) -> List[float]:
duration
)

async def aembed_query(self, question: str) -> List[float]:
"""Embed Query Async.
Embed a string.
Args:
question (str):
A string to embed.
"""
# start_time = time.time()
# metrics.llm_inprogress_requests.labels(self.model_name).inc()

# try:
logger.debug_pii(f"aembed_query() embedding question={question}")
query_embedding = await self.embeddings.aembed_query(question)
# metrics.llm_success_response_total.labels(self.model_name).inc()
return query_embedding
# except Exception as e:
# # metrics.llm_query_error_total.labels(self.model_name).inc()
# raise e
# finally:
# metrics.llm_request_total.labels(self.model_name).inc()
# metrics.llm_inprogress_requests.labels(self.model_name).dec()
# duration = time.time() - start_time
# metrics.llm_request_duration_seconds.labels(self.model_name).observe(
# duration
# )


class AzureOpenAI_Ada002(EmbeddingModel):
"""Azure OpenAI Ada-002 Embedding Model"""
Expand All @@ -105,7 +134,8 @@ def __init__(self, config):
super().__init__(
config, model_name=config.get("model_name", "OpenAI gpt-4-0613")
)
from langchain_openai import OpenAIEmbeddings
# from langchain_openai import OpenAIEmbeddings
from langchain_community.embeddings.openai import OpenAIEmbeddings

self.embeddings = OpenAIEmbeddings()

Expand All @@ -124,8 +154,8 @@ class AWS_Bedrock_Embedding(EmbeddingModel):
"""AWS Bedrock Embedding Model"""

def __init__(self, config):
from langchain_community.embeddings import BedrockEmbeddings
import boto3
from langchain_community.embeddings import BedrockEmbeddings

super().__init__(config=config, model_name=config["embedding_model"])

Expand Down
150 changes: 140 additions & 10 deletions common/embeddings/milvus_embedding_store.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
import logging
import traceback
from time import sleep, time
from typing import Iterable, List, Optional, Tuple

from langchain_milvus.vectorstores import Milvus
import Levenshtein as lev
from asyncer import asyncify
from langchain_community.vectorstores import Milvus
from langchain_core.documents.base import Document
from pymilvus import connections, utility
# from langchain_milvus.vectorstores import Milvus
from langchain_community.vectorstores.milvus import Milvus
from pymilvus import MilvusException, connections, utility
from pymilvus.exceptions import MilvusException

from common.embeddings.base_embedding_store import EmbeddingStore
from common.embeddings.embedding_services import EmbeddingModel
from common.logs.log import req_id_cv
from common.metrics.prometheus_metrics import metrics
from common.logs.logwriter import LogWriter
from pymilvus import MilvusException
from common.metrics.prometheus_metrics import metrics

logger = logging.getLogger(__name__)

Expand All @@ -33,6 +37,7 @@ def __init__(
alias: str = "alias",
retry_interval: int = 2,
max_retry_attempts: int = 10,
drop_old=False,
):
self.embedding_service = embedding_service
self.vector_field = vector_field
Expand All @@ -43,6 +48,7 @@ def __init__(
self.milvus_alias = alias
self.retry_interval = retry_interval
self.max_retry_attempts = max_retry_attempts
self.drop_old = drop_old

if host.startswith("http"):
if host.endswith(str(port)):
Expand Down Expand Up @@ -77,7 +83,7 @@ def connect_to_milvus(self):
while retry_attempt < self.max_retry_attempts:
try:
connections.connect(**self.milvus_connection)
metrics.milvus_active_connections.labels(self.collection_name).inc
# metrics.milvus_active_connections.labels(self.collection_name).inc
LogWriter.info(
f"""Initializing Milvus with host={self.milvus_connection.get("host", self.milvus_connection.get("uri", "unknown host"))},
port={self.milvus_connection.get('port', 'unknown')}, username={self.milvus_connection.get('user', 'unknown')}, collection={self.collection_name}"""
Expand All @@ -88,7 +94,7 @@ def connect_to_milvus(self):
collection_name=self.collection_name,
connection_args=self.milvus_connection,
auto_id=True,
drop_old=False,
drop_old=self.drop_old,
text_field=self.text_field,
vector_field=self.vector_field,
)
Expand Down Expand Up @@ -120,6 +126,9 @@ def metadata_func(record: dict, metadata: dict) -> dict:
return metadata

LogWriter.info("Milvus add initial load documents init()")
import os

logger.info(f"*******{os.path.exists('tg_documents')}")
loader = DirectoryLoader(
"./common/tg_documents/",
glob="*.json",
Expand Down Expand Up @@ -216,6 +225,76 @@ def add_embeddings(
error_message = f"An error occurred while registering document: {str(e)}"
LogWriter.error(error_message)

async def aadd_embeddings(
self,
embeddings: Iterable[Tuple[str, List[float]]],
metadatas: List[dict] = None,
):
"""Async Add Embeddings.
Add embeddings to the Embedding store.
Args:
embeddings (Iterable[Tuple[str, List[float]]]):
Iterable of content and embedding of the document.
metadatas (List[Dict]):
List of dictionaries containing the metadata for each document.
The embeddings and metadatas list need to have identical indexing.
"""
try:
if metadatas is None:
metadatas = []

# add fields required by Milvus if they do not exist
if self.support_ai_instance:
for metadata in metadatas:
if self.vertex_field not in metadata:
metadata[self.vertex_field] = ""
else:
for metadata in metadatas:
if "seq_num" not in metadata:
metadata["seq_num"] = 1
if "source" not in metadata:
metadata["source"] = ""

LogWriter.info(
f"request_id={req_id_cv.get()} Milvus ENTRY aadd_embeddings()"
)
texts = [text for text, _ in embeddings]

# operation_type = "add_texts"
# metrics.milvus_query_total.labels(
# self.collection_name, operation_type
# ).inc()
# start_time = time()

added = await self.milvus.aadd_texts(texts=texts, metadatas=metadatas)

# duration = time() - start_time
# metrics.milvus_query_duration_seconds.labels(
# self.collection_name, operation_type
# ).observe(duration)

LogWriter.info(
f"request_id={req_id_cv.get()} Milvus EXIT aadd_embeddings()"
)

# Check if registration was successful
if added:
success_message = f"Document registered with id: {added[0]}"
LogWriter.info(success_message)
return success_message
else:
error_message = f"Failed to register document {added}"
LogWriter.error(error_message)
raise Exception(error_message)

except Exception as e:
error_message = f"An error occurred while registering document:{metadatas} ({len(texts)},{len(metadatas)})\nErr: {str(e)}"
LogWriter.error(error_message)
exc = traceback.format_exc()
LogWriter.error(exc)
LogWriter.error(f"{texts}")
raise e

def get_pks(
self,
expr: str,
Expand Down Expand Up @@ -509,14 +588,65 @@ def query(self, expr: str, output_fields: List[str]):
return None

try:
query_result = self.milvus.col.query(
expr=expr, output_fields=output_fields
)
query_result = self.milvus.col.query(expr=expr, output_fields=output_fields)
except MilvusException as exc:
LogWriter.error(f"Failed to get outputs: {self.milvus.collection_name} error: {exc}")
LogWriter.error(
f"Failed to get outputs: {self.milvus.collection_name} error: {exc}"
)
raise exc

return query_result

def edit_dist_check(self, a: str, b: str, edit_dist_threshold: float):
a = a.lower()
b = b.lower()
# if the words are short, they should be the same
if len(a) < 5 and len(b) < 5:
return a == b

# edit_dist_threshold (as a percent) of word must match
threshold = int(min(len(a), len(b)) * (1 - edit_dist_threshold))
return lev.distance(a, b) < threshold

async def aget_k_closest(
self, v_id: str, k=15, threshold_similarity=0.90, edit_dist_threshold_pct=0.75
) -> list[Document]:
threshold_dist = 1 - threshold_similarity

# asyncify necessary funcs
query = asyncify(self.milvus.col.query)
search = asyncify(self.milvus.similarity_search_with_score_by_vector)

# Get all vectors with this ID
verts = await query(
f'{self.vertex_field} == "{v_id}"',
output_fields=[self.vertex_field, self.vector_field],
)
result = []
for v in verts:
# get the k closest verts
sim = await search(
v["document_vector"],
k=k,
)
# filter verts using similiarity threshold and leven_dist
similar_verts = [
doc.metadata["vertex_id"]
for doc, dist in sim
# check semantic similarity
if dist < threshold_dist
# check name similarity (won't merge Apple and Google if they're semantically similar)
and self.edit_dist_check(
doc.metadata["vertex_id"],
v_id,
edit_dist_threshold_pct,
)
# don't have to merge verts with the same id (they're the same)
and doc.metadata["vertex_id"] != v_id
]
result.extend(similar_verts)
result.append(v_id)
return set(result)

def __del__(self):
metrics.milvus_active_connections.labels(self.collection_name).dec
13 changes: 10 additions & 3 deletions common/extractors/BaseExtractor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
class BaseExtractor:
def __init__():
from abc import ABC, abstractmethod

from langchain_community.graphs.graph_document import GraphDocument


class BaseExtractor(ABC):
@abstractmethod
def extract(self, text:str):
pass

def extract(self, text):
@abstractmethod
async def aextract(self, text:str) -> list[GraphDocument]:
pass
Loading

0 comments on commit d80acfa

Please sign in to comment.