diff --git a/pyTigerGraph/pyTigerGraph.py b/pyTigerGraph/pyTigerGraph.py index 3e3cbb4c..5c2e909b 100644 --- a/pyTigerGraph/pyTigerGraph.py +++ b/pyTigerGraph/pyTigerGraph.py @@ -31,9 +31,9 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph", tgCloud: bool = False, restppPort: Union[int, str] = "9000", gsPort: Union[int, str] = "14240", gsqlVersion: str = "", version: str = "", apiToken: str = "", useCert: bool = None, certPath: str = None, debug: bool = None, - sslPort: Union[int, str] = "443", gcp: bool = False): + sslPort: Union[int, str] = "443", gcp: bool = False, jwtToken: str = ""): super().__init__(host, graphname, gsqlSecret, username, password, tgCloud, restppPort, - gsPort, gsqlVersion, version, apiToken, useCert, certPath, debug, sslPort, gcp) + gsPort, gsqlVersion, version, apiToken, useCert, certPath, debug, sslPort, gcp, jwtToken) self.gds = None self.ai = None diff --git a/pyTigerGraph/pyTigerGraphBase.py b/pyTigerGraph/pyTigerGraphBase.py index ae4874e1..b7743374 100644 --- a/pyTigerGraph/pyTigerGraphBase.py +++ b/pyTigerGraph/pyTigerGraphBase.py @@ -36,7 +36,7 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph", tgCloud: bool = False, restppPort: Union[int, str] = "9000", gsPort: Union[int, str] = "14240", gsqlVersion: str = "", version: str = "", apiToken: str = "", useCert: bool = None, certPath: str = None, debug: bool = None, - sslPort: Union[int, str] = "443", gcp: bool = False): + sslPort: Union[int, str] = "443", gcp: bool = False, jwtToken: str = ""): """Initiate a connection object. Args: @@ -76,6 +76,8 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph", Port for fetching SSL certificate in case of firewall. gcp: DEPRECATED. Previously used for connecting to databases provisioned on GCP in TigerGraph Cloud. + jwtToken: + The JWT token generated from customer side for authentication Raises: TigerGraphException: In case on invalid URL scheme. @@ -124,6 +126,12 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph", else: self.authHeader = {"Authorization": "Basic {0}".format(self.base64_credential)} + # If JWT token is provided, set authMode to "token", and overwrite authMode = "pwd" for GSQL authentication as well + if jwtToken: + self.authMode = "token" + self.jwtToken = jwtToken + self.authHeader = {"Authorization": "Bearer " + self.jwtToken} + if debug is not None: warnings.warn( "The `debug` parameter is deprecated; configure standard logging in your app.", @@ -257,15 +265,23 @@ def _req(self, method: str, url: str, authMode: str = "token", headers: dict = N if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) - if authMode == "token" and str(self.apiToken) != "": - if isinstance(self.apiToken, tuple): - self.apiToken = self.apiToken[0] - self.authHeader = {'Authorization': "Bearer " + self.apiToken} - _headers = self.authHeader - else: - self.authHeader = {'Authorization': 'Basic {0}'.format(self.base64_credential)} - _headers = self.authHeader - authMode = 'pwd' + if authMode == "token": + if isinstance(self.jwtToken, str) and self.jwtToken.strip() != "": + token = self.jwtToken + elif isinstance(self.apiToken, tuple): + token = self.apiToken[0] + elif isinstance(self.apiToken, str) and self.apiToken.strip() != "": + token = self.apiToken + else: + token = None + + if token is not None: + self.authHeader = {'Authorization': "Bearer " + token} + _headers = self.authHeader + else: + self.authHeader = {'Authorization': 'Basic {0}'.format(self.base64_credential)} + _headers = self.authHeader + authMode = 'pwd' if authMode == "pwd": _auth = (self.username, self.password)