Skip to content

Commit

Permalink
feat(copilot): adjust query registration to utilize gsql query descri…
Browse files Browse the repository at this point in the history
…ptors if TG 4+
  • Loading branch information
parkererickson-tg committed Apr 25, 2024
1 parent bd15deb commit 91f4e7f
Showing 1 changed file with 69 additions and 27 deletions.
96 changes: 69 additions & 27 deletions pyTigerGraph/ai/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(self, conn: TigerGraphConnection) -> None:
"""
self.conn = conn
self.nlqs_host = None
self.if4 = conn.getVer().split(".")[0] >= "4"
if conn.tgCloud:
# split scheme and host
scheme, host = conn.host.split("://")
Expand All @@ -40,63 +41,104 @@ def configureCoPilotHost(self, hostname: str):
"""
self.nlqs_host = hostname

def registerCustomQuery(self, function_header: str, description: str, docstring: str, param_types: dict = {}):
def registerCustomQuery(self, query_name: str, description: str = None, docstring: str = None, param_types: dict = None):
""" Register a custom query with the InquiryAI service.
Args:
function_header (str):
The name of the query being registered.
query_name (str):
The name of the query being registered. Required.
description (str):
The high-level description of the query being registered.
The high-level description of the query being registered. Only used when using TigerGraph 3.x.
docstring (str):
The docstring of the query being registered. Includes information about each parameter.
Only used when using TigerGraph 3.x.
param_types (Dict[str, str]):
The types of the parameters. In the format {"param_name": "param_type"}
Only used when using TigerGraph 3.x.
Returns:
Hash of query that was registered.
"""
data = {
"function_header": function_header,
"description": description,
"docstring": docstring,
"param_types": param_types
}
url = self.nlqs_host+"/"+self.conn.graphname+"/register_docs"
return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None)
if self.if4:
if description or docstring or param_types:
warnings.warn(
"""When using TigerGraph 4.x, query descriptions, docstrings, and parameter types are not required parameters.
They will be ignored, and the GSQL descriptions of the queries will be used instead.""",
UserWarning)
data = {
"queries": [query_name]
}
url = self.nlqs_host+"/"+self.conn.graphname+"/upsert_from_gsql"
return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None)
else:
if description is None:
raise ValueError("When using TigerGraph 3.x, query descriptions are required parameters.")
if docstring is None:
raise ValueError("When using TigerGraph 3.x, query docstrings are required parameters.")
if param_types is None:
raise ValueError("When using TigerGraph 3.x, query parameter types are required parameters.")
data = {
"function_header": query_name,
"description": description,
"docstring": docstring,
"param_types": param_types
}
url = self.nlqs_host+"/"+self.conn.graphname+"/register_docs"
return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None)

def updateCustomQuery(self, function_header: str, description: str, docstring: str, param_types: dict = {}):
def updateCustomQuery(self, query_name: str, description: str = None, docstring: str = None, param_types: dict = None):
""" Update a custom query with the InquiryAI service.
Args:
function_header (str):
The name of the query being updated.
query_name (str):
The name of the query being updated. Required.
description (str):
The high-level description of the query being updated.
Only used when using TigerGraph 3.x.
docstring (str):
The docstring of the query being updated. Includes information about each parameter.
Only used when using TigerGraph 3.x.
param_types (Dict[str, str]):
The types of the parameters. In the format {"param_name": "param_type"}
Only used when using TigerGraph 3.x.
Returns:
Hash of query that was updated.
"""
data = {
"function_header": function_header,
"description": description,
"docstring": docstring,
"param_types": param_types
}
if self.if4:
if description or docstring or param_types:
warnings.warn(
"""When using TigerGraph 4.x, query descriptions, docstrings, and parameter types are not required parameters.
They will be ignored, and the GSQL descriptions of the queries will be used instead.""",
UserWarning)
data = {
"queries": [query_name]
}
url = self.nlqs_host+"/"+self.conn.graphname+"/upsert_from_gsql"
return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None)
else:
if description is None:
raise ValueError("When using TigerGraph 3.x, query descriptions are required parameters.")
if docstring is None:
raise ValueError("When using TigerGraph 3.x, query docstrings are required parameters.")
if param_types is None:
raise ValueError("When using TigerGraph 3.x, query parameter types are required parameters.")
data = {
"function_header": query_name,
"description": description,
"docstring": docstring,
"param_types": param_types
}

json_payload = {"query_info": data}
url = self.nlqs_host+"/"+self.conn.graphname+"/upsert_docs"
return self.conn._req("POST", url, authMode="pwd", data = json_payload, jsonData=True, resKey=None)
json_payload = {"query_info": data}
url = self.nlqs_host+"/"+self.conn.graphname+"/upsert_docs"
return self.conn._req("POST", url, authMode="pwd", data = json_payload, jsonData=True, resKey=None)

def deleteCustomQuery(self, function_header: str):
def deleteCustomQuery(self, query_name: str):
""" Delete a custom query with the InquiryAI service.
Args:
function_header (str):
query_name (str):
The name of the query being deleted.
Returns:
Hash of query that was deleted.
"""
data = {"ids": [], "expr": "function_header == '"+function_header+"'"}
data = {"ids": [], "expr": "function_header == '"+query_name+"'"}
url = self.nlqs_host+"/"+self.conn.graphname+"/delete_docs"
return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None)

Expand Down

0 comments on commit 91f4e7f

Please sign in to comment.