From d9f2cf58ee594dd6c960be8c3500f6de942eacf0 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Wed, 23 Mar 2022 14:43:23 -0700 Subject: [PATCH] Create jiterator cache dirs recursively (reland) (#74592) Summary: Reland of https://github.com/pytorch/pytorch/pull/74425 with internal compilation error fixed The change expects the base directories (`HOME/TEMP`, `XDG_CACHE_HOME`, or the user-defined `PYTORCH_KERNEL_CACHE_PATH`) to exist to avoid potentially exploiting the recursive folder creation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/74592 Reviewed By: mruberry Differential Revision: D35066710 Pulled By: malfet fbshipit-source-id: c26aff826b0a3d6ca99286b031711698a515fbbb (cherry picked from commit 99479e5a4fdc7c77080dd806e9dfd96e25f1192d) --- aten/src/ATen/native/cuda/jit_utils.cpp | 74 ++++++++++++++++++++++--- 1 file changed, 67 insertions(+), 7 deletions(-) diff --git a/aten/src/ATen/native/cuda/jit_utils.cpp b/aten/src/ATen/native/cuda/jit_utils.cpp index c8010a6e9b0af..d88a39e261d46 100644 --- a/aten/src/ATen/native/cuda/jit_utils.cpp +++ b/aten/src/ATen/native/cuda/jit_utils.cpp @@ -808,6 +808,65 @@ std::string generate_code( return code; } +// Creates directories recursively +bool _r_mkdir(const std::string& dir) { + // Check if current dir exists + const char* p_dir = dir.c_str(); + const bool dir_exists = (access(p_dir, F_OK) == 0); + if (dir_exists) { + return true; + } + + // Try to create current directory +#ifdef _WIN32 + int ret = _mkdir(dir.c_str()); +#else + int ret = mkdir(dir.c_str(), S_IRWXU | S_IRWXG | S_IRWXO); +#endif + // Success + if (ret == 0) { + return true; + } + + // Find folder separator and check if we are at the top + auto pos = dir.find_last_of("/\\"); + if (pos == std::string::npos) { + return false; + } + + // Try to create parent directory + if (!(_r_mkdir(dir.substr(0, pos)))) { + return false; + } + + // Try to create complete path again +#ifdef _WIN32 + ret = _mkdir(dir.c_str()); +#else + ret = mkdir(dir.c_str(), S_IRWXU | S_IRWXG | S_IRWXO); +#endif + return ret == 0; +} + +// Creates directories recursively assuming that base exists +bool r_mkdir_with_base(std::string& base, std::string& dir){ + const char* p_base = base.c_str(); + const bool base_exists = (access(p_base, F_OK) == 0); + if (!base_exists) { + return false; + } + + // remove trailing '/' or '\\' + if ((base[base.size()-1]=='/') || base[base.size()-1]=='\\') { + base.pop_back(); + } + if ((dir[dir.size()-1]=='/') || dir[dir.size()-1]=='\\') { + dir.pop_back(); + } + + return _r_mkdir(base+dir); +} + // Acquires (possibly creating) the kernel cache directory c10::optional get_cache_dir() { @@ -822,6 +881,8 @@ c10::optional get_cache_dir() { // Cache path comes from PYTORCH_KERNEL_CACHE_PATH, then TEMP (Windows) or XDG_CACHE_HOME (Linux), then HOME environment variables std::string cache_dir; char* ptkcp = std::getenv("PYTORCH_KERNEL_CACHE_PATH"); + // Create kernel_cache_dir if needed as we do not want to create the base directory passed by the user + std::string kernels_cache_dir = ""; if (ptkcp != nullptr) { cache_dir = std::string(ptkcp); } else { @@ -832,7 +893,8 @@ c10::optional get_cache_dir() { ptkcp = std::getenv("XDG_CACHE_HOME"); #endif if (ptkcp != nullptr) { - cache_dir = std::string(ptkcp) + "/torch/kernels"; + kernels_cache_dir = "/torch/kernels"; + cache_dir = std::string(ptkcp) + kernels_cache_dir; } else { // Falls back to HOME/.cache ptkcp = std::getenv("HOME"); @@ -841,7 +903,8 @@ c10::optional get_cache_dir() { " This disables kernel caching."); return {}; } else { - cache_dir = std::string(ptkcp) + "/.cache/torch/kernels"; + kernels_cache_dir = "/.cache/torch/kernels"; + cache_dir = std::string(ptkcp) + kernels_cache_dir; } } } @@ -850,11 +913,8 @@ c10::optional get_cache_dir() { const char* p_cache_dir = cache_dir.c_str(); const bool cache_dir_exists = (access(p_cache_dir, F_OK) == 0); if (!cache_dir_exists) { -#ifdef _WIN32 - if (_mkdir(p_cache_dir) != 0) { -#else - if (mkdir(p_cache_dir, S_IRWXU | S_IRWXG | S_IRWXO) != 0) { -#endif + std::string s_ptkcp = std::string(ptkcp); + if (!r_mkdir_with_base(s_ptkcp, kernels_cache_dir)) { TORCH_WARN_ONCE("Specified kernel cache directory could not be created! This disables kernel caching.", " Specified directory is ", cache_dir, ".", " This warning will appear only once per process.");