Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate logging to application/features/. #38

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions application/features/Audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
from .Connection import Connection
from .. import app
from ..utils import find_free_port, get_headers_dict_from_str, local_auth
import logging.config

logger = logging.getLogger(__name__)

AUDIO_CONNECTIONS = {}

Expand All @@ -57,13 +60,16 @@ def __del__(self):
super().__del__()

def connect(self, *args, **kwargs):
logger.debug("Audio: Establishing Audio connection")
return super().connect(*args, **kwargs)

def launch_audio(self):
try:
logger.debug("Audio: Launching Audio connection. Forwarding request to 127.0.0.1, port 0.")
self.transport = self.client.get_transport()
self.remote_port = self.transport.request_port_forward('127.0.0.1', 0)
except Exception as e:
logger.warning("Audio: exception raised during launch audio: %s", e)
return False, str(e)

self.id = uuid.uuid4().hex
Expand All @@ -83,11 +89,12 @@ def handleConnected(self):
headers = get_headers_dict_from_str(headers)
if not local_auth(headers=headers, abort_func=self.close):
# local auth failure
logger.warning("AudioWebSocket: Local Authentication Failure")
return

audio_id = self.request.path[1:]
if audio_id not in AUDIO_CONNECTIONS:
print(f'AudioWebSocket: Requested audio_id={audio_id} does not exist.')
logger.warning("AudioWebSocket: Requested audio_id=%s does not exist.", audio_id)
self.close()
return

Expand All @@ -103,26 +110,35 @@ def handleConnected(self):
f'module-null-sink sink_name={sink_name} '
exit_status, _, stdout, _ = self.audio.exec_command_blocking(load_module_command)
if exit_status != 0:
print(f'AudioWebSocket: audio_id={audio_id}: unable to load pactl module-null-sink sink_name={sink_name}')
logger.warning(
"AudioWebSocket: audio_id=%s: unable to load pactl module-null-sink sink_name=%s",
audio_id,
sink_name
)
IreneLime marked this conversation as resolved.
Show resolved Hide resolved
return
load_module_stdout_lines = stdout.readlines()
logger.debug("AudioWebSocket: Load Module: %s", load_module_stdout_lines)
self.module_id = int(load_module_stdout_lines[0])

keep_launching_ffmpeg = True

def ffmpeg_launcher():
logger.debug("AudioWebSocket: ffmpeg_launcher thread started")
# TODO: support requesting audio format from the client
launch_ffmpeg_command = f'killall ffmpeg; ffmpeg -f pulse -i "{sink_name}.monitor" ' \
f'-ac 2 -acodec pcm_s16le -ar 44100 -f s16le "tcp://127.0.0.1:{self.audio.remote_port}"'
# keep launching if the connection is not accepted in the writer() below
while keep_launching_ffmpeg:
logger.debug("AudioWebSocket: Launch ffmpeg: %s", launch_ffmpeg_command)
_, ffmpeg_stdout, _ = self.audio.client.exec_command(launch_ffmpeg_command)
ffmpeg_stdout.channel.recv_exit_status()
# if `ffmpeg` launches successfully, `ffmpeg_stdout.channel.recv_exit_status` should not return
logger.debug("AudioWebSocket: ffmpeg_launcher thread ended")

ffmpeg_launcher_thread = threading.Thread(target=ffmpeg_launcher)

def writer():
logger.debug("AudioWebSocket: writer thread started")
channel = self.audio.transport.accept(FFMPEG_LOAD_TIME * TRY_FFMPEG_MAX_COUNT)

nonlocal keep_launching_ffmpeg
Expand All @@ -138,14 +154,17 @@ def writer():
while True:
data = channel.recv(AUDIO_BUFFER_SIZE)
if not data:
logger.debug("AudioWebSocket: Close audio socket connection")
self.close()
break
buffer += data
if len(buffer) >= AUDIO_BUFFER_SIZE:
compressed = zlib.compress(buffer, level=4)
logger.debug("AudioWebSocket: Send compressed message of size %s", len(compressed))
self.sendMessage(compressed)
# print(len(compressed) / len(buffer) * 100)
buffer = b''
logger.debug("AudioWebSocket: write thread ended")

writer_thread = threading.Thread(target=writer)

Expand All @@ -155,8 +174,10 @@ def writer():
def handleClose(self):
if self.module_id is not None:
# unload the module before leaving
logger.debug("AudioWebSocket: Unload module %s", self.module_id)
self.audio.client.exec_command(f'pactl unload-module {self.module_id}')

logger.debug("AudioWebSocket: End audio socket %s connection", self.audio.id)
del AUDIO_CONNECTIONS[self.audio.id]
del self.audio

Expand All @@ -166,18 +187,20 @@ def handleClose(self):
# if we are in debug mode, run the server in the second round
if not app.debug or os.environ.get("WERKZEUG_RUN_MAIN") == "true":
AUDIO_PORT = find_free_port()
print("AUDIO_PORT =", AUDIO_PORT)
logger.debug("Audio: Audio port %s", AUDIO_PORT)

if os.environ.get('SSL_CERT_PATH') is None:
logger.debug("Audio: SSL Certification Path not set. Generating self-signing certificate")
# no certificate provided, generate self-signing certificate
audio_server = SimpleSSLWebSocketServer('127.0.0.1', AUDIO_PORT, AudioWebSocket,
ssl_context=generate_adhoc_ssl_context())
else:
logger.debug("Audio: SSL Certification Path exists")
import ssl

audio_server = SimpleSSLWebSocketServer('0.0.0.0', AUDIO_PORT, AudioWebSocket,
certfile=os.environ.get('SSL_CERT_PATH'),
keyfile=os.environ.get('SSL_KEY_PATH'),
version=ssl.PROTOCOL_TLS)

threading.Thread(target=audio_server.serveforever, daemon=True).start()
threading.Thread(target=audio_server.serveforever, daemon=True).start()
IreneLime marked this conversation as resolved.
Show resolved Hide resolved
32 changes: 31 additions & 1 deletion application/features/Connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@
import paramiko
import select

import logging.config

logger = logging.getLogger(__name__)

class ForwardServerHandler(socketserver.BaseRequestHandler):
def handle(self):
junhaoliao marked this conversation as resolved.
Show resolved Hide resolved
self.server: ForwardServer
try:
logger.debug("Connection: Open forward server channel")
chan = self.server.ssh_transport.open_channel(
"direct-tcpip",
("127.0.0.1", self.server.chain_port),
Expand All @@ -49,6 +53,12 @@ def handle(self):
("127.0.0.1", self.server.chain_port),
)
)
logger.debug(
"Connected! Tunnel open %r -> %r -> %r",
self.request.getpeername(),
chan.getpeername(),
("127.0.0.1", self.server.chain_port),
)

try:
while True:
Expand All @@ -67,6 +77,7 @@ def handle(self):
print(e)

try:
logger.debug("Connection: Close forward server channel")
chan.close()
self.server.shutdown()
except Exception as e:
Expand Down Expand Up @@ -102,6 +113,9 @@ def __del__(self):
def _client_connect(self, client: paramiko.SSHClient,
host, username,
password=None, key_filename=None, private_key_str=None):
if self._jump_channel is not None:
logger.debug("Connection: Connection initialized through Jump Channel")
logger.debug("Connection: Connecting to %s@%s", username, host)
if password is not None:
client.connect(host, username=username, password=password, timeout=15, sock=self._jump_channel)
elif key_filename is not None:
Expand All @@ -128,13 +142,16 @@ def _init_jump_channel(self, host, username, **auth_methods):

self._jump_client = paramiko.SSHClient()
self._jump_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
logger.debug("Connection: Initialize Jump Client for connection to %[email protected]", username)
self._client_connect(self._jump_client, 'remote.ecf.utoronto.ca', username, **auth_methods)
logger.debug("Connection: Open Jump channel connection to %s at port 22", host)
self._jump_channel = self._jump_client.get_transport().open_channel('direct-tcpip',
(host, 22),
('127.0.0.1', 22))

def connect(self, host: str, username: str, **auth_methods):
junhaoliao marked this conversation as resolved.
Show resolved Hide resolved
try:
logger.debug("Connection: Connection attempt to %s@%s", username, host)
self._init_jump_channel(host, username, **auth_methods)
self._client_connect(self.client, host, username, **auth_methods)
except Exception as e:
Expand All @@ -145,6 +162,7 @@ def connect(self, host: str, username: str, **auth_methods):
self.host = host
self.username = username

logger.debug("Connection: Successfully connected to %s@%s", username, host)
return True, ''

@staticmethod
Expand All @@ -160,9 +178,11 @@ def ssh_keygen(key_filename=None, key_file_obj=None, public_key_comment=''):

# save the private key
if key_filename is not None:
logger.debug("Connection: RSA SSH private key written to %s", key_filename)
rsa_key.write_private_key_file(key_filename)
elif key_file_obj is not None:
rsa_key.write_private_key(key_file_obj)
logger.debug("Connection: RSA SSH private key written to %s", key_file_obj)
else:
raise ValueError('Neither key_filename nor key_file_obj is provided.')

Expand Down Expand Up @@ -192,6 +212,7 @@ def save_keys(self, key_filename=None, key_file_obj=None, public_key_comment='')
"mkdir -p ~/.ssh && chmod 700 ~/.ssh && echo '%s' >> ~/.ssh/authorized_keys" % pub_key)
if exit_status != 0:
return False, "Connection::save_keys: unable to save public key; Check for disk quota and permissions with any conventional SSH clients. "
logger.debug("Connection: Public ssh key saved to remove server ~/.ssh/authorized_keys")

return True, ""

Expand All @@ -217,22 +238,28 @@ def exec_command_blocking_large(self, command):
return '\n'.join(stdout) + '\n' + '\n'.join(stderr)

def _port_forward_thread(self, local_port, remote_port):
logger.debug("Connection: Port forward thread started")
forward_server = ForwardServer(("", local_port), ForwardServerHandler)

forward_server.ssh_transport = self.client.get_transport()
forward_server.chain_port = remote_port

forward_server.serve_forever()
forward_server.server_close()
logger.debug("Connection: Port forward thread ended")

def port_forward(self, *args):
forwarding_thread = threading.Thread(target=self._port_forward_thread, args=args)
forwarding_thread.start()

def is_eecg(self):
if 'eecg' in self.host:
logger.debug("Connection: Target host is eecg")
return 'eecg' in self.host

def is_ecf(self):
if 'ecf' in self.host:
logger.debug("Connection: Target host is ecf")
return 'ecf' in self.host

def is_uoft(self):
Expand All @@ -256,6 +283,9 @@ def is_load_high(self):

my_pts_count = len(output) - 1 # -1: excluding the `uptime` output

logger.debug("Connection: pts count: %s; my pts count: %s", pts_count, my_pts_count)
logger.debug("Connection: load sum: %s", load_sum)
IreneLime marked this conversation as resolved.
Show resolved Hide resolved

if pts_count > my_pts_count: # there are more terminals than mine
return True
elif load_sum > 1.0:
Expand All @@ -265,4 +295,4 @@ def is_load_high(self):
# it is considered a high load
return True

return False
return False
IreneLime marked this conversation as resolved.
Show resolved Hide resolved
20 changes: 16 additions & 4 deletions application/features/SFTP.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
from paramiko.sftp_client import SFTPClient

from .Connection import Connection
import logging.config
IreneLime marked this conversation as resolved.
Show resolved Hide resolved

logger = logging.getLogger(__name__)

class SFTP(Connection):
def __init__(self):
Expand All @@ -41,11 +43,13 @@ def __del__(self):
super().__del__()

def connect(self, *args, **kwargs):
logger.debug("SFTP: Establishing SFTP connection")
status, reason = super().connect(*args, **kwargs)
if not status:
return status, reason

try:
logger.debug("SFTP: Open SFTP client connection")
IreneLime marked this conversation as resolved.
Show resolved Hide resolved
self.sftp = self.client.open_sftp()
self.sftp.chdir(".")
except Exception as e:
Expand All @@ -59,6 +63,7 @@ def ls(self, path=""):
self.sftp.chdir(path)
cwd = self.sftp.getcwd()
attrs = self.sftp.listdir_attr(cwd)
logger.debug("SFTP: ls %s: %s", cwd, attrs)
IreneLime marked this conversation as resolved.
Show resolved Hide resolved

file_list = []
# TODO: should support uid and gid later
Expand Down Expand Up @@ -99,10 +104,10 @@ def _zip_dir_recurse(self, z, parent, file):
try:
mode = self.sftp.stat(fullpath).st_mode
if stat.S_ISREG(mode):
# print(fullpath, 'is file')
logger.debug("SFTP: %s is a file", fullpath)
z.write_iter(fullpath, self.dl_generator(fullpath))
elif stat.S_ISDIR(mode):
# print(fullpath, 'is dir')
logger.debug("SFTP: %s is a directory", fullpath)
# TODO: support writing an empty directory if len(dir_ls)==0
# That will involve modifying the zipstream library
dir_ls = self.sftp.listdir(fullpath)
Expand All @@ -116,10 +121,12 @@ def _zip_dir_recurse(self, z, parent, file):
return

def zip_generator(self, cwd, file_list):
logger.debug("SFTP: zip_generator on directory: %s", cwd)
self.sftp.chdir(cwd)
z = zipstream.ZipFile(compression=zipstream.ZIP_DEFLATED, allowZip64=True)

for file in file_list:
logger.debug("SFTP: zip_generator on file: %s", file)
self._zip_dir_recurse(z, '', file)

return z
Expand All @@ -128,6 +135,7 @@ def rename(self, cwd, old, new):
try:
self.sftp.chdir(cwd)
self.sftp.rename(old, new)
logger.debug("SFTP: Rename %s in directory %s to %s", old, cwd, new)
except Exception as e:
return False, repr(e)

Expand All @@ -136,9 +144,10 @@ def rename(self, cwd, old, new):
def chmod(self, path, mode, recursive):
_, _, _, stderr = self.exec_command_blocking(
f'chmod {"-R" if recursive else ""} {"{0:0{1}o}".format(mode, 3)} "{path}"')
logger.debug("SFTP: Change permission on %s to '%s'", path, "{0:0{1}o}".format(mode, 3))
stderr_lines = stderr.readlines()
if len(stderr_lines) != 0:
print(stderr_lines)
logger.warning("SFTP: chmod failed due to %s", stderr_lines)
IreneLime marked this conversation as resolved.
Show resolved Hide resolved
IreneLime marked this conversation as resolved.
Show resolved Hide resolved
return False, 'Some files were not applied with the request mode due to permission issues.'

return True, ''
Expand All @@ -159,6 +168,7 @@ def rm(self, cwd, file_list):

counter += 1
if counter == 50:
logger.debug("SFTP: Execute Command %s", ' '.join(cmd_list))
IreneLime marked this conversation as resolved.
Show resolved Hide resolved
_, _, stderr = self.client.exec_command(" ".join(cmd_list))
stderr_lines = stderr.readlines()
if len(stderr_lines) != 0:
Expand All @@ -169,6 +179,7 @@ def rm(self, cwd, file_list):
counter = 0
cmd_list = [f'cd "{cwd}" && rm -rf']

logger.debug("SFTP: Execute Command %s", ' '.join(cmd_list))
_, _, stderr = self.client.exec_command(" ".join(cmd_list))
stderr_lines = stderr.readlines()
if len(stderr_lines) != 0:
Expand All @@ -180,8 +191,9 @@ def rm(self, cwd, file_list):
return True, ''

def mkdir(self, cwd, name):
logger.debug("SFTP: Make directory %s at %s", name, cwd)
IreneLime marked this conversation as resolved.
Show resolved Hide resolved
_, _, _, stderr = self.exec_command_blocking(f'cd "{cwd}"&& mkdir "{name}"')
stderr_lines = stderr.readlines()
if len(stderr_lines) != 0:
return False, stderr_lines[0]
return True, ''
return True, ''
IreneLime marked this conversation as resolved.
Show resolved Hide resolved
Loading
Loading