Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Register merge extract companion agg functions without suffix #468

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions velox/core/PlanNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,16 +233,8 @@ bool AggregationNode::canSpill(const QueryConfig& queryConfig) const {
}
// TODO: add spilling for pre-grouped aggregation later:
// https://github.com/facebookincubator/velox/issues/3264
if ((isFinal() || isSingle()) && queryConfig.aggregationSpillEnabled()) {
return preGroupedKeys().empty();
}

if ((isIntermediate() || isPartial()) &&
queryConfig.partialAggregationSpillEnabled()) {
return preGroupedKeys().empty();
}

return false;
return (isFinal() || isSingle()) && preGroupedKeys().empty() &&
queryConfig.aggregationSpillEnabled();
}

void AggregationNode::addDetails(std::stringstream& stream) const {
Expand Down
8 changes: 0 additions & 8 deletions velox/core/PlanNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -620,14 +620,6 @@ class AggregationNode : public PlanNode {
return step_ == Step::kSingle;
}

bool isIntermediate() const {
return step_ == Step::kIntermediate;
}

bool isPartial() const {
return step_ == Step::kPartial;
}

folly::dynamic serialize() const override;

static PlanNodePtr create(const folly::dynamic& obj, void* context);
Expand Down
13 changes: 1 addition & 12 deletions velox/core/QueryConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,6 @@ class QueryConfig {
static constexpr const char* kAggregationSpillEnabled =
"aggregation_spill_enabled";

/// Partial aggregation spilling flag, only applies if "spill_enabled" flag is
/// set.
static constexpr const char* kPartialAggregationSpillEnabled =
"partial_aggregation_spill_enabled";

/// Join spilling flag, only applies if "spill_enabled" flag is set.
static constexpr const char* kJoinSpillEnabled = "join_spill_enabled";

Expand Down Expand Up @@ -551,17 +546,11 @@ class QueryConfig {
}

/// Returns 'is aggregation spilling enabled' flag. Must also check the
/// spillEnabled()!
/// spillEnabled()!g
bool aggregationSpillEnabled() const {
return get<bool>(kAggregationSpillEnabled, true);
}

/// Returns 'is partial aggregation spilling enabled' flag. Must also check
/// the spillEnabled()!
bool partialAggregationSpillEnabled() const {
return get<bool>(kPartialAggregationSpillEnabled, false);
}

/// Returns 'is join spilling enabled' flag. Must also check the
/// spillEnabled()!
bool joinSpillEnabled() const {
Expand Down
79 changes: 38 additions & 41 deletions velox/exec/AggregateCompanionAdapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,13 @@ bool CompanionFunctionsRegistrar::registerPartialFunction(
const core::QueryConfig& config)
-> std::unique_ptr<Aggregate> {
if (auto func = getAggregateFunctionEntry(name)) {
core::AggregationNode::Step usedStep{
core::AggregationNode::Step::kPartial};
if (!exec::isRawInput(step)) {
step = core::AggregationNode::Step::kIntermediate;
usedStep = core::AggregationNode::Step::kIntermediate;
}
auto fn = func->factory(step, argTypes, resultType, config);
auto fn =
func->factory(usedStep, argTypes, resultType, config);
VELOX_CHECK_NOT_NULL(fn);
return std::make_unique<
AggregateCompanionAdapter::PartialFunction>(
Expand Down Expand Up @@ -366,56 +369,50 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunction(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
bool overwrite) {
bool registered = false;
if (CompanionSignatures::hasSameIntermediateTypesAcrossSignatures(
signatures)) {
return registerMergeExtractFunctionWithSuffix(name, signatures, overwrite);
registered |=
registerMergeExtractFunctionWithSuffix(name, signatures, overwrite);
}

auto mergeExtractSignatures =
CompanionSignatures::mergeExtractFunctionSignatures(signatures);
if (mergeExtractSignatures.empty()) {
return false;
return registered;
}

auto mergeExtractFunctionName =
CompanionSignatures::mergeExtractFunctionName(name);
return exec::registerAggregateFunction(
mergeExtractFunctionName,
std::move(mergeExtractSignatures),
[name, mergeExtractFunctionName](
core::AggregationNode::Step /*step*/,
const std::vector<TypePtr>& argTypes,
const TypePtr& resultType,
const core::QueryConfig& config)
-> std::unique_ptr<Aggregate> {
const auto& [originalResultType, _] =
resolveAggregateFunction(mergeExtractFunctionName, argTypes);
if (!originalResultType) {
// TODO: limitation -- result type must be resolveable given
// intermediate type of the original UDAF.
VELOX_UNREACHABLE(
"Signatures whose result types are not resolvable given intermediate types should have been excluded.");
}

if (auto func = getAggregateFunctionEntry(name)) {
auto fn = func->factory(
core::AggregationNode::Step::kFinal,
argTypes,
originalResultType,
config);
VELOX_CHECK_NOT_NULL(fn);
return std::make_unique<
AggregateCompanionAdapter::MergeExtractFunction>(
std::move(fn), resultType);
}
VELOX_FAIL(
"Original aggregation function {} not found: {}",
name,
mergeExtractFunctionName);
},
/*registerCompanionFunctions*/ false,
overwrite)
.mainFunction;
registered |=
exec::registerAggregateFunction(
mergeExtractFunctionName,
std::move(mergeExtractSignatures),
[name, mergeExtractFunctionName](
core::AggregationNode::Step /*step*/,
const std::vector<TypePtr>& argTypes,
const TypePtr& resultType,
const core::QueryConfig& config) -> std::unique_ptr<Aggregate> {
if (auto func = getAggregateFunctionEntry(name)) {
auto fn = func->factory(
core::AggregationNode::Step::kFinal,
argTypes,
resultType,
config);
VELOX_CHECK_NOT_NULL(fn);
return std::make_unique<
AggregateCompanionAdapter::MergeExtractFunction>(
std::move(fn), resultType);
}
VELOX_FAIL(
"Original aggregation function {} not found: {}",
name,
mergeExtractFunctionName);
},
/*registerCompanionFunctions*/ false,
overwrite)
.mainFunction;
return registered;
}

bool CompanionFunctionsRegistrar::registerExtractFunctionWithSuffix(
Expand Down
17 changes: 4 additions & 13 deletions velox/exec/GroupingSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,6 @@ bool GroupingSet::getOutput(
}

if (hasSpilled()) {
spill();
return getOutputWithSpill(maxOutputRows, maxOutputBytes, result);
}
VELOX_CHECK(!isDistinct());
Expand Down Expand Up @@ -804,7 +803,7 @@ const HashLookup& GroupingSet::hashLookup() const {
void GroupingSet::ensureInputFits(const RowVectorPtr& input) {
// Spilling is considered if this is a final or single aggregation and
// spillPath is set.
if (spillConfig_ == nullptr) {
if (isPartial_ || spillConfig_ == nullptr) {
return;
}

Expand Down Expand Up @@ -889,7 +888,7 @@ void GroupingSet::ensureOutputFits() {
// to reserve memory for the output as we can't reclaim much memory from this
// operator itself. The output processing can reclaim memory from the other
// operator or query through memory arbitration.
if (spillConfig_ == nullptr || hasSpilled()) {
if (isPartial_ || spillConfig_ == nullptr || hasSpilled()) {
return;
}

Expand Down Expand Up @@ -939,6 +938,7 @@ void GroupingSet::spill() {
if (table_ == nullptr || table_->numDistinct() == 0) {
return;
}

if (!hasSpilled()) {
auto rows = table_->rows();
VELOX_DCHECK(pool_.trackUsage());
Expand Down Expand Up @@ -1013,16 +1013,7 @@ bool GroupingSet::getOutputWithSpill(
if (merge_ == nullptr) {
return false;
}
bool hasData = mergeNext(maxOutputRows, maxOutputBytes, result);
if (!hasData) {
// If spill has been finalized, reset merge stream and spiller. This would
// help partial aggregation replay the spilling procedure once needed again.
merge_ = nullptr;
mergeRows_ = nullptr;
mergeArgs_.clear();
spiller_ = nullptr;
}
return hasData;
return mergeNext(maxOutputRows, maxOutputBytes, result);
}

bool GroupingSet::mergeNext(
Expand Down
47 changes: 0 additions & 47 deletions velox/exec/tests/AggregationTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3592,51 +3592,4 @@ TEST_F(AggregationTest, reclaimFromCompletedAggregation) {
}
}

TEST_F(AggregationTest, reclaimFromPartialAggregation) {
const uint64_t maxQueryCapacity = 20L << 20;
std::vector<RowVectorPtr> vectors =
createVectors(rowType_, 1024, maxQueryCapacity * 2);
createDuckDbTable(vectors);
std::unique_ptr<memory::MemoryManager> memoryManager = createMemoryManager();
const auto spillDirectory = exec::test::TempDirectoryPath::create();
core::PlanNodeId partialAggNodeId;
core::PlanNodeId finalAggNodeId;
std::shared_ptr<core::QueryCtx> queryCtx =
newQueryCtx(memoryManager, executor_, kMemoryCapacity * 2);
auto task =
AssertQueryBuilder(duckDbQueryRunner_)
.spillDirectory(spillDirectory->path)
.config(core::QueryConfig::kSpillEnabled, "true")
.config(core::QueryConfig::kPartialAggregationSpillEnabled, "true")
.config(core::QueryConfig::kAggregationSpillEnabled, "true")
.config(
core::QueryConfig::kMaxPartialAggregationMemory,
std::to_string(1LL << 30)) // disable flush
.config(
core::QueryConfig::kMaxExtendedPartialAggregationMemory,
std::to_string(1LL << 30)) // disable flush
.config(
core::QueryConfig::kAbandonPartialAggregationMinPct,
"200") // avoid abandoning
.config(
core::QueryConfig::kAbandonPartialAggregationMinRows,
std::to_string(1LL << 30)) // avoid abandoning
.queryCtx(queryCtx)
.plan(PlanBuilder()
.values(vectors)
.partialAggregation({"c0"}, {"count(1)"})
.capturePlanNodeId(partialAggNodeId)
.finalAggregation()
.capturePlanNodeId(finalAggNodeId)
.planNode())
.assertResults("SELECT c0, count(1) FROM tmp GROUP BY c0");
auto taskStats = exec::toPlanStats(task->taskStats());
auto& partialStats = taskStats.at(partialAggNodeId);
auto& finalStats = taskStats.at(finalAggNodeId);
ASSERT_GT(partialStats.spilledBytes, 0);
ASSERT_GT(finalStats.spilledBytes, 0);
task.reset();
waitForAllTasksToBeDeleted();
}

} // namespace facebook::velox::exec::test
33 changes: 10 additions & 23 deletions velox/functions/prestosql/aggregates/ApproxPercentileAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -639,19 +639,6 @@ class ApproxPercentileAggregate : public exec::Aggregate {
DecodedVector decodedDigest_;

private:
bool isConstantVector(const VectorPtr& vec) {
if (vec->isConstantEncoding()) {
return true;
}
VELOX_USER_CHECK(vec->size() > 0);
for (vector_size_t i = 1; i < vec->size(); ++i) {
if (!vec->equalValueAt(vec.get(), i, 0)) {
return false;
}
}
return true;
}

template <bool kSingleGroup, bool checkIntermediateInputs>
void addIntermediateImpl(
std::conditional_t<kSingleGroup, char*, char**> group,
Expand All @@ -663,8 +650,7 @@ class ApproxPercentileAggregate : public exec::Aggregate {
if constexpr (checkIntermediateInputs) {
VELOX_USER_CHECK(rowVec);
for (int i = kPercentiles; i <= kAccuracy; ++i) {
VELOX_USER_CHECK(isConstantVector(
rowVec->childAt(i))); // spilling flats constant encoding
VELOX_USER_CHECK(rowVec->childAt(i)->isConstantEncoding());
}
for (int i = kK; i <= kMaxValue; ++i) {
VELOX_USER_CHECK(rowVec->childAt(i)->isFlatEncoding());
Expand All @@ -691,9 +677,10 @@ class ApproxPercentileAggregate : public exec::Aggregate {
}

DecodedVector percentiles(*rowVec->childAt(kPercentiles), *baseRows);
DecodedVector percentileIsArray(
*rowVec->childAt(kPercentilesIsArray), *baseRows);
DecodedVector accuracy(*rowVec->childAt(kAccuracy), *baseRows);
auto percentileIsArray =
rowVec->childAt(kPercentilesIsArray)->asUnchecked<SimpleVector<bool>>();
auto accuracy =
rowVec->childAt(kAccuracy)->asUnchecked<SimpleVector<double>>();
auto k = rowVec->childAt(kK)->asUnchecked<SimpleVector<int32_t>>();
auto n = rowVec->childAt(kN)->asUnchecked<SimpleVector<int64_t>>();
auto minValue = rowVec->childAt(kMinValue)->asUnchecked<SimpleVector<T>>();
Expand Down Expand Up @@ -723,7 +710,7 @@ class ApproxPercentileAggregate : public exec::Aggregate {
return;
}
int i = decoded.index(row);
if (percentileIsArray.isNullAt(i)) {
if (percentileIsArray->isNullAt(i)) {
return;
}
if (!accumulator) {
Expand All @@ -733,19 +720,19 @@ class ApproxPercentileAggregate : public exec::Aggregate {
percentilesBase->elements()->asFlatVector<double>();
if constexpr (checkIntermediateInputs) {
VELOX_USER_CHECK(percentileBaseElements);
VELOX_USER_CHECK(!percentiles.isNullAt(indexInBaseVector));
VELOX_USER_CHECK(!percentilesBase->isNullAt(indexInBaseVector));
}

bool isArray = percentileIsArray.valueAt<bool>(i);
bool isArray = percentileIsArray->valueAt(i);
const double* data;
vector_size_t len;
std::vector<bool> isNull;
extractPercentiles(
percentilesBase, indexInBaseVector, data, len, isNull);
checkSetPercentile(isArray, data, len, isNull);

if (!accuracy.isNullAt(i)) {
checkSetAccuracy(accuracy.valueAt<double>(i));
if (!accuracy->isNullAt(i)) {
checkSetAccuracy(accuracy->valueAt(i));
}
}
if constexpr (kSingleGroup) {
Expand Down
7 changes: 3 additions & 4 deletions velox/functions/sparksql/aggregates/AverageAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -444,10 +444,9 @@ exec::AggregateRegistrationResult registerAverage(
if (inputType->isShortDecimal()) {
auto inputPrecision = inputType->asShortDecimal().precision();
auto inputScale = inputType->asShortDecimal().scale();
auto sumType = getDecimalSumType(inputPrecision, inputScale);
if (exec::isPartialOutput(step) ||
(step == core::AggregationNode::Step::kSingle &&
resultType->isRow())) {
auto sumType =
DECIMAL(std::min(38, inputPrecision + 10), inputScale);
if (exec::isPartialOutput(step)) {
return std::make_unique<
DecimalAverageAggregate<int64_t, int64_t>>(
resultType, sumType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,9 @@ class BloomFilterAggAggregate : public exec::Aggregate {
} // namespace

exec::AggregateRegistrationResult registerBloomFilterAggAggregate(
const std::string& name) {
const std::string& name,
bool withCompanionFunctions,
bool overwrite) {
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures{
exec::AggregateFunctionSignatureBuilder()
.argumentType("bigint")
Expand Down Expand Up @@ -318,6 +320,8 @@ exec::AggregateRegistrationResult registerBloomFilterAggAggregate(
const TypePtr& resultType,
const core::QueryConfig& config) -> std::unique_ptr<exec::Aggregate> {
return std::make_unique<BloomFilterAggAggregate>(resultType, config);
});
},
withCompanionFunctions,
overwrite);
}
} // namespace facebook::velox::functions::aggregate::sparksql
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
namespace facebook::velox::functions::aggregate::sparksql {

exec::AggregateRegistrationResult registerBloomFilterAggAggregate(
const std::string& name);
const std::string& name,
bool withCompanionFunctions,
bool overwrite);

} // namespace facebook::velox::functions::aggregate::sparksql
Loading
Loading