From 42c5a775f038ef2d88b51532524b24c5bc3c2105 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Tue, 20 Jun 2023 10:14:11 +0100 Subject: [PATCH] Catch exceptions thrown during inference and report as errors --- bin/pytorch_inference/CCommandParser.h | 8 +++++-- bin/pytorch_inference/Main.cc | 29 +++++++++++++++++--------- bin/pytorch_inference/evaluate.py | 2 +- include/core/CCompressedLfuCache.h | 17 ++++++++++----- 4 files changed, 38 insertions(+), 18 deletions(-) diff --git a/bin/pytorch_inference/CCommandParser.h b/bin/pytorch_inference/CCommandParser.h index c723a1da70..15fe8c06a1 100644 --- a/bin/pytorch_inference/CCommandParser.h +++ b/bin/pytorch_inference/CCommandParser.h @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -58,7 +59,7 @@ class CCommandParser { //! \brief Inference request cache interface. class CRequestCacheInterface { public: - using TComputeResponse = std::function; + using TComputeResponse = std::function(SRequest)>; using TReadResponse = std::function; public: @@ -102,7 +103,10 @@ class CCommandParser { bool lookup(SRequest request, const TComputeResponse& computeResponse, const TReadResponse& readResponse) override { - readResponse(computeResponse(std::move(request)), false); + auto computed = computeResponse(std::move(request)); + if (computed) { + readResponse(*computed, false); + } return false; } diff --git a/bin/pytorch_inference/Main.cc b/bin/pytorch_inference/Main.cc index 2b93ec5fd9..9da1cd7bf1 100644 --- a/bin/pytorch_inference/Main.cc +++ b/bin/pytorch_inference/Main.cc @@ -18,6 +18,7 @@ #include #include +#include #include #include @@ -78,16 +79,24 @@ bool handleRequest(ml::torch::CCommandParser::CRequestCacheInterface& cache, // We time the combination of the cache lookup and (if necessary) // the inference. ml::core::CStopWatch stopWatch(true); - cache.lookup(std::move(capturedRequest), - [&](auto request_) -> std::string { - torch::Tensor results = infer(module_, request_); - return resultWriter.createInnerResult(results); - }, - [&](const auto& innerResponseJson_, bool isCacheHit) { - resultWriter.wrapAndWriteInnerResponse(innerResponseJson_, - requestId, isCacheHit, - stopWatch.stop()); - }); + cache.lookup( + std::move(capturedRequest), + [&](auto request_) -> std::optional { + try { + torch::Tensor results = infer(module_, request_); + return resultWriter.createInnerResult(results); + } catch (const c10::Error& e) { + resultWriter.writeError(request_.s_RequestId, e.what()); + return std::nullopt; + } catch (std::runtime_error& e) { + resultWriter.writeError(request_.s_RequestId, e.what()); + return std::nullopt; + } + }, + [&](const auto& innerResponseJson_, bool isCacheHit) { + resultWriter.wrapAndWriteInnerResponse( + innerResponseJson_, requestId, isCacheHit, stopWatch.stop()); + }); }); return true; } diff --git a/bin/pytorch_inference/evaluate.py b/bin/pytorch_inference/evaluate.py index ac01ded113..6845c906ef 100644 --- a/bin/pytorch_inference/evaluate.py +++ b/bin/pytorch_inference/evaluate.py @@ -288,7 +288,7 @@ def test_evaluation(args): for result in result_docs: if 'error' in result: - print(f"Inference failed. Request: {result['error']['request_id']}, Msg: {result['error']['error']}") + print(f"Inference failed. Request: {result['request_id']}, Msg: {result['error']['error']}") results_match = False continue diff --git a/include/core/CCompressedLfuCache.h b/include/core/CCompressedLfuCache.h index b19f09ea12..17529ebe68 100644 --- a/include/core/CCompressedLfuCache.h +++ b/include/core/CCompressedLfuCache.h @@ -30,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -65,7 +66,7 @@ class CCompressedLfuCache { using TDictionary = CCompressedDictionary; using TCompressedKey = typename TDictionary::CWord; using TCompressKey = std::function; - using TComputeValueCallback = std::function; + using TComputeValueCallback = std::function(KEY)>; using TReadValueCallback = std::function; public: @@ -96,6 +97,9 @@ class CCompressedLfuCache { //! Lookup an item with \p key in the cache or else fall back to computing. //! + //! \warning If \p computeValue fails to produce a value (returns std::nullopt) + //! then \p readValue will not be called. + //! //! \param[in] key The item key. //! \param[in] computeValue Computes the value in the case of a cache miss. //! \param[in] readValue Processes the value. @@ -137,15 +141,18 @@ class CCompressedLfuCache { } auto value = computeValue(std::move(key)); + if (!value) { + return false; + } - std::size_t itemMemoryUsage{memory::dynamicSize(value)}; + std::size_t itemMemoryUsage{memory::dynamicSize(*value)}; if (this->guardWrite(TIME_OUT, [&] { // It is possible that two values with the same key check the cache // before either takes the write lock. So check if this is already // in the cache before going any further. if (m_ItemCache.find(compressedKey) != m_ItemCache.end()) { - readValue(value, true); + readValue(*value, true); this->incrementCount(compressedKey); return; } @@ -158,14 +165,14 @@ class CCompressedLfuCache { // It's possible that the cache is empty yet isn't big // enough to hold this new item. if (itemToEvict == m_ItemStats.end()) { - readValue(value, false); + readValue(*value, false); return; } m_RemovedCount += lastEvictedCount; lastEvictedCount = itemToEvict->count(); this->removeFromCache(itemToEvict); } - readValue(this->insert(compressedKey, value, itemMemoryUsage, + readValue(this->insert(compressedKey, *value, itemMemoryUsage, count + lastEvictedCount), false); }) == false) {