Skip to content

Commit

Permalink
Folder: function
Browse files Browse the repository at this point in the history
relative pr:

Fix replace SparkSQL function facebookincubator#277
Support kPreceeding & kFollowing for window range frame type facebookincubator#287
support timestamp hash facebookincubator#269
Spark sum can overflow facebookincubator#101
Support float & double types in pmod function facebookincubator#157
Implement datetime functions in velox/sparksql. facebookincubator#81
Fix type check in MapFunction facebookincubator#273
Let function validation fail for lookaround pattern in RE2-based implementation facebookincubator#124
Register lpad/rpad functions for Spark SQL. facebookincubator#63
Support substring_index sql function facebookincubator#189
Fix First/Last aggregate functions intermediate type and support decimal facebookincubator#245
Support date_add spark sql function facebookincubator#144
  • Loading branch information
zhejiangxiaomai committed Jul 3, 2023
1 parent e7134e3 commit 44b2919
Show file tree
Hide file tree
Showing 67 changed files with 5,887 additions and 156 deletions.
3 changes: 2 additions & 1 deletion velox/functions/FunctionRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ std::shared_ptr<const Type> resolveCallableSpecialForm(
const std::string& functionName,
const std::vector<TypePtr>& argTypes) {
// TODO Replace with struct_pack
if (functionName == "row_constructor") {
if (functionName == "row_constructor" ||
functionName == "row_constructor_with_null") {
auto numInput = argTypes.size();
std::vector<TypePtr> types(numInput);
std::vector<std::string> names(numInput);
Expand Down
3 changes: 2 additions & 1 deletion velox/functions/lib/aggregates/BitwiseAggregateBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ exec::AggregateRegistrationResult registerBitwise(const std::string& name) {
name,
inputType->kindName());
}
});
},
true);
}

} // namespace facebook::velox::functions::aggregate
8 changes: 8 additions & 0 deletions velox/functions/lib/string/StringCore.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ inline int64_t findNthInstanceByteIndexFromEnd(
/// each charecter. When inputString is empty results is empty.
/// replace("", "", "x") = ""
/// replace("aa", "", "x") = "xaxax"
template <bool ignoreEmptyReplaced = false>
inline static size_t replace(
char* outputString,
const std::string_view& inputString,
Expand All @@ -309,6 +310,13 @@ inline static size_t replace(
return 0;
}

if (ignoreEmptyReplaced && replaced.size() == 0) {
if (!inPlace) {
std::memcpy(outputString, inputString.data(), inputString.size());
}
return inputString.size();
}

size_t readPosition = 0;
size_t writePosition = 0;
// Copy needed in out of place replace, and when replaced and replacement are
Expand Down
14 changes: 10 additions & 4 deletions velox/functions/lib/string/StringImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,10 @@ stringPosition(const T& string, const T& subString, int64_t instance = 0) {

/// Replace replaced with replacement in inputString and write results to
/// outputString.
template <typename TOutString, typename TInString>
template <
bool ignoreEmptyReplaced = false,
typename TOutString,
typename TInString>
FOLLY_ALWAYS_INLINE void replace(
TOutString& outputString,
const TInString& inputString,
Expand All @@ -200,7 +203,7 @@ FOLLY_ALWAYS_INLINE void replace(
(inputString.size() / replaced.size()) * replacement.size());
}

auto outputSize = stringCore::replace(
auto outputSize = stringCore::replace<ignoreEmptyReplaced>(
outputString.data(),
std::string_view(inputString.data(), inputString.size()),
std::string_view(replaced.data(), replaced.size()),
Expand All @@ -211,14 +214,17 @@ FOLLY_ALWAYS_INLINE void replace(
}

/// Replace replaced with replacement in place in string.
template <typename TInOutString, typename TInString>
template <
bool ignoreEmptyReplaced = false,
typename TInOutString,
typename TInString>
FOLLY_ALWAYS_INLINE void replaceInPlace(
TInOutString& string,
const TInString& replaced,
const TInString& replacement) {
assert(replacement.size() <= replaced.size() && "invalid inplace replace");

auto outputSize = stringCore::replace(
auto outputSize = stringCore::replace<ignoreEmptyReplaced>(
string.data(),
std::string_view(string.data(), string.size()),
std::string_view(replaced.data(), replaced.size()),
Expand Down
20 changes: 12 additions & 8 deletions velox/functions/lib/tests/DateTimeFormatterTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -547,11 +547,13 @@ TEST_F(JodaDateTimeFormatterTest, parseYear) {
EXPECT_THROW(parseJoda("++100", "y"), VeloxUserError);

// Probe the year range
EXPECT_THROW(parseJoda("-292275056", "y"), VeloxUserError);
EXPECT_THROW(parseJoda("292278995", "y"), VeloxUserError);
EXPECT_EQ(
util::fromTimestampString("292278994-01-01"),
parseJoda("292278994", "y").timestamp);
// Temporarily removed for adapting to spark semantic (not allowed year digits
// larger than 7).
// EXPECT_THROW(parseJoda("-292275056", "y"), VeloxUserError);
// EXPECT_THROW(parseJoda("292278995", "y"), VeloxUserError);
// EXPECT_EQ(
// util::fromTimestampString("292278994-01-01"),
// parseJoda("292278994", "y").timestamp);
}

TEST_F(JodaDateTimeFormatterTest, parseWeekYear) {
Expand Down Expand Up @@ -626,9 +628,11 @@ TEST_F(JodaDateTimeFormatterTest, parseWeekYear) {

TEST_F(JodaDateTimeFormatterTest, parseCenturyOfEra) {
// Probe century range
EXPECT_EQ(
util::fromTimestampString("292278900-01-01 00:00:00"),
parseJoda("2922789", "CCCCCCC").timestamp);
// Temporarily removed for adapting to spark semantic (not allowed year digits
// larger than 7).
// EXPECT_EQ(
// util::fromTimestampString("292278900-01-01 00:00:00"),
// parseJoda("2922789", "CCCCCCC").timestamp);
EXPECT_EQ(
util::fromTimestampString("00-01-01 00:00:00"),
parseJoda("0", "C").timestamp);
Expand Down
35 changes: 35 additions & 0 deletions velox/functions/lib/window/tests/WindowTestBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,41 @@ void WindowTestBase::testWindowFunction(
}
}

void WindowTestBase::testKRangeFrames(const std::string& function) {
// The current support for k Range frames is limited to ascending sort
// orders without null values. Frames clauses generating empty frames
// are also not supported.

// For deterministic results its expected that rows have a fixed ordering
// in the partition so that the range frames are predictable. So the
// input table.
vector_size_t size = 100;

auto vectors = makeRowVector({
makeFlatVector<int32_t>(size, [](auto row) { return row % 10; }),
makeFlatVector<int64_t>(size, [](auto row) { return row; }),
makeFlatVector<int32_t>(size, [](auto row) { return row % 7 + 1; }),
makeFlatVector<int64_t>(size, [](auto row) { return row % 4 + 1; }),
});

const std::string overClause = "partition by c0 order by c1";
const std::vector<std::string> kRangeFrames = {
"range between 5 preceding and current row",
"range between current row and 5 following",
"range between 5 preceding and 5 following",
"range between unbounded preceding and 5 following",
"range between 5 preceding and unbounded following",

"range between c3 preceding and current row",
"range between current row and c3 following",
"range between c2 preceding and c3 following",
"range between unbounded preceding and c3 following",
"range between c3 preceding and unbounded following",
};

testWindowFunction({vectors}, function, {overClause}, kRangeFrames);
}

void WindowTestBase::assertWindowFunctionError(
const std::vector<RowVectorPtr>& input,
const std::string& function,
Expand Down
2 changes: 2 additions & 0 deletions velox/functions/lib/window/tests/WindowTestBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ class WindowTestBase : public exec::test::OperatorTestBase {
const std::vector<std::string>& frameClauses = {""},
bool createTable = true);

void testKRangeFrames(const std::string& function);

/// This function tests the SQL query for the window function and overClause
/// combination with the input RowVectors. It is expected that query execution
/// will throw an exception with the errorMessage specified.
Expand Down
9 changes: 7 additions & 2 deletions velox/functions/prestosql/ArithmeticImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,15 @@ round(const TNum& number, const TDecimals& decimals = 0) {
}

double factor = std::pow(10, decimals);
double variance = 0.1;
if (number < 0) {
return (std::round(number * factor * -1) / factor) * -1;
return (std::round(
std::nextafter(number, number - variance) * factor * -1) /
factor) *
-1;
}
return std::round(number * factor) / factor;
return std::round(std::nextafter(number, number + variance) * factor) /
factor;
}

// This is used by Velox for floating points plus.
Expand Down
1 change: 1 addition & 0 deletions velox/functions/prestosql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ add_library(
Repeat.cpp
Reverse.cpp
RowFunction.cpp
RowFunctionWithNull.cpp
Sequence.cpp
Slice.cpp
Split.cpp
Expand Down
72 changes: 72 additions & 0 deletions velox/functions/prestosql/RowFunctionWithNull.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/expression/Expr.h"
#include "velox/expression/VectorFunction.h"

namespace facebook::velox::functions {
namespace {

class RowFunctionWithNull : public exec::VectorFunction {
public:
void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& outputType,
exec::EvalCtx& context,
VectorPtr& result) const override {
auto argsCopy = args;

BufferPtr nulls = AlignedBuffer::allocate<char>(
bits::nbytes(rows.size()), context.pool(), 1);
auto* nullsPtr = nulls->asMutable<uint64_t>();
auto cntNull = 0;
rows.applyToSelected([&](vector_size_t i) {
bits::clearNull(nullsPtr, i);
if (!bits::isBitNull(nullsPtr, i)) {
for (size_t c = 0; c < argsCopy.size(); c++) {
auto arg = argsCopy[c].get();
if (arg->mayHaveNulls() && arg->isNullAt(i)) {
// If any argument of the struct is null, set the struct as null.
bits::setNull(nullsPtr, i, true);
cntNull++;
break;
}
}
}
});

RowVectorPtr localResult = std::make_shared<RowVector>(
context.pool(),
outputType,
nulls,
rows.size(),
std::move(argsCopy),
cntNull /*nullCount*/);
context.moveOrCopyResult(localResult, rows, result);
}

bool isDefaultNullBehavior() const override {
return false;
}
};
} // namespace

VELOX_DECLARE_VECTOR_FUNCTION(
udf_concat_row_with_null,
std::vector<std::shared_ptr<exec::FunctionSignature>>{},
std::make_unique<RowFunctionWithNull>());

} // namespace facebook::velox::functions
18 changes: 15 additions & 3 deletions velox/functions/prestosql/StringFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,8 @@ class ConcatFunction : public exec::VectorFunction {
* If search is an empty string, inserts replace in front of every character
*and at the end of the string.
**/
class Replace : public exec::VectorFunction {
template <bool ignoreEmptyReplaced>
class ReplaceBase : public exec::VectorFunction {
private:
template <
typename StringReader,
Expand All @@ -298,7 +299,7 @@ class Replace : public exec::VectorFunction {
FlatVector<StringView>* results) const {
rows.applyToSelected([&](int row) {
auto proxy = exec::StringWriter<>(results, row);
stringImpl::replace(
stringImpl::replace<ignoreEmptyReplaced>(
proxy, stringReader(row), searchReader(row), replaceReader(row));
proxy.finalize();
});
Expand All @@ -317,7 +318,8 @@ class Replace : public exec::VectorFunction {
rows.applyToSelected([&](int row) {
auto proxy = exec::StringWriter<true /*reuseInput*/>(
results, row, stringReader(row) /*reusedInput*/, true /*inPlace*/);
stringImpl::replaceInPlace(proxy, searchReader(row), replaceReader(row));
stringImpl::replaceInPlace<ignoreEmptyReplaced>(
proxy, searchReader(row), replaceReader(row));
proxy.finalize();
});
}
Expand Down Expand Up @@ -429,6 +431,11 @@ class Replace : public exec::VectorFunction {
return {{0, 2}};
}
};

class Replace : public ReplaceBase<false /*ignoreEmptyReplaced*/> {};

class ReplaceIgnoreEmptyReplaced
: public ReplaceBase<true /*ignoreEmptyReplaced*/> {};
} // namespace

VELOX_DECLARE_VECTOR_FUNCTION(
Expand All @@ -454,4 +461,9 @@ VELOX_DECLARE_VECTOR_FUNCTION(
Replace::signatures(),
std::make_unique<Replace>());

VELOX_DECLARE_VECTOR_FUNCTION(
udf_replace_ignore_empty_replaced,
ReplaceIgnoreEmptyReplaced::signatures(),
std::make_unique<ReplaceIgnoreEmptyReplaced>());

} // namespace facebook::velox::functions
22 changes: 20 additions & 2 deletions velox/functions/prestosql/aggregates/AverageAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,16 @@ class AverageAggregate : public exec::Aggregate {
rows.applyToSelected([&](vector_size_t i) {
updateNonNullValue(groups[i], TAccumulator(value));
});
} else {
// Spark expects the result of partial avg to be non-nullable.
rows.applyToSelected(
[&](vector_size_t i) { exec::Aggregate::clearNull(groups[i]); });
}
} else if (decodedRaw_.mayHaveNulls()) {
rows.applyToSelected([&](vector_size_t i) {
if (decodedRaw_.isNullAt(i)) {
// Spark expects the result of partial avg to be non-nullable.
exec::Aggregate::clearNull(groups[i]);
return;
}
updateNonNullValue(
Expand Down Expand Up @@ -135,12 +141,18 @@ class AverageAggregate : public exec::Aggregate {
const TInput value = decodedRaw_.valueAt<TInput>(0);
const auto numRows = rows.countSelected();
updateNonNullValue(group, numRows, TAccumulator(value) * numRows);
} else {
// Spark expects the result of partial avg to be non-nullable.
exec::Aggregate::clearNull(group);
}
} else if (decodedRaw_.mayHaveNulls()) {
rows.applyToSelected([&](vector_size_t i) {
if (!decodedRaw_.isNullAt(i)) {
updateNonNullValue(
group, TAccumulator(decodedRaw_.valueAt<TInput>(i)));
} else {
// Spark expects the result of partial avg to be non-nullable.
exec::Aggregate::clearNull(group);
}
});
} else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) {
Expand Down Expand Up @@ -337,9 +349,15 @@ class AverageAggregate : public exec::Aggregate {
if (isNull(group)) {
vector->setNull(i, true);
} else {
clearNull(rawNulls, i);
auto* sumCount = accumulator(group);
rawValues[i] = TResult(sumCount->sum) / sumCount->count;
if (sumCount->count == 0) {
// To align with Spark, if all input are nulls, count will be 0,
// and the result of final avg will be null.
vector->setNull(i, true);
} else {
clearNull(rawNulls, i);
rawValues[i] = (TResult)sumCount->sum / sumCount->count;
}
}
}
}
Expand Down
Loading

0 comments on commit 44b2919

Please sign in to comment.