Skip to content

Commit

Permalink
nGraph EP Optimizations (#1630)
Browse files Browse the repository at this point in the history
* Added check for unnecessary function initializations, and removed lock from unneeded areas of code.

* Added LRU cache to EP.

* Bugfixes for nGraph EP Optimization PR

* Changed default cache size to 500 and refactored mutex readability.

* Fixed unsafe environmental variable fetch for Windows.

* Cleaned up Windows environment functions and cleaned up mutexes.
  • Loading branch information
tvtrimel authored and jywu-msft committed Aug 21, 2019
1 parent a68a20e commit 97d0a46
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 6 deletions.
53 changes: 48 additions & 5 deletions onnxruntime/core/providers/ngraph/ngraph_custom_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
namespace onnxruntime {
namespace ngraph_ep {

#define NGRAPH_EP_LRU_CACHE_DEFAULT_SIZE 500

static bool check_ngraph_dump_ops() {
#ifdef _WIN32
size_t env_name_len = 0;
Expand Down Expand Up @@ -80,7 +82,45 @@ void NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* con
uniq_input_shape.append(reinterpret_cast<const char*>(tensor_shape.data()), ndim * sizeof(int64_t));
}

auto it = ng_exe_map_.insert({uniq_input_shape, nullptr}); //TODO: Limit the size of map with configurable size.
// Get cache size from environment
std::string tempSize;
#ifdef _WIN32
char *buf{nullptr};
size_t bufSize = 0;
if (!_dupenv_s(&buf, &bufSize, "ONNXRUNTIME_NGRAPH_LRU_CACHE_SIZE") && buf) {
tempSize = buf;
free(buf);
}
#else
if (std::getenv("ONNXRUNTIME_NGRAPH_LRU_CACHE_SIZE")) {
tempSize = std::getenv("ONNXRUNTIME_NGRAPH_LRU_CACHE_SIZE");
}
#endif
size_t cacheSize = tempSize.empty() ? NGRAPH_EP_LRU_CACHE_DEFAULT_SIZE : std::stoi(tempSize);

// Not in cache
if (ng_exe_map_.find(uniq_input_shape) == ng_exe_map_.end()) {
// Check if full
if (keyCache.size() == cacheSize) {
// Delete least recently used element
std::string last = keyCache.back();

// Pop the last elmeent
keyCache.pop_back();

// Erase the last element from cache
ng_exe_map_.erase(ng_exe_map_.find(last));
}
}

// Found in cache
else {
keyCache.remove(uniq_input_shape);
}

// update reference
keyCache.push_front(uniq_input_shape);
auto it = ng_exe_map_.insert({uniq_input_shape, nullptr});

//ng_exe with current shape already exists
if (!it.second) {
Expand Down Expand Up @@ -141,11 +181,11 @@ void NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* con
Status NGRAPHCustomOp::Compute(const OrtCustomOpApi* api, OrtKernelContext* context) const {
Ort::CustomOpApi ort{*api};

//TODO: Minimize locked region
std::lock_guard<std::mutex> lock(compute_lock_);

// Initialize nGraph function if it is not already initialized.
Initialize(api, context);
{
std::lock_guard<std::mutex> lock(compute_lock_);
Initialize(api, context);
}

ORT_ENFORCE(ng_curr_exe_ != nullptr);

Expand All @@ -158,6 +198,7 @@ Status NGRAPHCustomOp::Compute(const OrtCustomOpApi* api, OrtKernelContext* cont
for (const auto& ng_param : ng_curr_exe_->get_parameters()) {
const OrtValue* input_tensor = ort.KernelContext_GetInput(context, input_index++);
void* input_data = const_cast<void*>(ort.GetTensorData<void>(input_tensor));
std::lock_guard<std::mutex> lock(compute_lock_);
ng_inputs.emplace_back(ng_backend_->create_tensor(ng_param->get_output_element_type(0), ng_param->get_output_shape(0), input_data));
}
} catch (const std::exception& exp) {
Expand All @@ -177,6 +218,7 @@ Status NGRAPHCustomOp::Compute(const OrtCustomOpApi* api, OrtKernelContext* cont
std::vector<int64_t> ort_shape{shape.begin(), shape.end()};
OrtValue* output_tensor = ort.KernelContext_GetOutput(context, output_index++, ort_shape.data(), ort_shape.size());
void* output_data = ort.GetTensorMutableData<void>(output_tensor);
std::lock_guard<std::mutex> lock(compute_lock_);
ng_outputs.emplace_back(ng_backend_->create_tensor(dtype, shape, output_data));
}
} catch (const std::exception& exp) {
Expand All @@ -187,6 +229,7 @@ Status NGRAPHCustomOp::Compute(const OrtCustomOpApi* api, OrtKernelContext* cont

// Run the graph through nGraph.
try {
std::lock_guard<std::mutex> lock(compute_lock_);
if (!ng_curr_exe_->call(ng_outputs, ng_inputs))
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, name_ + ": Error while executing nGraph computation");
} catch (const std::exception& exp) {
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/ngraph/ngraph_custom_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ class NGRAPHCustomOp {
key = [3,1,2,3,2,4,5]
*/
mutable std::unordered_map<std::string, std::shared_ptr<ngraph::runtime::Executable>> ng_exe_map_;

mutable std::list<std::string> keyCache;

mutable std::mutex compute_lock_;

mutable ONNX_NAMESPACE::ModelProto model_proto_;
Expand Down

0 comments on commit 97d0a46

Please sign in to comment.