Skip to content

Commit

Permalink
[GLUTEN-4170][VL] Decouple partitions from plan to avoid driver stall…
Browse files Browse the repository at this point in the history
…ed (apache#4177)
  • Loading branch information
Yohahaha authored Jan 18, 2024
1 parent dacaa01 commit 2fc4503
Show file tree
Hide file tree
Showing 33 changed files with 353 additions and 402 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,6 @@ import scala.collection.JavaConverters._

class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {

/**
* Generate native row partition.
*
* @return
*/
override def genSplitInfo(
partition: InputPartition,
partitionSchema: StructType,
Expand Down Expand Up @@ -95,6 +90,26 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
}
}

/**
* Generate native row partition.
*
* @return
*/
override def genPartitions(
wsCtx: WholeStageTransformContext,
splitInfos: Seq[Seq[SplitInfo]]): Seq[BaseGlutenPartition] = {
splitInfos.zipWithIndex.map {
case (splitInfos, index) =>
wsCtx.substraitContext.initSplitInfosIndex(0)
wsCtx.substraitContext.setSplitInfos(splitInfos)
val substraitPlan = wsCtx.root.toProtobuf
GlutenPartition(
index,
substraitPlan.toByteArray,
locations = splitInfos.flatMap(_.preferredLocations().asScala).toArray)
}
}

/**
* Generate Iterator[ColumnarBatch] for first stage.
*
Expand Down Expand Up @@ -252,7 +267,7 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
GlutenPartition(
index,
substraitPlan.toByteArray,
splitInfo.preferredLocations().asScala.toArray)
locations = splitInfo.preferredLocations().asScala.toArray)
}
}(t => logInfo(s"Generating the Substrait plan took: $t ms."))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class MixedAffinitySuite extends QueryTest with SharedSparkSession {
}
val partition = GlutenMergeTreePartition(0, "", "", "", "fakePath", 0, 0)
val locations = affinity.getNativeMergeTreePartitionLocations(partition)
val nativePartition = GlutenPartition(0, PlanBuilder.EMPTY_PLAN, locations)
val nativePartition = GlutenPartition(0, PlanBuilder.EMPTY_PLAN, locations = locations)
assertResult(Set("forced_host_host-0")) {
nativePartition.preferredLocations().toSet
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import io.glutenproject.backendsapi.IteratorApi
import io.glutenproject.execution._
import io.glutenproject.metrics.IMetrics
import io.glutenproject.substrait.plan.PlanNode
import io.glutenproject.substrait.rel.{LocalFilesBuilder, SplitInfo}
import io.glutenproject.substrait.rel.{LocalFilesBuilder, LocalFilesNode, SplitInfo}
import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat
import io.glutenproject.utils._
import io.glutenproject.vectorized._
Expand All @@ -46,16 +46,12 @@ import java.lang.{Long => JLong}
import java.nio.charset.StandardCharsets
import java.time.ZoneOffset
import java.util.{ArrayList => JArrayList, HashMap => JHashMap, Map => JMap}
import java.util.concurrent.TimeUnit

import scala.collection.JavaConverters._

class IteratorApiImpl extends IteratorApi with Logging {

/**
* Generate native row partition.
*
* @return
*/
override def genSplitInfo(
partition: InputPartition,
partitionSchema: StructType,
Expand All @@ -79,6 +75,24 @@ class IteratorApiImpl extends IteratorApi with Logging {
}
}

/** Generate native row partition. */
override def genPartitions(
wsCtx: WholeStageTransformContext,
splitInfos: Seq[Seq[SplitInfo]]): Seq[BaseGlutenPartition] = {
// Only serialize plan once, save lots time when plan is complex.
val planByteArray = wsCtx.root.toProtobuf.toByteArray

splitInfos.zipWithIndex.map {
case (splitInfos, index) =>
GlutenPartition(
index,
planByteArray,
splitInfos.map(_.asInstanceOf[LocalFilesNode].toProtobuf.toByteArray).toArray,
splitInfos.flatMap(_.preferredLocations().asScala).toArray
)
}
}

private def constructSplitInfo(schema: StructType, files: Array[PartitionedFile]) = {
val paths = new JArrayList[String]()
val starts = new JArrayList[JLong]
Expand Down Expand Up @@ -124,25 +138,34 @@ class IteratorApiImpl extends IteratorApi with Logging {
transKernel.injectWriteFilesTempPath(path)
}

/**
* Generate Iterator[ColumnarBatch] for first stage.
*
* @return
*/
/** Generate Iterator[ColumnarBatch] for first stage. */
override def genFirstStageIterator(
inputPartition: BaseGlutenPartition,
context: TaskContext,
pipelineTime: SQLMetric,
updateInputMetrics: (InputMetricsWrapper) => Unit,
updateNativeMetrics: IMetrics => Unit,
inputIterators: Seq[Iterator[ColumnarBatch]] = Seq()): Iterator[ColumnarBatch] = {
assert(
inputPartition.isInstanceOf[GlutenPartition],
"Velox backend only accept GlutenPartition.")

val beforeBuild = System.nanoTime()
val columnarNativeIterators =
new JArrayList[GeneralInIterator](inputIterators.map {
iter => new ColumnarBatchInIterator(iter.asJava)
}.asJava)
val transKernel = NativePlanEvaluator.create()

val splitInfoByteArray = inputPartition
.asInstanceOf[GlutenPartition]
.splitInfosByteArray
val resIter: GeneralOutIterator =
transKernel.createKernelWithBatchIterator(inputPartition.plan, columnarNativeIterators)
transKernel.createKernelWithBatchIterator(
inputPartition.plan,
splitInfoByteArray,
columnarNativeIterators)
pipelineTime += TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - beforeBuild)

Iterators
.wrap(resIter.asScala)
Expand All @@ -160,11 +183,7 @@ class IteratorApiImpl extends IteratorApi with Logging {

// scalastyle:off argcount

/**
* Generate Iterator[ColumnarBatch] for final stage.
*
* @return
*/
/** Generate Iterator[ColumnarBatch] for final stage. */
override def genFinalStageIterator(
context: TaskContext,
inputIterators: Seq[Iterator[ColumnarBatch]],
Expand All @@ -186,7 +205,10 @@ class IteratorApiImpl extends IteratorApi with Logging {
val nativeResultIterator =
transKernel.createKernelWithBatchIterator(
rootNode.toProtobuf.toByteArray,
columnarNativeIterator)
// Final iterator does not contain scan split, so pass empty split info to native here.
new Array[Array[Byte]](0),
columnarNativeIterator
)

Iterators
.wrap(nativeResultIterator.asScala)
Expand Down
3 changes: 3 additions & 0 deletions cpp/core/compute/Runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class Runtime : public std::enable_shared_from_this<Runtime> {

virtual void injectWriteFilesTempPath(const std::string& path) = 0;

virtual void parseSplitInfo(const uint8_t* data, int32_t size) = 0;

// Just for benchmark
::substrait::Plan& getPlan() {
return substraitPlan_;
Expand Down Expand Up @@ -140,6 +142,7 @@ class Runtime : public std::enable_shared_from_this<Runtime> {
protected:
std::unique_ptr<ObjectStore> objStore_ = ObjectStore::create();
::substrait::Plan substraitPlan_;
std::vector<::substrait::ReadRel_LocalFiles> localFiles_;
std::optional<std::string> writeFilesTempPath_;
SparkTaskInfo taskInfo_;
// Session conf map
Expand Down
10 changes: 9 additions & 1 deletion cpp/core/jni/JniWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ Java_io_glutenproject_vectorized_PlanEvaluatorJniWrapper_nativeCreateKernelWithI
jobject wrapper,
jlong memoryManagerHandle,
jbyteArray planArr,
jobjectArray splitInfosArr,
jobjectArray iterArr,
jint stageId,
jint partitionId,
Expand All @@ -381,10 +382,17 @@ Java_io_glutenproject_vectorized_PlanEvaluatorJniWrapper_nativeCreateKernelWithI

auto spillDirStr = jStringToCString(env, spillDir);

for (jsize i = 0, splitInfoArraySize = env->GetArrayLength(splitInfosArr); i < splitInfoArraySize; i++) {
jbyteArray splitInfoArray = static_cast<jbyteArray>(env->GetObjectArrayElement(splitInfosArr, i));
jsize splitInfoSize = env->GetArrayLength(splitInfoArray);
auto splitInfoData = reinterpret_cast<const uint8_t*>(env->GetByteArrayElements(splitInfoArray, nullptr));
ctx->parseSplitInfo(splitInfoData, splitInfoSize);
}

auto planData = reinterpret_cast<const uint8_t*>(env->GetByteArrayElements(planArr, nullptr));
auto planSize = env->GetArrayLength(planArr);

ctx->parsePlan(planData, planSize, {stageId, partitionId, taskId});

auto& conf = ctx->getConfMap();

// Handle the Java iters
Expand Down
48 changes: 30 additions & 18 deletions cpp/velox/benchmarks/GenericBenchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ void populateWriterMetrics(
} // namespace

auto BM_Generic = [](::benchmark::State& state,
const std::string& substraitJsonFile,
const std::string& planFile,
const std::string& splitFile,
const std::vector<std::string>& inputFiles,
const std::unordered_map<std::string, std::string>& conf,
FileReaderType readerType) {
Expand All @@ -119,11 +120,15 @@ auto BM_Generic = [](::benchmark::State& state,
} else {
setCpu(state.thread_index());
}
bool emptySplit = splitFile.empty();
memory::MemoryManager::testingSetInstance({});
auto memoryManager = getDefaultMemoryManager();
auto runtime = Runtime::create(kVeloxRuntimeKind, conf);
const auto& filePath = substraitJsonFile;
auto plan = getPlanFromFile(filePath);
auto plan = getPlanFromFile("Plan", planFile);
std::string split;
if (!emptySplit) {
split = getPlanFromFile("ReadRel.LocalFiles", splitFile);
}
auto startTime = std::chrono::steady_clock::now();
int64_t collectBatchTime = 0;
WriterMetrics writerMetrics{};
Expand All @@ -145,6 +150,9 @@ auto BM_Generic = [](::benchmark::State& state,
}

runtime->parsePlan(reinterpret_cast<uint8_t*>(plan.data()), plan.size(), {});
if (!emptySplit) {
runtime->parseSplitInfo(reinterpret_cast<uint8_t*>(split.data()), split.size());
}
auto resultIter =
runtime->createResultIterator(memoryManager.get(), "/tmp/test-spill", std::move(inputIters), conf);
auto veloxPlan = dynamic_cast<gluten::VeloxRuntime*>(runtime)->getVeloxPlan();
Expand Down Expand Up @@ -231,6 +239,7 @@ int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);

std::string substraitJsonFile;
std::string splitFile;
std::vector<std::string> inputFiles;
std::unordered_map<std::string, std::string> conf;

Expand All @@ -242,18 +251,20 @@ int main(int argc, char** argv) {
if (argc < 2) {
LOG(INFO)
<< "No input args. Usage: " << std::endl
<< "./generic_benchmark /absolute-path/to/substrait_json_file /absolute-path/to/data_file_1 /absolute-path/to/data_file_2 ...";
<< "./generic_benchmark /absolute-path/to/substrait_json_file /absolute-path/to/split_json_file(optional)"
<< " /absolute-path/to/data_file_1 /absolute-path/to/data_file_2 ...";
LOG(INFO) << "Running example...";
inputFiles.resize(2);
substraitJsonFile = getGeneratedFilePath("example.json");
inputFiles[0] = getGeneratedFilePath("example_orders");
inputFiles[1] = getGeneratedFilePath("example_lineitem");
} else {
substraitJsonFile = argv[1];
splitFile = argv[2];
abortIfFileNotExists(substraitJsonFile);
LOG(INFO) << "Using substrait json file: " << std::endl << substraitJsonFile;
LOG(INFO) << "Using " << argc - 2 << " input data file(s): ";
for (auto i = 2; i < argc; ++i) {
for (auto i = 3; i < argc; ++i) {
inputFiles.emplace_back(argv[i]);
abortIfFileNotExists(inputFiles.back());
LOG(INFO) << inputFiles.back();
Expand All @@ -265,19 +276,20 @@ int main(int argc, char** argv) {
std::exit(EXIT_FAILURE);
}

#define GENERIC_BENCHMARK(NAME, READER_TYPE) \
do { \
auto* bm = ::benchmark::RegisterBenchmark(NAME, BM_Generic, substraitJsonFile, inputFiles, conf, READER_TYPE) \
->MeasureProcessCPUTime() \
->UseRealTime(); \
if (FLAGS_threads > 0) { \
bm->Threads(FLAGS_threads); \
} else { \
bm->ThreadRange(1, std::thread::hardware_concurrency()); \
} \
if (FLAGS_iterations > 0) { \
bm->Iterations(FLAGS_iterations); \
} \
#define GENERIC_BENCHMARK(NAME, READER_TYPE) \
do { \
auto* bm = \
::benchmark::RegisterBenchmark(NAME, BM_Generic, substraitJsonFile, splitFile, inputFiles, conf, READER_TYPE) \
->MeasureProcessCPUTime() \
->UseRealTime(); \
if (FLAGS_threads > 0) { \
bm->Threads(FLAGS_threads); \
} else { \
bm->ThreadRange(1, std::thread::hardware_concurrency()); \
} \
if (FLAGS_iterations > 0) { \
bm->Iterations(FLAGS_iterations); \
} \
} while (0)

DLOG(INFO) << "FLAGS_threads:" << FLAGS_threads;
Expand Down
4 changes: 2 additions & 2 deletions cpp/velox/benchmarks/common/BenchmarkUtils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ void initVeloxBackend() {
initVeloxBackend(bmConfMap);
}

std::string getPlanFromFile(const std::string& filePath) {
std::string getPlanFromFile(const std::string& type, const std::string& filePath) {
// Read json file and resume the binary data.
std::ifstream msgJson(filePath);
std::stringstream buffer;
buffer << msgJson.rdbuf();
std::string msgData = buffer.str();

return gluten::substraitFromJsonToPb("Plan", msgData);
return gluten::substraitFromJsonToPb(type, msgData);
}

velox::dwio::common::FileFormat getFileFormat(const std::string& fileFormat) {
Expand Down
2 changes: 1 addition & 1 deletion cpp/velox/benchmarks/common/BenchmarkUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ inline std::string getGeneratedFilePath(const std::string& fileName) {
}

/// Read binary data from a json file.
std::string getPlanFromFile(const std::string& filePath);
std::string getPlanFromFile(const std::string& type, const std::string& filePath);

/// Get the file paths, starts, lengths from a directory.
/// Use fileFormat to specify the format to read, eg., orc, parquet.
Expand Down
Loading

0 comments on commit 2fc4503

Please sign in to comment.