From a7bbc135374e593496f73c327112da64677ad10c Mon Sep 17 00:00:00 2001 From: DoronZ Date: Tue, 22 Feb 2022 00:10:23 +0200 Subject: [PATCH] server: bugfix: several possible edge cases and races --- src/rpcclient/rpcclient/client.py | 34 +++- src/rpcclient/rpcclient/protocol.py | 5 +- src/rpcclient/tests/conftest.py | 6 +- src/rpcclient/tests/test_spawn.py | 33 ++- src/rpcserver/common.c | 50 ++++- src/rpcserver/common.h | 11 +- src/rpcserver/rpcserver.c | 304 ++++++++++++++++------------ 7 files changed, 280 insertions(+), 163 deletions(-) diff --git a/src/rpcclient/rpcclient/client.py b/src/rpcclient/rpcclient/client.py index b0579d49..f843a13b 100644 --- a/src/rpcclient/rpcclient/client.py +++ b/src/rpcclient/rpcclient/client.py @@ -192,6 +192,9 @@ def poke(self, address: int, data: bytes): 'data': {'address': address, 'size': len(data), 'data': data}, }) self._sock.sendall(message) + reply = protocol_message_t.parse(self._recvall(reply_protocol_message_t.sizeof())) + if reply.cmd_type == cmd_type_t.CMD_REPLY_ERROR: + raise ArgumentError(f'failed to write {len(data)} bytes to {address}') def get_dummy_block(self) -> Symbol: """ get an address for a stub block containing nothing """ @@ -210,7 +213,7 @@ def spawn(self, argv: typing.List[str] = None, envp: typing.List[str] = None, st :param argv: argv of the process to be executed :param envp: envp of the process to be executed :param stdin: either a file object to read from OR a string - :param stdout: a file object to write both stdout and stderr to + :param stdout: a file object to write both stdout and stderr to. None if background is requested :param raw_tty: should enable raw tty mode :param background: should execute process in background :return: a SpawnResult. error is None if background is requested @@ -222,35 +225,35 @@ def spawn(self, argv: typing.List[str] = None, envp: typing.List[str] = None, st envp = self.DEFAULT_ENVP try: - pid = self._execute(argv, envp) + pid = self._execute(argv, envp, background=background) except SpawnError: # depends on where the error occurred, the socket might be closed - self.reconnect() raise logging.info(f'shell process started as pid: {pid}') if background: - # if in background was requested, we can just detach this connection - self.reconnect() - return SpawnResult(error=None, pid=pid, stdout=stdout) - - self._sock.setblocking(False) + return SpawnResult(error=None, pid=pid, stdout=None) if raw_tty: self._prepare_terminal() try: + # the socket must be non-blocking for using select() + self._sock.setblocking(False) error = self._execution_loop(stdin, stdout) except Exception: # noqa: E722 + self._sock.setblocking(True) # this is important to really catch every exception here, even exceptions not inheriting from Exception # so the controlling terminal will remain working with its previous settings if raw_tty: self._restore_terminal() - self.reconnect() raise if raw_tty: self._restore_terminal() + + # TODO: we should be able to return here without the need to reconnect but from some reason the + # socket goes out of sync when doing so self.reconnect() return SpawnResult(error=error, pid=pid, stdout=stdout) @@ -356,8 +359,17 @@ def _ipython_run_cell_hook(self, info): symbol ) + def close(self): + message = protocol_message_t.build({ + 'cmd_type': cmd_type_t.CMD_CLOSE, + 'data': None, + }) + self._sock.sendall(message) + self._sock.close() + def reconnect(self): """ close current socket and attempt to reconnect """ + self.close() self._sock = socket() self._sock.connect((self._hostname, self._port)) magic = self._recvall(len(SERVER_MAGIC_VERSION)) @@ -367,10 +379,10 @@ def reconnect(self): self._recvall(UNAME_VERSION_LEN) - def _execute(self, argv: typing.List[str], envp: typing.List[str]) -> int: + def _execute(self, argv: typing.List[str], envp: typing.List[str], background=False) -> int: message = protocol_message_t.build({ 'cmd_type': cmd_type_t.CMD_EXEC, - 'data': {'argv': argv, 'envp': envp}, + 'data': {'background': background, 'argv': argv, 'envp': envp}, }) self._sock.sendall(message) pid = pid_t.parse(self._sock.recv(pid_t.sizeof())) diff --git a/src/rpcclient/rpcclient/protocol.py b/src/rpcclient/rpcclient/protocol.py index 8f663a0d..d0a0c9a9 100644 --- a/src/rpcclient/rpcclient/protocol.py +++ b/src/rpcclient/rpcclient/protocol.py @@ -1,5 +1,5 @@ from construct import Struct, Int32ul, PrefixedArray, Const, Enum, this, PascalString, Switch, PaddedString, Bytes, \ - Int64ul + Int64ul, Int8ul cmd_type_t = Enum(Int32ul, CMD_EXEC=0, @@ -12,6 +12,8 @@ CMD_REPLY_ERROR=7, CMD_REPLY_PEEK=8, CMD_GET_DUMMY_BLOCK=9, + CMD_CLOSE=10, + CMD_REPLY_POKE=11, ) DEFAULT_PORT = 5910 SERVER_MAGIC_VERSION = Int32ul.build(0x88888800) @@ -20,6 +22,7 @@ UNAME_VERSION_LEN = 256 cmd_exec_t = Struct( + 'background' / Int8ul, 'argv' / PrefixedArray(Int32ul, PascalString(Int32ul, 'utf8')), 'envp' / PrefixedArray(Int32ul, PascalString(Int32ul, 'utf8')), ) diff --git a/src/rpcclient/tests/conftest.py b/src/rpcclient/tests/conftest.py index 414aa534..94932181 100644 --- a/src/rpcclient/tests/conftest.py +++ b/src/rpcclient/tests/conftest.py @@ -5,7 +5,11 @@ @pytest.fixture def client(): - return create_client('127.0.0.1') + try: + c = create_client('127.0.0.1') + yield c + finally: + c.close() def pytest_addoption(parser): diff --git a/src/rpcclient/tests/test_spawn.py b/src/rpcclient/tests/test_spawn.py index 0692a791..f380361e 100644 --- a/src/rpcclient/tests/test_spawn.py +++ b/src/rpcclient/tests/test_spawn.py @@ -3,11 +3,21 @@ import pytest +def test_spawn_fds(client): + pid = client.spawn(['/bin/sleep', '5'], stdout=StringIO(), stdin='', background=True).pid + + # should only have: stdin, stdout and stderr + assert len(client.processes.get_fds(pid)) == 3 + + client.processes.kill(pid) + + @pytest.mark.parametrize('argv,expected_stdout,errorcode', [ [['/bin/sleep', '0'], '', 0], + [['/bin/echo', 'blat'], 'blat', 0], [['/bin/ls', 'INVALID_PATH'], 'ls: INVALID_PATH: No such file or directory', 256], ]) -def test_spawn_sanity(client, argv, expected_stdout, errorcode): +def test_spawn_foreground_sanity(client, argv, expected_stdout, errorcode): stdout = StringIO() assert errorcode == client.spawn(argv, stdout=stdout, stdin='').error @@ -15,20 +25,27 @@ def test_spawn_sanity(client, argv, expected_stdout, errorcode): assert expected_stdout == stdout.read().strip() -def test_spawn_bad_value_stress(client): +@pytest.mark.parametrize('argv,expected_stdout,errorcode', [ + [['/bin/sleep', '0'], '', 0], + [['/bin/echo', 'blat'], 'blat', 0], + [['/bin/ls', 'INVALID_PATH'], 'ls: INVALID_PATH: No such file or directory', 256], +]) +def test_spawn_foreground_stress(client, argv, expected_stdout, errorcode): for i in range(1000): - stdout = StringIO() - assert 256 == client.spawn(['/bin/ls', 'INVALID_PATH'], stdout=stdout, stdin='').error - - stdout.seek(0) - assert 'ls: INVALID_PATH: No such file or directory' == stdout.read().strip() + test_spawn_foreground_sanity(client, argv, expected_stdout, errorcode) -def test_spawn_background(client): +def test_spawn_background_sanity(client): spawn_result = client.spawn(['/bin/sleep', '5'], stdout=StringIO(), stdin='', background=True) # when running in background, no error is returned assert spawn_result.error is None + assert spawn_result.stdout is None # instead, we can just make sure it ran by sending it a kill and don't fail client.processes.kill(spawn_result.pid) + + +def test_spawn_background_stress(client): + for i in range(1000): + test_spawn_background_sanity(client) diff --git a/src/rpcserver/common.c b/src/rpcserver/common.c index ba42d7a9..23bc3483 100644 --- a/src/rpcserver/common.c +++ b/src/rpcserver/common.c @@ -4,6 +4,7 @@ #include #include #include +#include #include #include "common.h" @@ -12,6 +13,35 @@ bool g_stdout = false; bool g_syslog = false; FILE *g_file = NULL; +#define BT_BUF_SIZE (100) + +void print_backtrace() +{ + int nptrs; + void *buffer[BT_BUF_SIZE]; + char **strings; + + nptrs = backtrace(buffer, BT_BUF_SIZE); + trace("BACKTRACE", "backtrace() returned %d addresses", nptrs); + + /* The call backtrace_symbols_fd(buffer, nptrs, STDOUT_FILENO) + would produce similar output to the following: */ + + strings = backtrace_symbols(buffer, nptrs); + if (strings == NULL) + { + perror("backtrace_symbols"); + return; + } + + for (int j = 0; j < nptrs; j++) + { + trace("BACKTRACE:\t", "%s", strings[j]); + } + + free(strings); +} + void trace(const char *prefix, const char *fmt, ...) { if (!g_stdout && !g_syslog) @@ -45,14 +75,21 @@ void trace(const char *prefix, const char *fmt, ...) } } -bool recvall(int sockfd, char *buf, size_t len) +bool recvall_ext(int sockfd, char *buf, size_t len, bool *disconnected) { size_t total_bytes = 0; size_t bytes = 0; + *disconnected = false; while (len > 0) { bytes = recv(sockfd, buf + total_bytes, len, 0); + if (0 == bytes) + { + TRACE("client fd: %d disconnected", sockfd); + *disconnected = true; + return false; + } CHECK(bytes > 0); total_bytes += bytes; @@ -62,14 +99,15 @@ bool recvall(int sockfd, char *buf, size_t len) return true; error: - if (0 == bytes) - { - TRACE("client fd: %d disconnected", sockfd); - } - return false; } +bool recvall(int sockfd, char *buf, size_t len) +{ + bool disconnected; + return recvall_ext(sockfd, buf, len, &disconnected); +} + bool sendall(int sockfd, const char *buf, size_t len) { size_t total_bytes = 0; diff --git a/src/rpcserver/common.h b/src/rpcserver/common.h index baab78ba..76ff137d 100644 --- a/src/rpcserver/common.h +++ b/src/rpcserver/common.h @@ -1,10 +1,12 @@ #ifndef __COMMON_H_ #define __COMMON_H_ +#include #include #include #include #include +#include typedef unsigned char u8; typedef unsigned short u16; @@ -21,12 +23,17 @@ FILE *g_file; #define CHECK(expression) \ if (!(expression)) \ { \ - if (errno) perror(__PRETTY_FUNCTION__); \ + if (errno) \ + { \ + trace(__PRETTY_FUNCTION__, "ERROR: errno: %d (%s)", errno, strerror(errno)); \ + } \ + print_backtrace(); \ goto error; \ } +void print_backtrace(); void trace(const char *prefix, const char *fmt, ...); - +bool recvall_ext(int sockfd, char *buf, size_t len, bool *disconnected); bool recvall(int sockfd, char *buf, size_t len); bool sendall(int sockfd, const char *buf, size_t len); bool writeall(int fd, const char *buf, size_t len); diff --git a/src/rpcserver/rpcserver.c b/src/rpcserver/rpcserver.c index ca4971b1..1de40a82 100644 --- a/src/rpcserver/rpcserver.c +++ b/src/rpcserver/rpcserver.c @@ -63,6 +63,8 @@ typedef enum CMD_REPLY_ERROR = 7, CMD_REPLY_PEEK = 8, CMD_GET_DUMMY_BLOCK = 9, + CMD_CLOSE = 10, + CMD_REPLY_POKE = 11, } cmd_type_t; typedef enum @@ -125,27 +127,12 @@ void *get_in_addr(struct sockaddr *sa) // get sockaddr, IPv4 or IPv6: return sa->sa_family == AF_INET ? (void *)&(((struct sockaddr_in *)sa)->sin_addr) : (void *)&(((struct sockaddr_in6 *)sa)->sin6_addr); } -int internal_spawn(char *const *argv, char *const *envp, pid_t *pid) +bool internal_spawn(bool background, char *const *argv, char *const *envp, pid_t *pid, int *master_fd) { - int master_fd = -1; + bool success = false; int slave_fd = -1; - int res = 0; - - // We need a new pseudoterminal to avoid bufferring problems. The 'atos' tool - // in particular detects when it's talking to a pipe and forgets to flush the - // output stream after sending a response. - master_fd = posix_openpt(O_RDWR); - CHECK(-1 != master_fd); - CHECK(0 == grantpt(master_fd)); - CHECK(0 == unlockpt(master_fd)); - - char slave_pty_name[128]; - CHECK(0 == ptsname_r(master_fd, slave_pty_name, sizeof(slave_pty_name))); - - TRACE("slave_pty_name: %s", slave_pty_name); - - slave_fd = open(slave_pty_name, O_RDWR); - CHECK(-1 != slave_fd); + *master_fd = -1; + *pid = INVALID_PID; // call setsid() on child so Ctrl-C and all other control characters are set in a different terminal // and process group @@ -155,33 +142,60 @@ int internal_spawn(char *const *argv, char *const *envp, pid_t *pid) posix_spawn_file_actions_t actions; CHECK(0 == posix_spawn_file_actions_init(&actions)); - CHECK(0 == posix_spawn_file_actions_adddup2(&actions, slave_fd, STDIN_FILENO)); - CHECK(0 == posix_spawn_file_actions_adddup2(&actions, slave_fd, STDOUT_FILENO)); - CHECK(0 == posix_spawn_file_actions_adddup2(&actions, slave_fd, STDERR_FILENO)); - CHECK(0 == posix_spawn_file_actions_addclose(&actions, slave_fd)); - CHECK(0 == posix_spawn_file_actions_addclose(&actions, master_fd)); + + if (!background) + { + // We need a new pseudoterminal to avoid bufferring problems. The 'atos' tool + // in particular detects when it's talking to a pipe and forgets to flush the + // output stream after sending a response. + *master_fd = posix_openpt(O_RDWR); + CHECK(-1 != *master_fd); + CHECK(0 == grantpt(*master_fd)); + CHECK(0 == unlockpt(*master_fd)); + + char slave_pty_name[128]; + CHECK(0 == ptsname_r(*master_fd, slave_pty_name, sizeof(slave_pty_name))); + + TRACE("slave_pty_name: %s", slave_pty_name); + + slave_fd = open(slave_pty_name, O_RDWR); + CHECK(-1 != slave_fd); + + CHECK(0 == posix_spawn_file_actions_adddup2(&actions, slave_fd, STDIN_FILENO)); + CHECK(0 == posix_spawn_file_actions_adddup2(&actions, slave_fd, STDOUT_FILENO)); + CHECK(0 == posix_spawn_file_actions_adddup2(&actions, slave_fd, STDERR_FILENO)); + CHECK(0 == posix_spawn_file_actions_addclose(&actions, slave_fd)); + CHECK(0 == posix_spawn_file_actions_addclose(&actions, *master_fd)); + } + else + { + CHECK(0 == posix_spawn_file_actions_addopen(&actions, STDIN_FILENO, "/dev/null", O_RDONLY, 0)); + CHECK(0 == posix_spawn_file_actions_addopen(&actions, STDOUT_FILENO, "/dev/null", O_WRONLY, 0)); + CHECK(0 == posix_spawn_file_actions_addopen(&actions, STDERR_FILENO, "/dev/null", O_WRONLY, 0)); + } CHECK(0 == posix_spawnp(pid, argv[0], &actions, &attr, argv, envp)); + CHECK(*pid != INVALID_PID); posix_spawnattr_destroy(&attr); posix_spawn_file_actions_destroy(&actions); - close(slave_fd); - slave_fd = -1; - - return master_fd; + success = true; error: - if (master_fd != -1) - { - close(master_fd); - } if (slave_fd != -1) { close(slave_fd); } - *pid = INVALID_PID; - return -1; + if (!success) + { + if (*master_fd != -1) + { + close(*master_fd); + } + *pid = INVALID_PID; + } + return success; } bool send_reply(int sockfd, cmd_type_t type) @@ -196,47 +210,30 @@ bool send_reply(int sockfd, cmd_type_t type) typedef struct { int sockfd; pid_t pid; - int pipe[2]; -} wait_process_exit_thread_t; +} thread_notify_client_spawn_error_t; -void wait_process_exit_thread(wait_process_exit_thread_t *params) +void thread_waitpid(pid_t pid) { TRACE("enter"); - u8 byte; - CHECK(sizeof(byte) == write(params->pipe[1], &byte, sizeof(byte))); - TRACE("waitpid"); - s32 err; - CHECK(-1 != waitpid(params->pid, &err, 0)); - - TRACE("waitpid done with err: %d", err); - - CHECK(sizeof(byte) == read(params->pipe[0], &byte, sizeof(byte))); - - cmd_exec_chunk_t chunk; - chunk.type = CMD_EXEC_CHUNK_TYPE_EXITCODE; - chunk.size = sizeof(err); - - CHECK(sendall(params->sockfd, (char *)&chunk, sizeof(chunk))); - CHECK(sendall(params->sockfd, (char *)&err, chunk.size)); - - TRACE("sent exit code to client fd: %d", params->sockfd); - -error: - return; + waitpid(pid, &err, 0); } bool handle_exec(int sockfd) { + u8 byte; pthread_t thread = 0; - wait_process_exit_thread_t *thread_params = NULL; + thread_notify_client_spawn_error_t *thread_params = NULL; pid_t pid = INVALID_PID; int master = -1; - int result = false; + int success = false; char **argv = NULL; char **envp = NULL; u32 argc; u32 envc; + u8 background; + + CHECK(recvall(sockfd, (char *)&background, sizeof(background))); CHECK(recvall(sockfd, (char *)&argc, sizeof(argc))); CHECK(argc > 0); @@ -281,96 +278,90 @@ bool handle_exec(int sockfd) envp[i][len] = '\0'; } - master = internal_spawn((char *const *)argv, envc ? (char *const *)envp : environ, &pid); + CHECK(internal_spawn(background, (char *const *)argv, envc ? (char *const *)envp : environ, &pid, &master)); CHECK(sendall(sockfd, (char *)&pid, sizeof(u32))); - CHECK(master >= 0); - - // create a new thread to wait for the exit of the new process, but lock it until after - // all stdin/stdout/stderr has been forwarded - thread_params = (wait_process_exit_thread_t *)malloc(sizeof(wait_process_exit_thread_t)); - thread_params->sockfd = sockfd; - thread_params->pid = pid; - CHECK(0 == pipe(thread_params->pipe)); - - CHECK(0 == pthread_create(&thread, NULL, (void * (*)(void *))wait_process_exit_thread, thread_params)); - - TRACE("wait for thread to reach waitpid"); - - u8 byte; - CHECK(sizeof(byte) == read(thread_params->pipe[0], &byte, sizeof(byte))); - TRACE("thread reached waitpid. forwarding fds"); + if (background) + { + CHECK(0 == pthread_create(&thread, NULL, (void * (*)(void *))thread_waitpid, (void *)(intptr_t)pid)); + } + else + { + // make sure we have the process fd for its stdout and stderr + CHECK(master >= 0); - fd_set readfds; - char buf[BUFFERSIZE]; - int maxfd = master > sockfd ? master : sockfd; - int nbytes = 0; + fd_set readfds; + char buf[BUFFERSIZE]; + int maxfd = master > sockfd ? master : sockfd; + int nbytes = 0; - fd_set errfds; + fd_set errfds; - while (true) - { - FD_ZERO(&readfds); - FD_SET(master, &readfds); - FD_SET(sockfd, &readfds); + while (true) + { + FD_ZERO(&readfds); + FD_SET(master, &readfds); + FD_SET(sockfd, &readfds); - CHECK(select(maxfd + 1, &readfds, NULL, &errfds, NULL) != -1); + CHECK(select(maxfd + 1, &readfds, NULL, &errfds, NULL) != -1); - if (FD_ISSET(master, &readfds)) - { - nbytes = read(master, buf, BUFFERSIZE); - if (nbytes < 1) + if (FD_ISSET(master, &readfds)) { - TRACE("read master failed. break"); - break; - } + nbytes = read(master, buf, BUFFERSIZE); + if (nbytes < 1) + { + TRACE("read master failed. break"); + break; + } - TRACE("master->sock"); + TRACE("master->sock"); - cmd_exec_chunk_t chunk; - chunk.type = CMD_EXEC_CHUNK_TYPE_STDOUT; - chunk.size = nbytes; + cmd_exec_chunk_t chunk; + chunk.type = CMD_EXEC_CHUNK_TYPE_STDOUT; + chunk.size = nbytes; - CHECK(sendall(sockfd, (char *)&chunk, sizeof(chunk))); - CHECK(sendall(sockfd, buf, chunk.size)); - } + CHECK(sendall(sockfd, (char *)&chunk, sizeof(chunk))); + CHECK(sendall(sockfd, buf, chunk.size)); + } - if (FD_ISSET(sockfd, &readfds)) - { - nbytes = recv(sockfd, buf, BUFFERSIZE, 0); - if (nbytes < 1) + if (FD_ISSET(sockfd, &readfds)) { - break; - } + nbytes = recv(sockfd, buf, BUFFERSIZE, 0); + if (nbytes < 1) + { + break; + } - TRACE("sock->master"); + TRACE("sock->master"); - CHECK(writeall(master, buf, nbytes)); + CHECK(writeall(master, buf, nbytes)); + } } - } - TRACE("notify thread its now okay to send the exit status"); - CHECK(sizeof(byte) == write(thread_params->pipe[1], &byte, sizeof(byte))); - - TRACE("wait for thread to finish"); - CHECK(0 != pthread_join(thread, NULL)); + TRACE("wait for process to finish"); + s32 error; + CHECK(pid == waitpid(pid, &error, 0)); + + cmd_exec_chunk_t chunk; + chunk.type = CMD_EXEC_CHUNK_TYPE_EXITCODE; + chunk.size = sizeof(error); + + CHECK(sendall(sockfd, (const char *)&chunk, sizeof(chunk))); + CHECK(sendall(sockfd, (const char *)&error, sizeof(error))); + } - thread = NULL; - - TRACE("thread exit"); - - result = true; + success = true; error: if (thread_params) { - close(thread_params->pipe[0]); - close(thread_params->pipe[1]); free(thread_params); } if (INVALID_PID == pid) { + TRACE("invalid pid"); + // failed to create process somewhere in the prolog, at least notify sendall(sockfd, (char *)&pid, sizeof(u32)); } @@ -408,7 +399,7 @@ bool handle_exec(int sockfd) } } - return result; + return success; } bool handle_dlopen(int sockfd) @@ -447,8 +438,10 @@ bool handle_dlsym(int sockfd) cmd_dlsym_t cmd; CHECK(recvall(sockfd, (char *)&cmd, sizeof(cmd))); - u64 err = (u64)dlsym((void *)cmd.lib, cmd.symbol_name); - CHECK(sendall(sockfd, (char *)&err, sizeof(err))); + u64 ptr = (u64)dlsym((void *)cmd.lib, cmd.symbol_name); + CHECK(sendall(sockfd, (char *)&ptr, sizeof(ptr))); + + TRACE("%s = %p", cmd.symbol_name, ptr); result = true; @@ -481,6 +474,8 @@ bool handle_call(int sockfd) argv = (u64 *)malloc(sizeof(u64) * cmd.argc); CHECK(recvall(sockfd, (char *)argv, sizeof(u64) * cmd.argc)); + TRACE("address: %p", cmd.address); + switch (cmd.argc) { case 0: @@ -542,7 +537,6 @@ bool handle_peek(int sockfd) cmd_peek_t cmd; #ifdef __APPLE__ - kern_return_t rc; mach_port_t task; vm_offset_t data; mach_msg_type_number_t size; @@ -579,20 +573,47 @@ bool handle_poke(int sockfd) { TRACE("enter"); s64 err = 0; - int result = false; + int success = false; u64 *argv = NULL; + char *data = NULL; cmd_poke_t cmd; + +#ifdef __APPLE__ + mach_port_t task; + CHECK(task_for_pid(mach_task_self(), getpid(), &task) == KERN_SUCCESS); + CHECK(recvall(sockfd, (char *)&cmd, sizeof(cmd))); + + // TODO: consider splitting recieve chunks + data = malloc(cmd.size); + CHECK(data); + CHECK(recvall(sockfd, data, cmd.size)); + + if (vm_write(task, cmd.address, (vm_offset_t)data, cmd.size) == KERN_SUCCESS) + { + CHECK(send_reply(sockfd, CMD_REPLY_POKE)); + } + else + { + CHECK(send_reply(sockfd, CMD_REPLY_ERROR)); + } +#else // __APPLE__ CHECK(recvall(sockfd, (char *)&cmd, sizeof(cmd))); CHECK(recvall(sockfd, (char *)cmd.address, cmd.size)); + CHECK(send_reply(sockfd, CMD_REPLY_POKE)); +#endif // __APPLE__ - result = true; + success = true; error: if (argv) { free(argv); } - return result; + if (data) + { + free(data); + } + return success; } #if __APPLE__ @@ -620,6 +641,7 @@ bool handle_get_dummy_block(int sockfd) void handle_client(int sockfd) { + bool disconnected = false; TRACE("enter. fd: %d", sockfd); // send MAGIC @@ -635,7 +657,10 @@ void handle_client(int sockfd) { protocol_message_t cmd; TRACE("recv"); - CHECK(recvall(sockfd, (char *)&cmd, sizeof(cmd))); + if (!recvall_ext(sockfd, (char *)&cmd, sizeof(cmd), &disconnected)) + { + goto error; + } CHECK(cmd.magic == MAGIC); TRACE("client fd: %d, cmd type: %d", sockfd, cmd.cmd_type); @@ -677,6 +702,11 @@ void handle_client(int sockfd) handle_get_dummy_block(sockfd); break; } + case CMD_CLOSE: + { + // client requested to close connection + goto error; + } default: { TRACE("unknown cmd"); @@ -685,10 +715,14 @@ void handle_client(int sockfd) } error: - TRACE("close client fd: %d", sockfd); - if (0 != close(sockfd)) + if (!disconnected) { - perror("close"); + // if client was disconnected, then os has already closed this fd + TRACE("close client fd: %d", sockfd); + if (0 != close(sockfd)) + { + perror("close"); + } } } @@ -756,6 +790,7 @@ int main(int argc, const char *argv[]) server_fd = socket(servinfo2->ai_family, servinfo2->ai_socktype, servinfo2->ai_protocol); CHECK(server_fd >= 0); + CHECK(-1 != fcntl(server_fd, F_SETFD, FD_CLOEXEC)); int yes_1 = 1; CHECK(0 == setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &yes_1, sizeof(yes_1))); @@ -771,6 +806,7 @@ int main(int argc, const char *argv[]) socklen_t addr_size = sizeof(their_addr); int client_fd = accept(server_fd, (struct sockaddr *)&their_addr, &addr_size); CHECK(client_fd >= 0); + CHECK(-1 != fcntl(client_fd, F_SETFD, FD_CLOEXEC)); char ipstr[INET6_ADDRSTRLEN]; CHECK(inet_ntop(their_addr.ss_family, get_in_addr((struct sockaddr *)&their_addr), ipstr, sizeof(ipstr)));