Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
exmy committed Jan 23, 2024
1 parent c6f1270 commit 04834ff
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import io.glutenproject.substrait.extensions.AdvancedExtensionNode;
import io.glutenproject.substrait.extensions.ExtensionBuilder;
import io.glutenproject.substrait.plan.PlanBuilder;
import io.glutenproject.substrait.plan.PlanNode;

import org.apache.spark.SparkConf;
import org.apache.spark.sql.internal.SQLConf;
Expand Down Expand Up @@ -56,7 +55,7 @@ public void initNative(SparkConf conf) {
// Get the customer config from SparkConf for each backend
BackendsApiManager.getTransformerApiInstance().postProcessNativeConfig(nativeConfMap, prefix);

jniWrapper.nativeInitNative(buildNativeConfNode(nativeConfMap).toProtobuf().toByteArray());
jniWrapper.nativeInitNative(buildNativeConf(nativeConfMap));
}

public void finalizeNative() {
Expand All @@ -68,13 +67,18 @@ public boolean doValidate(byte[] subPlan) {
return jniWrapper.nativeDoValidate(subPlan);
}

private PlanNode buildNativeConfNode(Map<String, String> confs) {
private byte[] buildNativeConf(Map<String, String> confs) {
StringMapNode stringMapNode = ExpressionBuilder.makeStringMap(confs);
AdvancedExtensionNode extensionNode =
ExtensionBuilder.makeAdvancedExtension(
BackendsApiManager.getTransformerApiInstance()
.packPBMessage(stringMapNode.toProtobuf()));
return PlanBuilder.makePlan(extensionNode);
return PlanBuilder.makePlan(extensionNode).toProtobuf().toByteArray();
}

private Map<String, String> getNativeBackendConf() {
return GlutenConfig.getNativeBackendConf(
BackendsApiManager.getSettings().getBackendConfigPrefix(), SQLConf.get().getAllConfs());
}

// Used by WholeStageTransform to create the native computing pipeline and
Expand All @@ -91,12 +95,7 @@ public GeneralOutIterator createKernelWithBatchIterator(
wsPlan,
splitInfo,
iterList.toArray(new GeneralInIterator[0]),
buildNativeConfNode(
GlutenConfig.getNativeBackendConf(
BackendsApiManager.getSettings().getBackendConfigPrefix(),
SQLConf.get().getAllConfs()))
.toProtobuf()
.toByteArray(),
buildNativeConf(getNativeBackendConf()),
materializeInput);
return createOutIterator(handle);
}
Expand All @@ -110,12 +109,7 @@ public GeneralOutIterator createKernelWithBatchIterator(
wsPlan,
splitInfo,
iterList.toArray(new GeneralInIterator[0]),
buildNativeConfNode(
GlutenConfig.getNativeBackendConf(
BackendsApiManager.getSettings().getBackendConfigPrefix(),
SQLConf.get().getAllConfs()))
.toProtobuf()
.toByteArray(),
buildNativeConf(getNativeBackendConf()),
false);
return createOutIterator(handle);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ import io.glutenproject.backendsapi.IteratorApi
import io.glutenproject.execution._
import io.glutenproject.metrics.{GlutenTimeMetric, IMetrics, NativeMetrics}
import io.glutenproject.substrait.plan.PlanNode
import io.glutenproject.substrait.rel.{ExtensionTableBuilder, LocalFilesBuilder, LocalFilesNode, SplitInfo}
import io.glutenproject.substrait.rel.{ExtensionTableBuilder, ExtensionTableNode, LocalFilesBuilder, LocalFilesNode, SplitInfo}
import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat
import io.glutenproject.utils.{LogLevelUtil, SubstraitPlanPrinterUtil}
import io.glutenproject.utils.LogLevelUtil
import io.glutenproject.vectorized.{CHNativeExpressionEvaluator, CloseableCHColumnBatchIterator, GeneralInIterator, GeneralOutIterator}

import org.apache.spark.{InterruptibleIterator, SparkConf, SparkContext, TaskContext}
Expand Down Expand Up @@ -86,7 +86,7 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
fileFormat,
preferredLocations.toList.asJava)
case _ =>
throw new UnsupportedOperationException(s"Unsupported input partition.")
throw new UnsupportedOperationException(s"Unsupported input partition: $partition.")
}
}

Expand All @@ -105,10 +105,16 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
case (splits, index) =>
val splitInfosByteArray = splits.zipWithIndex.map {
case (split, i) =>
val filesNode = split.asInstanceOf[LocalFilesNode]
filesNode.setFileSchema(scans(i).getDataSchema)
filesNode.setFileReadProperties(mapAsJavaMap(scans(i).getProperties))
filesNode.toProtobuf.toByteArray
split match {
case filesNode: LocalFilesNode =>
// todo: set file schema only scan is HiveTableScan
// and check Case-insensitive schema matching
filesNode.setFileSchema(scans(i).getDataSchema)
filesNode.setFileReadProperties(mapAsJavaMap(scans(i).getProperties))
filesNode.toProtobuf.toByteArray
case extensionTableNode: ExtensionTableNode =>
extensionTableNode.toProtobuf.toByteArray
}
}

GlutenPartition(
Expand Down Expand Up @@ -269,27 +275,22 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
/** Generate Native FileScanRDD, currently only for ClickHouse Backend. */
override def genNativeFileScanRDD(
sparkContext: SparkContext,
wsCxt: WholeStageTransformContext,
wsCtx: WholeStageTransformContext,
splitInfos: Seq[SplitInfo],
numOutputRows: SQLMetric,
numOutputBatches: SQLMetric,
scanTime: SQLMetric): RDD[ColumnarBatch] = {
val substraitPlanPartition = GlutenTimeMetric.withMillisTime {
val planByteArray = wsCtx.root.toProtobuf.toByteArray
splitInfos.zipWithIndex.map {
case (splitInfo, index) =>
wsCxt.substraitContext.initSplitInfosIndex(0)
wsCxt.substraitContext.setSplitInfos(Seq(splitInfo))
val substraitPlan = wsCxt.root.toProtobuf
if (index == 0) {
logOnLevel(
GlutenConfig.getConf.substraitPlanLogLevel,
s"The substrait plan for partition $index:\n${SubstraitPlanPrinterUtil
.substraitPlanToJson(substraitPlan)}"
)
}
// wsCtx.substraitContext.initSplitInfosIndex(0)
// wsCtx.substraitContext.setSplitInfos(Seq(splitInfo))
// val substraitPlan = wsCtx.root.toProtobuf
GlutenPartition(
index,
substraitPlan.toByteArray,
planByteArray,
Array(splitInfo.asInstanceOf[ExtensionTableNode].toProtobuf.toByteArray),
locations = splitInfo.preferredLocations().asScala.toArray)
}
}(t => logInfo(s"Generating the Substrait plan took: $t ms."))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,21 @@ class NativeFileScanColumnarRDD(
override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
val inputPartition = castNativePartition(split)

assert(
inputPartition.isInstanceOf[GlutenPartition],
"NativeFileScanColumnarRDD only accepts GlutenPartition.")

val splitInfoByteArray = inputPartition
.asInstanceOf[GlutenPartition]
.splitInfosByteArray

val resIter: GeneralOutIterator = GlutenTimeMetric.millis(scanTime) {
_ =>
val transKernel = new CHNativeExpressionEvaluator()
val inBatchIters = new util.ArrayList[GeneralInIterator]()
transKernel.createKernelWithBatchIterator(
inputPartition.plan,
// The substraitPlanPartition contains ExtensionTableNode
// and doesn't need separate SplitInfo
new Array[Array[Byte]](0),
splitInfoByteArray,
inBatchIters,
false
)
Expand Down
27 changes: 18 additions & 9 deletions cpp-ch/local-engine/Parser/MergeTreeRelParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <Common/MergeTreeTool.h>

#include "MergeTreeRelParser.h"
#include "substrait/algebra.pb.h"


namespace DB
Expand Down Expand Up @@ -59,17 +60,25 @@ static Int64 findMinPosition(const NameSet & condition_table_columns, const Name
}

DB::QueryPlanPtr
MergeTreeRelParser::parse(DB::QueryPlanPtr query_plan, const substrait::Rel & rel_, std::list<const substrait::Rel *> & /*rel_stack_*/)
MergeTreeRelParser::parse(DB::QueryPlanPtr, const substrait::Rel &, std::list<const substrait::Rel *> &)
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "MergeTreeRelParser can't call parse(), call parseReadRel instead.");
}

DB::QueryPlanPtr
MergeTreeRelParser::parseReadRel(
DB::QueryPlanPtr query_plan,
const substrait::ReadRel & read,
const substrait::ReadRel::ExtensionTable * extension_table,
std::list<const substrait::Rel *> & /*rel_stack_*/)
{
const auto & rel = rel_.read();
assert(rel.has_extension_table());
google::protobuf::StringValue table;
table.ParseFromString(rel.extension_table().detail().value());
table.ParseFromString(extension_table->detail().value());
auto merge_tree_table = local_engine::parseMergeTreeTableString(table.value());
DB::Block header;
if (rel.has_base_schema() && rel.base_schema().names_size())
if (read.has_base_schema() && read.base_schema().names_size())
{
header = TypeParser::buildBlockFromNamedStruct(rel.base_schema());
header = TypeParser::buildBlockFromNamedStruct(read.base_schema());
}
else
{
Expand Down Expand Up @@ -114,11 +123,11 @@ MergeTreeRelParser::parse(DB::QueryPlanPtr query_plan, const substrait::Rel & re
auto query_info = buildQueryInfo(names_and_types_list);

std::set<String> non_nullable_columns;
if (rel.has_filter())
if (read.has_filter())
{
NonNullableColumnsResolver non_nullable_columns_resolver(header, *getPlanParser(), rel.filter());
NonNullableColumnsResolver non_nullable_columns_resolver(header, *getPlanParser(), read.filter());
non_nullable_columns = non_nullable_columns_resolver.resolve();
query_info->prewhere_info = parsePreWhereInfo(rel.filter(), header);
query_info->prewhere_info = parsePreWhereInfo(read.filter(), header);
}
auto data_parts = query_context.custom_storage_merge_tree->getAllDataPartsVector();
int min_block = merge_tree_table.min_block;
Expand Down
9 changes: 7 additions & 2 deletions cpp-ch/local-engine/Parser/MergeTreeRelParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,13 @@ class MergeTreeRelParser : public RelParser

~MergeTreeRelParser() override = default;

DB::QueryPlanPtr
parse(DB::QueryPlanPtr query_plan, const substrait::Rel & rel, std::list<const substrait::Rel *> & rel_stack_) override;
DB::QueryPlanPtr parse(DB::QueryPlanPtr query_plan, const substrait::Rel & rel, std::list<const substrait::Rel *> & rel_stack_) override;

DB::QueryPlanPtr parseReadRel(
DB::QueryPlanPtr query_plan,
const substrait::ReadRel & read,
const substrait::ReadRel::ExtensionTable * extension_table,
std::list<const substrait::Rel *> & rel_stack_);

const substrait::Rel & getSingleInput(const substrait::Rel &) override
{
Expand Down
103 changes: 69 additions & 34 deletions cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
#include <Common/MergeTreeTool.h>
#include <Common/logger_useful.h>
#include <Common/typeid_cast.h>
#include "substrait/algebra.pb.h"

namespace DB
{
Expand Down Expand Up @@ -251,14 +252,13 @@ bool SerializedPlanParser::isReadRelFromJava(const substrait::ReadRel & rel)
return rel.has_local_files() && rel.local_files().items().size() == 1 && rel.local_files().items().at(0).uri_file().starts_with("iterator");
}

QueryPlanStepPtr SerializedPlanParser::parseReadRealWithLocalFile(const substrait::ReadRel & rel)
QueryPlanStepPtr SerializedPlanParser::parseReadRealWithLocalFile(
const substrait::ReadRel & rel,
const substrait::ReadRel::LocalFiles * local_files)
{
// only support read one relation
assert(split_infos.size() == 1);
assert(rel.has_base_schema());
auto header = TypeParser::buildBlockFromNamedStruct(rel.base_schema());
auto local_files = parseSplitInfo(split_infos.at(0));
auto source = std::make_shared<SubstraitFileSource>(context, header, *local_files);
auto source = std::make_shared<SubstraitFileSource>(context, header, local_files ? *local_files : rel.local_files());
auto source_pipe = Pipe(source);
auto source_step = std::make_unique<SubstraitFileSourceStep>(context, std::move(source_pipe), "substrait local files");
source_step->setStepDescription("read local files");
Expand Down Expand Up @@ -555,15 +555,8 @@ QueryPlanPtr SerializedPlanParser::parseOp(const substrait::Rel & rel, std::list
}
case substrait::Rel::RelTypeCase::kRead: {
const auto & read = rel.read();
if (!read.has_extension_table())
{
LOG_DEBUG(&Poco::Logger::get("SerializedPlanParser"), "read from local files");
QueryPlanStepPtr step;
if (isReadRelFromJava(read))
step = parseReadRealWithJavaIter(read);
else
step = parseReadRealWithLocalFile(read);

QueryPlanStepPtr step;
auto update = [&](){
query_plan = std::make_unique<QueryPlan>();
steps.emplace_back(step.get());
query_plan->addStep(std::move(step));
Expand All @@ -576,14 +569,35 @@ QueryPlanPtr SerializedPlanParser::parseOp(const substrait::Rel & rel, std::list
steps.emplace_back(buffer_step.get());
query_plan->addStep(std::move(buffer_step));
}
};

if (!split_infos.empty())
{
// only support read one relation
assert(split_infos.size() == 1);
auto split_info = parseSplitInfo(split_infos.at(0));
if (split_info.second == substrait::ReadRel::kExtensionTable)
{
LOG_DEBUG(&Poco::Logger::get("SerializedPlanParser"), "read from merge tree");
MergeTreeRelParser mergeTreeParser(this, context, query_context, global_context);
std::list<const substrait::Rel *> stack;
query_plan = mergeTreeParser.parseReadRel(std::make_unique<QueryPlan>(), read, static_cast<const substrait::ReadRel::ExtensionTable *>(split_info.first.get()), stack);
steps = mergeTreeParser.getSteps();
}
if (split_info.second == substrait::ReadRel::kLocalFiles)
{
LOG_DEBUG(&Poco::Logger::get("SerializedPlanParser"), "read from local files");
step = parseReadRealWithLocalFile(read, static_cast<const substrait::ReadRel::LocalFiles *>(split_info.first.get()));
update();
}
}
else
{
LOG_DEBUG(&Poco::Logger::get("SerializedPlanParser"), "read from merge tree");
MergeTreeRelParser mergeTreeParser(this, context, query_context, global_context);
std::list<const substrait::Rel *> stack;
query_plan = mergeTreeParser.parse(std::make_unique<QueryPlan>(), rel, stack);
steps = mergeTreeParser.getSteps();
if (isReadRelFromJava(read))
step = parseReadRealWithJavaIter(read);
else
step = parseReadRealWithLocalFile(read);
update();
}
break;
}
Expand Down Expand Up @@ -1824,27 +1838,48 @@ const ActionsDAG::Node * SerializedPlanParser::parseExpression(ActionsDAGPtr act
}
}

LocalFilesPtr SerializedPlanParser::parseSplitInfo(const std::string & split_info)
SplitInfo SerializedPlanParser::parseSplitInfo(const std::string & split_info)
{
auto local_files = std::make_shared<substrait::ReadRel::LocalFiles>();
auto logMessage = [](const std::shared_ptr<google::protobuf::Message> & message) {
auto * logger = &Poco::Logger::get("SerializedPlanParser");
if (logger->debug())
{
namespace pb_util = google::protobuf::util;
pb_util::JsonOptions options;
std::string json;
auto s = pb_util::MessageToJsonString(*message, &json, options);
if (!s.ok())
throw Exception(ErrorCodes::LOGICAL_ERROR, "Can not convert pb message to json");
LOG_DEBUG(logger, "split info:\n{}", json);
}
};

substrait::ReadRel::ReadTypeCase read_type_case;

{
google::protobuf::io::CodedInputStream coded_in(reinterpret_cast<const uint8_t *>(split_info.data()), static_cast<int>(split_info.size()));
coded_in.SetRecursionLimit(100000);
auto local_files = std::make_shared<substrait::ReadRel::LocalFiles>();
if (local_files->ParseFromCodedStream(&coded_in))
{
read_type_case = substrait::ReadRel::kLocalFiles;
logMessage(local_files);
return {local_files, read_type_case};
}
}

google::protobuf::io::CodedInputStream coded_in(reinterpret_cast<const uint8_t *>(split_info.data()), static_cast<int>(split_info.size()));
coded_in.SetRecursionLimit(100000);

auto ok = local_files->ParseFromCodedStream(&coded_in);
if (!ok)
throw Exception(ErrorCodes::CANNOT_PARSE_PROTOBUF_SCHEMA, "Parse substrait::ReadRel::LocalFiles from string failed");
auto * logger = &Poco::Logger::get("SerializedPlanParser");
if (logger->debug())
auto extension_table = std::make_shared<substrait::ReadRel::ExtensionTable>();
if (extension_table->ParseFromCodedStream(&coded_in))
{
namespace pb_util = google::protobuf::util;
pb_util::JsonOptions options;
std::string json;
auto s = pb_util::MessageToJsonString(*local_files, &json, options);
if (!s.ok())
throw Exception(ErrorCodes::LOGICAL_ERROR, "Can not convert LocalFiles to Json");
LOG_DEBUG(logger, "local files:\n{}", json);
read_type_case = substrait::ReadRel::kExtensionTable;
logMessage(extension_table);
return {extension_table, read_type_case};
}
return local_files;

throw Exception(ErrorCodes::CANNOT_PARSE_PROTOBUF_SCHEMA, "Parse split info from string failed, {}", split_info);
}


Expand Down
Loading

0 comments on commit 04834ff

Please sign in to comment.