Skip to content

Commit

Permalink
Merge pull request #219 from tigergraph/GML-1648-copilot-polish
Browse files Browse the repository at this point in the history
Gml 1648 copilot polish
  • Loading branch information
parkererickson-tg authored Apr 25, 2024
2 parents f91a238 + 91f4e7f commit b623632
Showing 1 changed file with 86 additions and 29 deletions.
115 changes: 86 additions & 29 deletions pyTigerGraph/ai/ai.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json
import warnings
from pyTigerGraph import TigerGraphConnection

class AI:
def __init__(self, conn: "TigerGraphConnection") -> None:
def __init__(self, conn: TigerGraphConnection) -> None:
"""NO DOC: Initiate an AI object. Currently in beta testing.
Args:
conn (TigerGraphConnection):
Expand All @@ -13,9 +14,15 @@ def __init__(self, conn: "TigerGraphConnection") -> None:
"""
self.conn = conn
self.nlqs_host = None

self.if4 = conn.getVer().split(".")[0] >= "4"
if conn.tgCloud:
# split scheme and host
scheme, host = conn.host.split("://")
self.nlqs_host = scheme + "://copilot-" + host

def configureInquiryAIHost(self, hostname: str):
""" DEPRECATED: Configure the hostname of the InquiryAI service.
Not recommended to use. Use configureCoPilotHost() instead.
Args:
hostname (str):
The hostname (and port number) of the InquiryAI serivce.
Expand All @@ -27,69 +34,111 @@ def configureInquiryAIHost(self, hostname: str):

def configureCoPilotHost(self, hostname: str):
""" Configure the hostname of the CoPilot service.
Not necessary if using TigerGraph CoPilot on TigerGraph Cloud.
Args:
hostname (str):
The hostname (and port number) of the CoPilot serivce.
"""
self.nlqs_host = hostname

def registerCustomQuery(self, function_header: str, description: str, docstring: str, param_types: dict = {}):
def registerCustomQuery(self, query_name: str, description: str = None, docstring: str = None, param_types: dict = None):
""" Register a custom query with the InquiryAI service.
Args:
function_header (str):
The name of the query being registered.
query_name (str):
The name of the query being registered. Required.
description (str):
The high-level description of the query being registered.
The high-level description of the query being registered. Only used when using TigerGraph 3.x.
docstring (str):
The docstring of the query being registered. Includes information about each parameter.
Only used when using TigerGraph 3.x.
param_types (Dict[str, str]):
The types of the parameters. In the format {"param_name": "param_type"}
Only used when using TigerGraph 3.x.
Returns:
Hash of query that was registered.
"""
data = {
"function_header": function_header,
"description": description,
"docstring": docstring,
"param_types": param_types
}
url = self.nlqs_host+"/"+self.conn.graphname+"/register_docs"
return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None)
if self.if4:
if description or docstring or param_types:
warnings.warn(
"""When using TigerGraph 4.x, query descriptions, docstrings, and parameter types are not required parameters.
They will be ignored, and the GSQL descriptions of the queries will be used instead.""",
UserWarning)
data = {
"queries": [query_name]
}
url = self.nlqs_host+"/"+self.conn.graphname+"/upsert_from_gsql"
return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None)
else:
if description is None:
raise ValueError("When using TigerGraph 3.x, query descriptions are required parameters.")
if docstring is None:
raise ValueError("When using TigerGraph 3.x, query docstrings are required parameters.")
if param_types is None:
raise ValueError("When using TigerGraph 3.x, query parameter types are required parameters.")
data = {
"function_header": query_name,
"description": description,
"docstring": docstring,
"param_types": param_types
}
url = self.nlqs_host+"/"+self.conn.graphname+"/register_docs"
return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None)

def updateCustomQuery(self, function_header: str, description: str, docstring: str, param_types: dict = {}):
def updateCustomQuery(self, query_name: str, description: str = None, docstring: str = None, param_types: dict = None):
""" Update a custom query with the InquiryAI service.
Args:
function_header (str):
The name of the query being updated.
query_name (str):
The name of the query being updated. Required.
description (str):
The high-level description of the query being updated.
Only used when using TigerGraph 3.x.
docstring (str):
The docstring of the query being updated. Includes information about each parameter.
Only used when using TigerGraph 3.x.
param_types (Dict[str, str]):
The types of the parameters. In the format {"param_name": "param_type"}
Only used when using TigerGraph 3.x.
Returns:
Hash of query that was updated.
"""
data = {
"function_header": function_header,
"description": description,
"docstring": docstring,
"param_types": param_types
}
if self.if4:
if description or docstring or param_types:
warnings.warn(
"""When using TigerGraph 4.x, query descriptions, docstrings, and parameter types are not required parameters.
They will be ignored, and the GSQL descriptions of the queries will be used instead.""",
UserWarning)
data = {
"queries": [query_name]
}
url = self.nlqs_host+"/"+self.conn.graphname+"/upsert_from_gsql"
return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None)
else:
if description is None:
raise ValueError("When using TigerGraph 3.x, query descriptions are required parameters.")
if docstring is None:
raise ValueError("When using TigerGraph 3.x, query docstrings are required parameters.")
if param_types is None:
raise ValueError("When using TigerGraph 3.x, query parameter types are required parameters.")
data = {
"function_header": query_name,
"description": description,
"docstring": docstring,
"param_types": param_types
}

json_payload = {"query_info": data}
url = self.nlqs_host+"/"+self.conn.graphname+"/upsert_docs"
return self.conn._req("POST", url, authMode="pwd", data = json_payload, jsonData=True, resKey=None)
json_payload = {"query_info": data}
url = self.nlqs_host+"/"+self.conn.graphname+"/upsert_docs"
return self.conn._req("POST", url, authMode="pwd", data = json_payload, jsonData=True, resKey=None)

def deleteCustomQuery(self, function_header: str):
def deleteCustomQuery(self, query_name: str):
""" Delete a custom query with the InquiryAI service.
Args:
function_header (str):
query_name (str):
The name of the query being deleted.
Returns:
Hash of query that was deleted.
"""
data = {"ids": [], "expr": "function_header == '"+function_header+"'"}
data = {"ids": [], "expr": "function_header == '"+query_name+"'"}
url = self.nlqs_host+"/"+self.conn.graphname+"/delete_docs"
return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None)

Expand Down Expand Up @@ -231,4 +280,12 @@ def forceConsistencyUpdate(self):
JSON response from the consistency update.
"""
url = self.nlqs_host+"/"+self.conn.graphname+"/supportai/forceupdate"
return self.conn._req("GET", url, authMode="pwd", resKey=None)

def checkConsistencyProgress(self):
""" Check the progress of the consistency update.
Returns:
JSON response from the consistency update progress.
"""
url = self.nlqs_host+"/"+self.conn.graphname+"/supportai/consistency_status"
return self.conn._req("GET", url, authMode="pwd", resKey=None)

0 comments on commit b623632

Please sign in to comment.