diff --git a/velox/connectors/Connector.h b/velox/connectors/Connector.h index 89994c2d68b24..777231c8d94b3 100644 --- a/velox/connectors/Connector.h +++ b/velox/connectors/Connector.h @@ -19,6 +19,9 @@ #include "velox/core/Context.h" #include "velox/vector/ComplexVector.h" +namespace facebook::velox::common { +class Filter; +} namespace facebook::velox::core { class ITypedExpr; } @@ -85,6 +88,14 @@ class DataSource { // processed. virtual RowVectorPtr next(uint64_t size) = 0; + // Add dynamically generated filter. + // @param outputChannel index into outputType specified in + // Connector::createDataSource() that identifies the column this filter + // applies to. + virtual void addDynamicFilter( + ChannelIndex outputChannel, + const std::shared_ptr& filter) = 0; + // Returns the number of input bytes processed so far. virtual uint64_t getCompletedBytes() = 0; diff --git a/velox/connectors/hive/HiveConnector.cpp b/velox/connectors/hive/HiveConnector.cpp index 923f22b4ebe18..cb772aebc6e4f 100644 --- a/velox/connectors/hive/HiveConnector.cpp +++ b/velox/connectors/hive/HiveConnector.cpp @@ -236,6 +236,26 @@ bool testFilters( } } // namespace +void HiveDataSource::addDynamicFilter( + ChannelIndex outputChannel, + const std::shared_ptr& filter) { + pendingDynamicFilters_.emplace(outputChannel, filter); +} + +void HiveDataSource::addPendingDynamicFilters() { + for (const auto& entry : pendingDynamicFilters_) { + common::Subfield subfield{outputType_->nameOf(entry.first)}; + auto fieldSpec = scanSpec_->getOrCreateChild(subfield); + if (fieldSpec->filter()) { + fieldSpec->filter()->mergeWith(entry.second.get()); + } else { + fieldSpec->setFilter(entry.second->clone()); + } + } + scanSpec_->resetCachedValues(); + pendingDynamicFilters_.clear(); +}; + void HiveDataSource::addSplit(std::shared_ptr split) { VELOX_CHECK( split_ == nullptr, @@ -245,6 +265,8 @@ void HiveDataSource::addSplit(std::shared_ptr split) { VLOG(1) << "Adding split " << split_->toString(); + addPendingDynamicFilters(); + fileHandle_ = fileHandleFactory_->generate(split_->filePath); if (dataCache_) { auto dataCacheConfig = std::make_shared(); diff --git a/velox/connectors/hive/HiveConnector.h b/velox/connectors/hive/HiveConnector.h index 3994bed9fa3b9..976129a9489ca 100644 --- a/velox/connectors/hive/HiveConnector.h +++ b/velox/connectors/hive/HiveConnector.h @@ -128,6 +128,10 @@ class HiveDataSource : public DataSource { void addSplit(std::shared_ptr split) override; + void addDynamicFilter( + ChannelIndex outputChannel, + const std::shared_ptr& filter) override; + RowVectorPtr next(uint64_t size) override; uint64_t getCompletedRows() override { @@ -157,12 +161,14 @@ class HiveDataSource : public DataSource { void setNullConstantValue(common::ScanSpec* spec, const TypePtr& type) const; + void addPendingDynamicFilters(); + const std::shared_ptr outputType_; FileHandleFactory* fileHandleFactory_; velox::memory::MemoryPool* pool_; std::vector regularColumns_; std::unique_ptr columnReaderFactory_; - std::unique_ptr scanSpec_ = nullptr; + std::unique_ptr scanSpec_; std::shared_ptr split_; dwio::common::ReaderOptions readerOpts_; dwio::common::RowReaderOptions rowReaderOpts_; @@ -173,6 +179,10 @@ class HiveDataSource : public DataSource { std::shared_ptr readerOutputType_; bool emptySplit_; + // Dynamically pushed down filters to be added to scanSpec_ on next split. + std::unordered_map> + pendingDynamicFilters_; + // Number of splits skipped based on statistics. int64_t skippedSplits_{0}; diff --git a/velox/dwio/dwrf/reader/ScanSpec.h b/velox/dwio/dwrf/reader/ScanSpec.h index 64a565d46034a..a2624ffe0ab41 100644 --- a/velox/dwio/dwrf/reader/ScanSpec.h +++ b/velox/dwio/dwrf/reader/ScanSpec.h @@ -233,6 +233,15 @@ class ScanSpec { // result of runtime adaptation. bool hasFilter() const; + // Resets cached values after this or children were updated, e.g. a new filter + // was added or existing filter was modified. + void resetCachedValues() const { + hasFilter_.clear(); + for (auto& child : children_) { + child->resetCachedValues(); + } + } + void setEnableFilterReorder(bool enableFilterReorder) { enableFilterReorder_ = enableFilterReorder; } diff --git a/velox/dwio/dwrf/reader/SelectiveColumnReader.cpp b/velox/dwio/dwrf/reader/SelectiveColumnReader.cpp index d5d2406d03f96..11f6197c51f43 100644 --- a/velox/dwio/dwrf/reader/SelectiveColumnReader.cpp +++ b/velox/dwio/dwrf/reader/SelectiveColumnReader.cpp @@ -118,7 +118,7 @@ void SelectiveColumnReader::seekTo(vector_size_t offset, bool readsNullsOnly) { } readOffset_ = offset; } else { - VELOX_CHECK(false, "Seeking backward on a ColumnReader"); + VELOX_FAIL("Seeking backward on a ColumnReader"); } } @@ -219,7 +219,7 @@ void SelectiveColumnReader::getFlatValues( VectorPtr* result, const TypePtr& type, bool isFinal) { - VELOX_CHECK(valueSize_ != kNoValueSize); + VELOX_CHECK_NE(valueSize_, kNoValueSize); VELOX_CHECK(mayGetValues_); if (isFinal) { mayGetValues_ = false; @@ -258,7 +258,7 @@ void SelectiveColumnReader::getFlatValues( bool isFinal) { constexpr int32_t kWidth = V8::VSize; static_assert(kWidth == 32); - VELOX_CHECK(valueSize_ == sizeof(int8_t)); + VELOX_CHECK_EQ(valueSize_, sizeof(int8_t)); compactScalarValues(rows, isFinal); auto boolValues = AlignedBuffer::allocate(numValues_, &memoryPool, false); @@ -283,12 +283,12 @@ void SelectiveColumnReader::getFlatValues( template void SelectiveColumnReader::upcastScalarValues(RowSet rows) { - VELOX_CHECK(rows.size() <= numValues_); + VELOX_CHECK_LE(rows.size(), numValues_); VELOX_CHECK(!rows.empty()); if (!values_) { return; } - VELOX_CHECK(sizeof(TVector) > sizeof(T)); + VELOX_CHECK_GT(sizeof(TVector), sizeof(T)); // Since upcast is not going to be a common path, allocate buffer to copy // upcasted values to and then copy back to the values buffer. std::vector buf; @@ -338,7 +338,7 @@ void SelectiveColumnReader::upcastScalarValues(RowSet rows) { template void SelectiveColumnReader::compactScalarValues(RowSet rows, bool isFinal) { - VELOX_CHECK(rows.size() <= numValues_); + VELOX_CHECK_LE(rows.size(), numValues_); VELOX_CHECK(!rows.empty()); if (!values_ || (rows.size() == numValues_ && sizeof(T) == sizeof(TVector))) { if (values_) { @@ -346,7 +346,7 @@ void SelectiveColumnReader::compactScalarValues(RowSet rows, bool isFinal) { } return; } - VELOX_CHECK(sizeof(TVector) <= sizeof(T)); + VELOX_CHECK_LE(sizeof(TVector), sizeof(T)); T* typedSourceValues = reinterpret_cast(rawValues_); TVector* typedDestValues = reinterpret_cast(rawValues_); RowSet sourceRows; @@ -784,7 +784,7 @@ class ColumnVisitor { return 0; } if (nextNonNull < 64) { - VELOX_CHECK(rowIndex_ <= rowOfNullWord + nextNonNull); + VELOX_CHECK_LE(rowIndex_, rowOfNullWord + nextNonNull); rowIndex_ = rowOfNullWord + nextNonNull; current = currentRow(); return 0; @@ -1101,9 +1101,8 @@ class SelectiveByteRleColumnReader : public SelectiveColumnReader { getFlatValues(rows, result); break; default: - VELOX_CHECK( - false, - "Result type {} not supported in ByteRLE encoding", + VELOX_FAIL( + "Result type not supported in ByteRLE encoding: {}", requestedType_->toString()); } } @@ -2081,7 +2080,7 @@ void SelectiveIntegerDictionaryColumnReader::readWithVisitor( RowSet rows, ColumnVisitor visitor) { vector_size_t numRows = rows.back() + 1; - VELOX_CHECK(rleVersion_ == RleVersion_1); + VELOX_CHECK_EQ(rleVersion_, RleVersion_1); auto reader = reinterpret_cast*>(dataReader_.get()); if (nullsInReadRange_) { reader->readWithVisitor(nullsInReadRange_->as(), visitor); @@ -3595,8 +3594,9 @@ static void scatter(RowSet rows, VectorPtr* result) { } void ColumnLoader::load(RowSet rows, ValueHook* hook, VectorPtr* result) { - VELOX_CHECK( - version_ == structReader_->numReads(), + VELOX_CHECK_EQ( + version_, + structReader_->numReads(), "Loading LazyVector after the enclosing reader has moved"); auto offset = structReader_->lazyVectorReadOffset(); auto incomingNulls = structReader_->nulls(); @@ -4204,7 +4204,7 @@ std::unique_ptr SelectiveColumnReader::build( case TypeKind::MAP: if (stripe.getEncoding(ek).kind() == proto::ColumnEncoding_Kind_MAP_FLAT) { - VELOX_CHECK(false, "SelectiveColumnReader does not support flat maps"); + VELOX_UNSUPPORTED("SelectiveColumnReader does not support flat maps"); } return std::make_unique( ek, requestedType, dataType, stripe, scanSpec); diff --git a/velox/exec/Driver.cpp b/velox/exec/Driver.cpp index 46556c0775322..5c8469b2e167e 100644 --- a/velox/exec/Driver.cpp +++ b/velox/exec/Driver.cpp @@ -216,6 +216,68 @@ Driver::Driver( ctx_->driver = this; } +namespace { +/// Checks if output channel is produced using identity projection and returns +/// input channel if so. +std::optional getIdentityProjection( + const std::vector& projections, + ChannelIndex outputChannel) { + for (const auto& projection : projections) { + if (projection.outputChannel == outputChannel) { + return projection.inputChannel; + } + } + return std::nullopt; +} +} // namespace + +void Driver::pushdownFilters(int operatorIndex) { + auto op = operators_[operatorIndex].get(); + const auto& filters = op->getDynamicFilters(); + if (filters.empty()) { + return; + } + + op->stats().addRuntimeStat("dynamicFiltersProduced", filters.size()); + + // Walk operator list upstream and find a place to install the filters. + for (const auto& entry : filters) { + auto channel = entry.first; + for (auto i = operatorIndex - 1; i >= 0; --i) { + auto prevOp = operators_[i].get(); + + if (i == 0) { + // Source operator. + VELOX_CHECK( + prevOp->canAddDynamicFilter(), + "Cannot push down dynamic filters produced by {}", + op->toString()); + prevOp->addDynamicFilter(channel, entry.second); + prevOp->stats().addRuntimeStat("dynamicFiltersAccepted", 1); + break; + } + + const auto& identityProjections = prevOp->identityProjections(); + auto inputChannel = getIdentityProjection(identityProjections, channel); + if (!inputChannel.has_value()) { + // Filter channel is not an identity projection. + VELOX_CHECK( + prevOp->canAddDynamicFilter(), + "Cannot push down dynamic filters produced by {}", + op->toString()); + prevOp->addDynamicFilter(channel, entry.second); + prevOp->stats().addRuntimeStat("dynamicFiltersAccepted", 1); + break; + } + + // Continue walking upstream. + channel = inputChannel.value(); + } + } + + op->clearDynamicFilters(); +} + core::StopReason Driver::runInternal( std::shared_ptr& self, std::shared_ptr* blockingState) { @@ -296,6 +358,7 @@ core::StopReason Driver::runInternal( op->stats().outputBytes += resultBytes; } } + pushdownFilters(i); if (result) { OperationTimer timer(nextOp->stats().addInputTiming); nextOp->stats().inputPositions += result->size(); @@ -338,8 +401,11 @@ core::StopReason Driver::runInternal( // control here so it can advance. If it is again blocked, // this will be detected when trying to add input and we // will come back here after this is again on thread. - OperationTimer timer(op->stats().getOutputTiming); - op->getOutput(); + { + OperationTimer timer(op->stats().getOutputTiming); + op->getOutput(); + } + pushdownFilters(i); continue; } if (i == 0) { @@ -437,7 +503,7 @@ bool Driver::terminate() { return false; } -bool Driver::mayPushdownAggregation(Operator* aggregation) { +bool Driver::mayPushdownAggregation(Operator* aggregation) const { for (auto i = 1; i < operators_.size(); ++i) { auto op = operators_[i].get(); if (aggregation == op) { @@ -447,8 +513,58 @@ bool Driver::mayPushdownAggregation(Operator* aggregation) { return false; } } - VELOX_CHECK(false, "{} not found in its Driver", aggregation->toString()); - return false; + VELOX_FAIL( + "Aggregation operator not found in its Driver: {}", + aggregation->toString()); +} + +std::unordered_set Driver::canPushdownFilters( + Operator* FOLLY_NONNULL filterSource, + const std::vector& channels) const { + int filterSourceIndex = -1; + for (auto i = 0; i < operators_.size(); ++i) { + auto op = operators_[i].get(); + if (filterSource == op) { + filterSourceIndex = i; + break; + } + } + VELOX_CHECK_GE( + filterSourceIndex, + 0, + "Operator not found in its Driver: {}", + filterSource->toString()); + + std::unordered_set supportedChannels; + for (auto i = 0; i < channels.size(); ++i) { + auto channel = channels[i]; + for (auto j = filterSourceIndex - 1; j >= 0; --j) { + auto prevOp = operators_[j].get(); + + if (j == 0) { + // Source operator. + if (prevOp->canAddDynamicFilter()) { + supportedChannels.emplace(channels[i]); + } + break; + } + + const auto& identityProjections = prevOp->identityProjections(); + auto inputChannel = getIdentityProjection(identityProjections, channel); + if (!inputChannel.has_value()) { + // Filter channel is not an identity projection. + if (prevOp->canAddDynamicFilter()) { + supportedChannels.emplace(channels[i]); + } + break; + } + + // Continue walking upstream. + channel = inputChannel.value(); + } + } + + return supportedChannels; } Operator* FOLLY_NULLABLE diff --git a/velox/exec/Driver.h b/velox/exec/Driver.h index b0adf7765414f..6996d38c38b8f 100644 --- a/velox/exec/Driver.h +++ b/velox/exec/Driver.h @@ -154,7 +154,13 @@ class Driver { // Returns true if all operators between the source and 'aggregation' are // order-preserving and do not increase cardinality. - bool mayPushdownAggregation(Operator* FOLLY_NONNULL aggregation); + bool mayPushdownAggregation(Operator* FOLLY_NONNULL aggregation) const; + + // Returns a subset of channels for which there are operators upstream from + // filterSource that accept dynamically generated filters. + std::unordered_set canPushdownFilters( + Operator* FOLLY_NONNULL filterSource, + const std::vector& channels) const; // Returns the Operator with 'planNodeId.' or nullptr if not // found. For example, hash join probe accesses the corresponding @@ -176,6 +182,10 @@ class Driver { void close(); + // Push down dynamic filters produced by the operator at the specified + // position in the pipeline. + void pushdownFilters(int operatorIndex); + std::unique_ptr ctx_; std::shared_ptr task_; core::CancelPoolPtr cancelPool_; diff --git a/velox/exec/HashProbe.cpp b/velox/exec/HashProbe.cpp index ec75f7c2b6fab..5804b47f54129 100644 --- a/velox/exec/HashProbe.cpp +++ b/velox/exec/HashProbe.cpp @@ -175,6 +175,20 @@ BlockingReason HashProbe::isBlocked(ContinueFuture* future) { joinType_ == core::JoinType::kSemi) { isFinishing_ = true; } + } else if ( + joinType_ == core::JoinType::kInner && + table_->hashMode() != BaseHashTable::HashMode::kHash) { + const auto& buildHashers = table_->hashers(); + auto channels = operatorCtx_->driverCtx()->driver->canPushdownFilters( + this, keyChannels_); + dynamicFilterBuilders_.resize(keyChannels_.size()); + for (auto i = 0; i < keyChannels_.size(); i++) { + auto it = channels.find(keyChannels_[i]); + if (it != channels.end()) { + dynamicFilterBuilders_[i].emplace(DynamicFilterBuilder( + *(buildHashers[i].get()), keyChannels_[i], dynamicFilters_)); + } + } } } @@ -195,6 +209,16 @@ void HashProbe::addInput(RowVectorPtr input) { nonNullRows_.setAll(); deselectRowsWithNulls(*input_, keyChannels_, nonNullRows_); + auto getDynamicFilterBuilder = [&](auto i) -> DynamicFilterBuilder* { + if (!dynamicFilterBuilders_.empty()) { + auto& builder = dynamicFilterBuilders_[i]; + if (builder.has_value() && builder->isActive()) { + return &(builder.value()); + } + } + return nullptr; + }; + activeRows_ = nonNullRows_; lookup_->hashes.resize(input_->size()); auto mode = table_->hashMode(); @@ -202,9 +226,18 @@ void HashProbe::addInput(RowVectorPtr input) { for (auto i = 0; i < keyChannels_.size(); ++i) { auto key = input_->loadedChildAt(keyChannels_[i]); if (mode != BaseHashTable::HashMode::kHash) { + auto* dynamicFilterBuilder = getDynamicFilterBuilder(i); + if (dynamicFilterBuilder) { + dynamicFilterBuilder->addInput(activeRows_.countSelected()); + } + valueIdDecoder_.decode(*key, activeRows_); buildHashers[i]->lookupValueIds( valueIdDecoder_, activeRows_, deduppedHashes_, &lookup_->hashes); + + if (dynamicFilterBuilder) { + dynamicFilterBuilder->addOutput(activeRows_.countSelected()); + } } else { hashers_[i]->hash(*key, activeRows_, i > 0, &lookup_->hashes); } diff --git a/velox/exec/HashProbe.h b/velox/exec/HashProbe.h index fe77018f679ec..ecd296f0dcd0f 100644 --- a/velox/exec/HashProbe.h +++ b/velox/exec/HashProbe.h @@ -63,6 +63,54 @@ class HashProbe : public Operator { // Channel of probe keys in 'input_'. std::vector keyChannels_; + // Tracks selectivity of a given VectorHasher from the build side and creates + // a filter to push down upstream if the hasher is somewhat selective. + class DynamicFilterBuilder { + public: + DynamicFilterBuilder( + const VectorHasher& buildHasher, + ChannelIndex channel, + std::unordered_map>& + dynamicFilters) + : buildHasher_{buildHasher}, + channel_{channel}, + dynamicFilters_{dynamicFilters} {} + + bool isActive() const { + return isActive_; + } + + void addInput(uint64_t numIn) { + numIn_ += numIn; + } + + void addOutput(uint64_t numOut) { + numOut_ += numOut; + + // Add filter if VectorHasher is somewhat selective, e.g. dropped at least + // 1/3 of the rows. Make sure we have seen at least 10K rows. + if (isActive_ && numIn_ >= 10'000 && numOut_ < 1.66 * numIn_) { + if (auto filter = buildHasher_.getFilter(false)) { + dynamicFilters_.emplace(channel_, std::move(filter)); + } + isActive_ = false; + } + } + + private: + const VectorHasher& buildHasher_; + const ChannelIndex channel_; + std::unordered_map>& + dynamicFilters_; + uint64_t numIn_{0}; + uint64_t numOut_{0}; + bool isActive_{true}; + }; + + // List of DynamicFilterBuilders aligned with keyChannels_. Contains a valid + // entry if the driver can push down a filter on the corresponding join key. + std::vector> dynamicFilterBuilders_; + std::vector> hashers_; // Table shared between other HashProbes in other Drivers of the diff --git a/velox/exec/Operator.h b/velox/exec/Operator.h index 33b19073728c9..949cd55487e3c 100644 --- a/velox/exec/Operator.h +++ b/velox/exec/Operator.h @@ -17,6 +17,7 @@ #include "velox/common/time/CpuWallTimer.h" #include "velox/core/PlanNode.h" #include "velox/exec/Driver.h" +#include "velox/type/Filter.h" namespace facebook::velox::exec { @@ -179,6 +180,10 @@ class OperatorCtx { return driverCtx_->execCtx.get(); } + core::QueryCtx* queryCtx() const { + return driverCtx_->execCtx->queryCtx(); + } + Driver* driver() const { return driverCtx_->driver; } @@ -246,6 +251,44 @@ class Operator { return isFinishing_; } + // Returns single-column dynamically generated filters to be pushed down to + // upstream operators. Used to push down filters on join keys from broadcast + // hash join into probe-side table scan. Can also be used to push down TopN + // cutoff. + virtual const std:: + unordered_map>& + getDynamicFilters() const { + return dynamicFilters_; + } + + // Clears dynamically generated filters. Called after filters were pushed + // down. + virtual void clearDynamicFilters() { + dynamicFilters_.clear(); + } + + // Returns true if this operator would accept a filter dynamically generated + // by a downstream operator. + virtual bool canAddDynamicFilter() const { + return false; + } + + // Adds a filter dynamically generated by a downstream operator. Called only + // if canAddFilter() returns true. + virtual void addDynamicFilter( + ChannelIndex outputChannel, + const std::shared_ptr& filter) { + VELOX_UNSUPPORTED( + "This operator doesn't support dynamic filter pushdown: {}", + toString()); + } + + // Returns a list of identify projections, e.g. columns that are projected + // as-is possibly after applying a filter. + const std::vector& identityProjections() const { + return identityProjections_; + } + // Frees all resources associated with 'this'. No other methods // should be called after this. virtual void close() { @@ -329,7 +372,10 @@ class Operator { // i.e. one could copy directly from input to output if no // cardinality change. bool isIdentityProjection_ = false; -}; + + std::unordered_map> + dynamicFilters_; +}; // namespace facebook::velox::exec constexpr ChannelIndex kConstantChannel = std::numeric_limits::max(); diff --git a/velox/exec/TableScan.cpp b/velox/exec/TableScan.cpp index e7c4973dac6fc..bc5c740e37e6b 100644 --- a/velox/exec/TableScan.cpp +++ b/velox/exec/TableScan.cpp @@ -75,6 +75,10 @@ RowVectorPtr TableScan::getOutput() { tableHandle_, columnHandles_, connectorQueryCtx_.get()); + for (const auto& entry : pendingDynamicFilters_) { + dataSource_->addDynamicFilter(entry.first, entry.second); + } + pendingDynamicFilters_.clear(); } else { VELOX_CHECK( connector_->connectorId() == connectorSplit->connectorId, @@ -103,6 +107,16 @@ RowVectorPtr TableScan::getOutput() { } } +void TableScan::addDynamicFilter( + ChannelIndex outputChannel, + const std::shared_ptr& filter) { + if (dataSource_) { + dataSource_->addDynamicFilter(outputChannel, filter); + } else { + pendingDynamicFilters_.emplace(outputChannel, filter); + } +} + void TableScan::close() { // TODO Implement } diff --git a/velox/exec/TableScan.h b/velox/exec/TableScan.h index ddfa76f6d318c..ac9e85f00d957 100644 --- a/velox/exec/TableScan.h +++ b/velox/exec/TableScan.h @@ -43,6 +43,16 @@ class TableScan : public SourceOperator { close(); } + bool canAddDynamicFilter() const override { + // TODO Consult with the connector. Return true only if connector can accept + // dynamic filters. + return true; + } + + void addDynamicFilter( + ChannelIndex outputChannel, + const std::shared_ptr& filter) override; + void close() override; private: @@ -63,5 +73,8 @@ class TableScan : public SourceOperator { bool noMoreSplits_ = false; // The bucketed group id we are in the middle of processing. int32_t currentSplitGroupId_{-1}; + // Dynamic filters to add to the data source when it gets created. + std::unordered_map> + pendingDynamicFilters_; }; } // namespace facebook::velox::exec diff --git a/velox/exec/tests/HashJoinTest.cpp b/velox/exec/tests/HashJoinTest.cpp index 9b3514633422a..d3e4f9ef3d881 100644 --- a/velox/exec/tests/HashJoinTest.cpp +++ b/velox/exec/tests/HashJoinTest.cpp @@ -18,6 +18,8 @@ #include "velox/exec/tests/Cursor.h" #include "velox/exec/tests/HiveConnectorTestBase.h" #include "velox/exec/tests/PlanBuilder.h" +#include "velox/type/tests/FilterBuilder.h" +#include "velox/type/tests/SubfieldFiltersBuilder.h" using namespace facebook::velox; using namespace facebook::velox::exec; @@ -87,6 +89,27 @@ class HashJoinTest : public HiveConnectorTestBase { std::iota(channels.begin(), channels.end(), 0); return channels; } + + static RuntimeMetric getFiltersProduced( + const std::shared_ptr& task, + int operatorIndex) { + auto stats = task->taskStats().pipelineStats.front().operatorStats; + return stats[operatorIndex].runtimeStats["dynamicFiltersProduced"]; + }; + + static RuntimeMetric getFiltersAccepted( + const std::shared_ptr& task, + int operatorIndex) { + auto stats = task->taskStats().pipelineStats.front().operatorStats; + return stats[operatorIndex].runtimeStats["dynamicFiltersAccepted"]; + }; + + static uint64_t getInputPositions( + const std::shared_ptr& task, + int operatorIndex) { + auto stats = task->taskStats().pipelineStats.front().operatorStats; + return stats[operatorIndex].inputPositions; + }; }; TEST_F(HashJoinTest, bigintArray) { @@ -436,3 +459,143 @@ TEST_F(HashJoinTest, antiJoin) { assertQuery(op, "SELECT t.c1 FROM t WHERE t.c0 NOT IN (SELECT c0 FROM u)"); } + +TEST_F(HashJoinTest, dynamicFilters) { + std::vector leftVectors; + leftVectors.reserve(20); + + auto leftFiles = makeFilePaths(20); + + for (int i = 0; i < 20; i++) { + auto rowVector = makeRowVector({ + makeFlatVector(1'024, [&](auto row) { return row - i * 10; }), + makeFlatVector(1'024, [](auto row) { return row; }), + }); + leftVectors.push_back(rowVector); + writeToFile(leftFiles[i]->path, kWriter, rowVector); + } + + // 100 key values in [35, 233] range. + auto rightVectors = {makeRowVector( + {makeFlatVector(100, [](auto row) { return 35 + row * 2; })})}; + + createDuckDbTable("t", {leftVectors}); + createDuckDbTable("u", {rightVectors}); + + auto probeType = ROW({"c0", "c1"}, {INTEGER(), BIGINT()}); + + // Basic push-down. + { + auto op = PlanBuilder(10) + .tableScan(probeType) + .hashJoin( + {0}, + {0}, + PlanBuilder(0).values(rightVectors).planNode(), + "", + {1}) + .project({"c1 + 1"}) + .planNode(); + + auto task = assertQuery( + op, {{10, leftFiles}}, "SELECT t.c1 + 1 FROM t, u WHERE t.c0 = u.c0"); + EXPECT_EQ(1, getFiltersProduced(task, 1).sum); + EXPECT_EQ(1, getFiltersAccepted(task, 0).sum); + EXPECT_LT(getInputPositions(task, 1), 1024 * 20); + } + + // Push-down that requires merging filters. + { + auto filters = + common::test::singleSubfieldFilter("c0", common::test::lessThan(500)); + auto op = PlanBuilder(10) + .tableScan( + probeType, + makeTableHandle(std::move(filters)), + allRegularColumns(probeType)) + .hashJoin( + {0}, + {0}, + PlanBuilder(0).values(rightVectors).planNode(), + "", + {1}) + .project({"c1 + 1"}) + .planNode(); + + auto task = assertQuery( + op, + {{10, leftFiles}}, + "SELECT t.c1 + 1 FROM t, u WHERE t.c0 = u.c0 AND t.c0 < 500"); + EXPECT_EQ(1, getFiltersProduced(task, 1).sum); + EXPECT_EQ(1, getFiltersAccepted(task, 0).sum); + } + + // Disable filter push-down by using highly selective filter in the scan. + { + auto filters = + common::test::singleSubfieldFilter("c0", common::test::lessThan(200)); + auto op = PlanBuilder(10) + .tableScan( + probeType, + makeTableHandle(std::move(filters)), + allRegularColumns(probeType)) + .hashJoin( + {0}, + {0}, + PlanBuilder(0).values(rightVectors).planNode(), + "", + {1}) + .project({"c1 + 1"}) + .planNode(); + + auto task = assertQuery( + op, + {{10, leftFiles}}, + "SELECT t.c1 + 1 FROM t, u WHERE t.c0 = u.c0 AND t.c0 < 200"); + EXPECT_EQ(0, getFiltersProduced(task, 1).sum); + EXPECT_EQ(0, getFiltersAccepted(task, 0).sum); + } + + // Disable filter push-down by using values in place of scan. + { + auto op = PlanBuilder(10) + .values(leftVectors) + .hashJoin( + {0}, + {0}, + PlanBuilder(0).values(rightVectors).planNode(), + "", + {1}) + .project({"c1 + 1"}) + .planNode(); + + auto task = assertQuery(op, "SELECT t.c1 + 1 FROM t, u WHERE t.c0 = u.c0"); + EXPECT_EQ(0, getFiltersProduced(task, 1).sum); + EXPECT_EQ(0, getFiltersAccepted(task, 0).sum); + EXPECT_EQ(getInputPositions(task, 1), 1024 * 20); + } + + // Disable filter push-down by using an expression as the join key on the + // probe side. + { + auto op = PlanBuilder(10) + .tableScan(probeType) + .project({"c0 + 1", "c1"}) + .hashJoin( + {0}, + {0}, + PlanBuilder(0).values(rightVectors).planNode(), + "", + {1}) + .project({"p1 + 1"}) + .planNode(); + + auto task = assertQuery( + op, + {{10, leftFiles}}, + "SELECT t.c1 + 1 FROM t, u WHERE (t.c0 + 1) = u.c0"); + EXPECT_EQ(0, getFiltersProduced(task, 1).sum); + EXPECT_EQ(0, getFiltersAccepted(task, 0).sum); + EXPECT_EQ(getInputPositions(task, 1), 1024 * 20); + } +}