Skip to content

Commit

Permalink
client: make client thread-safe
Browse files Browse the repository at this point in the history
  • Loading branch information
doronz88 committed Apr 4, 2022
1 parent c3331b1 commit 143e859
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 51 deletions.
111 changes: 64 additions & 47 deletions src/rpcclient/rpcclient/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import os
import sys
import threading
import typing
from collections import namedtuple
from enum import Enum
Expand Down Expand Up @@ -71,6 +72,7 @@ def __init__(self, sock, sysname: str, arch: arch_t, hostname: str, port: int =
self._endianness = '<'
self._sysname = sysname
self._dlsym_global_handle = -1 # RTLD_NEXT
self._lock = threading.Lock()

# whether the system uses inode structs of 64 bits
self.inode64 = False
Expand Down Expand Up @@ -112,9 +114,10 @@ def dlopen(self, filename: str, mode: int) -> Symbol:
'cmd_type': cmd_type_t.CMD_DLOPEN,
'data': {'filename': filename, 'mode': mode},
})
self._sock.sendall(message)
err = Int64sl.parse(self._recvall(Int64sl.sizeof()))
return self.symbol(err)
with self._lock:
self._sock.sendall(message)
address = Int64sl.parse(self._recvall(Int64sl.sizeof()))
return self.symbol(address)

def dlclose(self, lib: int):
""" call dlclose() at remote and return its handle. see the man page for more details. """
Expand All @@ -123,8 +126,9 @@ def dlclose(self, lib: int):
'cmd_type': cmd_type_t.CMD_DLCLOSE,
'data': {'lib': lib},
})
self._sock.sendall(message)
err = Int64sl.parse(self._recvall(Int64sl.sizeof()))
with self._lock:
self._sock.sendall(message)
err = Int64sl.parse(self._recvall(Int64sl.sizeof()))
return err

def dlsym(self, lib: int, symbol_name: str):
Expand All @@ -134,9 +138,10 @@ def dlsym(self, lib: int, symbol_name: str):
'cmd_type': cmd_type_t.CMD_DLSYM,
'data': {'lib': lib, 'symbol_name': symbol_name},
})
self._sock.sendall(message)
err = Int64sl.parse(self._recvall(Int64sl.sizeof()))
return err
with self._lock:
self._sock.sendall(message)
address = Int64sl.parse(self._recvall(Int64sl.sizeof()))
return address

def call(self, address: int, argv: typing.List[int] = None, return_float64=False, return_float32=False,
return_float16=False, return_raw=False) -> Symbol:
Expand Down Expand Up @@ -178,9 +183,10 @@ def call(self, address: int, argv: typing.List[int] = None, return_float64=False
'cmd_type': cmd_type_t.CMD_CALL,
'data': {'address': address, 'argv': fixed_argv},
})
self._sock.sendall(message)

response = call_response_t.parse(self._recvall(call_response_t_size))
with self._lock:
self._sock.sendall(message)
response = call_response_t.parse(self._recvall(call_response_t_size))

for f in free_list:
self.symbols.free(f)
Expand Down Expand Up @@ -213,31 +219,37 @@ def peek(self, address: int, size: int) -> bytes:
'cmd_type': cmd_type_t.CMD_PEEK,
'data': {'address': address, 'size': size},
})
self._sock.sendall(message)
reply = protocol_message_t.parse(self._recvall(reply_protocol_message_t.sizeof()))
if reply.cmd_type == cmd_type_t.CMD_REPLY_ERROR:
raise ArgumentError(f'failed to read {size} bytes from {address}')
return self._recvall(size)
with self._lock:
self._sock.sendall(message)
reply = protocol_message_t.parse(self._recvall(reply_protocol_message_t.sizeof()))
if reply.cmd_type == cmd_type_t.CMD_REPLY_ERROR:
raise ArgumentError(f'failed to read {size} bytes from {address}')
return self._recvall(size)

def poke(self, address: int, data: bytes):
""" poke data at given address """
message = protocol_message_t.build({
'cmd_type': cmd_type_t.CMD_POKE,
'data': {'address': address, 'size': len(data), 'data': data},
})
self._sock.sendall(message)
reply = protocol_message_t.parse(self._recvall(reply_protocol_message_t.sizeof()))
if reply.cmd_type == cmd_type_t.CMD_REPLY_ERROR:
raise ArgumentError(f'failed to write {len(data)} bytes to {address}')
with self._lock:
self._sock.sendall(message)
reply = protocol_message_t.parse(self._recvall(reply_protocol_message_t.sizeof()))
if reply.cmd_type == cmd_type_t.CMD_REPLY_ERROR:
raise ArgumentError(f'failed to write {len(data)} bytes to {address}')

def get_dummy_block(self) -> Symbol:
""" get an address for a stub block containing nothing """
message = protocol_message_t.build({
'cmd_type': cmd_type_t.CMD_GET_DUMMY_BLOCK,
'data': None,
})
self._sock.sendall(message)
return self.symbol(dummy_block_t.parse(self._recvall(8)))

with self._lock:
self._sock.sendall(message)
result = dummy_block_t.parse(self._recvall(dummy_block_t.sizeof()))

return self.symbol(result)

def spawn(self, argv: typing.List[str] = None, envp: typing.List[str] = None, stdin: io_or_str = sys.stdin,
stdout=sys.stdout, raw_tty=False, background=False) -> SpawnResult:
Expand All @@ -258,30 +270,31 @@ def spawn(self, argv: typing.List[str] = None, envp: typing.List[str] = None, st
if envp is None:
envp = self.DEFAULT_ENVP

try:
pid = self._execute(argv, envp, background=background)
except SpawnError:
# depends on where the error occurred, the socket might be closed
raise
with self._lock:
try:
pid = self._execute(argv, envp, background=background)
except SpawnError:
# depends on where the error occurred, the socket might be closed
raise

logging.info(f'shell process started as pid: {pid}')
logging.info(f'shell process started as pid: {pid}')

if background:
return SpawnResult(error=None, pid=pid, stdout=None)
if background:
return SpawnResult(error=None, pid=pid, stdout=None)

if raw_tty:
self._prepare_terminal()
try:
# the socket must be non-blocking for using select()
self._sock.setblocking(False)
error = self._execution_loop(stdin, stdout)
except Exception: # noqa: E722
self._sock.setblocking(True)
# this is important to really catch every exception here, even exceptions not inheriting from Exception
# so the controlling terminal will remain working with its previous settings
if raw_tty:
self._restore_terminal()
raise
self._prepare_terminal()
try:
# the socket must be non-blocking for using select()
self._sock.setblocking(False)
error = self._execution_loop(stdin, stdout)
except Exception: # noqa: E722
self._sock.setblocking(True)
# this is important to really catch every exception here, even exceptions not inheriting from Exception
# so the controlling terminal will remain working with its previous settings
if raw_tty:
self._restore_terminal()
raise

if raw_tty:
self._restore_terminal()
Expand Down Expand Up @@ -405,21 +418,25 @@ def _ipython_run_cell_hook(self, info):
symbol
)

def close(self):
def _close(self):
message = protocol_message_t.build({
'cmd_type': cmd_type_t.CMD_CLOSE,
'data': None,
})
self._sock.sendall(message)
self._sock.close()

def close(self):
with self._lock:
self._close()

def reconnect(self):
""" close current socket and attempt to reconnect """
self.close()
self._sock = socket()
self._sock.connect((self._hostname, self._port))

handshake = protocol_handshake_t.parse(self._recvall(protocol_handshake_t.sizeof()))
with self._lock:
self._close()
self._sock = socket()
self._sock.connect((self._hostname, self._port))
handshake = protocol_handshake_t.parse(self._recvall(protocol_handshake_t.sizeof()))

if handshake.magic != SERVER_MAGIC_VERSION:
raise InvalidServerVersionMagicError()
Expand Down
2 changes: 1 addition & 1 deletion src/rpcclient/rpcclient/client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def create_client(hostname: str, port: int = DEFAULT_PORT):
handshake = protocol_handshake_t.parse(recvall(sock, protocol_handshake_t.sizeof()))

if handshake.magic != SERVER_MAGIC_VERSION:
raise InvalidServerVersionMagicError()
raise InvalidServerVersionMagicError(f'got {handshake.magic:x} instead of {SERVER_MAGIC_VERSION:x}')

sysname = handshake.sysname.lower()
arch = handshake.arch
Expand Down
6 changes: 3 additions & 3 deletions src/rpcclient/rpcclient/protocol.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from construct import Struct, Int32ul, PrefixedArray, Const, Enum, this, PascalString, Switch, PaddedString, Bytes, \
Int64ul, Int8ul, IfThenElse, Float64l, Array, Union
Int64ul, Int8ul, IfThenElse, Float64l, Array, Union, Hex

cmd_type_t = Enum(Int32ul,
CMD_EXEC=0,
Expand Down Expand Up @@ -28,7 +28,7 @@
MAX_PATH_LEN = 1024

protocol_handshake_t = Struct(
'magic' / Int32ul,
'magic' / Hex(Int32ul),
'arch' / arch_t,
'sysname' / PaddedString(256, 'utf8'),
)
Expand Down Expand Up @@ -79,7 +79,7 @@
)

protocol_message_t = Struct(
'magic' / Const(MAGIC, Int32ul),
'magic' / Const(MAGIC, Hex(Int32ul)),
'cmd_type' / cmd_type_t,
'data' / Switch(this.cmd_type, {
cmd_type_t.CMD_EXEC: cmd_exec_t,
Expand Down
31 changes: 31 additions & 0 deletions src/rpcclient/tests/test_thread_safe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import time
from threading import Thread


def test_same_socket_different_threads(client):
# get an expected result once on main thread
expected_result = client.fs.listdir('/')

# global to tell threads when to exit
should_exit = False

def listdir_thread(client):
while not should_exit:
assert client.fs.listdir('/') == expected_result
return 0

# launch the two threads
t1 = Thread(target=listdir_thread, args=(client,))
t2 = Thread(target=listdir_thread, args=(client,))

t1.start()
t2.start()

# wait 10 seconds
time.sleep(10)

# tell threads they
should_exit = True

t1.join()
t2.join()

0 comments on commit 143e859

Please sign in to comment.