diff --git a/velox/substrait/SubstraitToVeloxPlan.cpp b/velox/substrait/SubstraitToVeloxPlan.cpp index 9c3673934e1f4..2b389afa71bf5 100644 --- a/velox/substrait/SubstraitToVeloxPlan.cpp +++ b/velox/substrait/SubstraitToVeloxPlan.cpp @@ -376,21 +376,18 @@ std::shared_ptr SubstraitVeloxPlanConverter::toVeloxAgg( aggregateMasks.reserve(aggRel.measures().size()); for (const auto& smea : aggRel.measures()) { core::FieldAccessTypedExprPtr aggregateMask; - ::substrait::Expression substraitAggMask = smea.filter(); // Get Aggregation Masks. - if (smea.has_filter()) { - if (substraitAggMask.ByteSizeLong() == 0) { - aggregateMask = {}; - } else { - aggregateMask = - std::dynamic_pointer_cast( - exprConverter_->toVeloxExpr(substraitAggMask, inputType)); - VELOX_CHECK( - aggregateMask != nullptr, - " the agg filter expression in Aggregate Operator only support field"); - } - aggregateMasks.push_back(aggregateMask); + if (!smea.has_filter()) { + aggregateMask = {}; + } else { + aggregateMask = + std::dynamic_pointer_cast( + exprConverter_->toVeloxExpr(smea.filter(), inputType)); + VELOX_CHECK( + aggregateMask != nullptr, + " the agg filter expression in Aggregate Operator only support field"); } + aggregateMasks.push_back(aggregateMask); const auto& aggFunction = smea.measure(); std::string funcName = subParser_->findVeloxFunction( functionMap_, aggFunction.function_reference()); diff --git a/velox/substrait/SubstraitToVeloxPlanValidator.cpp b/velox/substrait/SubstraitToVeloxPlanValidator.cpp index 81a17fb1f3495..caf6ae392f012 100644 --- a/velox/substrait/SubstraitToVeloxPlanValidator.cpp +++ b/velox/substrait/SubstraitToVeloxPlanValidator.cpp @@ -718,9 +718,15 @@ bool SubstraitToVeloxPlanValidator::validate( } const auto& aggFunction = smea.measure(); - funcSpecs.emplace_back( - planConverter_->findFuncSpec(aggFunction.function_reference())); + const auto& functionSpec = planConverter_->findFuncSpec(aggFunction.function_reference()); + funcSpecs.emplace_back(functionSpec); toVeloxType(subParser_->parseType(aggFunction.output_type())->type); + // Validate the size of arguments. + if (subParser_->getSubFunctionName(functionSpec) == "count" && + aggFunction.arguments().size() > 1) { + // Count accepts only one argument. + return false; + } for (const auto& arg : aggFunction.arguments()) { auto typeCase = arg.value().rex_type_case(); switch (typeCase) {