Skip to content

Commit

Permalink
Folder: expression
Browse files Browse the repository at this point in the history
relative pr:

Allow decimal in casting string to int facebookincubator#215
Add mapping from named_struct to row_constructor facebookincubator#214
Fix semantic issues in cast function facebookincubator#280
  • Loading branch information
zhejiangxiaomai committed May 31, 2023
1 parent 65aee45 commit 26e686f
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 20 deletions.
152 changes: 139 additions & 13 deletions velox/expression/CastExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "velox/expression/StringWriter.h"
#include "velox/external/date/tz.h"
#include "velox/functions/lib/RowsTranslationUtil.h"
#include "velox/type/DecimalUtilOp.h"
#include "velox/vector/ComplexVector.h"
#include "velox/vector/FunctionVector.h"
#include "velox/vector/SelectivityVector.h"
Expand All @@ -42,17 +43,25 @@ namespace {
/// @param input The input vector (of type From)
/// @param result The output vector (of type To)
/// @return False if the result is null
template <typename To, typename From, bool Truncate>
template <typename To, typename From, bool Truncate, bool AllowDecimal>
void applyCastKernel(
vector_size_t row,
const SimpleVector<From>* input,
FlatVector<To>* result,
bool& nullOutput) {
// Special handling for string target type
if constexpr (CppToType<To>::typeKind == TypeKind::VARCHAR) {
auto output =
util::Converter<CppToType<To>::typeKind, void, Truncate>::cast(
input->valueAt(row), nullOutput);
std::string output;
if (input->type()->isDecimal()) {
output = util::
Converter<CppToType<To>::typeKind, void, Truncate, AllowDecimal>::
cast(input->valueAt(row), nullOutput, input->type());
} else {
output = util::
Converter<CppToType<To>::typeKind, void, Truncate, AllowDecimal>::
cast(input->valueAt(row), nullOutput);
}

if (!nullOutput) {
// Write the result output to the output vector
auto writer = exec::StringWriter<>(result, row);
Expand All @@ -63,11 +72,20 @@ void applyCastKernel(
writer.finalize();
}
} else {
auto output =
util::Converter<CppToType<To>::typeKind, void, Truncate>::cast(
input->valueAt(row), nullOutput);
if (!nullOutput) {
result->set(row, output);
if (input->type()->isDecimal()) {
auto output = util::
Converter<CppToType<To>::typeKind, void, Truncate, AllowDecimal>::
cast(input->valueAt(row), nullOutput, input->type());
if (!nullOutput) {
result->set(row, output);
}
} else {
auto output = util::
Converter<CppToType<To>::typeKind, void, Truncate, AllowDecimal>::
cast(input->valueAt(row), nullOutput);
if (!nullOutput) {
result->set(row, output);
}
}
}
}
Expand Down Expand Up @@ -134,6 +152,78 @@ void applyIntToDecimalCastKernel(
}
});
}

template <typename TInput, typename TOutput>
void applyDateToDecimalCastKernel(
const SelectivityVector& rows,
const BaseVector& input,
exec::EvalCtx& context,
const TypePtr& toType,
VectorPtr castResult) {
auto sourceVector = input.as<SimpleVector<Date>>();
auto castResultRawBuffer =
castResult->asUnchecked<FlatVector<TOutput>>()->mutableRawValues();
const auto& toPrecisionScale = getDecimalPrecisionScale(*toType);
context.applyToSelectedNoThrow(rows, [&](vector_size_t row) {
auto rescaledValue = DecimalUtil::rescaleInt<TInput, TOutput>(
sourceVector->valueAt(row).days(),
toPrecisionScale.first,
toPrecisionScale.second);
if (rescaledValue.has_value()) {
castResultRawBuffer[row] = rescaledValue.value();
} else {
castResult->setNull(row, true);
}
});
}

template <typename TInput, typename TOutput>
void applyDoubleToDecimalCastKernel(
const SelectivityVector& rows,
const BaseVector& input,
exec::EvalCtx& context,
const TypePtr& toType,
VectorPtr castResult) {
auto sourceVector = input.as<SimpleVector<TInput>>();
auto castResultRawBuffer =
castResult->asUnchecked<FlatVector<TOutput>>()->mutableRawValues();
const auto& toPrecisionScale = getDecimalPrecisionScale(*toType);
context.applyToSelectedNoThrow(rows, [&](vector_size_t row) {
auto rescaledValue = DecimalUtilOp::rescaleDouble<TInput, TOutput>(
sourceVector->valueAt(row),
toPrecisionScale.first,
toPrecisionScale.second);
if (rescaledValue.has_value()) {
castResultRawBuffer[row] = rescaledValue.value();
} else {
castResult->setNull(row, true);
}
});
}

template <typename TOutput>
void applyVarCharToDecimalCastKernel(
const SelectivityVector& rows,
const BaseVector& input,
exec::EvalCtx& context,
const TypePtr& toType,
VectorPtr castResult) {
auto sourceVector = input.as<SimpleVector<StringView>>();
auto castResultRawBuffer =
castResult->asUnchecked<FlatVector<TOutput>>()->mutableRawValues();
const auto& toPrecisionScale = getDecimalPrecisionScale(*toType);
context.applyToSelectedNoThrow(rows, [&](vector_size_t row) {
auto rescaledValue = DecimalUtilOp::rescaleVarchar<TOutput>(
sourceVector->valueAt(row),
toPrecisionScale.first,
toPrecisionScale.second);
if (rescaledValue.has_value()) {
castResultRawBuffer[row] = rescaledValue.value();
} else {
castResult->setNull(row, true);
}
});
}
} // namespace

template <typename To, typename From>
Expand All @@ -143,6 +233,7 @@ void CastExpr::applyCastWithTry(
const BaseVector& input,
FlatVector<To>* resultFlatVector) {
const auto& queryConfig = context.execCtx()->queryCtx()->queryConfig();
const bool isCastIntAllowDecimal = queryConfig.isCastIntAllowDecimal();

auto* inputSimpleVector = input.as<SimpleVector<From>>();

Expand All @@ -151,8 +242,13 @@ void CastExpr::applyCastWithTry(
bool nullOutput = false;
try {
// Passing a false truncate flag
applyCastKernel<To, From, false>(
row, inputSimpleVector, resultFlatVector, nullOutput);
if (isCastIntAllowDecimal) {
applyCastKernel<To, From, false, true>(
row, inputSimpleVector, resultFlatVector, nullOutput);
} else {
applyCastKernel<To, From, false, false>(
row, inputSimpleVector, resultFlatVector, nullOutput);
}
} catch (const VeloxRuntimeError& re) {
VELOX_FAIL(
makeErrorMessage(input, row, resultFlatVector->type()) + " " +
Expand All @@ -176,8 +272,13 @@ void CastExpr::applyCastWithTry(
bool nullOutput = false;
try {
// Passing a true truncate flag
applyCastKernel<To, From, true>(
row, inputSimpleVector, resultFlatVector, nullOutput);
if (isCastIntAllowDecimal) {
applyCastKernel<To, From, true, true>(
row, inputSimpleVector, resultFlatVector, nullOutput);
} else {
applyCastKernel<To, From, true, false>(
row, inputSimpleVector, resultFlatVector, nullOutput);
}
} catch (const VeloxRuntimeError& re) {
VELOX_FAIL(
makeErrorMessage(input, row, resultFlatVector->type()) + " " +
Expand Down Expand Up @@ -272,6 +373,11 @@ void CastExpr::applyCast(
return applyCastWithTry<To, Timestamp>(
rows, context, input, resultFlatVector);
}
case TypeKind::HUGEINT: {
return applyCastWithTry<To, int128_t>(
rows, context, input, resultFlatVector);
}

default: {
VELOX_UNSUPPORTED("Invalid from type in casting: {}", fromType);
}
Expand Down Expand Up @@ -513,6 +619,10 @@ VectorPtr CastExpr::applyDecimal(
(*castResult).clearNulls(rows);
// toType is a decimal
switch (fromType->kind()) {
case TypeKind::BOOLEAN:
applyIntToDecimalCastKernel<bool, toDecimalType>(
rows, input, context, toType, castResult);
break;
case TypeKind::TINYINT:
applyIntToDecimalCastKernel<int8_t, toDecimalType>(
rows, input, context, toType, castResult);
Expand Down Expand Up @@ -542,6 +652,22 @@ VectorPtr CastExpr::applyDecimal(
break;
}
}
case TypeKind::DATE:
applyDateToDecimalCastKernel<int32_t, toDecimalType>(
rows, input, context, toType, castResult);
break;
case TypeKind::REAL:
applyDoubleToDecimalCastKernel<float, toDecimalType>(
rows, input, context, toType, castResult);
break;
case TypeKind::DOUBLE:
applyDoubleToDecimalCastKernel<double, toDecimalType>(
rows, input, context, toType, castResult);
break;
case TypeKind::VARCHAR:
applyVarCharToDecimalCastKernel<toDecimalType>(
rows, input, context, toType, castResult);
break;
default:
VELOX_UNSUPPORTED(
"Cast from {} to {} is not supported",
Expand Down
25 changes: 25 additions & 0 deletions velox/expression/ExprCompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ using core::TypedExprPtr;
const char* const kAnd = "and";
const char* const kOr = "or";
const char* const kRowConstructor = "row_constructor";
const char* const kRowConstructorWithNull = "row_constructor_with_null";

struct ITypedExprHasher {
size_t operator()(const ITypedExpr* expr) const {
Expand Down Expand Up @@ -212,6 +213,25 @@ ExprPtr getRowConstructorExpr(
trackCpuUsage);
}

ExprPtr getRowConstructorWithNullExpr(
const TypePtr& type,
std::vector<ExprPtr>&& compiledChildren,
bool trackCpuUsage) {
static auto rowConstructorVectorFunction =
vectorFunctionFactories().withRLock([](auto& functionMap) {
auto functionIterator = functionMap.find(exec::kRowConstructorWithNull);
return functionIterator->second.factory(
exec::kRowConstructorWithNull, {});
});

return std::make_shared<Expr>(
type,
std::move(compiledChildren),
rowConstructorVectorFunction,
"row_constructor_with_null",
trackCpuUsage);
}

ExprPtr getSpecialForm(
const std::string& name,
const TypePtr& type,
Expand All @@ -222,6 +242,11 @@ ExprPtr getSpecialForm(
type, std::move(compiledChildren), trackCpuUsage);
}

if (name == kRowConstructorWithNull) {
return getRowConstructorWithNullExpr(
type, std::move(compiledChildren), trackCpuUsage);
}

// If we just check the output of constructSpecialForm we'll have moved
// compiledChildren, and if the function isn't a special form we'll still need
// compiledChildren. Splitting the check in two avoids this use after move.
Expand Down
14 changes: 7 additions & 7 deletions velox/expression/ExprToSubfieldFilter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -439,44 +439,44 @@ std::unique_ptr<common::Filter> leafCallToSubfieldFilter(
common::Subfield& subfield,
core::ExpressionEvaluator* evaluator,
bool negated) {
if (call.name() == "eq") {
if (call.name() == "eq" || call.name() == "equalto") {
if (auto field = asField(&call, 0)) {
if (toSubfield(field, subfield)) {
return negated ? makeNotEqualFilter(call.inputs()[1], evaluator)
: makeEqualFilter(call.inputs()[1], evaluator);
}
}
} else if (call.name() == "neq") {
} else if (call.name() == "neq" || call.name() == "notequalto") {
if (auto field = asField(&call, 0)) {
if (toSubfield(field, subfield)) {
return negated ? makeEqualFilter(call.inputs()[1], evaluator)
: makeNotEqualFilter(call.inputs()[1], evaluator);
}
}
} else if (call.name() == "lte") {
} else if (call.name() == "lte" || call.name() == "lessthanorequal") {
if (auto field = asField(&call, 0)) {
if (toSubfield(field, subfield)) {
return negated ? makeGreaterThanFilter(call.inputs()[1], evaluator)
: makeLessThanOrEqualFilter(call.inputs()[1], evaluator);
}
}
} else if (call.name() == "lt") {
} else if (call.name() == "lt" || call.name() == "lessthan") {
if (auto field = asField(&call, 0)) {
if (toSubfield(field, subfield)) {
return negated
? makeGreaterThanOrEqualFilter(call.inputs()[1], evaluator)
: makeLessThanFilter(call.inputs()[1], evaluator);
}
}
} else if (call.name() == "gte") {
} else if (call.name() == "gte" || call.name() == "greaterthanorequal") {
if (auto field = asField(&call, 0)) {
if (toSubfield(field, subfield)) {
return negated
? makeLessThanFilter(call.inputs()[1], evaluator)
: makeGreaterThanOrEqualFilter(call.inputs()[1], evaluator);
}
}
} else if (call.name() == "gt") {
} else if (call.name() == "gt" || call.name() == "greaterthan") {
if (auto field = asField(&call, 0)) {
if (toSubfield(field, subfield)) {
return negated ? makeLessThanOrEqualFilter(call.inputs()[1], evaluator)
Expand All @@ -496,7 +496,7 @@ std::unique_ptr<common::Filter> leafCallToSubfieldFilter(
return makeInFilter(call.inputs()[1], evaluator, negated);
}
}
} else if (call.name() == "is_null") {
} else if (call.name() == "is_null" || call.name() == "isnull") {
if (auto field = asField(&call, 0)) {
if (toSubfield(field, subfield)) {
if (negated) {
Expand Down
Loading

0 comments on commit 26e686f

Please sign in to comment.