Skip to content

Commit

Permalink
Merge pull request #213 from tigergraph/master
Browse files Browse the repository at this point in the history
v1.5.2 patch
  • Loading branch information
parkererickson-tg authored Feb 15, 2024
2 parents 1cd09d6 + 061e105 commit d4afb4e
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyTigerGraph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pyTigerGraph.pyTigerGraph import TigerGraphConnection

__version__ = "1.5.1"
__version__ = "1.5.2"

__license__ = "Apache 2"
Empty file added pyTigerGraph/ai/__init__.py
Empty file.
77 changes: 77 additions & 0 deletions pyTigerGraph/ai/ai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import json

class AI:
def __init__(self, conn: "TigerGraphConnection") -> None:
"""NO DOC: Initiate an AI object. Currently in beta testing.
Args:
conn (TigerGraphConnection):
Accept a TigerGraphConnection to run queries with
Returns:
None
"""
self.conn = conn
self.nlqs_host = None

def configureInquiryAIHost(self, hostname: str):
""" Configure the hostname of the InquiryAI service.
Args:
hostname (str):
The hostname (and port number) of the InquiryAI serivce.
"""
self.nlqs_host = hostname

def registerCustomQuery(self, function_header: str, description: str, docstring: str, param_types: dict = {}):
""" Register a custom query with the InquiryAI service.
Args:
function_header (str):
The name of the query being registered.
description (str):
The high-level description of the query being registered.
docstring (str):
The docstring of the query being registered. Includes information about each parameter.
param_types (Dict[str, str]):
The types of the parameters. In the format {"param_name": "param_type"}
Returns:
Hash of query that was registered.
"""
data = {
"function_header": function_header,
"description": description,
"docstring": docstring,
"param_types": param_types
}
url = self.nlqs_host+"/"+self.conn.graphname+"/registercustomquery"
return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None)

def retrieveDocs(self, query:str, top_k:int = 3):
""" Retrieve docs from the vector store.
Args:
query (str):
The natural language query to retrieve docs with.
top_k (int):
The number of docs to retrieve.
Returns:
List of docs retrieved.
"""
data = {
"query": query
}

url = self.nlqs_host+"/"+self.conn.graphname+"/retrievedocs?top_k="+str(top_k)
return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None, skipCheck=True)

def query(self, query):
""" Query the database with natural language.
Args:
query (str):
Natural language query to ask about the database.
Returns:
JSON including the natural language response, a answered_question flag, and answer sources.
"""
data = {
"query": query
}

url = self.nlqs_host+"/"+self.conn.graphname+"/query"
return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None)
13 changes: 13 additions & 0 deletions pyTigerGraph/pyTigerGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph",
gsPort, gsqlVersion, version, apiToken, useCert, certPath, debug, sslPort, gcp)

self.gds = None
self.ai = None

def __getattribute__(self, name):
if name == "gds":
Expand All @@ -50,6 +51,18 @@ def __getattribute__(self, name):
"Check the https://docs.tigergraph.com/pytigergraph/current/getting-started/install#_install_pytigergraphgds for more details.")
else:
return super().__getattribute__(name)
elif name == "ai":
if super().__getattribute__(name) is None:
try:
from .ai import ai
self.ai = ai.AI(self)
return super().__getattribute__(name)
except Exception as e:
raise Exception(
"Error importing AI submodule. "+str(e)
)
else:
return super().__getattribute__(name)
else:
return super().__getattribute__(name)

Expand Down
9 changes: 6 additions & 3 deletions pyTigerGraph/pyTigerGraphAuth.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def dropSecret(self, alias: Union[str, list], ignoreErrors: bool = True) -> str:

return res

def getToken(self, secret: str = None, setToken: bool = True, lifetime: int = None) -> tuple:
def getToken(self, secret: str = None, setToken: bool = True, lifetime: int = None) -> Union[tuple, str]:
"""Requests an authorization token.
This function returns a token only if REST++ authentication is enabled. If not, an exception
Expand All @@ -194,9 +194,12 @@ def getToken(self, secret: str = None, setToken: bool = True, lifetime: int = No
Duration of token validity (in seconds, default 30 days = 2,592,000 seconds).
Returns:
A tuple of `(<token>, <expiration_timestamp_unixtime>, <expiration_timestamp_ISO8601>)`.
If your TigerGraph instance is running version 3.5 or before, the return value is
a tuple of `(<token>, <expiration_timestamp_unixtime>, <expiration_timestamp_ISO8601>)`.
The return value can be ignored, as the token is automatically set for the connection after this call.
If your TigerGraph instance is running version 3.6 or later, the return value is just the token.
[NOTE]
The expiration timestamp's time zone might be different from your computer's local time
zone.
Expand Down Expand Up @@ -231,7 +234,7 @@ def getToken(self, secret: str = None, setToken: bool = True, lifetime: int = No
elif not(success) and not(secret):
res = self._post(self.restppUrl+"/requesttoken", authMode="pwd", data=str({"graph": self.graphname}), resKey="results")
success = True
else:
elif not(success) and (int(s) < 3 or (int(s) == 3 and int(m) < 5)):
raise TigerGraphException("Cannot request a token with username/password for versions < 3.5.")


Expand Down

0 comments on commit d4afb4e

Please sign in to comment.