From 143e859037ddf7925e96f8f1d106d6c54efc643b Mon Sep 17 00:00:00 2001 From: doron zarhi Date: Mon, 4 Apr 2022 16:11:07 +0300 Subject: [PATCH] client: make client thread-safe --- src/rpcclient/rpcclient/client.py | 111 +++++++++++++--------- src/rpcclient/rpcclient/client_factory.py | 2 +- src/rpcclient/rpcclient/protocol.py | 6 +- src/rpcclient/tests/test_thread_safe.py | 31 ++++++ 4 files changed, 99 insertions(+), 51 deletions(-) create mode 100644 src/rpcclient/tests/test_thread_safe.py diff --git a/src/rpcclient/rpcclient/client.py b/src/rpcclient/rpcclient/client.py index bb4e9769..dbd8ae3b 100644 --- a/src/rpcclient/rpcclient/client.py +++ b/src/rpcclient/rpcclient/client.py @@ -4,6 +4,7 @@ import logging import os import sys +import threading import typing from collections import namedtuple from enum import Enum @@ -71,6 +72,7 @@ def __init__(self, sock, sysname: str, arch: arch_t, hostname: str, port: int = self._endianness = '<' self._sysname = sysname self._dlsym_global_handle = -1 # RTLD_NEXT + self._lock = threading.Lock() # whether the system uses inode structs of 64 bits self.inode64 = False @@ -112,9 +114,10 @@ def dlopen(self, filename: str, mode: int) -> Symbol: 'cmd_type': cmd_type_t.CMD_DLOPEN, 'data': {'filename': filename, 'mode': mode}, }) - self._sock.sendall(message) - err = Int64sl.parse(self._recvall(Int64sl.sizeof())) - return self.symbol(err) + with self._lock: + self._sock.sendall(message) + address = Int64sl.parse(self._recvall(Int64sl.sizeof())) + return self.symbol(address) def dlclose(self, lib: int): """ call dlclose() at remote and return its handle. see the man page for more details. """ @@ -123,8 +126,9 @@ def dlclose(self, lib: int): 'cmd_type': cmd_type_t.CMD_DLCLOSE, 'data': {'lib': lib}, }) - self._sock.sendall(message) - err = Int64sl.parse(self._recvall(Int64sl.sizeof())) + with self._lock: + self._sock.sendall(message) + err = Int64sl.parse(self._recvall(Int64sl.sizeof())) return err def dlsym(self, lib: int, symbol_name: str): @@ -134,9 +138,10 @@ def dlsym(self, lib: int, symbol_name: str): 'cmd_type': cmd_type_t.CMD_DLSYM, 'data': {'lib': lib, 'symbol_name': symbol_name}, }) - self._sock.sendall(message) - err = Int64sl.parse(self._recvall(Int64sl.sizeof())) - return err + with self._lock: + self._sock.sendall(message) + address = Int64sl.parse(self._recvall(Int64sl.sizeof())) + return address def call(self, address: int, argv: typing.List[int] = None, return_float64=False, return_float32=False, return_float16=False, return_raw=False) -> Symbol: @@ -178,9 +183,10 @@ def call(self, address: int, argv: typing.List[int] = None, return_float64=False 'cmd_type': cmd_type_t.CMD_CALL, 'data': {'address': address, 'argv': fixed_argv}, }) - self._sock.sendall(message) - response = call_response_t.parse(self._recvall(call_response_t_size)) + with self._lock: + self._sock.sendall(message) + response = call_response_t.parse(self._recvall(call_response_t_size)) for f in free_list: self.symbols.free(f) @@ -213,11 +219,12 @@ def peek(self, address: int, size: int) -> bytes: 'cmd_type': cmd_type_t.CMD_PEEK, 'data': {'address': address, 'size': size}, }) - self._sock.sendall(message) - reply = protocol_message_t.parse(self._recvall(reply_protocol_message_t.sizeof())) - if reply.cmd_type == cmd_type_t.CMD_REPLY_ERROR: - raise ArgumentError(f'failed to read {size} bytes from {address}') - return self._recvall(size) + with self._lock: + self._sock.sendall(message) + reply = protocol_message_t.parse(self._recvall(reply_protocol_message_t.sizeof())) + if reply.cmd_type == cmd_type_t.CMD_REPLY_ERROR: + raise ArgumentError(f'failed to read {size} bytes from {address}') + return self._recvall(size) def poke(self, address: int, data: bytes): """ poke data at given address """ @@ -225,10 +232,11 @@ def poke(self, address: int, data: bytes): 'cmd_type': cmd_type_t.CMD_POKE, 'data': {'address': address, 'size': len(data), 'data': data}, }) - self._sock.sendall(message) - reply = protocol_message_t.parse(self._recvall(reply_protocol_message_t.sizeof())) - if reply.cmd_type == cmd_type_t.CMD_REPLY_ERROR: - raise ArgumentError(f'failed to write {len(data)} bytes to {address}') + with self._lock: + self._sock.sendall(message) + reply = protocol_message_t.parse(self._recvall(reply_protocol_message_t.sizeof())) + if reply.cmd_type == cmd_type_t.CMD_REPLY_ERROR: + raise ArgumentError(f'failed to write {len(data)} bytes to {address}') def get_dummy_block(self) -> Symbol: """ get an address for a stub block containing nothing """ @@ -236,8 +244,12 @@ def get_dummy_block(self) -> Symbol: 'cmd_type': cmd_type_t.CMD_GET_DUMMY_BLOCK, 'data': None, }) - self._sock.sendall(message) - return self.symbol(dummy_block_t.parse(self._recvall(8))) + + with self._lock: + self._sock.sendall(message) + result = dummy_block_t.parse(self._recvall(dummy_block_t.sizeof())) + + return self.symbol(result) def spawn(self, argv: typing.List[str] = None, envp: typing.List[str] = None, stdin: io_or_str = sys.stdin, stdout=sys.stdout, raw_tty=False, background=False) -> SpawnResult: @@ -258,30 +270,31 @@ def spawn(self, argv: typing.List[str] = None, envp: typing.List[str] = None, st if envp is None: envp = self.DEFAULT_ENVP - try: - pid = self._execute(argv, envp, background=background) - except SpawnError: - # depends on where the error occurred, the socket might be closed - raise + with self._lock: + try: + pid = self._execute(argv, envp, background=background) + except SpawnError: + # depends on where the error occurred, the socket might be closed + raise - logging.info(f'shell process started as pid: {pid}') + logging.info(f'shell process started as pid: {pid}') - if background: - return SpawnResult(error=None, pid=pid, stdout=None) + if background: + return SpawnResult(error=None, pid=pid, stdout=None) - if raw_tty: - self._prepare_terminal() - try: - # the socket must be non-blocking for using select() - self._sock.setblocking(False) - error = self._execution_loop(stdin, stdout) - except Exception: # noqa: E722 - self._sock.setblocking(True) - # this is important to really catch every exception here, even exceptions not inheriting from Exception - # so the controlling terminal will remain working with its previous settings if raw_tty: - self._restore_terminal() - raise + self._prepare_terminal() + try: + # the socket must be non-blocking for using select() + self._sock.setblocking(False) + error = self._execution_loop(stdin, stdout) + except Exception: # noqa: E722 + self._sock.setblocking(True) + # this is important to really catch every exception here, even exceptions not inheriting from Exception + # so the controlling terminal will remain working with its previous settings + if raw_tty: + self._restore_terminal() + raise if raw_tty: self._restore_terminal() @@ -405,7 +418,7 @@ def _ipython_run_cell_hook(self, info): symbol ) - def close(self): + def _close(self): message = protocol_message_t.build({ 'cmd_type': cmd_type_t.CMD_CLOSE, 'data': None, @@ -413,13 +426,17 @@ def close(self): self._sock.sendall(message) self._sock.close() + def close(self): + with self._lock: + self._close() + def reconnect(self): """ close current socket and attempt to reconnect """ - self.close() - self._sock = socket() - self._sock.connect((self._hostname, self._port)) - - handshake = protocol_handshake_t.parse(self._recvall(protocol_handshake_t.sizeof())) + with self._lock: + self._close() + self._sock = socket() + self._sock.connect((self._hostname, self._port)) + handshake = protocol_handshake_t.parse(self._recvall(protocol_handshake_t.sizeof())) if handshake.magic != SERVER_MAGIC_VERSION: raise InvalidServerVersionMagicError() diff --git a/src/rpcclient/rpcclient/client_factory.py b/src/rpcclient/rpcclient/client_factory.py index 5849da47..de3408f3 100644 --- a/src/rpcclient/rpcclient/client_factory.py +++ b/src/rpcclient/rpcclient/client_factory.py @@ -32,7 +32,7 @@ def create_client(hostname: str, port: int = DEFAULT_PORT): handshake = protocol_handshake_t.parse(recvall(sock, protocol_handshake_t.sizeof())) if handshake.magic != SERVER_MAGIC_VERSION: - raise InvalidServerVersionMagicError() + raise InvalidServerVersionMagicError(f'got {handshake.magic:x} instead of {SERVER_MAGIC_VERSION:x}') sysname = handshake.sysname.lower() arch = handshake.arch diff --git a/src/rpcclient/rpcclient/protocol.py b/src/rpcclient/rpcclient/protocol.py index 61a61149..7557eff6 100644 --- a/src/rpcclient/rpcclient/protocol.py +++ b/src/rpcclient/rpcclient/protocol.py @@ -1,5 +1,5 @@ from construct import Struct, Int32ul, PrefixedArray, Const, Enum, this, PascalString, Switch, PaddedString, Bytes, \ - Int64ul, Int8ul, IfThenElse, Float64l, Array, Union + Int64ul, Int8ul, IfThenElse, Float64l, Array, Union, Hex cmd_type_t = Enum(Int32ul, CMD_EXEC=0, @@ -28,7 +28,7 @@ MAX_PATH_LEN = 1024 protocol_handshake_t = Struct( - 'magic' / Int32ul, + 'magic' / Hex(Int32ul), 'arch' / arch_t, 'sysname' / PaddedString(256, 'utf8'), ) @@ -79,7 +79,7 @@ ) protocol_message_t = Struct( - 'magic' / Const(MAGIC, Int32ul), + 'magic' / Const(MAGIC, Hex(Int32ul)), 'cmd_type' / cmd_type_t, 'data' / Switch(this.cmd_type, { cmd_type_t.CMD_EXEC: cmd_exec_t, diff --git a/src/rpcclient/tests/test_thread_safe.py b/src/rpcclient/tests/test_thread_safe.py new file mode 100644 index 00000000..1a9d308f --- /dev/null +++ b/src/rpcclient/tests/test_thread_safe.py @@ -0,0 +1,31 @@ +import time +from threading import Thread + + +def test_same_socket_different_threads(client): + # get an expected result once on main thread + expected_result = client.fs.listdir('/') + + # global to tell threads when to exit + should_exit = False + + def listdir_thread(client): + while not should_exit: + assert client.fs.listdir('/') == expected_result + return 0 + + # launch the two threads + t1 = Thread(target=listdir_thread, args=(client,)) + t2 = Thread(target=listdir_thread, args=(client,)) + + t1.start() + t2.start() + + # wait 10 seconds + time.sleep(10) + + # tell threads they + should_exit = True + + t1.join() + t2.join()