From 91f4e7fd0e985ee31ca13d6a9af9fbf2a01f2971 Mon Sep 17 00:00:00 2001 From: Parker Erickson Date: Thu, 25 Apr 2024 11:56:16 -0500 Subject: [PATCH] feat(copilot): adjust query registration to utilize gsql query descriptors if TG 4+ --- pyTigerGraph/ai/ai.py | 96 +++++++++++++++++++++++++++++++------------ 1 file changed, 69 insertions(+), 27 deletions(-) diff --git a/pyTigerGraph/ai/ai.py b/pyTigerGraph/ai/ai.py index bd35d10a..cb68d105 100644 --- a/pyTigerGraph/ai/ai.py +++ b/pyTigerGraph/ai/ai.py @@ -14,6 +14,7 @@ def __init__(self, conn: TigerGraphConnection) -> None: """ self.conn = conn self.nlqs_host = None + self.if4 = conn.getVer().split(".")[0] >= "4" if conn.tgCloud: # split scheme and host scheme, host = conn.host.split("://") @@ -40,63 +41,104 @@ def configureCoPilotHost(self, hostname: str): """ self.nlqs_host = hostname - def registerCustomQuery(self, function_header: str, description: str, docstring: str, param_types: dict = {}): + def registerCustomQuery(self, query_name: str, description: str = None, docstring: str = None, param_types: dict = None): """ Register a custom query with the InquiryAI service. Args: - function_header (str): - The name of the query being registered. + query_name (str): + The name of the query being registered. Required. description (str): - The high-level description of the query being registered. + The high-level description of the query being registered. Only used when using TigerGraph 3.x. docstring (str): The docstring of the query being registered. Includes information about each parameter. + Only used when using TigerGraph 3.x. param_types (Dict[str, str]): The types of the parameters. In the format {"param_name": "param_type"} + Only used when using TigerGraph 3.x. Returns: Hash of query that was registered. """ - data = { - "function_header": function_header, - "description": description, - "docstring": docstring, - "param_types": param_types - } - url = self.nlqs_host+"/"+self.conn.graphname+"/register_docs" - return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None) + if self.if4: + if description or docstring or param_types: + warnings.warn( + """When using TigerGraph 4.x, query descriptions, docstrings, and parameter types are not required parameters. + They will be ignored, and the GSQL descriptions of the queries will be used instead.""", + UserWarning) + data = { + "queries": [query_name] + } + url = self.nlqs_host+"/"+self.conn.graphname+"/upsert_from_gsql" + return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None) + else: + if description is None: + raise ValueError("When using TigerGraph 3.x, query descriptions are required parameters.") + if docstring is None: + raise ValueError("When using TigerGraph 3.x, query docstrings are required parameters.") + if param_types is None: + raise ValueError("When using TigerGraph 3.x, query parameter types are required parameters.") + data = { + "function_header": query_name, + "description": description, + "docstring": docstring, + "param_types": param_types + } + url = self.nlqs_host+"/"+self.conn.graphname+"/register_docs" + return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None) - def updateCustomQuery(self, function_header: str, description: str, docstring: str, param_types: dict = {}): + def updateCustomQuery(self, query_name: str, description: str = None, docstring: str = None, param_types: dict = None): """ Update a custom query with the InquiryAI service. Args: - function_header (str): - The name of the query being updated. + query_name (str): + The name of the query being updated. Required. description (str): The high-level description of the query being updated. + Only used when using TigerGraph 3.x. docstring (str): The docstring of the query being updated. Includes information about each parameter. + Only used when using TigerGraph 3.x. param_types (Dict[str, str]): The types of the parameters. In the format {"param_name": "param_type"} + Only used when using TigerGraph 3.x. Returns: Hash of query that was updated. """ - data = { - "function_header": function_header, - "description": description, - "docstring": docstring, - "param_types": param_types - } + if self.if4: + if description or docstring or param_types: + warnings.warn( + """When using TigerGraph 4.x, query descriptions, docstrings, and parameter types are not required parameters. + They will be ignored, and the GSQL descriptions of the queries will be used instead.""", + UserWarning) + data = { + "queries": [query_name] + } + url = self.nlqs_host+"/"+self.conn.graphname+"/upsert_from_gsql" + return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None) + else: + if description is None: + raise ValueError("When using TigerGraph 3.x, query descriptions are required parameters.") + if docstring is None: + raise ValueError("When using TigerGraph 3.x, query docstrings are required parameters.") + if param_types is None: + raise ValueError("When using TigerGraph 3.x, query parameter types are required parameters.") + data = { + "function_header": query_name, + "description": description, + "docstring": docstring, + "param_types": param_types + } - json_payload = {"query_info": data} - url = self.nlqs_host+"/"+self.conn.graphname+"/upsert_docs" - return self.conn._req("POST", url, authMode="pwd", data = json_payload, jsonData=True, resKey=None) + json_payload = {"query_info": data} + url = self.nlqs_host+"/"+self.conn.graphname+"/upsert_docs" + return self.conn._req("POST", url, authMode="pwd", data = json_payload, jsonData=True, resKey=None) - def deleteCustomQuery(self, function_header: str): + def deleteCustomQuery(self, query_name: str): """ Delete a custom query with the InquiryAI service. Args: - function_header (str): + query_name (str): The name of the query being deleted. Returns: Hash of query that was deleted. """ - data = {"ids": [], "expr": "function_header == '"+function_header+"'"} + data = {"ids": [], "expr": "function_header == '"+query_name+"'"} url = self.nlqs_host+"/"+self.conn.graphname+"/delete_docs" return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None)