Skip to content

Commit

Permalink
Merge pull request #269 from tigergraph/GML-1860-graphrag-load-tuning
Browse files Browse the repository at this point in the history
stability improvements
  • Loading branch information
parkererickson-tg authored Aug 29, 2024
2 parents 415eec9 + d1197bc commit 81b0052
Show file tree
Hide file tree
Showing 6 changed files with 297 additions and 218 deletions.
2 changes: 1 addition & 1 deletion common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def get_llm_service(llm_config) -> LLM_Model:
):
doc_processing_config = {
"chunker": "semantic",
"chunker_config": {"method": "percentile", "threshold": 0.95},
"chunker_config": {"method": "percentile", "threshold": 0.90},
"extractor": "graphrag",
"extractor_config": {},
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
CREATE DISTRIBUTED QUERY GraphRAG_CommunityRetriever(INT community_level=2) FOR GRAPH pyTigerGraphRAG {
CREATE DISTRIBUTED QUERY GraphRAG_Community_Retriever(INT community_level=2) {
comms = {Community.*};

selected_comms = SELECT c FROM comms:c WHERE c.iteration == community_level;
Expand Down
4 changes: 2 additions & 2 deletions copilot/app/supportai/retrievers/GraphRAG.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ def __init__(
connection: TigerGraphConnectionProxy,
):
super().__init__(embedding_service, embedding_store, llm_service, connection)
self._check_query_install("GraphRAG_CommunityRetriever")
self._check_query_install("GraphRAG_Community_Retriever")

def search(self, question, community_level: int):
res = self.conn.runInstalledQuery("GraphRAG_CommunityRetriever", {"community_level": community_level}, usePost=True)
res = self.conn.runInstalledQuery("GraphRAG_Community_Retriever", {"community_level": community_level}, usePost=True)
return res

async def _generate_candidate(self, question, context):
Expand Down
27 changes: 16 additions & 11 deletions eventual-consistency-service/app/graphrag/graph_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
http_timeout,
init,
load_q,
loading_event,
make_headers,
stream_ids,
tg_sem,
Expand Down Expand Up @@ -124,7 +125,7 @@ async def upsert(upsert_chan: Channel):
async def load(conn: TigerGraphConnection):
logger.info("Reading from load_q")
dd = lambda: defaultdict(dd) # infinite default dict
batch_size = 1000
batch_size = 500
# while the load q is still open or has contents
while not load_q.closed() or not load_q.empty():
if load_q.closed():
Expand Down Expand Up @@ -169,7 +170,11 @@ async def load(conn: TigerGraphConnection):
logger.info(
f"Upserting batch size of {size}. ({n_verts} verts | {n_edges} edges. {len(data.encode())/1000:,} kb)"
)

loading_event.clear()
await upsert_batch(conn, data)
await asyncio.sleep(5)
loading_event.set()
else:
await asyncio.sleep(1)

Expand Down Expand Up @@ -435,12 +440,12 @@ async def run(graphname: str, conn: TigerGraphConnection):
if doc_process_switch:
logger.info("Doc Processing Start")
docs_chan = Channel(1)
embed_chan = Channel(100)
upsert_chan = Channel(100)
extract_chan = Channel(100)
embed_chan = Channel()
upsert_chan = Channel()
extract_chan = Channel()
async with asyncio.TaskGroup() as grp:
# get docs
grp.create_task(stream_docs(conn, docs_chan, 10))
grp.create_task(stream_docs(conn, docs_chan, 100))
# process docs
grp.create_task(
chunk_docs(conn, docs_chan, embed_chan, upsert_chan, extract_chan)
Expand All @@ -462,8 +467,8 @@ async def run(graphname: str, conn: TigerGraphConnection):

if entity_resolution_switch:
logger.info("Entity Processing Start")
entities_chan = Channel(100)
upsert_chan = Channel(100)
entities_chan = Channel()
upsert_chan = Channel()
load_q.reopen()
async with asyncio.TaskGroup() as grp:
grp.create_task(stream_entities(conn, entities_chan, 50))
Expand All @@ -487,10 +492,10 @@ async def run(graphname: str, conn: TigerGraphConnection):
community_start = time.perf_counter()
if community_detection_switch:
logger.info("Community Processing Start")
upsert_chan = Channel(10)
comm_process_chan = Channel(100)
upsert_chan = Channel(100)
embed_chan = Channel(100)
upsert_chan = Channel()
comm_process_chan = Channel()
upsert_chan = Channel()
embed_chan = Channel()
load_q.reopen()
async with asyncio.TaskGroup() as grp:
# run louvain
Expand Down
35 changes: 24 additions & 11 deletions eventual-consistency-service/app/graphrag/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from glob import glob

import httpx
from graphrag import reusable_channel, workers
from pyTigerGraph import TigerGraphConnection

from common.config import (
doc_processing_config,
embedding_service,
Expand All @@ -17,15 +20,17 @@
from common.extractors import GraphExtractor, LLMEntityRelationshipExtractor
from common.extractors.BaseExtractor import BaseExtractor
from common.logs.logwriter import LogWriter
from graphrag import reusable_channel, workers
from pyTigerGraph import TigerGraphConnection

logger = logging.getLogger(__name__)
http_timeout = httpx.Timeout(15.0)

tg_sem = asyncio.Semaphore(20)
tg_sem = asyncio.Semaphore(2)
load_q = reusable_channel.ReuseableChannel()

# will pause workers until the event is false
loading_event = asyncio.Event()
loading_event.set() # set the event to true to allow the workers to run

async def install_queries(
requried_queries: list[str],
conn: TigerGraphConnection,
Expand Down Expand Up @@ -207,7 +212,6 @@ async def upsert_batch(conn: TigerGraphConnection, data: str):
res.raise_for_status()



async def check_vertex_exists(conn, v_id: str):
headers = make_headers(conn)
async with httpx.AsyncClient(timeout=http_timeout) as client:
Expand All @@ -219,7 +223,8 @@ async def check_vertex_exists(conn, v_id: str):
)

except Exception as e:
logger.error(f"Check err:\n{e}")
err = traceback.format_exc()
logger.error(f"Check err:\n{err}")
return {"error": True}

try:
Expand Down Expand Up @@ -264,17 +269,25 @@ async def get_commuinty_children(conn, i: int, c: str):
headers = make_headers(conn)
async with httpx.AsyncClient(timeout=None) as client:
async with tg_sem:
resp = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/get_community_children",
params={"comm": c, "iter": i},
headers=headers,
)
try:
resp = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/get_community_children",
params={"comm": c, "iter": i},
headers=headers,
)
except:
logger.error(f"Get Children err:\n{traceback.format_exc()}")
try:
resp.raise_for_status()
except Exception as e:
logger.error(f"Get Children err:\n{e}")
descrs = []
for d in resp.json()["results"][0]["children"]:
try:
res = resp.json()["results"][0]["children"]
except Exception as e:
logger.error(f"Get Children err:\n{e}")
res = []
for d in res:
desc = d["attributes"]["description"]
# if it's the entity iteration
if i == 1:
Expand Down
Loading

0 comments on commit 81b0052

Please sign in to comment.