Skip to content

Commit

Permalink
[OPPRO-113] Improve C++ exception handling (facebookincubator#177)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer authored Jun 9, 2022
1 parent f442785 commit 7a46ecc
Show file tree
Hide file tree
Showing 10 changed files with 265 additions and 187 deletions.
14 changes: 7 additions & 7 deletions cpp/gazelle-cpp/compute/substrait_arrow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ std::shared_ptr<gluten::RecordBatchResultIterator> ArrowExecBackend::GetResultIt
std::vector<std::shared_ptr<gluten::RecordBatchResultIterator>> inputs) {
GLUTEN_ASSIGN_OR_THROW(auto decls, arrow::engine::ConvertPlan(plan_));
if (decls.size() != 1) {
throw gluten::JniPendingException("Expected 1 decl, but got " +
std::to_string(decls.size()));
throw gluten::GlutenException("Expected 1 decl, but got " +
std::to_string(decls.size()));
}
decl_ = std::make_shared<arrow::compute::Declaration>(std::move(decls[0]));

Expand All @@ -64,8 +64,8 @@ std::shared_ptr<gluten::RecordBatchResultIterator> ArrowExecBackend::GetResultIt
for (auto i = 0; i < inputs.size(); ++i) {
auto it = schema_map_.find(i);
if (it == schema_map_.end()) {
throw gluten::JniPendingException("Schema not found for input batch iterator " +
std::to_string(i));
throw gluten::GlutenException("Schema not found for input batch iterator " +
std::to_string(i));
}
auto batch_it = MakeMapIterator(
[](const std::shared_ptr<arrow::RecordBatch>& batch) {
Expand Down Expand Up @@ -179,8 +179,8 @@ void ArrowExecBackend::FieldPathToName(arrow::compute::Expression* expression,
*expr =
arrow::compute::field_ref(schema->field((field_path->indices())[0])->name());
} else {
throw gluten::JniPendingException("Field Ref is not field path: " +
field_ref->ToString());
throw gluten::GlutenException("Field Ref is not field path: " +
field_ref->ToString());
}
}
}
Expand All @@ -207,7 +207,7 @@ void ArrowExecBackend::ReplaceSourceDecls(
}

if (source_indexes.size() != source_decls.size()) {
throw gluten::JniPendingException(
throw gluten::GlutenException(
"Wrong number of source declarations. " + std::to_string(source_indexes.size()) +
" source(s) needed by source declarations, but got " +
std::to_string(source_decls.size()) + " from input batches.");
Expand Down
6 changes: 6 additions & 0 deletions cpp/gazelle-cpp/jni/jni_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "compute/substrait_arrow.h"
#include "compute/substrait_utils.h"
#include "jni/jni_errors.h"

static jint JNI_VERSION = JNI_VERSION_1_8;

Expand All @@ -31,6 +32,7 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) {
if (vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION) != JNI_OK) {
return JNI_ERR;
}
gluten::GetJniErrorsState()->Initialize(env);
std::cout << "loaded gazelle_cpp" << std::endl;
return JNI_VERSION;
}
Expand All @@ -43,15 +45,19 @@ void JNI_OnUnload(JavaVM* vm, void* reserved) {
JNIEXPORT void JNICALL
Java_io_glutenproject_vectorized_ExpressionEvaluatorJniWrapper_nativeInitNative(
JNIEnv* env, jobject obj) {
JNI_METHOD_START
gazellecpp::compute::Initialize();
gluten::SetBackendFactory(
[] { return std::make_shared<gazellecpp::compute::ArrowExecBackend>(); });
JNI_METHOD_END()
}

JNIEXPORT jboolean JNICALL
Java_io_glutenproject_vectorized_ExpressionEvaluatorJniWrapper_nativeDoValidate(
JNIEnv* env, jobject obj, jbyteArray planArray) {
JNI_METHOD_START
return true;
JNI_METHOD_END(false)
}

#ifdef __cplusplus
Expand Down
28 changes: 14 additions & 14 deletions cpp/src/compute/substrait_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,20 @@ namespace compute {
class SubstraitParser : public ExecBackendBase {
public:
SubstraitParser();
void ParseLiteral(const substrait::Expression::Literal& slit);
void ParseScalarFunction(const substrait::Expression::ScalarFunction& sfunc);
void ParseReferenceSegment(const substrait::Expression::ReferenceSegment& sref);
void ParseFieldReference(const substrait::Expression::FieldReference& sfield);
void ParseExpression(const substrait::Expression& sexpr);
void ParseType(const substrait::Type& stype);
void ParseNamedStruct(const substrait::NamedStruct& named_struct);
void ParseAggregateRel(const substrait::AggregateRel& sagg);
void ParseProjectRel(const substrait::ProjectRel& sproject);
void ParseFilterRel(const substrait::FilterRel& sfilter);
void ParseReadRel(const substrait::ReadRel& sread);
void ParseRelRoot(const substrait::RelRoot& sroot);
void ParseRel(const substrait::Rel& srel);
void ParsePlan(const substrait::Plan& splan);
void ParseLiteral(const ::substrait::Expression::Literal& slit);
void ParseScalarFunction(const ::substrait::Expression::ScalarFunction& sfunc);
void ParseReferenceSegment(const ::substrait::Expression::ReferenceSegment& sref);
void ParseFieldReference(const ::substrait::Expression::FieldReference& sfield);
void ParseExpression(const ::substrait::Expression& sexpr);
void ParseType(const ::substrait::Type& stype);
void ParseNamedStruct(const ::substrait::NamedStruct& named_struct);
void ParseAggregateRel(const ::substrait::AggregateRel& sagg);
void ParseProjectRel(const ::substrait::ProjectRel& sproject);
void ParseFilterRel(const ::substrait::FilterRel& sfilter);
void ParseReadRel(const ::substrait::ReadRel& sread);
void ParseRelRoot(const ::substrait::RelRoot& sroot);
void ParseRel(const ::substrait::Rel& srel);
void ParsePlan(const ::substrait::Plan& splan);
std::shared_ptr<RecordBatchResultIterator> GetResultIterator() override;
std::shared_ptr<RecordBatchResultIterator> GetResultIterator(
std::vector<std::shared_ptr<RecordBatchResultIterator>> inputs) override;
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/jni/exec_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class RecordBatchResultIterator : public ResultIteratorBase<arrow::RecordBatch>

inline void CheckValid() {
if (iter_ == nullptr) {
throw JniPendingException(
throw GlutenException(
"RecordBatchResultIterator: the underlying iterator has expired.");
}
}
Expand Down Expand Up @@ -139,7 +139,7 @@ class ExecBackendBase : public std::enable_shared_from_this<ExecBackendBase> {
// TODO: remove arrow::Status
GLUTEN_THROW_NOT_OK(GetIterInputSchemaFromRel(sroot.input()));
} else {
throw JniPendingException("Expect Rel as input.");
throw GlutenException("Expect Rel as input.");
}
}
if (srel.has_rel()) {
Expand Down
29 changes: 6 additions & 23 deletions cpp/src/jni/jni_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
* limitations under the License.
*/

#pragma once

#include <arrow/builder.h>
#include <arrow/io/memory.h>
#include <arrow/ipc/reader.h>
#include <arrow/ipc/writer.h>
#include <arrow/pretty_print.h>
#include <arrow/record_batch.h>
#include <arrow/status.h>
Expand All @@ -37,43 +42,21 @@
#include "compute/protobuf_utils.h"
#include "compute/substrait_utils.h"

static jclass io_exception_class;
static jclass runtime_exception_class;
static jclass unsupportedoperation_exception_class;
static jclass illegal_access_exception_class;
static jclass illegal_argument_exception_class;

jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name) {
jclass local_class = env->FindClass(class_name);
jclass global_class = (jclass)env->NewGlobalRef(local_class);
env->DeleteLocalRef(local_class);
if (global_class == nullptr) {
std::string error_message =
"Unable to createGlobalClassReference for" + std::string(class_name);
env->ThrowNew(illegal_access_exception_class, error_message.c_str());
}
return global_class;
}

jmethodID GetMethodID(JNIEnv* env, jclass this_class, const char* name, const char* sig) {
jmethodID ret = env->GetMethodID(this_class, name, sig);
if (ret == nullptr) {
std::string error_message = "Unable to find method " + std::string(name) +
" within signature" + std::string(sig);
env->ThrowNew(illegal_access_exception_class, error_message.c_str());
}

return ret;
}

jmethodID GetStaticMethodID(JNIEnv* env, jclass this_class, const char* name,
const char* sig) {
jmethodID ret = env->GetStaticMethodID(this_class, name, sig);
if (ret == nullptr) {
std::string error_message = "Unable to find static method " + std::string(name) +
" within signature" + std::string(sig);
env->ThrowNew(illegal_access_exception_class, error_message.c_str());
}
return ret;
}

Expand Down Expand Up @@ -316,7 +299,7 @@ jbyteArray ToSchemaByteArray(JNIEnv* env, std::shared_ptr<arrow::Schema> schema)
if (!status.ok()) {
std::string error_message =
"Unable to convert schema to byte array, err is " + status.message();
env->ThrowNew(io_exception_class, error_message.c_str());
throw gluten::GlutenException(error_message);
}
auto buffer = *std::move(maybe_buffer);
jbyteArray out = env->NewByteArray(buffer->size());
Expand Down
126 changes: 126 additions & 0 deletions cpp/src/jni/jni_errors.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <stdexcept>

#include "jni_common.h"
#include "utils/exception.h"

#define JNI_METHOD_START try {
// macro ended

#define JNI_METHOD_END(fallback_expr) \
} \
catch (std::exception & e) { \
env->ThrowNew(gluten::GetJniErrorsState()->RuntimeExceptionClass(), e.what()); \
return fallback_expr; \
}
// macro ended

namespace gluten {

class JniPendingException : public std::runtime_error {
public:
explicit JniPendingException(const std::string& arg) : runtime_error(arg) {}
};

void ThrowPendingException(const std::string& message) {
throw JniPendingException(message);
}

template <typename T>
T JniGetOrThrow(arrow::Result<T> result) {
if (!result.status().ok()) {
ThrowPendingException(result.status().message());
}
return std::move(result).ValueOrDie();
}

template <typename T>
T JniGetOrThrow(arrow::Result<T> result, const std::string& message) {
if (!result.status().ok()) {
ThrowPendingException(message + " - " + result.status().message());
}
return std::move(result).ValueOrDie();
}

void JniAssertOkOrThrow(arrow::Status status) {
if (!status.ok()) {
ThrowPendingException(status.message());
}
}

void JniAssertOkOrThrow(arrow::Status status, const std::string& message) {
if (!status.ok()) {
ThrowPendingException(message + " - " + status.message());
}
}

void JniThrow(const std::string& message) { ThrowPendingException(message); }

static struct JniErrorsGlobalState {
public:
virtual ~JniErrorsGlobalState() = default;

void Initialize(JNIEnv* env) {
std::lock_guard<std::mutex> lock_guard(mtx_);
io_exception_class_ = CreateGlobalClassReference(env, "Ljava/io/IOException;");
runtime_exception_class_ =
CreateGlobalClassReference(env, "Ljava/lang/RuntimeException;");
unsupportedoperation_exception_class_ =
CreateGlobalClassReference(env, "Ljava/lang/UnsupportedOperationException;");
illegal_access_exception_class_ =
CreateGlobalClassReference(env, "Ljava/lang/IllegalAccessException;");
illegal_argument_exception_class_ =
CreateGlobalClassReference(env, "Ljava/lang/IllegalArgumentException;");
}

jclass RuntimeExceptionClass() {
std::lock_guard<std::mutex> lock_guard(mtx_);
if (runtime_exception_class_ == nullptr) {
throw gluten::GlutenException(
"Fatal: JniGlobalState::Initialize(...) was not called before using the "
"utility");
}
return runtime_exception_class_;
}

jclass IllegalAccessExceptionClass() {
std::lock_guard<std::mutex> lock_guard(mtx_);
if (illegal_access_exception_class_ == nullptr) {
throw gluten::GlutenException(
"Fatal: JniGlobalState::Initialize(...) was not called before using the "
"utility");
}
return illegal_access_exception_class_;
}

private:
jclass io_exception_class_ = nullptr;
jclass runtime_exception_class_ = nullptr;
jclass unsupportedoperation_exception_class_ = nullptr;
jclass illegal_access_exception_class_ = nullptr;
jclass illegal_argument_exception_class_ = nullptr;
std::mutex mtx_;

} jni_errors_state;

static JniErrorsGlobalState* GetJniErrorsState() { return &jni_errors_state; }

} // namespace gluten
Loading

0 comments on commit 7a46ecc

Please sign in to comment.