Skip to content

Commit

Permalink
Create jiterator cache dirs recursively (reland) (pytorch#74592)
Browse files Browse the repository at this point in the history
Summary:
Reland of pytorch#74425 with internal compilation error fixed

The change expects the base directories (`HOME/TEMP`, `XDG_CACHE_HOME`, or the user-defined `PYTORCH_KERNEL_CACHE_PATH`) to exist to avoid potentially exploiting the recursive folder creation.

Pull Request resolved: pytorch#74592

Reviewed By: mruberry

Differential Revision: D35066710

Pulled By: malfet

fbshipit-source-id: c26aff826b0a3d6ca99286b031711698a515fbbb
(cherry picked from commit 99479e5)
  • Loading branch information
malfet authored and pytorchmergebot committed Mar 23, 2022
1 parent 85d8647 commit d9f2cf5
Showing 1 changed file with 67 additions and 7 deletions.
74 changes: 67 additions & 7 deletions aten/src/ATen/native/cuda/jit_utils.cpp
Original file line number Diff line number Diff line change
@@ -808,6 +808,65 @@ std::string generate_code(
return code;
}

// Creates directories recursively
bool _r_mkdir(const std::string& dir) {
// Check if current dir exists
const char* p_dir = dir.c_str();
const bool dir_exists = (access(p_dir, F_OK) == 0);
if (dir_exists) {
return true;
}

// Try to create current directory
#ifdef _WIN32
int ret = _mkdir(dir.c_str());
#else
int ret = mkdir(dir.c_str(), S_IRWXU | S_IRWXG | S_IRWXO);
#endif
// Success
if (ret == 0) {
return true;
}

// Find folder separator and check if we are at the top
auto pos = dir.find_last_of("/\\");
if (pos == std::string::npos) {
return false;
}

// Try to create parent directory
if (!(_r_mkdir(dir.substr(0, pos)))) {
return false;
}

// Try to create complete path again
#ifdef _WIN32
ret = _mkdir(dir.c_str());
#else
ret = mkdir(dir.c_str(), S_IRWXU | S_IRWXG | S_IRWXO);
#endif
return ret == 0;
}

// Creates directories recursively assuming that base exists
bool r_mkdir_with_base(std::string& base, std::string& dir){
const char* p_base = base.c_str();
const bool base_exists = (access(p_base, F_OK) == 0);
if (!base_exists) {
return false;
}

// remove trailing '/' or '\\'
if ((base[base.size()-1]=='/') || base[base.size()-1]=='\\') {
base.pop_back();
}
if ((dir[dir.size()-1]=='/') || dir[dir.size()-1]=='\\') {
dir.pop_back();
}

return _r_mkdir(base+dir);
}


// Acquires (possibly creating) the kernel cache directory
c10::optional<std::string> get_cache_dir() {
@@ -822,6 +881,8 @@ c10::optional<std::string> get_cache_dir() {
// Cache path comes from PYTORCH_KERNEL_CACHE_PATH, then TEMP (Windows) or XDG_CACHE_HOME (Linux), then HOME environment variables
std::string cache_dir;
char* ptkcp = std::getenv("PYTORCH_KERNEL_CACHE_PATH");
// Create kernel_cache_dir if needed as we do not want to create the base directory passed by the user
std::string kernels_cache_dir = "";
if (ptkcp != nullptr) {
cache_dir = std::string(ptkcp);
} else {
@@ -832,7 +893,8 @@ c10::optional<std::string> get_cache_dir() {
ptkcp = std::getenv("XDG_CACHE_HOME");
#endif
if (ptkcp != nullptr) {
cache_dir = std::string(ptkcp) + "/torch/kernels";
kernels_cache_dir = "/torch/kernels";
cache_dir = std::string(ptkcp) + kernels_cache_dir;
} else {
// Falls back to HOME/.cache
ptkcp = std::getenv("HOME");
@@ -841,7 +903,8 @@ c10::optional<std::string> get_cache_dir() {
" This disables kernel caching.");
return {};
} else {
cache_dir = std::string(ptkcp) + "/.cache/torch/kernels";
kernels_cache_dir = "/.cache/torch/kernels";
cache_dir = std::string(ptkcp) + kernels_cache_dir;
}
}
}
@@ -850,11 +913,8 @@ c10::optional<std::string> get_cache_dir() {
const char* p_cache_dir = cache_dir.c_str();
const bool cache_dir_exists = (access(p_cache_dir, F_OK) == 0);
if (!cache_dir_exists) {
#ifdef _WIN32
if (_mkdir(p_cache_dir) != 0) {
#else
if (mkdir(p_cache_dir, S_IRWXU | S_IRWXG | S_IRWXO) != 0) {
#endif
std::string s_ptkcp = std::string(ptkcp);
if (!r_mkdir_with_base(s_ptkcp, kernels_cache_dir)) {
TORCH_WARN_ONCE("Specified kernel cache directory could not be created! This disables kernel caching.",
" Specified directory is ", cache_dir, ".",
" This warning will appear only once per process.");

0 comments on commit d9f2cf5

Please sign in to comment.