diff --git a/velox/expression/Expr.cpp b/velox/expression/Expr.cpp index 8aa1c7ab44ab..4ffda1995c3e 100644 --- a/velox/expression/Expr.cpp +++ b/velox/expression/Expr.cpp @@ -103,15 +103,14 @@ bool hasConditionals(Expr* expr) { return false; } -void updateAllAndMultiplyReferencedFields( - std::set& allFields, +void updateMultiplyReferencedFields( + const std::vector& allFields, std::set& multiRefFields, const std::vector& fieldsToAdd) { for (auto* newField : fieldsToAdd) { - if (allFields.find(newField) != allFields.end()) { + if (isMember(allFields, *newField)) { multiRefFields.insert(newField); } - allFields.insert(newField); } } @@ -198,14 +197,13 @@ void Expr::computeMetadata() { deterministic_ = vectorFunction_->isDeterministic(); } - std::set allFields; for (auto& input : inputs_) { input->computeMetadata(); deterministic_ &= input->deterministic_; propagatesNulls_ &= input->propagatesNulls_; + updateMultiplyReferencedFields( + distinctFields_, multiRefFields_, input->distinctFields_); mergeFields(distinctFields_, input->distinctFields_); - updateAllAndMultiplyReferencedFields( - allFields, multiRefFields_, input->distinctFields_); } if (isSpecialForm()) { propagatesNulls_ = propagatesNulls(); @@ -1333,10 +1331,11 @@ ExprSet::ExprSet( : execCtx_(execCtx) { exprs_ = compileExpressions( std::move(sources), execCtx, this, enableConstantFolding); - std::set allFields; + std::vector allFields; for (auto& expr : exprs_) { - updateAllAndMultiplyReferencedFields( + updateMultiplyReferencedFields( allFields, multiRefFields_, expr->distinctFields()); + mergeFields(allFields, expr->distinctFields()); } } diff --git a/velox/expression/tests/ExprTest.cpp b/velox/expression/tests/ExprTest.cpp index 7a482f1e7618..b596101ea781 100644 --- a/velox/expression/tests/ExprTest.cpp +++ b/velox/expression/tests/ExprTest.cpp @@ -942,12 +942,11 @@ TEST_F(ExprTest, selectiveLazyLoadingOr) { TEST_F(ExprTest, lazyVectorAccessTwiceWithDifferentRows) { const vector_size_t size = 4; - // [1, 1, 1, null] - auto inputC0 = makeNullableFlatVector({1, 1, 1, std::nullopt}); + auto c0 = makeNullableFlatVector({1, 1, 1, std::nullopt}); // [0, 1, 2, 3] if fully loaded std::vector loadedRows; auto valueAt = [](auto row) { return row; }; - VectorPtr inputC1 = std::make_shared( + VectorPtr c1 = std::make_shared( pool_.get(), BIGINT(), size, @@ -958,17 +957,12 @@ TEST_F(ExprTest, lazyVectorAccessTwiceWithDifferentRows) { return makeFlatVector(rows.back() + 1, valueAt); })); - // isFinalSelection_ == true auto result = evaluate( - "row_constructor(c0 + c1, if (c1 >= 0, c1, 0))", - makeRowVector({inputC0, inputC1})); - - // [1, 2, 3, null] - auto outputCol0 = makeNullableFlatVector({1, 2, 3, std::nullopt}); - // [0, 1, 2, 3] - auto outputCol1 = makeNullableFlatVector({0, 1, 2, 3}); - // [(1, 0), (2, 1), (3, 2), (null, 3)] - auto expected = ExprTest::makeRowVector({outputCol0, outputCol1}); + "row_constructor(c0 + c1, if (c1 >= 0, c1, 0))", makeRowVector({c0, c1})); + + auto expected = makeRowVector( + {makeNullableFlatVector({1, 2, 3, std::nullopt}), + makeNullableFlatVector({0, 1, 2, 3})}); assertEqualVectors(expected, result); }