Skip to content

Commit

Permalink
More tests for column split and vertical federated learning (#8985)
Browse files Browse the repository at this point in the history
Added some more tests for the learner and fit_stump, for both column-wise distributed learning and vertical federated learning.

Also moved the `IsRowSplit` and `IsColumnSplit` methods from the `DMatrix` to the `MetaInfo` since in some places we only have access to the `MetaInfo`. Added a new convenience method `IsVerticalFederatedLearning`.

Some refactoring of the testing fixtures.
  • Loading branch information
rongou authored Mar 28, 2023
1 parent 401ce5c commit ff26cd3
Show file tree
Hide file tree
Showing 18 changed files with 210 additions and 92 deletions.
26 changes: 16 additions & 10 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,22 @@ class MetaInfo {
*/
void SynchronizeNumberOfColumns();

/*! \brief Whether the data is split row-wise. */
bool IsRowSplit() const {
return data_split_mode == DataSplitMode::kRow;
}

/*! \brief Whether the data is split column-wise. */
bool IsColumnSplit() const {
return data_split_mode == DataSplitMode::kCol;
}

/*!
* \brief A convenient method to check if we are doing vertical federated learning, which requires
* some special processing.
*/
bool IsVerticalFederated() const;

private:
void SetInfoFromHost(Context const& ctx, StringView key, Json arr);
void SetInfoFromCUDA(Context const& ctx, StringView key, Json arr);
Expand Down Expand Up @@ -542,16 +558,6 @@ class DMatrix {
return Info().num_nonzero_ == Info().num_row_ * Info().num_col_;
}

/*! \brief Whether the data is split row-wise. */
bool IsRowSplit() const {
return Info().data_split_mode == DataSplitMode::kRow;
}

/*! \brief Whether the data is split column-wise. */
bool IsColumnSplit() const {
return Info().data_split_mode == DataSplitMode::kCol;
}

/*!
* \brief Load DMatrix from URI.
* \param uri The URI of input.
Expand Down
4 changes: 2 additions & 2 deletions src/common/hist_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, b
if (!use_sorted) {
HostSketchContainer container(max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
HostSketchContainer::UseGroup(info),
m->IsColumnSplit(), n_threads);
m->Info().IsColumnSplit(), n_threads);
for (auto const& page : m->GetBatches<SparsePage>()) {
container.PushRowPage(page, info, hessian);
}
container.MakeCuts(&out);
} else {
SortedSketchContainer container{max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
HostSketchContainer::UseGroup(info),
m->IsColumnSplit(), n_threads};
m->Info().IsColumnSplit(), n_threads};
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
container.PushColPage(page, info, hessian);
}
Expand Down
6 changes: 5 additions & 1 deletion src/data/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
}

void MetaInfo::SynchronizeNumberOfColumns() {
if (collective::IsFederated() && data_split_mode == DataSplitMode::kCol) {
if (IsVerticalFederated()) {
collective::Allreduce<collective::Operation::kSum>(&num_col_, 1);
} else {
collective::Allreduce<collective::Operation::kMax>(&num_col_, 1);
Expand Down Expand Up @@ -770,6 +770,10 @@ void MetaInfo::Validate(std::int32_t device) const {
void MetaInfo::SetInfoFromCUDA(Context const&, StringView, Json) { common::AssertGPUSupport(); }
#endif // !defined(XGBOOST_USE_CUDA)

bool MetaInfo::IsVerticalFederated() const {
return collective::IsFederated() && IsColumnSplit();
}

using DMatrixThreadLocal =
dmlc::ThreadLocalStore<std::map<DMatrix const *, XGBAPIThreadLocalEntry>>;

Expand Down
2 changes: 1 addition & 1 deletion src/data/iterative_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
SyncFeatureType(&h_ft);
p_sketch.reset(new common::HostSketchContainer{
batch_param_.max_bin, h_ft, column_sizes, !proxy->Info().group_ptr_.empty(),
proxy->IsColumnSplit(), ctx_.Threads()});
proxy->Info().IsColumnSplit(), ctx_.Threads()});
}
HostAdapterDispatch(proxy, [&](auto const& batch) {
proxy->Info().num_nonzero_ = batch_nnz[i];
Expand Down
2 changes: 1 addition & 1 deletion src/data/simple_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
}

void SimpleDMatrix::ReindexFeatures() {
if (collective::IsFederated() && info_.data_split_mode == DataSplitMode::kCol) {
if (info_.IsVerticalFederated()) {
std::vector<uint64_t> buffer(collective::GetWorldSize());
buffer[collective::GetRank()] = info_.num_col_;
collective::Allgather(buffer.data(), buffer.size() * sizeof(uint64_t));
Expand Down
6 changes: 3 additions & 3 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -860,9 +860,9 @@ class LearnerConfiguration : public Learner {

void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
// Special handling for vertical federated learning.
if (collective::IsFederated() && info.data_split_mode == DataSplitMode::kCol) {
if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the estimation is calculated there
// and added to other workers.
// and broadcast to other workers.
if (collective::GetRank() == 0) {
UsePtr(obj_)->InitEstimation(info, base_score);
collective::Broadcast(base_score->Data()->HostPointer(),
Expand Down Expand Up @@ -1487,7 +1487,7 @@ class LearnerImpl : public LearnerIO {
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info, int iteration,
HostDeviceVector<GradientPair>* out_gpair) {
// Special handling for vertical federated learning.
if (collective::IsFederated() && info.data_split_mode == DataSplitMode::kCol) {
if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the gradients are calculated there
// and broadcast to other workers.
if (collective::GetRank() == 0) {
Expand Down
2 changes: 1 addition & 1 deletion src/predictor/cpu_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ class CPUPredictor : public Predictor {
protected:
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds,
gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end) const {
if (p_fmat->IsColumnSplit()) {
if (p_fmat->Info().IsColumnSplit()) {
ColumnSplitHelper helper(this->ctx_->Threads(), model, tree_begin, tree_end);
helper.PredictDMatrix(p_fmat, out_preds);
return;
Expand Down
3 changes: 1 addition & 2 deletions src/tree/fit_stump.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ void FitStump(Context const* ctx, MetaInfo const& info,
}
CHECK(h_sum.CContiguous());

// In vertical federated learning, only worker 0 needs to call this, no need to do an allreduce.
if (!collective::IsFederated() || info.data_split_mode != DataSplitMode::kCol) {
if (info.IsRowSplit()) {
collective::Allreduce<collective::Operation::kSum>(
reinterpret_cast<double*>(h_sum.Values().data()), h_sum.Size() * 2);
}
Expand Down
2 changes: 1 addition & 1 deletion src/tree/hist/evaluate_splits.h
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ class HistEvaluator {
param_{param},
column_sampler_{std::move(sampler)},
tree_evaluator_{*param, static_cast<bst_feature_t>(info.num_col_), Context::kCpuId},
is_col_split_{info.data_split_mode == DataSplitMode::kCol} {
is_col_split_{info.IsColumnSplit()} {
interaction_constraints_.Configure(*param, info.num_col_);
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
param_->colsample_bynode, param_->colsample_bylevel,
Expand Down
7 changes: 4 additions & 3 deletions src/tree/updater_approx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,13 @@ class GloablApproxBuilder {
} else {
CHECK_EQ(n_total_bins, page.cut.TotalBins());
}
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid, p_fmat->IsColumnSplit());
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid,
p_fmat->Info().IsColumnSplit());
n_batches_++;
}

histogram_builder_.Reset(n_total_bins, BatchSpec(*param_, hess), ctx_->Threads(), n_batches_,
collective::IsDistributed(), p_fmat->IsColumnSplit());
collective::IsDistributed(), p_fmat->Info().IsColumnSplit());
monitor_->Stop(__func__);
}

Expand All @@ -91,7 +92,7 @@ class GloablApproxBuilder {
for (auto const &g : gpair) {
root_sum.Add(g);
}
if (p_fmat->IsRowSplit()) {
if (p_fmat->Info().IsRowSplit()) {
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<double *>(&root_sum), 2);
}
std::vector<CPUExpandEntry> nodes{best};
Expand Down
9 changes: 5 additions & 4 deletions src/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class MultiTargetHistBuilder {
} else {
CHECK_EQ(n_total_bins, page.cut.TotalBins());
}
partitioner_.emplace_back(ctx_, page.Size(), page.base_rowid, p_fmat->IsColumnSplit());
partitioner_.emplace_back(ctx_, page.Size(), page.base_rowid, p_fmat->Info().IsColumnSplit());
page_id++;
}

Expand All @@ -167,7 +167,7 @@ class MultiTargetHistBuilder {
for (std::size_t i = 0; i < n_targets; ++i) {
histogram_builder_.emplace_back();
histogram_builder_.back().Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
collective::IsDistributed(), p_fmat->IsColumnSplit());
collective::IsDistributed(), p_fmat->Info().IsColumnSplit());
}

evaluator_ = std::make_unique<HistMultiEvaluator>(ctx_, p_fmat->Info(), param_, col_sampler_);
Expand Down Expand Up @@ -388,11 +388,12 @@ class HistBuilder {
} else {
CHECK_EQ(n_total_bins, page.cut.TotalBins());
}
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid, fmat->IsColumnSplit());
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid,
fmat->Info().IsColumnSplit());
++page_id;
}
histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
collective::IsDistributed(), fmat->IsColumnSplit());
collective::IsDistributed(), fmat->Info().IsColumnSplit());
evaluator_ = std::make_unique<HistEvaluator<CPUExpandEntry>>(ctx_, this->param_, fmat->Info(),
col_sampler_);
p_last_tree_ = p_tree;
Expand Down
12 changes: 3 additions & 9 deletions tests/cpp/helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,15 +191,9 @@ double GetMultiMetricEval(xgboost::Metric* metric,
}

namespace xgboost {
bool IsNear(std::vector<xgboost::bst_float>::const_iterator _beg1,
std::vector<xgboost::bst_float>::const_iterator _end1,
std::vector<xgboost::bst_float>::const_iterator _beg2) {
for (auto iter1 = _beg1, iter2 = _beg2; iter1 != _end1; ++iter1, ++iter2) {
if (std::abs(*iter1 - *iter2) > xgboost::kRtEps){
return false;
}
}
return true;

float GetBaseScore(Json const &config) {
return std::stof(get<String const>(config["learner"]["learner_model_param"]["base_score"]));
}

SimpleLCG::StateType SimpleLCG::operator()() {
Expand Down
5 changes: 2 additions & 3 deletions tests/cpp/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,8 @@ double GetMultiMetricEval(xgboost::Metric* metric,
std::vector<xgboost::bst_uint> groups = {});

namespace xgboost {
bool IsNear(std::vector<xgboost::bst_float>::const_iterator _beg1,
std::vector<xgboost::bst_float>::const_iterator _end1,
std::vector<xgboost::bst_float>::const_iterator _beg2);

float GetBaseScore(Json const &config);

/*!
* \brief Linear congruential generator.
Expand Down
33 changes: 24 additions & 9 deletions tests/cpp/plugin/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,33 @@ class BaseFederatedTest : public ::testing::Test {
server_thread_->join();
}

void InitCommunicator(int rank) {
Json config{JsonObject()};
config["xgboost_communicator"] = String("federated");
config["federated_server_address"] = String(server_address_);
config["federated_world_size"] = kWorldSize;
config["federated_rank"] = rank;
xgboost::collective::Init(config);
}

static int const kWorldSize{3};
std::string server_address_;
std::unique_ptr<std::thread> server_thread_;
std::unique_ptr<grpc::Server> server_;
};

template <typename Function, typename... Args>
void RunWithFederatedCommunicator(int32_t world_size, std::string const& server_address,
Function&& function, Args&&... args) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < world_size; rank++) {
threads.emplace_back([&, rank]() {
Json config{JsonObject()};
config["xgboost_communicator"] = String("federated");
config["federated_server_address"] = String(server_address);
config["federated_world_size"] = world_size;
config["federated_rank"] = rank;
xgboost::collective::Init(config);

std::forward<Function>(function)(std::forward<Args>(args)...);

xgboost::collective::Finalize();
});
}
for (auto& thread : threads) {
thread.join();
}
}

} // namespace xgboost
62 changes: 25 additions & 37 deletions tests/cpp/plugin/test_federated_data.cc
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
/*!
* Copyright 2023 XGBoost contributors
*/
#include <dmlc/parameter.h>
#include <gtest/gtest.h>
#include <xgboost/data.h>

#include <fstream>
#include <iostream>
#include <thread>

#include "../../../plugin/federated/federated_server.h"
Expand All @@ -17,49 +14,40 @@

namespace xgboost {

class FederatedDataTest : public BaseFederatedTest {
public:
void VerifyLoadUri(int rank) {
InitCommunicator(rank);
class FederatedDataTest : public BaseFederatedTest {};

size_t constexpr kRows{16};
size_t const kCols = 8 + rank;
void VerifyLoadUri() {
auto const rank = collective::GetRank();

dmlc::TemporaryDirectory tmpdir;
std::string path = tmpdir.path + "/small" + std::to_string(rank) + ".csv";
CreateTestCSV(path, kRows, kCols);
size_t constexpr kRows{16};
size_t const kCols = 8 + rank;

std::unique_ptr<DMatrix> dmat;
std::string uri = path + "?format=csv";
dmat.reset(DMatrix::Load(uri, false, DataSplitMode::kCol));
dmlc::TemporaryDirectory tmpdir;
std::string path = tmpdir.path + "/small" + std::to_string(rank) + ".csv";
CreateTestCSV(path, kRows, kCols);

ASSERT_EQ(dmat->Info().num_col_, 8 * kWorldSize + 3);
ASSERT_EQ(dmat->Info().num_row_, kRows);
std::unique_ptr<DMatrix> dmat;
std::string uri = path + "?format=csv";
dmat.reset(DMatrix::Load(uri, false, DataSplitMode::kCol));

for (auto const& page : dmat->GetBatches<SparsePage>()) {
auto entries = page.GetView().data;
auto index = 0;
int offsets[] = {0, 8, 17};
int offset = offsets[rank];
for (auto row = 0; row < kRows; row++) {
for (auto col = 0; col < kCols; col++) {
EXPECT_EQ(entries[index].index, col + offset);
index++;
}
ASSERT_EQ(dmat->Info().num_col_, 8 * collective::GetWorldSize() + 3);
ASSERT_EQ(dmat->Info().num_row_, kRows);

for (auto const& page : dmat->GetBatches<SparsePage>()) {
auto entries = page.GetView().data;
auto index = 0;
int offsets[] = {0, 8, 17};
int offset = offsets[rank];
for (auto row = 0; row < kRows; row++) {
for (auto col = 0; col < kCols; col++) {
EXPECT_EQ(entries[index].index, col + offset);
index++;
}
}

xgboost::collective::Finalize();
}
};
}

TEST_F(FederatedDataTest, LoadUri) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(&FederatedDataTest_LoadUri_Test::VerifyLoadUri, this, rank);
}
for (auto& thread : threads) {
thread.join();
}
RunWithFederatedCommunicator(kWorldSize, server_address_, &VerifyLoadUri);
}
} // namespace xgboost
Loading

0 comments on commit ff26cd3

Please sign in to comment.