Skip to content

Commit

Permalink
Merge branch 'develop' into async_refine
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaocaibei123 authored Apr 1, 2022
2 parents 80f95df + 0b0c276 commit 03b6c3b
Show file tree
Hide file tree
Showing 89 changed files with 1,568 additions and 1,526 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/distributed/ps/service/brpc_ps_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1460,7 +1460,7 @@ void sparse_local_merge(ValueAccessor *accessor, float *merge_data,
merge_data_shell[i] = merge_data + i;
another_data_shell[i] = another_data + i;
}
accessor->merge(merge_data_shell, another_data_shell, 1);
accessor->Merge(merge_data_shell, another_data_shell, 1);
}

int BrpcPsClient::PushSparseAsyncShardMerge(
Expand Down Expand Up @@ -1699,7 +1699,7 @@ void BrpcPsClient::PushDenseTaskConsume() {
async_task]() -> int {
auto &tmp_task_vec = *(async_task->data());
const float *merge_data = tmp_task_vec.data();
accessor->merge(&total_send_data, &merge_data,
accessor->Merge(&total_send_data, &merge_data,
total_send_data_size);
#pragma optimize("", off)
auto *debug_closure = closure;
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/distributed/ps/service/brpc_ps_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,9 @@ int32_t BrpcPsService::PullDense(Table *table, const PsRequestMessage &request,
}

auto res_data = butil::get_object<std::vector<float>>();
res_data->resize(num * table->ValueAccesor()->select_size() / sizeof(float));
res_data->resize(num * table->value_accesor()->GetTableInfo(SELECT_SIZE) /
sizeof(float));

TableContext table_context;
table_context.value_type = Dense;
table_context.pull_context.values = res_data->data();
Expand Down Expand Up @@ -382,7 +384,7 @@ int32_t BrpcPsService::PullSparse(Table *table, const PsRequestMessage &request,

CostTimer timer("pserver_server_pull_sparse");
uint32_t num = *(uint32_t *)(request.params(0).c_str());
auto dim = table->ValueAccesor()->select_dim();
auto dim = table->value_accesor()->GetTableInfo(SELECT_DIM);

thread_local std::string req_buffer;
req_buffer.reserve(req_buffer_size);
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/distributed/ps/service/ps_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ int32_t PSClient::Configure(
auto *accessor = CREATE_PSCORE_CLASS(
ValueAccessor,
work_param.downpour_table_param(i).accessor().accessor_class());
accessor->configure(work_param.downpour_table_param(i).accessor());
accessor->initialize();
accessor->Configure(work_param.downpour_table_param(i).accessor());
accessor->Initialize();
_table_accessors[work_param.downpour_table_param(i).table_id()].reset(
accessor);
}
Expand Down
10 changes: 7 additions & 3 deletions paddle/fluid/distributed/ps/service/ps_local_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ ::std::future<int32_t> PsLocalClient::PullDense(Region* regions,
auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = GetTable(table_id);

uint32_t num_per_shard = DenseDimPerShard(accessor->fea_dim(), 1);
uint32_t num_per_shard =
dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), 1);

std::vector<float> region_buffer;
region_buffer.resize(num_per_shard);
table_ptr->PullDense(region_buffer.data(), region_buffer.size());
Expand Down Expand Up @@ -144,7 +146,8 @@ ::std::future<int32_t> PsLocalClient::PushDenseParam(const Region* regions,
auto* table_ptr = GetTable(table_id);

std::vector<float> region_buffer;
region_buffer.resize(DenseDimPerShard(accessor->fea_dim(), 1), 0);
region_buffer.resize(DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), 1), 0);

for (size_t i = 0, offset = 0; i < region_num; ++i) {
uint32_t data_num = regions[i].size / sizeof(float);
memcpy(region_buffer.data() + offset, regions[i].data, regions[i].size);
Expand Down Expand Up @@ -177,7 +180,8 @@ ::std::future<int32_t> PsLocalClient::PushDense(const Region* regions,
auto* table_ptr = GetTable(table_id);

std::vector<float> region_buffer;
region_buffer.resize(DenseDimPerShard(accessor->fea_dim(), 1));
region_buffer.resize(DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), 1));

size_t data_size = region_buffer.size();
for (size_t i = 0, offset = 0; i < region_num; ++i) {
uint32_t data_num = regions[i].size / sizeof(float);
Expand Down
69 changes: 23 additions & 46 deletions paddle/fluid/distributed/ps/table/accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class ValueAccessor {
ValueAccessor() {}
virtual ~ValueAccessor() {}

virtual int configure(const TableAccessorParameter& parameter) {
virtual int Configure(const TableAccessorParameter& parameter) {
_config = parameter;
// data_convert结构体初始化
if (_config.table_accessor_save_param_size() != 0) {
Expand All @@ -88,38 +88,15 @@ class ValueAccessor {
}
return 0;
}
virtual int initialize() = 0;
virtual int Initialize() = 0;

virtual void SetTableInfo(AccessorInfo& info) = 0;
virtual size_t GetTableInfo(InfoKey key) = 0;

// value维度
virtual size_t dim() = 0;
// value各个维度的size
virtual size_t dim_size(size_t dim) = 0;
// value各维度相加总size
virtual size_t size() = 0;

// value中mf动态长度部分总size大小, sparse下生效
virtual size_t mf_size() { return 0; }
virtual bool need_extend_mf(float* value) { return false; }
virtual bool has_mf(size_t size) { return false; }
// pull value维度
virtual size_t select_dim() = 0;
// pull value各个维度的size
virtual size_t select_dim_size(size_t dim) = 0;
// pull value各维度相加总size
virtual size_t select_size() = 0;
// push value维度
virtual size_t update_dim() = 0;
// push value各个维度的size
virtual size_t update_dim_size(size_t dim) = 0;
// push value各维度相加总size
virtual size_t update_size() = 0;
// fea total for dense
virtual size_t fea_dim() { return _config.fea_dim(); }
virtual bool NeedExtendMF(float* value) { return false; }
virtual bool HasMF(size_t size) { return false; }
// converter for save
virtual std::string get_converter(int param) {
virtual std::string GetConverter(int param) {
auto itr = _data_coverter_map.find(param);
if (itr == _data_coverter_map.end()) {
return "";
Expand All @@ -128,7 +105,7 @@ class ValueAccessor {
}
}
// deconverter for load
virtual std::string get_deconverter(int param) {
virtual std::string GetDeconverter(int param) {
auto itr = _data_coverter_map.find(param);
if (itr == _data_coverter_map.end()) {
return "";
Expand All @@ -137,47 +114,47 @@ class ValueAccessor {
}
}
// 判断该value是否进行shrink
virtual bool shrink(float* value) = 0;
virtual bool Shrink(float* value) = 0;

// 判断该value是否在save阶段dump,
// param作为参数用于标识save阶段,如downpour的xbox与batch_model
virtual bool save(float* value, int param) = 0;
virtual bool Save(float* value, int param) = 0;
// update delta_score and unseen_days after save
virtual void update_stat_after_save(float* value, int param) {}
virtual void UpdateStatAfterSave(float* value, int param) {}

// keys不存在时,为values生成随机值
virtual int32_t create(float** value, size_t num) = 0;
virtual bool create_value(int type, const float* value) { return true; }
virtual int32_t Create(float** value, size_t num) = 0;
virtual bool CreateValue(int type, const float* value) { return true; }
// 从values中选取到select_values中
virtual int32_t select(float** select_values, const float** values,
virtual int32_t Select(float** select_values, const float** values,
size_t num) = 0;
// 将update_values聚合到一起
virtual int32_t merge(float** update_values,
virtual int32_t Merge(float** update_values,
const float** other_update_values, size_t num) = 0;
// 将update_values聚合到一起,通过it.next判定是否进入下一个key
// virtual int32_t merge(float** update_values, iterator it);
// virtual int32_t Merge(float** update_values, iterator it);
// 将update_values更新应用到values中
virtual int32_t update(float** values, const float** update_values,
virtual int32_t Update(float** values, const float** update_values,
size_t num) = 0;

// used to save model, will filter feature
virtual std::string parse_to_string(const float* value, int param) = 0;
virtual std::string ParseToString(const float* value, int param) = 0;
// parse value from string, used to load model
virtual int32_t parse_from_string(const std::string& data, float* value) = 0;
virtual int32_t ParseFromString(const std::string& data, float* value) = 0;

virtual FsDataConverter converter(int param) {
virtual FsDataConverter Converter(int param) {
FsDataConverter data_convert;
data_convert.converter = this->get_converter(param);
data_convert.deconverter = this->get_deconverter(param);
data_convert.converter = this->GetConverter(param);
data_convert.deconverter = this->GetDeconverter(param);
return data_convert;
}

virtual int set_weight(float** values, const float** update_values,
size_t num) {
virtual int SetWeight(float** values, const float** update_values,
size_t num) {
return 0;
}

virtual float get_field(float* value, const std::string& name) { return 0.0; }
virtual float GetField(float* value, const std::string& name) { return 0.0; }
#define DEFINE_GET_INDEX(class, field) \
virtual int get_##field##_index() override { return class ::field##_index(); }

Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/distributed/ps/table/common_dense_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,9 @@ int32_t CommonDenseTable::Load(const std::string& path,
int load_param = atoi(param.c_str());
FsChannelConfig channel_config;

channel_config.converter = _value_accesor->converter(load_param).converter;
channel_config.converter = _value_accesor->Converter(load_param).converter;
channel_config.deconverter =
_value_accesor->converter(load_param).deconverter;
_value_accesor->Converter(load_param).deconverter;
bool is_read_failed = false;
int err_no = 0;
int retry_num = 0;
Expand Down Expand Up @@ -329,9 +329,9 @@ int32_t CommonDenseTable::Save(const std::string& path,
"%s/part-%03d", TableDir(path).c_str(), _shard_idx);
}
_afs_client.remove(channel_config.path);
channel_config.converter = _value_accesor->converter(save_param).converter;
channel_config.converter = _value_accesor->Converter(save_param).converter;
channel_config.deconverter =
_value_accesor->converter(save_param).deconverter;
_value_accesor->Converter(save_param).deconverter;

bool is_write_failed = false;
std::vector<std::vector<std::string>> result_buffer_param(
Expand Down
Loading

0 comments on commit 03b6c3b

Please sign in to comment.