diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 772a97e560..3a42e44401 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -107,7 +107,12 @@ include(cmake/thirdparty/get_thread_pool.cmake) # ################################################################################################## # * library targets -------------------------------------------------------------------------------- -file(GLOB SOURCES "src/*.cpp") +set(SOURCES "src/file_handle.cpp") + +if(KvikIO_REMOTE_SUPPORT) + list(APPEND SOURCES "src/remote_handle.cpp") +endif() + add_library(kvikio ${SOURCES}) # To avoid symbol conflicts when statically linking to libcurl.a (see get_libcurl.cmake) and its diff --git a/cpp/include/kvikio/file_handle.hpp b/cpp/include/kvikio/file_handle.hpp index 97c0ba9748..141c17371a 100644 --- a/cpp/include/kvikio/file_handle.hpp +++ b/cpp/include/kvikio/file_handle.hpp @@ -15,14 +15,11 @@ */ #pragma once -#include #include #include -#include #include #include -#include #include #include @@ -37,96 +34,6 @@ #include namespace kvikio { -namespace detail { - -/** - * @brief Parse open file flags given as a string and return oflags - * - * @param flags The flags - * @param o_direct Append O_DIRECT to the open flags - * @return oflags - * - * @throw std::invalid_argument if the specified flags are not supported. - * @throw std::invalid_argument if `o_direct` is true, but `O_DIRECT` is not supported. - */ -inline int open_fd_parse_flags(const std::string& flags, bool o_direct) -{ - int file_flags = -1; - if (flags.empty()) { throw std::invalid_argument("Unknown file open flag"); } - switch (flags[0]) { - case 'r': - file_flags = O_RDONLY; - if (flags[1] == '+') { file_flags = O_RDWR; } - break; - case 'w': - file_flags = O_WRONLY; - if (flags[1] == '+') { file_flags = O_RDWR; } - file_flags |= O_CREAT | O_TRUNC; - break; - case 'a': throw std::invalid_argument("Open flag 'a' isn't supported"); - default: throw std::invalid_argument("Unknown file open flag"); - } - file_flags |= O_CLOEXEC; - if (o_direct) { -#if defined(O_DIRECT) - file_flags |= O_DIRECT; -#else - throw std::invalid_argument("'o_direct' flag unsupported on this platform"); -#endif - } - return file_flags; -} - -/** - * @brief Open file using `open(2)` - * - * @param flags Open flags given as a string - * @param o_direct Append O_DIRECT to `flags` - * @param mode Access modes - * @return File descriptor - */ -inline int open_fd(const std::string& file_path, - const std::string& flags, - bool o_direct, - mode_t mode) -{ - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-vararg) - int fd = ::open(file_path.c_str(), open_fd_parse_flags(flags, o_direct), mode); - if (fd == -1) { throw std::system_error(errno, std::generic_category(), "Unable to open file"); } - return fd; -} - -/** - * @brief Get the flags of the file descriptor (see `open(2)`) - * - * @return Open flags - */ -[[nodiscard]] inline int open_flags(int fd) -{ - int ret = fcntl(fd, F_GETFL); // NOLINT(cppcoreguidelines-pro-type-vararg) - if (ret == -1) { - throw std::system_error(errno, std::generic_category(), "Unable to retrieve open flags"); - } - return ret; -} - -/** - * @brief Get file size from file descriptor `fstat(3)` - * - * @param file_descriptor Open file descriptor - * @return The number of bytes - */ -[[nodiscard]] inline std::size_t get_file_size(int file_descriptor) -{ - struct stat st {}; - int ret = fstat(file_descriptor, &st); - if (ret == -1) { - throw std::system_error(errno, std::generic_category(), "Unable to query file size"); - } - return static_cast(st.st_size); -} - -} // namespace detail /** * @brief Handle of an open file registered with cufile. @@ -166,33 +73,7 @@ class FileHandle { FileHandle(const std::string& file_path, const std::string& flags = "r", mode_t mode = m644, - bool compat_mode = defaults::compat_mode()) - : _fd_direct_off{detail::open_fd(file_path, flags, false, mode)}, - _initialized{true}, - _compat_mode{compat_mode} - { - if (_compat_mode) { - return; // Nothing to do in compatibility mode - } - - // Try to open the file with the O_DIRECT flag. Fall back to compatibility mode, if it fails. - try { - _fd_direct_on = detail::open_fd(file_path, flags, true, mode); - } catch (const std::system_error&) { - _compat_mode = true; - } catch (const std::invalid_argument&) { - _compat_mode = true; - } - - // Create a cuFile handle, if not in compatibility mode - if (!_compat_mode) { - CUfileDescr_t desc{}; // It is important to set to zero! - desc.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD; - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-union-access) - desc.handle.fd = _fd_direct_on; - CUFILE_TRY(cuFileAPI::instance().HandleRegister(&_handle, &desc)); - } - } + bool compat_mode = defaults::compat_mode()); /** * @brief FileHandle support move semantic but isn't copyable @@ -274,7 +155,7 @@ class FileHandle { * * @return File descriptor */ - [[nodiscard]] int fd_open_flags() const { return detail::open_flags(_fd_direct_off); } + [[nodiscard]] int fd_open_flags() const; /** * @brief Get the file size @@ -283,12 +164,7 @@ class FileHandle { * * @return The number of bytes */ - [[nodiscard]] std::size_t nbytes() const - { - if (closed()) { return 0; } - if (_nbytes == 0) { _nbytes = detail::get_file_size(_fd_direct_off); } - return _nbytes; - } + [[nodiscard]] std::size_t nbytes() const; /** * @brief Reads specified bytes from the file into the device memory. diff --git a/cpp/src/file_handle.cpp b/cpp/src/file_handle.cpp new file mode 100644 index 0000000000..c5b7ada59a --- /dev/null +++ b/cpp/src/file_handle.cpp @@ -0,0 +1,158 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace kvikio { + +namespace { + +/** + * @brief Parse open file flags given as a string and return oflags + * + * @param flags The flags + * @param o_direct Append O_DIRECT to the open flags + * @return oflags + * + * @throw std::invalid_argument if the specified flags are not supported. + * @throw std::invalid_argument if `o_direct` is true, but `O_DIRECT` is not supported. + */ +int open_fd_parse_flags(const std::string& flags, bool o_direct) +{ + int file_flags = -1; + if (flags.empty()) { throw std::invalid_argument("Unknown file open flag"); } + switch (flags[0]) { + case 'r': + file_flags = O_RDONLY; + if (flags[1] == '+') { file_flags = O_RDWR; } + break; + case 'w': + file_flags = O_WRONLY; + if (flags[1] == '+') { file_flags = O_RDWR; } + file_flags |= O_CREAT | O_TRUNC; + break; + case 'a': throw std::invalid_argument("Open flag 'a' isn't supported"); + default: throw std::invalid_argument("Unknown file open flag"); + } + file_flags |= O_CLOEXEC; + if (o_direct) { +#if defined(O_DIRECT) + file_flags |= O_DIRECT; +#else + throw std::invalid_argument("'o_direct' flag unsupported on this platform"); +#endif + } + return file_flags; +} + +/** + * @brief Open file using `open(2)` + * + * @param flags Open flags given as a string + * @param o_direct Append O_DIRECT to `flags` + * @param mode Access modes + * @return File descriptor + */ +int open_fd(const std::string& file_path, const std::string& flags, bool o_direct, mode_t mode) +{ + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-vararg) + int fd = ::open(file_path.c_str(), open_fd_parse_flags(flags, o_direct), mode); + if (fd == -1) { throw std::system_error(errno, std::generic_category(), "Unable to open file"); } + return fd; +} + +/** + * @brief Get the flags of the file descriptor (see `open(2)`) + * + * @return Open flags + */ +[[nodiscard]] int open_flags(int fd) +{ + int ret = fcntl(fd, F_GETFL); // NOLINT(cppcoreguidelines-pro-type-vararg) + if (ret == -1) { + throw std::system_error(errno, std::generic_category(), "Unable to retrieve open flags"); + } + return ret; +} + +/** + * @brief Get file size from file descriptor `fstat(3)` + * + * @param file_descriptor Open file descriptor + * @return The number of bytes + */ +[[nodiscard]] std::size_t get_file_size(int file_descriptor) +{ + struct stat st {}; + int ret = fstat(file_descriptor, &st); + if (ret == -1) { + throw std::system_error(errno, std::generic_category(), "Unable to query file size"); + } + return static_cast(st.st_size); +} + +} // namespace + +FileHandle::FileHandle(const std::string& file_path, + const std::string& flags, + mode_t mode, + bool compat_mode) + : _fd_direct_off{open_fd(file_path, flags, false, mode)}, + _initialized{true}, + _compat_mode{compat_mode} +{ + if (_compat_mode) { + return; // Nothing to do in compatibility mode + } + + // Try to open the file with the O_DIRECT flag. Fall back to compatibility mode, if it fails. + try { + _fd_direct_on = open_fd(file_path, flags, true, mode); + } catch (const std::system_error&) { + _compat_mode = true; + } catch (const std::invalid_argument&) { + _compat_mode = true; + } + + // Create a cuFile handle, if not in compatibility mode + if (!_compat_mode) { + CUfileDescr_t desc{}; // It is important to set to zero! + desc.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD; + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-union-access) + desc.handle.fd = _fd_direct_on; + CUFILE_TRY(cuFileAPI::instance().HandleRegister(&_handle, &desc)); + } +} + +[[nodiscard]] int FileHandle::fd_open_flags() const { return open_flags(_fd_direct_off); } + +[[nodiscard]] std::size_t FileHandle::nbytes() const +{ + if (closed()) { return 0; } + if (_nbytes == 0) { _nbytes = get_file_size(_fd_direct_off); } + return _nbytes; +} + +} // namespace kvikio diff --git a/cpp/src/remote_handle.cpp b/cpp/src/remote_handle.cpp index 527811e143..adcf56befc 100644 --- a/cpp/src/remote_handle.cpp +++ b/cpp/src/remote_handle.cpp @@ -19,8 +19,6 @@ #include #include #include -#include -#include #include #include #include