Skip to content

Commit

Permalink
server: bugfix: race in spawn error code
Browse files Browse the repository at this point in the history
  • Loading branch information
doronz88 committed Feb 20, 2022
1 parent cae7a13 commit 4ea8374
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 29 deletions.
26 changes: 24 additions & 2 deletions src/rpcclient/rpcclient/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,17 @@ def get_dummy_block(self) -> Symbol:

def spawn(self, argv: typing.List[str] = None, envp: typing.List[str] = None, stdin=sys.stdin, stdout=sys.stdout,
tty=False, background=False):
""" spawn a new process and forward its stdin, stdout & stderr """
"""
spawn a new process and forward its stdin, stdout & stderr
:param argv: argv of the process to be executed
:param envp: envp of the process to be executed
:param stdin: either a file object to read from or a string
:param stdout: a file object to write both stdout and stderr to
:param tty: should enable raw tty mode
:param background: should execute process in background
:return: error code
"""
if argv is None:
argv = self.DEFAULT_ARGV

Expand Down Expand Up @@ -392,8 +402,20 @@ def _recvall(self, size: int) -> bytes:
return buf

def _execution_loop(self, stdin=sys.stdin, stdout=sys.stdout):
"""
if stdin is a file object, we need to select between the fds and give higher priority to stdin.
otherwise, we can simply write all stdin contents directly to the process
"""
fds = []
if hasattr(stdin, 'fileno'):
fds.append(stdin)
else:
# assume it's just raw bytes
self._sock.sendall(stdin)
fds.append(self._sock)

while True:
rlist, _, _ = select([stdin, self._sock], [], [])
rlist, _, _ = select(fds, [], [])

for fd in rlist:
if fd == sys.stdin:
Expand Down
20 changes: 20 additions & 0 deletions src/rpcclient/tests/test_spawn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from io import StringIO

import pytest


@pytest.mark.parametrize('argv,expected_stdout,errorcode', [
[['/bin/echo', 'blat'], 'blat', 0],
[['/bin/ls', 'INVALID_PATH'], 'ls: INVALID_PATH: No such file or directory', 256],
])
def test_spawn_sanity(client, argv, expected_stdout, errorcode):
stdout = StringIO()
assert errorcode == client.spawn(argv, stdout=stdout, stdin=b'')

stdout.seek(0)
assert expected_stdout == stdout.read().strip()


def test_spawn_bad_value_stress(client):
for i in range(1000):
assert 256 == client.spawn(['/bin/ls', 'INVALID_PATH'], stdout=StringIO(), stdin=b'')
86 changes: 59 additions & 27 deletions src/rpcserver/rpcserver.c
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,6 @@ typedef struct
u32 cmd_type;
} protocol_message_t;

void sigchld_handler(int s)
{
(void)s;

// TODO: close socket associated with this pid
while (waitpid(-1, NULL, WNOHANG) > 0)
;
TRACE("child died.");
}

void *get_in_addr(struct sockaddr *sa) // get sockaddr, IPv4 or IPv6:
{
return sa->sa_family == AF_INET ? (void *)&(((struct sockaddr_in *)sa)->sin_addr) : (void *)&(((struct sockaddr_in6 *)sa)->sin6_addr);
Expand Down Expand Up @@ -198,6 +188,42 @@ bool send_reply(int sockfd, cmd_type_t type)
return false;
}

typedef struct {
int sockfd;
pid_t pid;
int pipe[2];
} wait_process_exit_thread_t;

void wait_process_exit_thread(wait_process_exit_thread_t *params)
{
TRACE("enter");
u8 byte;
CHECK(sizeof(byte) == write(params->pipe[1], &byte, sizeof(byte)));
TRACE("waitpid");

s32 err;
CHECK(-1 != waitpid(params->pid, &err, 0));

TRACE("waitpid done with err: %d", err);

CHECK(sizeof(byte) == read(params->pipe[0], &byte, sizeof(byte)));

cmd_exec_chunk_t chunk;
chunk.type = CMD_EXEC_CHUNK_TYPE_EXITCODE;
chunk.size = sizeof(err);

CHECK(sendall(params->sockfd, (char *)&chunk, sizeof(chunk)));
CHECK(sendall(params->sockfd, (char *)&err, chunk.size));

TRACE("sent exit code to client fd: %d", params->sockfd);

error:
close(params->pipe[0]);
close(params->pipe[1]);
free(params);
return;
}

bool handle_exec(int sockfd)
{
pid_t pid = INVALID_PID;
Expand Down Expand Up @@ -255,6 +281,23 @@ bool handle_exec(int sockfd)
CHECK(sendall(sockfd, (char *)&pid, sizeof(u32)));
CHECK(master >= 0);

// create a new thread to wait for the exit of the new process, but lock it until after
// all stdin/stdout/stderr has been forwarded
wait_process_exit_thread_t *thread_params = (wait_process_exit_thread_t *)malloc(sizeof(wait_process_exit_thread_t));
thread_params->sockfd = sockfd;
thread_params->pid = pid;
CHECK(0 == pipe(thread_params->pipe));

pthread_t thread;
CHECK(0 == pthread_create(&thread, NULL, wait_process_exit_thread, thread_params));

TRACE("wait for thread to reach waitpid");

u8 byte;
CHECK(sizeof(byte) == read(thread_params->pipe[0], &byte, sizeof(byte)));

TRACE("thread reached waitpid. forwarding fds");

fd_set readfds;
char buf[BUFFERSIZE];
int maxfd = master > sockfd ? master : sockfd;
Expand Down Expand Up @@ -297,18 +340,13 @@ bool handle_exec(int sockfd)
}
}

s32 err = 0;
TRACE("notify thread its now okay to send the exit status");
CHECK(sizeof(byte) == write(thread_params->pipe[1], &byte, sizeof(byte)));

// dont exit on error here so client is notified right away when the process dies
waitpid(pid, &err, 0);
cmd_exec_chunk_t chunk;
chunk.type = CMD_EXEC_CHUNK_TYPE_EXITCODE;
chunk.size = sizeof(err);

CHECK(sendall(sockfd, (char *)&chunk, sizeof(chunk)));
CHECK(sendall(sockfd, (char *)&err, chunk.size));

TRACE("sent exit code to client fd: %d", sockfd);
TRACE("wait for thread to finish");
CHECK(0 != pthread_join(&thread, NULL));

TRACE("thread exit");

result = true;

Expand Down Expand Up @@ -686,12 +724,6 @@ int main(int argc, const char *argv[])

CHECK(0 == listen(server_fd, MAX_CONNECTIONS));

struct sigaction sa;
sa.sa_handler = sigchld_handler; // reap all dead processes
sigemptyset(&sa.sa_mask);
sa.sa_flags = SA_RESTART;
CHECK(0 == sigaction(SIGCHLD, &sa, NULL));

while (1)
{
struct sockaddr_storage their_addr; // connector's address information
Expand Down

0 comments on commit 4ea8374

Please sign in to comment.