Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
[NSE-469] Lazy Read: Iterator objects are not correctly released (#470)
Browse files Browse the repository at this point in the history
* [NSE-469] Lazy Read: Iterator objects are not correctly released

* style

* debugging

* debugging

* fixup
  • Loading branch information
zhztheplayer authored Aug 17, 2021
1 parent 5ab09b9 commit 18f2a19
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 62 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.intel.oap.execution;

import com.intel.oap.expression.ConverterUtils;
import org.apache.arrow.dataset.jni.NativeSerializedRecordBatchIterator;
import org.apache.arrow.dataset.jni.UnsafeRecordBatchSerializer;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.spark.sql.vectorized.ColumnarBatch;

import java.util.Iterator;

public class ColumnarNativeIterator implements NativeSerializedRecordBatchIterator {
private final Iterator<ColumnarBatch> delegated;
private ColumnarBatch nextBatch = null;

public ColumnarNativeIterator(Iterator<ColumnarBatch> delegated) {
this.delegated = delegated;
}

@Override
public boolean hasNext() {
while (delegated.hasNext()) {
nextBatch = delegated.next();
if (nextBatch.numRows() > 0) {
return true;
}
}
return false;
}

@Override
public byte[] next() {
ColumnarBatch dep_cb = nextBatch;
if (dep_cb.numRows() > 0) {
ArrowRecordBatch dep_rb = ConverterUtils.createArrowRecordBatch(dep_cb);
return serialize(dep_rb);
} else {
throw new IllegalStateException();
}
}

private byte[] serialize(ArrowRecordBatch batch) {
return UnsafeRecordBatchSerializer.serializeUnsafe(batch);
}

@Override
public void close() throws Exception {

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,15 @@
package com.intel.oap.vectorized;

import com.intel.oap.ColumnarPluginConfig;
import com.intel.oap.execution.ColumnarNativeIterator;
import com.intel.oap.spark.sql.execution.datasources.v2.arrow.Spiller;
import org.apache.arrow.dataset.jni.NativeSerializedRecordBatchIterator;
import org.apache.arrow.memory.ArrowBuf;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.channels.Channels;
import java.util.List;

import org.apache.arrow.dataset.jni.NativeMemoryPool;
import org.apache.arrow.dataset.jni.UnsafeRecordBatchSerializer;
import org.apache.arrow.gandiva.evaluator.SelectionVectorInt16;
import org.apache.arrow.gandiva.exceptions.GandivaException;
import org.apache.arrow.gandiva.expression.ExpressionTree;
import org.apache.arrow.gandiva.ipc.GandivaTypes;
import org.apache.arrow.memory.ArrowBuf;
import org.apache.arrow.vector.ipc.WriteChannel;
import org.apache.arrow.vector.ipc.message.ArrowBuffer;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
Expand All @@ -40,6 +35,11 @@
import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.channels.Channels;
import java.util.List;

public class ExpressionEvaluator implements AutoCloseable {
private long nativeHandler = 0;
private ExpressionEvaluatorJniWrapper jniWrapper;
Expand Down Expand Up @@ -144,7 +144,7 @@ public ArrowRecordBatch[] evaluate(ArrowRecordBatch recordBatch) throws RuntimeE
return evaluate(recordBatch, null);
}

public void evaluate(NativeSerializedRecordBatchIterator batchItr)
public void evaluate(ColumnarNativeIterator batchItr)
throws RuntimeException, IOException {
jniWrapper.nativeEvaluateWithIterator(nativeHandler,
batchItr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package com.intel.oap.vectorized;

import org.apache.arrow.dataset.jni.NativeSerializedRecordBatchIterator;
import com.intel.oap.execution.ColumnarNativeIterator;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.spark.memory.MemoryConsumer;

Expand Down Expand Up @@ -156,7 +156,7 @@ native ArrowRecordBatchBuilder[] nativeEvaluate(long nativeHandler, int numRows,
* @param nativeHandler a iterator instance carrying input record batches
*/
native void nativeEvaluateWithIterator(long nativeHandler,
NativeSerializedRecordBatchIterator batchItr) throws RuntimeException;
ColumnarNativeIterator batchItr) throws RuntimeException;

/**
* Get native kernel signature by the nativeHandler.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@ import org.apache.spark.util.{ExecutorManager, UserAddedJarUtils}
import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer

import org.apache.arrow.dataset.jni.NativeSerializedRecordBatchIterator
import org.apache.arrow.dataset.jni.UnsafeRecordBatchSerializer
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch

case class ColumnarCodegenContext(inputSchema: Schema, outputSchema: Schema, root: TreeNode) {}

trait ColumnarCodegenSupport extends SparkPlan {
Expand Down Expand Up @@ -419,31 +415,7 @@ case class ColumnarWholeStageCodegenExec(child: SparkPlan)(val codegenStageId: I

if (enableColumnarSortMergeJoinLazyRead) {
// Used as ABI to prevent from serializing buffer data
val serializedItr: NativeSerializedRecordBatchIterator = {
new NativeSerializedRecordBatchIterator {

override def hasNext: Boolean = {
depIter.hasNext
}

override def next(): Array[Byte] = {
val dep_cb = depIter.next()
if (dep_cb.numRows > 0) {
val dep_rb = ConverterUtils.createArrowRecordBatch(dep_cb)
serialize(dep_rb)
} else {
throw new IllegalStateException()
}
}

private def serialize(batch: ArrowRecordBatch) = {
UnsafeRecordBatchSerializer.serializeUnsafe(batch)
}

override def close(): Unit = {
}
}
}
val serializedItr = new ColumnarNativeIterator(depIter.asJava)
cachedRelationKernel.evaluate(serializedItr)
} else {
while (depIter.hasNext) {
Expand Down
94 changes: 71 additions & 23 deletions native-sql-engine/cpp/src/jni/jni_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,15 @@ static jmethodID serializable_obj_builder_constructor;
static jclass split_result_class;
static jmethodID split_result_constructor;

jclass serialized_record_batch_iterator_class;
static jclass serialized_record_batch_iterator_class;
static jclass metrics_builder_class;
static jmethodID metrics_builder_constructor;

static jclass unsafe_row_class;
static jmethodID unsafe_row_class_constructor;
static jmethodID unsafe_row_class_point_to;
jmethodID serialized_record_batch_iterator_hasNext;
jmethodID serialized_record_batch_iterator_next;
static jmethodID serialized_record_batch_iterator_hasNext;
static jmethodID serialized_record_batch_iterator_next;

using arrow::jni::ConcurrentMap;
static ConcurrentMap<std::shared_ptr<arrow::Buffer>> buffer_holder_;
Expand Down Expand Up @@ -124,30 +124,80 @@ arrow::Result<std::shared_ptr<arrow::RecordBatch>> FromBytes(
return batch;
}

class JavaRecordBatchIterator {
public:
explicit JavaRecordBatchIterator(JavaVM* vm,
jobject java_serialized_record_batch_iterator,
std::shared_ptr<arrow::Schema> schema)
: vm_(vm),
java_serialized_record_batch_iterator_(java_serialized_record_batch_iterator),
schema_(std::move(schema)) {}

// singleton, avoid stack instantiation
JavaRecordBatchIterator(const JavaRecordBatchIterator& itr) = delete;
JavaRecordBatchIterator(JavaRecordBatchIterator&& itr) = delete;

virtual ~JavaRecordBatchIterator() {
JNIEnv* env;
if (vm_->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION) == JNI_OK) {
#ifdef DEBUG
std::cout << "DELETING GLOBAL ITERATOR REF "
<< reinterpret_cast<long>(java_serialized_record_batch_iterator_) << "..."
<< std::endl;
#endif
env->DeleteGlobalRef(java_serialized_record_batch_iterator_);
}
}

arrow::Result<std::shared_ptr<arrow::RecordBatch>> Next() {
JNIEnv* env;
if (vm_->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION) != JNI_OK) {
return arrow::Status::Invalid("JNIEnv was not attached to current thread");
}
#ifdef DEBUG
std::cout << "PICKING ITERATOR REF "
<< reinterpret_cast<long>(java_serialized_record_batch_iterator_) << "..."
<< std::endl;
#endif
if (!env->CallBooleanMethod(java_serialized_record_batch_iterator_,
serialized_record_batch_iterator_hasNext)) {
return nullptr; // stream ended
}
auto bytes = (jbyteArray)env->CallObjectMethod(java_serialized_record_batch_iterator_,
serialized_record_batch_iterator_next);
RETURN_NOT_OK(arrow::jniutil::CheckException(env));
ARROW_ASSIGN_OR_RAISE(auto batch, FromBytes(env, schema_, bytes));
return batch;
}

private:
JavaVM* vm_;
jobject java_serialized_record_batch_iterator_;
std::shared_ptr<arrow::Schema> schema_;
};

class JavaRecordBatchIteratorWrapper {
public:
explicit JavaRecordBatchIteratorWrapper(
std::shared_ptr<JavaRecordBatchIterator> delegated)
: delegated_(std::move(delegated)) {}

arrow::Result<std::shared_ptr<arrow::RecordBatch>> Next() { return delegated_->Next(); }

private:
std::shared_ptr<JavaRecordBatchIterator> delegated_;
};

// See Java class
// org/apache/arrow/dataset/jni/NativeSerializedRecordBatchIterator
//
arrow::Result<arrow::RecordBatchIterator> MakeJavaRecordBatchIterator(
JavaVM* vm, jobject java_serialized_record_batch_iterator,
std::shared_ptr<arrow::Schema> schema) {
std::shared_ptr<arrow::Schema> schema_moved = std::move(schema);
arrow::RecordBatchIterator itr = arrow::MakeFunctionIterator(
[vm, java_serialized_record_batch_iterator,
schema_moved]() -> arrow::Result<std::shared_ptr<arrow::RecordBatch>> {
JNIEnv* env;
if (vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION) != JNI_OK) {
return arrow::Status::Invalid("JNIEnv was not attached to current thread");
}
if (!env->CallBooleanMethod(java_serialized_record_batch_iterator,
serialized_record_batch_iterator_hasNext)) {
return nullptr; // stream ended
}
auto bytes = (jbyteArray)env->CallObjectMethod(
java_serialized_record_batch_iterator, serialized_record_batch_iterator_next);
RETURN_NOT_OK(arrow::jniutil::CheckException(env));
ARROW_ASSIGN_OR_RAISE(auto batch, FromBytes(env, schema_moved, bytes));
return batch;
});
arrow::RecordBatchIterator itr = arrow::Iterator<std::shared_ptr<arrow::RecordBatch>>(
JavaRecordBatchIteratorWrapper(std::make_shared<JavaRecordBatchIterator>(
vm, java_serialized_record_batch_iterator, schema_moved)));
return itr;
}

Expand Down Expand Up @@ -253,9 +303,7 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) {
unsafe_row_class_constructor = GetMethodID(env, unsafe_row_class, "<init>", "(I)V");
unsafe_row_class_point_to = GetMethodID(env, unsafe_row_class, "pointTo", "([BI)V");
serialized_record_batch_iterator_class =
CreateGlobalClassReference(env,
"Lorg/apache/arrow/"
"dataset/jni/NativeSerializedRecordBatchIterator;");
CreateGlobalClassReference(env, "Lcom/intel/oap/execution/ColumnarNativeIterator;");
serialized_record_batch_iterator_hasNext =
GetMethodID(env, serialized_record_batch_iterator_class, "hasNext", "()Z");
serialized_record_batch_iterator_next =
Expand Down

0 comments on commit 18f2a19

Please sign in to comment.