Skip to content

Commit

Permalink
Merge pull request #270 from tigergraph/GML-1804-graphrag-improvements
Browse files Browse the repository at this point in the history
feat(graphrag): add descriptions to all upserts, cooccurence edges
  • Loading branch information
RobRossmiller-TG authored Aug 30, 2024
2 parents 81b0052 + e76fbd3 commit e51d35a
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 36 deletions.
2 changes: 1 addition & 1 deletion common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def get_llm_service(llm_config) -> LLM_Model:
doc_processing_config = {
"chunker": "semantic",
"chunker_config": {"method": "percentile", "threshold": 0.90},
"extractor": "graphrag",
"extractor": "llm",
"extractor_config": {},
}
elif DOC_PROCESSING_CONFIG.endswith(".json"):
Expand Down
198 changes: 165 additions & 33 deletions common/extractors/LLMEntityRelationshipExtractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from common.extractors.BaseExtractor import BaseExtractor
from common.llm_services import LLM_Model
from common.py_schemas import KnowledgeGraph

from langchain_community.graphs.graph_document import Node, Relationship, GraphDocument
from langchain_core.documents import Document

class LLMEntityRelationshipExtractor(BaseExtractor):
def __init__(
Expand All @@ -19,42 +20,116 @@ def __init__(
self.allowed_edge_types = allowed_relationship_types
self.strict_mode = strict_mode

def _extract_kg_from_doc(self, doc, chain, parser):
"""
returns:
{
"nodes": [
{
"id": "str",
"type": "string",
"definition": "string"
}
],
"rels": [
{
"source":{
"id": "str",
"type": "string",
"definition": "string"
}
"target":{
"id": "str",
"type": "string",
"definition": "string"
async def _aextract_kg_from_doc(self, doc, chain, parser) -> list[GraphDocument]:
try:
out = await chain.ainvoke(
{"input": doc, "format_instructions": parser.get_format_instructions()}
)
except Exception as e:
return [GraphDocument(nodes=[], relationships=[], source=Document(page_content=doc))]
try:
if "```json" not in out.content:
json_out = json.loads(out.content.strip("content="))
else:
json_out = json.loads(
out.content.split("```")[1].strip("```").strip("json").strip()
)

formatted_rels = []
for rels in json_out["rels"]:
if isinstance(rels["source"], str) and isinstance(rels["target"], str):
formatted_rels.append(
{
"source": rels["source"],
"target": rels["target"],
"type": rels["relation_type"].replace(" ", "_").upper(),
"definition": rels["definition"],
}
)
elif isinstance(rels["source"], dict) and isinstance(
rels["target"], str
):
formatted_rels.append(
{
"source": rels["source"]["id"],
"target": rels["target"],
"type": rels["relation_type"].replace(" ", "_").upper(),
"definition": rels["definition"],
}
)
elif isinstance(rels["source"], str) and isinstance(
rels["target"], dict
):
formatted_rels.append(
{
"source": rels["source"],
"target": rels["target"]["id"],
"type": rels["relation_type"].replace(" ", "_").upper(),
"definition": rels["definition"],
}
)
elif isinstance(rels["source"], dict) and isinstance(
rels["target"], dict
):
formatted_rels.append(
{
"source": rels["source"]["id"],
"target": rels["target"]["id"],
"type": rels["relation_type"].replace(" ", "_").upper(),
"definition": rels["definition"],
}
)
else:
raise Exception("Relationship parsing error")
formatted_nodes = []
for node in json_out["nodes"]:
formatted_nodes.append(
{
"id": node["id"],
"type": node["node_type"].replace(" ", "_").capitalize(),
"definition": node["definition"],
}
"definition"
}
]
}
"""
)

# filter relationships and nodes based on allowed types
if self.strict_mode:
if self.allowed_vertex_types:
formatted_nodes = [
node
for node in formatted_nodes
if node["type"] in self.allowed_vertex_types
]
if self.allowed_edge_types:
formatted_rels = [
rel
for rel in formatted_rels
if rel["type"] in self.allowed_edge_types
]

nodes = []
for node in formatted_nodes:
nodes.append(Node(id=node["id"],
type=node["type"],
properties={"description": node["definition"]}))
relationships = []
for rel in formatted_rels:
relationships.append(Relationship(source=Node(id=rel["source"], type=rel["source"],
properties={"description": rel["definition"]}),
target=Node(id=rel["target"], type=rel["target"],
properties={"description": rel["definition"]}), type=rel["type"]))

return [GraphDocument(nodes=nodes, relationships=relationships, source=Document(page_content=doc))]

except:
return [GraphDocument(nodes=[], relationships=[], source=Document(page_content=doc))]

def _extract_kg_from_doc(self, doc, chain, parser) -> list[GraphDocument]:
try:
out = chain.invoke(
{"input": doc, "format_instructions": parser.get_format_instructions()}
)
except Exception as e:
print("Error: ", e)
return {"nodes": [], "rels": []}
return [GraphDocument(nodes=[], relationships=[], source=Document(page_content=doc))]
try:
if "```json" not in out.content:
json_out = json.loads(out.content.strip("content="))
Expand Down Expand Up @@ -133,15 +208,67 @@ def _extract_kg_from_doc(self, doc, chain, parser):
for rel in formatted_rels
if rel["type"] in self.allowed_edge_types
]
return {"nodes": formatted_nodes, "rels": formatted_rels}

nodes = []
for node in formatted_nodes:
nodes.append(Node(id=node["id"],
type=node["type"],
properties={"description": node["definition"]}))
relationships = []
for rel in formatted_rels:
relationships.append(Relationship(source=Node(id=rel["source"], type=rel["source"],
properties={"description": rel["definition"]}),
target=Node(id=rel["target"], type=rel["target"],
properties={"description": rel["definition"]}), type=rel["type"]))

return [GraphDocument(nodes=nodes, relationships=relationships, source=Document(page_content=doc))]

except:
print("Error Processing: ", out)
return {"nodes": [], "rels": []}
return [GraphDocument(nodes=[], relationships=[], source=Document(page_content=doc))]

async def adocument_er_extraction(self, document):
from langchain.prompts import ChatPromptTemplate
from langchain.output_parsers import PydanticOutputParser


parser = PydanticOutputParser(pydantic_object=KnowledgeGraph)
prompt = [
("system", self.llm_service.entity_relationship_extraction_prompt),
(
"human",
"Tip: Make sure to answer in the correct format and do "
"not include any explanations. "
"Use the given format to extract information from the "
"following input: {input}",
),
(
"human",
"Mandatory: Make sure to answer in the correct format, specified here: {format_instructions}",
),
]
if self.allowed_vertex_types or self.allowed_edge_types:
prompt.append(
(
"human",
"Tip: Make sure to use the following types if they are applicable. "
"If the input does not contain any of the types, you may create your own.",
)
)
if self.allowed_vertex_types:
prompt.append(("human", f"Allowed Node Types: {self.allowed_vertex_types}"))
if self.allowed_edge_types:
prompt.append(("human", f"Allowed Edge Types: {self.allowed_edge_types}"))
prompt = ChatPromptTemplate.from_messages(prompt)
chain = prompt | self.llm_service.model # | parser
er = await self._aextract_kg_from_doc(document, chain, parser)
return er


def document_er_extraction(self, document):
from langchain.prompts import ChatPromptTemplate
from langchain.output_parsers import PydanticOutputParser


parser = PydanticOutputParser(pydantic_object=KnowledgeGraph)
prompt = [
("system", self.llm_service.entity_relationship_extraction_prompt),
Expand Down Expand Up @@ -176,3 +303,8 @@ def document_er_extraction(self, document):

def extract(self, text):
return self.document_er_extraction(text)

async def aextract(self, text) -> list[GraphDocument]:
return await self.adocument_er_extraction(text)


2 changes: 1 addition & 1 deletion eventual-consistency-service/app/graphrag/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def map_attrs(attributes: dict):


def process_id(v_id: str):
v_id = v_id.replace(" ", "_").replace("/", "")
v_id = v_id.replace(" ", "_").replace("/", "").replace("%", "percent")

has_func = re.compile(r"(.*)\(").findall(v_id)
if len(has_func) > 0:
Expand Down
18 changes: 17 additions & 1 deletion eventual-consistency-service/app/graphrag/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ async def extract(

# upsert nodes and edges to the graph
for doc in extracted:
for node in doc.nodes:
for i, node in enumerate(doc.nodes):
logger.info(f"extract writes entity vert to upsert\nNode: {node.id}")
v_id = util.process_id(str(node.id))
if len(v_id) == 0:
Expand Down Expand Up @@ -259,6 +259,22 @@ async def extract(
),
)
)
for node2 in doc.nodes[i + 1:]:
v_id2 = util.process_id(str(node2.id))
await upsert_chan.put(
(
util.upsert_edge,
(
conn,
"Entity", # src_type
v_id, # src_id
"RELATIONSHIP", # edgeType
"Entity", # tgt_type
v_id2, # tgt_id
{"relation_type": "DOC_CHUNK_COOCCURRENCE"}, # attributes
),
)
)

for edge in doc.relationships:
logger.info(
Expand Down

0 comments on commit e51d35a

Please sign in to comment.