Skip to content

Commit

Permalink
Use presto abs
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchengchenghh committed Jul 3, 2023
1 parent b32cc11 commit 0a4b182
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 75 deletions.
67 changes: 0 additions & 67 deletions velox/functions/sparksql/Decimal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,62 +215,6 @@ class RoundDecimalFunction final : public exec::VectorFunction {
}
};

template <typename TInput>
class AbsFunction final : public exec::VectorFunction {
void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args, // Not using const ref so we can reuse args
const TypePtr& outputType,
exec::EvalCtx& context,
VectorPtr& resultRef) const final {
VELOX_CHECK_EQ(args.size(), 1);
auto inputType = args[0]->type();
VELOX_CHECK(
inputType->isShortDecimal() || inputType->isLongDecimal(),
"ShortDecimal or LongDecimal type is required.");

exec::DecodedArgs decodedArgs(rows, args, context);
auto decimalVector = decodedArgs.at(0);
if (inputType->isShortDecimal()) {
auto decimalType = inputType->asShortDecimal();
context.ensureWritable(
rows,
DECIMAL(decimalType.precision(), decimalType.scale()),
resultRef);
auto result =
resultRef->asUnchecked<FlatVector<int64_t>>()->mutableRawValues();
rows.applyToSelected([&](int row) {
auto unscaled = std::abs(decimalVector->valueAt<int64_t>(row));
if (unscaled >= DecimalUtil::kShortDecimalMin &&
unscaled <= DecimalUtil::kShortDecimalMax) {
result[row] = unscaled;
} else {
// TODO: adjust the bahavior according to ANSI.
resultRef->setNull(row, true);
}
});
} else {
auto decimalType = inputType->asLongDecimal();
context.ensureWritable(
rows,
DECIMAL(decimalType.precision(), decimalType.scale()),
resultRef);
auto result =
resultRef->asUnchecked<FlatVector<int128_t>>()->mutableRawValues();
rows.applyToSelected([&](int row) {
auto unscaled = std::abs(decimalVector->valueAt<int128_t>(row));
if (unscaled >= DecimalUtil::kLongDecimalMin &&
unscaled <= DecimalUtil::kLongDecimalMax) {
result[row] = unscaled;
} else {
// TODO: adjust the bahavior according to ANSI.
resultRef->setNull(row, true);
}
});
}
}
};

class UnscaledValueFunction final : public exec::VectorFunction {
void apply(
const SelectivityVector& rows,
Expand Down Expand Up @@ -328,17 +272,6 @@ std::vector<std::shared_ptr<exec::FunctionSignature>> roundDecimalSignatures() {
.build()};
}

std::vector<std::shared_ptr<exec::FunctionSignature>> absSignatures() {
return {exec::FunctionSignatureBuilder()
.integerVariable("a_precision")
.integerVariable("a_scale")
.integerVariable("r_precision", "min(38, a_precision)")
.integerVariable("r_scale", "min(38, a_scale)")
.returnType("DECIMAL(r_precision, r_scale)")
.argumentType("DECIMAL(a_precision, a_scale)")
.build()};
}

std::vector<std::shared_ptr<exec::FunctionSignature>>
unscaledValueSignatures() {
return {exec::FunctionSignatureBuilder()
Expand Down
6 changes: 0 additions & 6 deletions velox/functions/sparksql/Decimal.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,6 @@ std::shared_ptr<exec::VectorFunction> makeRoundDecimal(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs);

std::vector<std::shared_ptr<exec::FunctionSignature>> absSignatures();

std::shared_ptr<exec::VectorFunction> makeAbs(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs);

std::vector<std::shared_ptr<exec::FunctionSignature>> unscaledValueSignatures();

std::shared_ptr<exec::VectorFunction> makeUnscaledValue(
Expand Down
2 changes: 0 additions & 2 deletions velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,6 @@ void registerFunctions(const std::string& prefix) {
prefix + "make_decimal", makeDecimalSignatures(), makeMakeDecimal);
exec::registerStatefulVectorFunction(
prefix + "decimal_round", roundDecimalSignatures(), makeRoundDecimal);
exec::registerStatefulVectorFunction(
prefix + "abs", absSignatures(), makeAbs);
exec::registerStatefulVectorFunction(
prefix + "unscaled_value", unscaledValueSignatures(), makeUnscaledValue);
// Register date functions.
Expand Down

0 comments on commit 0a4b182

Please sign in to comment.