diff --git a/pyTigerGraph/pyTigerGraph.py b/pyTigerGraph/pyTigerGraph.py index 3e3cbb4c..5c2e909b 100644 --- a/pyTigerGraph/pyTigerGraph.py +++ b/pyTigerGraph/pyTigerGraph.py @@ -31,9 +31,9 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph", tgCloud: bool = False, restppPort: Union[int, str] = "9000", gsPort: Union[int, str] = "14240", gsqlVersion: str = "", version: str = "", apiToken: str = "", useCert: bool = None, certPath: str = None, debug: bool = None, - sslPort: Union[int, str] = "443", gcp: bool = False): + sslPort: Union[int, str] = "443", gcp: bool = False, jwtToken: str = ""): super().__init__(host, graphname, gsqlSecret, username, password, tgCloud, restppPort, - gsPort, gsqlVersion, version, apiToken, useCert, certPath, debug, sslPort, gcp) + gsPort, gsqlVersion, version, apiToken, useCert, certPath, debug, sslPort, gcp, jwtToken) self.gds = None self.ai = None diff --git a/pyTigerGraph/pyTigerGraphBase.py b/pyTigerGraph/pyTigerGraphBase.py index e4cc91ab..3e3cc27a 100644 --- a/pyTigerGraph/pyTigerGraphBase.py +++ b/pyTigerGraph/pyTigerGraphBase.py @@ -8,10 +8,10 @@ import json import logging import sys +import re import warnings from typing import Union from urllib.parse import urlparse - import requests from pyTigerGraph.pyTigerGraphException import TigerGraphException @@ -36,7 +36,7 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph", tgCloud: bool = False, restppPort: Union[int, str] = "9000", gsPort: Union[int, str] = "14240", gsqlVersion: str = "", version: str = "", apiToken: str = "", useCert: bool = None, certPath: str = None, debug: bool = None, - sslPort: Union[int, str] = "443", gcp: bool = False): + sslPort: Union[int, str] = "443", gcp: bool = False, jwtToken: str = ""): """Initiate a connection object. Args: @@ -76,6 +76,8 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph", Port for fetching SSL certificate in case of firewall. gcp: DEPRECATED. Previously used for connecting to databases provisioned on GCP in TigerGraph Cloud. + jwtToken: + The JWT token generated from customer side for authentication Raises: TigerGraphException: In case on invalid URL scheme. @@ -100,12 +102,19 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph", self.graphname = graphname self.responseConfigHeader = {} self.awsIamHeaders={} + + self.jwtToken = jwtToken + self.apiToken = apiToken + self.base64_credential = base64.b64encode( + "{0}:{1}".format(self.username, self.password).encode("utf-8")).decode("utf-8") + + self.authHeader = self._set_auth_header() + # TODO Remove apiToken parameter if apiToken: warnings.warn( "The `apiToken` parameter is deprecated; use `getToken()` function instead.", DeprecationWarning) - self.apiToken = apiToken # TODO Eliminate version and use gsqlVersion only, meaning TigerGraph server version if gsqlVersion: @@ -117,12 +126,6 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph", self.version = version else: self.version = "" - self.base64_credential = base64.b64encode( - "{0}:{1}".format(self.username, self.password).encode("utf-8")).decode("utf-8") - if self.apiToken: - self.authHeader = {"Authorization": "Bearer " + self.apiToken} - else: - self.authHeader = {"Authorization": "Basic {0}".format(self.base64_credential)} if debug is not None: warnings.warn( @@ -200,8 +203,39 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph", self.awsIamHeaders["X-Amz-Security-Token"] = request.headers["X-Amz-Security-Token"] self.awsIamHeaders["Authorization"] = request.headers["Authorization"] + if self.jwtToken: + self._verify_jwt_token_support() + logger.info("exit: __init__") + def _set_auth_header(self): + """Set the authentication header based on available tokens or credentials.""" + if self.jwtToken: + return {"Authorization": "Bearer " + self.jwtToken} + elif self.apiToken: + return {"Authorization": "Bearer " + self.apiToken} + else: + return {"Authorization": "Basic {0}".format(self.base64_credential)} + + def _verify_jwt_token_support(self): + try: + # Check JWT support for RestPP server + logger.debug("Attempting to verify JWT token support with getVer() on RestPP server.") + logger.debug(f"Using auth header: {self.authHeader}") + version = self.getVer() + logger.info(f"Database version: {version}") + + # Check JWT support for GSQL server + logger.debug(f"Attempting to get auth info with URL: {self.gsUrl + '/gsqlserver/gsql/simpleauth'}") + self._get(f"{self.gsUrl}/gsqlserver/gsql/simpleauth", authMode="token", resKey=None) + except requests.exceptions.ConnectionError as e: + logger.error(f"Connection error: {e}.") + raise RuntimeError(f"Connection error: {e}.") from e + except Exception as e: + message = "The JWT token might be invalid or expired or DB version doesn't support JWT token. Please generate new JWT token or switch to API token or username/password." + logger.error(f"Error occurred: {e}. {message}") + raise RuntimeError(message) from e + def _locals(self, _locals: dict) -> str: del _locals["self"] return str(_locals) @@ -257,20 +291,31 @@ def _req(self, method: str, url: str, authMode: str = "token", headers: dict = N if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) - if authMode == "token" and str(self.apiToken) != "": - if isinstance(self.apiToken, tuple): - self.apiToken = self.apiToken[0] - self.authHeader = {'Authorization': "Bearer " + self.apiToken} - _headers = self.authHeader - else: - self.authHeader = {'Authorization': 'Basic {0}'.format(self.base64_credential)} - _headers = self.authHeader - authMode = 'pwd' + # If JWT token is provided, always use jwtToken as token + if authMode == "token": + if isinstance(self.jwtToken, str) and self.jwtToken.strip() != "": + token = self.jwtToken + elif isinstance(self.apiToken, tuple): + token = self.apiToken[0] + elif isinstance(self.apiToken, str) and self.apiToken.strip() != "": + token = self.apiToken + else: + token = None + + if token: + self.authHeader = {'Authorization': "Bearer " + token} + _headers = self.authHeader + else: + self.authHeader = {'Authorization': 'Basic {0}'.format(self.base64_credential)} + _headers = self.authHeader + authMode = 'pwd' if authMode == "pwd": - _auth = (self.username, self.password) - else: - _auth = None + if self.jwtToken: + _headers = {'Authorization': "Bearer " + self.jwtToken} + else: + _headers = {'Authorization': 'Basic {0}'.format(self.base64_credential)} + if headers: _headers.update(headers) if self.awsIamHeaders: @@ -448,4 +493,82 @@ def customizeHeader(self, timeout:int = 16_000, responseSize:int = 3.2e+7): Returns: Nothing. Sets `responseConfigHeader` class attribute. """ - self.responseConfigHeader = {"GSQL-TIMEOUT": str(timeout), "RESPONSE-LIMIT": str(responseSize)} \ No newline at end of file + self.responseConfigHeader = {"GSQL-TIMEOUT": str(timeout), "RESPONSE-LIMIT": str(responseSize)} + + def getVersion(self, raw: bool = False) -> Union[str, list]: + """Retrieves the git versions of all components of the system. + + Args: + raw: + Return unprocessed version info string, or extract version info for each component + into a list. + + Returns: + Either an unprocessed string containing the version info details, or a list with version + info for each component. + + Endpoint: + - `GET /version` + See xref:tigergraph-server:API:built-in-endpoints.adoc#_show_component_versions[Show component versions] + """ + logger.info("entry: getVersion") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + response = self._get(self.restppUrl+"/version", strictJson=False, resKey="message") + + if raw: + return response + res = response.split("\n") + components = [] + for i in range(len(res)): + if 2 < i < len(res) - 1: + m = res[i].split() + component = {"name": m[0], "version": m[1], "hash": m[2], + "datetime": m[3] + " " + m[4] + " " + m[5]} + components.append(component) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(components)) + logger.info("exit: getVersion") + + return components + + def getVer(self, component: str = "product", full: bool = False) -> str: + """Gets the version information of a specific component. + + Get the full list of components using `getVersion()`. + + Args: + component: + One of TigerGraph's components (e.g. product, gpe, gse). + full: + Return the full version string (with timestamp, etc.) or just X.Y.Z. + + Returns: + Version info for specified component. + + Raises: + `TigerGraphException` if invalid/non-existent component is specified. + """ + logger.info("entry: getVer") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + ret = "" + for v in self.getVersion(): + if v["name"] == component.lower(): + ret = v["version"] + if ret != "": + if full: + return ret + ret = re.search("_.+_", ret) + ret = ret.group().strip("_") + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: getVer") + + return ret + else: + raise TigerGraphException("\"" + component + "\" is not a valid component.", None) \ No newline at end of file diff --git a/pyTigerGraph/pyTigerGraphUtils.py b/pyTigerGraph/pyTigerGraphUtils.py index 0a53353e..aaa28a7e 100644 --- a/pyTigerGraph/pyTigerGraphUtils.py +++ b/pyTigerGraph/pyTigerGraphUtils.py @@ -5,7 +5,6 @@ """ import json import logging -import re import urllib from typing import Any, Union from urllib.parse import urlparse @@ -72,84 +71,6 @@ def echo(self, usePost: bool = False) -> str: return ret - def getVersion(self, raw: bool = False) -> Union[str, list]: - """Retrieves the git versions of all components of the system. - - Args: - raw: - Return unprocessed version info string, or extract version info for each component - into a list. - - Returns: - Either an unprocessed string containing the version info details, or a list with version - info for each component. - - Endpoint: - - `GET /version` - See xref:tigergraph-server:API:built-in-endpoints.adoc#_show_component_versions[Show component versions] - """ - logger.info("entry: getVersion") - if logger.level == logging.DEBUG: - logger.debug("params: " + self._locals(locals())) - - response = self._get(self.restppUrl+"/version", strictJson=False, resKey="message") - - if raw: - return response - res = response.split("\n") - components = [] - for i in range(len(res)): - if 2 < i < len(res) - 1: - m = res[i].split() - component = {"name": m[0], "version": m[1], "hash": m[2], - "datetime": m[3] + " " + m[4] + " " + m[5]} - components.append(component) - - if logger.level == logging.DEBUG: - logger.debug("return: " + str(components)) - logger.info("exit: getVersion") - - return components - - def getVer(self, component: str = "product", full: bool = False) -> str: - """Gets the version information of a specific component. - - Get the full list of components using `getVersion()`. - - Args: - component: - One of TigerGraph's components (e.g. product, gpe, gse). - full: - Return the full version string (with timestamp, etc.) or just X.Y.Z. - - Returns: - Version info for specified component. - - Raises: - `TigerGraphException` if invalid/non-existent component is specified. - """ - logger.info("entry: getVer") - if logger.level == logging.DEBUG: - logger.debug("params: " + self._locals(locals())) - - ret = "" - for v in self.getVersion(): - if v["name"] == component.lower(): - ret = v["version"] - if ret != "": - if full: - return ret - ret = re.search("_.+_", ret) - ret = ret.group().strip("_") - - if logger.level == logging.DEBUG: - logger.debug("return: " + str(ret)) - logger.info("exit: getVer") - - return ret - else: - raise TigerGraphException("\"" + component + "\" is not a valid component.", None) - def getLicenseInfo(self) -> dict: """Returns the expiration date and remaining days of the license. diff --git a/tests/pyTigerGraphUnitTest.py b/tests/pyTigerGraphUnitTest.py index e62e8a30..cc59ceb8 100644 --- a/tests/pyTigerGraphUnitTest.py +++ b/tests/pyTigerGraphUnitTest.py @@ -20,6 +20,7 @@ def make_connection(graphname: str = None): "sslPort": "443", "tgCloud": False, "gcp": False, + "jwtToken": "" } path = os.path.dirname(os.path.realpath(__file__)) @@ -42,6 +43,7 @@ def make_connection(graphname: str = None): certPath=server_config["certPath"], sslPort=server_config["sslPort"], gcp=server_config["gcp"], + jwtToken=server_config["jwtToken"] ) if server_config.get("getToken", False): conn.getToken(conn.createSecret()) diff --git a/tests/test_jwtAuth.py b/tests/test_jwtAuth.py new file mode 100644 index 00000000..a5a49b5e --- /dev/null +++ b/tests/test_jwtAuth.py @@ -0,0 +1,90 @@ +import unittest +import requests +import json + +from pyTigerGraphUnitTest import make_connection +from pyTigerGraph import TigerGraphConnection + + +class TestJWTTokenAuth(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.conn = make_connection(graphname="Cora") + + + def test_jwtauth(self): + dbversion = self.conn.getVer() + + if "3.9" in str(dbversion): + self._test_jwtauth_3_9() + elif "4.1" in str(dbversion): + self._test_jwtauth_4_1_success() + self._test_jwtauth_4_1_fail() + else: + pass + + + def _requestJWTToken(self): + # Define the URL + url = f"{self.conn.host}:{self.conn.gsPort}/gsqlserver/requestjwttoken" + # Define the data payload + payload = json.dumps({"lifetime": "1000000000"}) + # Define the headers for the request + headers = { + 'Content-Type': 'application/json' + } + # Make the POST request with basic authentication + response = requests.post(url, data=payload, headers=headers, auth=(self.conn.username, self.conn.password)) + return response.json()['token'] + + + def _test_jwtauth_3_9(self): + with self.assertRaises(RuntimeError) as context: + TigerGraphConnection( + host=self.conn.host, + jwtToken="fake.JWT.Token" + ) + + # Verify the exception message + self.assertIn("switch to API token or username/password.", str(context.exception)) + + + def _test_jwtauth_4_1_success(self): + jwt_token = self._requestJWTToken() + + newconn = TigerGraphConnection( + host=self.conn.host, + jwtToken=jwt_token + ) + + authheader = newconn.authHeader + print (f"authheader from new conn: {authheader}") + + # restpp on port 9000 + dbversion = newconn.getVer() + print (f"dbversion from new conn: {dbversion}") + self.assertIn("4.1", str(dbversion)) + + # gsql on port 14240 + res = newconn._get(f"{self.conn.host}:{self.conn.gsPort}/gsqlserver/gsql/simpleauth", authMode="token", resKey=None) + self.assertIn("privileges", res) + + + def _test_jwtauth_4_1_fail(self): + with self.assertRaises(RuntimeError) as context: + TigerGraphConnection( + host=self.conn.host, + jwtToken="invalid.JWT.Token" + ) + + # Verify the exception message + self.assertIn("Please generate new JWT token", str(context.exception)) + + +if __name__ == '__main__': + # unittest.main() + suite = unittest.TestSuite() + suite.addTest(TestJWTTokenAuth("test_jwtauth")) + + runner = unittest.TextTestRunner(verbosity=2, failfast=True) + runner.run(suite) \ No newline at end of file diff --git a/tests/testserver.json b/tests/testserver.json index 0cc0c510..7e8cf115 100644 --- a/tests/testserver.json +++ b/tests/testserver.json @@ -4,5 +4,6 @@ "password": "tigergraph", "restppPort": "9000", "gsPort": "14240", - "getToken": true + "getToken": true, + "jwtToken": "" }