diff --git a/src/rpcclient/rpcclient/client.py b/src/rpcclient/rpcclient/client.py index 7b6e9b42..7f3c695c 100644 --- a/src/rpcclient/rpcclient/client.py +++ b/src/rpcclient/rpcclient/client.py @@ -200,7 +200,17 @@ def get_dummy_block(self) -> Symbol: def spawn(self, argv: typing.List[str] = None, envp: typing.List[str] = None, stdin=sys.stdin, stdout=sys.stdout, tty=False, background=False): - """ spawn a new process and forward its stdin, stdout & stderr """ + """ + spawn a new process and forward its stdin, stdout & stderr + + :param argv: argv of the process to be executed + :param envp: envp of the process to be executed + :param stdin: either a file object to read from or a string + :param stdout: a file object to write both stdout and stderr to + :param tty: should enable raw tty mode + :param background: should execute process in background + :return: error code + """ if argv is None: argv = self.DEFAULT_ARGV @@ -392,8 +402,20 @@ def _recvall(self, size: int) -> bytes: return buf def _execution_loop(self, stdin=sys.stdin, stdout=sys.stdout): + """ + if stdin is a file object, we need to select between the fds and give higher priority to stdin. + otherwise, we can simply write all stdin contents directly to the process + """ + fds = [] + if hasattr(stdin, 'fileno'): + fds.append(stdin) + else: + # assume it's just raw bytes + self._sock.sendall(stdin) + fds.append(self._sock) + while True: - rlist, _, _ = select([stdin, self._sock], [], []) + rlist, _, _ = select(fds, [], []) for fd in rlist: if fd == sys.stdin: diff --git a/src/rpcclient/tests/test_spawn.py b/src/rpcclient/tests/test_spawn.py new file mode 100644 index 00000000..6b68f456 --- /dev/null +++ b/src/rpcclient/tests/test_spawn.py @@ -0,0 +1,20 @@ +from io import StringIO + +import pytest + + +@pytest.mark.parametrize('argv,expected_stdout,errorcode', [ + [['/bin/echo', 'blat'], 'blat', 0], + [['/bin/ls', 'INVALID_PATH'], 'ls: INVALID_PATH: No such file or directory', 256], +]) +def test_spawn_sanity(client, argv, expected_stdout, errorcode): + stdout = StringIO() + assert errorcode == client.spawn(argv, stdout=stdout, stdin=b'') + + stdout.seek(0) + assert expected_stdout == stdout.read().strip() + + +def test_spawn_bad_value_stress(client): + for i in range(1000): + assert 256 == client.spawn(['/bin/ls', 'INVALID_PATH'], stdout=StringIO(), stdin=b'') diff --git a/src/rpcserver/rpcserver.c b/src/rpcserver/rpcserver.c index 04e7da0a..ab10a80a 100644 --- a/src/rpcserver/rpcserver.c +++ b/src/rpcserver/rpcserver.c @@ -115,16 +115,6 @@ typedef struct u32 cmd_type; } protocol_message_t; -void sigchld_handler(int s) -{ - (void)s; - - // TODO: close socket associated with this pid - while (waitpid(-1, NULL, WNOHANG) > 0) - ; - TRACE("child died."); -} - void *get_in_addr(struct sockaddr *sa) // get sockaddr, IPv4 or IPv6: { return sa->sa_family == AF_INET ? (void *)&(((struct sockaddr_in *)sa)->sin_addr) : (void *)&(((struct sockaddr_in6 *)sa)->sin6_addr); @@ -198,6 +188,42 @@ bool send_reply(int sockfd, cmd_type_t type) return false; } +typedef struct { + int sockfd; + pid_t pid; + int pipe[2]; +} wait_process_exit_thread_t; + +void wait_process_exit_thread(wait_process_exit_thread_t *params) +{ + TRACE("enter"); + u8 byte; + CHECK(sizeof(byte) == write(params->pipe[1], &byte, sizeof(byte))); + TRACE("waitpid"); + + s32 err; + CHECK(-1 != waitpid(params->pid, &err, 0)); + + TRACE("waitpid done with err: %d", err); + + CHECK(sizeof(byte) == read(params->pipe[0], &byte, sizeof(byte))); + + cmd_exec_chunk_t chunk; + chunk.type = CMD_EXEC_CHUNK_TYPE_EXITCODE; + chunk.size = sizeof(err); + + CHECK(sendall(params->sockfd, (char *)&chunk, sizeof(chunk))); + CHECK(sendall(params->sockfd, (char *)&err, chunk.size)); + + TRACE("sent exit code to client fd: %d", params->sockfd); + +error: + close(params->pipe[0]); + close(params->pipe[1]); + free(params); + return; +} + bool handle_exec(int sockfd) { pid_t pid = INVALID_PID; @@ -255,6 +281,23 @@ bool handle_exec(int sockfd) CHECK(sendall(sockfd, (char *)&pid, sizeof(u32))); CHECK(master >= 0); + // create a new thread to wait for the exit of the new process, but lock it until after + // all stdin/stdout/stderr has been forwarded + wait_process_exit_thread_t *thread_params = (wait_process_exit_thread_t *)malloc(sizeof(wait_process_exit_thread_t)); + thread_params->sockfd = sockfd; + thread_params->pid = pid; + CHECK(0 == pipe(thread_params->pipe)); + + pthread_t thread; + CHECK(0 == pthread_create(&thread, NULL, wait_process_exit_thread, thread_params)); + + TRACE("wait for thread to reach waitpid"); + + u8 byte; + CHECK(sizeof(byte) == read(thread_params->pipe[0], &byte, sizeof(byte))); + + TRACE("thread reached waitpid. forwarding fds"); + fd_set readfds; char buf[BUFFERSIZE]; int maxfd = master > sockfd ? master : sockfd; @@ -297,18 +340,13 @@ bool handle_exec(int sockfd) } } - s32 err = 0; + TRACE("notify thread its now okay to send the exit status"); + CHECK(sizeof(byte) == write(thread_params->pipe[1], &byte, sizeof(byte))); - // dont exit on error here so client is notified right away when the process dies - waitpid(pid, &err, 0); - cmd_exec_chunk_t chunk; - chunk.type = CMD_EXEC_CHUNK_TYPE_EXITCODE; - chunk.size = sizeof(err); - - CHECK(sendall(sockfd, (char *)&chunk, sizeof(chunk))); - CHECK(sendall(sockfd, (char *)&err, chunk.size)); - - TRACE("sent exit code to client fd: %d", sockfd); + TRACE("wait for thread to finish"); + CHECK(0 != pthread_join(&thread, NULL)); + + TRACE("thread exit"); result = true; @@ -686,12 +724,6 @@ int main(int argc, const char *argv[]) CHECK(0 == listen(server_fd, MAX_CONNECTIONS)); - struct sigaction sa; - sa.sa_handler = sigchld_handler; // reap all dead processes - sigemptyset(&sa.sa_mask); - sa.sa_flags = SA_RESTART; - CHECK(0 == sigaction(SIGCHLD, &sa, NULL)); - while (1) { struct sockaddr_storage their_addr; // connector's address information