Skip to content

Commit

Permalink
Fix AggregateRel validate rel without group and functions and functio…
Browse files Browse the repository at this point in the history
…n signature not match (oap-project#87)
  • Loading branch information
jinchengchenghh authored and zhejiangxiaomai committed Apr 20, 2023
1 parent 12a6c05 commit cc36909
Show file tree
Hide file tree
Showing 11 changed files with 282 additions and 30 deletions.
15 changes: 14 additions & 1 deletion velox/substrait/SubstraitParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,16 @@ std::shared_ptr<SubstraitParser::SubstraitType> SubstraitParser::parseType(
return std::make_shared<SubstraitType>(type);
}

std::string SubstraitParser::parseType(
const std::string& substraitType) {
auto it = typeMap_.find(substraitType);
if (it == typeMap_.end()) {
VELOX_NYI(
"Substrait parsing for type {} not supported.", substraitType);
}
return it->second;
};

std::vector<std::shared_ptr<SubstraitParser::SubstraitType>>
SubstraitParser::parseNamedStruct(const ::substrait::NamedStruct& namedStruct) {
// Nte that "names" are not used.
Expand Down Expand Up @@ -275,7 +285,10 @@ void SubstraitParser::getSubFunctionTypes(
// Split the types with delimiter.
std::string delimiter = "_";
while ((pos = funcTypes.find(delimiter)) != std::string::npos) {
types.emplace_back(funcTypes.substr(0, pos));
auto type = funcTypes.substr(0, pos);
if (type != "opt" && type !="req") {
types.emplace_back(type);
}
funcTypes.erase(0, pos + delimiter.length());
}
types.emplace_back(funcTypes);
Expand Down
18 changes: 18 additions & 0 deletions velox/substrait/SubstraitParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ class SubstraitParser {
std::shared_ptr<SubstraitType> parseType(
const ::substrait::Type& substraitType);

// Parse substraitType type such as i32.
std::string parseType(const std::string& substraitType);

/// Parse Substrait ReferenceSegment.
int32_t parseReferenceSegment(
const ::substrait::Expression::ReferenceSegment& refSegment);
Expand Down Expand Up @@ -110,6 +113,21 @@ class SubstraitParser {
{"ends_with", "endswith"},
{"starts_with", "startswith"},
{"modulus", "mod"} /*Presto functions.*/};
// The map is uesd for mapping substrait type.
// Key: type in function name.
// Value: substrait type name.
const std::unordered_map<std::string, std::string> typeMap_ = {
{"bool", "BOOLEAN"},
{"i8", "TINYINT"},
{"i16", "SMALLINT"},
{"i32", "INTEGER"},
{"i64", "BIGINT"},
{"fp32", "REAL"},
{"fp64", "DOUBLE"},
{"date", "DATE"},
{"ts", "TIMESTAMP_TZ"},
{"str", "VARCHAR"},
{"vbin", "VARBINARY"}};
};

} // namespace facebook::velox::substrait
47 changes: 24 additions & 23 deletions velox/substrait/SubstraitToVeloxPlan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,6 @@

namespace facebook::velox::substrait {
namespace {
core::AggregationNode::Step toAggregationStep(
const ::substrait::AggregateRel& sAgg) {
if (sAgg.measures().size() == 0) {
// When only groupings exist, set the phase to be Single.
return core::AggregationNode::Step::kSingle;
}

// Use the first measure to set aggregation phase.
const auto& firstMeasure = sAgg.measures()[0];
const auto& aggFunction = firstMeasure.measure();
switch (aggFunction.phase()) {
case ::substrait::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE:
return core::AggregationNode::Step::kPartial;
case ::substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_INTERMEDIATE:
return core::AggregationNode::Step::kIntermediate;
case ::substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT:
return core::AggregationNode::Step::kFinal;
case ::substrait::AGGREGATION_PHASE_INITIAL_TO_RESULT:
return core::AggregationNode::Step::kSingle;
default:
VELOX_FAIL("Aggregate phase is not supported.");
}
}

core::SortOrder toSortOrder(const ::substrait::SortField& sortField) {
switch (sortField.direction()) {
Expand Down Expand Up @@ -250,6 +227,30 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::processEmit(
}
}

core::AggregationNode::Step SubstraitVeloxPlanConverter::toAggregationStep(
const ::substrait::AggregateRel& sAgg) {
if (sAgg.measures().size() == 0) {
// When only groupings exist, set the phase to be Single.
return core::AggregationNode::Step::kSingle;
}

// Use the first measure to set aggregation phase.
const auto& firstMeasure = sAgg.measures()[0];
const auto& aggFunction = firstMeasure.measure();
switch (aggFunction.phase()) {
case ::substrait::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE:
return core::AggregationNode::Step::kPartial;
case ::substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_INTERMEDIATE:
return core::AggregationNode::Step::kIntermediate;
case ::substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT:
return core::AggregationNode::Step::kFinal;
case ::substrait::AGGREGATION_PHASE_INITIAL_TO_RESULT:
return core::AggregationNode::Step::kSingle;
default:
VELOX_FAIL("Aggregate phase is not supported.");
}
}

core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
const ::substrait::JoinRel& sJoin) {
if (!sJoin.has_left()) {
Expand Down
7 changes: 5 additions & 2 deletions velox/substrait/SubstraitToVeloxPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ class SubstraitVeloxPlanConverter {
std::vector<const ::substrait::Expression::FieldReference*>& leftExprs,
std::vector<const ::substrait::Expression::FieldReference*>& rightExprs);

/// Get aggregation step from AggregateRel.
core::AggregationNode::Step toAggregationStep(
const ::substrait::AggregateRel& sAgg);

private:
/// Range filter recorder for a field is used to make sure only the conditions
/// that can coexist for this field being pushed down with a range filter.
Expand Down Expand Up @@ -480,8 +484,7 @@ class SubstraitVeloxPlanConverter {
remainingFunctions,
const std::vector<::substrait::Expression_SingularOrList>&
singularOrLists,
const std::vector<::substrait::Expression_IfThen>&
ifThens);
const std::vector<::substrait::Expression_IfThen>& ifThens);

/// Connect the left and right expressions with 'and' relation.
core::TypedExprPtr connectWithAnd(
Expand Down
89 changes: 85 additions & 4 deletions velox/substrait/SubstraitToVeloxPlanValidator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "velox/substrait/SubstraitToVeloxPlanValidator.h"
#include "TypeUtils.h"
#include "velox/expression/SignatureBinder.h"

namespace facebook::velox::substrait {

Expand All @@ -41,8 +42,8 @@ bool SubstraitToVeloxPlanValidator::validateInputTypes(
try {
types.emplace_back(toVeloxType(subParser_->parseType(sType)->type));
} catch (const VeloxException& err) {
std::cout << "Type is not supported in ProjectRel due to:"
<< err.message() << std::endl;
std::cout << "Type is not supported due to:" << err.message()
<< std::endl;
return false;
}
}
Expand Down Expand Up @@ -356,6 +357,62 @@ bool SubstraitToVeloxPlanValidator::validate(
return true;
}

bool SubstraitToVeloxPlanValidator::validateAggRelFunctionType(
const ::substrait::AggregateRel& sAgg) {
if (sAgg.measures_size() == 0) {
return true;
}
core::AggregationNode::Step step = planConverter_->toAggregationStep(sAgg);
for (const auto& smea : sAgg.measures()) {
const auto& aggFunction = smea.measure();
auto funcSpec =
planConverter_->findFuncSpec(aggFunction.function_reference());
auto funcName = subParser_->getSubFunctionName(funcSpec);
std::vector<TypePtr> types;
try {
std::vector<std::string> funcTypes;
subParser_->getSubFunctionTypes(funcSpec, funcTypes);
types.reserve(funcTypes.size());
for (auto& type : funcTypes) {
types.emplace_back(toVeloxType(subParser_->parseType(type)));
}
} catch (const VeloxException& err) {
std::cout
<< "Validation failed for input type in AggregateRel function due to:"
<< err.message() << std::endl;
return false;
}
if (auto signatures = exec::getAggregateFunctionSignatures(funcName)) {
for (const auto& signature : signatures.value()) {
exec::SignatureBinder binder(*signature, types);
if (binder.tryBind()) {
auto resolveType = binder.tryResolveType(
exec::isPartialOutput(step) ? signature->intermediateType()
: signature->returnType());
if (resolveType == nullptr) {
std::cout
<< fmt::format(
"Validation failed for function {} resolve type in AggregateRel.",
funcName)
<< std::endl;
return false;
}
return true;
}
}
std::cout
<< fmt::format(
"Validation failed for function {} bind in AggregateRel.",
funcName)
<< std::endl;
return false;
}
}
std::cout << "Validation failed for function resolve in AggregateRel."
<< std::endl;
return false;
}

bool SubstraitToVeloxPlanValidator::validate(
const ::substrait::AggregateRel& sAgg) {
if (sAgg.has_input() && !validate(sAgg.input())) {
Expand All @@ -364,10 +421,10 @@ bool SubstraitToVeloxPlanValidator::validate(

// Validate input types.
if (sAgg.has_advanced_extension()) {
const auto& extension = sAgg.advanced_extension();
std::vector<TypePtr> types;
const auto& extension = sAgg.advanced_extension();
if (!validateInputTypes(extension, types)) {
std::cout << "Validation failed for input types in AggregateRel"
std::cout << "Validation failed for input types in AggregateRel."
<< std::endl;
return false;
}
Expand Down Expand Up @@ -425,6 +482,30 @@ bool SubstraitToVeloxPlanValidator::validate(
return false;
}
}

if (!validateAggRelFunctionType(sAgg)) {
return false;
}

// Validate both groupby and aggregates input are empty, which is corner case.
if (sAgg.measures_size() == 0) {
bool hasExpr = false;
for (const auto& grouping : sAgg.groupings()) {
for (const auto& groupingExpr : grouping.grouping_expressions()) {
hasExpr = true;
break;
}
if (hasExpr) {
break;
}
}
if (!hasExpr) {
std::cout
<< "Validation failed due to aggregation must specify either grouping keys or aggregates."
<< std::endl;
return false;
}
}
return true;
}

Expand Down
2 changes: 2 additions & 0 deletions velox/substrait/SubstraitToVeloxPlanValidator.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ class SubstraitToVeloxPlanValidator {
bool validateInputTypes(
const ::substrait::extensions::AdvancedExtension& extension,
std::vector<TypePtr>& types);

bool validateAggRelFunctionType(const ::substrait::AggregateRel& sAgg);
};

} // namespace facebook::velox::substrait
1 change: 1 addition & 0 deletions velox/substrait/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ add_executable(
velox_plan_conversion_test
Substrait2VeloxPlanConversionTest.cpp
Substrait2VeloxValuesNodeConversionTest.cpp
Substrait2VeloxPlanValidatorTest.cpp
FunctionTest.cpp
JsonToProtoConverter.cpp
VeloxSubstraitRoundTripTest.cpp
Expand Down
16 changes: 16 additions & 0 deletions velox/substrait/tests/FunctionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,19 @@ TEST_F(FunctionTest, streamIsInput) {
int index = planConverter_->streamIsInput(substraitRel.read());
ASSERT_EQ(index, 0);
}

TEST_F(FunctionTest, getFunctionType) {
std::vector<std::string> types;
substraitParser_->getSubFunctionTypes("sum:opt_i32", types);
ASSERT_EQ("i32", types[0]);

types.clear();
substraitParser_->getSubFunctionTypes("sum:i32", types);
ASSERT_EQ("i32", types[0]);

types.clear();
substraitParser_->getSubFunctionTypes("sum:opt_str_str", types);
ASSERT_EQ(2, types.size());
ASSERT_EQ("str", types[0]);
ASSERT_EQ("str", types[1]);
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "velox/exec/tests/utils/PlanBuilder.h"
#include "velox/exec/tests/utils/TempDirectoryPath.h"
#include "velox/substrait/SubstraitToVeloxPlan.h"
#include "velox/substrait/SubstraitToVeloxPlanValidator.h"
#include "velox/type/Type.h"

using namespace facebook::velox;
Expand Down
82 changes: 82 additions & 0 deletions velox/substrait/tests/Substrait2VeloxPlanValidatorTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/substrait/tests/JsonToProtoConverter.h"

#include "velox/common/base/tests/GTestUtils.h"
#include "velox/dwio/common/tests/utils/DataFiles.h"
#include "velox/exec/tests/utils/AssertQueryBuilder.h"
#include "velox/exec/tests/utils/HiveConnectorTestBase.h"
#include "velox/exec/tests/utils/PlanBuilder.h"
#include "velox/exec/tests/utils/TempDirectoryPath.h"
#include "velox/substrait/SubstraitToVeloxPlan.h"
#include "velox/substrait/SubstraitToVeloxPlanValidator.h"
#include "velox/type/Type.h"

using namespace facebook::velox;
using namespace facebook::velox::test;
using namespace facebook::velox::connector::hive;
using namespace facebook::velox::exec;
namespace vestrait = facebook::velox::substrait;

class Substrait2VeloxPlanConversionTest
: public exec::test::HiveConnectorTestBase {
protected:

std::shared_ptr<vestrait::SubstraitVeloxPlanConverter> planConverter_ =
std::make_shared<vestrait::SubstraitVeloxPlanConverter>(
memoryPool_.get());

bool validatePlan(std::string file) {
std::string subPlanPath = getDataFilePath("velox/substrait/tests", file);

::substrait::Plan substraitPlan;
JsonToProtoConverter::readFromFile(subPlanPath, substraitPlan);
return validatePlan(substraitPlan);
}

bool validatePlan(::substrait::Plan& plan) {
std::shared_ptr<core::QueryCtx> queryCtx =
std::make_shared<core::QueryCtx>();

// An execution context used for function validation.
std::unique_ptr<core::ExecCtx> execCtx =
std::make_unique<core::ExecCtx>(pool_.get(), queryCtx.get());

auto planValidator = std::make_shared<
facebook::velox::substrait::SubstraitToVeloxPlanValidator>(
pool_.get(), execCtx.get());
return planValidator->validate(plan);
}

private:
std::shared_ptr<memory::MemoryPool> memoryPool_{
memory::getDefaultMemoryPool()};
};

TEST_F(Substrait2VeloxPlanConversionTest, group) {
std::string subPlanPath =
getDataFilePath("velox/substrait/tests", "group.json");

::substrait::Plan substraitPlan;
JsonToProtoConverter::readFromFile(subPlanPath, substraitPlan);

ASSERT_FALSE(validatePlan(substraitPlan));
// Convert to Velox PlanNode.
facebook::velox::substrait::SubstraitVeloxPlanConverter planConverter(
pool_.get());
EXPECT_ANY_THROW(planConverter.toVeloxPlan(substraitPlan));
}
Loading

0 comments on commit cc36909

Please sign in to comment.