Skip to content

Commit

Permalink
Cleanup.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jun 12, 2023
1 parent a6202d0 commit 117fb97
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 24 deletions.
11 changes: 6 additions & 5 deletions demo/guide-python/external_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,21 +75,22 @@ def reset(self) -> None:

def main(tmpdir: str) -> xgboost.Booster:
# generate some random data for demo
files = make_batches(2 ** 16, 17, 31, tmpdir)
files = make_batches(1024, 17, 31, tmpdir)
it = Iterator(files)
# For non-data arguments, specify it here once instead of passing them by the `next`
# method.
missing = np.NaN
Xy = xgboost.DMatrix(it, missing=missing, enable_categorical=False)

# Other tree methods including ``hist`` and ``gpu_hist`` also work, see tutorial in
# doc for details.
# Other tree methods including ``approx``, ``hist``, and ``gpu_hist`` are supported,
# see tutorial in doc for details.
booster = xgboost.train(
{"tree_method": "gpu_hist", "max_depth": 6, "sampling_method": "gradient_based", "subsample": 0.5},
{"tree_method": "hist", "max_depth": 4},
Xy,
evals=[(Xy, "Train")],
num_boost_round=2,
num_boost_round=10,
)
return booster


if __name__ == "__main__":
Expand Down
5 changes: 3 additions & 2 deletions doc/tutorials/external_memory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ not supported by ``exact`` tree method.

.. warning::

The implementation of external memory uses ``mmap`` and is not tested against errors
like disconnected network devices. (`SIGBUS`)
The implementation of external memory uses ``mmap`` and is not tested against system
errors like disconnected network devices (`SIGBUS`). In addition, Windows is not yet
supported.

.. note::

Expand Down
3 changes: 1 addition & 2 deletions src/common/io.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,7 @@ void* PrivateMmapStream::Open(StringView path, bool read_only, std::size_t offse
#if defined(__linux__) || defined(__GLIBC__)
ptr = reinterpret_cast<char*>(mmap64(nullptr, length, prot, MAP_PRIVATE, fd_, offset));
#elif defined(_MSC_VER)
// fixme: not yet implemented
ptr = reinterpret_cast<char*>(mmap(nullptr, length, prot, MAP_PRIVATE, fd_, offset));
LOG(FATAL) << "External memory is not implemented for Windows.";
#else
CHECK_LE(offset, std::numeric_limits<off_t>::max())
<< "File size has exceeded the limit on the current system.";
Expand Down
20 changes: 20 additions & 0 deletions src/data/sparse_page_source.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include "sparse_page_source.h"

#include <unistd.h> // for getpagesize

namespace xgboost::data {
std::size_t PadPageForMMAP(std::size_t file_bytes, dmlc::Stream* fo) {
decltype(file_bytes) page_size = getpagesize();
CHECK(page_size != 0 && page_size % 2 == 0) << "Failed to get page size on the current system.";
CHECK_NE(file_bytes, 0) << "Empty page encountered.";
auto n = file_bytes / page_size;
auto padded = (n + !!(file_bytes % page_size != 0)) * page_size;
auto padding = padded - file_bytes;
std::vector<std::uint8_t> padding_bytes(padding, 0);
fo->Write(padding_bytes.data(), padding_bytes.size());
return padded;
}
} // namespace xgboost::data
29 changes: 14 additions & 15 deletions src/data/sparse_page_source.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
#ifndef XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
#define XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_

#include <unistd.h> // for getpagesize

#include <algorithm> // for min
#include <future>
#include <map>
Expand Down Expand Up @@ -34,14 +32,23 @@ inline void TryDeleteCacheFile(const std::string& file) {
}
}

/**
* @brief Pad the output file for a page to make it mmap compatible.
*
* @param file_bytes The size of the output file
* @param fo Stream used to write the file.
*
* @return The file size after being padded.
*/
std::size_t PadPageForMMAP(std::size_t file_bytes, dmlc::Stream* fo);

struct Cache {
// whether the write to the cache is complete
bool written;
std::string name;
std::string format;
// offset into binary cache file.
std::vector<size_t> offset;
std::vector<std::uint64_t> bytes;

Cache(bool w, std::string n, std::string fmt)
: written{w}, name{std::move(n)}, format{std::move(fmt)} {
Expand All @@ -57,7 +64,6 @@ struct Cache {
return ShardName(this->name, this->format);
}
void Push(std::size_t n_bytes) {
bytes.push_back(n_bytes);
offset.push_back(n_bytes);
}

Expand Down Expand Up @@ -139,7 +145,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
auto n = self->cache_info_->ShardName();

std::uint64_t offset = self->cache_info_->offset.at(fetch_it);
std::uint64_t length = self->cache_info_->bytes.at(fetch_it);
std::uint64_t length = self->cache_info_->offset.at(fetch_it + 1) - offset;

auto fi = std::make_unique<common::PrivateMmapStream>(n, true, offset, length);
CHECK(fmt->Read(page.get(), fi.get()));
Expand All @@ -151,6 +157,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
n_prefetch_batches)
<< "Sparse DMatrix assumes forward iteration.";
page_ = (*ring_)[count_].get();
CHECK(!(*ring_)[count_].valid());
return true;
}

Expand All @@ -169,18 +176,9 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
}

auto bytes = fmt->Write(*page_, fo.get());

// align for mmap
decltype(bytes) page_size = getpagesize();
CHECK(page_size != 0 && page_size % 2 == 0) << "Failed to get page size on the current system.";
auto n = bytes / page_size;
auto padded = (n + 1) * page_size;
auto padding = padded - bytes;
std::vector<std::uint8_t> padding_bytes(padding, 0);
fo->Write(padding_bytes.data(), padding_bytes.size());
auto padded = PadPageForMMAP(bytes, fo.get());

timer.Stop();

LOG(INFO) << static_cast<double>(bytes) / 1024.0 / 1024.0 << " MB written in "
<< timer.ElapsedSeconds() << " seconds.";
cache_info_->Push(padded);
Expand Down Expand Up @@ -280,6 +278,7 @@ class SparsePageSource : public SparsePageSourceImpl<SparsePage> {
}

if (at_end_) {
CHECK_EQ(cache_info_->offset.size(), n_batches_ + 1);
cache_info_->Commit();
if (n_batches_ != 0) {
CHECK_EQ(count_, n_batches_);
Expand Down

0 comments on commit 117fb97

Please sign in to comment.