Skip to content

Commit

Permalink
Merge pull request #43 from qingshui/paddlebox
Browse files Browse the repository at this point in the history
remove FLAGS_padbox_auc_runner_mode gflags, remove same merge pv cond…
  • Loading branch information
qingshui authored Jun 17, 2022
2 parents a67cdaa + 42e69ee commit 1d813ff
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 37 deletions.
26 changes: 6 additions & 20 deletions paddle/fluid/framework/data_feed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1661,10 +1661,8 @@ bool MultiSlotFileInstantDataFeed::ParseOneMiniBatch() {

bool PaddleBoxDataFeed::Start() {
#ifdef _LINUX
// join: 1, update: 0, join_test: 3, update_test: 2
int phase = GetCurrentPhase();
this->CheckSetFileList();
if (enable_pv_merge_ && (phase == 1 || phase == 3)) {
if (enable_pv_merge_) {
// join phase : input_pv_channel to output_pv_channel
if (output_pv_channel_->Size() == 0 && input_pv_channel_->Size() != 0) {
std::vector<PvInstance> data;
Expand Down Expand Up @@ -1693,10 +1691,8 @@ bool PaddleBoxDataFeed::Start() {

int PaddleBoxDataFeed::Next() {
#ifdef _LINUX
// join: 1, update: 0, join_test: 3, update_test: 2
int phase = GetCurrentPhase();
this->CheckStart();
if (enable_pv_merge_ && (phase == 1 || phase == 3)) {
if (enable_pv_merge_) {
// join phase : output_pv_channel to consume_pv_channel
CHECK(output_pv_channel_ != nullptr);
CHECK(consume_pv_channel_ != nullptr);
Expand Down Expand Up @@ -1827,9 +1823,7 @@ void PaddleBoxDataFeed::GetRankOffset(const std::vector<PvInstance>& pv_vec,
void PaddleBoxDataFeed::AssignFeedVar(const Scope& scope) {
MultiSlotInMemoryDataFeed::AssignFeedVar(scope);
// set rank offset memory
// join: 1, update: 0, join_test: 3, update_test: 2
int phase = GetCurrentPhase();
if (enable_pv_merge_ && (phase == 1 || phase == 3)) {
if (enable_pv_merge_) {
rank_offset_ = scope.FindVar(rank_offset_name_)->GetMutable<LoDTensor>();
}
}
Expand Down Expand Up @@ -2341,14 +2335,12 @@ bool SlotPaddleBoxDataFeed::Start() {
return true;
}
int SlotPaddleBoxDataFeed::Next() {
// join: 1, update: 0, join_test: 3, update_test: 2
int phase = GetCurrentPhase();
this->CheckStart();
if (offset_index_ >= static_cast<int>(batch_offsets_.size())) {
return 0;
}
auto& batch = batch_offsets_[offset_index_++];
if (enable_pv_merge_ && (phase == 1 || phase == 3)) {
if (enable_pv_merge_) {
// join phase : output_pv_channel to consume_pv_channel
this->batch_size_ = batch.second;
if (this->batch_size_ != 0) {
Expand All @@ -2364,18 +2356,14 @@ int SlotPaddleBoxDataFeed::Next() {
batch_timer_.Resume();
PutToFeedSlotVec(&records_[batch.first], this->batch_size_);
// update set join q value
if ((phase == 0 || phase == 2) && FLAGS_padbox_slotrecord_extend_dim > 0) {
if (FLAGS_padbox_slotrecord_extend_dim > 0) {
// pcoc
pack_->pack_qvalue();
}
batch_timer_.Pause();
return this->batch_size_;
}
}
bool SlotPaddleBoxDataFeed::EnablePvMerge(void) {
return (enable_pv_merge_ &&
(GetCurrentPhase() == 1 || GetCurrentPhase() == 3));
}
int SlotPaddleBoxDataFeed::GetPackInstance(SlotRecord** ins) {
if (offset_index_ >= static_cast<int>(batch_offsets_.size())) {
return 0;
Expand All @@ -2399,9 +2387,7 @@ void SlotPaddleBoxDataFeed::AssignFeedVar(const Scope& scope) {
scope.FindVar(used_slots_info_[i].slot)->GetMutable<LoDTensor>();
}
// set rank offset memory
// join: 1, update: 0, join_test: 3, update_test: 2
int phase = GetCurrentPhase();
if (enable_pv_merge_ && (phase == 1 || phase == 3)) {
if (enable_pv_merge_) {
rank_offset_ = scope.FindVar(rank_offset_name_)->GetMutable<LoDTensor>();
}
}
Expand Down
17 changes: 9 additions & 8 deletions paddle/fluid/framework/data_feed.h
Original file line number Diff line number Diff line change
Expand Up @@ -847,12 +847,8 @@ struct SlotRecordObject {
};
using SlotRecord = SlotRecordObject*;

inline SlotRecord make_slotrecord() {
static const size_t slot_record_byte_size =
sizeof(SlotRecordObject) +
sizeof(float) * FLAGS_padbox_slotrecord_extend_dim +
sizeof(AucRunnerInfo) * static_cast<int>(FLAGS_padbox_auc_runner_mode);
void* p = malloc(slot_record_byte_size);
inline SlotRecord make_slotrecord(const size_t& byte_size) {
void* p = malloc(byte_size);
new (p) SlotRecordObject;
return reinterpret_cast<SlotRecordObject*>(p);
}
Expand Down Expand Up @@ -963,6 +959,8 @@ class SlotObjPool {
: inited_(true),
max_capacity_(FLAGS_padbox_record_pool_max_size),
alloc_(free_slotrecord) {
slot_record_byte_size_ = sizeof(SlotRecordObject) +
sizeof(float) * FLAGS_padbox_slotrecord_extend_dim;
for (int i = 0; i < FLAGS_padbox_slotpool_thread_num; ++i) {
threads_.push_back(std::thread([this]() { run(); }));
}
Expand All @@ -976,6 +974,9 @@ class SlotObjPool {
t.join();
}
}
void set_slotrecord_size(size_t byte_size) {
slot_record_byte_size_ = byte_size;
}
void disable_pool(bool disable) { disable_pool_ = disable; }
void set_max_capacity(size_t max_capacity) { max_capacity_ = max_capacity; }
void get(std::vector<SlotRecord>* output, size_t n) {
Expand All @@ -993,7 +994,7 @@ class SlotObjPool {
return;
}
for (size_t i = size; i < n; ++i) {
output[i] = make_slotrecord();
output[i] = make_slotrecord(slot_record_byte_size_);
}
}
void put(std::vector<SlotRecord>* input) {
Expand Down Expand Up @@ -1069,6 +1070,7 @@ class SlotObjPool {
bool disable_pool_;
size_t count_; // NOLINT
std::condition_variable cond_;
size_t slot_record_byte_size_ = 0;
};

inline SlotObjPool& SlotRecordPool() {
Expand Down Expand Up @@ -1617,7 +1619,6 @@ class SlotPaddleBoxDataFeed : public DataFeed {
// expand values
void ExpandSlotRecord(SlotRecord* ins);
// pack
bool EnablePvMerge(void);
int GetPackInstance(SlotRecord** ins);
int GetPackPvInstance(SlotPvInstance** pv_ins);
void SetSlotRecordPool(SlotObjPool* pool) { slot_pool_ = pool; }
Expand Down
12 changes: 8 additions & 4 deletions paddle/fluid/framework/data_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2327,8 +2327,10 @@ void PadBoxSlotDataset::PrepareTrain(void) {
->GetPvBatchSize();
compute_thread_batch_nccl(thread_num_, GetPvDataSize(), batchsize, &offset);
for (int i = 0; i < thread_num_; ++i) {
reinterpret_cast<SlotPaddleBoxDataFeed*>(readers_[i].get())
->SetPvInstance(&input_pv_ins_[0]);
SlotPaddleBoxDataFeed* feed =
reinterpret_cast<SlotPaddleBoxDataFeed*>(readers_[i].get());
feed->SetEnablePvMerge(enable_pv_merge_);
feed->SetPvInstance(&input_pv_ins_[0]);
}
for (size_t i = 0; i < offset.size(); ++i) {
reinterpret_cast<SlotPaddleBoxDataFeed*>(readers_[i % thread_num_].get())
Expand All @@ -2343,8 +2345,10 @@ void PadBoxSlotDataset::PrepareTrain(void) {
compute_thread_batch_nccl(thread_num_, GetMemoryDataSize(), batchsize,
&offset);
for (int i = 0; i < thread_num_; ++i) {
reinterpret_cast<SlotPaddleBoxDataFeed*>(readers_[i].get())
->SetSlotRecord(&input_records_[0]);
SlotPaddleBoxDataFeed* feed =
reinterpret_cast<SlotPaddleBoxDataFeed*>(readers_[i].get());
feed->SetEnablePvMerge(false);
feed->SetSlotRecord(&input_records_[0]);
}
for (size_t i = 0; i < offset.size(); ++i) {
reinterpret_cast<SlotPaddleBoxDataFeed*>(readers_[i % thread_num_].get())
Expand Down
16 changes: 11 additions & 5 deletions paddle/fluid/framework/fleet/box_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -764,10 +764,15 @@ class BoxWrapper {
void InitializeAucRunner(std::vector<std::vector<std::string>> slot_eval,
int thread_num, int pool_size,
std::vector<std::string> slot_list) {
PADDLE_ENFORCE_EQ(FLAGS_padbox_auc_runner_mode, true,
platform::errors::InvalidArgument(
"you should export FLAGS_padbox_auc_runner_mode=true "
"in auc runner mode."));
// PADDLE_ENFORCE_EQ(FLAGS_padbox_auc_runner_mode, true,
// platform::errors::InvalidArgument(
// "you should export
// FLAGS_padbox_auc_runner_mode=true "
// "in auc runner mode."));
size_t object_bytes = sizeof(SlotRecordObject) +
sizeof(float) * FLAGS_padbox_slotrecord_extend_dim +
sizeof(AucRunnerInfo);
SlotRecordPool().set_slotrecord_size(object_bytes);

mode_ = 1;
phase_num_ = static_cast<int>(slot_eval.size());
Expand All @@ -788,7 +793,8 @@ class BoxWrapper {

VLOG(0) << "AucRunner configuration: thread number[" << thread_num
<< "], pool size[" << pool_size << "], runner_group[" << phase_num_
<< "], eval size:[" << slot_eval_set_.size() << "]";
<< "], eval size:[" << slot_eval_set_.size() << "]"
<< ", object size:[" << object_bytes << "]";
// VLOG(0) << "Slots that need to be evaluated:";
// for (auto e : slot_index_to_replace_) {
// VLOG(0) << e << ": " << slot_list[e];
Expand Down

0 comments on commit 1d813ff

Please sign in to comment.