Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gml 1660 support jwt token auth in py tiger graph #224

Merged
merged 27 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
e3aac00
GML-1660 add jwttoken support for authentication for restpp and gsql
May 10, 2024
5537745
GML-1660 add dbversion to set authMode
May 10, 2024
b584eb2
GML-1660 use the gerVer() to get dbversion
May 10, 2024
b12759c
GML-1660 use the gerVer() to get dbversion
May 10, 2024
030656c
GML-1660 add logic to check if supporting JWT on RestPP and GSQL
May 24, 2024
80f0503
GML-1660 fix the version configuration
May 24, 2024
ce0d746
add unit test
Jun 6, 2024
3fef8fc
Merge branch 'master' into GML-1660-support-jwt-token-auth-in-py-tige…
Jun 6, 2024
1ad4209
GML-1660 add unit test
Jun 7, 2024
1cf1615
GML-1660 correct error message for test
Jun 7, 2024
e7eaaaa
GML-1660 add connection error
Jun 7, 2024
4f3537f
switch to runtimerror
Jun 8, 2024
dea665f
GML-1660 test authheader
Jun 9, 2024
6fabbc5
GML-1660 create tests graph for jwtauth
Jun 10, 2024
0d20920
GML-1660 move graph creation
Jun 10, 2024
34f05cb
GML-1660 create new graph and generate token
Jun 10, 2024
185d1d4
GML-1660 create jwttoken graph first and generate token on it
Jun 10, 2024
3476060
GML-1660 make connection()
Jun 10, 2024
163236a
GML-1660 swithc to graph Cora
Jun 11, 2024
5bc4f3e
GML-1660 setToken is True
Jun 11, 2024
089165b
GML-1660 change authheader to token
Jun 11, 2024
c55cca4
GML-1660 change authheader to token
Jun 11, 2024
eebf27e
GML-1660 test getVer()
Jun 11, 2024
3403e3c
GML-1660 change the importing make_connection
Jun 12, 2024
f5ec141
GML-1660 remove print
Jun 12, 2024
e69505b
GML-1660 add jwtToken in config file and set it to empty.
Jun 12, 2024
39bca3a
GML-1660 add jwtToken in config file and set it to empty.
Jun 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyTigerGraph/pyTigerGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph",
tgCloud: bool = False, restppPort: Union[int, str] = "9000",
gsPort: Union[int, str] = "14240", gsqlVersion: str = "", version: str = "",
apiToken: str = "", useCert: bool = None, certPath: str = None, debug: bool = None,
sslPort: Union[int, str] = "443", gcp: bool = False):
sslPort: Union[int, str] = "443", gcp: bool = False, jwtToken: str = ""):
super().__init__(host, graphname, gsqlSecret, username, password, tgCloud, restppPort,
gsPort, gsqlVersion, version, apiToken, useCert, certPath, debug, sslPort, gcp)
gsPort, gsqlVersion, version, apiToken, useCert, certPath, debug, sslPort, gcp, jwtToken)

self.gds = None
self.ai = None
Expand Down
166 changes: 144 additions & 22 deletions pyTigerGraph/pyTigerGraphBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import json
import logging
import sys
import re
import warnings
from typing import Union
from urllib.parse import urlparse

import requests

from pyTigerGraph.pyTigerGraphException import TigerGraphException
Expand All @@ -36,7 +36,7 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph",
tgCloud: bool = False, restppPort: Union[int, str] = "9000",
gsPort: Union[int, str] = "14240", gsqlVersion: str = "", version: str = "",
apiToken: str = "", useCert: bool = None, certPath: str = None, debug: bool = None,
sslPort: Union[int, str] = "443", gcp: bool = False):
sslPort: Union[int, str] = "443", gcp: bool = False, jwtToken: str = ""):
"""Initiate a connection object.

Args:
Expand Down Expand Up @@ -76,6 +76,8 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph",
Port for fetching SSL certificate in case of firewall.
gcp:
DEPRECATED. Previously used for connecting to databases provisioned on GCP in TigerGraph Cloud.
jwtToken:
The JWT token generated from customer side for authentication

Raises:
TigerGraphException: In case on invalid URL scheme.
Expand All @@ -100,12 +102,19 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph",
self.graphname = graphname
self.responseConfigHeader = {}
self.awsIamHeaders={}

self.jwtToken = jwtToken
self.apiToken = apiToken
self.base64_credential = base64.b64encode(
"{0}:{1}".format(self.username, self.password).encode("utf-8")).decode("utf-8")

self.authHeader = self._set_auth_header()

# TODO Remove apiToken parameter
if apiToken:
warnings.warn(
"The `apiToken` parameter is deprecated; use `getToken()` function instead.",
DeprecationWarning)
self.apiToken = apiToken

# TODO Eliminate version and use gsqlVersion only, meaning TigerGraph server version
if gsqlVersion:
Expand All @@ -117,12 +126,6 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph",
self.version = version
else:
self.version = ""
self.base64_credential = base64.b64encode(
"{0}:{1}".format(self.username, self.password).encode("utf-8")).decode("utf-8")
if self.apiToken:
self.authHeader = {"Authorization": "Bearer " + self.apiToken}
else:
self.authHeader = {"Authorization": "Basic {0}".format(self.base64_credential)}

if debug is not None:
warnings.warn(
Expand Down Expand Up @@ -200,8 +203,39 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph",
self.awsIamHeaders["X-Amz-Security-Token"] = request.headers["X-Amz-Security-Token"]
self.awsIamHeaders["Authorization"] = request.headers["Authorization"]

if self.jwtToken:
self._verify_jwt_token_support()

logger.info("exit: __init__")

def _set_auth_header(self):
"""Set the authentication header based on available tokens or credentials."""
if self.jwtToken:
return {"Authorization": "Bearer " + self.jwtToken}
elif self.apiToken:
return {"Authorization": "Bearer " + self.apiToken}
else:
return {"Authorization": "Basic {0}".format(self.base64_credential)}

def _verify_jwt_token_support(self):
try:
# Check JWT support for RestPP server
logger.debug("Attempting to verify JWT token support with getVer() on RestPP server.")
logger.debug(f"Using auth header: {self.authHeader}")
version = self.getVer()
# logger.info(f"Database version: {version}")

# Check JWT support for GSQL server
logger.debug(f"Attempting to get auth info with URL: {self.gsUrl + '/gsqlserver/gsql/simpleauth'}")
self._get(f"{self.gsUrl}/gsqlserver/gsql/simpleauth", authMode="token", resKey=None)
except requests.exceptions.ConnectionError as e:
logger.error(f"Connection error: {e}.")
raise RuntimeError(f"Connection error: {e}.") from e
except Exception as e:
message = "The JWT token might be invalid or expired or DB version doesn't support JWT token. Please generate new JWT token or switch to API token or username/password."
logger.error(f"Error occurred: {e}. {message}")
raise RuntimeError(message) from e

def _locals(self, _locals: dict) -> str:
del _locals["self"]
return str(_locals)
Expand Down Expand Up @@ -257,20 +291,30 @@ def _req(self, method: str, url: str, authMode: str = "token", headers: dict = N
if logger.level == logging.DEBUG:
logger.debug("params: " + self._locals(locals()))

if authMode == "token" and str(self.apiToken) != "":
if isinstance(self.apiToken, tuple):
self.apiToken = self.apiToken[0]
self.authHeader = {'Authorization': "Bearer " + self.apiToken}
_headers = self.authHeader
else:
self.authHeader = {'Authorization': 'Basic {0}'.format(self.base64_credential)}
_headers = self.authHeader
authMode = 'pwd'
# If JWT token is provided, always use jwtToken as token
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is jwtToken also used when authMode is pwd?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently no, still use pwd. in the init, if the db doesn't support jwt, it will raise an error asking users to use username and password.

if authMode == "token":
if isinstance(self.jwtToken, str) and self.jwtToken.strip() != "":
token = self.jwtToken
elif isinstance(self.apiToken, tuple):
token = self.apiToken[0]
elif isinstance(self.apiToken, str) and self.apiToken.strip() != "":
token = self.apiToken
else:
token = None

if token:
self.authHeader = {'Authorization': "Bearer " + token}
_headers = self.authHeader
else:
self.authHeader = {'Authorization': 'Basic {0}'.format(self.base64_credential)}
_headers = self.authHeader
authMode = 'pwd'

if authMode == "pwd":
_auth = (self.username, self.password)
else:
_auth = None
if self.jwtToken:
_headers = {'Authorization': "Bearer " + self.jwtToken}
else:
_headers = {'Authorization': 'Basic {0}'.format(self.base64_credential)}
if headers:
_headers.update(headers)
if self.awsIamHeaders:
Expand Down Expand Up @@ -448,4 +492,82 @@ def customizeHeader(self, timeout:int = 16_000, responseSize:int = 3.2e+7):
Returns:
Nothing. Sets `responseConfigHeader` class attribute.
"""
self.responseConfigHeader = {"GSQL-TIMEOUT": str(timeout), "RESPONSE-LIMIT": str(responseSize)}
self.responseConfigHeader = {"GSQL-TIMEOUT": str(timeout), "RESPONSE-LIMIT": str(responseSize)}

def getVersion(self, raw: bool = False) -> Union[str, list]:
"""Retrieves the git versions of all components of the system.

Args:
raw:
Return unprocessed version info string, or extract version info for each component
into a list.

Returns:
Either an unprocessed string containing the version info details, or a list with version
info for each component.

Endpoint:
- `GET /version`
See xref:tigergraph-server:API:built-in-endpoints.adoc#_show_component_versions[Show component versions]
"""
logger.info("entry: getVersion")
if logger.level == logging.DEBUG:
logger.debug("params: " + self._locals(locals()))

response = self._get(self.restppUrl+"/version", strictJson=False, resKey="message")

if raw:
return response
res = response.split("\n")
components = []
for i in range(len(res)):
if 2 < i < len(res) - 1:
m = res[i].split()
component = {"name": m[0], "version": m[1], "hash": m[2],
"datetime": m[3] + " " + m[4] + " " + m[5]}
components.append(component)

if logger.level == logging.DEBUG:
logger.debug("return: " + str(components))
logger.info("exit: getVersion")

return components

def getVer(self, component: str = "product", full: bool = False) -> str:
"""Gets the version information of a specific component.

Get the full list of components using `getVersion()`.

Args:
component:
One of TigerGraph's components (e.g. product, gpe, gse).
full:
Return the full version string (with timestamp, etc.) or just X.Y.Z.

Returns:
Version info for specified component.

Raises:
`TigerGraphException` if invalid/non-existent component is specified.
"""
logger.info("entry: getVer")
if logger.level == logging.DEBUG:
logger.debug("params: " + self._locals(locals()))

ret = ""
for v in self.getVersion():
if v["name"] == component.lower():
ret = v["version"]
if ret != "":
if full:
return ret
ret = re.search("_.+_", ret)
ret = ret.group().strip("_")

if logger.level == logging.DEBUG:
logger.debug("return: " + str(ret))
logger.info("exit: getVer")

return ret
else:
raise TigerGraphException("\"" + component + "\" is not a valid component.", None)
79 changes: 0 additions & 79 deletions pyTigerGraph/pyTigerGraphUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
"""
import json
import logging
import re
import urllib
from typing import Any, Union
from urllib.parse import urlparse
Expand Down Expand Up @@ -72,84 +71,6 @@ def echo(self, usePost: bool = False) -> str:

return ret

def getVersion(self, raw: bool = False) -> Union[str, list]:
"""Retrieves the git versions of all components of the system.

Args:
raw:
Return unprocessed version info string, or extract version info for each component
into a list.

Returns:
Either an unprocessed string containing the version info details, or a list with version
info for each component.

Endpoint:
- `GET /version`
See xref:tigergraph-server:API:built-in-endpoints.adoc#_show_component_versions[Show component versions]
"""
logger.info("entry: getVersion")
if logger.level == logging.DEBUG:
logger.debug("params: " + self._locals(locals()))

response = self._get(self.restppUrl+"/version", strictJson=False, resKey="message")

if raw:
return response
res = response.split("\n")
components = []
for i in range(len(res)):
if 2 < i < len(res) - 1:
m = res[i].split()
component = {"name": m[0], "version": m[1], "hash": m[2],
"datetime": m[3] + " " + m[4] + " " + m[5]}
components.append(component)

if logger.level == logging.DEBUG:
logger.debug("return: " + str(components))
logger.info("exit: getVersion")

return components

def getVer(self, component: str = "product", full: bool = False) -> str:
"""Gets the version information of a specific component.

Get the full list of components using `getVersion()`.

Args:
component:
One of TigerGraph's components (e.g. product, gpe, gse).
full:
Return the full version string (with timestamp, etc.) or just X.Y.Z.

Returns:
Version info for specified component.

Raises:
`TigerGraphException` if invalid/non-existent component is specified.
"""
logger.info("entry: getVer")
if logger.level == logging.DEBUG:
logger.debug("params: " + self._locals(locals()))

ret = ""
for v in self.getVersion():
if v["name"] == component.lower():
ret = v["version"]
if ret != "":
if full:
return ret
ret = re.search("_.+_", ret)
ret = ret.group().strip("_")

if logger.level == logging.DEBUG:
logger.debug("return: " + str(ret))
logger.info("exit: getVer")

return ret
else:
raise TigerGraphException("\"" + component + "\" is not a valid component.", None)

def getLicenseInfo(self) -> dict:
"""Returns the expiration date and remaining days of the license.

Expand Down
2 changes: 2 additions & 0 deletions tests/pyTigerGraphUnitTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def make_connection(graphname: str = None):
"sslPort": "443",
"tgCloud": False,
"gcp": False,
"jwtToken": ""
}

path = os.path.dirname(os.path.realpath(__file__))
Expand All @@ -42,6 +43,7 @@ def make_connection(graphname: str = None):
certPath=server_config["certPath"],
sslPort=server_config["sslPort"],
gcp=server_config["gcp"],
jwtToken=server_config["jwtToken"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we support this in our current testing pipeline?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we need to modify the config file to add the jwtToken field but set it to empty str to avoid actually using it now.

Meanwhile, we need to figure out with QE team how to config the DB to use jwtToken

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just double check here: does server_config contains this key "jwtToken"? I don't see the config file modified.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I just add the jwtToken to the config file and set it to empty.

Just double check here: does server_config contains this key "jwtToken"? I don't see the config file modified.

)
if server_config.get("getToken", False):
conn.getToken(conn.createSecret())
Expand Down
Loading