Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GML 1748 fixing existing gsql api calls #234

Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
e3aac00
GML-1660 add jwttoken support for authentication for restpp and gsql
May 10, 2024
5537745
GML-1660 add dbversion to set authMode
May 10, 2024
b584eb2
GML-1660 use the gerVer() to get dbversion
May 10, 2024
b12759c
GML-1660 use the gerVer() to get dbversion
May 10, 2024
030656c
GML-1660 add logic to check if supporting JWT on RestPP and GSQL
May 24, 2024
80f0503
GML-1660 fix the version configuration
May 24, 2024
ce0d746
add unit test
Jun 6, 2024
3fef8fc
Merge branch 'master' into GML-1660-support-jwt-token-auth-in-py-tige…
Jun 6, 2024
1ad4209
GML-1660 add unit test
Jun 7, 2024
1cf1615
GML-1660 correct error message for test
Jun 7, 2024
e7eaaaa
GML-1660 add connection error
Jun 7, 2024
4f3537f
switch to runtimerror
Jun 8, 2024
dea665f
GML-1660 test authheader
Jun 9, 2024
6fabbc5
GML-1660 create tests graph for jwtauth
Jun 10, 2024
0d20920
GML-1660 move graph creation
Jun 10, 2024
34f05cb
GML-1660 create new graph and generate token
Jun 10, 2024
185d1d4
GML-1660 create jwttoken graph first and generate token on it
Jun 10, 2024
3476060
GML-1660 make connection()
Jun 10, 2024
163236a
GML-1660 swithc to graph Cora
Jun 11, 2024
5bc4f3e
GML-1660 setToken is True
Jun 11, 2024
089165b
GML-1660 change authheader to token
Jun 11, 2024
c55cca4
GML-1660 change authheader to token
Jun 11, 2024
eebf27e
GML-1660 test getVer()
Jun 11, 2024
3403e3c
GML-1660 change the importing make_connection
Jun 12, 2024
f5ec141
GML-1660 remove print
Jun 12, 2024
e69505b
GML-1660 add jwtToken in config file and set it to empty.
Jun 12, 2024
39bca3a
GML-1660 add jwtToken in config file and set it to empty.
Jun 12, 2024
26919bd
Merge pull request #224 from tigergraph/GML-1660-support-jwt-token-au…
luzhoutg Jun 12, 2024
e3fb819
naive changing of endpoints and adding >4.1 version checking
Jun 17, 2024
d659a1c
adding changes to GSQL endpoints
Jun 18, 2024
fd5d141
changed version REST API url
Jun 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions pyTigerGraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,3 @@
__version__ = "1.6.2"

__license__ = "Apache 2"

#dummy comment
4 changes: 2 additions & 2 deletions pyTigerGraph/pyTigerGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph",
tgCloud: bool = False, restppPort: Union[int, str] = "9000",
gsPort: Union[int, str] = "14240", gsqlVersion: str = "", version: str = "",
apiToken: str = "", useCert: bool = None, certPath: str = None, debug: bool = None,
sslPort: Union[int, str] = "443", gcp: bool = False):
sslPort: Union[int, str] = "443", gcp: bool = False, jwtToken: str = ""):
super().__init__(host, graphname, gsqlSecret, username, password, tgCloud, restppPort,
gsPort, gsqlVersion, version, apiToken, useCert, certPath, debug, sslPort, gcp)
gsPort, gsqlVersion, version, apiToken, useCert, certPath, debug, sslPort, gcp, jwtToken)

self.gds = None
self.ai = None
Expand Down
177 changes: 155 additions & 22 deletions pyTigerGraph/pyTigerGraphBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import json
import logging
import sys
import re
import warnings
from typing import Union
from urllib.parse import urlparse

import requests

from pyTigerGraph.pyTigerGraphException import TigerGraphException
Expand All @@ -36,7 +36,7 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph",
tgCloud: bool = False, restppPort: Union[int, str] = "9000",
gsPort: Union[int, str] = "14240", gsqlVersion: str = "", version: str = "",
apiToken: str = "", useCert: bool = None, certPath: str = None, debug: bool = None,
sslPort: Union[int, str] = "443", gcp: bool = False):
sslPort: Union[int, str] = "443", gcp: bool = False, jwtToken: str = ""):
"""Initiate a connection object.

Args:
Expand Down Expand Up @@ -76,6 +76,8 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph",
Port for fetching SSL certificate in case of firewall.
gcp:
DEPRECATED. Previously used for connecting to databases provisioned on GCP in TigerGraph Cloud.
jwtToken:
The JWT token generated from customer side for authentication

Raises:
TigerGraphException: In case on invalid URL scheme.
Expand All @@ -100,12 +102,19 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph",
self.graphname = graphname
self.responseConfigHeader = {}
self.awsIamHeaders={}

self.jwtToken = jwtToken
self.apiToken = apiToken
self.base64_credential = base64.b64encode(
"{0}:{1}".format(self.username, self.password).encode("utf-8")).decode("utf-8")

self.authHeader = self._set_auth_header()

# TODO Remove apiToken parameter
if apiToken:
warnings.warn(
"The `apiToken` parameter is deprecated; use `getToken()` function instead.",
DeprecationWarning)
self.apiToken = apiToken

# TODO Eliminate version and use gsqlVersion only, meaning TigerGraph server version
if gsqlVersion:
Expand All @@ -117,12 +126,6 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph",
self.version = version
else:
self.version = ""
self.base64_credential = base64.b64encode(
"{0}:{1}".format(self.username, self.password).encode("utf-8")).decode("utf-8")
if self.apiToken:
self.authHeader = {"Authorization": "Bearer " + self.apiToken}
else:
self.authHeader = {"Authorization": "Basic {0}".format(self.base64_credential)}

if debug is not None:
warnings.warn(
Expand Down Expand Up @@ -200,8 +203,43 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph",
self.awsIamHeaders["X-Amz-Security-Token"] = request.headers["X-Amz-Security-Token"]
self.awsIamHeaders["Authorization"] = request.headers["Authorization"]

if self.jwtToken:
self._verify_jwt_token_support()

logger.info("exit: __init__")

def _set_auth_header(self):
"""Set the authentication header based on available tokens or credentials."""
if self.jwtToken:
return {"Authorization": "Bearer " + self.jwtToken}
elif self.apiToken:
return {"Authorization": "Bearer " + self.apiToken}
else:
return {"Authorization": "Basic {0}".format(self.base64_credential)}

def _verify_jwt_token_support(self):
try:
# Check JWT support for RestPP server
logger.debug("Attempting to verify JWT token support with getVer() on RestPP server.")
logger.debug(f"Using auth header: {self.authHeader}")
version = self.getVer()
logger.info(f"Database version: {version}")

# Check JWT support for GSQL server
if self._versionGreaterThan4_0():
logger.debug(f"Attempting to get auth info with URL: {self.gsUrl + '/gsqlserver/gsql/simpleauth'}")
self._get(f"{self.gsUrl}/gsqlserver/gsql/v1/auth/simple", authMode="token", resKey=None)
else:
logger.debug(f"Attempting to get auth info with URL: {self.gsUrl + '/gsqlserver/gsql/simpleauth'}")
self._get(f"{self.gsUrl}/gsqlserver/gsql/simpleauth", authMode="token", resKey=None)
except requests.exceptions.ConnectionError as e:
logger.error(f"Connection error: {e}.")
raise RuntimeError(f"Connection error: {e}.") from e
except Exception as e:
message = "The JWT token might be invalid or expired or DB version doesn't support JWT token. Please generate new JWT token or switch to API token or username/password."
logger.error(f"Error occurred: {e}. {message}")
raise RuntimeError(message) from e

def _locals(self, _locals: dict) -> str:
del _locals["self"]
return str(_locals)
Expand Down Expand Up @@ -257,20 +295,31 @@ def _req(self, method: str, url: str, authMode: str = "token", headers: dict = N
if logger.level == logging.DEBUG:
logger.debug("params: " + self._locals(locals()))

if authMode == "token" and str(self.apiToken) != "":
if isinstance(self.apiToken, tuple):
self.apiToken = self.apiToken[0]
self.authHeader = {'Authorization': "Bearer " + self.apiToken}
_headers = self.authHeader
else:
self.authHeader = {'Authorization': 'Basic {0}'.format(self.base64_credential)}
_headers = self.authHeader
authMode = 'pwd'
# If JWT token is provided, always use jwtToken as token
if authMode == "token":
if isinstance(self.jwtToken, str) and self.jwtToken.strip() != "":
token = self.jwtToken
elif isinstance(self.apiToken, tuple):
token = self.apiToken[0]
elif isinstance(self.apiToken, str) and self.apiToken.strip() != "":
token = self.apiToken
else:
token = None

if token:
self.authHeader = {'Authorization': "Bearer " + token}
_headers = self.authHeader
else:
self.authHeader = {'Authorization': 'Basic {0}'.format(self.base64_credential)}
_headers = self.authHeader
authMode = 'pwd'

if authMode == "pwd":
_auth = (self.username, self.password)
else:
_auth = None
if self.jwtToken:
_headers = {'Authorization': "Bearer " + self.jwtToken}
else:
_headers = {'Authorization': 'Basic {0}'.format(self.base64_credential)}

if headers:
_headers.update(headers)
if self.awsIamHeaders:
Expand Down Expand Up @@ -448,4 +497,88 @@ def customizeHeader(self, timeout:int = 16_000, responseSize:int = 3.2e+7):
Returns:
Nothing. Sets `responseConfigHeader` class attribute.
"""
self.responseConfigHeader = {"GSQL-TIMEOUT": str(timeout), "RESPONSE-LIMIT": str(responseSize)}
self.responseConfigHeader = {"GSQL-TIMEOUT": str(timeout), "RESPONSE-LIMIT": str(responseSize)}

def getVersion(self, raw: bool = False) -> Union[str, list]:
"""Retrieves the git versions of all components of the system.

Args:
raw:
Return unprocessed version info string, or extract version info for each component
into a list.

Returns:
Either an unprocessed string containing the version info details, or a list with version
info for each component.

Endpoint:
- `GET /version`
See xref:tigergraph-server:API:built-in-endpoints.adoc#_show_component_versions[Show component versions]
"""
logger.info("entry: getVersion")
if logger.level == logging.DEBUG:
logger.debug("params: " + self._locals(locals()))

response = self._get(self.restppUrl+"/version", strictJson=False, resKey="message")

if raw:
return response
res = response.split("\n")
components = []
for i in range(len(res)):
if 2 < i < len(res) - 1:
m = res[i].split()
component = {"name": m[0], "version": m[1], "hash": m[2],
"datetime": m[3] + " " + m[4] + " " + m[5]}
components.append(component)

if logger.level == logging.DEBUG:
logger.debug("return: " + str(components))
logger.info("exit: getVersion")

return components

def getVer(self, component: str = "product", full: bool = False) -> str:
"""Gets the version information of a specific component.

Get the full list of components using `getVersion()`.

Args:
component:
One of TigerGraph's components (e.g. product, gpe, gse).
full:
Return the full version string (with timestamp, etc.) or just X.Y.Z.

Returns:
Version info for specified component.

Raises:
`TigerGraphException` if invalid/non-existent component is specified.
"""
logger.info("entry: getVer")
if logger.level == logging.DEBUG:
logger.debug("params: " + self._locals(locals()))

ret = ""
for v in self.getVersion():
if v["name"] == component.lower():
ret = v["version"]
if ret != "":
if full:
return ret
ret = re.search("_.+_", ret)
ret = ret.group().strip("_")

if logger.level == logging.DEBUG:
logger.debug("return: " + str(ret))
logger.info("exit: getVer")

return ret
else:
raise TigerGraphException("\"" + component + "\" is not a valid component.", None)

def _versionGreaterThan4_0(self):
version = self.getVer().split('.')
if version[0]>="4" and version[1]>"0":
return True
return False
56 changes: 42 additions & 14 deletions pyTigerGraph/pyTigerGraphGSQL.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,18 @@ def clean_res(resp: list) -> str:
if str(graphname).upper() == "GLOBAL" or str(graphname).upper() == "":
graphname = ""

res = self._req("POST",
self.gsUrl + "/gsqlserver/gsql/file",
if self._versionGreaterThan4_0():
res = self._req("POST",
self.gsUrl + "/gsqlserver/gsql/v1/statements",
data=quote_plus(query.encode("utf-8")),
authMode="pwd", resKey=None, skipCheck=True,
jsonResponse=False)
else:
res = self._req("POST",
self.gsUrl + "/gsqlserver/gsql/file",
data=quote_plus(query.encode("utf-8")),
authMode="pwd", resKey=None, skipCheck=True,
jsonResponse=False)


if isinstance(res, list):
Expand Down Expand Up @@ -120,9 +127,15 @@ def installUDF(self, ExprFunctions: str = "", ExprUtil: str = "") -> None:
# A local file: read from disk.
with open(ExprFunctions) as infile:
data = infile.read()
res = self._req("PUT",
url="{}/gsqlserver/gsql/userdefinedfunction?filename=ExprFunctions".format(
self.gsUrl), authMode="pwd", data=data, resKey="")

if self._versionGreaterThan4_0():
res = self._req("PUT",
url="{}/gsqlserver/gsql/v1/udt/files/ExprFunctions".format(
self.gsUrl), authMode="pwd", data=data, resKey="")
else:
res = self._req("PUT",
url="{}/gsqlserver/gsql/userdefinedfunction?filename=ExprFunctions".format(
self.gsUrl), authMode="pwd", data=data, resKey="")
if not res["error"]:
logger.info("ExprFunctions installed successfully")
else:
Expand All @@ -137,9 +150,14 @@ def installUDF(self, ExprFunctions: str = "", ExprUtil: str = "") -> None:
# A local file: read from disk.
with open(ExprUtil) as infile:
data = infile.read()
res = self._req("PUT",
url="{}/gsqlserver/gsql/userdefinedfunction?filename=ExprUtil".format(self.gsUrl),
authMode="pwd", data=data, resKey="")
if self._versionGreaterThan4_0():
res = self._req("PUT",
url="{}/gsqlserver/gsql/v1/udt/files/ExprUtil".format(self.gsUrl),
authMode="pwd", data=data, resKey="")
else:
res = self._req("PUT",
url="{}/gsqlserver/gsql/userdefinedfunction?filename=ExprUtil".format(self.gsUrl),
authMode="pwd", data=data, resKey="")
if not res["error"]:
logger.info("ExprUtil installed successfully")
else:
Expand Down Expand Up @@ -172,9 +190,14 @@ def getUDF(self, ExprFunctions: bool = True, ExprUtil: bool = True) -> Union[str

functions_ret = None
if ExprFunctions:
resp = self._get(
"{}/gsqlserver/gsql/userdefinedfunction".format(self.gsUrl),
params={"filename": "ExprFunctions"}, resKey="")
if self._versionGreaterThan4_0():
resp = self._get(
"{}/gsqlserver/gsql/v1/udt/files/ExprFunctions".format(self.gsUrl),
resKey="")
else:
resp = self._get(
"{}/gsqlserver/gsql/userdefinedfunction".format(self.gsUrl),
params={"filename": "ExprFunctions"}, resKey="")
if not resp["error"]:
logger.info("ExprFunctions get successfully")
functions_ret = resp["results"]
Expand All @@ -184,9 +207,14 @@ def getUDF(self, ExprFunctions: bool = True, ExprUtil: bool = True) -> Union[str

util_ret = None
if ExprUtil:
resp = self._get(
"{}/gsqlserver/gsql/userdefinedfunction".format(self.gsUrl),
params={"filename": "ExprUtil"}, resKey="")
if self._versionGreaterThan4_0():
resp = self._get(
"{}/gsqlserver/gsql/v1/udt/files/ExprUtil".format(self.gsUrl),
resKey="")
else:
resp = self._get(
"{}/gsqlserver/gsql/userdefinedfunction".format(self.gsUrl),
params={"filename": "ExprUtil"}, resKey="")
if not resp["error"]:
logger.info("ExprUtil get successfully")
util_ret = resp["results"]
Expand Down
Loading