Skip to content

Commit

Permalink
Removed special handling for avg (oap-project#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo authored and zhejiangxiaomai committed Jul 26, 2022
1 parent f819b0c commit 0494705
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 231 deletions.
17 changes: 12 additions & 5 deletions velox/substrait/SubstraitParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,20 @@ std::shared_ptr<SubstraitParser::SubstraitType> SubstraitParser::parseType(
break;
}
case ::substrait::Type::KindCase::kStruct: {
// TODO: Support for Struct is not fully added.
typeName = "STRUCT";
// The type name of struct is in the format of:
// STRUCT:type0_type1...typen.
typeName = "STRUCT:";
const auto& sStruct = substraitType.struct_();
const auto& substraitType = sStruct.types();
for (const auto& type : substraitType) {
parseType(type);
const auto& substraitTypes = sStruct.types();
for (int idx = 0; idx < substraitTypes.size() - 1; idx++) {
std::string childTypeWithSuffix =
parseType(substraitTypes[idx])->type + "_";
typeName += childTypeWithSuffix;
}
std::string lastType =
parseType(substraitTypes[substraitTypes.size() - 1])->type;
typeName += lastType;
nullability = substraitType.struct_().nullability();
break;
}
case ::substrait::Type::KindCase::kString: {
Expand Down
38 changes: 32 additions & 6 deletions velox/substrait/SubstraitToVeloxExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,30 @@ SubstraitVeloxExprConverter::toExtractExpr(
VELOX_FAIL("Constant is expected to be the first parameter in extract.");
}

std::shared_ptr<const core::ITypedExpr>
SubstraitVeloxExprConverter::toRowConstructorExpr(
const std::vector<std::shared_ptr<const core::ITypedExpr>>& params,
const std::string& typeName) {
std::vector<std::string> structTypeNames;
subParser_->getSubFunctionTypes(typeName, structTypeNames);
VELOX_CHECK(
structTypeNames.size() > 0, "At lease one type name is expected.");

// Preparation for the conversion from struct types to RowType.
std::vector<TypePtr> rowTypes;
std::vector<std::string> names;
for (int idx = 0; idx < structTypeNames.size(); idx++) {
std::string substraitTypeName = structTypeNames[idx];
names.emplace_back("col_" + std::to_string(idx));
rowTypes.emplace_back(std::move(toVeloxType(substraitTypeName)));
}

return std::make_shared<const core::CallTypedExpr>(
ROW(std::move(names), std::move(rowTypes)),
std::move(params),
"row_constructor");
}

std::shared_ptr<const core::ITypedExpr>
SubstraitVeloxExprConverter::toVeloxExpr(
const ::substrait::Expression::ScalarFunction& sFunc,
Expand All @@ -114,21 +138,23 @@ SubstraitVeloxExprConverter::toVeloxExpr(
}
const auto& veloxFunction =
subParser_->findVeloxFunction(functionMap_, sFunc.function_reference());
const auto& veloxType =
toVeloxType(subParser_->parseType(sFunc.output_type())->type);
std::string typeName = subParser_->parseType(sFunc.output_type())->type;

if (veloxFunction == "extract") {
return toExtractExpr(params, veloxType);
return toExtractExpr(std::move(params), toVeloxType(typeName));
}
if (veloxFunction == "alias") {
return toAliasExpr(params);
return toAliasExpr(std::move(params));
}
if (veloxFunction == "is_not_null") {
return toIsNotNullExpr(params, veloxType);
return toIsNotNullExpr(std::move(params), toVeloxType(typeName));
}
if (veloxFunction == "row_constructor") {
return toRowConstructorExpr(std::move(params), typeName);
}

return std::make_shared<const core::CallTypedExpr>(
veloxType, std::move(params), veloxFunction);
toVeloxType(typeName), std::move(params), veloxFunction);
}

std::shared_ptr<const core::ConstantTypedExpr>
Expand Down
5 changes: 5 additions & 0 deletions velox/substrait/SubstraitToVeloxExpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ class SubstraitVeloxExprConverter {
const std::vector<std::shared_ptr<const core::ITypedExpr>>& params,
const TypePtr& outputType);

/// Create expression for row_constructor.
std::shared_ptr<const core::ITypedExpr> toRowConstructorExpr(
const std::vector<std::shared_ptr<const core::ITypedExpr>>& params,
const std::string& typeName);

/// Used to convert Substrait Literal into Velox Expression.
std::shared_ptr<const core::ConstantTypedExpr> toVeloxExpr(
const ::substrait::Expression::Literal& substraitLit);
Expand Down
240 changes: 31 additions & 209 deletions velox/substrait/SubstraitToVeloxPlan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,175 +192,15 @@ std::shared_ptr<const core::PlanNode> SubstraitVeloxPlanConverter::toVeloxPlan(
std::shared_ptr<const core::PlanNode> childNode;
if (sAgg.has_input()) {
childNode = toVeloxPlan(sAgg.input());

} else {
VELOX_FAIL("Child Rel is expected in AggregateRel.");
}

core::AggregationNode::Step aggStep;
// Get aggregation phase and check if there are input columns need to be
// combined into row.
if (needsRowConstruct(sAgg, aggStep)) {
return toVeloxAggWithRowConstruct(sAgg, childNode, aggStep);
}
setPhase(sAgg, aggStep);
return toVeloxAgg(sAgg, childNode, aggStep);
}

std::shared_ptr<const core::PlanNode>
SubstraitVeloxPlanConverter::toVeloxAggWithRowConstruct(
const ::substrait::AggregateRel& sAgg,
const std::shared_ptr<const core::PlanNode>& childNode,
const core::AggregationNode::Step& aggStep) {
// Will add a Project node before Aggregate node to combine columns into
// row.
std::vector<std::shared_ptr<const core::ITypedExpr>> constructExprs;
const auto& groupings = sAgg.groupings();
const auto& constructInputType = childNode->outputType();

// Handle groupings.
uint32_t groupingOutIdx = 0;
for (const auto& grouping : groupings) {
const auto& groupingExprs = grouping.grouping_expressions();
for (const auto& groupingExpr : groupingExprs) {
// Velox's groupings are limited to be Field.
auto fieldExpr = exprConverter_->toVeloxExpr(
groupingExpr.selection(), constructInputType);
constructExprs.push_back(fieldExpr);
groupingOutIdx += 1;
}
}

std::vector<core::CallTypedExprPtr> aggExprs;
auto aggMeasureSize = sAgg.measures().size();
aggExprs.reserve(aggMeasureSize);

std::vector<core::FieldAccessTypedExprPtr> aggregateMasks;
aggregateMasks.reserve(sAgg.measures().size());

// Construct Velox Aggregate expressions.
for (const auto& measure : sAgg.measures()) {
core::FieldAccessTypedExprPtr aggregateMask;
::substrait::Expression substraitAggMask = measure.filter();
// Get Aggregation Masks.
if (measure.has_filter()) {
if (substraitAggMask.ByteSizeLong() == 0) {
aggregateMask = {};
} else {
aggregateMask =
std::dynamic_pointer_cast<const core::FieldAccessTypedExpr>(
exprConverter_->toVeloxExpr(
substraitAggMask, constructInputType));
}
aggregateMasks.push_back(aggregateMask);
}
}
// Handle aggregations.
std::vector<std::string> aggFuncNames;
aggFuncNames.reserve(sAgg.measures().size());
std::vector<TypePtr> aggOutTypes;
aggOutTypes.reserve(sAgg.measures().size());

for (const auto& smea : sAgg.measures()) {
const auto& aggFunction = smea.measure();
std::string funcName = subParser_->findVeloxFunction(
functionMap_, aggFunction.function_reference());
aggFuncNames.emplace_back(funcName);
aggOutTypes.emplace_back(
toVeloxType(subParser_->parseType(aggFunction.output_type())->type));
if (funcName == "avg") {
// Will use row constructor to combine the sum and count columns into
// row.
if (aggFunction.args().size() != 2) {
VELOX_FAIL("Final average should have two args.");
}
std::vector<std::shared_ptr<const core::ITypedExpr>> aggParams;
aggParams.reserve(aggFunction.args().size());
for (const auto& arg : aggFunction.args()) {
aggParams.emplace_back(
exprConverter_->toVeloxExpr(arg, constructInputType));
}
auto constructExpr = std::make_shared<const core::CallTypedExpr>(
ROW({"sum", "count"}, {DOUBLE(), BIGINT()}),
std::move(aggParams),
"row_constructor");
constructExprs.emplace_back(constructExpr);
} else {
if (aggFunction.args().size() != 1) {
VELOX_FAIL("Expect only one arg.");
}
for (const auto& arg : aggFunction.args()) {
constructExprs.emplace_back(
exprConverter_->toVeloxExpr(arg, constructInputType));
}
}
}
// Get the output names of row construct.
std::vector<std::string> constructOutNames;
constructOutNames.reserve(constructExprs.size());
for (uint32_t colIdx = 0; colIdx < constructExprs.size(); colIdx++) {
constructOutNames.emplace_back(
subParser_->makeNodeName(planNodeId_, colIdx));
}

uint32_t totalOutColNum = constructExprs.size();
// Create the row construct node.
auto constructNode = std::make_shared<core::ProjectNode>(
nextPlanNodeId(),
std::move(constructOutNames),
std::move(constructExprs),
childNode);

// Create the Aggregation node.
bool ignoreNullKeys = false;
std::vector<std::shared_ptr<const core::FieldAccessTypedExpr>>
preGroupingExprs = {};

// Get the output names of Aggregate node.
std::vector<std::string> aggOutNames;
aggOutNames.reserve(totalOutColNum - groupingOutIdx);
for (uint32_t idx = groupingOutIdx; idx < totalOutColNum; idx++) {
aggOutNames.emplace_back(subParser_->makeNodeName(planNodeId_, idx));
}

const auto& constructOutType = constructNode->outputType();
for (uint32_t colIdx = groupingOutIdx; colIdx < totalOutColNum; colIdx++) {
std::vector<std::shared_ptr<const core::ITypedExpr>> aggArgs;
aggArgs.reserve(1);
// Use the colIdx to access the columns after grouping columns.
aggArgs.emplace_back(std::make_shared<const core::FieldAccessTypedExpr>(
constructOutType->childAt(colIdx), constructOutType->names()[colIdx]));
// Use the another index to access the types and names of aggregation
// columns.
aggExprs.emplace_back(std::make_shared<const core::CallTypedExpr>(
aggOutTypes[colIdx - groupingOutIdx],
std::move(aggArgs),
aggFuncNames[colIdx - groupingOutIdx]));
}

// Get the grouping expressions.
std::vector<std::shared_ptr<const core::FieldAccessTypedExpr>> groupingExprs;
groupingExprs.reserve(groupingOutIdx);
for (uint32_t colIdx = 0; colIdx < groupingOutIdx; colIdx++) {
// Velox's groupings are limited to be Field.
groupingExprs.emplace_back(
std::make_shared<const core::FieldAccessTypedExpr>(
constructOutType->childAt(colIdx),
constructOutType->names()[colIdx]));
}

// Create the Aggregation node.
auto aggNode = std::make_shared<core::AggregationNode>(
nextPlanNodeId(),
aggStep,
groupingExprs,
preGroupingExprs,
aggOutNames,
aggExprs,
aggregateMasks,
ignoreNullKeys,
constructNode);
return aggNode;
}

std::shared_ptr<const core::PlanNode> SubstraitVeloxPlanConverter::toVeloxAgg(
const ::substrait::AggregateRel& sAgg,
const std::shared_ptr<const core::PlanNode>& childNode,
Expand All @@ -381,7 +221,7 @@ std::shared_ptr<const core::PlanNode> SubstraitVeloxPlanConverter::toVeloxAgg(
}

// Parse measures and get the aggregate expressions.
uint32_t aggOutIdx = groupingOutIdx;
// Each measure represents one aggregate expression.
std::vector<std::shared_ptr<const core::CallTypedExpr>> aggExprs;
aggExprs.reserve(sAgg.measures().size());
for (const auto& smea : sAgg.measures()) {
Expand All @@ -395,33 +235,22 @@ std::shared_ptr<const core::PlanNode> SubstraitVeloxPlanConverter::toVeloxAgg(
}
auto aggVeloxType =
toVeloxType(subParser_->parseType(aggFunction.output_type())->type);
if (funcName == "avg") {
// Will used sum and count to calculate the partial avg.
auto sumExpr = std::make_shared<const core::CallTypedExpr>(
aggVeloxType, aggParams, "sum");
auto countExpr = std::make_shared<const core::CallTypedExpr>(
BIGINT(), aggParams, "count");
aggExprs.emplace_back(sumExpr);
aggExprs.emplace_back(countExpr);
aggOutIdx += 2;
} else {
auto aggExpr = std::make_shared<const core::CallTypedExpr>(
aggVeloxType, std::move(aggParams), funcName);
aggExprs.emplace_back(aggExpr);
aggOutIdx += 1;
}
auto aggExpr = std::make_shared<const core::CallTypedExpr>(
aggVeloxType, std::move(aggParams), funcName);
aggExprs.emplace_back(aggExpr);
}

bool ignoreNullKeys = false;
std::vector<std::shared_ptr<const core::FieldAccessTypedExpr>> aggregateMasks(
aggOutIdx - groupingOutIdx);
sAgg.measures().size());
std::vector<std::shared_ptr<const core::FieldAccessTypedExpr>>
preGroupingExprs = {};

// Get the output names of Aggregation.
std::vector<std::string> aggOutNames;
aggOutNames.reserve(aggOutIdx - groupingOutIdx);
for (int idx = groupingOutIdx; idx < aggOutIdx; idx++) {
aggOutNames.reserve(sAgg.measures().size());
for (int idx = groupingOutIdx; idx < groupingOutIdx + sAgg.measures().size();
idx++) {
aggOutNames.emplace_back(subParser_->makeNodeName(planNodeId_, idx));
}

Expand Down Expand Up @@ -819,41 +648,34 @@ std::string SubstraitVeloxPlanConverter::findFuncSpec(uint64_t id) {
return subParser_->findSubstraitFuncSpec(functionMap_, id);
}

bool SubstraitVeloxPlanConverter::needsRowConstruct(
void SubstraitVeloxPlanConverter::setPhase(
const ::substrait::AggregateRel& sAgg,
core::AggregationNode::Step& aggStep) {
if (sAgg.measures().size() == 0) {
// When only groupings exist, set the phase to be Single.
aggStep = core::AggregationNode::Step::kSingle;
return false;
return;
}
for (const auto& smea : sAgg.measures()) {
auto aggFunction = smea.measure();
std::string funcName = subParser_->findVeloxFunction(
functionMap_, aggFunction.function_reference());
// Set the aggregation phase.
switch (aggFunction.phase()) {
case ::substrait::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE:
aggStep = core::AggregationNode::Step::kPartial;
break;
case ::substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_INTERMEDIATE:
aggStep = core::AggregationNode::Step::kIntermediate;
break;
case ::substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT:
aggStep = core::AggregationNode::Step::kFinal;
// Only Final Average needs row construct currently.
if (funcName == "avg") {
return true;
}
break;
case ::substrait::AGGREGATION_PHASE_INITIAL_TO_RESULT:
aggStep = core::AggregationNode::Step::kSingle;
break;
default:
throw std::runtime_error("Aggregate phase is not supported.");
}

// Use the first measure to set aggregation phase.
const auto& smea = sAgg.measures()[0];
const auto& aggFunction = smea.measure();
switch (aggFunction.phase()) {
case ::substrait::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE:
aggStep = core::AggregationNode::Step::kPartial;
break;
case ::substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_INTERMEDIATE:
aggStep = core::AggregationNode::Step::kIntermediate;
break;
case ::substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT:
aggStep = core::AggregationNode::Step::kFinal;
break;
case ::substrait::AGGREGATION_PHASE_INITIAL_TO_RESULT:
aggStep = core::AggregationNode::Step::kSingle;
break;
default:
VELOX_FAIL("Aggregate phase is not supported.");
}
return false;
}

int32_t SubstraitVeloxPlanConverter::streamIsInput(
Expand Down
Loading

0 comments on commit 0494705

Please sign in to comment.