Skip to content

Commit

Permalink
Merge pull request #271 from tigergraph/GML-1866-entity-type-upserts
Browse files Browse the repository at this point in the history
feat(graphrag): add type information upsert
  • Loading branch information
parkererickson-tg authored Sep 5, 2024
2 parents e51d35a + e6065f6 commit 1d8e9b8
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 31 deletions.
3 changes: 3 additions & 0 deletions common/gsql/supportai/SupportAI_Schema.gsql
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ CREATE SCHEMA_CHANGE JOB add_supportai_schema {
ADD VERTEX Relationship(PRIMARY_ID id STRING, definition STRING, short_name STRING, epoch_added UINT, epoch_processing UINT, epoch_processed UINT) WITH STATS="OUTDEGREE_BY_EDGETYPE", PRIMARY_ID_AS_ATTRIBUTE="true";
ADD VERTEX DocumentCollection(PRIMARY_ID id STRING, epoch_added UINT) WITH STATS="OUTDEGREE_BY_EDGETYPE", PRIMARY_ID_AS_ATTRIBUTE="true";
ADD VERTEX Content(PRIMARY_ID id STRING, text STRING, epoch_added UINT) WITH STATS="OUTDEGREE_BY_EDGETYPE", PRIMARY_ID_AS_ATTRIBUTE="true";
ADD VERTEX EntityType(PRIMARY_ID id STRING, description STRING, epoch_added UINT) WITH STATS="OUTDEGREE_BY_EDGETYPE", PRIMARY_ID_AS_ATTRIBUTE="true";
ADD DIRECTED EDGE HAS_CONTENT(FROM Document, TO Content|FROM DocumentChunk, TO Content) WITH REVERSE_EDGE="reverse_HAS_CONTENT";
ADD DIRECTED EDGE IS_CHILD_OF(FROM Concept, TO Concept) WITH REVERSE_EDGE="reverse_IS_CHILD_OF";
ADD DIRECTED EDGE IS_HEAD_OF(FROM Entity, TO Relationship) WITH REVERSE_EDGE="reverse_IS_HEAD_OF";
Expand All @@ -18,6 +19,8 @@ CREATE SCHEMA_CHANGE JOB add_supportai_schema {
ADD DIRECTED EDGE HAS_CHILD(FROM Document, TO DocumentChunk) WITH REVERSE_EDGE="reverse_HAS_CHILD";
ADD DIRECTED EDGE HAS_RELATIONSHIP(FROM Concept, TO Concept, relation_type STRING) WITH REVERSE_EDGE="reverse_HAS_RELATIONSHIP";
ADD DIRECTED EDGE CONTAINS_DOCUMENT(FROM DocumentCollection, TO Document) WITH REVERSE_EDGE="reverse_CONTAINS_DOCUMENT";
ADD DIRECTED EDGE ENTITY_HAS_TYPE(FROM Entity, TO EntityType) WITH REVERSE_EDGE="reverse_ENTITY_HAS_TYPE";
ADD DIRECTED EDGE RELATIONSHIP_TYPE(FROM EntityType, TO EntityType, DISCRIMINATOR(relation_type STRING), frequency INT) WITH REVERSE_EDGE="reverse_RELATIONSHIP_TYPE";

// GraphRAG
ADD VERTEX Community (PRIMARY_ID id STRING, iteration UINT, description STRING) WITH PRIMARY_ID_AS_ATTRIBUTE="true";
Expand Down
20 changes: 20 additions & 0 deletions common/gsql/supportai/create_entity_type_relationships.gsql
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
CREATE DISTRIBUTED QUERY create_entity_type_relationships(/* Parameters here */) SYNTAX v2{
MapAccum<STRING, MapAccum<STRING, SumAccum<INT>>> @rel_type_count; // entity type, relationship type for entity type, frequency
SumAccum<INT> @@rels_inserted;
ents = {Entity.*};
accum_types = SELECT et FROM ents:e -(RELATIONSHIP>:r)- Entity:e2 -(ENTITY_HAS_TYPE>:eht)- EntityType:et
WHERE r.relation_type != "DOC_CHUNK_COOCCURRENCE"
ACCUM
e.@rel_type_count += (et.id -> (r.relation_type -> 1));

ets = SELECT et FROM ents:e -(ENTITY_HAS_TYPE>:eht)- EntityType:et
ACCUM
FOREACH (entity_type, rel_type_freq) IN e.@rel_type_count DO
FOREACH (rel_type, freq) IN e.@rel_type_count.get(entity_type) DO
INSERT INTO RELATIONSHIP_TYPE VALUES (et.id, entity_type, rel_type, freq),
@@rels_inserted += 1
END
END;

PRINT @@rels_inserted as relationships_inserted;
}
17 changes: 10 additions & 7 deletions eventual-consistency-service/app/graphrag/community_summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,13 @@ async def summarize(self, name: str, text: list[str]) -> CommunitySummary:

# remove iteration tags from name
name = id_pat.sub("", name)
summary = await chain.ainvoke(
{
"entity_name": name,
"description_list": text,
}
)
return summary.summary
try:
summary = await chain.ainvoke(
{
"entity_name": name,
"description_list": text,
}
)
except Exception as e:
return {"error": True, "summary": ""}
return {"error": False, "summary": summary.summary}
12 changes: 12 additions & 0 deletions eventual-consistency-service/app/graphrag/graph_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
stream_ids,
tg_sem,
upsert_batch,
add_rels_between_types
)
from pyTigerGraph import TigerGraphConnection

Expand Down Expand Up @@ -462,6 +463,16 @@ async def run(graphname: str, conn: TigerGraphConnection):
init_end = time.perf_counter()
logger.info("Doc Processing End")

# Type Resolution
type_start = time.perf_counter()
logger.info("Type Processing Start")
res = await add_rels_between_types(conn)
if res["error"]:
logger.error(f"Error adding relationships between types: {res}")
else:
logger.info(f"Added relationships between types: {res}")
logger.info("Type Processing End")
type_end = time.perf_counter()
# Entity Resolution
entity_start = time.perf_counter()

Expand Down Expand Up @@ -516,6 +527,7 @@ async def run(graphname: str, conn: TigerGraphConnection):
end = time.perf_counter()
logger.info(f"DONE. graphrag system initializer dT: {init_end-init_start}")
logger.info(f"DONE. graphrag entity resolution dT: {entity_end-entity_start}")
logger.info(f"DONE. graphrag type creation dT: {type_end-type_start}")
logger.info(
f"DONE. graphrag community initializer dT: {community_end-community_start}"
)
Expand Down
35 changes: 29 additions & 6 deletions eventual-consistency-service/app/graphrag/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ async def init(
"common/gsql/graphRAG/louvain/graphrag_louvain_communities",
"common/gsql/graphRAG/louvain/modularity",
"common/gsql/graphRAG/louvain/stream_community",
"common/gsql/supportai/create_entity_type_relationships"
]
# add louvain to queries
q = [x.split(".gsql")[0] for x in glob("common/gsql/graphRAG/louvain/*")]
Expand Down Expand Up @@ -206,9 +207,13 @@ async def upsert_batch(conn: TigerGraphConnection, data: str):
headers = make_headers(conn)
async with httpx.AsyncClient(timeout=http_timeout) as client:
async with tg_sem:
res = await client.post(
f"{conn.restppUrl}/graph/{conn.graphname}", data=data, headers=headers
)
try:
res = await client.post(
f"{conn.restppUrl}/graph/{conn.graphname}", data=data, headers=headers
)
except Exception as e:
err = traceback.format_exc()
logger.error(f"Upsert err:\n{err}")
res.raise_for_status()


Expand All @@ -225,14 +230,12 @@ async def check_vertex_exists(conn, v_id: str):
except Exception as e:
err = traceback.format_exc()
logger.error(f"Check err:\n{err}")
return {"error": True}

try:
res.raise_for_status()
return res.json()
except Exception as e:
logger.error(f"Check err:\n{e}\n{res.text}")
return {"error": True}
return {"error": True, "message": res.text}


async def upsert_edge(
Expand Down Expand Up @@ -321,6 +324,26 @@ async def check_all_ents_resolved(conn):

return res

async def add_rels_between_types(conn):
headers = make_headers(conn)
async with httpx.AsyncClient(timeout=None) as client:
async with tg_sem:
resp = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/create_entity_type_relationships",
headers=headers,
)
try:
resp.raise_for_status()
except Exception as e:
logger.error(f"Check Vert EntityType err:\n{e}")

if resp.status_code != 200:
logger.error(f"Check Vert EntityType err:\n{resp.text}")
return {"error": True, "message": resp.text}
else:
res = resp.json()["results"][0]["relationships_inserted"]
logger.info(resp.json()["results"])
return res

async def check_vertex_has_desc(conn, i: int):
headers = make_headers(conn)
Expand Down
76 changes: 58 additions & 18 deletions eventual-consistency-service/app/graphrag/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from common.config import milvus_config
from common.embeddings.embedding_services import EmbeddingModel
from common.embeddings.milvus_embedding_store import MilvusEmbeddingStore
from common.extractors.BaseExtractor import BaseExtractor
from common.extractors import BaseExtractor, LLMEntityRelationshipExtractor
from common.logs.logwriter import LogWriter

vertex_field = milvus_config.get("vertex_field", "vertex_id")
Expand Down Expand Up @@ -178,7 +178,7 @@ async def get_vert_desc(conn, v_id, node: Node):
desc = [node.properties.get("description", "")]
exists = await util.check_vertex_exists(conn, v_id)
# if vertex exists, get description content and append this description to it
if not exists["error"]:
if not exists.get("error", False):
# deduplicate descriptions
desc.extend(exists["results"][0]["attributes"]["description"])
desc = list(set(desc))
Expand Down Expand Up @@ -242,6 +242,39 @@ async def extract(
),
)
)
if isinstance(extractor, LLMEntityRelationshipExtractor):
logger.info("extract writes type vert to upsert")
type_id = util.process_id(node.type)
if len(type_id) == 0:
continue
await upsert_chan.put(
(
util.upsert_vertex, # func to call
(
conn,
"EntityType", # v_type
type_id, # v_id
{ # attrs
"epoch_added": int(time.time()),
},
)
)
)
logger.info("extract writes entity_has_type edge to upsert")
await upsert_chan.put(
(
util.upsert_edge,
(
conn,
"Entity", # src_type
v_id, # src_id
"ENTITY_HAS_TYPE", # edgeType
"EntityType", # tgt_type
type_id, # tgt_id
None, # attributes
),
)
)

# link the entity to the chunk it came from
logger.info("extract writes contains edge to upsert")
Expand Down Expand Up @@ -445,6 +478,7 @@ async def process_community(
# get the children of the community
children = await util.get_commuinty_children(conn, i, comm_id)
comm_id = util.process_id(comm_id)
err = False

# if the community only has one child, use its description
if len(children) == 1:
Expand All @@ -453,22 +487,28 @@ async def process_community(
llm = ecc_util.get_llm_service()
summarizer = community_summarizer.CommunitySummarizer(llm)
summary = await summarizer.summarize(comm_id, children)

logger.debug(f"Community {comm_id}: {children}, {summary}")
await upsert_chan.put(
(
util.upsert_vertex, # func to call
if summary["error"]:
logger.error(f"Failed to summarize community {comm_id}")
err = True
else:
summary = summary["summary"]

if not err:
logger.debug(f"Community {comm_id}: {children}, {summary}")
await upsert_chan.put(
(
conn,
"Community", # v_type
comm_id, # v_id
{ # attrs
"description": summary,
"iteration": i,
},
),
util.upsert_vertex, # func to call
(
conn,
"Community", # v_type
comm_id, # v_id
{ # attrs
"description": summary,
"iteration": i,
},
),
)
)
)

# (v_id, content, index_name)
await embed_chan.put((comm_id, summary, "Community"))
# (v_id, content, index_name)
await embed_chan.put((comm_id, summary, "Community"))

0 comments on commit 1d8e9b8

Please sign in to comment.