Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-4170][VL] Decouple partitions from plan to avoid driver stalled #4177

Merged
merged 12 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,6 @@ import scala.collection.JavaConverters._

class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {

/**
* Generate native row partition.
*
* @return
*/
override def genSplitInfo(
partition: InputPartition,
partitionSchema: StructType,
Expand Down Expand Up @@ -95,6 +90,26 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
}
}

/**
* Generate native row partition.
*
* @return
*/
override def genPartitions(
wsCtx: WholeStageTransformContext,
splitInfos: Seq[Seq[SplitInfo]]): Seq[BaseGlutenPartition] = {
splitInfos.zipWithIndex.map {
case (splitInfos, index) =>
wsCtx.substraitContext.initSplitInfosIndex(0)
wsCtx.substraitContext.setSplitInfos(splitInfos)
val substraitPlan = wsCtx.root.toProtobuf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming the proposed optimization can also be applied for CH backend. If so, it will need some follow-up work from CH engineer. @baibaichen

GlutenPartition(
index,
substraitPlan.toByteArray,
locations = splitInfos.flatMap(_.preferredLocations().asScala).toArray)
}
}

/**
* Generate Iterator[ColumnarBatch] for first stage.
*
Expand Down Expand Up @@ -252,7 +267,7 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
GlutenPartition(
index,
substraitPlan.toByteArray,
splitInfo.preferredLocations().asScala.toArray)
locations = splitInfo.preferredLocations().asScala.toArray)
}
}(t => logInfo(s"Generating the Substrait plan took: $t ms."))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class MixedAffinitySuite extends QueryTest with SharedSparkSession {
}
val partition = GlutenMergeTreePartition(0, "", "", "", "fakePath", 0, 0)
val locations = affinity.getNativeMergeTreePartitionLocations(partition)
val nativePartition = GlutenPartition(0, PlanBuilder.EMPTY_PLAN, locations)
val nativePartition = GlutenPartition(0, PlanBuilder.EMPTY_PLAN, locations = locations)
assertResult(Set("forced_host_host-0")) {
nativePartition.preferredLocations().toSet
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import io.glutenproject.backendsapi.IteratorApi
import io.glutenproject.execution._
import io.glutenproject.metrics.IMetrics
import io.glutenproject.substrait.plan.PlanNode
import io.glutenproject.substrait.rel.{LocalFilesBuilder, SplitInfo}
import io.glutenproject.substrait.rel.{LocalFilesBuilder, LocalFilesNode, SplitInfo}
import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat
import io.glutenproject.utils.Iterators
import io.glutenproject.vectorized._
Expand All @@ -47,16 +47,12 @@ import java.net.URLDecoder
import java.nio.charset.StandardCharsets
import java.time.ZoneOffset
import java.util.{ArrayList => JArrayList, HashMap => JHashMap, Map => JMap}
import java.util.concurrent.TimeUnit

import scala.collection.JavaConverters._

class IteratorApiImpl extends IteratorApi with Logging {

/**
* Generate native row partition.
*
* @return
*/
override def genSplitInfo(
partition: InputPartition,
partitionSchema: StructType,
Expand All @@ -80,6 +76,24 @@ class IteratorApiImpl extends IteratorApi with Logging {
}
}

/** Generate native row partition. */
override def genPartitions(
wsCtx: WholeStageTransformContext,
splitInfos: Seq[Seq[SplitInfo]]): Seq[BaseGlutenPartition] = {
// Only serialize plan once, save lots time when plan is complex.
val planByteArray = wsCtx.root.toProtobuf.toByteArray
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the plan binary size is big enough, e.g, bigger than 1MB, we can broadcast it to reduce the task serialization time


splitInfos.zipWithIndex.map {
case (splitInfos, index) =>
GlutenPartition(
index,
planByteArray,
splitInfos.map(_.asInstanceOf[LocalFilesNode].toProtobuf.toByteArray).toArray,
splitInfos.flatMap(_.preferredLocations().asScala).toArray
)
}
}

private def constructSplitInfo(schema: StructType, files: Array[PartitionedFile]) = {
val paths = new JArrayList[String]()
val starts = new JArrayList[JLong]
Expand Down Expand Up @@ -121,25 +135,34 @@ class IteratorApiImpl extends IteratorApi with Logging {
transKernel.injectWriteFilesTempPath(path)
}

/**
* Generate Iterator[ColumnarBatch] for first stage.
*
* @return
*/
/** Generate Iterator[ColumnarBatch] for first stage. */
override def genFirstStageIterator(
inputPartition: BaseGlutenPartition,
context: TaskContext,
pipelineTime: SQLMetric,
updateInputMetrics: (InputMetricsWrapper) => Unit,
updateNativeMetrics: IMetrics => Unit,
inputIterators: Seq[Iterator[ColumnarBatch]] = Seq()): Iterator[ColumnarBatch] = {
assert(
inputPartition.isInstanceOf[GlutenPartition],
"Velox backend only accept GlutenPartition.")

val beforeBuild = System.nanoTime()
val columnarNativeIterators =
new JArrayList[GeneralInIterator](inputIterators.map {
iter => new ColumnarBatchInIterator(iter.asJava)
}.asJava)
val transKernel = NativePlanEvaluator.create()

val splitInfoByteArray = inputPartition
.asInstanceOf[GlutenPartition]
.splitInfosByteArray
val resIter: GeneralOutIterator =
transKernel.createKernelWithBatchIterator(inputPartition.plan, columnarNativeIterators)
transKernel.createKernelWithBatchIterator(
inputPartition.plan,
splitInfoByteArray,
columnarNativeIterators)
pipelineTime += TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - beforeBuild)

Iterators
.wrap(resIter.asScala)
Expand All @@ -157,11 +180,7 @@ class IteratorApiImpl extends IteratorApi with Logging {

// scalastyle:off argcount

/**
* Generate Iterator[ColumnarBatch] for final stage.
*
* @return
*/
/** Generate Iterator[ColumnarBatch] for final stage. */
override def genFinalStageIterator(
context: TaskContext,
inputIterators: Seq[Iterator[ColumnarBatch]],
Expand All @@ -183,7 +202,10 @@ class IteratorApiImpl extends IteratorApi with Logging {
val nativeResultIterator =
transKernel.createKernelWithBatchIterator(
rootNode.toProtobuf.toByteArray,
columnarNativeIterator)
// Final iterator does not contain scan split, so pass empty split info to native here.
new Array[Array[Byte]](0),
columnarNativeIterator
)

Iterators
.wrap(nativeResultIterator.asScala)
Expand Down
3 changes: 3 additions & 0 deletions cpp/core/compute/Runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class Runtime : public std::enable_shared_from_this<Runtime> {

virtual void injectWriteFilesTempPath(const std::string& path) = 0;

virtual void parseSplitInfo(const uint8_t* data, int32_t size) = 0;

// Just for benchmark
::substrait::Plan& getPlan() {
return substraitPlan_;
Expand Down Expand Up @@ -140,6 +142,7 @@ class Runtime : public std::enable_shared_from_this<Runtime> {
protected:
std::unique_ptr<ObjectStore> objStore_ = ObjectStore::create();
::substrait::Plan substraitPlan_;
std::vector<::substrait::ReadRel_LocalFiles> localFiles_;
std::optional<std::string> writeFilesTempPath_;
SparkTaskInfo taskInfo_;
// Session conf map
Expand Down
10 changes: 9 additions & 1 deletion cpp/core/jni/JniWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ Java_io_glutenproject_vectorized_PlanEvaluatorJniWrapper_nativeCreateKernelWithI
jobject wrapper,
jlong memoryManagerHandle,
jbyteArray planArr,
jobjectArray splitInfosArr,
jobjectArray iterArr,
jint stageId,
jint partitionId,
Expand All @@ -381,10 +382,17 @@ Java_io_glutenproject_vectorized_PlanEvaluatorJniWrapper_nativeCreateKernelWithI

auto spillDirStr = jStringToCString(env, spillDir);

for (jsize i = 0, splitInfoArraySize = env->GetArrayLength(splitInfosArr); i < splitInfoArraySize; i++) {
jbyteArray splitInfoArray = static_cast<jbyteArray>(env->GetObjectArrayElement(splitInfosArr, i));
jsize splitInfoSize = env->GetArrayLength(splitInfoArray);
auto splitInfoData = reinterpret_cast<const uint8_t*>(env->GetByteArrayElements(splitInfoArray, nullptr));
ctx->parseSplitInfo(splitInfoData, splitInfoSize);
}

auto planData = reinterpret_cast<const uint8_t*>(env->GetByteArrayElements(planArr, nullptr));
auto planSize = env->GetArrayLength(planArr);

ctx->parsePlan(planData, planSize, {stageId, partitionId, taskId});

auto& conf = ctx->getConfMap();

// Handle the Java iters
Expand Down
48 changes: 30 additions & 18 deletions cpp/velox/benchmarks/GenericBenchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ void populateWriterMetrics(
} // namespace

auto BM_Generic = [](::benchmark::State& state,
const std::string& substraitJsonFile,
const std::string& planFile,
const std::string& splitFile,
const std::vector<std::string>& inputFiles,
const std::unordered_map<std::string, std::string>& conf,
FileReaderType readerType) {
Expand All @@ -119,11 +120,15 @@ auto BM_Generic = [](::benchmark::State& state,
} else {
setCpu(state.thread_index());
}
bool emptySplit = splitFile.empty();
memory::MemoryManager::testingSetInstance({});
auto memoryManager = getDefaultMemoryManager();
auto runtime = Runtime::create(kVeloxRuntimeKind, conf);
const auto& filePath = substraitJsonFile;
auto plan = getPlanFromFile(filePath);
auto plan = getPlanFromFile("Plan", planFile);
std::string split;
if (!emptySplit) {
split = getPlanFromFile("ReadRel.LocalFiles", splitFile);
}
auto startTime = std::chrono::steady_clock::now();
int64_t collectBatchTime = 0;
WriterMetrics writerMetrics{};
Expand All @@ -145,6 +150,9 @@ auto BM_Generic = [](::benchmark::State& state,
}

runtime->parsePlan(reinterpret_cast<uint8_t*>(plan.data()), plan.size(), {});
if (!emptySplit) {
runtime->parseSplitInfo(reinterpret_cast<uint8_t*>(split.data()), split.size());
}
auto resultIter =
runtime->createResultIterator(memoryManager.get(), "/tmp/test-spill", std::move(inputIters), conf);
auto veloxPlan = dynamic_cast<gluten::VeloxRuntime*>(runtime)->getVeloxPlan();
Expand Down Expand Up @@ -231,6 +239,7 @@ int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);

std::string substraitJsonFile;
std::string splitFile;
std::vector<std::string> inputFiles;
std::unordered_map<std::string, std::string> conf;

Expand All @@ -242,18 +251,20 @@ int main(int argc, char** argv) {
if (argc < 2) {
LOG(INFO)
<< "No input args. Usage: " << std::endl
<< "./generic_benchmark /absolute-path/to/substrait_json_file /absolute-path/to/data_file_1 /absolute-path/to/data_file_2 ...";
<< "./generic_benchmark /absolute-path/to/substrait_json_file /absolute-path/to/split_json_file(optional)"
<< " /absolute-path/to/data_file_1 /absolute-path/to/data_file_2 ...";
LOG(INFO) << "Running example...";
inputFiles.resize(2);
substraitJsonFile = getGeneratedFilePath("example.json");
inputFiles[0] = getGeneratedFilePath("example_orders");
inputFiles[1] = getGeneratedFilePath("example_lineitem");
} else {
substraitJsonFile = argv[1];
splitFile = argv[2];
abortIfFileNotExists(substraitJsonFile);
LOG(INFO) << "Using substrait json file: " << std::endl << substraitJsonFile;
LOG(INFO) << "Using " << argc - 2 << " input data file(s): ";
for (auto i = 2; i < argc; ++i) {
for (auto i = 3; i < argc; ++i) {
inputFiles.emplace_back(argv[i]);
abortIfFileNotExists(inputFiles.back());
LOG(INFO) << inputFiles.back();
Expand All @@ -265,19 +276,20 @@ int main(int argc, char** argv) {
std::exit(EXIT_FAILURE);
}

#define GENERIC_BENCHMARK(NAME, READER_TYPE) \
do { \
auto* bm = ::benchmark::RegisterBenchmark(NAME, BM_Generic, substraitJsonFile, inputFiles, conf, READER_TYPE) \
->MeasureProcessCPUTime() \
->UseRealTime(); \
if (FLAGS_threads > 0) { \
bm->Threads(FLAGS_threads); \
} else { \
bm->ThreadRange(1, std::thread::hardware_concurrency()); \
} \
if (FLAGS_iterations > 0) { \
bm->Iterations(FLAGS_iterations); \
} \
#define GENERIC_BENCHMARK(NAME, READER_TYPE) \
do { \
auto* bm = \
::benchmark::RegisterBenchmark(NAME, BM_Generic, substraitJsonFile, splitFile, inputFiles, conf, READER_TYPE) \
->MeasureProcessCPUTime() \
->UseRealTime(); \
if (FLAGS_threads > 0) { \
bm->Threads(FLAGS_threads); \
} else { \
bm->ThreadRange(1, std::thread::hardware_concurrency()); \
} \
if (FLAGS_iterations > 0) { \
bm->Iterations(FLAGS_iterations); \
} \
} while (0)

DLOG(INFO) << "FLAGS_threads:" << FLAGS_threads;
Expand Down
4 changes: 2 additions & 2 deletions cpp/velox/benchmarks/common/BenchmarkUtils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ void initVeloxBackend() {
initVeloxBackend(bmConfMap);
}

std::string getPlanFromFile(const std::string& filePath) {
std::string getPlanFromFile(const std::string& type, const std::string& filePath) {
// Read json file and resume the binary data.
std::ifstream msgJson(filePath);
std::stringstream buffer;
buffer << msgJson.rdbuf();
std::string msgData = buffer.str();

return gluten::substraitFromJsonToPb("Plan", msgData);
return gluten::substraitFromJsonToPb(type, msgData);
}

velox::dwio::common::FileFormat getFileFormat(const std::string& fileFormat) {
Expand Down
2 changes: 1 addition & 1 deletion cpp/velox/benchmarks/common/BenchmarkUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ inline std::string getGeneratedFilePath(const std::string& fileName) {
}

/// Read binary data from a json file.
std::string getPlanFromFile(const std::string& filePath);
std::string getPlanFromFile(const std::string& type, const std::string& filePath);

/// Get the file paths, starts, lengths from a directory.
/// Use fileFormat to specify the format to read, eg., orc, parquet.
Expand Down
Loading
Loading