diff --git a/src/libcrun/status.c b/src/libcrun/status.c index 1c7f24f46f..836104cfbe 100644 --- a/src/libcrun/status.c +++ b/src/libcrun/status.c @@ -45,6 +45,16 @@ struct pid_stat unsigned long long starttime; }; +/* If ID is not NULL, then ennsure that it does not contain any slash. */ +static int +validate_id (const char *id, libcrun_error_t *err) +{ + if (id && strchr (id, '/') != NULL) + return crun_make_error (err, 0, "invalid character `/` in the ID `%s`", id); + + return 0; +} + static int get_run_directory (char **out, const char *state_root, libcrun_error_t *err) { @@ -82,6 +92,10 @@ libcrun_get_state_directory (char **out, const char *state_root, const char *id, cleanup_free char *path = NULL; cleanup_free char *root = NULL; + ret = validate_id (id, err); + if (UNLIKELY (ret < 0)) + return ret; + ret = get_run_directory (&root, state_root, err); if (UNLIKELY (ret < 0)) return ret; @@ -102,6 +116,10 @@ get_state_directory_status_file (char **out, const char *state_root, const char char *path = NULL; int ret; + ret = validate_id (id, err); + if (UNLIKELY (ret < 0)) + return ret; + ret = get_run_directory (&root, state_root, err); if (UNLIKELY (ret < 0)) return ret; @@ -551,6 +569,10 @@ libcrun_container_delete_status (const char *state_root, const char *id, libcrun cleanup_close int dfd = -1; cleanup_free char *dir = NULL; + ret = validate_id (id, err); + if (UNLIKELY (ret < 0)) + return ret; + ret = get_run_directory (&dir, state_root, err); if (UNLIKELY (ret < 0)) return ret; diff --git a/tests/test_start.py b/tests/test_start.py index da74820c26..9b0744275f 100755 --- a/tests/test_start.py +++ b/tests/test_start.py @@ -540,6 +540,22 @@ def test_run_keep(): return 0 +def test_invalid_id(): + conf = base_config() + conf['process']['args'] = ['./init', 'echo', 'hello'] + conf['process']['cwd'] = "/sbin" + add_all_namespaces(conf) + try: + out, _ = run_and_get_output(conf, id_container="this/is/invalid") + return -1 + except Exception as e: + err = e.output.decode() + if "invalid character `/` in the ID" in err: + return 0 + sys.stderr.write("Got error: %s\n" % err) + return -1 + return 0 + all_tests = { "start" : test_start, "start-override-config" : test_start_override_config, @@ -562,6 +578,7 @@ def test_run_keep(): "unknown-sysctl": test_unknown_sysctl, "ioprio": test_ioprio, "run-keep": test_run_keep, + "invalid-id": test_invalid_id, } if __name__ == "__main__":