diff --git a/pyTigerGraph/ai/ai.py b/pyTigerGraph/ai/ai.py index 9db320a5..cb68d105 100644 --- a/pyTigerGraph/ai/ai.py +++ b/pyTigerGraph/ai/ai.py @@ -1,8 +1,9 @@ import json import warnings +from pyTigerGraph import TigerGraphConnection class AI: - def __init__(self, conn: "TigerGraphConnection") -> None: + def __init__(self, conn: TigerGraphConnection) -> None: """NO DOC: Initiate an AI object. Currently in beta testing. Args: conn (TigerGraphConnection): @@ -13,9 +14,15 @@ def __init__(self, conn: "TigerGraphConnection") -> None: """ self.conn = conn self.nlqs_host = None - + self.if4 = conn.getVer().split(".")[0] >= "4" + if conn.tgCloud: + # split scheme and host + scheme, host = conn.host.split("://") + self.nlqs_host = scheme + "://copilot-" + host + def configureInquiryAIHost(self, hostname: str): """ DEPRECATED: Configure the hostname of the InquiryAI service. + Not recommended to use. Use configureCoPilotHost() instead. Args: hostname (str): The hostname (and port number) of the InquiryAI serivce. @@ -27,69 +34,111 @@ def configureInquiryAIHost(self, hostname: str): def configureCoPilotHost(self, hostname: str): """ Configure the hostname of the CoPilot service. + Not necessary if using TigerGraph CoPilot on TigerGraph Cloud. Args: hostname (str): The hostname (and port number) of the CoPilot serivce. """ self.nlqs_host = hostname - def registerCustomQuery(self, function_header: str, description: str, docstring: str, param_types: dict = {}): + def registerCustomQuery(self, query_name: str, description: str = None, docstring: str = None, param_types: dict = None): """ Register a custom query with the InquiryAI service. Args: - function_header (str): - The name of the query being registered. + query_name (str): + The name of the query being registered. Required. description (str): - The high-level description of the query being registered. + The high-level description of the query being registered. Only used when using TigerGraph 3.x. docstring (str): The docstring of the query being registered. Includes information about each parameter. + Only used when using TigerGraph 3.x. param_types (Dict[str, str]): The types of the parameters. In the format {"param_name": "param_type"} + Only used when using TigerGraph 3.x. Returns: Hash of query that was registered. """ - data = { - "function_header": function_header, - "description": description, - "docstring": docstring, - "param_types": param_types - } - url = self.nlqs_host+"/"+self.conn.graphname+"/register_docs" - return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None) + if self.if4: + if description or docstring or param_types: + warnings.warn( + """When using TigerGraph 4.x, query descriptions, docstrings, and parameter types are not required parameters. + They will be ignored, and the GSQL descriptions of the queries will be used instead.""", + UserWarning) + data = { + "queries": [query_name] + } + url = self.nlqs_host+"/"+self.conn.graphname+"/upsert_from_gsql" + return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None) + else: + if description is None: + raise ValueError("When using TigerGraph 3.x, query descriptions are required parameters.") + if docstring is None: + raise ValueError("When using TigerGraph 3.x, query docstrings are required parameters.") + if param_types is None: + raise ValueError("When using TigerGraph 3.x, query parameter types are required parameters.") + data = { + "function_header": query_name, + "description": description, + "docstring": docstring, + "param_types": param_types + } + url = self.nlqs_host+"/"+self.conn.graphname+"/register_docs" + return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None) - def updateCustomQuery(self, function_header: str, description: str, docstring: str, param_types: dict = {}): + def updateCustomQuery(self, query_name: str, description: str = None, docstring: str = None, param_types: dict = None): """ Update a custom query with the InquiryAI service. Args: - function_header (str): - The name of the query being updated. + query_name (str): + The name of the query being updated. Required. description (str): The high-level description of the query being updated. + Only used when using TigerGraph 3.x. docstring (str): The docstring of the query being updated. Includes information about each parameter. + Only used when using TigerGraph 3.x. param_types (Dict[str, str]): The types of the parameters. In the format {"param_name": "param_type"} + Only used when using TigerGraph 3.x. Returns: Hash of query that was updated. """ - data = { - "function_header": function_header, - "description": description, - "docstring": docstring, - "param_types": param_types - } + if self.if4: + if description or docstring or param_types: + warnings.warn( + """When using TigerGraph 4.x, query descriptions, docstrings, and parameter types are not required parameters. + They will be ignored, and the GSQL descriptions of the queries will be used instead.""", + UserWarning) + data = { + "queries": [query_name] + } + url = self.nlqs_host+"/"+self.conn.graphname+"/upsert_from_gsql" + return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None) + else: + if description is None: + raise ValueError("When using TigerGraph 3.x, query descriptions are required parameters.") + if docstring is None: + raise ValueError("When using TigerGraph 3.x, query docstrings are required parameters.") + if param_types is None: + raise ValueError("When using TigerGraph 3.x, query parameter types are required parameters.") + data = { + "function_header": query_name, + "description": description, + "docstring": docstring, + "param_types": param_types + } - json_payload = {"query_info": data} - url = self.nlqs_host+"/"+self.conn.graphname+"/upsert_docs" - return self.conn._req("POST", url, authMode="pwd", data = json_payload, jsonData=True, resKey=None) + json_payload = {"query_info": data} + url = self.nlqs_host+"/"+self.conn.graphname+"/upsert_docs" + return self.conn._req("POST", url, authMode="pwd", data = json_payload, jsonData=True, resKey=None) - def deleteCustomQuery(self, function_header: str): + def deleteCustomQuery(self, query_name: str): """ Delete a custom query with the InquiryAI service. Args: - function_header (str): + query_name (str): The name of the query being deleted. Returns: Hash of query that was deleted. """ - data = {"ids": [], "expr": "function_header == '"+function_header+"'"} + data = {"ids": [], "expr": "function_header == '"+query_name+"'"} url = self.nlqs_host+"/"+self.conn.graphname+"/delete_docs" return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None) @@ -231,4 +280,12 @@ def forceConsistencyUpdate(self): JSON response from the consistency update. """ url = self.nlqs_host+"/"+self.conn.graphname+"/supportai/forceupdate" + return self.conn._req("GET", url, authMode="pwd", resKey=None) + + def checkConsistencyProgress(self): + """ Check the progress of the consistency update. + Returns: + JSON response from the consistency update progress. + """ + url = self.nlqs_host+"/"+self.conn.graphname+"/supportai/consistency_status" return self.conn._req("GET", url, authMode="pwd", resKey=None) \ No newline at end of file