Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ETHOSN] Throw error message when inference fails #13022

Merged
merged 2 commits into from
Nov 3, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,9 @@ if(GTEST_FOUND)
if(DEFINED LLVM_LIBS)
target_link_libraries(cpptest PRIVATE ${LLVM_LIBS})
endif()
if(DEFINED ETHOSN_RUNTIME_LIBRARY)
target_link_libraries(cpptest PRIVATE ${ETHOSN_RUNTIME_LIBRARY})
endif()
set_target_properties(cpptest PROPERTIES EXCLUDE_FROM_ALL 1)
set_target_properties(cpptest PROPERTIES EXCLUDE_FROM_DEFAULT_BUILD 1)
if(USE_RELAY_DEBUG)
Expand Down
67 changes: 41 additions & 26 deletions src/runtime/contrib/ethosn/ethosn_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

#include <algorithm>
#include <memory>
#include <string>

#include "ethosn_driver_library/Buffer.hpp"
#include "ethosn_runtime.h"
Expand All @@ -48,7 +49,7 @@ namespace ethosn {

namespace dl = ::ethosn::driver_library;

bool WaitForInference(dl::Inference* inference, int timeout) {
WaitStatus WaitForInference(dl::Inference* inference, int timeout) {
// Wait for inference to complete
int fd = inference->GetFileDescriptor();
struct pollfd fds;
Expand All @@ -58,20 +59,29 @@ bool WaitForInference(dl::Inference* inference, int timeout) {

const int ms_per_seconds = 1000;
int poll_result = poll(&fds, 1, timeout * ms_per_seconds);
if (poll_result > 0) {
dl::InferenceResult result;
if (read(fd, &result, sizeof(result)) != sizeof(result)) {
return false;
}
if (result != dl::InferenceResult::Completed) {
return false;
}
int poll_error_code = errno;

if (poll_result < 0) {
return WaitStatus(WaitErrorCode::Error, "Error while waiting for the inference to complete (" +
std::string(strerror(poll_error_code)) + ")");
} else if (poll_result == 0) {
return false;
} else {
return false;
return WaitStatus(WaitErrorCode::Timeout,
"Timed out while waiting for the inference to complete.");
}
return true;

// poll_result > 0
dl::InferenceResult npu_result;
if (read(fd, &npu_result, sizeof(npu_result)) != static_cast<ssize_t>(sizeof(npu_result))) {
return WaitStatus(WaitErrorCode::Error, "Failed to read inference result status (" +
std::string(strerror(poll_error_code)) + ")");
}

if (npu_result != dl::InferenceResult::Completed) {
return WaitStatus(WaitErrorCode::Error, "Inference failed with status " +
std::to_string(static_cast<uint32_t>(npu_result)));
}

return WaitStatus(WaitErrorCode::Success);
}

void CreateBuffers(std::vector<std::shared_ptr<dl::Buffer>>* fm,
Expand Down Expand Up @@ -123,21 +133,26 @@ bool Inference(tvm::runtime::TVMArgs args, dl::Network* npu,
}

// Execute the inference.
std::unique_ptr<dl::Inference> result(
std::unique_ptr<dl::Inference> inference(
npu->ScheduleInference(ifm_raw, n_inputs, ofm_raw, n_outputs));
bool inferenceCompleted = WaitForInference(result.get(), 60);
if (inferenceCompleted) {
for (size_t i = 0; i < n_outputs; i++) {
DLTensor* tensor = outputs[i];
dl::Buffer* source_buffer = ofm_raw[i];
uint8_t* dest_buffer = static_cast<uint8_t*>(tensor->data);
size_t size = source_buffer->GetSize();
uint8_t* source_buffer_data = source_buffer->Map();
std::copy(source_buffer_data, source_buffer_data + size, dest_buffer);
source_buffer->Unmap();
}
WaitStatus result = WaitForInference(inference.get(), 60);
lhutton1 marked this conversation as resolved.
Show resolved Hide resolved

if (result.GetErrorCode() != WaitErrorCode::Success) {
LOG(FATAL) << "An error has occured waiting for the inference of a sub-graph on the NPU: "
<< result.GetErrorDescription();
}

for (size_t i = 0; i < n_outputs; i++) {
DLTensor* tensor = outputs[i];
dl::Buffer* source_buffer = ofm_raw[i];
uint8_t* dest_buffer = static_cast<uint8_t*>(tensor->data);
size_t size = source_buffer->GetSize();
uint8_t* source_buffer_data = source_buffer->Map();
std::copy(source_buffer_data, source_buffer_data + size, dest_buffer);
source_buffer->Unmap();
}
return inferenceCompleted;

return true;
}

} // namespace ethosn
Expand Down
32 changes: 32 additions & 0 deletions src/runtime/contrib/ethosn/ethosn_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,38 @@ class EthosnModule : public ModuleNode {
std::map<std::string, OrderedCompiledNetwork> network_map_;
};

/*!
* \brief Error codes for evaluating the result of inference on the NPU.
*/
enum class WaitErrorCode { Success = 0, Timeout = 1, Error = 2 };

/*!
* \brief A helper class holding the status of inference on the NPU and
* associated error message(s) if any occurred.
*/
class WaitStatus {
public:
WaitStatus() : error_code_(WaitErrorCode::Success), error_description_("") {}

explicit WaitStatus(WaitErrorCode errorCode, std::string errorDescription = "")
: error_code_(errorCode), error_description_(errorDescription) {}

WaitStatus(const WaitStatus&) = default;
WaitStatus(WaitStatus&&) = default;
WaitStatus& operator=(const WaitStatus&) = default;
WaitStatus& operator=(WaitStatus&&) = default;

explicit operator bool() const noexcept { return error_code_ == WaitErrorCode::Success; }

WaitErrorCode GetErrorCode() const { return error_code_; }

std::string GetErrorDescription() const { return error_description_; }

private:
WaitErrorCode error_code_;
std::string error_description_;
};

} // namespace ethosn
} // namespace runtime
} // namespace tvm
Expand Down
60 changes: 60 additions & 0 deletions tests/cpp/runtime/contrib/ethosn/inference_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tests/cpp/runtime/contrib/ethosn/inference_test.cc
* \brief Tests to check runtime components used during inference.
lhutton1 marked this conversation as resolved.
Show resolved Hide resolved
*/

#ifdef ETHOSN_HW

#include <gtest/gtest.h>

#include "../../../../../src/runtime/contrib/ethosn/ethosn_device.cc"

namespace tvm {
namespace runtime {
namespace ethosn {

TEST(WaitForInference, FailedResultRead) {
const int inference_error = 3;
const int timeout = 0;
dl::Inference inference = dl::Inference(inference_error);
WaitStatus result = WaitForInference(&inference, timeout);

ASSERT_EQ(result.GetErrorCode(), WaitErrorCode::Error);
ICHECK_EQ(result.GetErrorDescription(),
"Failed to read inference result status (No such file or directory)");
}

TEST(WaitForInference, InferenceTimeout) {
const int inference_scheduled = 0;
const int timeout = 0;
dl::Inference inference = dl::Inference(inference_scheduled);
lhutton1 marked this conversation as resolved.
Show resolved Hide resolved
WaitStatus result = WaitForInference(&inference, timeout);

ASSERT_EQ(result.GetErrorCode(), WaitErrorCode::Timeout);
ICHECK_EQ(result.GetErrorDescription(), "Timed out while waiting for the inference to complete.");
}

} // namespace ethosn
} // namespace runtime
} // namespace tvm

#endif