Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate logging to application/features/. #38

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
31 changes: 27 additions & 4 deletions application/features/Audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
from .Connection import Connection
from .. import app
from ..utils import find_free_port, get_headers_dict_from_str, local_auth
import logging

logger = logging.getLogger(__name__)

AUDIO_CONNECTIONS = {}

Expand All @@ -57,13 +60,16 @@ def __del__(self):
super().__del__()

def connect(self, *args, **kwargs):
logger.debug("Audio: Establishing Audio connection")
return super().connect(*args, **kwargs)

def launch_audio(self):
try:
logger.debug("Audio: Launching Audio connection.")
self.transport = self.client.get_transport()
self.remote_port = self.transport.request_port_forward('127.0.0.1', 0)
except Exception as e:
logger.exception("Audio: exception raised during launch audio")
return False, str(e)

self.id = uuid.uuid4().hex
Expand All @@ -83,11 +89,12 @@ def handleConnected(self):
headers = get_headers_dict_from_str(headers)
if not local_auth(headers=headers, abort_func=self.close):
# local auth failure
logger.warning("AudioWebSocket: Local Authentication Failure")
IreneLime marked this conversation as resolved.
Show resolved Hide resolved
return

audio_id = self.request.path[1:]
if audio_id not in AUDIO_CONNECTIONS:
print(f'AudioWebSocket: Requested audio_id={audio_id} does not exist.')
logger.warning("AudioWebSocket: Requested audio_id=%s does not exist", audio_id)
self.close()
return

Expand All @@ -103,26 +110,35 @@ def handleConnected(self):
f'module-null-sink sink_name={sink_name} '
exit_status, _, stdout, _ = self.audio.exec_command_blocking(load_module_command)
if exit_status != 0:
print(f'AudioWebSocket: audio_id={audio_id}: unable to load pactl module-null-sink sink_name={sink_name}')
logger.warning(
"AudioWebSocket: audio_id=%s: unable to load pactl module-null-sink sink_name=%s",
audio_id,
sink_name
)
IreneLime marked this conversation as resolved.
Show resolved Hide resolved
return
load_module_stdout_lines = stdout.readlines()
logger.debug("AudioWebSocket: Load Module: %s", load_module_stdout_lines)
self.module_id = int(load_module_stdout_lines[0])

keep_launching_ffmpeg = True

def ffmpeg_launcher():
logger.debug("AudioWebSocket: ffmpeg_launcher thread started")
# TODO: support requesting audio format from the client
launch_ffmpeg_command = f'killall ffmpeg; ffmpeg -f pulse -i "{sink_name}.monitor" ' \
f'-ac 2 -acodec pcm_s16le -ar 44100 -f s16le "tcp://127.0.0.1:{self.audio.remote_port}"'
# keep launching if the connection is not accepted in the writer() below
while keep_launching_ffmpeg:
logger.debug("AudioWebSocket: Launch ffmpeg: %s", launch_ffmpeg_command)
_, ffmpeg_stdout, _ = self.audio.client.exec_command(launch_ffmpeg_command)
ffmpeg_stdout.channel.recv_exit_status()
# if `ffmpeg` launches successfully, `ffmpeg_stdout.channel.recv_exit_status` should not return
logger.debug("AudioWebSocket: ffmpeg_launcher thread ended")

ffmpeg_launcher_thread = threading.Thread(target=ffmpeg_launcher)

def writer():
logger.debug("AudioWebSocket: writer thread started")
channel = self.audio.transport.accept(FFMPEG_LOAD_TIME * TRY_FFMPEG_MAX_COUNT)

nonlocal keep_launching_ffmpeg
Expand All @@ -138,14 +154,17 @@ def writer():
while True:
data = channel.recv(AUDIO_BUFFER_SIZE)
if not data:
logger.debug("AudioWebSocket: Close audio socket connection")
self.close()
break
buffer += data
if len(buffer) >= AUDIO_BUFFER_SIZE:
compressed = zlib.compress(buffer, level=4)
logger.debug("AudioWebSocket: Send compressed message of size %d", compressed)
self.sendMessage(compressed)
# print(len(compressed) / len(buffer) * 100)
logger.debug("Audio: Audio port %s", AUDIO_PORT)
buffer = b''
logger.debug("AudioWebSocket: write thread ended")

writer_thread = threading.Thread(target=writer)

Expand All @@ -155,8 +174,10 @@ def writer():
def handleClose(self):
if self.module_id is not None:
# unload the module before leaving
logger.debug("AudioWebSocket: Unload module %d", self.module_id)
self.audio.client.exec_command(f'pactl unload-module {self.module_id}')

logger.debug("AudioWebSocket: End audio socket %s connection", self.audio.id)
del AUDIO_CONNECTIONS[self.audio.id]
del self.audio

Expand All @@ -166,13 +187,15 @@ def handleClose(self):
# if we are in debug mode, run the server in the second round
if not app.debug or os.environ.get("WERKZEUG_RUN_MAIN") == "true":
AUDIO_PORT = find_free_port()
print("AUDIO_PORT =", AUDIO_PORT)
logger.debug("Audio: Audio port %s", AUDIO_PORT)
IreneLime marked this conversation as resolved.
Show resolved Hide resolved

if os.environ.get('SSL_CERT_PATH') is None:
logger.debug("Audio: SSL Certification Path not set. Generating self-signing certificate")
# no certificate provided, generate self-signing certificate
audio_server = SimpleSSLWebSocketServer('127.0.0.1', AUDIO_PORT, AudioWebSocket,
ssl_context=generate_adhoc_ssl_context())
else:
logger.debug("Audio: SSL Certification Path exists")
import ssl

audio_server = SimpleSSLWebSocketServer('0.0.0.0', AUDIO_PORT, AudioWebSocket,
Expand Down
47 changes: 36 additions & 11 deletions application/features/Connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,30 @@
import paramiko
import select

import logging

logger = logging.getLogger(__name__)

class ForwardServerHandler(socketserver.BaseRequestHandler):
def handle(self):
junhaoliao marked this conversation as resolved.
Show resolved Hide resolved
self.server: ForwardServer
try:
logger.debug("Connection: Open forward server channel")
chan = self.server.ssh_transport.open_channel(
"direct-tcpip",
("127.0.0.1", self.server.chain_port),
self.request.getpeername(),
)
except Exception as e:
logger.exception("Connection: Incoming request to 127.0.0.1:%d failed", self.server.chain_port)
return False, "Incoming request to %s:%d failed: %s" % (
"127.0.0.1", self.server.chain_port, repr(e))

print(
"Connected! Tunnel open %r -> %r -> %r"
% (
self.request.getpeername(),
chan.getpeername(),
("127.0.0.1", self.server.chain_port),
)
logger.info(
"Connected! Tunnel open %r -> %r -> %r",
self.request.getpeername(),
chan.getpeername(),
("127.0.0.1", self.server.chain_port),
)

try:
Expand All @@ -64,13 +67,15 @@ def handle(self):
break
self.request.send(data)
except Exception as e:
print(e)
logger.exception("Connection: Error occurred during data transfer")

try:
logger.debug("Connection: Close forward server channel")
chan.close()
self.server.shutdown()
except Exception as e:
print(e)
IreneLime marked this conversation as resolved.
Show resolved Hide resolved
logger.exception("Connection: Close forward server channel failed")


class ForwardServer(socketserver.ThreadingTCPServer):
Expand Down Expand Up @@ -102,6 +107,9 @@ def __del__(self):
def _client_connect(self, client: paramiko.SSHClient,
host, username,
password=None, key_filename=None, private_key_str=None):
if self._jump_channel is not None:
logger.debug("Connection: Connection initialized through Jump Channel")
logger.debug("Connection: Connecting to %s@%s", username, host)
if password is not None:
client.connect(host, username=username, password=password, timeout=15, sock=self._jump_channel)
elif key_filename is not None:
Expand All @@ -128,23 +136,26 @@ def _init_jump_channel(self, host, username, **auth_methods):

self._jump_client = paramiko.SSHClient()
self._jump_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
logger.debug("Connection: Initialize Jump Client for connection to %[email protected]", username)
self._client_connect(self._jump_client, 'remote.ecf.utoronto.ca', username, **auth_methods)
logger.debug("Connection: Open Jump channel connection to %s at port 22", host)
self._jump_channel = self._jump_client.get_transport().open_channel('direct-tcpip',
(host, 22),
('127.0.0.1', 22))

def connect(self, host: str, username: str, **auth_methods):
junhaoliao marked this conversation as resolved.
Show resolved Hide resolved
try:
logger.debug("Connection: Connection attempt to %s@%s", username, host)
self._init_jump_channel(host, username, **auth_methods)
self._client_connect(self.client, host, username, **auth_methods)
except Exception as e:
# raise e
# print('Connection::connect() exception:')
logger.exception("Connection: Connection attempt to %s@%s failed", username, host)
return False, str(e)

self.host = host
self.username = username

logger.debug("Connection: Successfully connected to %s@%s", username, host)
return True, ''

@staticmethod
Expand All @@ -160,9 +171,11 @@ def ssh_keygen(key_filename=None, key_file_obj=None, public_key_comment=''):

# save the private key
if key_filename is not None:
logger.debug("Connection: RSA SSH private key written to %s", key_filename)
rsa_key.write_private_key_file(key_filename)
elif key_file_obj is not None:
rsa_key.write_private_key(key_file_obj)
logger.debug("Connection: RSA SSH private key written to %s", key_file_obj)
else:
raise ValueError('Neither key_filename nor key_file_obj is provided.')

Expand Down Expand Up @@ -192,6 +205,7 @@ def save_keys(self, key_filename=None, key_file_obj=None, public_key_comment='')
"mkdir -p ~/.ssh && chmod 700 ~/.ssh && echo '%s' >> ~/.ssh/authorized_keys" % pub_key)
if exit_status != 0:
return False, "Connection::save_keys: unable to save public key; Check for disk quota and permissions with any conventional SSH clients. "
logger.debug("Connection: Public ssh key saved to remove server ~/.ssh/authorized_keys")

return True, ""

Expand All @@ -217,22 +231,30 @@ def exec_command_blocking_large(self, command):
return '\n'.join(stdout) + '\n' + '\n'.join(stderr)

def _port_forward_thread(self, local_port, remote_port):
logger.debug("Connection: Port forward thread started")
forward_server = ForwardServer(("", local_port), ForwardServerHandler)

forward_server.ssh_transport = self.client.get_transport()
forward_server.chain_port = remote_port

forward_server.serve_forever()
forward_server.server_close()
logger.debug("Connection: Port forward thread ended")

def port_forward(self, *args):
forwarding_thread = threading.Thread(target=self._port_forward_thread, args=args)
forwarding_thread.start()

def is_eecg(self):
return 'eecg' in self.host
if 'eecg' in self.host:
logger.debug("Connection: Target host is eecg")
return True

return False

def is_ecf(self):
if 'ecf' in self.host:
logger.debug("Connection: Target host is ecf")
return 'ecf' in self.host
IreneLime marked this conversation as resolved.
Show resolved Hide resolved

def is_uoft(self):
Expand All @@ -256,6 +278,9 @@ def is_load_high(self):

my_pts_count = len(output) - 1 # -1: excluding the `uptime` output

logger.debug("Connection: pts count: %d; my pts count: %d", pts_count, my_pts_count)
logger.debug("Connection: load sum: %d", load_sum)

if pts_count > my_pts_count: # there are more terminals than mine
return True
elif load_sum > 1.0:
Expand Down
Loading
Loading