Skip to content

Commit

Permalink
RDMA: Prevent IO for child process
Browse files Browse the repository at this point in the history
RDMA MR(memory region) is not forkable, the VMA(virtual memory area)
of a MR gets empty in a child process. Prevent IO for child process to
avoid server crash.

Suggested by Viktor, close connection at current step.

To test any possible fork issue for RDMA, run two valkey-server
instances. After setting random KV pairs into the main server, run
REPLICAOF command in another instance. Fork is called in the main
server at any random time, several clients verify KV pairs in the
meantime.

Suggested-by: Viktor Söderqvist <[email protected]>
Signed-off-by: zhenwei pi <[email protected]>
  • Loading branch information
pizhenwei committed Nov 1, 2024
1 parent 1c222f7 commit ad1a956
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 18 deletions.
36 changes: 31 additions & 5 deletions src/rdma.c
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,34 @@ static void serverRdmaError(char *err, const char *fmt, ...) {
va_end(ap);
}

static inline int connRdmaAllowCommand(void) {
/* RDMA MR is not accessible in a child process, avoid segment fault due to
* invalid MR access, close it rather than server random crash */
if (server.in_fork_child != CHILD_TYPE_NONE) {
return C_ERR;
}

return C_OK;
}

static inline int connRdmaAllowRW(connection *conn) {
if (conn->state == CONN_STATE_ERROR || conn->state == CONN_STATE_CLOSED) {
return C_ERR;
}

return connRdmaAllowCommand();
}

static int rdmaPostRecv(RdmaContext *ctx, struct rdma_cm_id *cm_id, ValkeyRdmaCmd *cmd) {
struct ibv_sge sge;
size_t length = sizeof(ValkeyRdmaCmd);
struct ibv_recv_wr recv_wr, *bad_wr;
int ret;

if (connRdmaAllowCommand()) {
return C_ERR;
}

sge.addr = (uint64_t)cmd;
sge.length = length;
sge.lkey = ctx->cmd_mr->lkey;
Expand Down Expand Up @@ -1214,6 +1236,10 @@ static size_t connRdmaSend(connection *conn, const void *data, size_t data_len)
char *remote_addr = ctx->tx_addr + ctx->tx.offset;
int ret;

if (connRdmaAllowCommand()) {
return C_ERR;
}

memcpy(addr, data, data_len);

sge.addr = (uint64_t)addr;
Expand Down Expand Up @@ -1247,7 +1273,7 @@ static int connRdmaWrite(connection *conn, const void *data, size_t data_len) {
RdmaContext *ctx = cm_id->context;
uint32_t towrite;

if (conn->state == CONN_STATE_ERROR || conn->state == CONN_STATE_CLOSED) {
if (connRdmaAllowRW(conn)) {
return C_ERR;
}

Expand Down Expand Up @@ -1290,7 +1316,7 @@ static int connRdmaRead(connection *conn, void *buf, size_t buf_len) {
struct rdma_cm_id *cm_id = rdma_conn->cm_id;
RdmaContext *ctx = cm_id->context;

if (conn->state == CONN_STATE_ERROR || conn->state == CONN_STATE_CLOSED) {
if (connRdmaAllowRW(conn)) {
return C_ERR;
}

Expand All @@ -1312,7 +1338,7 @@ static ssize_t connRdmaSyncWrite(connection *conn, char *ptr, ssize_t size, long
long long start = mstime();
uint32_t towrite;

if (conn->state == CONN_STATE_ERROR || conn->state == CONN_STATE_CLOSED) {
if (connRdmaAllowRW(conn)) {
return C_ERR;
}

Expand Down Expand Up @@ -1355,7 +1381,7 @@ static ssize_t connRdmaSyncRead(connection *conn, char *ptr, ssize_t size, long
long long start = mstime();
uint32_t toread;

if (conn->state == CONN_STATE_ERROR || conn->state == CONN_STATE_CLOSED) {
if (connRdmaAllowRW(conn)) {
return C_ERR;
}

Expand Down Expand Up @@ -1390,7 +1416,7 @@ static ssize_t connRdmaSyncReadLine(connection *conn, char *ptr, ssize_t size, l
char *c;
char nl = 0;

if (conn->state == CONN_STATE_ERROR || conn->state == CONN_STATE_CLOSED) {
if (connRdmaAllowRW(conn)) {
return C_ERR;
}

Expand Down
45 changes: 44 additions & 1 deletion tests/rdma/rdma-test.c
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,7 @@ static RdmaContext *valkeyContextConnectRdma(const char *addr, int port, int tim
}

static int port = 6379;
static int replicaport = 6380;
static char *host = NULL;
static int minkeys = 128;
static int maxkeys = 8192;
Expand Down Expand Up @@ -952,6 +953,39 @@ static void *test_routine(void *arg) {
}
printf("Valkey Over RDMA test thread[%d] GET %d KVs [OK]\n", tid, keys);

/* # round 5, test REPLICAOF, also run REPLICAOF only once */
RdmaContext *replicactx;

replicactx = valkeyContextConnectRdma(host, replicaport, 1000);
if (!replicactx) {
rdmaFatal("RDMA connect to replica failed");
}

char *replicaofcmd = "*3\r\n$9\r\nREPLICAOF\r\n$9\r\n127.0.0.1\r\n$4\r\n6379\r\n";
static int replicaofd;

if (!__atomic_fetch_add(&replicaofd, 1, __ATOMIC_SEQ_CST)) {
valkeyRdmaWrite(replicactx, replicaofcmd, strlen(replicaofcmd));
inbytes = valkeyRdmaReadFull(replicactx, inbuf, strlen(okresp));
assert(!strncmp(okresp, inbuf, inbytes));
printf("Valkey Over RDMA test thread[%d] REPLICAOF [OK]\n", tid);
}

/* # round 6, test GET after REPLICAOF. also verify all the value already set.
* After fork, the child process can't inherit VMA(virtual memory area) of RDMA
* memory region. Test any random crash/error from the valkey-server after REPLICAOF again. */
for (int i = 0; i < keys; i++) {
kv_pair = &kv_pairs[i];
/* build GET command */
outbytes = sprintf(outbuf, "*2\r\n$3\r\nGET\r\n$%ld\r\n%s\r\n",
strlen(kv_pair->key), kv_pair->key);
valkeyRdmaWrite(ctx, outbuf, outbytes);
inbytes = valkeyRdmaReadFull(ctx, inbuf, getrespprexlen + strlen(kv_pair->value) + 2);
assert(!strncmp(getrespprex, inbuf, getrespprexlen));
assert(!strncmp(kv_pair->value, inbuf + getrespprexlen, strlen(kv_pair->value)));
}
printf("Valkey Over RDMA test thread[%d] GET %d KVs after REPLICAOF [OK]\n", tid, keys);

return NULL;
}

Expand All @@ -960,6 +994,7 @@ void usage(char *proc) {
printf("\t--help/-H\n");
printf("\t--host/-h HOSTADDR\n");
printf("\t--port/-p PORT\n");
printf("\t--replica/-P PORT\n");
printf("\t--maxkeys/-M MAXKEYS\n");
printf("\t--minkeys/-M MINKEYS\n");
printf("\t--thread/-t THREADS\n");
Expand All @@ -975,11 +1010,12 @@ int main(int argc, char *argv[])
{ "help", no_argument, NULL, 'H' },
{ "host", required_argument, NULL, 'h' },
{ "port", required_argument, NULL, 'p' },
{ "replica", required_argument, NULL, 'P' },
{ "maxkeys", required_argument, NULL, 'M' },
{ "minkeys", required_argument, NULL, 'm' },
{ "thread", required_argument, NULL, 't' },
};
static char *short_opts = "Hh:p:t:M:m:";
static char *short_opts = "Hh:p:P:t:M:m:";

while (1) {
c = getopt_long(argc, argv, short_opts, long_opts, &args);
Expand All @@ -998,6 +1034,13 @@ int main(int argc, char *argv[])
}
break;

case 'P':
replicaport = atoi(optarg);
if (replicaport <= 0 || replicaport > 65535) {
rdmaFatal("invalid replica port");
}
break;

case 't':
nr_threads = atoi(optarg);
if (nr_threads < 0 || nr_threads > MAX_THREADS) {
Expand Down
42 changes: 30 additions & 12 deletions tests/rdma/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,23 +55,39 @@ def test_rdma(ipaddr):
retval = 0

# step 1, prepare test directory
tmpdir = valkeydir + "/tests/rdma/tmp"
subprocess.Popen("mkdir -p " + tmpdir, shell=True).wait()
tmpdir = valkeydir + "/tests/rdma/tmp/"
subprocess.Popen("mkdir -p " + tmpdir + "6379", shell=True).wait()
subprocess.Popen("mkdir -p " + tmpdir + "6380", shell=True).wait()

# step 2, start server
# step 2, start server and replica
svrpath = valkeydir + "/src/valkey-server"
rdmapath = valkeydir + "/src/valkey-rdma.so"
svrcmd = [svrpath, "--port", "0", "--loglevel", "verbose", "--protected-mode", "yes",
"--appendonly", "no", "--daemonize", "no", "--dir", valkeydir + "/tests/rdma/tmp",
"--loadmodule", rdmapath, "port=6379", "bind=" + ipaddr]
mainsvrcmd = [svrpath, "--port", "6379", "--loglevel", "verbose", "--protected-mode", "yes",
"--appendonly", "no", "--daemonize", "no", "--dir",
valkeydir + "/tests/rdma/tmp/6379",
"--loadmodule", rdmapath, "port=6379", "bind=" + ipaddr]

svr = subprocess.Popen(svrcmd, shell=False, stdout=subprocess.PIPE)
mainsvr = subprocess.Popen(mainsvrcmd, shell=False, stdout=subprocess.PIPE)
try:
if svr.wait(1):
print("Valkey Over RDMA valkey-server runs less than 1s [FAILED]")
if mainsvr.wait(1):
print("Valkey Over RDMA main valkey-server runs less than 1s [FAILED]")
return 1
except subprocess.TimeoutExpired as e:
print("Valkey Over RDMA valkey-server start [OK]")
print("Valkey Over RDMA main valkey-server start [OK]")
pass

replicasvrcmd = [svrpath, "--port", "6380", "--loglevel", "verbose", "--protected-mode", "yes",
"--appendonly", "no", "--daemonize", "no", "--dir",
valkeydir + "/tests/rdma/tmp/6380",
"--loadmodule", rdmapath, "port=6380", "bind=" + ipaddr]

replicasvr = subprocess.Popen(replicasvrcmd, shell=False, stdout=subprocess.PIPE)
try:
if replicasvr.wait(1):
print("Valkey Over RDMA replica valkey-server runs less than 1s [FAILED]")
return 1
except subprocess.TimeoutExpired as e:
print("Valkey Over RDMA replica valkey-server start [OK]")
pass

# step 3, run test client
Expand All @@ -92,8 +108,10 @@ def test_rdma(ipaddr):
retval = 0

# step 4, cleanup
svr.kill()
svr.wait()
mainsvr.kill()
mainsvr.wait()
replicasvr.kill()
replicasvr.wait()
subprocess.Popen("rm -rf " + tmpdir, shell=True).wait()

# step 5, report result
Expand Down

0 comments on commit ad1a956

Please sign in to comment.