Skip to content

Commit

Permalink
Merge pull request #27 from Stas-Hrytsyshyn/ds-2085_common_ssh_tunnel…
Browse files Browse the repository at this point in the history
…_class

[DS-2085] - Create common SSH tunnel component class
  • Loading branch information
Stas-Hrytsyshyn authored Mar 31, 2021
2 parents f9b2b0c + 96e0082 commit 3e926df
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 2 deletions.
1 change: 1 addition & 0 deletions panoply/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .datasource import *
from .sdk import *
from .ssh import SSHTunnel
2 changes: 1 addition & 1 deletion panoply/constants.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version__ = "2.0.2"
__version__ = "2.0.3"
__package_name__ = "panoply-python-sdk"
5 changes: 5 additions & 0 deletions panoply/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,8 @@ class PanoplyException(Exception):
def __init__(self, args=None, retryable=True):
super(PanoplyException, self).__init__(args)
self.retryable = retryable


class IncorrectParamError(Exception):
def __init__(self, msg: str = "Incorrect input parametr"):
super().__init__(msg)
158 changes: 158 additions & 0 deletions panoply/ssh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
"""
Module for storing SSH related stuff
"""
from typing import Dict
from paramiko import RSAKey
from sshtunnel import SSHTunnelForwarder
from io import StringIO

from .errors import IncorrectParamError


class SSHTunnel:
"""
General SSH tunnel class-component
for working with ssh in python Data Sources.
Arguments:
host (str):
Host where resource(database in most cases) is located.
Example:
default host will be: 127.0.0.1
port (int):
Port of resource you want to connect to.
Example:
default port for `mongo` will be: 27017
ssh_tunnel (dictionary):
General UI object with all required
information for connecting through ssh tunnel.
Structure:
active (bool):
Defines if we will use ssh tunnel for connecting.
if True -> using.
if False (by default) -> ignoring.
port (int):
SSH port, by default = 22.
host (str):
SSH host of remote server.
username (str):
Name of SSH user.
password (str):
Password for remote SSH server.
Empty by default.
privateKey (str):
String representation of private key
for connecting to ssh server.
* Here should not be any newline characters
platform_ssh (bool, True by default):
Flag that you can find in source object.
if False -> we will use python SSHTunnel logic.
if True -> we will use platform SSH.
How to Use:
1) tunnel = SSHTunnel('127.0.0.1', 27017, {...params}, False)
server = tunnel.server
... your code
server.stop() - important, don't forget to close the socket
2) with SSHTunnel('127.0.0.1', 27017, {..params}, False) as server:
... your code here - tunnel will be closed automatically
"""
def __init__(self, host: str, port: int,
ssh_tunnel: Dict, platform_ssh: bool = True):
self.host = host
self.port = port
self.tunnel = ssh_tunnel
self._server = self._get_server(platform_ssh)

@property
def server(self):
return self._server

@property
def port(self):
return self._port

@port.setter
def port(self, value):
if not isinstance(value, int):
raise IncorrectParamError("Port should be instance of `int`")
if value < 0 or value > 65535:
raise IncorrectParamError("Port should be in range [0: 65535]")

self._port = value

@property
def tunnel(self):
return self._tunnel

@tunnel.setter
def tunnel(self, value):
if not isinstance(value, dict):
raise IncorrectParamError("SSH tunnel should be `dict` object")

required_keys = [
"active", "host",
"username", "privateKey"
]

for key in required_keys:
if not value.get(key):
msg = f"SSH tunnel object should contain `{key}` property"
raise IncorrectParamError(msg)

if not value.get("active", False):
msg = "To use SSH tunnel connection, property `active` should be `True`" # noqa
raise IncorrectParamError(msg)

value["port"] = int(value.get("port", 22))
self._tunnel = value

def _get_server(self, platform_ssh: bool):
"""Method for getting and starting ssh server."""
if platform_ssh:
return None

pkey = RSAKey.from_private_key(StringIO(self.tunnel["privateKey"]))

server = SSHTunnelForwarder(
ssh_address_or_host=(self.tunnel["host"], self.tunnel["port"]),
ssh_username=self.tunnel["username"],
ssh_password=self.tunnel.get("password"),
ssh_pkey=pkey,
remote_bind_address=(self.host, self.port)
)
server.start()

return server

def __enter__(self):
return self._server

def __exit__(self, exc_type, exc_val, exc_tb):
if self._server:
self._server.stop()

def __str__(self):
return "SSH tunnel to {}, for user {}".format(
self.tunnel["host"], self.tunnel["username"]
)
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
install_requires=[
"requests==2.21.0",
"oauth2client==4.1.1",
"backoff==1.10.0"
"backoff==1.10.0",
"sshtunnel==0.1.5",
"paramiko==2.7.2",
],
extras_require={
"test": [
Expand Down
59 changes: 59 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from unittest import TestCase
from unittest.mock import patch, MagicMock
import panoply
import base64

Expand All @@ -19,3 +20,61 @@ def test_write(self):
sdk.write('table', {'data': 2})

self.assertEqual(sdk._buffer.qsize(), 2)


class TestSSHTunnel(TestCase):

MOCKED_TUNNEL_OBJECT = {
"active": True,
"username": "panoply-user",
"host": "233.23.11.223",
"port": 22,
"privateKey": "dasd2dsd"
}

class MockedSSHTunnel:
def __init__(self, kwargs):
self.server = kwargs

@staticmethod
def start():
print("server started")

@staticmethod
def stop():
print("server stopped")

def test_tunnel_with_incorrect_port(self):
mocked_message = "Port should be in range [0: 65535]"
try:
tunnel = panoply.SSHTunnel('127.0.0.1', -5, {}, False)
print(tunnel)
except panoply.errors.IncorrectParamError as err:
self.assertEqual(str(err), mocked_message)

def test_tunnel_with_incorrect_tunnel_object(self):
mocked_message = "SSH tunnel object should contain `active` property"
try:
tunnel = panoply.SSHTunnel('127.0.0.1', 27017, {"key": "v"}, False)
print(tunnel)
except panoply.errors.IncorrectParamError as err:
self.assertEqual(str(err), mocked_message)

def test_tunnel_with_incorrect_platform_flag(self):
tunnel = panoply.SSHTunnel(
'127.0.0.1', 27017, self.MOCKED_TUNNEL_OBJECT, True
)
self.assertIsNone(tunnel.server)

@patch("paramiko.RSAKey.from_private_key")
@patch("panoply.ssh.SSHTunnelForwarder")
def test_tunnel_ctxt_manager(self, SSHTunnelForwarder, from_private_key):
SSHTunnelForwarder.return_value = self.MockedSSHTunnel(
self.MOCKED_TUNNEL_OBJECT
)
SSHTunnelForwarder.start.return_value = self.MockedSSHTunnel.start
SSHTunnelForwarder.stop.return_value = self.MockedSSHTunnel.stop
from_private_key.return_value = self.MOCKED_TUNNEL_OBJECT["privateKey"]

with panoply.SSHTunnel("127.0.0.1", 22, self.MOCKED_TUNNEL_OBJECT, False) as tunnel: # noqa
self.assertEqual(tunnel.server, self.MOCKED_TUNNEL_OBJECT)

0 comments on commit 3e926df

Please sign in to comment.