Skip to content

Commit

Permalink
Merge pull request #90 from doronz88/bugfix/reconnect_double_close
Browse files Browse the repository at this point in the history
server: bugfix: several possible edge cases and races
  • Loading branch information
doronz88 authored Feb 22, 2022
2 parents 8d81bf9 + 756eda0 commit 84b283d
Show file tree
Hide file tree
Showing 11 changed files with 295 additions and 166 deletions.
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,13 @@ python3 -m pip install --user -U .

To execute the server:

```shell
./server [-p port]
```
Usage: ./rpcserver [-p port] [-o (stdout|syslog|file:filename)]
-h show this help message
-o output. can be all of the following: stdout, syslog and file:filename. can be passed multiple times
Example usage:
./rpcserver -p 5910 -o syslog -o stdout -o file:/tmp/log.txt
```

Connecting via:
Expand Down
2 changes: 1 addition & 1 deletion src/rpcclient/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ cached-property
dataclasses; python_version<"3.7"
pygments
objc_types_decoder
pycrashreport>=0.0.7
pycrashreport>=0.0.8
34 changes: 23 additions & 11 deletions src/rpcclient/rpcclient/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ def poke(self, address: int, data: bytes):
'data': {'address': address, 'size': len(data), 'data': data},
})
self._sock.sendall(message)
reply = protocol_message_t.parse(self._recvall(reply_protocol_message_t.sizeof()))
if reply.cmd_type == cmd_type_t.CMD_REPLY_ERROR:
raise ArgumentError(f'failed to write {len(data)} bytes to {address}')

def get_dummy_block(self) -> Symbol:
""" get an address for a stub block containing nothing """
Expand All @@ -210,7 +213,7 @@ def spawn(self, argv: typing.List[str] = None, envp: typing.List[str] = None, st
:param argv: argv of the process to be executed
:param envp: envp of the process to be executed
:param stdin: either a file object to read from OR a string
:param stdout: a file object to write both stdout and stderr to
:param stdout: a file object to write both stdout and stderr to. None if background is requested
:param raw_tty: should enable raw tty mode
:param background: should execute process in background
:return: a SpawnResult. error is None if background is requested
Expand All @@ -222,35 +225,35 @@ def spawn(self, argv: typing.List[str] = None, envp: typing.List[str] = None, st
envp = self.DEFAULT_ENVP

try:
pid = self._execute(argv, envp)
pid = self._execute(argv, envp, background=background)
except SpawnError:
# depends on where the error occurred, the socket might be closed
self.reconnect()
raise

logging.info(f'shell process started as pid: {pid}')

if background:
# if in background was requested, we can just detach this connection
self.reconnect()
return SpawnResult(error=None, pid=pid, stdout=stdout)

self._sock.setblocking(False)
return SpawnResult(error=None, pid=pid, stdout=None)

if raw_tty:
self._prepare_terminal()
try:
# the socket must be non-blocking for using select()
self._sock.setblocking(False)
error = self._execution_loop(stdin, stdout)
except Exception: # noqa: E722
self._sock.setblocking(True)
# this is important to really catch every exception here, even exceptions not inheriting from Exception
# so the controlling terminal will remain working with its previous settings
if raw_tty:
self._restore_terminal()
self.reconnect()
raise

if raw_tty:
self._restore_terminal()

# TODO: we should be able to return here without the need to reconnect but from some reason the
# socket goes out of sync when doing so
self.reconnect()

return SpawnResult(error=error, pid=pid, stdout=stdout)
Expand Down Expand Up @@ -356,8 +359,17 @@ def _ipython_run_cell_hook(self, info):
symbol
)

def close(self):
message = protocol_message_t.build({
'cmd_type': cmd_type_t.CMD_CLOSE,
'data': None,
})
self._sock.sendall(message)
self._sock.close()

def reconnect(self):
""" close current socket and attempt to reconnect """
self.close()
self._sock = socket()
self._sock.connect((self._hostname, self._port))
magic = self._recvall(len(SERVER_MAGIC_VERSION))
Expand All @@ -367,10 +379,10 @@ def reconnect(self):

self._recvall(UNAME_VERSION_LEN)

def _execute(self, argv: typing.List[str], envp: typing.List[str]) -> int:
def _execute(self, argv: typing.List[str], envp: typing.List[str], background=False) -> int:
message = protocol_message_t.build({
'cmd_type': cmd_type_t.CMD_EXEC,
'data': {'argv': argv, 'envp': envp},
'data': {'background': background, 'argv': argv, 'envp': envp},
})
self._sock.sendall(message)
pid = pid_t.parse(self._sock.recv(pid_t.sizeof()))
Expand Down
3 changes: 3 additions & 0 deletions src/rpcclient/rpcclient/darwin/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def disconnect(self):
self._client.symbols.Apple80211Disassociate(self._interface)

def _set(self, is_on: bool):
with self._client.preferences.sc.get_preferences_object('com.apple.wifi.plist') as pref:
pref.set('AllowEnable', int(is_on))

if not is_on:
if self._client.symbols.WiFiManagerClientDisable(self._wifi_manager_client):
raise BadReturnValueError(f'WiFiManagerClientDisable failed ({self._client.last_error})')
Expand Down
4 changes: 4 additions & 0 deletions src/rpcclient/rpcclient/ios/client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import typing

from rpcclient.darwin.client import DarwinClient
from rpcclient.darwin.reports import Reports
from rpcclient.ios.backlight import Backlight

CRASH_REPORTS_DIR = 'Library/Logs/CrashReporter'


class IosClient(DarwinClient):
def __init__(self, sock, sysname: str, hostname: str, port: int = None):
super().__init__(sock, sysname, hostname, port)
self.backlight = Backlight(self)
self.reports = Reports(self, CRASH_REPORTS_DIR)

@property
def roots(self) -> typing.List[str]:
Expand Down
5 changes: 4 additions & 1 deletion src/rpcclient/rpcclient/protocol.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from construct import Struct, Int32ul, PrefixedArray, Const, Enum, this, PascalString, Switch, PaddedString, Bytes, \
Int64ul
Int64ul, Int8ul

cmd_type_t = Enum(Int32ul,
CMD_EXEC=0,
Expand All @@ -12,6 +12,8 @@
CMD_REPLY_ERROR=7,
CMD_REPLY_PEEK=8,
CMD_GET_DUMMY_BLOCK=9,
CMD_CLOSE=10,
CMD_REPLY_POKE=11,
)
DEFAULT_PORT = 5910
SERVER_MAGIC_VERSION = Int32ul.build(0x88888800)
Expand All @@ -20,6 +22,7 @@
UNAME_VERSION_LEN = 256

cmd_exec_t = Struct(
'background' / Int8ul,
'argv' / PrefixedArray(Int32ul, PascalString(Int32ul, 'utf8')),
'envp' / PrefixedArray(Int32ul, PascalString(Int32ul, 'utf8')),
)
Expand Down
6 changes: 5 additions & 1 deletion src/rpcclient/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@

@pytest.fixture
def client():
return create_client('127.0.0.1')
try:
c = create_client('127.0.0.1')
yield c
finally:
c.close()


def pytest_addoption(parser):
Expand Down
33 changes: 25 additions & 8 deletions src/rpcclient/tests/test_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,49 @@
import pytest


def test_spawn_fds(client):
pid = client.spawn(['/bin/sleep', '5'], stdout=StringIO(), stdin='', background=True).pid

# should only have: stdin, stdout and stderr
assert len(client.processes.get_fds(pid)) == 3

client.processes.kill(pid)


@pytest.mark.parametrize('argv,expected_stdout,errorcode', [
[['/bin/sleep', '0'], '', 0],
[['/bin/echo', 'blat'], 'blat', 0],
[['/bin/ls', 'INVALID_PATH'], 'ls: INVALID_PATH: No such file or directory', 256],
])
def test_spawn_sanity(client, argv, expected_stdout, errorcode):
def test_spawn_foreground_sanity(client, argv, expected_stdout, errorcode):
stdout = StringIO()
assert errorcode == client.spawn(argv, stdout=stdout, stdin='').error

stdout.seek(0)
assert expected_stdout == stdout.read().strip()


def test_spawn_bad_value_stress(client):
@pytest.mark.parametrize('argv,expected_stdout,errorcode', [
[['/bin/sleep', '0'], '', 0],
[['/bin/echo', 'blat'], 'blat', 0],
[['/bin/ls', 'INVALID_PATH'], 'ls: INVALID_PATH: No such file or directory', 256],
])
def test_spawn_foreground_stress(client, argv, expected_stdout, errorcode):
for i in range(1000):
stdout = StringIO()
assert 256 == client.spawn(['/bin/ls', 'INVALID_PATH'], stdout=stdout, stdin='').error

stdout.seek(0)
assert 'ls: INVALID_PATH: No such file or directory' == stdout.read().strip()
test_spawn_foreground_sanity(client, argv, expected_stdout, errorcode)


def test_spawn_background(client):
def test_spawn_background_sanity(client):
spawn_result = client.spawn(['/bin/sleep', '5'], stdout=StringIO(), stdin='', background=True)

# when running in background, no error is returned
assert spawn_result.error is None
assert spawn_result.stdout is None

# instead, we can just make sure it ran by sending it a kill and don't fail
client.processes.kill(spawn_result.pid)


def test_spawn_background_stress(client):
for i in range(1000):
test_spawn_background_sanity(client)
50 changes: 44 additions & 6 deletions src/rpcserver/common.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <stdbool.h>
#include <unistd.h>
#include <syslog.h>
#include <execinfo.h>
#include <sys/socket.h>

#include "common.h"
Expand All @@ -12,6 +13,35 @@ bool g_stdout = false;
bool g_syslog = false;
FILE *g_file = NULL;

#define BT_BUF_SIZE (100)

void print_backtrace()
{
int nptrs;
void *buffer[BT_BUF_SIZE];
char **strings;

nptrs = backtrace(buffer, BT_BUF_SIZE);
trace("BACKTRACE", "backtrace() returned %d addresses", nptrs);

/* The call backtrace_symbols_fd(buffer, nptrs, STDOUT_FILENO)
would produce similar output to the following: */

strings = backtrace_symbols(buffer, nptrs);
if (strings == NULL)
{
perror("backtrace_symbols");
return;
}

for (int j = 0; j < nptrs; j++)
{
trace("BACKTRACE:\t", "%s", strings[j]);
}

free(strings);
}

void trace(const char *prefix, const char *fmt, ...)
{
if (!g_stdout && !g_syslog)
Expand Down Expand Up @@ -45,14 +75,21 @@ void trace(const char *prefix, const char *fmt, ...)
}
}

bool recvall(int sockfd, char *buf, size_t len)
bool recvall_ext(int sockfd, char *buf, size_t len, bool *disconnected)
{
size_t total_bytes = 0;
size_t bytes = 0;
*disconnected = false;

while (len > 0)
{
bytes = recv(sockfd, buf + total_bytes, len, 0);
if (0 == bytes)
{
TRACE("client fd: %d disconnected", sockfd);
*disconnected = true;
return false;
}
CHECK(bytes > 0);

total_bytes += bytes;
Expand All @@ -62,14 +99,15 @@ bool recvall(int sockfd, char *buf, size_t len)
return true;

error:
if (0 == bytes)
{
TRACE("client fd: %d disconnected", sockfd);
}

return false;
}

bool recvall(int sockfd, char *buf, size_t len)
{
bool disconnected;
return recvall_ext(sockfd, buf, len, &disconnected);
}

bool sendall(int sockfd, const char *buf, size_t len)
{
size_t total_bytes = 0;
Expand Down
11 changes: 9 additions & 2 deletions src/rpcserver/common.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#ifndef __COMMON_H_
#define __COMMON_H_

#include <stdio.h>
#include <stdarg.h>
#include <stdlib.h>
#include <stdbool.h>
#include <errno.h>
#include <string.h>

typedef unsigned char u8;
typedef unsigned short u16;
Expand All @@ -21,12 +23,17 @@ FILE *g_file;
#define CHECK(expression) \
if (!(expression)) \
{ \
if (errno) perror(__PRETTY_FUNCTION__); \
if (errno) \
{ \
trace(__PRETTY_FUNCTION__, "ERROR: errno: %d (%s)", errno, strerror(errno)); \
} \
print_backtrace(); \
goto error; \
}

void print_backtrace();
void trace(const char *prefix, const char *fmt, ...);

bool recvall_ext(int sockfd, char *buf, size_t len, bool *disconnected);
bool recvall(int sockfd, char *buf, size_t len);
bool sendall(int sockfd, const char *buf, size_t len);
bool writeall(int fd, const char *buf, size_t len);
Expand Down
Loading

0 comments on commit 84b283d

Please sign in to comment.