Skip to content

Commit

Permalink
Implement decimal comparisons as vector functions (facebookincubator#…
Browse files Browse the repository at this point in the history
…4638)

Summary:
Decimal Simple Functions require separate signatures/registration for
ShortDecimal and LongDecimal types.
We want to view Decimal Types as a single entity/type. Simple Functions break this view.
To achieve the single entity view, all Decimal functions must be implemented as Vector Functions.
The Decimal comparison functions are currently implemented as Simple Functions.
The scope of this PR is to implement them as Vector Functions.

Pull Request resolved: facebookincubator#4638

Reviewed By: mbasmanova

Differential Revision: D45193675

Pulled By: Yuhta

fbshipit-source-id: 5f78d04464ec910577ca8c353493fa66cb6890ab
  • Loading branch information
majetideepak authored and facebook-github-bot committed Apr 26, 2023
1 parent df927bd commit c125342
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 103 deletions.
20 changes: 2 additions & 18 deletions velox/core/SimpleFunctionMetadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@
#include "velox/type/Type.h"
#include "velox/type/Variant.h"

inline constexpr char kPrecisionVariable[] = "a_precision";
inline constexpr char kScaleVariable[] = "a_scale";

namespace facebook::velox::core {

// Most UDFs are deterministic, hence this default value.
Expand Down Expand Up @@ -194,14 +191,8 @@ struct TypeAnalysis {
SimpleTypeTrait<T>::isPrimitiveType ||
SimpleTypeTrait<T>::typeKind == TypeKind::OPAQUE);
results.stats.concreteCount++;
if (isDecimalKind(SimpleTypeTrait<T>::typeKind)) {
results.out << detail::strToLowerCopy(
std::string(SimpleTypeTrait<T>::name))
<< "(" << kPrecisionVariable << "," << kScaleVariable << ")";
} else {
results.out << detail::strToLowerCopy(
std::string(SimpleTypeTrait<T>::name));
}
results.out << detail::strToLowerCopy(
std::string(SimpleTypeTrait<T>::name));
}
};

Expand Down Expand Up @@ -443,15 +434,8 @@ class SimpleFunctionMetadata : public ISimpleFunctionMetadata {

builder.returnType(analysis.outputType);

bool isDecimalArg = false;
for (const auto& arg : analysis.argsTypes) {
builder.argumentType(arg);
isDecimalArg |= isDecimalTypeSignature(arg);
}

if (isDecimalArg) {
builder.integerVariable(kPrecisionVariable);
builder.integerVariable(kScaleVariable);
}

for (const auto& variable : analysis.variables) {
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/prestosql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ add_library(
ArraySort.cpp
ArraySum.cpp
Comparisons.cpp
DecimalArithmetic.cpp
DecimalVectorFunctions.cpp
ElementAt.cpp
FilterFunctions.cpp
FromUnixTime.cpp
Expand Down
84 changes: 54 additions & 30 deletions velox/functions/prestosql/Comparisons.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,10 @@ struct SimdComparator {
template <
TypeKind kind,
typename std::enable_if_t<
xsimd::has_simd_register<
typename TypeTraits<kind>::NativeType>::value &&
kind != TypeKind::BOOLEAN,
(xsimd::has_simd_register<
typename TypeTraits<kind>::NativeType>::value &&
kind != TypeKind::BOOLEAN) ||
kind == TypeKind::SHORT_DECIMAL || kind == TypeKind::LONG_DECIMAL,
int> = 0>
void applyComparison(
const SelectivityVector& rows,
Expand All @@ -112,11 +113,16 @@ struct SimdComparator {
auto resultVector = result->asUnchecked<FlatVector<bool>>();
auto rawResult = resultVector->mutableRawValues<uint8_t>();

// UnscaledShortDecimal and UnscaledLongDecimal will soon be replaced with
// int64_t and int128_t respectively. This change will be removed then.
constexpr bool isDecimal =
(std::is_same_v<T, UnscaledShortDecimal> ||
std::is_same_v<T, UnscaledLongDecimal>);
auto isSimdizable = (lhs.isConstantEncoding() || lhs.isFlatEncoding()) &&
(rhs.isConstantEncoding() || rhs.isFlatEncoding()) &&
rows.isAllSelected();

if (!isSimdizable) {
if (!isSimdizable || isDecimal) {
exec::LocalDecodedVector lhsDecoded(context, lhs, rows);
exec::LocalDecodedVector rhsDecoded(context, rhs, rows);

Expand All @@ -128,38 +134,40 @@ struct SimdComparator {
});
return;
}
if constexpr (!isDecimal) {
if (lhs.isConstantEncoding() && rhs.isConstantEncoding()) {
auto l = lhs.asUnchecked<ConstantVector<T>>()->valueAt(0);
auto r = rhs.asUnchecked<ConstantVector<T>>()->valueAt(0);
applySimdComparison<T, true, true>(
rows.begin(), rows.end(), &l, &r, rawResult);
} else if (lhs.isConstantEncoding()) {
auto l = lhs.asUnchecked<ConstantVector<T>>()->valueAt(0);
auto rawRhs = rhs.asUnchecked<FlatVector<T>>()->rawValues();
applySimdComparison<T, true, false>(
rows.begin(), rows.end(), &l, rawRhs, rawResult);
} else if (rhs.isConstantEncoding()) {
auto rawLhs = lhs.asUnchecked<FlatVector<T>>()->rawValues();
auto r = rhs.asUnchecked<ConstantVector<T>>()->valueAt(0);
applySimdComparison<T, false, true>(
rows.begin(), rows.end(), rawLhs, &r, rawResult);
} else {
auto rawLhs = lhs.asUnchecked<FlatVector<T>>()->rawValues();
auto rawRhs = rhs.asUnchecked<FlatVector<T>>()->rawValues();
applySimdComparison<T, false, false>(
rows.begin(), rows.end(), rawLhs, rawRhs, rawResult);
}

if (lhs.isConstantEncoding() && rhs.isConstantEncoding()) {
auto l = lhs.asUnchecked<ConstantVector<T>>()->valueAt(0);
auto r = rhs.asUnchecked<ConstantVector<T>>()->valueAt(0);
applySimdComparison<T, true, true>(
rows.begin(), rows.end(), &l, &r, rawResult);
} else if (lhs.isConstantEncoding()) {
auto l = lhs.asUnchecked<ConstantVector<T>>()->valueAt(0);
auto rawRhs = rhs.asUnchecked<FlatVector<T>>()->rawValues();
applySimdComparison<T, true, false>(
rows.begin(), rows.end(), &l, rawRhs, rawResult);
} else if (rhs.isConstantEncoding()) {
auto rawLhs = lhs.asUnchecked<FlatVector<T>>()->rawValues();
auto r = rhs.asUnchecked<ConstantVector<T>>()->valueAt(0);
applySimdComparison<T, false, true>(
rows.begin(), rows.end(), rawLhs, &r, rawResult);
} else {
auto rawLhs = lhs.asUnchecked<FlatVector<T>>()->rawValues();
auto rawRhs = rhs.asUnchecked<FlatVector<T>>()->rawValues();
applySimdComparison<T, false, false>(
rows.begin(), rows.end(), rawLhs, rawRhs, rawResult);
resultVector->clearNulls(rows);
}

resultVector->clearNulls(rows);
}

template <
TypeKind kind,
typename std::enable_if_t<
!xsimd::has_simd_register<
typename TypeTraits<kind>::NativeType>::value ||
kind == TypeKind::BOOLEAN,
(!xsimd::has_simd_register<
typename TypeTraits<kind>::NativeType>::value ||
kind == TypeKind::BOOLEAN) &&
kind != TypeKind::SHORT_DECIMAL && kind != TypeKind::LONG_DECIMAL,
int> = 0>
void applyComparison(
const SelectivityVector& /* rows */,
Expand Down Expand Up @@ -187,6 +195,16 @@ class ComparisonSimdFunction : public exec::VectorFunction {
context.ensureWritable(rows, outputType, result);
auto comparator = SimdComparator<ComparisonOp>{};

if (args[0]->type()->isShortDecimal()) {
comparator.template applyComparison<TypeKind::SHORT_DECIMAL>(
rows, *args[0], *args[1], context, result);
return;
} else if (args[0]->type()->isLongDecimal()) {
comparator.template applyComparison<TypeKind::LONG_DECIMAL>(
rows, *args[0], *args[1], context, result);
return;
}

VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH(
comparator.template applyComparison,
args[0]->typeKind(),
Expand All @@ -208,7 +226,13 @@ class ComparisonSimdFunction : public exec::VectorFunction {
.argumentType(inputType)
.build());
}

signatures.push_back(exec::FunctionSignatureBuilder()
.integerVariable("a_precision")
.integerVariable("a_scale")
.returnType("boolean")
.argumentType("DECIMAL(a_precision, a_scale)")
.argumentType("DECIMAL(a_precision, a_scale)")
.build());
return signatures;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,59 @@ class DecimalUnaryBaseFunction : public exec::VectorFunction {
return result->asUnchecked<FlatVector<R>>()->mutableRawValues();
}

private:
const uint8_t aRescale_;
};

template <typename A /* Argument */>
class DecimalBetweenFunction : public exec::VectorFunction {
public:
DecimalBetweenFunction() {}
void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& resultType,
exec::EvalCtx& context,
VectorPtr& result) const override {
prepareResults(rows, resultType, context, result);
// Second and third arguments must always be constant.
VELOX_CHECK(args[1]->isConstantEncoding() && args[2]->isConstantEncoding());
auto constantB = args[1]->asUnchecked<SimpleVector<A>>()->valueAt(0);
auto constantC = args[2]->asUnchecked<SimpleVector<A>>()->valueAt(0);
if (args[0]->isFlatEncoding()) {
// Fast path if first argument is flat.
auto flatA = args[0]->asUnchecked<FlatVector<A>>();
auto rawA = flatA->mutableRawValues();
context.applyToSelectedNoThrow(rows, [&](auto row) {
result->asUnchecked<FlatVector<bool>>()->set(
row,
rawA[row].unscaledValue() >= constantB.unscaledValue() &&
rawA[row].unscaledValue() <= constantC.unscaledValue());
});
} else {
// Path if first argument is encoded.
exec::DecodedArgs decodedArgs(rows, args, context);
auto a = decodedArgs.at(0);
context.applyToSelectedNoThrow(rows, [&](auto row) {
auto value = a->valueAt<A>(row);
result->asUnchecked<FlatVector<bool>>()->set(
row,
value.unscaledValue() >= constantB.unscaledValue() &&
value.unscaledValue() <= constantC.unscaledValue());
});
}
}

private:
void prepareResults(
const SelectivityVector& rows,
const TypePtr& resultType,
exec::EvalCtx& context,
VectorPtr& result) const {
context.ensureWritable(rows, resultType, result);
result->clearNulls(rows);
}
};

class Addition {
public:
template <typename R, typename A, typename B>
Expand Down Expand Up @@ -431,6 +480,18 @@ decimalAbsNegateSignature() {
.build()};
}

std::vector<std::shared_ptr<exec::FunctionSignature>>
decimalBetweenSignature() {
return {exec::FunctionSignatureBuilder()
.integerVariable("a_precision")
.integerVariable("a_scale")
.returnType("BOOLEAN")
.argumentType("DECIMAL(a_precision, a_scale)")
.argumentType("DECIMAL(a_precision, a_scale)")
.argumentType("DECIMAL(a_precision, a_scale)")
.build()};
}

template <typename Operation>
std::shared_ptr<exec::VectorFunction> createDecimalUnary(
const std::string& /*name*/,
Expand Down Expand Up @@ -516,6 +577,26 @@ std::shared_ptr<exec::VectorFunction> createDecimalFunction(
}
VELOX_UNSUPPORTED();
}

std::shared_ptr<exec::VectorFunction> createDecimalBetweenFunction(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs) {
auto aType = inputArgs[0].type;
auto bType = inputArgs[1].type;
auto cType = inputArgs[2].type;
if (aType->kind() == TypeKind::SHORT_DECIMAL) {
VELOX_CHECK(bType->kind() == TypeKind::SHORT_DECIMAL);
VELOX_CHECK(cType->kind() == TypeKind::SHORT_DECIMAL);
// Arguments are short decimals.
return std::make_shared<DecimalBetweenFunction<UnscaledShortDecimal>>();
} else {
VELOX_CHECK(bType->kind() == TypeKind::LONG_DECIMAL);
VELOX_CHECK(cType->kind() == TypeKind::LONG_DECIMAL);
// Arguments are long decimals.
return std::make_shared<DecimalBetweenFunction<UnscaledLongDecimal>>();
}
VELOX_UNSUPPORTED();
}
}; // namespace

VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
Expand Down Expand Up @@ -552,4 +633,9 @@ VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
udf_decimal_negate,
decimalAbsNegateSignature(),
createDecimalUnary<Negate>);

VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
udf_decimal_between,
decimalBetweenSignature(),
createDecimalBetweenFunction);
}; // namespace facebook::velox::functions
Original file line number Diff line number Diff line change
Expand Up @@ -58,47 +58,8 @@ void registerComparisonFunctions(const std::string& prefix) {
{prefix + "between"});
registerFunction<BetweenFunction, bool, Date, Date, Date>(
{prefix + "between"});
registerFunction<
BetweenFunction,
bool,
UnscaledShortDecimal,
UnscaledShortDecimal,
UnscaledShortDecimal>({prefix + "between"});
registerFunction<
BetweenFunction,
bool,
UnscaledLongDecimal,
UnscaledLongDecimal,
UnscaledLongDecimal>({prefix + "between"});
registerFunction<
GtFunction,
bool,
UnscaledShortDecimal,
UnscaledShortDecimal>({prefix + "gt"});
registerFunction<GtFunction, bool, UnscaledLongDecimal, UnscaledLongDecimal>(
{prefix + "gt"});
registerFunction<
LtFunction,
bool,
UnscaledShortDecimal,
UnscaledShortDecimal>({prefix + "lt"});
registerFunction<LtFunction, bool, UnscaledLongDecimal, UnscaledLongDecimal>(
{prefix + "lt"});

registerFunction<
GteFunction,
bool,
UnscaledShortDecimal,
UnscaledShortDecimal>({prefix + "gte"});
registerFunction<GteFunction, bool, UnscaledLongDecimal, UnscaledLongDecimal>(
{prefix + "gte"});
registerFunction<
LteFunction,
bool,
UnscaledShortDecimal,
UnscaledShortDecimal>({prefix + "lte"});
registerFunction<LteFunction, bool, UnscaledLongDecimal, UnscaledLongDecimal>(
{prefix + "lte"});
VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_between, prefix + "between");
}

} // namespace facebook::velox::functions
9 changes: 7 additions & 2 deletions velox/functions/prestosql/tests/ComparisonsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ TEST_F(ComparisonsTest, betweenDecimal) {
VectorPtr input,
VectorPtr expectedResult) {
auto actual = evaluate<SimpleVector<bool>>(exprStr, makeRowVector({input}));
test::assertEqualVectors(actual, expectedResult);
test::assertEqualVectors(expectedResult, actual);
};

auto shortFlat = makeNullableShortDecimalFlatVector(
Expand All @@ -120,6 +120,11 @@ TEST_F(ComparisonsTest, betweenDecimal) {
auto longFlat = makeNullableLongDecimalFlatVector(
{100, 250, 300, 500, std::nullopt}, DECIMAL(20, 2));

runAndCompare(
"c0 between cast(2.00 as DECIMAL(20, 2)) and cast(3.00 as DECIMAL(20, 2))",
longFlat,
expectedResult);

// Comparing LONG_DECIMAL and SHORT_DECIMAL must throw error.
VELOX_ASSERT_THROW(
runAndCompare("c0 between 2.00 and 3.00", longFlat, expectedResult),
Expand Down Expand Up @@ -159,7 +164,7 @@ TEST_F(ComparisonsTest, gtLtDecimal) {
std::vector<VectorPtr>& inputs,
VectorPtr expectedResult) {
auto actual = evaluate<SimpleVector<bool>>(expr, makeRowVector(inputs));
test::assertEqualVectors(actual, expectedResult);
test::assertEqualVectors(expectedResult, actual);
};

// Short Decimals test.
Expand Down
9 changes: 0 additions & 9 deletions velox/type/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,6 @@ bool isDecimalName(const std::string& typeName) {
typeNameUpper == TypeTraits<TypeKind::LONG_DECIMAL>::name);
}

bool isDecimalTypeSignature(const std::string& arg) {
auto upper = boost::algorithm::to_upper_copy(arg);
return (
upper.find(TypeTraits<TypeKind::SHORT_DECIMAL>::name) !=
std::string::npos ||
upper.find(TypeTraits<TypeKind::LONG_DECIMAL>::name) !=
std::string::npos);
}

// Static variable intialization is not thread safe for non
// constant-initialization, but scoped static initialization is thread safe.
const std::unordered_map<std::string, TypeKind>& getTypeStringMap() {
Expand Down
Loading

0 comments on commit c125342

Please sign in to comment.