Skip to content

Commit

Permalink
Threadsafe batcher is templated, unsure if this is better
Browse files Browse the repository at this point in the history
  • Loading branch information
Jerin Philip committed Aug 5, 2021
1 parent bdb610e commit 68ae377
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 57 deletions.
1 change: 0 additions & 1 deletion src/translator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ add_library(bergamot-translator STATIC
batch.cpp
annotation.cpp
service.cpp
threadsafe_batcher.cpp
)
if (USE_WASM_COMPATIBLE_SOURCE)
# Using wasm compatible sources should include this compile definition;
Expand Down
20 changes: 13 additions & 7 deletions src/translator/aggregate_batching_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,32 @@
namespace marian {
namespace bergamot {

void AggregateBatchingPool::addRequest(Ptr<TranslationModel> model, Ptr<Request> request) {
AggregateBatchingPool::AggregateBatchingPool(Ptr<Options> options) {
// TODO(@jerinphilip): Set aggregate limits
}

size_t AggregateBatchingPool::addRequest(Ptr<TranslationModel> model, Ptr<Request> request) {
model->addRequest(request);
aggregateQueue_.push(model);
return request->numSegments();
}
bool AggregateBatchingPool::generateBatch(Ptr<TranslationModel>& model, Batch& batch) {

size_t AggregateBatchingPool::generateBatch(Ptr<TranslationModel>& model, Batch& batch) {
while (model == nullptr && !aggregateQueue_.empty()) {
std::weak_ptr<TranslationModel> weakModel = aggregateQueue_.front();
model = weakModel.lock();
if (model) {
bool retCode = model->generateBatch(batch);
if (retCode) {
// We found a batch.
return true;
size_t numSentences = model->generateBatch(batch);
if (numSentences > 0) {
return numSentences;
} else {
// Try the next model's batching pool.
aggregateQueue_.pop();
}
}
}
return false;

return 0;
}

} // namespace bergamot
Expand Down
5 changes: 3 additions & 2 deletions src/translator/aggregate_batching_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ namespace bergamot {

class AggregateBatchingPool {
public:
void addRequest(Ptr<TranslationModel> model, Ptr<Request> request);
bool generateBatch(Ptr<TranslationModel>& model, Batch& batch);
AggregateBatchingPool(Ptr<Options> options);
size_t addRequest(Ptr<TranslationModel> model, Ptr<Request> request);
size_t generateBatch(Ptr<TranslationModel>& model, Batch& batch);

private:
std::queue<std::weak_ptr<TranslationModel>> aggregateQueue_;
Expand Down
11 changes: 6 additions & 5 deletions src/translator/batching_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ BatchingPool::BatchingPool(Ptr<Options> options) {
"longer than what can fit in a batch.");
}

bool BatchingPool::generateBatch(Batch &batch) {
size_t BatchingPool::generateBatch(Batch &batch) {
// For now simply iterates on buckets and converts batches greedily. This
// has to be enhanced with optimizing over priority. The baseline
// implementation should at least be as fast as marian's maxi-batch with full
Expand All @@ -35,22 +35,23 @@ bool BatchingPool::generateBatch(Batch &batch) {
} else {
// Check if elements exist
assert(batch.size() > 0);
return true;
return batch.size();
}
}
}

bool isValidBatch = batch.size() > 0;
return isValidBatch;
return batch.size();
}

void BatchingPool::addRequest(Ptr<Request> request) {
size_t BatchingPool::addRequest(Ptr<Request> request) {
for (size_t i = 0; i < request->numSegments(); i++) {
RequestSentence sentence(i, request);
size_t bucket_id = sentence.numTokens();
assert(bucket_id < bucket_.size());
bucket_[bucket_id].insert(sentence);
}

return request->numSegments();
}

} // namespace bergamot
Expand Down
4 changes: 2 additions & 2 deletions src/translator/batching_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ class BatchingPool {
// RequestSentence incorporates (tentative) notions of priority with each
// sentence. This method inserts the sentence into the internal data-structure
// which maintains priority among sentences from multiple concurrent requests.
void addRequest(Ptr<Request> request);
size_t addRequest(Ptr<Request> request);

// Loads sentences with sentences compiled from (tentatively) multiple
// requests optimizing for both padding and priority.
bool generateBatch(Batch &batch);
size_t generateBatch(Batch &batch);

private:
size_t miniBatchWords;
Expand Down
14 changes: 9 additions & 5 deletions src/translator/service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ namespace marian {
namespace bergamot {

Service::Service(Ptr<Options> options)
: requestId_(0), options_(options), numWorkers_(std::max<int>(1, options->get<int>("cpu-threads"))) {
: requestId_(0),
options_(options),
numWorkers_(std::max<int>(1, options->get<int>("cpu-threads"))),
aggregateBatchingPool_(options),
aggregateBatchingPoolAccess_(aggregateBatchingPool_) {
// translationModel_ = New<TranslationModel>(options_, std::move(memoryBundle), /*replicas=*/numWorkers_);
#ifndef WASM_COMPATIBLE_SOURCE
workers_.reserve(numWorkers_);
Expand All @@ -19,7 +23,7 @@ Service::Service(Ptr<Options> options)
Batch batch;
// Run thread mainloop
Ptr<TranslationModel> translationModel{nullptr};
while (batcher_.generateBatch(translationModel, batch)) {
while (aggregateBatchingPoolAccess_.generateBatch(translationModel, batch)) {
translateBatch(cpuId, translationModel, batch);
}
});
Expand All @@ -43,7 +47,7 @@ std::vector<Response> Service::translateMultiple(Ptr<TranslationModel> translati
Batch batch;
// There's no need to do shutdown here because it's single threaded.
Ptr<TranslationModel> model{nullptr};
while (batcher_.generateBatch(model, batch)) {
while (aggregateBatchingPoolAccess_.generateBatch(model, batch)) {
translateBatch(/*deviceId=*/0, model, batch);
}

Expand All @@ -61,7 +65,7 @@ void Service::queueRequest(Ptr<TranslationModel> translationModel, std::string &
ResponseBuilder responseBuilder(responseOptions, std::move(source), translationModel->vocabs(), std::move(callback));
Ptr<Request> request = New<Request>(requestId_++, std::move(segments), std::move(responseBuilder));

batcher_.addRequest(translationModel, request);
aggregateBatchingPoolAccess_.addRequest(translationModel, request);
}

void Service::translate(Ptr<TranslationModel> translationModel, std::string &&input, CallbackType &&callback,
Expand All @@ -71,7 +75,7 @@ void Service::translate(Ptr<TranslationModel> translationModel, std::string &&in

Service::~Service() {
#ifndef WASM_COMPATIBLE_SOURCE
batcher_.shutdown();
aggregateBatchingPoolAccess_.shutdown();
for (std::thread &worker : workers_) {
assert(worker.joinable());
worker.join();
Expand Down
9 changes: 3 additions & 6 deletions src/translator/service.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,9 @@ class Service {

size_t requestId_;

#ifdef WASM_COMPATIBLE_SOURCE
AggregateBatchingPool batcher_;
#else
/// Batcher handles generation of batches from a request, subject to
/// packing-efficiency and priority optimization heuristics.
ThreadsafeAggregateBatchingPool batcher_;
AggregateBatchingPool aggregateBatchingPool_;
#ifndef WASM_COMPATIBLE_SOURCE
GuardedBatchingPoolAccess<AggregateBatchingPool> aggregateBatchingPoolAccess_;
#endif

// The following constructs are available providing full capabilities on a non
Expand Down
38 changes: 24 additions & 14 deletions src/translator/threadsafe_batcher.cpp
Original file line number Diff line number Diff line change
@@ -1,38 +1,48 @@
#ifndef WASM_COMPATIBLE_SOURCE
#include "threadsafe_batcher.h"

#ifndef SRC_BERGAMOT_THREADSAFE_BATCHER_IMPL
#error "This is an impl file and must not be included directly!"
#endif

#include <cassert>

namespace marian {
namespace bergamot {

ThreadsafeAggregateBatchingPool::ThreadsafeAggregateBatchingPool() : enqueued_(0), shutdown_(false) {}
template <class BatchingPoolType>
GuardedBatchingPoolAccess<BatchingPoolType>::GuardedBatchingPoolAccess(BatchingPoolType &backend)
: backend_(backend), enqueued_(0), shutdown_(false) {}

ThreadsafeAggregateBatchingPool::~ThreadsafeAggregateBatchingPool() { shutdown(); }
template <class BatchingPoolType>
GuardedBatchingPoolAccess<BatchingPoolType>::~GuardedBatchingPoolAccess() {
shutdown();
}

void ThreadsafeAggregateBatchingPool::addRequest(Ptr<TranslationModel> translationModel, Ptr<Request> request) {
template <class BatchingPoolType>
template <class... Args>
void GuardedBatchingPoolAccess<BatchingPoolType>::addRequest(Args &&... args) {
std::unique_lock<std::mutex> lock(mutex_);
assert(!shutdown_);
backend_.addRequest(translationModel, request);
enqueued_ += request->numSegments();
enqueued_ += backend_.addRequest(std::forward<Args>(args)...);
work_.notify_all();
}

void ThreadsafeAggregateBatchingPool::shutdown() {
template <class BatchingPoolType>
void GuardedBatchingPoolAccess<BatchingPoolType>::shutdown() {
std::unique_lock<std::mutex> lock(mutex_);
shutdown_ = true;
work_.notify_all();
}

bool ThreadsafeAggregateBatchingPool::generateBatch(Ptr<TranslationModel> &translationModel, Batch &batch) {
template <class BatchingPoolType>
template <class... Args>
size_t GuardedBatchingPoolAccess<BatchingPoolType>::generateBatch(Args &&... args) {
std::unique_lock<std::mutex> lock(mutex_);
work_.wait(lock, [this]() { return enqueued_ || shutdown_; });
bool ret = backend_.generateBatch(translationModel, batch);
assert(ret || shutdown_);
enqueued_ -= batch.size();
return ret;
size_t sentencesInBatch = backend_.generateBatch(std::forward<Args>(args)...);
assert(sentencesInBatch > 0 || shutdown_);
enqueued_ -= sentencesInBatch;
return sentencesInBatch;
}

} // namespace bergamot
} // namespace marian
#endif // WASM_COMPATIBLE_SOURCE
30 changes: 15 additions & 15 deletions src/translator/threadsafe_batcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,26 @@
#ifndef WASM_COMPATIBLE_SOURCE
#include <condition_variable>
#include <mutex>
#endif

namespace marian {
namespace bergamot {

#ifndef WASM_COMPATIBLE_SOURCE

class ThreadsafeAggregateBatchingPool {
template <class BatchingPoolType>
class GuardedBatchingPoolAccess {
public:
explicit ThreadsafeAggregateBatchingPool();
GuardedBatchingPoolAccess(BatchingPoolType &backend);
~GuardedBatchingPoolAccess();

~ThreadsafeAggregateBatchingPool();
template <class... Args>
void addRequest(Args &&... args);

// Add sentences to be translated by calling these (see Batcher). When
// done, call shutdown.
void addRequest(Ptr<TranslationModel> translationModel, Ptr<Request> request);
void shutdown();
template <class... Args>
size_t generateBatch(Args &&... args);

// Get a batch out of the batcher. Return false to shutdown worker.
bool generateBatch(Ptr<TranslationModel> &translationModel, Batch &batch);
void shutdown();

private:
AggregateBatchingPool backend_;
BatchingPoolType &backend_;

// Number of sentences in backend_;
size_t enqueued_;
Expand All @@ -48,9 +45,12 @@ class ThreadsafeAggregateBatchingPool {
std::condition_variable work_;
};

#endif

} // namespace bergamot
} // namespace marian

#define SRC_BERGAMOT_THREADSAFE_BATCHER_IMPL
#include "threadsafe_batcher.cpp"
#undef SRC_BERGAMOT_THREADSAFE_BATCHER_IMPL

#endif // WASM_COMPATIBLE_SOURCE
#endif // SRC_BERGAMOT_THREADSAFE_BATCHER_H_

0 comments on commit 68ae377

Please sign in to comment.