diff --git a/CMakeLists.txt b/CMakeLists.txt index 47d30a89d2d1..be7814f03cf1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -677,6 +677,9 @@ if(GTEST_FOUND) if(DEFINED LLVM_LIBS) target_link_libraries(cpptest PRIVATE ${LLVM_LIBS}) endif() + if(DEFINED ETHOSN_RUNTIME_LIBRARY) + target_link_libraries(cpptest PRIVATE ${ETHOSN_RUNTIME_LIBRARY}) + endif() set_target_properties(cpptest PROPERTIES EXCLUDE_FROM_ALL 1) set_target_properties(cpptest PROPERTIES EXCLUDE_FROM_DEFAULT_BUILD 1) if(USE_RELAY_DEBUG) diff --git a/src/runtime/contrib/ethosn/ethosn_device.cc b/src/runtime/contrib/ethosn/ethosn_device.cc index 612f4b4cec39..0d79f69815fa 100644 --- a/src/runtime/contrib/ethosn/ethosn_device.cc +++ b/src/runtime/contrib/ethosn/ethosn_device.cc @@ -32,6 +32,7 @@ #include #include +#include #include "ethosn_driver_library/Buffer.hpp" #include "ethosn_runtime.h" @@ -48,7 +49,7 @@ namespace ethosn { namespace dl = ::ethosn::driver_library; -bool WaitForInference(dl::Inference* inference, int timeout) { +InferenceWaitStatus WaitForInference(dl::Inference* inference, int timeout) { // Wait for inference to complete int fd = inference->GetFileDescriptor(); struct pollfd fds; @@ -58,20 +59,32 @@ bool WaitForInference(dl::Inference* inference, int timeout) { const int ms_per_seconds = 1000; int poll_result = poll(&fds, 1, timeout * ms_per_seconds); - if (poll_result > 0) { - dl::InferenceResult result; - if (read(fd, &result, sizeof(result)) != sizeof(result)) { - return false; - } - if (result != dl::InferenceResult::Completed) { - return false; - } + int poll_error_code = errno; + + if (poll_result < 0) { + return InferenceWaitStatus(InferenceWaitErrorCode::kError, + "Error while waiting for the inference to complete (" + + std::string(strerror(poll_error_code)) + ")"); } else if (poll_result == 0) { - return false; - } else { - return false; + return InferenceWaitStatus(InferenceWaitErrorCode::kTimeout, + "Timed out while waiting for the inference to complete."); } - return true; + + // poll_result > 0 + dl::InferenceResult npu_result; + if (read(fd, &npu_result, sizeof(npu_result)) != static_cast(sizeof(npu_result))) { + return InferenceWaitStatus( + InferenceWaitErrorCode::kError, + "Failed to read inference result status (" + std::string(strerror(poll_error_code)) + ")"); + } + + if (npu_result != dl::InferenceResult::Completed) { + return InferenceWaitStatus( + InferenceWaitErrorCode::kError, + "Inference failed with status " + std::to_string(static_cast(npu_result))); + } + + return InferenceWaitStatus(InferenceWaitErrorCode::kSuccess); } void CreateBuffers(std::vector>* fm, @@ -123,21 +136,26 @@ bool Inference(tvm::runtime::TVMArgs args, dl::Network* npu, } // Execute the inference. - std::unique_ptr result( + std::unique_ptr inference( npu->ScheduleInference(ifm_raw, n_inputs, ofm_raw, n_outputs)); - bool inferenceCompleted = WaitForInference(result.get(), 60); - if (inferenceCompleted) { - for (size_t i = 0; i < n_outputs; i++) { - DLTensor* tensor = outputs[i]; - dl::Buffer* source_buffer = ofm_raw[i]; - uint8_t* dest_buffer = static_cast(tensor->data); - size_t size = source_buffer->GetSize(); - uint8_t* source_buffer_data = source_buffer->Map(); - std::copy(source_buffer_data, source_buffer_data + size, dest_buffer); - source_buffer->Unmap(); - } + InferenceWaitStatus result = WaitForInference(inference.get(), 60); + + if (result.GetErrorCode() != InferenceWaitErrorCode::kSuccess) { + LOG(FATAL) << "An error has occured waiting for the inference of a sub-graph on the NPU: " + << result.GetErrorDescription(); + } + + for (size_t i = 0; i < n_outputs; i++) { + DLTensor* tensor = outputs[i]; + dl::Buffer* source_buffer = ofm_raw[i]; + uint8_t* dest_buffer = static_cast(tensor->data); + size_t size = source_buffer->GetSize(); + uint8_t* source_buffer_data = source_buffer->Map(); + std::copy(source_buffer_data, source_buffer_data + size, dest_buffer); + source_buffer->Unmap(); } - return inferenceCompleted; + + return true; } } // namespace ethosn diff --git a/src/runtime/contrib/ethosn/ethosn_runtime.h b/src/runtime/contrib/ethosn/ethosn_runtime.h index 7c8c32e784be..b8942fef12d9 100644 --- a/src/runtime/contrib/ethosn/ethosn_runtime.h +++ b/src/runtime/contrib/ethosn/ethosn_runtime.h @@ -107,6 +107,39 @@ class EthosnModule : public ModuleNode { std::map network_map_; }; +/*! + * \brief Error codes for evaluating the result of inference on the NPU. + */ +enum class InferenceWaitErrorCode { kSuccess = 0, kTimeout = 1, kError = 2 }; + +/*! + * \brief A helper class holding the status of inference on the NPU and + * associated error message(s) if any occurred. + * + * Similar to the implementation of 'WaitStatus' in the driver stack: + * https://github.com/ARM-software/ethos-n-driver-stack/blob/22.08/armnn-ethos-n-backend/workloads/EthosNPreCompiledWorkload.cpp#L48 + */ +class InferenceWaitStatus { + public: + InferenceWaitStatus() : error_code_(InferenceWaitErrorCode::kSuccess), error_description_("") {} + + explicit InferenceWaitStatus(InferenceWaitErrorCode errorCode, std::string errorDescription = "") + : error_code_(errorCode), error_description_(errorDescription) {} + + InferenceWaitStatus(const InferenceWaitStatus&) = default; + InferenceWaitStatus(InferenceWaitStatus&&) = default; + InferenceWaitStatus& operator=(const InferenceWaitStatus&) = default; + InferenceWaitStatus& operator=(InferenceWaitStatus&&) = default; + + explicit operator bool() const { return error_code_ == InferenceWaitErrorCode::kSuccess; } + InferenceWaitErrorCode GetErrorCode() const { return error_code_; } + std::string GetErrorDescription() const { return error_description_; } + + private: + InferenceWaitErrorCode error_code_; + std::string error_description_; +}; + } // namespace ethosn } // namespace runtime } // namespace tvm diff --git a/tests/cpp/runtime/contrib/ethosn/inference_test.cc b/tests/cpp/runtime/contrib/ethosn/inference_test.cc new file mode 100644 index 000000000000..95b27070e19a --- /dev/null +++ b/tests/cpp/runtime/contrib/ethosn/inference_test.cc @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tests/cpp/runtime/contrib/ethosn/inference_test.cc + * \brief Tests to check Arm(R) Ethos(TM)-N runtime components used during inference. + */ + +#ifdef ETHOSN_HW + +#include + +#include "../../../../../src/runtime/contrib/ethosn/ethosn_device.cc" + +namespace tvm { +namespace runtime { +namespace ethosn { + +TEST(WaitForInference, InferenceScheduled) { + const int inference_result = 0 /* Scheduled */; + const int timeout = 0; + + dl::Inference inference = dl::Inference(inference_result); + InferenceWaitStatus result = WaitForInference(&inference, timeout); + + ASSERT_EQ(result.GetErrorCode(), InferenceWaitErrorCode::kTimeout); + ICHECK_EQ(result.GetErrorDescription(), "Timed out while waiting for the inference to complete."); +} + +TEST(WaitForInference, InferenceRunning) { + const int inference_result = 1 /* Running */; + const int timeout = 0; + + dl::Inference inference = dl::Inference(inference_result); + InferenceWaitStatus result = WaitForInference(&inference, timeout); + + ASSERT_EQ(result.GetErrorCode(), InferenceWaitErrorCode::kTimeout); + std::cout << result.GetErrorDescription() << std::endl; + ICHECK_EQ(result.GetErrorDescription(), "Timed out while waiting for the inference to complete."); +} + +TEST(WaitForInference, InferenceError) { + const int inference_result = 3 /* Error */; + const int timeout = 0; + + dl::Inference inference = dl::Inference(inference_result); + InferenceWaitStatus result = WaitForInference(&inference, timeout); + + ASSERT_EQ(result.GetErrorCode(), InferenceWaitErrorCode::kError); + ICHECK_EQ(result.GetErrorDescription(), + "Failed to read inference result status (No such file or directory)"); +} + +} // namespace ethosn +} // namespace runtime +} // namespace tvm + +#endif