Skip to content

Commit

Permalink
Catch exceptions thrown during inference and report as errors
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle committed Jun 20, 2023
1 parent 1d29863 commit 42c5a77
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 18 deletions.
8 changes: 6 additions & 2 deletions bin/pytorch_inference/CCommandParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <functional>
#include <iosfwd>
#include <memory>
#include <optional>
#include <string>
#include <vector>

Expand Down Expand Up @@ -58,7 +59,7 @@ class CCommandParser {
//! \brief Inference request cache interface.
class CRequestCacheInterface {
public:
using TComputeResponse = std::function<std::string(SRequest)>;
using TComputeResponse = std::function<std::optional<std::string>(SRequest)>;
using TReadResponse = std::function<void(const std::string&, bool)>;

public:
Expand Down Expand Up @@ -102,7 +103,10 @@ class CCommandParser {
bool lookup(SRequest request,
const TComputeResponse& computeResponse,
const TReadResponse& readResponse) override {
readResponse(computeResponse(std::move(request)), false);
auto computed = computeResponse(std::move(request));
if (computed) {
readResponse(*computed, false);
}
return false;
}

Expand Down
29 changes: 19 additions & 10 deletions bin/pytorch_inference/Main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <core/CStringUtils.h>
#include <core/Concurrency.h>

#include <optional>
#include <seccomp/CSystemCallFilter.h>

#include <ver/CBuildInfo.h>
Expand Down Expand Up @@ -78,16 +79,24 @@ bool handleRequest(ml::torch::CCommandParser::CRequestCacheInterface& cache,
// We time the combination of the cache lookup and (if necessary)
// the inference.
ml::core::CStopWatch stopWatch(true);
cache.lookup(std::move(capturedRequest),
[&](auto request_) -> std::string {
torch::Tensor results = infer(module_, request_);
return resultWriter.createInnerResult(results);
},
[&](const auto& innerResponseJson_, bool isCacheHit) {
resultWriter.wrapAndWriteInnerResponse(innerResponseJson_,
requestId, isCacheHit,
stopWatch.stop());
});
cache.lookup(
std::move(capturedRequest),
[&](auto request_) -> std::optional<std::string> {
try {
torch::Tensor results = infer(module_, request_);
return resultWriter.createInnerResult(results);
} catch (const c10::Error& e) {
resultWriter.writeError(request_.s_RequestId, e.what());
return std::nullopt;
} catch (std::runtime_error& e) {
resultWriter.writeError(request_.s_RequestId, e.what());
return std::nullopt;
}
},
[&](const auto& innerResponseJson_, bool isCacheHit) {
resultWriter.wrapAndWriteInnerResponse(
innerResponseJson_, requestId, isCacheHit, stopWatch.stop());
});
});
return true;
}
Expand Down
2 changes: 1 addition & 1 deletion bin/pytorch_inference/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def test_evaluation(args):
for result in result_docs:

if 'error' in result:
print(f"Inference failed. Request: {result['error']['request_id']}, Msg: {result['error']['error']}")
print(f"Inference failed. Request: {result['request_id']}, Msg: {result['error']['error']}")
results_match = False
continue

Expand Down
17 changes: 12 additions & 5 deletions include/core/CCompressedLfuCache.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <limits>
#include <memory>
#include <mutex>
#include <optional>
#include <set>
#include <shared_mutex>
#include <string>
Expand Down Expand Up @@ -65,7 +66,7 @@ class CCompressedLfuCache {
using TDictionary = CCompressedDictionary<COMPRESSED_KEY_BITS / 64>;
using TCompressedKey = typename TDictionary::CWord;
using TCompressKey = std::function<TCompressedKey(const TDictionary&, const KEY&)>;
using TComputeValueCallback = std::function<VALUE(KEY)>;
using TComputeValueCallback = std::function<std::optional<VALUE>(KEY)>;
using TReadValueCallback = std::function<void(const VALUE&, bool)>;

public:
Expand Down Expand Up @@ -96,6 +97,9 @@ class CCompressedLfuCache {

//! Lookup an item with \p key in the cache or else fall back to computing.
//!
//! \warning If \p computeValue fails to produce a value (returns std::nullopt)
//! then \p readValue will not be called.
//!
//! \param[in] key The item key.
//! \param[in] computeValue Computes the value in the case of a cache miss.
//! \param[in] readValue Processes the value.
Expand Down Expand Up @@ -137,15 +141,18 @@ class CCompressedLfuCache {
}

auto value = computeValue(std::move(key));
if (!value) {
return false;
}

std::size_t itemMemoryUsage{memory::dynamicSize(value)};
std::size_t itemMemoryUsage{memory::dynamicSize(*value)};

if (this->guardWrite(TIME_OUT, [&] {
// It is possible that two values with the same key check the cache
// before either takes the write lock. So check if this is already
// in the cache before going any further.
if (m_ItemCache.find(compressedKey) != m_ItemCache.end()) {
readValue(value, true);
readValue(*value, true);
this->incrementCount(compressedKey);
return;
}
Expand All @@ -158,14 +165,14 @@ class CCompressedLfuCache {
// It's possible that the cache is empty yet isn't big
// enough to hold this new item.
if (itemToEvict == m_ItemStats.end()) {
readValue(value, false);
readValue(*value, false);
return;
}
m_RemovedCount += lastEvictedCount;
lastEvictedCount = itemToEvict->count();
this->removeFromCache(itemToEvict);
}
readValue(this->insert(compressedKey, value, itemMemoryUsage,
readValue(this->insert(compressedKey, *value, itemMemoryUsage,
count + lastEvictedCount),
false);
}) == false) {
Expand Down

0 comments on commit 42c5a77

Please sign in to comment.