diff --git a/common/config.py b/common/config.py index 718d2c43..8812016c 100644 --- a/common/config.py +++ b/common/config.py @@ -16,15 +16,16 @@ AWSBedrock, AzureOpenAI, GoogleVertexAI, - OpenAI, Groq, - Ollama, HuggingFaceEndpoint, + LLM_Model, + Ollama, + OpenAI, IBMWatsonX ) +from common.logs.logwriter import LogWriter from common.session import SessionHandler from common.status import StatusManager -from common.logs.logwriter import LogWriter security = HTTPBasic() session_handler = SessionHandler() @@ -105,7 +106,7 @@ raise Exception("Embedding service not implemented") -def get_llm_service(llm_config): +def get_llm_service(llm_config) -> LLM_Model: if llm_config["completion_service"]["llm_service"].lower() == "openai": return OpenAI(llm_config["completion_service"]) elif llm_config["completion_service"]["llm_service"].lower() == "azure": @@ -191,7 +192,7 @@ def get_llm_service(llm_config): doc_processing_config = { "chunker": "semantic", "chunker_config": {"method": "percentile", "threshold": 0.95}, - "extractor": "llm", + "extractor": "graphrag", "extractor_config": {}, } elif DOC_PROCESSING_CONFIG.endswith(".json"): diff --git a/common/embeddings/embedding_services.py b/common/embeddings/embedding_services.py index 7195edf4..8020b97f 100644 --- a/common/embeddings/embedding_services.py +++ b/common/embeddings/embedding_services.py @@ -1,11 +1,13 @@ +import logging import os +import time from typing import List + from langchain.schema.embeddings import Embeddings -import logging -import time + from common.logs.log import req_id_cv -from common.metrics.prometheus_metrics import metrics from common.logs.logwriter import LogWriter +from common.metrics.prometheus_metrics import metrics logger = logging.getLogger(__name__) @@ -87,6 +89,33 @@ def embed_query(self, question: str) -> List[float]: duration ) + async def aembed_query(self, question: str) -> List[float]: + """Embed Query Async. + Embed a string. + + Args: + question (str): + A string to embed. + """ + # start_time = time.time() + # metrics.llm_inprogress_requests.labels(self.model_name).inc() + + # try: + logger.debug_pii(f"aembed_query() embedding question={question}") + query_embedding = await self.embeddings.aembed_query(question) + # metrics.llm_success_response_total.labels(self.model_name).inc() + return query_embedding + # except Exception as e: + # # metrics.llm_query_error_total.labels(self.model_name).inc() + # raise e + # finally: + # metrics.llm_request_total.labels(self.model_name).inc() + # metrics.llm_inprogress_requests.labels(self.model_name).dec() + # duration = time.time() - start_time + # metrics.llm_request_duration_seconds.labels(self.model_name).observe( + # duration + # ) + class AzureOpenAI_Ada002(EmbeddingModel): """Azure OpenAI Ada-002 Embedding Model""" @@ -105,7 +134,8 @@ def __init__(self, config): super().__init__( config, model_name=config.get("model_name", "OpenAI gpt-4-0613") ) - from langchain_openai import OpenAIEmbeddings + # from langchain_openai import OpenAIEmbeddings + from langchain_community.embeddings.openai import OpenAIEmbeddings self.embeddings = OpenAIEmbeddings() @@ -124,8 +154,8 @@ class AWS_Bedrock_Embedding(EmbeddingModel): """AWS Bedrock Embedding Model""" def __init__(self, config): - from langchain_community.embeddings import BedrockEmbeddings import boto3 + from langchain_community.embeddings import BedrockEmbeddings super().__init__(config=config, model_name=config["embedding_model"]) diff --git a/common/embeddings/milvus_embedding_store.py b/common/embeddings/milvus_embedding_store.py index e9bbdfe1..de7812fd 100644 --- a/common/embeddings/milvus_embedding_store.py +++ b/common/embeddings/milvus_embedding_store.py @@ -1,18 +1,22 @@ import logging +import traceback from time import sleep, time from typing import Iterable, List, Optional, Tuple -from langchain_milvus.vectorstores import Milvus +import Levenshtein as lev +from asyncer import asyncify +from langchain_community.vectorstores import Milvus from langchain_core.documents.base import Document -from pymilvus import connections, utility +# from langchain_milvus.vectorstores import Milvus +from langchain_community.vectorstores.milvus import Milvus +from pymilvus import MilvusException, connections, utility from pymilvus.exceptions import MilvusException from common.embeddings.base_embedding_store import EmbeddingStore from common.embeddings.embedding_services import EmbeddingModel from common.logs.log import req_id_cv -from common.metrics.prometheus_metrics import metrics from common.logs.logwriter import LogWriter -from pymilvus import MilvusException +from common.metrics.prometheus_metrics import metrics logger = logging.getLogger(__name__) @@ -33,6 +37,7 @@ def __init__( alias: str = "alias", retry_interval: int = 2, max_retry_attempts: int = 10, + drop_old=False, ): self.embedding_service = embedding_service self.vector_field = vector_field @@ -43,6 +48,7 @@ def __init__( self.milvus_alias = alias self.retry_interval = retry_interval self.max_retry_attempts = max_retry_attempts + self.drop_old = drop_old if host.startswith("http"): if host.endswith(str(port)): @@ -77,7 +83,7 @@ def connect_to_milvus(self): while retry_attempt < self.max_retry_attempts: try: connections.connect(**self.milvus_connection) - metrics.milvus_active_connections.labels(self.collection_name).inc + # metrics.milvus_active_connections.labels(self.collection_name).inc LogWriter.info( f"""Initializing Milvus with host={self.milvus_connection.get("host", self.milvus_connection.get("uri", "unknown host"))}, port={self.milvus_connection.get('port', 'unknown')}, username={self.milvus_connection.get('user', 'unknown')}, collection={self.collection_name}""" @@ -88,7 +94,7 @@ def connect_to_milvus(self): collection_name=self.collection_name, connection_args=self.milvus_connection, auto_id=True, - drop_old=False, + drop_old=self.drop_old, text_field=self.text_field, vector_field=self.vector_field, ) @@ -120,6 +126,9 @@ def metadata_func(record: dict, metadata: dict) -> dict: return metadata LogWriter.info("Milvus add initial load documents init()") + import os + + logger.info(f"*******{os.path.exists('tg_documents')}") loader = DirectoryLoader( "./common/tg_documents/", glob="*.json", @@ -216,6 +225,76 @@ def add_embeddings( error_message = f"An error occurred while registering document: {str(e)}" LogWriter.error(error_message) + async def aadd_embeddings( + self, + embeddings: Iterable[Tuple[str, List[float]]], + metadatas: List[dict] = None, + ): + """Async Add Embeddings. + Add embeddings to the Embedding store. + Args: + embeddings (Iterable[Tuple[str, List[float]]]): + Iterable of content and embedding of the document. + metadatas (List[Dict]): + List of dictionaries containing the metadata for each document. + The embeddings and metadatas list need to have identical indexing. + """ + try: + if metadatas is None: + metadatas = [] + + # add fields required by Milvus if they do not exist + if self.support_ai_instance: + for metadata in metadatas: + if self.vertex_field not in metadata: + metadata[self.vertex_field] = "" + else: + for metadata in metadatas: + if "seq_num" not in metadata: + metadata["seq_num"] = 1 + if "source" not in metadata: + metadata["source"] = "" + + LogWriter.info( + f"request_id={req_id_cv.get()} Milvus ENTRY aadd_embeddings()" + ) + texts = [text for text, _ in embeddings] + + # operation_type = "add_texts" + # metrics.milvus_query_total.labels( + # self.collection_name, operation_type + # ).inc() + # start_time = time() + + added = await self.milvus.aadd_texts(texts=texts, metadatas=metadatas) + + # duration = time() - start_time + # metrics.milvus_query_duration_seconds.labels( + # self.collection_name, operation_type + # ).observe(duration) + + LogWriter.info( + f"request_id={req_id_cv.get()} Milvus EXIT aadd_embeddings()" + ) + + # Check if registration was successful + if added: + success_message = f"Document registered with id: {added[0]}" + LogWriter.info(success_message) + return success_message + else: + error_message = f"Failed to register document {added}" + LogWriter.error(error_message) + raise Exception(error_message) + + except Exception as e: + error_message = f"An error occurred while registering document:{metadatas} ({len(texts)},{len(metadatas)})\nErr: {str(e)}" + LogWriter.error(error_message) + exc = traceback.format_exc() + LogWriter.error(exc) + LogWriter.error(f"{texts}") + raise e + def get_pks( self, expr: str, @@ -509,14 +588,65 @@ def query(self, expr: str, output_fields: List[str]): return None try: - query_result = self.milvus.col.query( - expr=expr, output_fields=output_fields - ) + query_result = self.milvus.col.query(expr=expr, output_fields=output_fields) except MilvusException as exc: - LogWriter.error(f"Failed to get outputs: {self.milvus.collection_name} error: {exc}") + LogWriter.error( + f"Failed to get outputs: {self.milvus.collection_name} error: {exc}" + ) raise exc return query_result + def edit_dist_check(self, a: str, b: str, edit_dist_threshold: float): + a = a.lower() + b = b.lower() + # if the words are short, they should be the same + if len(a) < 5 and len(b) < 5: + return a == b + + # edit_dist_threshold (as a percent) of word must match + threshold = int(min(len(a), len(b)) * (1 - edit_dist_threshold)) + return lev.distance(a, b) < threshold + + async def aget_k_closest( + self, v_id: str, k=15, threshold_similarity=0.90, edit_dist_threshold_pct=0.75 + ) -> list[Document]: + threshold_dist = 1 - threshold_similarity + + # asyncify necessary funcs + query = asyncify(self.milvus.col.query) + search = asyncify(self.milvus.similarity_search_with_score_by_vector) + + # Get all vectors with this ID + verts = await query( + f'{self.vertex_field} == "{v_id}"', + output_fields=[self.vertex_field, self.vector_field], + ) + result = [] + for v in verts: + # get the k closest verts + sim = await search( + v["document_vector"], + k=k, + ) + # filter verts using similiarity threshold and leven_dist + similar_verts = [ + doc.metadata["vertex_id"] + for doc, dist in sim + # check semantic similarity + if dist < threshold_dist + # check name similarity (won't merge Apple and Google if they're semantically similar) + and self.edit_dist_check( + doc.metadata["vertex_id"], + v_id, + edit_dist_threshold_pct, + ) + # don't have to merge verts with the same id (they're the same) + and doc.metadata["vertex_id"] != v_id + ] + result.extend(similar_verts) + result.append(v_id) + return set(result) + def __del__(self): metrics.milvus_active_connections.labels(self.collection_name).dec diff --git a/common/extractors/BaseExtractor.py b/common/extractors/BaseExtractor.py index 3f1ec92b..e8638665 100644 --- a/common/extractors/BaseExtractor.py +++ b/common/extractors/BaseExtractor.py @@ -1,6 +1,13 @@ -class BaseExtractor: - def __init__(): +from abc import ABC, abstractmethod + +from langchain_community.graphs.graph_document import GraphDocument + + +class BaseExtractor(ABC): + @abstractmethod + def extract(self, text:str): pass - def extract(self, text): + @abstractmethod + async def aextract(self, text:str) -> list[GraphDocument]: pass diff --git a/common/extractors/GraphExtractor.py b/common/extractors/GraphExtractor.py new file mode 100644 index 00000000..2a7ba505 --- /dev/null +++ b/common/extractors/GraphExtractor.py @@ -0,0 +1,70 @@ +from langchain_community.graphs.graph_document import GraphDocument +from langchain_core.documents import Document +from langchain_experimental.graph_transformers import LLMGraphTransformer + +from common.config import get_llm_service, llm_config +from common.extractors.BaseExtractor import BaseExtractor + + +class GraphExtractor(BaseExtractor): + def __init__(self): + llm = get_llm_service(llm_config).llm + self.transformer = LLMGraphTransformer( + llm=llm, + node_properties=["description"], + relationship_properties=["description"], + ) + + def extract(self, text) -> list[GraphDocument]: + """ + returns a list of GraphDocument: + Each doc is: + nodes=[ + Node( + id='Marie Curie', + type='Person', + properties={ + 'description': 'A Polish and naturalised-French physicist and chemist who conducted pioneering research on radioactivity.' + } + ), + ... + ], + relationships=[ + Relationship( + source=Node(id='Marie Curie', type='Person'), + target=Node(id='Pierre Curie', type='Person'), + type='SPOUSE' + ), + ... + ] + """ + doc = Document(page_content=text) + graph_docs = self.transformer.convert_to_graph_documents([doc]) + return graph_docs + + async def aextract(self, text:str) -> list[GraphDocument]: + """ + returns a list of GraphDocument: + Each doc is: + nodes=[ + Node( + id='Marie Curie', + type='Person', + properties={ + 'description': 'A Polish and naturalised-French physicist and chemist who conducted pioneering research on radioactivity.' + } + ), + ... + ], + relationships=[ + Relationship( + source=Node(id='Marie Curie', type='Person'), + target=Node(id='Pierre Curie', type='Person'), + type='SPOUSE' + ), + ... + ] + """ + doc = Document(page_content=text) + graph_docs = await self.transformer.aconvert_to_graph_documents([doc]) + return graph_docs diff --git a/common/extractors/LLMEntityRelationshipExtractor.py b/common/extractors/LLMEntityRelationshipExtractor.py index d5a0a970..415c3235 100644 --- a/common/extractors/LLMEntityRelationshipExtractor.py +++ b/common/extractors/LLMEntityRelationshipExtractor.py @@ -1,8 +1,9 @@ -from common.llm_services import LLM_Model +import json +from typing import List + from common.extractors.BaseExtractor import BaseExtractor +from common.llm_services import LLM_Model from common.py_schemas import KnowledgeGraph -from typing import List -import json class LLMEntityRelationshipExtractor(BaseExtractor): @@ -19,6 +20,34 @@ def __init__( self.strict_mode = strict_mode def _extract_kg_from_doc(self, doc, chain, parser): + """ + returns: + { + "nodes": [ + { + "id": "str", + "type": "string", + "definition": "string" + } + ], + "rels": [ + { + "source":{ + "id": "str", + "type": "string", + "definition": "string" + } + "target":{ + "id": "str", + "type": "string", + "definition": "string" + } + "definition" + } + ] + } + """ + try: out = chain.invoke( {"input": doc, "format_instructions": parser.get_format_instructions()} diff --git a/common/extractors/__init__.py b/common/extractors/__init__.py index ced539e4..e2f0bcdf 100644 --- a/common/extractors/__init__.py +++ b/common/extractors/__init__.py @@ -1,3 +1,4 @@ +from common.extractors.GraphExtractor import GraphExtractor from common.extractors.LLMEntityRelationshipExtractor import ( LLMEntityRelationshipExtractor, ) diff --git a/common/gsql/graphRAG/ResolveRelationships.gsql b/common/gsql/graphRAG/ResolveRelationships.gsql new file mode 100644 index 00000000..6a0e515d --- /dev/null +++ b/common/gsql/graphRAG/ResolveRelationships.gsql @@ -0,0 +1,26 @@ +CREATE DISTRIBUTED QUERY ResolveRelationships(BOOL printResults=FALSE) SYNTAX V2 { + /* + * RE1 <- entity -RELATES-> entity -> RE2 + * to + * RE1 -resolved-> RE + * + * Combines all of a Resolved entity's children's relationships into + * RESOLVED_RELATIONSHIP + */ + REs = {ResolvedEntity.*}; + + + REs = SELECT re1 FROM REs:re1 -(:rel)- Entity:e_tgt -(RESOLVES_TO>:r)- ResolvedEntity:re2 + // Connect the The first RE to the second RE + ACCUM + INSERT INTO RESOLVED_RELATIONSHIP(FROM,TO, relation_type) VALUES(re1, re2, rel.relation_type); + + + IF printResults THEN + // show which entities didn't get resolved + Ents = {Entity.*}; + rEnts = SELECT e FROM Ents:e -(RESOLVES_TO>)- _; + ents = Ents minus rEnts; + PRINT ents; + END; +} diff --git a/common/gsql/graphRAG/SetEpochProcessing.gsql b/common/gsql/graphRAG/SetEpochProcessing.gsql new file mode 100644 index 00000000..9a92ecf9 --- /dev/null +++ b/common/gsql/graphRAG/SetEpochProcessing.gsql @@ -0,0 +1,7 @@ +CREATE DISTRIBUTED QUERY SetEpochProcessing(Vertex v_id) { + Verts = {v_id}; + + // mark the vertex as processed + Verts = SELECT v FROM Verts:v + POST-ACCUM v.epoch_processed = datetime_to_epoch(now()); +} diff --git a/common/gsql/graphRAG/StreamDocContent.gsql b/common/gsql/graphRAG/StreamDocContent.gsql new file mode 100644 index 00000000..a2845148 --- /dev/null +++ b/common/gsql/graphRAG/StreamDocContent.gsql @@ -0,0 +1,8 @@ +CREATE DISTRIBUTED QUERY StreamDocContent(Vertex doc) { + Doc = {doc}; + + // Get the document's content and mark it as processed + DocContent = SELECT c FROM Doc:d -(HAS_CONTENT)-> Content:c + POST-ACCUM d.epoch_processed = datetime_to_epoch(now()); + PRINT DocContent; +} diff --git a/common/gsql/graphRAG/StreamDocIds.gsql b/common/gsql/graphRAG/StreamDocIds.gsql new file mode 100644 index 00000000..2fb4a9c4 --- /dev/null +++ b/common/gsql/graphRAG/StreamDocIds.gsql @@ -0,0 +1,16 @@ +CREATE DISTRIBUTED QUERY StreamDocIds(INT current_batch, INT ttl_batches) { + /* + * Get the IDs of documents that have not already been processed (one + * batch at a time) + */ + ListAccum @@doc_ids; + Docs = {Document.*}; + + Docs = SELECT d FROM Docs:d + WHERE vertex_to_int(d) % ttl_batches == current_batch + AND d.epoch_processed == 0 + ACCUM @@doc_ids += d.id + POST-ACCUM d.epoch_processing = datetime_to_epoch(now()); // set the processing time + + PRINT @@doc_ids; +} diff --git a/common/gsql/graphRAG/StreamIds.gsql b/common/gsql/graphRAG/StreamIds.gsql new file mode 100644 index 00000000..41181007 --- /dev/null +++ b/common/gsql/graphRAG/StreamIds.gsql @@ -0,0 +1,16 @@ +CREATE DISTRIBUTED QUERY StreamIds(INT current_batch, INT ttl_batches, STRING v_type) { + /* + * Get the IDs of entities that have not already been processed + * (one batch at a time) + */ + ListAccum @@ids; + Verts = {v_type}; + + Verts = SELECT v FROM Verts:v + WHERE vertex_to_int(v) % ttl_batches == current_batch + AND v.epoch_processed == 0 + ACCUM @@ids += v.id + POST-ACCUM v.epoch_processing = datetime_to_epoch(now()); // set the processing time + + PRINT @@ids; +} diff --git a/common/gsql/graphRAG/communities_have_desc.gsql b/common/gsql/graphRAG/communities_have_desc.gsql new file mode 100644 index 00000000..f5cda70e --- /dev/null +++ b/common/gsql/graphRAG/communities_have_desc.gsql @@ -0,0 +1,14 @@ +CREATE DISTRIBUTED QUERY communities_have_desc(UINT iter) SYNTAX V2{ + SumAccum @@descrs; + Comms = {Community.*}; + Comms = SELECT c FROM Comms:c + WHERE c.iteration == iter + ACCUM + IF length(c.description) > 0 THEN + @@descrs += 1 + END; + + + PRINT (@@descrs == Comms.size()) as all_have_desc; + PRINT @@descrs, Comms.size(); +} diff --git a/common/gsql/graphRAG/get_community_children.gsql b/common/gsql/graphRAG/get_community_children.gsql new file mode 100644 index 00000000..7913e1b7 --- /dev/null +++ b/common/gsql/graphRAG/get_community_children.gsql @@ -0,0 +1,12 @@ +CREATE DISTRIBUTED QUERY get_community_children(Vertex comm, UINT iter) SYNTAX V2{ + Comms = {comm}; + + IF iter > 1 THEN + Comms = SELECT t FROM Comms:c -()- ResolvedEntity -(_>)- Entity:t; + + PRINT Ents[Ents.description as description] as children; + END; +} diff --git a/common/gsql/graphRAG/louvain/graphrag_louvain_communities.gsql b/common/gsql/graphRAG/louvain/graphrag_louvain_communities.gsql new file mode 100644 index 00000000..241ccaf0 --- /dev/null +++ b/common/gsql/graphRAG/louvain/graphrag_louvain_communities.gsql @@ -0,0 +1,198 @@ +CREATE DISTRIBUTED QUERY graphrag_louvain_communities(UINT iteration=1, UINT max_hop = 10, UINT n_batches = 1) SYNTAX V2{ + /* + * This is the same query as tg_louvain, just that Paper-related schema + * are changed to Community-related schema + * + * For the first call to this query, iteration = 1 + */ + TYPEDEF TUPLE community, STRING ext_vid> Move; + SumAccum @@m; // the sum of the weights of all the links in the network + MinAccum> @community_id; // the community ID of the node + MinAccum @community_vid; // the community ID of the node + SumAccum @k; // the sum of the weights of the links incident to the node + SumAccum @k_in; // the sum of the weights of the links inside the previous community of the node + SumAccum @k_self_loop; // the weight of the self-loop link + MapAccum, SumAccum> @community_k_in_map; // the community of the neighbors of the nodes -> the sum of the weights of the links inside the community + MapAccum, SumAccum> @@community_sum_total_map; // community ID C -> the sum of the weights of the links incident to nodes in C + SumAccum @community_sum_total; // the sum of the weights of the links incident to nodes in the community of the node + MapAccum, SumAccum> @@community_sum_in_map; // community ID -> the sum of the weights of the links inside the community + MapAccum>> @@source_target_k_in_map; // source community ID -> (target community ID -> the sum of the weights of the links from the source community to the target community) + SumAccum @delta_Q_remove; // delta Q to remove the node from the previous community + MaxAccum @best_move; // best move of the node with the highest delta Q to move the isolated node into the new community + MaxAccum @@min_double; // used to reset the @best_move + SumAccum @@move_cnt; + OrAccum @to_change_community, @is_current_iter, @has_parent; + SumAccum @batch_id; + MinAccum @vid; + + AllNodes = {Community.*}; + + // Get communities of the current iteration + AllNodes = SELECT s FROM AllNodes:s + WHERE s.iteration == iteration + ACCUM s.@is_current_iter += TRUE; + + // init + z = SELECT s FROM AllNodes:s -(_>:e)- Community:t + WHERE s.@is_current_iter AND t.@is_current_iter + ACCUM s.@k += e.weight, + @@m += e.weight/2, + IF s == t THEN // self loop + s.@k_self_loop += e.weight + END + POST-ACCUM + s.@community_id = s, // assign node to its own community + s.@community_vid = to_string(s.id), // external id + s.@vid = getvid(s), // internal id (used in batching) + s.@batch_id = s.@vid % n_batches; // get batch number + + IF @@m < 0.00000000001 THEN + PRINT "Warning: the sum of the weights in the edges should be greater than zero!"; + RETURN; + END; + + // Local moving + INT hop = 0; + Candidates = AllNodes; + WHILE Candidates.size() > 0 AND hop < max_hop DO + hop += 1; + IF hop == 1 THEN // first iteration + ChangedNodes = SELECT s FROM Candidates:s -(_>:e)- Community:t + WHERE s.@community_id != t.@community_id // can't move within the same community + AND s.@is_current_iter AND t.@is_current_iter // only use Communities in the current iteration + ACCUM + DOUBLE dq = 1 - s.@k * t.@k / (2 * @@m), + s.@best_move += Move(dq, t.@community_id, t.@community_vid) // find the best move + POST-ACCUM + IF s.@best_move.delta_q > 0 THEN // if the move increases dq + s.@to_change_community += TRUE + END + HAVING s.@to_change_community == TRUE; // only select nodes that will move + ELSE // other iterations + // Calculate sum_total of links in each community + Tmp = SELECT s FROM AllNodes:s + POST-ACCUM + @@community_sum_total_map += (s.@community_id -> s.@k); + // store community's total edges in each vert (easier access) + Tmp = SELECT s FROM AllNodes:s + POST-ACCUM + s.@community_sum_total = @@community_sum_total_map.get(s.@community_id); + @@community_sum_total_map.clear(); + + // find the best move + ChangedNodes = {}; + + // process nodes in batch + FOREACH batch_id IN RANGE[0, n_batches-1] DO + Nodes = SELECT s FROM Candidates:s -(_>:e)- Community:t + WHERE s.@batch_id == batch_id + AND s.@is_current_iter AND t.@is_current_iter // only use Communities in the current iteration + ACCUM + IF s.@community_id == t.@community_id THEN + // add edge weights connected to s + s.@k_in += e.weight + ELSE + // add edge weights connecetd to t + s.@community_k_in_map += (t.@community_id -> e.weight) + END + POST-ACCUM + // ∆Q if s is moved out of its current community + s.@delta_Q_remove = 2 * s.@k_self_loop - 2 * s.@k_in + s.@k * (s.@community_sum_total - s.@k) / @@m, + s.@k_in = 0, + s.@best_move = Move(@@min_double, s, to_string(s.id)); // reset best move + + // find the best move + Nodes = SELECT s FROM Nodes:s -(_>:E)- Community:t + WHERE s.@community_id != t.@community_id + AND s.@is_current_iter AND t.@is_current_iter // only use Communities in the current iteration + ACCUM + DOUBLE dq = 2 * s.@community_k_in_map.get(t.@community_id) - s.@k * t.@community_sum_total / @@m, + s.@best_move += Move(dq, t.@community_id, t.@community_vid) // find the best move + POST-ACCUM + IF s.@delta_Q_remove + s.@best_move.delta_q > 0 THEN // if the move increases dq + s.@to_change_community = TRUE// s should move + END, + s.@community_k_in_map.clear() + HAVING s.@to_change_community == TRUE; // only select nodes that will move + + // Add nodes that will move to ChangedNodes + ChangedNodes = ChangedNodes UNION Nodes; + END; + END; + // If two nodes swap, only change the community of one of them + SwapNodes = SELECT s FROM ChangedNodes:s -(_>:e)- Community:t + WHERE s.@best_move.community == t.@community_id + AND s.@is_current_iter AND t.@is_current_iter // only use Communities in the current iteration + AND t.@to_change_community + AND t.@best_move.community == s.@community_id + // if delta Q are the same, only change the one with larger delta Q or the one with smaller @vid + AND ( + s.@delta_Q_remove + s.@best_move.delta_q < t.@delta_Q_remove + t.@best_move.delta_q + OR ( + abs( + (s.@delta_Q_remove + s.@best_move.delta_q) + - (t.@delta_Q_remove + t.@best_move.delta_q) + ) < 0.00000000001 + AND s.@vid > t.@vid + ) + ) + POST-ACCUM + s.@to_change_community = FALSE; + + // remove SwapNodes (don't need to be changed) + ChangedNodes = ChangedNodes MINUS SwapNodes; + + // Update node communities (based on max ∆Q) + SwapNodes = SELECT s FROM ChangedNodes:s + POST-ACCUM + s.@community_id = s.@best_move.community, // move the node + s.@community_vid = s.@best_move.ext_vid, // move the node (external v_id update) + s.@to_change_community = FALSE; + @@move_cnt += ChangedNodes.size(); + + // Get all neighbours of the changed node that do not belong to the node’s new community + Candidates = SELECT t FROM ChangedNodes:s -(_>:e)- Community:t + WHERE t.@community_id != s.@community_id + AND s.@is_current_iter AND t.@is_current_iter; // only use Communities in the current iteration + END; + + // Coarsening + @@community_sum_total_map.clear(); + Tmp = SELECT s FROM AllNodes:s -(_>:e)- Community:t + WHERE s.@is_current_iter AND t.@is_current_iter // only use Communities in the current iteration + ACCUM + IF s.@community_id == t.@community_id THEN + // keep track of how many edges are within the community + @@community_sum_in_map += (s.@community_id -> e.weight) + ELSE + // get LINKS_TO edge weights (how many edges are between communities) + @@source_target_k_in_map += (s.@community_vid -> (t.@community_vid -> e.weight)) + END, + t.@has_parent += TRUE // Used to help find unattached partitions + POST-ACCUM + // Write the results to a new community vertex (iteration + 1) + // ID , iter, edges within the community + INSERT INTO Community VALUES (s.id+"_"+to_string(iteration+1), iteration+1, ""), + INSERT INTO HAS_PARENT VALUES (s, s.@community_vid+"_"+to_string(iteration+1)) // link Community's child/parent community + ; + + // Continue community hierarchy for unattached partitions + Tmp = SELECT s FROM AllNodes:s + WHERE s.@is_current_iter + AND NOT s.@has_parent + POST-ACCUM + // if s is a part of an unattached partition, add to its community hierarchy to maintain parity with rest of graph + INSERT INTO Community VALUES (s.id+"_"+to_string(iteration+1), iteration+1, ""), + INSERT INTO HAS_PARENT VALUES (s, s.id+"_"+to_string(iteration+1)) // link Community's child/parent community + ; + + // link communities + // "If two communities have an edge between them, their parents should also have an edge bewtween them" + Tmp = SELECT s FROM AllNodes:s -(_>:e)- Community:t + WHERE s.@community_vid != t.@community_vid + AND s.@is_current_iter AND t.@is_current_iter // only use Communities in the current iteration + ACCUM + DOUBLE w = @@source_target_k_in_map.get(s.@community_vid).get(t.@community_vid)/2, + INSERT INTO LINKS_TO VALUES (s.@community_vid+"_"+to_string(iteration+1), t.@community_vid+"_"+to_string(iteration+1), w) + ; +} diff --git a/common/gsql/graphRAG/louvain/graphrag_louvain_init.gsql b/common/gsql/graphRAG/louvain/graphrag_louvain_init.gsql new file mode 100644 index 00000000..42e9108d --- /dev/null +++ b/common/gsql/graphRAG/louvain/graphrag_louvain_init.gsql @@ -0,0 +1,180 @@ +CREATE DISTRIBUTED QUERY graphrag_louvain_init(UINT max_hop = 10, UINT n_batches = 1) { + /* + * Initialize GraphRAG's hierarchical communities. + */ + TYPEDEF TUPLE community, STRING ext_vid> Move; + SumAccum @@m; // the sum of the weights of all the links in the network + MinAccum> @community_id; // the community ID of the node + MinAccum @community_vid; // the community ID of the node + SumAccum @k; // the sum of the weights of the links incident to the node + SumAccum @k_in; // the sum of the weights of the links inside the previous community of the node + SumAccum @k_self_loop; // the weight of the self-loop link + MapAccum, SumAccum> @community_k_in_map; // the community of the neighbors of the nodes -> the sum of the weights of the links inside the community + MapAccum, SumAccum> @@community_sum_total_map; // community ID C -> the sum of the weights of the links incident to nodes in C + SumAccum @community_sum_total; // the sum of the weights of the links incident to nodes in the community of the node + MapAccum, SumAccum> @@community_sum_in_map; // community ID -> the sum of the weights of the links inside the community + MapAccum>> @@source_target_k_in_map; // source community ID -> (target community ID -> the sum of the weights of the links from the source community to the target community) + SumAccum @delta_Q_remove; // delta Q to remove the node from the previous community + MaxAccum @best_move; // best move of the node with the highest delta Q to move the isolated node into the new community + MaxAccum @@min_double; // used to reset the @best_move + SumAccum @@move_cnt; + OrAccum @to_change_community; + SumAccum @batch_id; + MinAccum @vid; + + AllNodes = {ResolvedEntity.*}; + DOUBLE wt = 1.0; + + // prevent multiple init runs + z = SELECT s FROM AllNodes:s -(_)-> Community:t; + IF z.size() > 0 THEN + EXCEPTION reinit(400001); + RAISE reinit("ERROR: the hierarchical communities have already been initialized"); + END; + + // init + z = SELECT s FROM AllNodes:s + ACCUM + s.@community_id = s, // assign node to its own community + s.@community_vid = s.id, // external id + s.@vid = getvid(s), // internal id (used in batching) + s.@batch_id = s.@vid % n_batches; // get batch number + z = SELECT s FROM AllNodes:s -(_)-> ResolvedEntity:t + ACCUM s.@k += wt, + @@m += 1; + + PRINT z.size(); + PRINT z; + + // Local moving + INT hop = 0; + Candidates = AllNodes; + WHILE Candidates.size() > 0 AND hop < max_hop DO + hop += 1; + IF hop == 1 THEN // first iteration + ChangedNodes = SELECT s FROM Candidates:s -(_:e)-> ResolvedEntity:t + WHERE s.@community_id != t.@community_id // can't move within the same community + ACCUM + DOUBLE dq = 1 - s.@k * t.@k / (2 * @@m), + s.@best_move += Move(dq, t.@community_id, t.@community_vid) // find the best move + POST-ACCUM + IF s.@best_move.delta_q > 0 THEN // if the move increases dq + s.@to_change_community += TRUE + END + HAVING s.@to_change_community == TRUE; // only select nodes that will move + PRINT ChangedNodes.size(); + ELSE // other iterations + // Calculate sum_total of links in each community + Tmp = SELECT s FROM AllNodes:s + POST-ACCUM + @@community_sum_total_map += (s.@community_id -> s.@k); + // store community's total edges in each vert (easier access) + Tmp = SELECT s FROM AllNodes:s + POST-ACCUM + s.@community_sum_total = @@community_sum_total_map.get(s.@community_id); + @@community_sum_total_map.clear(); + + // find the best move + ChangedNodes = {}; + + // process nodes in batch + FOREACH batch_id IN RANGE[0, n_batches-1] DO + Nodes = SELECT s FROM Candidates:s -(_:e)-> ResolvedEntity:t + WHERE s.@batch_id == batch_id + ACCUM + IF s.@community_id == t.@community_id THEN + // add edge weights connected to s + s.@k_in += wt + ELSE + // add edge weights connecetd to t + s.@community_k_in_map += (t.@community_id -> wt) + END + POST-ACCUM + // ∆Q if s is moved out of its current community + s.@delta_Q_remove = 2 * s.@k_self_loop - 2 * s.@k_in + s.@k * (s.@community_sum_total - s.@k) / @@m, + s.@k_in = 0, + s.@best_move = Move(@@min_double, s, to_string(s.id)); // reset best move + + // find the best move + Nodes = SELECT s FROM Nodes:s -(_:e)-> ResolvedEntity:t + WHERE s.@community_id != t.@community_id + ACCUM + DOUBLE dq = 2 * s.@community_k_in_map.get(t.@community_id) - s.@k * t.@community_sum_total / @@m, + s.@best_move += Move(dq, t.@community_id, t.@community_vid) // find the best move + POST-ACCUM + IF s.@delta_Q_remove + s.@best_move.delta_q > 0 THEN // if the move increases dq + s.@to_change_community = TRUE// s should move + END, + s.@community_k_in_map.clear() + HAVING s.@to_change_community == TRUE; // only select nodes that will move + + // Add nodes that will move to ChangedNodes + ChangedNodes = ChangedNodes UNION Nodes; + END; + END; + // If two nodes swap, only change the community of one of them + SwapNodes = SELECT s FROM ChangedNodes:s -(_:e)-> ResolvedEntity:t + WHERE s.@best_move.community == t.@community_id + AND t.@to_change_community + AND t.@best_move.community == s.@community_id + // if delta Q are the same, only change the one with larger delta Q or the one with smaller @vid + AND ( + s.@delta_Q_remove + s.@best_move.delta_q < t.@delta_Q_remove + t.@best_move.delta_q + OR ( + abs( + (s.@delta_Q_remove + s.@best_move.delta_q) + - (t.@delta_Q_remove + t.@best_move.delta_q) + ) < 0.00000000001 + AND s.@vid > t.@vid + ) + ) + POST-ACCUM + s.@to_change_community = FALSE; + + // remove SwapNodes (don't need to be changed) + ChangedNodes = ChangedNodes MINUS SwapNodes; + + // Update node communities (based on max ∆Q) + SwapNodes = SELECT s FROM ChangedNodes:s + POST-ACCUM + s.@community_id = s.@best_move.community, // move the node + s.@community_vid = s.@best_move.ext_vid, // move the node (external v_id update) + s.@to_change_community = FALSE; + @@move_cnt += ChangedNodes.size(); + + // Get all neighbours of the changed node that do not belong to the node’s new community + Candidates = SELECT t FROM ChangedNodes:s -(_:e)-> ResolvedEntity:t + WHERE t.@community_id != s.@community_id; + END; + + // Coarsening + UINT new_layer = 0; + @@community_sum_total_map.clear(); + Tmp = SELECT s FROM AllNodes:s -(_:e)-> ResolvedEntity:t + ACCUM + IF s.@community_id == t.@community_id THEN + // keep track of how many edges are within the community + @@community_sum_in_map += (s.@community_id -> wt) + ELSE + // get LINKS_TO edge weights (how many edges are between communities) + @@source_target_k_in_map += (s.@community_vid -> (t.@community_vid -> 1)) + END + POST-ACCUM + // ID , iter, edges within the community + INSERT INTO Community VALUES (s.@community_vid+"_1", 1, ""), + INSERT INTO IN_COMMUNITY VALUES (s, s.@community_vid+"_1") // link entity to it's first community + ; + + PRINT @@source_target_k_in_map; + + @@community_sum_total_map.clear(); + // link communities + Tmp = SELECT s FROM AllNodes:s -(_:e)-> ResolvedEntity:t + WHERE s.@community_vid != t.@community_vid + ACCUM + DOUBLE w = @@source_target_k_in_map.get(s.@community_vid).get(t.@community_vid), + INSERT INTO LINKS_TO VALUES (s.@community_vid+"_1", t.@community_vid+"_1", w); + + + PRINT @@source_target_k_in_map; +} diff --git a/common/gsql/graphRAG/louvain/modularity.gsql b/common/gsql/graphRAG/louvain/modularity.gsql new file mode 100644 index 00000000..3aaad826 --- /dev/null +++ b/common/gsql/graphRAG/louvain/modularity.gsql @@ -0,0 +1,49 @@ +CREATE DISTRIBUTED QUERY modularity(UINT iteration=1) SYNTAX V2 { + SumAccum @@sum_weight; // the sum of the weights of all the links in the network + MinAccum @community_id; // the community ID of the node + MapAccum> @@community_total_weight_map; // community ID C -> the sum of the weights of the links incident to nodes in C + MapAccum> @@community_in_weight_map; // community ID -> the sum of the weights of the links inside the community + SumAccum @@modularity; + MinAccum @parent; + DOUBLE wt = 1.0; + Comms = {Community.*}; + + // Assign Entities to their correct community (given the specified iteration level) + IF iteration > 1 THEN + Comms = SELECT t FROM Comms:c -()- ResolvedEntity:t + ACCUM t.@community_id = c.@parent; + + ELSE + Entities = SELECT t FROM Comms:c -(_>)- ResolvedEntity:t + WHERE c.iteration == iteration + ACCUM t.@community_id = c.id; + END; + + Nodes = SELECT s FROM Entities:s -(_>:e)- ResolvedEntity:t + ACCUM + IF s.@community_id == t.@community_id THEN + @@community_in_weight_map += (s.@community_id -> wt) + END, + @@community_total_weight_map += (s.@community_id -> wt), + @@sum_weight += wt; + + @@modularity = 0; + FOREACH (community, total_weight) IN @@community_total_weight_map DO + DOUBLE in_weight = 0; + IF @@community_in_weight_map.containsKey(community) THEN + in_weight = @@community_in_weight_map.get(community); + END; + @@modularity += in_weight / @@sum_weight - pow(total_weight / @@sum_weight, 2); + END; + + PRINT @@modularity as mod; +} diff --git a/common/gsql/graphRAG/louvain/stream_community.gsql b/common/gsql/graphRAG/louvain/stream_community.gsql new file mode 100644 index 00000000..d01959d2 --- /dev/null +++ b/common/gsql/graphRAG/louvain/stream_community.gsql @@ -0,0 +1,9 @@ +CREATE DISTRIBUTED QUERY stream_community(UINT iter) { + Comms = {Community.*}; + + // Get communities of the current iteration + Comms = SELECT s FROM Comms:s + WHERE s.iteration == iter; + + PRINT Comms; +} diff --git a/common/gsql/supportai/Scan_For_Updates.gsql b/common/gsql/supportai/Scan_For_Updates.gsql index 03ced2ec..7d9d1b83 100644 --- a/common/gsql/supportai/Scan_For_Updates.gsql +++ b/common/gsql/supportai/Scan_For_Updates.gsql @@ -24,10 +24,10 @@ CREATE DISTRIBUTED QUERY Scan_For_Updates(STRING v_type = "Document", res = SELECT s FROM start:s -(HAS_CONTENT)-> Content:c ACCUM @@v_and_text += (s.id -> c.text) POST-ACCUM s.epoch_processing = datetime_to_epoch(now()); - ELSE IF v_type == "Concept" THEN - res = SELECT s FROM start:s - POST-ACCUM @@v_and_text += (s.id -> s.description), - s.epoch_processing = datetime_to_epoch(now()); + // ELSE IF v_type == "Concept" THEN + // res = SELECT s FROM start:s + // POST-ACCUM @@v_and_text += (s.id -> s.description), + // s.epoch_processing = datetime_to_epoch(now()); ELSE IF v_type == "Entity" THEN res = SELECT s FROM start:s POST-ACCUM @@v_and_text += (s.id -> s.definition), @@ -42,4 +42,4 @@ CREATE DISTRIBUTED QUERY Scan_For_Updates(STRING v_type = "Document", POST-ACCUM s.epoch_processing = datetime_to_epoch(now()); END; PRINT @@v_and_text; -} \ No newline at end of file +} diff --git a/common/gsql/supportai/SupportAI_Schema.gsql b/common/gsql/supportai/SupportAI_Schema.gsql index 061993bb..718ab1a7 100644 --- a/common/gsql/supportai/SupportAI_Schema.gsql +++ b/common/gsql/supportai/SupportAI_Schema.gsql @@ -2,7 +2,7 @@ CREATE SCHEMA_CHANGE JOB add_supportai_schema { ADD VERTEX DocumentChunk(PRIMARY_ID id STRING, idx INT, epoch_added UINT, epoch_processing UINT, epoch_processed UINT) WITH STATS="OUTDEGREE_BY_EDGETYPE", PRIMARY_ID_AS_ATTRIBUTE="true"; ADD VERTEX Document(PRIMARY_ID id STRING, epoch_added UINT, epoch_processing UINT, epoch_processed UINT) WITH STATS="OUTDEGREE_BY_EDGETYPE", PRIMARY_ID_AS_ATTRIBUTE="true"; ADD VERTEX Concept(PRIMARY_ID id STRING, description STRING, concept_type STRING, human_curated BOOL, epoch_added UINT, epoch_processing UINT, epoch_processed UINT) WITH STATS="OUTDEGREE_BY_EDGETYPE", PRIMARY_ID_AS_ATTRIBUTE="true"; - ADD VERTEX Entity(PRIMARY_ID id STRING, definition STRING, epoch_added UINT, epoch_processing UINT, epoch_processed UINT) WITH STATS="OUTDEGREE_BY_EDGETYPE", PRIMARY_ID_AS_ATTRIBUTE="true"; + ADD VERTEX Entity(PRIMARY_ID id STRING, definition STRING, description SET, entity_type STRING, epoch_added UINT, epoch_processing UINT, epoch_processed UINT) WITH STATS="OUTDEGREE_BY_EDGETYPE", PRIMARY_ID_AS_ATTRIBUTE="true"; ADD VERTEX Relationship(PRIMARY_ID id STRING, definition STRING, short_name STRING, epoch_added UINT, epoch_processing UINT, epoch_processed UINT) WITH STATS="OUTDEGREE_BY_EDGETYPE", PRIMARY_ID_AS_ATTRIBUTE="true"; ADD VERTEX DocumentCollection(PRIMARY_ID id STRING, epoch_added UINT) WITH STATS="OUTDEGREE_BY_EDGETYPE", PRIMARY_ID_AS_ATTRIBUTE="true"; ADD VERTEX Content(PRIMARY_ID id STRING, text STRING, epoch_added UINT) WITH STATS="OUTDEGREE_BY_EDGETYPE", PRIMARY_ID_AS_ATTRIBUTE="true"; @@ -18,4 +18,16 @@ CREATE SCHEMA_CHANGE JOB add_supportai_schema { ADD DIRECTED EDGE HAS_CHILD(FROM Document, TO DocumentChunk) WITH REVERSE_EDGE="reverse_HAS_CHILD"; ADD DIRECTED EDGE HAS_RELATIONSHIP(FROM Concept, TO Concept, relation_type STRING) WITH REVERSE_EDGE="reverse_HAS_RELATIONSHIP"; ADD DIRECTED EDGE CONTAINS_DOCUMENT(FROM DocumentCollection, TO Document) WITH REVERSE_EDGE="reverse_CONTAINS_DOCUMENT"; -} \ No newline at end of file + + // GraphRAG + ADD VERTEX Community (PRIMARY_ID id STRING, iteration UINT, description STRING) WITH PRIMARY_ID_AS_ATTRIBUTE="true"; + ADD VERTEX ResolvedEntity(PRIMARY_ID id STRING, entity_type STRING) WITH PRIMARY_ID_AS_ATTRIBUTE="true"; + + ADD DIRECTED EDGE RELATIONSHIP(FROM Entity, TO Entity, relation_type STRING) WITH REVERSE_EDGE="reverse_RELATIONSHIP"; + ADD DIRECTED EDGE RESOLVES_TO(FROM Entity, TO ResolvedEntity, relation_type STRING) WITH REVERSE_EDGE="reverse_RESOLVES_TO"; // Connect ResolvedEntities with their children entities + ADD DIRECTED EDGE RESOLVED_RELATIONSHIP(FROM ResolvedEntity, TO ResolvedEntity, relation_type STRING) WITH REVERSE_EDGE="reverse_RESOLVED_RELATIONSHIP"; // store edges between entities after they're resolved + + ADD DIRECTED EDGE IN_COMMUNITY(FROM ResolvedEntity, TO Community) WITH REVERSE_EDGE="reverse_IN_COMMUNITY"; + ADD DIRECTED EDGE LINKS_TO (from Community, to Community, weight DOUBLE) WITH REVERSE_EDGE="reverse_LINKS_TO"; + ADD DIRECTED EDGE HAS_PARENT (from Community, to Community) WITH REVERSE_EDGE="reverse_HAS_PARENT"; +} diff --git a/common/llm_services/openai_service.py b/common/llm_services/openai_service.py index 81d3281e..aad5d44f 100644 --- a/common/llm_services/openai_service.py +++ b/common/llm_services/openai_service.py @@ -1,6 +1,11 @@ import logging import os +if os.getenv("ECC"): + from langchain_openai.chat_models import ChatOpenAI +else: + from langchain_community.chat_models import ChatOpenAI + from common.llm_services import LLM_Model from common.logs.log import req_id_cv from common.logs.logwriter import LogWriter @@ -16,8 +21,6 @@ def __init__(self, config): auth_detail ] - from langchain_community.chat_models import ChatOpenAI - model_name = config["llm_model"] self.llm = ChatOpenAI( temperature=config["model_kwargs"]["temperature"], model_name=model_name diff --git a/common/py_schemas/schemas.py b/common/py_schemas/schemas.py index e5dd1faf..a58d4660 100644 --- a/common/py_schemas/schemas.py +++ b/common/py_schemas/schemas.py @@ -15,11 +15,9 @@ class SupportAIQuestion(BaseModel): method_params: dict = {} -class SupportAIInitConfig(BaseModel): - chunker: str - chunker_params: dict - extractor: str - extractor_params: dict +class SupportAIMethod(enum.StrEnum): + SUPPORTAI = enum.auto() + GRAPHRAG = enum.auto() class GSQLQueryInfo(BaseModel): @@ -126,15 +124,18 @@ class QueryUpsertRequest(BaseModel): id: Optional[str] query_info: Optional[GSQLQueryInfo] + class MessageContext(BaseModel): # TODO: fix this to contain proper message context user: str content: str + class ReportQuestions(BaseModel): question: str reasoning: str + class ReportSection(BaseModel): section_name: str description: str @@ -142,6 +143,7 @@ class ReportSection(BaseModel): copilot_fortify: bool = True actions: Optional[List[str]] = None + class ReportCreationRequest(BaseModel): topic: str sections: Union[List[ReportSection], str] = None @@ -150,6 +152,7 @@ class ReportCreationRequest(BaseModel): conversation_id: Optional[str] = None message_context: Optional[List[MessageContext]] = None + class Role(enum.StrEnum): SYSTEM = enum.auto() USER = enum.auto() diff --git a/common/py_schemas/tool_io_schemas.py b/common/py_schemas/tool_io_schemas.py index 1fe16de4..4ca91b3d 100644 --- a/common/py_schemas/tool_io_schemas.py +++ b/common/py_schemas/tool_io_schemas.py @@ -1,10 +1,8 @@ +from typing import Dict, List, Optional + from langchain.pydantic_v1 import BaseModel, Field -from typing import Optional -from langchain_community.graphs.graph_document import ( - Node as BaseNode, - Relationship as BaseRelationship, -) -from typing import List, Dict, Type +from langchain_community.graphs.graph_document import Node as BaseNode +from langchain_community.graphs.graph_document import Relationship as BaseRelationship class MapQuestionToSchemaResponse(BaseModel): @@ -81,14 +79,27 @@ class KnowledgeGraph(BaseModel): ..., description="List of relationships in the knowledge graph" ) + class ReportQuestion(BaseModel): question: str = Field("The question to be asked") reasoning: str = Field("The reasoning behind the question") + class ReportSection(BaseModel): section: str = Field("Name of the section") description: str = Field("Description of the section") - questions: List[ReportQuestion] = Field("List of questions and reasoning for the section") + questions: List[ReportQuestion] = Field( + "List of questions and reasoning for the section" + ) + class ReportSections(BaseModel): - sections: List[ReportSection] = Field("List of sections for the report") \ No newline at end of file + sections: List[ReportSection] = Field("List of sections for the report") + + +class CommunitySummary(BaseModel): + """Generate a summary of the documents that are within this community.""" + + summary: str = Field( + ..., description="The community summary derived from the input documents" + ) diff --git a/common/requirements.txt b/common/requirements.txt index d45f2a60..af45c357 100644 --- a/common/requirements.txt +++ b/common/requirements.txt @@ -1,155 +1,177 @@ -aiohttp==3.9.3 +aiohappyeyeballs==2.3.5 +aiohttp==3.10.3 aiosignal==1.3.1 -annotated-types==0.5.0 -anyio==3.7.1 +annotated-types==0.7.0 +anyio==4.4.0 appdirs==1.4.4 argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 async-timeout==4.0.3 asyncer==0.0.7 -attrs==23.1.0 -azure-core==1.30.1 -azure-storage-blob==12.19.1 +attrs==24.2.0 +azure-core==1.30.2 +azure-storage-blob==12.22.0 backoff==2.2.1 -beautifulsoup4==4.12.2 -boto3==1.28.83 -botocore==1.31.83 -cachetools==5.3.2 -certifi==2023.7.22 -cffi==1.16.0 +beautifulsoup4==4.12.3 +boto3==1.34.159 +botocore==1.34.159 +cachetools==5.4.0 +certifi==2024.7.4 +cffi==1.17.0 chardet==5.2.0 -charset-normalizer==3.2.0 +charset-normalizer==3.3.2 click==8.1.7 -cryptography==42.0.5 -dataclasses-json==0.5.14 -distro==1.8.0 +contourpy==1.2.1 +cryptography==43.0.0 +cycler==0.12.1 +dataclasses-json==0.6.7 +deepdiff==7.0.1 +distro==1.9.0 docker-pycreds==0.4.0 docstring_parser==0.16 -emoji==2.8.0 +emoji==2.12.1 environs==9.5.0 -exceptiongroup==1.1.3 -fastapi==0.103.1 +exceptiongroup==1.2.2 +fastapi==0.112.0 filelock==3.15.4 filetype==1.2.0 -frozenlist==1.4.0 -fsspec==2024.6.0 +fonttools==4.53.1 +frozenlist==1.4.1 +fsspec==2024.6.1 gitdb==4.0.11 -GitPython==3.1.40 -google-api-core==2.14.0 -google-auth==2.23.4 -google-cloud-aiplatform==1.52.0 -google-cloud-bigquery==3.13.0 -google-cloud-core==2.3.3 -google-cloud-resource-manager==1.10.4 -google-cloud-storage==2.13.0 +GitPython==3.1.43 +google-api-core==2.19.1 +google-auth==2.33.0 +google-cloud-aiplatform==1.61.0 +google-cloud-bigquery==3.25.0 +google-cloud-core==2.4.1 +google-cloud-resource-manager==1.12.5 +google-cloud-storage==2.18.2 google-crc32c==1.5.0 -google-resumable-media==2.6.0 -googleapis-common-protos==1.61.0 -greenlet==2.0.2 -groq==0.5.0 -grpc-google-iam-v1==0.12.7 -grpcio==1.59.2 -grpcio-status==1.59.2 +google-resumable-media==2.7.2 +googleapis-common-protos==1.63.2 +greenlet==3.0.3 +groq==0.9.0 +grpc-google-iam-v1==0.13.1 +grpcio==1.63.0 +grpcio-status==1.63.0 h11==0.14.0 -httpcore==0.18.0 -httptools==0.6.0 -httpx==0.25.0 -huggingface-hub==0.23.0 +httpcore==1.0.5 +httptools==0.6.1 +httpx==0.27.0 +huggingface-hub==0.24.5 ibm-cos-sdk==2.13.6 ibm-cos-sdk-core==2.13.6 ibm-cos-sdk-s3transfer==2.13.6 -ibm_watsonx_ai==1.0.11 -idna==3.4 -importlib_metadata==8.0.0 +ibm_watsonx_ai==1.1.5 +idna==3.7 +importlib_metadata==8.2.0 iniconfig==2.0.0 isodate==0.6.1 +jiter==0.5.0 jmespath==1.0.1 -joblib==1.3.2 -jq==1.6.0 +joblib==1.4.2 +jq==1.7.0 jsonpatch==1.33 -jsonpointer==2.4 -langchain==0.2.11 -langchain-community==0.2.10 -langchain-core==0.2.25 -langchain-experimental==0.0.63 -langchain-groq==0.1.8 -langchain-ibm==0.1.11 +jsonpath-python==1.0.6 +jsonpointer==3.0.0 +kiwisolver==1.4.5 +langchain==0.2.13 +langchain-community==0.2.12 +langchain-core==0.2.30 +langchain-experimental==0.0.64 +langchain-groq==0.1.9 +langchain-ibm==0.1.12 +langchain-milvus==0.1.4 +langchain-openai==0.1.21 langchain-text-splitters==0.2.2 -langchain_milvus==0.1.3 -langchain_openai==0.1.19 -langchainhub==0.1.20 +langchainhub==0.1.21 langdetect==1.0.9 -langgraph==0.1.16 -langsmith==0.1.94 +langgraph==0.2.3 +langgraph-checkpoint==1.0.2 +langsmith==0.1.99 +Levenshtein==0.25.1 lomond==0.3.3 -lxml==4.9.3 -marshmallow==3.20.1 -matplotlib==3.9.1 -minio==7.2.5 -multidict==6.0.4 +lxml==5.3.0 +marshmallow==3.21.3 +matplotlib==3.9.2 +milvus-lite==2.4.9 +minio==7.2.7 +multidict==6.0.5 mypy-extensions==1.0.0 -nltk==3.8.1 +nest-asyncio==1.6.0 +nltk==3.8.2 numpy==1.26.4 -openai==1.37.1 -orjson==3.9.15 -packaging==23.2 -pandas==2.1.1 +openai==1.40.6 +ordered-set==4.1.0 +orjson==3.10.7 +packaging==24.1 +pandas==2.1.4 pathtools==0.1.2 +pillow==10.4.0 +platformdirs==4.2.2 pluggy==1.5.0 prometheus_client==0.20.0 -proto-plus==1.22.3 -protobuf==4.24.4 -psutil==5.9.6 -pyarrow==15.0.1 -pyasn1==0.5.0 -pyasn1-modules==0.3.0 -pycparser==2.21 +proto-plus==1.24.0 +protobuf==5.27.3 +psutil==6.0.0 +pyarrow==17.0.0 +pyasn1==0.6.0 +pyasn1_modules==0.4.0 +pycparser==2.22 pycryptodome==3.20.0 -pydantic==2.3.0 -pydantic_core==2.6.3 -pygit2==1.13.2 -pymilvus==2.4.4 -pytest==8.2.0 +pydantic==2.8.2 +pydantic_core==2.20.1 +pygit2==1.15.1 +pymilvus==2.4.5 +pyparsing==3.1.2 +pypdf==4.3.1 +pytest==8.3.2 python-dateutil==2.9.0.post0 -python-dotenv==1.0.0 -python-iso639==2023.6.15 +python-dotenv==1.0.1 +python-iso639==2024.4.27 python-magic==0.4.27 pyTigerDriver==1.0.15 -pyTigerGraph==1.6.2 -pytz==2023.3.post1 -PyYAML==6.0.1 -rapidfuzz==3.4.0 -regex==2023.10.3 +pyTigerGraph==1.6.5 +pytz==2024.1 +PyYAML==6.0.2 +rapidfuzz==3.9.6 +regex==2024.7.24 requests==2.32.2 +requests-toolbelt==1.0.0 rsa==4.9 -s3transfer==0.7.0 +s3transfer==0.10.2 scikit-learn==1.5.1 -sentry-sdk==1.32.0 +scipy==1.14.0 +sentry-sdk==2.13.0 setproctitle==1.3.3 -shapely==2.0.2 +shapely==2.0.5 six==1.16.0 smmap==5.0.1 -sniffio==1.3.0 -soupsieve==2.5 -SQLAlchemy==2.0.20 -starlette==0.27.0 +sniffio==1.3.1 +soupsieve==2.6 +SQLAlchemy==2.0.32 +starlette==0.37.2 tabulate==0.9.0 -tenacity==8.2.3 +tenacity==8.5.0 +threadpoolctl==3.5.0 tiktoken==0.7.0 -tqdm==4.66.1 -types-requests==2.31.0.6 +tqdm==4.66.5 +types-requests==2.32.0.20240712 types-urllib3==1.26.25.14 typing-inspect==0.9.0 -typing_extensions==4.8.0 -tzdata==2023.3 -ujson==5.9.0 -unstructured==0.10.23 -urllib3==1.26.18 -uvicorn==0.23.2 -uvloop==0.17.0 -validators==0.22.0 -wandb==0.15.12 -watchfiles==0.20.0 -websockets==11.0.3 -yarl==1.9.2 -zipp==3.19.2 +typing_extensions==4.12.2 +tzdata==2024.1 +ujson==5.10.0 +unstructured==0.15.1 +unstructured-client==0.25.5 +urllib3==2.2.2 +uvicorn==0.30.6 +uvloop==0.19.0 +validators==0.33.0 +wandb==0.17.6 +watchfiles==0.23.0 +websockets==12.0 +wrapt==1.16.0 +yarl==1.9.4 +zipp==3.20.0 diff --git a/copilot/app/routers/supportai.py b/copilot/app/routers/supportai.py index 3f599b26..0eff3c41 100644 --- a/copilot/app/routers/supportai.py +++ b/copilot/app/routers/supportai.py @@ -1,22 +1,39 @@ import json import logging -import uuid from typing import Annotated -from fastapi import APIRouter, BackgroundTasks, Depends, Request, HTTPException +from fastapi import APIRouter, BackgroundTasks, Depends, Request, Response, status from fastapi.security.http import HTTPBase +from supportai import supportai from supportai.concept_management.create_concepts import ( - CommunityConceptCreator, EntityConceptCreator, HigherLevelConceptCreator, - RelationshipConceptCreator) -from supportai.retrievers import (EntityRelationshipRetriever, - HNSWOverlapRetriever, HNSWRetriever, - HNSWSiblingRetriever) - -from common.config import (db_config, embedding_service, embedding_store, - get_llm_service, llm_config, service_status) + CommunityConceptCreator, + EntityConceptCreator, + HigherLevelConceptCreator, + RelationshipConceptCreator, +) +from supportai.retrievers import ( + EntityRelationshipRetriever, + HNSWOverlapRetriever, + HNSWRetriever, + HNSWSiblingRetriever, +) + +from common.config import ( + db_config, + embedding_service, + embedding_store, + get_llm_service, + llm_config, + service_status, +) from common.logs.logwriter import LogWriter -from common.py_schemas.schemas import (CoPilotResponse, CreateIngestConfig, - LoadingInfo, SupportAIQuestion) +from common.py_schemas.schemas import ( # SupportAIInitConfig,; SupportAIMethod, + CoPilotResponse, + CreateIngestConfig, + LoadingInfo, + SupportAIMethod, + SupportAIQuestion, +) logger = logging.getLogger(__name__) router = APIRouter(tags=["SupportAI"]) @@ -27,57 +44,20 @@ def check_embedding_store_status(): if service_status["embedding_store"]["error"]: return HTTPException( - status_code=503, - detail=service_status["embedding_store"]["error"] + status_code=503, detail=service_status["embedding_store"]["error"] ) - + @router.post("/{graphname}/supportai/initialize") def initialize( - graphname, conn: Request, credentials: Annotated[HTTPBase, Depends(security)] + graphname, + conn: Request, + credentials: Annotated[HTTPBase, Depends(security)], ): conn = conn.state.conn - # need to open the file using the absolute path - file_path = "common/gsql/supportai/SupportAI_Schema.gsql" - with open(file_path, "r") as f: - schema = f.read() - schema_res = conn.gsql( - """USE GRAPH {}\n{}\nRUN SCHEMA_CHANGE JOB add_supportai_schema""".format( - graphname, schema - ) - ) - - file_path = "common/gsql/supportai/SupportAI_IndexCreation.gsql" - with open(file_path) as f: - index = f.read() - index_res = conn.gsql( - """USE GRAPH {}\n{}\nRUN SCHEMA_CHANGE JOB add_supportai_indexes""".format( - graphname, index - ) - ) - - file_path = "common/gsql/supportai/Scan_For_Updates.gsql" - with open(file_path) as f: - scan_for_updates = f.read() - res = conn.gsql( - "USE GRAPH " - + conn.graphname - + "\n" - + scan_for_updates - + "\n INSTALL QUERY Scan_For_Updates" - ) - - file_path = "common/gsql/supportai/Update_Vertices_Processing_Status.gsql" - with open(file_path) as f: - update_vertices = f.read() - res = conn.gsql( - "USE GRAPH " - + conn.graphname - + "\n" - + update_vertices - + "\n INSTALL QUERY Update_Vertices_Processing_Status" - ) + resp = supportai.init_supportai(conn, graphname) + schema_res, index_res = resp[0], resp[1] return { "host_name": conn._tg_connection.host, # include host_name for debugging from client. Their pyTG conn might not have the same host as what's configured in copilot "schema_creation_status": json.dumps(schema_res), @@ -88,132 +68,13 @@ def initialize( @router.post("/{graphname}/supportai/create_ingest") def create_ingest( graphname, - ingest_config: CreateIngestConfig, + cfg: CreateIngestConfig, conn: Request, credentials: Annotated[HTTPBase, Depends(security)], ): conn = conn.state.conn - if ingest_config.file_format.lower() == "json": - file_path = "common/gsql/supportai/SupportAI_InitialLoadJSON.gsql" - - with open(file_path) as f: - ingest_template = f.read() - ingest_template = ingest_template.replace("@uuid@", str(uuid.uuid4().hex)) - doc_id = ingest_config.loader_config.get("doc_id_field", "doc_id") - doc_text = ingest_config.loader_config.get("content_field", "content") - ingest_template = ingest_template.replace('"doc_id"', '"{}"'.format(doc_id)) - ingest_template = ingest_template.replace('"content"', '"{}"'.format(doc_text)) - - if ingest_config.file_format.lower() == "csv": - file_path = "common/gsql/supportai/SupportAI_InitialLoadCSV.gsql" - - with open(file_path) as f: - ingest_template = f.read() - ingest_template = ingest_template.replace("@uuid@", str(uuid.uuid4().hex)) - separator = ingest_config.get("separator", "|") - header = ingest_config.get("header", "true") - eol = ingest_config.get("eol", "\n") - quote = ingest_config.get("quote", "double") - ingest_template = ingest_template.replace('"|"', '"{}"'.format(separator)) - ingest_template = ingest_template.replace('"true"', '"{}"'.format(header)) - ingest_template = ingest_template.replace('"\\n"', '"{}"'.format(eol)) - ingest_template = ingest_template.replace('"double"', '"{}"'.format(quote)) - - file_path = "common/gsql/supportai/SupportAI_DataSourceCreation.gsql" - - with open(file_path) as f: - data_stream_conn = f.read() - - # assign unique identifier to the data stream connection - - data_stream_conn = data_stream_conn.replace( - "@source_name@", "SupportAI_" + graphname + "_" + str(uuid.uuid4().hex) - ) - - # check the data source and create the appropriate connection - if ingest_config.data_source.lower() == "s3": - data_conn = ingest_config.data_source_config - if ( - data_conn.get("aws_access_key") is None - or data_conn.get("aws_secret_key") is None - ): - raise Exception("AWS credentials not provided") - connector = { - "type": "s3", - "access.key": data_conn["aws_access_key"], - "secret.key": data_conn["aws_secret_key"], - } - - data_stream_conn = data_stream_conn.replace( - "@source_config@", json.dumps(connector) - ) - - elif ingest_config.data_source.lower() == "azure": - if ingest_config.data_source_config.get("account_key") is not None: - connector = { - "type": "abs", - "account.key": ingest_config.data_source_config["account_key"], - } - elif ingest_config.data_source_config.get("client_id") is not None: - # verify that the client secret is also provided - if ingest_config.data_source_config.get("client_secret") is None: - raise Exception("Client secret not provided") - # verify that the tenant id is also provided - if ingest_config.data_source_config.get("tenant_id") is None: - raise Exception("Tenant id not provided") - connector = { - "type": "abs", - "client.id": ingest_config.data_source_config["client_id"], - "client.secret": ingest_config.data_source_config["client_secret"], - "tenant.id": ingest_config.data_source_config["tenant_id"], - } - else: - raise Exception("Azure credentials not provided") - data_stream_conn = data_stream_conn.replace( - "@source_config@", json.dumps(connector) - ) - elif ingest_config.data_source.lower() == "gcs": - # verify that the correct fields are provided - if ingest_config.data_source_config.get("project_id") is None: - raise Exception("Project id not provided") - if ingest_config.data_source_config.get("private_key_id") is None: - raise Exception("Private key id not provided") - if ingest_config.data_source_config.get("private_key") is None: - raise Exception("Private key not provided") - if ingest_config.data_source_config.get("client_email") is None: - raise Exception("Client email not provided") - connector = { - "type": "gcs", - "project_id": ingest_config.data_source_config["project_id"], - "private_key_id": ingest_config.data_source_config["private_key_id"], - "private_key": ingest_config.data_source_config["private_key"], - "client_email": ingest_config.data_source_config["client_email"], - } - data_stream_conn = data_stream_conn.replace( - "@source_config@", json.dumps(connector) - ) - else: - raise Exception("Data source not implemented") - - load_job_created = conn.gsql("USE GRAPH {}\n".format(graphname) + ingest_template) - - data_source_created = conn.gsql( - "USE GRAPH {}\n".format(graphname) + data_stream_conn - ) - - return { - "load_job_id": load_job_created.split(":")[1] - .strip(" [") - .strip(" ") - .strip(".") - .strip("]"), - "data_source_id": data_source_created.split(":")[1] - .strip(" [") - .strip(" ") - .strip(".") - .strip("]"), - } + return supportai.create_ingest(graphname, cfg, conn) @router.post("/{graphname}/supportai/ingest") @@ -397,18 +258,24 @@ def build_concepts( return {"status": "success"} -@router.get("/{graphname}/supportai/forceupdate") -def ecc( - graphname, +@router.get("/{graphname}/{method}/forceupdate") +def supportai_update( + graphname: str, + method: str, conn: Request, credentials: Annotated[HTTPBase, Depends(security)], bg_tasks: BackgroundTasks, + response: Response, ): + if method != SupportAIMethod.SUPPORTAI and method != SupportAIMethod.GRAPHRAG: + response.status_code = status.HTTP_404_NOT_FOUND + return f"{method} is not a valid method. {SupportAIMethod.SUPPORTAI} or {SupportAIMethod.GRAPHRAG}" + from httpx import get as http_get ecc = ( - db_config.get("ecc", "http://eventual-consistency-service:8001") - + f"/{graphname}/consistency_status" + db_config.get("ecc", "http://localhost:8001") + + f"/{graphname}/consistency_status/{method}" ) LogWriter.info(f"Sending ECC request to: {ecc}") bg_tasks.add_task( diff --git a/copilot/app/supportai/supportai.py b/copilot/app/supportai/supportai.py new file mode 100644 index 00000000..e96663a3 --- /dev/null +++ b/copilot/app/supportai/supportai.py @@ -0,0 +1,185 @@ +import json +import uuid + +from pyTigerGraph import TigerGraphConnection + +from common.py_schemas.schemas import ( + # CoPilotResponse, + CreateIngestConfig, + # LoadingInfo, + # SupportAIInitConfig, + # SupportAIMethod, + # SupportAIQuestion, +) + + +def init_supportai(conn: TigerGraphConnection, graphname: str) -> tuple[dict, dict]: + # need to open the file using the absolute path + file_path = "common/gsql/supportai/SupportAI_Schema.gsql" + with open(file_path, "r") as f: + schema = f.read() + schema_res = conn.gsql( + """USE GRAPH {}\n{}\nRUN SCHEMA_CHANGE JOB add_supportai_schema""".format( + graphname, schema + ) + ) + + file_path = "common/gsql/supportai/SupportAI_IndexCreation.gsql" + with open(file_path) as f: + index = f.read() + index_res = conn.gsql( + """USE GRAPH {}\n{}\nRUN SCHEMA_CHANGE JOB add_supportai_indexes""".format( + graphname, index + ) + ) + + file_path = "common/gsql/supportai/Scan_For_Updates.gsql" + with open(file_path) as f: + scan_for_updates = f.read() + res = conn.gsql( + "USE GRAPH " + + conn.graphname + + "\n" + + scan_for_updates + + "\n INSTALL QUERY Scan_For_Updates" + ) + + file_path = "common/gsql/supportai/Update_Vertices_Processing_Status.gsql" + with open(file_path) as f: + update_vertices = f.read() + res = conn.gsql( + "USE GRAPH " + + conn.graphname + + "\n" + + update_vertices + + "\n INSTALL QUERY Update_Vertices_Processing_Status" + ) + + return schema_res, index_res + + +def create_ingest( + graphname: str, + ingest_config: CreateIngestConfig, + conn: TigerGraphConnection, +): + if ingest_config.file_format.lower() == "json": + file_path = "common/gsql/supportai/SupportAI_InitialLoadJSON.gsql" + + with open(file_path) as f: + ingest_template = f.read() + ingest_template = ingest_template.replace("@uuid@", str(uuid.uuid4().hex)) + doc_id = ingest_config.loader_config.get("doc_id_field", "doc_id") + doc_text = ingest_config.loader_config.get("content_field", "content") + ingest_template = ingest_template.replace('"doc_id"', '"{}"'.format(doc_id)) + ingest_template = ingest_template.replace('"content"', '"{}"'.format(doc_text)) + + if ingest_config.file_format.lower() == "csv": + file_path = "common/gsql/supportai/SupportAI_InitialLoadCSV.gsql" + + with open(file_path) as f: + ingest_template = f.read() + ingest_template = ingest_template.replace("@uuid@", str(uuid.uuid4().hex)) + separator = ingest_config.get("separator", "|") + header = ingest_config.get("header", "true") + eol = ingest_config.get("eol", "\n") + quote = ingest_config.get("quote", "double") + ingest_template = ingest_template.replace('"|"', '"{}"'.format(separator)) + ingest_template = ingest_template.replace('"true"', '"{}"'.format(header)) + ingest_template = ingest_template.replace('"\\n"', '"{}"'.format(eol)) + ingest_template = ingest_template.replace('"double"', '"{}"'.format(quote)) + + file_path = "common/gsql/supportai/SupportAI_DataSourceCreation.gsql" + + with open(file_path) as f: + data_stream_conn = f.read() + + # assign unique identifier to the data stream connection + + data_stream_conn = data_stream_conn.replace( + "@source_name@", "SupportAI_" + graphname + "_" + str(uuid.uuid4().hex) + ) + + # check the data source and create the appropriate connection + if ingest_config.data_source.lower() == "s3": + data_conn = ingest_config.data_source_config + if ( + data_conn.get("aws_access_key") is None + or data_conn.get("aws_secret_key") is None + ): + raise Exception("AWS credentials not provided") + connector = { + "type": "s3", + "access.key": data_conn["aws_access_key"], + "secret.key": data_conn["aws_secret_key"], + } + + data_stream_conn = data_stream_conn.replace( + "@source_config@", json.dumps(connector) + ) + + elif ingest_config.data_source.lower() == "azure": + if ingest_config.data_source_config.get("account_key") is not None: + connector = { + "type": "abs", + "account.key": ingest_config.data_source_config["account_key"], + } + elif ingest_config.data_source_config.get("client_id") is not None: + # verify that the client secret is also provided + if ingest_config.data_source_config.get("client_secret") is None: + raise Exception("Client secret not provided") + # verify that the tenant id is also provided + if ingest_config.data_source_config.get("tenant_id") is None: + raise Exception("Tenant id not provided") + connector = { + "type": "abs", + "client.id": ingest_config.data_source_config["client_id"], + "client.secret": ingest_config.data_source_config["client_secret"], + "tenant.id": ingest_config.data_source_config["tenant_id"], + } + else: + raise Exception("Azure credentials not provided") + data_stream_conn = data_stream_conn.replace( + "@source_config@", json.dumps(connector) + ) + elif ingest_config.data_source.lower() == "gcs": + # verify that the correct fields are provided + if ingest_config.data_source_config.get("project_id") is None: + raise Exception("Project id not provided") + if ingest_config.data_source_config.get("private_key_id") is None: + raise Exception("Private key id not provided") + if ingest_config.data_source_config.get("private_key") is None: + raise Exception("Private key not provided") + if ingest_config.data_source_config.get("client_email") is None: + raise Exception("Client email not provided") + connector = { + "type": "gcs", + "project_id": ingest_config.data_source_config["project_id"], + "private_key_id": ingest_config.data_source_config["private_key_id"], + "private_key": ingest_config.data_source_config["private_key"], + "client_email": ingest_config.data_source_config["client_email"], + } + data_stream_conn = data_stream_conn.replace( + "@source_config@", json.dumps(connector) + ) + else: + raise Exception("Data source not implemented") + + load_job_created = conn.gsql("USE GRAPH {}\n".format(graphname) + ingest_template) + + data_source_created = conn.gsql( + "USE GRAPH {}\n".format(graphname) + data_stream_conn + ) + + return { + "load_job_id": load_job_created.split(":")[1] + .strip(" [") + .strip(" ") + .strip(".") + .strip("]"), + "data_source_id": data_source_created.split(":")[1] + .strip(" [") + .strip(" ") + .strip(".") + .strip("]"), + } diff --git a/copilot/requirements.txt b/copilot/requirements.txt index 7a8bd83f..4a5ac3d1 100644 --- a/copilot/requirements.txt +++ b/copilot/requirements.txt @@ -1,155 +1,179 @@ -aiohttp==3.9.3 +aiochannel==1.2.1 +aiohappyeyeballs==2.3.5 +aiohttp==3.10.3 aiosignal==1.3.1 -annotated-types==0.5.0 -anyio==3.7.1 +annotated-types==0.7.0 +anyio==4.4.0 appdirs==1.4.4 argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 async-timeout==4.0.3 asyncer==0.0.7 -attrs==23.1.0 -azure-core==1.30.1 -azure-storage-blob==12.19.1 +attrs==24.2.0 +azure-core==1.30.2 +azure-storage-blob==12.22.0 backoff==2.2.1 -beautifulsoup4==4.12.2 -boto3==1.28.83 -botocore==1.31.83 -cachetools==5.3.2 -certifi==2023.7.22 -cffi==1.16.0 +beautifulsoup4==4.12.3 +boto3==1.34.160 +botocore==1.34.160 +cachetools==5.4.0 +certifi==2024.7.4 +cffi==1.17.0 chardet==5.2.0 -charset-normalizer==3.2.0 +charset-normalizer==3.3.2 click==8.1.7 -cryptography==42.0.5 -dataclasses-json==0.5.14 -distro==1.8.0 +contourpy==1.2.1 +cryptography==43.0.0 +cycler==0.12.1 +dataclasses-json==0.6.7 +deepdiff==7.0.1 +distro==1.9.0 docker-pycreds==0.4.0 -docstring_parser==0.16 -emoji==2.8.0 +docstring-parser==0.16 +emoji==2.12.1 environs==9.5.0 -exceptiongroup==1.1.3 -fastapi==0.103.1 +exceptiongroup==1.2.2 +fastapi==0.112.0 filelock==3.15.4 filetype==1.2.0 -frozenlist==1.4.0 -fsspec==2024.6.0 +fonttools==4.53.1 +frozenlist==1.4.1 +fsspec==2024.6.1 gitdb==4.0.11 -GitPython==3.1.40 -google-api-core==2.14.0 -google-auth==2.23.4 -google-cloud-aiplatform==1.52.0 -google-cloud-bigquery==3.13.0 -google-cloud-core==2.3.3 -google-cloud-resource-manager==1.10.4 -google-cloud-storage==2.13.0 +gitpython==3.1.43 +google-api-core==2.19.1 +google-auth==2.33.0 +google-cloud-aiplatform==1.62.0 +google-cloud-bigquery==3.25.0 +google-cloud-core==2.4.1 +google-cloud-resource-manager==1.12.5 +google-cloud-storage==2.18.2 google-crc32c==1.5.0 -google-resumable-media==2.6.0 -googleapis-common-protos==1.61.0 -greenlet==2.0.2 -groq==0.5.0 -grpc-google-iam-v1==0.12.7 -grpcio==1.59.2 -grpcio-status==1.59.2 +google-resumable-media==2.7.2 +googleapis-common-protos==1.63.2 +greenlet==3.0.3 +groq==0.9.0 +grpc-google-iam-v1==0.13.1 +grpcio==1.63.0 +grpcio-status==1.63.0 h11==0.14.0 -httpcore==0.18.0 -httptools==0.6.0 -httpx==0.25.0 -huggingface-hub==0.23.0 +httpcore==1.0.5 +httptools==0.6.1 +httpx==0.27.0 +huggingface-hub==0.24.5 ibm-cos-sdk==2.13.6 ibm-cos-sdk-core==2.13.6 ibm-cos-sdk-s3transfer==2.13.6 -ibm_watsonx_ai==1.0.11 -idna==3.4 -importlib_metadata==8.0.0 +ibm-watsonx-ai==1.1.5 +idna==3.7 +importlib-metadata==8.2.0 iniconfig==2.0.0 isodate==0.6.1 +jiter==0.5.0 jmespath==1.0.1 -joblib==1.3.2 -jq==1.6.0 +joblib==1.4.2 +jq==1.7.0 jsonpatch==1.33 -jsonpointer==2.4 -langchain==0.2.11 -langchain-community==0.2.10 -langchain-core==0.2.25 -langchain-experimental==0.0.63 -langchain-groq==0.1.8 -langchain-ibm==0.1.11 +jsonpath-python==1.0.6 +jsonpointer==3.0.0 +kiwisolver==1.4.5 +langchain==0.2.13 +langchain-community==0.2.12 +langchain-core==0.2.30 +langchain-experimental==0.0.64 +langchain-groq==0.1.9 +langchain-ibm==0.1.12 +langchain-milvus==0.1.4 +langchain-openai==0.1.21 langchain-text-splitters==0.2.2 -langchain_milvus==0.1.3 -langchain_openai==0.1.19 -langchainhub==0.1.20 +langchainhub==0.1.21 langdetect==1.0.9 -langgraph==0.1.16 -langsmith==0.1.94 +langgraph==0.2.3 +langgraph-checkpoint==1.0.2 +langsmith==0.1.99 +levenshtein==0.25.1 lomond==0.3.3 -lxml==4.9.3 -marshmallow==3.20.1 -matplotlib==3.9.1 -minio==7.2.5 -multidict==6.0.4 +lxml==5.3.0 +marshmallow==3.21.3 +matplotlib==3.9.2 +milvus-lite==2.4.9 +minio==7.2.7 +multidict==6.0.5 mypy-extensions==1.0.0 -nltk==3.8.1 +nest-asyncio==1.6.0 +nltk==3.8.2 numpy==1.26.4 -openai==1.37.1 -orjson==3.9.15 -packaging==23.2 -pandas==2.1.1 +openai==1.40.6 +ordered-set==4.1.0 +orjson==3.10.7 +packaging==24.1 +pandas==2.1.4 pathtools==0.1.2 +pillow==10.4.0 +platformdirs==4.2.2 pluggy==1.5.0 -prometheus_client==0.20.0 -proto-plus==1.22.3 -protobuf==4.24.4 -psutil==5.9.6 -pyarrow==15.0.1 -pyasn1==0.5.0 -pyasn1-modules==0.3.0 -pycparser==2.21 +prometheus-client==0.20.0 +proto-plus==1.24.0 +protobuf==5.27.3 +psutil==6.0.0 +pyarrow==17.0.0 +pyasn1==0.6.0 +pyasn1-modules==0.4.0 +pycparser==2.22 pycryptodome==3.20.0 -pydantic==2.3.0 -pydantic_core==2.6.3 -pygit2==1.13.2 -pymilvus==2.4.4 -pytest==8.2.0 +pydantic==2.8.2 +pydantic-core==2.20.1 +pygit2==1.15.1 +pymilvus==2.4.5 +pyparsing==3.1.2 +pypdf==4.3.1 +pytest==8.3.2 python-dateutil==2.9.0.post0 -python-dotenv==1.0.0 -python-iso639==2023.6.15 +python-dotenv==1.0.1 +python-iso639==2024.4.27 python-magic==0.4.27 -pyTigerDriver==1.0.15 -pyTigerGraph==1.6.2 -pytz==2023.3.post1 -PyYAML==6.0.1 -rapidfuzz==3.4.0 -regex==2023.10.3 +pytigerdriver==1.0.15 +pytigergraph==1.6.5 +pytz==2024.1 +pyyaml==6.0.2 +rapidfuzz==3.9.6 +regex==2024.7.24 requests==2.32.2 +requests-toolbelt==1.0.0 rsa==4.9 -s3transfer==0.7.0 +s3transfer==0.10.2 scikit-learn==1.5.1 -sentry-sdk==1.32.0 +scipy==1.14.0 +sentry-sdk==2.13.0 setproctitle==1.3.3 -shapely==2.0.2 +setuptools==72.2.0 +shapely==2.0.5 six==1.16.0 smmap==5.0.1 -sniffio==1.3.0 -soupsieve==2.5 -SQLAlchemy==2.0.20 -starlette==0.27.0 +sniffio==1.3.1 +soupsieve==2.6 +sqlalchemy==2.0.32 +starlette==0.37.2 tabulate==0.9.0 -tenacity==8.2.3 +tenacity==8.5.0 +threadpoolctl==3.5.0 tiktoken==0.7.0 -tqdm==4.66.1 -types-requests==2.31.0.6 +tqdm==4.66.5 +types-requests==2.32.0.20240712 types-urllib3==1.26.25.14 +typing-extensions==4.12.2 typing-inspect==0.9.0 -typing_extensions==4.8.0 -tzdata==2023.3 -ujson==5.9.0 -unstructured==0.10.23 -urllib3==1.26.18 -uvicorn==0.23.2 -uvloop==0.17.0 -validators==0.22.0 -wandb==0.15.12 -watchfiles==0.20.0 -websockets==11.0.3 -yarl==1.9.2 -zipp==3.19.2 \ No newline at end of file +tzdata==2024.1 +ujson==5.10.0 +unstructured==0.15.1 +unstructured-client==0.25.5 +urllib3==2.2.2 +uvicorn==0.30.6 +uvloop==0.19.0 +validators==0.33.0 +wandb==0.17.6 +watchfiles==0.23.0 +websockets==12.0 +wrapt==1.16.0 +yarl==1.9.4 +zipp==3.20.0 diff --git a/docker-compose.yml b/docker-compose.yml index 2d03dcbe..058c0d77 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -21,7 +21,7 @@ services: eventual-consistency-service: image: tigergraphml/ecc:latest - container_name: eventual-consistency-service + # container_name: eventual-consistency-service build: context: . dockerfile: eventual-consistency-service/Dockerfile @@ -54,14 +54,12 @@ services: # image: tigergraphml/report-service:latest # container_name: report-service # build: - # context: . - # dockerfile: report-service/Dockerfile + # context: chat-history/ + # dockerfile: Dockerfile # ports: # - 8002:8002 # environment: - # LLM_CONFIG: "/code/configs/llm_config.json" - # DB_CONFIG: "/code/configs/db_config.json" - # MILVUS_CONFIG: "/code/configs/milvus_config.json" + # CONFIG: "/configs/config.json" # LOGLEVEL: "INFO" # volumes: # - ./configs/:/code/configs diff --git a/eventual-consistency-service/app/ecc_util.py b/eventual-consistency-service/app/ecc_util.py new file mode 100644 index 00000000..bccadd77 --- /dev/null +++ b/eventual-consistency-service/app/ecc_util.py @@ -0,0 +1,55 @@ +from common.chunkers import character_chunker, regex_chunker, semantic_chunker +from common.config import doc_processing_config, embedding_service, llm_config +from common.llm_services import ( + AWS_SageMaker_Endpoint, + AWSBedrock, + AzureOpenAI, + GoogleVertexAI, + Groq, + HuggingFaceEndpoint, + Ollama, + OpenAI, +) + + +def get_chunker(): + if doc_processing_config.get("chunker") == "semantic": + chunker = semantic_chunker.SemanticChunker( + embedding_service, + doc_processing_config["chunker_config"].get("method", "percentile"), + doc_processing_config["chunker_config"].get("threshold", 0.95), + ) + elif doc_processing_config.get("chunker") == "regex": + chunker = regex_chunker.RegexChunker( + pattern=doc_processing_config["chunker_config"].get("pattern", "\\r?\\n") + ) + elif doc_processing_config.get("chunker") == "character": + chunker = character_chunker.CharacterChunker( + chunk_size=doc_processing_config["chunker_config"].get("chunk_size", 1024), + overlap_size=doc_processing_config["chunker_config"].get("overlap_size", 0), + ) + else: + raise ValueError("Invalid chunker type") + + return chunker + + +def get_llm_service(): + if llm_config["completion_service"]["llm_service"].lower() == "openai": + llm_provider = OpenAI(llm_config["completion_service"]) + elif llm_config["completion_service"]["llm_service"].lower() == "azure": + llm_provider = AzureOpenAI(llm_config["completion_service"]) + elif llm_config["completion_service"]["llm_service"].lower() == "sagemaker": + llm_provider = AWS_SageMaker_Endpoint(llm_config["completion_service"]) + elif llm_config["completion_service"]["llm_service"].lower() == "vertexai": + llm_provider = GoogleVertexAI(llm_config["completion_service"]) + elif llm_config["completion_service"]["llm_service"].lower() == "bedrock": + llm_provider = AWSBedrock(llm_config["completion_service"]) + elif llm_config["completion_service"]["llm_service"].lower() == "groq": + llm_provider = Groq(llm_config["completion_service"]) + elif llm_config["completion_service"]["llm_service"].lower() == "ollama": + llm_provider = Ollama(llm_config["completion_service"]) + elif llm_config["completion_service"]["llm_service"].lower() == "huggingface": + llm_provider = HuggingFaceEndpoint(llm_config["completion_service"]) + + return llm_provider diff --git a/eventual-consistency-service/app/eventual_consistency_checker.py b/eventual-consistency-service/app/eventual_consistency_checker.py index 007330bd..fa16694e 100644 --- a/eventual-consistency-service/app/eventual_consistency_checker.py +++ b/eventual-consistency-service/app/eventual_consistency_checker.py @@ -1,4 +1,3 @@ -import json import logging import time from typing import Dict, List @@ -367,4 +366,4 @@ def get_status(self): )[0] LogWriter.info(f"ECC_Status for graphname {self.graphname}: {status}") statuses[v_type] = status - return statuses \ No newline at end of file + return statuses diff --git a/eventual-consistency-service/app/graphrag/__init__.py b/eventual-consistency-service/app/graphrag/__init__.py new file mode 100644 index 00000000..953b2a0b --- /dev/null +++ b/eventual-consistency-service/app/graphrag/__init__.py @@ -0,0 +1 @@ +from .graph_rag import * diff --git a/eventual-consistency-service/app/graphrag/community_summarizer.py b/eventual-consistency-service/app/graphrag/community_summarizer.py new file mode 100644 index 00000000..2bef4095 --- /dev/null +++ b/eventual-consistency-service/app/graphrag/community_summarizer.py @@ -0,0 +1,44 @@ +import re + +from langchain_core.prompts import PromptTemplate + +from common.llm_services import LLM_Model +from common.py_schemas import CommunitySummary + +# src: https://github.com/microsoft/graphrag/blob/main/graphrag/index/graph/extractors/summarize/prompts.py +SUMMARIZE_PROMPT = PromptTemplate.from_template(""" +You are a helpful assistant responsible for generating a comprehensive summary of the data provided below. +Given one or two entities, and a list of descriptions, all related to the same entity or group of entities. +Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions. +If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary, but do not add any information that is not in the description. +Make sure it is written in third person, and include the entity names so we the have full context. + +####### +-Data- +Commuinty Title: {entity_name} +Description List: {description_list} +""") + +id_pat = re.compile(r"[_\d]*") + + +class CommunitySummarizer: + def __init__( + self, + llm_service: LLM_Model, + ): + self.llm_service = llm_service + + async def summarize(self, name: str, text: list[str]) -> CommunitySummary: + structured_llm = self.llm_service.model.with_structured_output(CommunitySummary) + chain = SUMMARIZE_PROMPT | structured_llm + + # remove iteration tags from name + name = id_pat.sub("", name) + summary = await chain.ainvoke( + { + "entity_name": name, + "description_list": text, + } + ) + return summary.summary diff --git a/eventual-consistency-service/app/graphrag/graph_rag.py b/eventual-consistency-service/app/graphrag/graph_rag.py new file mode 100644 index 00000000..ecca36b2 --- /dev/null +++ b/eventual-consistency-service/app/graphrag/graph_rag.py @@ -0,0 +1,441 @@ +import asyncio +import logging +import time +import traceback + +import httpx +from aiochannel import Channel +from graphrag import workers +from graphrag.util import ( + check_vertex_has_desc, + http_timeout, + init, + make_headers, + stream_ids, +) +from pyTigerGraph import TigerGraphConnection + +from common.config import embedding_service +from common.embeddings.milvus_embedding_store import MilvusEmbeddingStore +from common.extractors.BaseExtractor import BaseExtractor + +logger = logging.getLogger(__name__) + +consistency_checkers = {} + + +async def stream_docs( + conn: TigerGraphConnection, + docs_chan: Channel, + ttl_batches: int = 10, +): + """ + Streams the document contents into the docs_chan + """ + logger.info("streaming docs") + headers = make_headers(conn) + async with httpx.AsyncClient(timeout=http_timeout) as client: + for i in range(ttl_batches): + doc_ids = await stream_ids(conn, "Document", i, ttl_batches) + if doc_ids["error"]: + # continue to the next batch. + # These docs will not be marked as processed, so the ecc will process it eventually. + continue + + for d in doc_ids["ids"]: + try: + res = await client.get( + f"{conn.restppUrl}/query/{conn.graphname}/StreamDocContent/", + params={"doc": d}, + headers=headers, + ) + if res.status_code != 200: + # continue to the next doc. + # This doc will not be marked as processed, so the ecc will process it eventually. + continue + logger.info("steam_docs writes to docs") + await docs_chan.put(res.json()["results"][0]["DocContent"][0]) + except Exception as e: + exc = traceback.format_exc() + logger.error(f"Error retrieving doc: {d} --> {e}\n{exc}") + continue # try retrieving the next doc + + logger.info("stream_docs done") + # close the docs chan -- this function is the only sender + logger.info("closing docs chan") + docs_chan.close() + + +async def chunk_docs( + conn: TigerGraphConnection, + docs_chan: Channel, + embed_chan: Channel, + upsert_chan: Channel, + extract_chan: Channel, +): + """ + Creates and starts one worker for each document + in the docs channel. + """ + logger.info("Reading from docs channel") + doc_tasks = [] + async with asyncio.TaskGroup() as grp: + async for content in docs_chan: + v_id = content["v_id"] + txt = content["attributes"]["text"] + # send the document to be embedded + logger.info("chunk writes to extract") + await embed_chan.put((v_id, txt, "Document")) + + task = grp.create_task( + workers.chunk_doc(conn, content, upsert_chan, embed_chan, extract_chan) + ) + doc_tasks.append(task) + + logger.info("chunk_docs done") + + # close the extract chan -- chunk_doc is the only sender + # and chunk_doc calls are kicked off from here + logger.info("closing extract_chan") + extract_chan.close() + + +async def upsert(upsert_chan: Channel): + """ + Creates and starts one worker for each upsert job + chan expects: + (func, args) <- q.get() + """ + + logger.info("Reading from upsert channel") + # consume task queue + async with asyncio.TaskGroup() as grp: + async for func, args in upsert_chan: + logger.info(f"{func.__name__}, {args[1]}") + # execute the task + grp.create_task(func(*args)) + + logger.info(f"upsert done") + + +async def embed( + embed_chan: Channel, index_stores: dict[str, MilvusEmbeddingStore], graphname: str +): + """ + Creates and starts one worker for each embed job + chan expects: + (v_id, content, index_name) <- q.get() + """ + logger.info("Reading from embed channel") + async with asyncio.TaskGroup() as grp: + # consume task queue + async for v_id, content, index_name in embed_chan: + embedding_store = index_stores[f"{graphname}_{index_name}"] + logger.info(f"Embed to {graphname}_{index_name}: {v_id}") + grp.create_task( + workers.embed( + embedding_service, + embedding_store, + v_id, + content, + ) + ) + + logger.info(f"embed done") + + +async def extract( + extract_chan: Channel, + upsert_chan: Channel, + embed_chan: Channel, + extractor: BaseExtractor, + conn: TigerGraphConnection, +): + """ + Creates and starts one worker for each extract job + chan expects: + (chunk , chunk_id) <- q.get() + """ + logger.info("Reading from extract channel") + # consume task queue + async with asyncio.TaskGroup() as grp: + async for item in extract_chan: + grp.create_task( + workers.extract(upsert_chan, embed_chan, extractor, conn, *item) + ) + + logger.info(f"extract done") + + logger.info("closing upsert and embed chan") + upsert_chan.close() + embed_chan.close() + + +async def stream_entities( + conn: TigerGraphConnection, + entity_chan: Channel, + ttl_batches: int = 50, +): + """ + Streams entity IDs from the grpah + """ + logger.info("streaming entities") + for i in range(ttl_batches): + ids = await stream_ids(conn, "Entity", i, ttl_batches) + if ids["error"]: + # continue to the next batch. + # These docs will not be marked as processed, so the ecc will process it eventually. + continue + + for i in ids["ids"]: + if len(i) > 0: + await entity_chan.put(i) + + logger.info("stream_enities done") + # close the docs chan -- this function is the only sender + logger.info("closing entities chan") + entity_chan.close() + + +async def resolve_entities( + conn: TigerGraphConnection, + emb_store: MilvusEmbeddingStore, + entity_chan: Channel, + upsert_chan: Channel, +): + """ + Merges entities into their ResolvedEntity form + Groups what should be the same entity into a resolved entity (e.g. V_type and VType should be merged) + + Copies edges between entities to their respective ResolvedEntities + """ + async with asyncio.TaskGroup() as grp: + # for every entity + async for entity_id in entity_chan: + grp.create_task( + workers.resolve_entity(conn, upsert_chan, emb_store, entity_id) + ) + logger.info("closing upsert_chan") + upsert_chan.close() + + # Copy RELATIONSHIP edges to RESOLVED_RELATIONSHIP + headers = make_headers(conn) + async with httpx.AsyncClient(timeout=http_timeout) as client: + res = await client.get( + f"{conn.restppUrl}/query/{conn.graphname}/ResolveRelationships/", + headers=headers, + ) + res.raise_for_status() + + +async def communities(conn: TigerGraphConnection, comm_process_chan: Channel): + """ + Run louvain + """ + # first pass: Group ResolvedEntities into Communities + logger.info("Initializing Communities (first louvain pass)") + headers = make_headers(conn) + async with httpx.AsyncClient(timeout=None) as client: + res = await client.get( + f"{conn.restppUrl}/query/{conn.graphname}/graphrag_louvain_init", + params={"n_batches": 1}, + headers=headers, + ) + res.raise_for_status() + # get the modularity + async with httpx.AsyncClient(timeout=None) as client: + res = await client.get( + f"{conn.restppUrl}/query/{conn.graphname}/modularity", + params={"iteration": 1, "batch_num": 1}, + headers=headers, + ) + res.raise_for_status() + mod = res.json()["results"][0]["mod"] + logger.info(f"****mod pass 1: {mod}") + await stream_communities(conn, 1, comm_process_chan) + + # nth pass: Iterate on Resolved Entities until modularity stops increasing + prev_mod = -10 + i = 0 + while abs(prev_mod - mod) > 0.0000001 and prev_mod != 0: + prev_mod = mod + i += 1 + logger.info(f"Running louvain on Communities (iteration: {i})") + # louvain pass + async with httpx.AsyncClient(timeout=None) as client: + res = await client.get( + f"{conn.restppUrl}/query/{conn.graphname}/graphrag_louvain_communities", + params={"n_batches": 1, "iteration": i}, + headers=headers, + ) + + res.raise_for_status() + + # get the modularity + async with httpx.AsyncClient(timeout=None) as client: + res = await client.get( + f"{conn.restppUrl}/query/{conn.graphname}/modularity", + params={"iteration": i + 1, "batch_num": 1}, + headers=headers, + ) + res.raise_for_status() + mod = res.json()["results"][0]["mod"] + logger.info(f"*** mod pass {i+1}: {mod} (diff= {abs(prev_mod - mod)})") + + # write iter to chan for layer to be processed + await stream_communities(conn, i + 1, comm_process_chan) + + # TODO: erase last run since it's ∆q to the run before it will be small + logger.info("closing communities chan") + comm_process_chan.close() + + +async def stream_communities( + conn: TigerGraphConnection, + i: int, + comm_process_chan: Channel, +): + """ + Streams Community IDs from the grpah for a given iteration (from the channel) + """ + logger.info("streaming communities") + + headers = make_headers(conn) + # TODO: + # can only do one layer at a time to ensure that every child community has their descriptions + + # async for i in community_chan: + # get the community from that layer + async with httpx.AsyncClient(timeout=None) as client: + resp = await client.get( + f"{conn.restppUrl}/query/{conn.graphname}/stream_community", + params={"iter": i}, + headers=headers, + ) + resp.raise_for_status() + comms = resp.json()["results"][0]["Comms"] + + for c in comms: + await comm_process_chan.put((i, c["v_id"])) + + # Wait for all communities for layer i to be processed before doing next layer + # all community descriptions must be populated before the next layer can be processed + if len(comms) > 0: + while not await check_vertex_has_desc(conn, i): + logger.info(f"Waiting for layer{i} to finish processing") + await asyncio.sleep(5) + await asyncio.sleep(3) + + logger.info("stream_communities done") + logger.info("closing comm_process_chan") + + +async def summarize_communities( + conn: TigerGraphConnection, + comm_process_chan: Channel, + upsert_chan: Channel, + embed_chan: Channel, +): + async with asyncio.TaskGroup() as tg: + async for c in comm_process_chan: + tg.create_task(workers.process_community(conn, upsert_chan, embed_chan, *c)) + + logger.info("closing upsert_chan") + upsert_chan.close() + embed_chan.close() + + +async def run(graphname: str, conn: TigerGraphConnection): + """ + Set up GraphRAG: + - Install necessary queries. + - Process the documents into: + - chunks + - embeddings + - entities/relationships (and their embeddings) + - upsert everything to the graph + - Resolve Entities + Ex: "Vincent van Gogh" and "van Gogh" should be resolved to "Vincent van Gogh" + """ + + extractor, index_stores = await init(conn) + init_start = time.perf_counter() + + doc_process_switch = True + entity_resolution_switch = True + community_detection_switch = True + if doc_process_switch: + logger.info("Doc Processing Start") + docs_chan = Channel(1) + embed_chan = Channel(100) + upsert_chan = Channel(100) + extract_chan = Channel(100) + async with asyncio.TaskGroup() as grp: + # get docs + grp.create_task(stream_docs(conn, docs_chan, 10)) + # process docs + grp.create_task( + chunk_docs(conn, docs_chan, embed_chan, upsert_chan, extract_chan) + ) + # upsert chunks + grp.create_task(upsert(upsert_chan)) + # embed + grp.create_task(embed(embed_chan, index_stores, graphname)) + # extract entities + grp.create_task( + extract(extract_chan, upsert_chan, embed_chan, extractor, conn) + ) + init_end = time.perf_counter() + logger.info("Doc Processing End") + + # Entity Resolution + entity_start = time.perf_counter() + + if entity_resolution_switch: + logger.info("Entity Processing Start") + entities_chan = Channel(100) + upsert_chan = Channel(100) + async with asyncio.TaskGroup() as grp: + grp.create_task(stream_entities(conn, entities_chan, 50)) + grp.create_task( + resolve_entities( + conn, + index_stores[f"{conn.graphname}_Entity"], + entities_chan, + upsert_chan, + ) + ) + grp.create_task(upsert(upsert_chan)) + entity_end = time.perf_counter() + logger.info("Entity Processing End") + + # Community Detection + community_start = time.perf_counter() + if community_detection_switch: + logger.info("Community Processing Start") + upsert_chan = Channel(10) + comm_process_chan = Channel(100) + upsert_chan = Channel(100) + embed_chan = Channel(100) + async with asyncio.TaskGroup() as grp: + # run louvain + # grp.create_task(communities(conn, communities_chan)) + grp.create_task(communities(conn, comm_process_chan)) + # get the communities + # grp.create_task( stream_communities(conn, communities_chan, comm_process_chan)) + # summarize each community + grp.create_task( + summarize_communities(conn, comm_process_chan, upsert_chan, embed_chan) + ) + grp.create_task(upsert(upsert_chan)) + grp.create_task(embed(embed_chan, index_stores, graphname)) + + community_end = time.perf_counter() + logger.info("Community Processing End") + + # Community Summarization + end = time.perf_counter() + logger.info(f"DONE. graphrag system initializer dT: {init_end-init_start}") + logger.info(f"DONE. graphrag entity resolution dT: {entity_end-entity_start}") + logger.info(f"DONE. graphrag community initializer dT: {community_end-community_start}") + logger.info(f"DONE. graphrag.run() total time elaplsed: {end-init_start}") diff --git a/eventual-consistency-service/app/graphrag/util.py b/eventual-consistency-service/app/graphrag/util.py new file mode 100644 index 00000000..186ab11a --- /dev/null +++ b/eventual-consistency-service/app/graphrag/util.py @@ -0,0 +1,293 @@ +import asyncio +import base64 +import json +import logging +import re +import traceback +from glob import glob + +import httpx +from graphrag import workers +from pyTigerGraph import TigerGraphConnection + +from common.config import ( + doc_processing_config, + embedding_service, + get_llm_service, + llm_config, + milvus_config, +) +from common.embeddings.milvus_embedding_store import MilvusEmbeddingStore +from common.extractors import GraphExtractor, LLMEntityRelationshipExtractor +from common.extractors.BaseExtractor import BaseExtractor +from common.logs.logwriter import LogWriter + +logger = logging.getLogger(__name__) +http_timeout = httpx.Timeout(15.0) + + +async def install_queries( + requried_queries: list[str], + conn: TigerGraphConnection, +): + # queries that are currently installed + installed_queries = [q.split("/")[-1] for q in conn.getEndpoints(dynamic=True)] + + # doesn't need to be parallel since tg only does it one at a time + for q in requried_queries: + # only install n queries at a time (n=n_workers) + q_name = q.split("/")[-1] + # if the query is not installed, install it + if q_name not in installed_queries: + res = await workers.install_query(conn, q) + # stop system if a required query doesn't install + if res["error"]: + raise Exception(res["message"]) + + +async def init_embedding_index(s: MilvusEmbeddingStore, vertex_field: str): + content = "init" + vec = embedding_service.embed_query(content) + await s.aadd_embeddings([(content, vec)], [{vertex_field: content}]) + s.remove_embeddings(expr=f"{vertex_field} in ['{content}']") + + +async def init( + conn: TigerGraphConnection, +) -> tuple[BaseExtractor, dict[str, MilvusEmbeddingStore]]: + # install requried queries + requried_queries = [ + # "common/gsql/supportai/Scan_For_Updates", + # "common/gsql/supportai/Update_Vertices_Processing_Status", + # "common/gsql/supportai/ECC_Status", + # "common/gsql/supportai/Check_Nonexistent_Vertices", + "common/gsql/graphRAG/StreamIds", + "common/gsql/graphRAG/StreamDocContent", + "common/gsql/graphRAG/SetEpochProcessing", + "common/gsql/graphRAG/ResolveRelationships", + "common/gsql/graphRAG/get_community_children", + "common/gsql/graphRAG/communities_have_desc", + "common/gsql/graphRAG/louvain/graphrag_louvain_init", + "common/gsql/graphRAG/louvain/graphrag_louvain_communities", + "common/gsql/graphRAG/louvain/modularity", + "common/gsql/graphRAG/louvain/stream_community", + ] + # add louvain to queries + q = [x.split(".gsql")[0] for x in glob("common/gsql/graphRAG/louvain/*")] + requried_queries.extend(q) + await install_queries(requried_queries, conn) + + # extractor + if doc_processing_config.get("extractor") == "graphrag": + extractor = GraphExtractor() + elif doc_processing_config.get("extractor") == "llm": + extractor = LLMEntityRelationshipExtractor(get_llm_service(llm_config)) + else: + raise ValueError("Invalid extractor type") + vertex_field = milvus_config.get("vertex_field", "vertex_id") + index_names = milvus_config.get( + "indexes", + [ + "Document", + "DocumentChunk", + "Entity", + "Relationship", + # "Concept", + "Community", + ], + ) + index_stores = {} + async with asyncio.TaskGroup() as tg: + for index_name in index_names: + name = conn.graphname + "_" + index_name + s = MilvusEmbeddingStore( + embedding_service, + host=milvus_config["host"], + port=milvus_config["port"], + support_ai_instance=True, + collection_name=name, + username=milvus_config.get("username", ""), + password=milvus_config.get("password", ""), + vector_field=milvus_config.get("vector_field", "document_vector"), + text_field=milvus_config.get("text_field", "document_content"), + vertex_field=vertex_field, + drop_old=False, + ) + + LogWriter.info(f"Initializing {name}") + # init collection if it doesn't exist + if not s.check_collection_exists(): + tg.create_task(init_embedding_index(s, vertex_field)) + + index_stores[name] = s + + return extractor, index_stores + + +def make_headers(conn: TigerGraphConnection): + if conn.apiToken is None or conn.apiToken == "": + tkn = base64.b64encode(f"{conn.username}:{conn.password}".encode()).decode() + headers = {"Authorization": f"Basic {tkn}"} + else: + headers = {"Authorization": f"Bearer {conn.apiToken}"} + + return headers + + +async def stream_ids( + conn: TigerGraphConnection, v_type: str, current_batch: int, ttl_batches: int +) -> dict[str, str | list[str]]: + headers = make_headers(conn) + + try: + async with httpx.AsyncClient(timeout=http_timeout) as client: + res = await client.post( + f"{conn.restppUrl}/query/{conn.graphname}/StreamIds", + params={ + "current_batch": current_batch, + "ttl_batches": ttl_batches, + "v_type": v_type, + }, + headers=headers, + ) + ids = res.json()["results"][0]["@@ids"] + return {"error": False, "ids": ids} + + except Exception as e: + exc = traceback.format_exc() + LogWriter.error(f"/{conn.graphname}/query/StreamIds\nException Trace:\n{exc}") + + return {"error": True, "message": str(e)} + + +def map_attrs(attributes: dict): + # map attrs + attrs = {} + for k, v in attributes.items(): + if isinstance(v, tuple): + attrs[k] = {"value": v[0], "op": v[1]} + elif isinstance(v, dict): + attrs[k] = { + "value": {"keylist": list(v.keys()), "valuelist": list(v.values())} + } + else: + attrs[k] = {"value": v} + return attrs + + +def process_id(v_id: str): + v_id = v_id.replace(" ", "_").replace("/", "") + + has_func = re.compile(r"(.*)\(").findall(v_id) + if len(has_func) > 0: + v_id = has_func[0] + if v_id == "''" or v_id == '""': + return "" + + return v_id + + +async def upsert_vertex( + conn: TigerGraphConnection, + vertex_type: str, + vertex_id: str, + attributes: dict, +): + logger.info(f"Upsert vertex: {vertex_type} {vertex_id}") + vertex_id = vertex_id.replace(" ", "_") + attrs = map_attrs(attributes) + data = json.dumps({"vertices": {vertex_type: {vertex_id: attrs}}}) + headers = make_headers(conn) + async with httpx.AsyncClient(timeout=http_timeout) as client: + res = await client.post( + f"{conn.restppUrl}/graph/{conn.graphname}", data=data, headers=headers + ) + + res.raise_for_status() + + +async def check_vertex_exists(conn, v_id: str): + headers = make_headers(conn) + async with httpx.AsyncClient(timeout=http_timeout) as client: + res = await client.get( + f"{conn.restppUrl}/graph/{conn.graphname}/vertices/Entity/{v_id}", + headers=headers, + ) + + res.raise_for_status() + return res.json() + + +async def upsert_edge( + conn: TigerGraphConnection, + src_v_type: str, + src_v_id: str, + edge_type: str, + tgt_v_type: str, + tgt_v_id: str, + attributes: dict = None, +): + if attributes is None: + attrs = {} + else: + attrs = map_attrs(attributes) + src_v_id = src_v_id.replace(" ", "_") + tgt_v_id = tgt_v_id.replace(" ", "_") + data = json.dumps( + { + "edges": { + src_v_type: { + src_v_id: { + edge_type: { + tgt_v_type: { + tgt_v_id: attrs, + } + } + }, + } + } + } + ) + headers = make_headers(conn) + async with httpx.AsyncClient(timeout=http_timeout) as client: + res = await client.post( + f"{conn.restppUrl}/graph/{conn.graphname}", data=data, headers=headers + ) + res.raise_for_status() + + +async def get_commuinty_children(conn, i: int, c: str): + headers = make_headers(conn) + async with httpx.AsyncClient(timeout=None) as client: + resp = await client.get( + f"{conn.restppUrl}/query/{conn.graphname}/get_community_children", + params={"comm": c, "iter": i}, + headers=headers, + ) + resp.raise_for_status() + descrs = [] + for d in resp.json()["results"][0]["children"]: + desc = d["attributes"]["description"] + if i == 1 and all(len(x) == 0 for x in desc): + desc = [d["v_id"]] + elif len(desc) == 0: + desc = d["v_id"] + + descrs.append(desc) + + return descrs + + +async def check_vertex_has_desc(conn, i: int): + headers = make_headers(conn) + async with httpx.AsyncClient(timeout=None) as client: + resp = await client.get( + f"{conn.restppUrl}/query/{conn.graphname}/communities_have_desc", + params={"iter": i}, + headers=headers, + ) + resp.raise_for_status() + + res = resp.json()["results"][0]["all_have_desc"] + + return res diff --git a/eventual-consistency-service/app/graphrag/workers.py b/eventual-consistency-service/app/graphrag/workers.py new file mode 100644 index 00000000..755b1085 --- /dev/null +++ b/eventual-consistency-service/app/graphrag/workers.py @@ -0,0 +1,392 @@ +import base64 +import logging +import time +from urllib.parse import quote_plus + +import ecc_util +import httpx +from aiochannel import Channel +from graphrag import community_summarizer, util +from langchain_community.graphs.graph_document import GraphDocument, Node +from pyTigerGraph import TigerGraphConnection + +from common.config import milvus_config +from common.embeddings.embedding_services import EmbeddingModel +from common.embeddings.milvus_embedding_store import MilvusEmbeddingStore +from common.extractors.BaseExtractor import BaseExtractor +from common.logs.logwriter import LogWriter + +vertex_field = milvus_config.get("vertex_field", "vertex_id") + +logger = logging.getLogger(__name__) + + +async def install_query( + conn: TigerGraphConnection, query_path: str +) -> dict[str, httpx.Response | str | None]: + LogWriter.info(f"Installing query {query_path}") + with open(f"{query_path}.gsql", "r") as f: + query = f.read() + + query_name = query_path.split("/")[-1] + query = f"""\ +USE GRAPH {conn.graphname} +{query} +INSTALL QUERY {query_name}""" + tkn = base64.b64encode(f"{conn.username}:{conn.password}".encode()).decode() + headers = {"Authorization": f"Basic {tkn}"} + + async with httpx.AsyncClient(timeout=None) as client: + res = await client.post( + conn.gsUrl + "/gsqlserver/gsql/file", + data=quote_plus(query.encode("utf-8")), + headers=headers, + ) + + if "error" in res.text.lower(): + LogWriter.error(res.text) + return { + "result": None, + "error": True, + "message": f"Failed to install query {query_name}", + } + + return {"result": res, "error": False} + + +async def chunk_doc( + conn: TigerGraphConnection, + doc: dict[str, str], + upsert_chan: Channel, + embed_chan: Channel, + extract_chan: Channel, +): + """ + Chunks a document. + Places the resulting chunks into the upsert channel (to be upserted to TG) + and the embed channel (to be embedded and written to the vector store) + """ + chunker = ecc_util.get_chunker() + chunks = chunker.chunk(doc["attributes"]["text"]) + v_id = util.process_id(doc["v_id"]) + logger.info(f"Chunking {v_id}") + for i, chunk in enumerate(chunks): + chunk_id = f"{v_id}_chunk_{i}" + # send chunks to be upserted (func, args) + logger.info("chunk writes to upsert_chan") + await upsert_chan.put((upsert_chunk, (conn, v_id, chunk_id, chunk))) + + # send chunks to be embedded + logger.info("chunk writes to embed_chan") + await embed_chan.put((v_id, chunk, "DocumentChunk")) + + # send chunks to have entities extracted + logger.info("chunk writes to extract_chan") + await extract_chan.put((chunk, chunk_id)) + + return doc["v_id"] + + +async def upsert_chunk(conn: TigerGraphConnection, doc_id, chunk_id, chunk): + logger.info(f"Upserting chunk {chunk_id}") + date_added = int(time.time()) + await util.upsert_vertex( + conn, + "DocumentChunk", + chunk_id, + attributes={"epoch_added": date_added, "idx": int(chunk_id.split("_")[-1])}, + ) + await util.upsert_vertex( + conn, + "Content", + chunk_id, + attributes={"text": chunk, "epoch_added": date_added}, + ) + await util.upsert_edge( + conn, "DocumentChunk", chunk_id, "HAS_CONTENT", "Content", chunk_id + ) + await util.upsert_edge( + conn, "Document", doc_id, "HAS_CHILD", "DocumentChunk", chunk_id + ) + if int(chunk_id.split("_")[-1]) > 0: + await util.upsert_edge( + conn, + "DocumentChunk", + chunk_id, + "IS_AFTER", + "DocumentChunk", + doc_id + "_chunk_" + str(int(chunk_id.split("_")[-1]) - 1), + ) + + +async def embed( + embed_svc: EmbeddingModel, + embed_store: MilvusEmbeddingStore, + v_id: str, + content: str, +): + """ + Args: + graphname: str + the name of the graph the documents are in + embed_svc: EmbeddingModel + The class used to vectorize text + embed_store: + The class used to store the vectore to a vector DB + v_id: str + the vertex id that will be embedded + content: str + the content of the document/chunk + index_name: str + the vertex index to write to + """ + logger.info(f"Embedding {v_id}") + + vec = await embed_svc.aembed_query(content) + await embed_store.aadd_embeddings([(content, vec)], [{vertex_field: v_id}]) + + +async def get_vert_desc(conn, v_id, node: Node): + desc = [node.properties.get("description", "")] + exists = await util.check_vertex_exists(conn, v_id) + # if vertex exists, get description content and append this description to it + if not exists["error"]: + # deduplicate descriptions + desc.extend(exists["results"][0]["attributes"]["description"]) + desc = list(set(desc)) + return desc + + +async def extract( + upsert_chan: Channel, + embed_chan: Channel, + extractor: BaseExtractor, + conn: TigerGraphConnection, + chunk: str, + chunk_id: str, +): + logger.info(f"Extracting chunk: {chunk_id}") + extracted: list[GraphDocument] = await extractor.aextract(chunk) + # upsert nodes and edges to the graph + for doc in extracted: + for node in doc.nodes: + logger.info(f"extract writes entity vert to upsert\nNode: {node.id}") + v_id = util.process_id(str(node.id)) + if len(v_id) == 0: + continue + desc = await get_vert_desc(conn, v_id, node) + + # embed the entity + # embed with the v_id if the description is blank + if len(desc[0]): + await embed_chan.put((v_id, v_id, "Entity")) + else: + # (v_id, content, index_name) + await embed_chan.put((v_id, desc[0], "Entity")) + + await upsert_chan.put( + ( + util.upsert_vertex, # func to call + ( + conn, + "Entity", # v_type + v_id, # v_id + { # attrs + "description": desc, + "epoch_added": int(time.time()), + }, + ), + ) + ) + + # link the entity to the chunk it came from + logger.info("extract writes contains edge to upsert") + await upsert_chan.put( + ( + util.upsert_edge, + ( + conn, + "DocumentChunk", # src_type + chunk_id, # src_id + "CONTAINS_ENTITY", # edge_type + "Entity", # tgt_type + v_id, # tgt_id + None, # attributes + ), + ) + ) + + for edge in doc.relationships: + logger.info( + f"extract writes relates edge to upsert\n{edge.source.id} -({edge.type})-> {edge.target.id}" + ) + # upsert verts first to make sure their ID becomes an attr + v_id = util.process_id(edge.source.id) # src_id + if len(v_id) == 0: + continue + desc = await get_vert_desc(conn, v_id, edge.source) + await upsert_chan.put( + ( + util.upsert_vertex, # func to call + ( + conn, + "Entity", # v_type + v_id, + { # attrs + "description": desc, + "epoch_added": int(time.time()), + }, + ), + ) + ) + v_id = util.process_id(edge.target.id) + if len(v_id) == 0: + continue + desc = await get_vert_desc(conn, v_id, edge.target) + await upsert_chan.put( + ( + util.upsert_vertex, # func to call + ( + conn, + "Entity", # v_type + v_id, # src_id + { # attrs + "description": desc, + "epoch_added": int(time.time()), + }, + ), + ) + ) + + # upsert the edge between the two entities + await upsert_chan.put( + ( + util.upsert_edge, + ( + conn, + "Entity", # src_type + util.process_id(edge.source.id), # src_id + "RELATIONSHIP", # edgeType + "Entity", # tgt_type + util.process_id(edge.target.id), # tgt_id + {"relation_type": edge.type}, # attributes + ), + ) + ) + # embed "Relationship", + # (v_id, content, index_name) + + +async def resolve_entity( + conn: TigerGraphConnection, + upsert_chan: Channel, + emb_store: MilvusEmbeddingStore, + entity_id: str, +): + """ + get all vectors of E (one name can have multiple discriptions) + get ents close to E + for e in ents: + if e is 95% similar to E and edit_dist(E,e) <=3: + merge + mark e as processed + + mark as processed + """ + results = await emb_store.aget_k_closest(entity_id) + if len(results) == 0: + logger.error( + f"aget_k_closest should, minimally, return the entity itself.\n{results}" + ) + raise Exception() + + # merge all entities into the ResolvedEntity vertex + # use the longest v_id as the resolved entity's v_id + resolved_entity_id = entity_id + for v in results: + if len(v) > len(resolved_entity_id): + resolved_entity_id = v + + # upsert the resolved entity + await upsert_chan.put( + ( + util.upsert_vertex, # func to call + ( + conn, + "ResolvedEntity", # v_type + resolved_entity_id, # v_id + { # attrs + }, + ), + ) + ) + + # create RESOLVES_TO edges from each entity to the ResolvedEntity + for v in results: + await upsert_chan.put( + ( + util.upsert_edge, + ( + conn, + "Entity", # src_type + v, # src_id + "RESOLVES_TO", # edge_type + "ResolvedEntity", # tgt_type + resolved_entity_id, # tgt_id + None, # attributes + ), + ) + ) + + +async def process_community( + conn: TigerGraphConnection, + upsert_chan: Channel, + embed_chan: Channel, + i: int, + comm_id: str, +): + """ + https://github.com/microsoft/graphrag/blob/main/graphrag/prompt_tune/template/community_report_summarization.py + + Get children verts (Entity for layer-1 Communities, Community otherwise) + if the commuinty only has one child, use its description -- no need to summarize + + embed summaries + """ + + logger.info(f"Processing Community: {comm_id}") + # get the children of the community + children = await util.get_commuinty_children(conn, i, comm_id) + if i == 1: + tmp = [] + for c in children: + tmp.extend(c) + children = list(filter(lambda x: len(x) > 0, tmp)) + comm_id = util.process_id(comm_id) + + # if the community only has one child, use its description + if len(children) == 1: + summary = children[0] + else: + llm = ecc_util.get_llm_service() + summarizer = community_summarizer.CommunitySummarizer(llm) + summary = await summarizer.summarize(comm_id, children) + + await upsert_chan.put( + ( + util.upsert_vertex, # func to call + ( + conn, + "Community", # v_type + comm_id, # v_id + { # attrs + "description": summary, + "iteration": i, + }, + ), + ) + ) + + # (v_id, content, index_name) + await embed_chan.put((comm_id, summary, "Community")) diff --git a/eventual-consistency-service/app/main.py b/eventual-consistency-service/app/main.py index 51843a04..2ccc10e2 100644 --- a/eventual-consistency-service/app/main.py +++ b/eventual-consistency-service/app/main.py @@ -1,54 +1,82 @@ +import os + +os.environ["ECC"] = "true" +import json import logging -from typing import Annotated +from contextlib import asynccontextmanager +from threading import Thread +from typing import Annotated, Callable -from fastapi import Depends, FastAPI, BackgroundTasks +import ecc_util +import graphrag +from eventual_consistency_checker import EventualConsistencyChecker +from fastapi import BackgroundTasks, Depends, FastAPI, Response, status from fastapi.security.http import HTTPBase from common.config import ( db_config, + doc_processing_config, embedding_service, get_llm_service, llm_config, milvus_config, security, - doc_processing_config, ) +from common.db.connections import elevate_db_connection_to_token from common.embeddings.milvus_embedding_store import MilvusEmbeddingStore from common.logs.logwriter import LogWriter from common.metrics.tg_proxy import TigerGraphConnectionProxy -from common.db.connections import elevate_db_connection_to_token -from eventual_consistency_checker import EventualConsistencyChecker -import json -from threading import Thread +from common.py_schemas.schemas import SupportAIMethod logger = logging.getLogger(__name__) consistency_checkers = {} -app = FastAPI() -@app.on_event("startup") -def startup_event(): - if not db_config.get("enable_consistency_checker", True): - LogWriter.info("Eventual consistency checker disabled") - return +@asynccontextmanager +async def lifespan(_: FastAPI): + if not db_config.get("enable_consistency_checker", False): + LogWriter.info("Eventual Consistency Checker not run on startup") + + else: + startup_checkers = db_config.get("graph_names", []) + for graphname in startup_checkers: + conn = elevate_db_connection_to_token( + db_config["hostname"], + db_config["username"], + db_config["password"], + graphname, + ) + start_ecc_in_thread(graphname, conn) + yield + LogWriter.info("ECC Shutdown") + + +app = FastAPI(lifespan=lifespan) - startup_checkers = db_config.get("graph_names", []) - for graphname in startup_checkers: - conn = elevate_db_connection_to_token(db_config["hostname"], db_config["username"], db_config["password"], graphname) - start_ecc_in_thread(graphname, conn) def start_ecc_in_thread(graphname: str, conn: TigerGraphConnectionProxy): - thread = Thread(target=initialize_eventual_consistency_checker, args=(graphname, conn), daemon=True) + thread = Thread( + target=initialize_eventual_consistency_checker, + args=(graphname, conn), + daemon=True, + ) thread.start() LogWriter.info(f"Eventual consistency checker started for graph {graphname}") -def initialize_eventual_consistency_checker(graphname: str, conn: TigerGraphConnectionProxy): + +def initialize_eventual_consistency_checker( + graphname: str, conn: TigerGraphConnectionProxy +): if graphname in consistency_checkers: return consistency_checkers[graphname] try: - process_interval_seconds = milvus_config.get("process_interval_seconds", 1800) # default 30 minutes - cleanup_interval_seconds = milvus_config.get("cleanup_interval_seconds", 86400) # default 30 days, + process_interval_seconds = milvus_config.get( + "process_interval_seconds", 1800 + ) # default 30 minutes + cleanup_interval_seconds = milvus_config.get( + "cleanup_interval_seconds", 86400 + ) # default 30 days, batch_size = milvus_config.get("batch_size", 10) vector_indices = {} vertex_field = None @@ -71,38 +99,10 @@ def initialize_eventual_consistency_checker(graphname: str, conn: TigerGraphConn vector_field=milvus_config.get("vector_field", "document_vector"), text_field=milvus_config.get("text_field", "document_content"), vertex_field=vertex_field, - alias=milvus_config.get("alias", "default") + alias=milvus_config.get("alias", "default"), ) - if doc_processing_config.get("chunker") == "semantic": - from common.chunkers.semantic_chunker import SemanticChunker - - chunker = SemanticChunker( - embedding_service, - doc_processing_config["chunker_config"].get("method", "percentile"), - doc_processing_config["chunker_config"].get("threshold", 0.95), - ) - elif doc_processing_config.get("chunker") == "regex": - from common.chunkers.regex_chunker import RegexChunker - - chunker = RegexChunker( - pattern=doc_processing_config["chunker_config"].get( - "pattern", "\\r?\\n" - ) - ) - elif doc_processing_config.get("chunker") == "character": - from common.chunkers.character_chunker import CharacterChunker - - chunker = CharacterChunker( - chunk_size=doc_processing_config["chunker_config"].get( - "chunk_size", 1024 - ), - overlap_size=doc_processing_config["chunker_config"].get( - "overlap_size", 0 - ), - ) - else: - raise ValueError("Invalid chunker type") + chunker = ecc_util.get_chunker() if doc_processing_config.get("extractor") == "llm": from common.extractors import LLMEntityRelationshipExtractor @@ -112,7 +112,9 @@ def initialize_eventual_consistency_checker(graphname: str, conn: TigerGraphConn raise ValueError("Invalid extractor type") if vertex_field is None: - raise ValueError("vertex_field is not defined. Ensure Milvus is enabled in the configuration.") + raise ValueError( + "vertex_field is not defined. Ensure Milvus is enabled in the configuration." + ) checker = EventualConsistencyChecker( process_interval_seconds, @@ -125,7 +127,7 @@ def initialize_eventual_consistency_checker(graphname: str, conn: TigerGraphConn conn, chunker, extractor, - batch_size + batch_size, ) consistency_checkers[graphname] = checker @@ -139,22 +141,61 @@ def initialize_eventual_consistency_checker(graphname: str, conn: TigerGraphConn return checker except Exception as e: - LogWriter.error(f"Failed to start eventual consistency checker for graph {graphname}: {e}") + LogWriter.error( + f"Failed to start eventual consistency checker for graph {graphname}: {e}" + ) + + +def start_func_in_thread(f: Callable, *args, **kwargs): + thread = Thread( + target=f, + args=args, + kwargs=kwargs, + daemon=True, + ) + thread.start() + LogWriter.info(f'Thread started for function: "{f.__name__}"') + @app.get("/") def root(): LogWriter.info(f"Healthcheck") return {"status": "ok"} -@app.get("/{graphname}/consistency_status") -def consistency_status(graphname: str, credentials: Annotated[HTTPBase, Depends(security)]): - if graphname in consistency_checkers: - ecc = consistency_checkers[graphname] - status = json.dumps(ecc.get_status()) - else: - conn = elevate_db_connection_to_token(db_config["hostname"], credentials.username, credentials.password, graphname) - start_ecc_in_thread(graphname, conn) - status = f"Eventual consistency checker started for graph {graphname}" - LogWriter.info(f"Returning consistency status for {graphname}: {status}") - return status +@app.get("/{graphname}/consistency_status/{ecc_method}") +def consistency_status( + graphname: str, + ecc_method: str, + background: BackgroundTasks, + credentials: Annotated[HTTPBase, Depends(security)], + response: Response, +): + conn = elevate_db_connection_to_token( + db_config["hostname"], + credentials.username, + credentials.password, + graphname, + ) + match ecc_method: + case SupportAIMethod.SUPPORTAI: + if graphname in consistency_checkers: + ecc = consistency_checkers[graphname] + ecc_status = json.dumps(ecc.get_status()) + else: + start_ecc_in_thread(graphname, conn) + ecc_status = ( + f"Eventual consistency checker started for graph {graphname}" + ) + + LogWriter.info(f"Returning consistency status for {graphname}: {status}") + case SupportAIMethod.GRAPHRAG: + background.add_task(graphrag.run, graphname, conn) + import time + + ecc_status = f"GraphRAG initialization on {conn.graphname} {time.ctime()}" + case _: + response.status_code = status.HTTP_404_NOT_FOUND + return f"Method unsupported, must be {SupportAIMethod.SUPPORTAI}, {SupportAIMethod.GRAPHRAG}" + + return ecc_status