Skip to content

Commit

Permalink
issue=baidu#936 optimize sdk performance
Browse files Browse the repository at this point in the history
  • Loading branch information
00k committed Jul 12, 2016
1 parent 401afb3 commit c9b7aae
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 54 deletions.
87 changes: 70 additions & 17 deletions src/sdk/sdk_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include <glog/logging.h>

#include "utils/timer.h"

namespace tera {

int64_t SdkTask::GetRef() {
Expand Down Expand Up @@ -34,33 +36,52 @@ void SdkTask::ExcludeOtherRef() {
CHECK_EQ(_ref, 1);
}

bool SdkTaskHashMap::PutTask(SdkTask* task) {
int64_t GetSdkTaskId(SdkTask* task) {
return task->GetId();
}

uint64_t GetSdkTaskDueTime(SdkTask* task) {
return task->DueTime();
}

SdkTimeoutManager::SdkTimeoutManager(ThreadPool* thread_pool)
: _thread_pool(thread_pool) {
CheckTimeout();
}

bool SdkTimeoutManager::PutTask(SdkTask* task, int64_t timeout,
SdkTask::TimeoutFunc timeout_func) {
int64_t task_id = task->GetId();
CHECK_GE(task_id, 0);
if (timeout > 0) {
task->SetDueTime(get_millis() + timeout);
task->SetTimeoutFunc(timeout_func);
}

uint32_t shard_id = Shard(task_id);
TaskHashMap& map = _map_shard[shard_id];
TaskMap& map = _map_shard[shard_id];
Mutex& mutex = _mutex_shard[shard_id];

MutexLock l(&mutex);
std::pair<TaskHashMap::iterator, bool> insert_ret;
insert_ret = map.insert(std::pair<int64_t, SdkTask*>(task_id, task));
std::pair<TaskMap::iterator, bool> insert_ret;
insert_ret = map.insert(task);
bool insert_success = insert_ret.second;
if (insert_success) {
task->IncRef();
}
return insert_success;
}

SdkTask* SdkTaskHashMap::GetTask(int64_t task_id) {
SdkTask* SdkTimeoutManager::GetTask(int64_t task_id) {
uint32_t shard_id = Shard(task_id);
TaskHashMap& map = _map_shard[shard_id];
TaskMap& map = _map_shard[shard_id];
Mutex& mutex = _mutex_shard[shard_id];

MutexLock l(&mutex);
TaskHashMap::iterator it = map.find(task_id);
if (it != map.end()) {
SdkTask* task = it->second;
TaskIdIndex& id_index = map.get<INDEX_BY_ID>();
TaskIdIndex::iterator it = id_index.find(task_id);
if (it != id_index.end()) {
SdkTask* task = *it;
CHECK_EQ(task->GetId(), task_id);
task->IncRef();
return task;
Expand All @@ -69,25 +90,57 @@ SdkTask* SdkTaskHashMap::GetTask(int64_t task_id) {
}
}

SdkTask* SdkTaskHashMap::PopTask(int64_t task_id) {
SdkTask* SdkTimeoutManager::PopTask(int64_t task_id) {
uint32_t shard_id = Shard(task_id);
TaskHashMap& map = _map_shard[shard_id];
TaskMap& map = _map_shard[shard_id];
Mutex& mutex = _mutex_shard[shard_id];

MutexLock l(&mutex);
TaskHashMap::iterator it = map.find(task_id);
if (it != map.end()) {
SdkTask* task = it->second;
TaskIdIndex& id_index = map.get<INDEX_BY_ID>();
TaskIdIndex::iterator it = id_index.find(task_id);
if (it != id_index.end()) {
SdkTask* task = *it;
CHECK_EQ(task->GetId(), task_id);
map.erase(it);
id_index.erase(it);
return task;
} else {
return NULL;
}
}

uint32_t SdkTaskHashMap::Shard(int64_t task_id) {
return (uint64_t)task_id >> (64 - kShardBits);
void SdkTimeoutManager::CheckTimeout() {
int64_t now_ms = get_millis();
for (uint32_t shard_id = 0; shard_id < kShardNum; shard_id++) {
TaskMap& map = _map_shard[shard_id];
Mutex& mutex = _mutex_shard[shard_id];

MutexLock l(&mutex);
while (!map.empty()) {
TaskDueTimeIndex& due_time_index = map.get<INDEX_BY_DUE_TIME>();
TaskDueTimeIndex::iterator it = due_time_index.begin();
SdkTask* task = *it;
if (task->DueTime() > (uint64_t)now_ms) {
break;
}
due_time_index.erase(it);
mutex.Unlock();
_thread_pool->AddTask(boost::bind(&SdkTimeoutManager::RunTimeoutFunc, this, task));
mutex.Lock();
}
}
if (get_millis() == now_ms) {
_thread_pool->DelayTask(1, boost::bind(&SdkTimeoutManager::CheckTimeout, this));
} else {
_thread_pool->AddTask(boost::bind(&SdkTimeoutManager::CheckTimeout, this));
}
}

void SdkTimeoutManager::RunTimeoutFunc(SdkTask* sdk_task) {
sdk_task->GetTimeoutFunc()(sdk_task);
}

uint32_t SdkTimeoutManager::Shard(int64_t task_id) {
return (uint64_t)task_id & ((1ull << kShardBits) - 1);
}

} // namespace tera
62 changes: 50 additions & 12 deletions src/sdk/sdk_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,25 @@
#ifndef TERA_SDK_SDK_TASK_H_
#define TERA_SDK_SDK_TASK_H_

#include <boost/unordered_map.hpp>

#include <boost/function.hpp>
#include <boost/multi_index_container.hpp>
#include <boost/multi_index/global_fun.hpp>
#include <boost/multi_index/hashed_index.hpp>
#include <boost/multi_index/indexed_by.hpp>
#include <boost/multi_index/ordered_index.hpp>

#include "common/base/stdint.h"
#include "common/mutex.h"
#include "common/thread_pool.h"

#include "proto/table_meta.pb.h"
#include "sdk/tera.h"

namespace tera {

class SdkTask {

public:
typedef boost::function<void (SdkTask*)> TimeoutFunc;
enum TYPE {
READ,
MUTATION,
Expand All @@ -33,10 +40,11 @@ class SdkTask {
void SetId(int64_t id) { _id = id; }
int64_t GetId() { return _id; }

void SetTimeout(int64_t timeout) { _timeout = timeout; }
int64_t Timeout() { return _timeout; }
void SetDueTime(uint64_t due_time) { _due_time_ms = due_time; }
uint64_t DueTime() { return _due_time_ms; }

bool operator<(const SdkTask& rhs) const { return _timeout < rhs._timeout; }
void SetTimeoutFunc(TimeoutFunc timeout_func) { _timeout_func = timeout_func; }
TimeoutFunc GetTimeoutFunc() { return _timeout_func; }

int64_t GetRef();
void IncRef();
Expand All @@ -49,7 +57,7 @@ class SdkTask {
_internal_err(kTabletNodeOk),
_meta_timestamp(0),
_id(-1),
_timeout(0),
_due_time_ms(UINT64_MAX),
_cond(&_mutex),
_ref(1) {}
virtual ~SdkTask() {}
Expand All @@ -59,28 +67,58 @@ class SdkTask {
StatusCode _internal_err;
int64_t _meta_timestamp;
int64_t _id;
int64_t _timeout;
uint64_t _due_time_ms; // timestamp of timeout
TimeoutFunc _timeout_func;

Mutex _mutex;
CondVar _cond;
int64_t _ref;
};

class SdkTaskHashMap {
int64_t GetSdkTaskId(SdkTask* task);

uint64_t GetSdkTaskDueTime(SdkTask* task);

class SdkTimeoutManager {
public:
bool PutTask(SdkTask* task);
SdkTimeoutManager(ThreadPool* thread_pool);

// timeout <= 0 means NEVER timeout
bool PutTask(SdkTask* task, int64_t timeout = 0,
SdkTask::TimeoutFunc timeout_func = NULL);
SdkTask* GetTask(int64_t task_id);
SdkTask* PopTask(int64_t task_id);

void CheckTimeout();
void RunTimeoutFunc(SdkTask* sdk_task);

private:
uint32_t Shard(int64_t task_id);

private:
const static uint32_t kShardBits = 6;
const static uint32_t kShardNum = (1 << kShardBits);
typedef boost::unordered_map<int64_t, SdkTask*> TaskHashMap;
TaskHashMap _map_shard[kShardNum];
typedef boost::multi_index_container<
SdkTask*,
boost::multi_index::indexed_by<
// hashed on SdkTask::_id
boost::multi_index::hashed_unique<
boost::multi_index::global_fun<SdkTask*, int64_t, &GetSdkTaskId> >,

// sort by less<int64_t> on SdkTask::_due_time_ms
boost::multi_index::ordered_non_unique<
boost::multi_index::global_fun<SdkTask*, uint64_t, &GetSdkTaskDueTime> >
>
> TaskMap;
enum {
INDEX_BY_ID = 0,
INDEX_BY_DUE_TIME = 1,
};
typedef TaskMap::nth_index<INDEX_BY_ID>::type TaskIdIndex;
typedef TaskMap::nth_index<INDEX_BY_DUE_TIME>::type TaskDueTimeIndex;
TaskMap _map_shard[kShardNum];
mutable Mutex _mutex_shard[kShardNum];
ThreadPool* _thread_pool;
};

} // namespace tera
Expand Down
29 changes: 7 additions & 22 deletions src/sdk/table_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ TableImpl::TableImpl(const std::string& table_name,
_meta_updating_count(0),
_table_meta_cond(&_table_meta_mutex),
_table_meta_updating(false),
_task_pool(thread_pool),
_thread_pool(thread_pool),
_cluster(cluster),
_cluster_private(false),
Expand Down Expand Up @@ -496,19 +497,15 @@ void TableImpl::DistributeMutations(const std::vector<RowMutationImpl*>& mu_list
_perf_counter.mutate_cnt.Inc();
if (called_by_user) {
row_mutation->SetId(_next_task_id.Inc());
_task_pool.PutTask(row_mutation);

int64_t row_timeout = -1;
if (!row_mutation->IsAsync()) {
row_timeout = sync_min_timeout;
} else {
row_timeout = row_mutation->TimeOut() > 0 ? row_mutation->TimeOut() : _timeout;
}
if (row_timeout > 0) {
ThreadPool::Task task =
boost::bind(&TableImpl::MutationTimeout, this, row_mutation->GetId());
_thread_pool->DelayTask(row_timeout, task);
}
SdkTask::TimeoutFunc task = boost::bind(&TableImpl::MutationTimeout, this, _1);
_task_pool.PutTask(row_mutation, row_timeout, task);
}

// flow control
Expand Down Expand Up @@ -811,11 +808,7 @@ void TableImpl::MutateCallBack(std::vector<int64_t>* mu_id_list,
delete mu_id_list;
}

void TableImpl::MutationTimeout(int64_t mutation_id) {
SdkTask* task = _task_pool.PopTask(mutation_id);
if (task == NULL) {
return;
}
void TableImpl::MutationTimeout(SdkTask* task) {
_perf_counter.mutate_timeout_cnt.Inc();
CHECK_NOTNULL(task);
CHECK_EQ(task->Type(), SdkTask::MUTATION);
Expand Down Expand Up @@ -883,17 +876,13 @@ void TableImpl::DistributeReaders(const std::vector<RowReaderImpl*>& row_reader_
RowReaderImpl* row_reader = (RowReaderImpl*)row_reader_list[i];
if (called_by_user) {
row_reader->SetId(_next_task_id.Inc());
_task_pool.PutTask(row_reader);

int64_t row_timeout = sync_min_timeout;
if (row_reader->IsAsync()) {
row_timeout = row_reader->TimeOut() > 0 ? row_reader->TimeOut() : _timeout;
}
if (row_timeout >= 0) {
ThreadPool::Task task =
boost::bind(&TableImpl::ReaderTimeout, this, row_reader->GetId());
_thread_pool->DelayTask(row_timeout, task);
}
SdkTask::TimeoutFunc task = boost::bind(&TableImpl::ReaderTimeout, this, _1);
_task_pool.PutTask(row_reader, row_timeout, task);
}

// flow control
Expand Down Expand Up @@ -1182,11 +1171,7 @@ void TableImpl::DistributeReadersById(std::vector<int64_t>* reader_id_list) {
delete reader_id_list;
}

void TableImpl::ReaderTimeout(int64_t reader_id) {
SdkTask* task = _task_pool.PopTask(reader_id);
if (task == NULL) {
return;
}
void TableImpl::ReaderTimeout(SdkTask* task) {
_perf_counter.reader_timeout_cnt.Inc();
CHECK_NOTNULL(task);
CHECK_EQ(task->Type(), SdkTask::READ);
Expand Down
6 changes: 3 additions & 3 deletions src/sdk/table_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ class TableImpl : public Table {
bool failed, int error_code);

// mutation到达用户设置的超时时间但尚未处理完
void MutationTimeout(int64_t mutation_id);
void MutationTimeout(SdkTask* sdk_task);

// 将一批reader根据rowkey分配给各个TS
void DistributeReaders(const std::vector<RowReaderImpl*>& row_reader_list,
Expand Down Expand Up @@ -294,7 +294,7 @@ class TableImpl : public Table {
bool failed, int error_code);

// reader到达用户设置的超时时间但尚未处理完
void ReaderTimeout(int64_t mutation_id);
void ReaderTimeout(SdkTask* sdk_task);

void ScanTabletAsync(ScanTask* scan_task, bool called_by_user);

Expand Down Expand Up @@ -423,7 +423,7 @@ class TableImpl : public Table {
TableSchema _table_schema;
// end of table meta managerment

SdkTaskHashMap _task_pool;
SdkTimeoutManager _task_pool;
Counter _next_task_id;

master::MasterClient* _master_client;
Expand Down

0 comments on commit c9b7aae

Please sign in to comment.