Skip to content

Commit

Permalink
Factor comparison functors out of LeastGreatest.cpp (#302)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #302

SQL comparisons (<, >, and ==) will be used in many functions. Ideally we'd have
these orderings defined in one common locations.

In this diff, I put these in a separate header Comparisons.h and update
LeastGreatest.cpp to use them. This required a little bit of refactoring.

Reviewed By: beroyfb

Differential Revision: D31185337

fbshipit-source-id: bc8b4534284b47cf4110efae04d7559f2d31623d
  • Loading branch information
funrollloops authored and facebook-github-bot committed Sep 28, 2021
1 parent cac6072 commit 7325060
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 100 deletions.
62 changes: 62 additions & 0 deletions velox/functions/sparksql/Comparisons.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/expression/VectorFunction.h"

namespace facebook::velox::functions::sparksql {

// Comparison functions that implement SparkSQL semantics.
// Intended to be used with scalar types (including strings and timestamps).
template <typename T>
static inline bool isNan(const T& value) {
if constexpr (std::is_floating_point<T>::value) {
return std::isnan(value);
} else {
return false;
}
}

template <typename T>
struct Less {
constexpr bool operator()(const T& a, const T& b) const {
if (isNan(a)) {
return false;
}
if (isNan(b)) {
return true;
}
return a < b;
}
};

template <typename T>
struct Greater : private Less<T> {
constexpr bool operator()(const T& a, const T& b) const {
return Less<T>::operator()(b, a);
}
};

template <typename T>
struct Equal {
constexpr bool operator()(const T& a, const T& b) const {
// In SparkSQL, NaN is defined to equal NaN.
if (isNan(a)) {
return isNan(b);
}
return a == b;
}
};

} // namespace facebook::velox::functions::sparksql
151 changes: 51 additions & 100 deletions velox/functions/sparksql/LeastGreatest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,20 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/functions/sparksql/LeastGreatest.h"

#include "velox/expression/EvalCtx.h"
#include "velox/expression/Expr.h"
#include "velox/expression/VectorFunction.h"
#include "velox/functions/sparksql/Comparisons.h"
#include "velox/type/Type.h"

#include "velox/functions/sparksql/LeastGreatest.h"

namespace facebook::velox::functions::sparksql {
namespace {

enum CompareType { MIN, MAX };
template <typename Cmp, TypeKind kind>
class LeastGreatestFunction final : public exec::VectorFunction {
using T = typename TypeTraits<kind>::NativeType;

template <CompareType type>
class LeastGreatestFunction : public exec::VectorFunction {
public:
bool isDefaultNullBehavior() const override {
return false;
}
Expand All @@ -39,52 +37,6 @@ class LeastGreatestFunction : public exec::VectorFunction {
exec::Expr* caller,
exec::EvalCtx* context,
VectorPtr* result) const override {
VELOX_CHECK(args.size() >= 2);
for (size_t i = 1; i < args.size(); i++) {
VELOX_CHECK(args[i]->typeKind() == args[0]->typeKind());
}

switch (args[0]->typeKind()) {
case TypeKind::TINYINT:
applyTyped<int8_t>(rows, args, caller, context, result);
break;
case TypeKind::SMALLINT:
applyTyped<int16_t>(rows, args, caller, context, result);
break;
case TypeKind::INTEGER:
applyTyped<int32_t>(rows, args, caller, context, result);
break;
case TypeKind::BIGINT:
applyTyped<int64_t>(rows, args, caller, context, result);
break;
case TypeKind::REAL:
applyTyped<float>(rows, args, caller, context, result);
break;
case TypeKind::DOUBLE:
applyTyped<double>(rows, args, caller, context, result);
break;
case TypeKind::BOOLEAN:
applyTyped<bool>(rows, args, caller, context, result);
break;
case TypeKind::VARCHAR:
applyTyped<StringView>(rows, args, caller, context, result);
break;
case TypeKind::TIMESTAMP:
applyTyped<Timestamp>(rows, args, caller, context, result);
break;
default:
VELOX_CHECK(false, "Bad type for least/greatest");
}
}

private:
template <typename T>
void applyTyped(
const SelectivityVector& rows,
const std::vector<VectorPtr>& args,
exec::Expr* caller,
exec::EvalCtx* context,
VectorPtr* result) const {
auto isFlatVector = [](const VectorPtr vp) -> bool {
return vp->encoding() == VectorEncoding::Simple::FLAT;
};
Expand Down Expand Up @@ -121,63 +73,69 @@ class LeastGreatestFunction : public exec::VectorFunction {
}
}

template <typename T>
inline T getValue(const FlatVector<T>& v, vector_size_t i) const {
return v.valueAt(i);
}

template <typename T>
inline T getValue(const DecodedVector& v, vector_size_t i) const {
return v.valueAt<T>(i);
}

template <typename T>
static bool isNaN(const T& value) {
if constexpr (std::is_floating_point<T>::value) {
// In C++ NaN is not comparable (i.e. all comparisons return false)
// to any value (including itself).
return std::isnan(value);
} else {
return false;
}
}

template <typename T, typename VecType>
template <typename VecType>
void cmpAndReplace(
FlatVector<T>& dst,
const VecType& src,
SelectivityVector& rows) const {
const Cmp cmp;
rows.applyToSelected([&](vector_size_t row) {
const auto srcVal = getValue<T>(src, row);
const auto dstVal = getValue<T>(dst, row);

if (dst.isNullAt(row)) {
const auto srcVal = getValue(src, row);
if (dst.isNullAt(row) || cmp(srcVal, getValue(dst, row))) {
dst.set(row, srcVal);
} else {
// NaN is treated as the largest number
if constexpr (type == CompareType::MIN) {
if (isNaN(dstVal)) {
dst.set(row, srcVal);
} else if (!isNaN(srcVal)) {
dst.set(row, std::min(dstVal, srcVal));
}
} else {
if (isNaN(srcVal)) {
dst.set(row, srcVal);
} else if (!isNaN(dstVal)) {
dst.set(row, std::max(dstVal, srcVal));
}
}
}
});
}
};

template <template <typename> class Cmp>
std::shared_ptr<exec::VectorFunction> makeImpl(
const std::string& functionName,
const std::vector<exec::VectorFunctionArg>& args) {
VELOX_CHECK_GE(args.size(), 2);
for (size_t i = 1; i < args.size(); i++) {
VELOX_CHECK(*args[i].type == *args[0].type);
}

switch (args[0].type->kind()) {
#define SCALAR_CASE(kind) \
case TypeKind::kind: \
return std::make_shared<LeastGreatestFunction< \
Cmp<TypeTraits<TypeKind::kind>::NativeType>, \
TypeKind::kind>>();
SCALAR_CASE(BOOLEAN);
SCALAR_CASE(TINYINT);
SCALAR_CASE(SMALLINT);
SCALAR_CASE(INTEGER);
SCALAR_CASE(BIGINT);
SCALAR_CASE(REAL);
SCALAR_CASE(DOUBLE);
SCALAR_CASE(VARCHAR);
SCALAR_CASE(VARBINARY);
SCALAR_CASE(TIMESTAMP);
#undef SCALAR_CASE
default:
VELOX_NYI(
"{} does not support arguments of type {}",
functionName,
args[0].type->kind());
}
}

} // namespace

std::shared_ptr<exec::VectorFunction> makeLeast(
const std::string& /**/,
const std::vector<exec::VectorFunctionArg>& /**/) {
return std::make_shared<LeastGreatestFunction<CompareType::MIN>>();
const std::string& functionName,
const std::vector<exec::VectorFunctionArg>& args) {
return makeImpl<Less>(functionName, args);
}

std::vector<std::shared_ptr<exec::FunctionSignature>> leastSignatures() {
Expand All @@ -192,20 +150,13 @@ std::vector<std::shared_ptr<exec::FunctionSignature>> leastSignatures() {
}

std::shared_ptr<exec::VectorFunction> makeGreatest(
const std::string& /**/,
const std::vector<exec::VectorFunctionArg>& /**/) {
return std::make_shared<LeastGreatestFunction<CompareType::MAX>>();
const std::string& functionName,
const std::vector<exec::VectorFunctionArg>& args) {
return makeImpl<Greater>(functionName, args);
}

std::vector<std::shared_ptr<exec::FunctionSignature>> greatestSignatures() {
// T, T... -> T
return {exec::FunctionSignatureBuilder()
.typeVariable("T")
.returnType("T")
.argumentType("T")
.argumentType("T")
.variableArity()
.build()};
return leastSignatures();
}

} // namespace facebook::velox::functions::sparksql
4 changes: 4 additions & 0 deletions velox/functions/sparksql/LeastGreatest.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@

#pragma once

#include <memory>

#include "velox/expression/VectorFunction.h"

namespace facebook::velox::functions::sparksql {

/// Supported Types:
Expand Down

0 comments on commit 7325060

Please sign in to comment.