Skip to content

Commit

Permalink
Add hash seed parameter to sparksql hash functions (facebookincubator…
Browse files Browse the repository at this point in the history
  • Loading branch information
marin-ma authored and zhejiangxiaomai committed May 26, 2023
1 parent 0099149 commit c905faa
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 15 deletions.
50 changes: 35 additions & 15 deletions velox/functions/sparksql/Hash.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,22 @@ void applyWithType(
std::vector<VectorPtr>& args, // Not using const ref so we can reuse args
exec::EvalCtx& context,
VectorPtr& resultRef) {
constexpr SeedType kSeed = 42;
SeedType seed = 42;
auto hashIdx = 0;
if (args[0]->isConstantEncoding()) {
seed = args[0]->as<ConstantVector<SeedType>>()->valueAt(0);
hashIdx = 1;
}

HashClass hash;

auto& result = *resultRef->as<FlatVector<ReturnType>>();
rows.applyToSelected([&](int row) { result.set(row, kSeed); });
rows.applyToSelected([&](int row) { result.set(row, seed); });

exec::LocalSelectivityVector selectedMinusNulls(context);

exec::DecodedArgs decodedArgs(rows, args, context);
for (auto i = 0; i < args.size(); i++) {
for (auto i = hashIdx; i < args.size(); i++) {
auto decoded = decodedArgs.at(i);
const SelectivityVector* selected = &rows;
if (args[i]->mayHaveNulls()) {
Expand Down Expand Up @@ -193,7 +199,7 @@ class Murmur3HashFunction final : public exec::VectorFunction {
exec::EvalCtx& context,
VectorPtr& resultRef) const final {
context.ensureWritable(rows, INTEGER(), resultRef);
applyWithType<int32_t, Murmur3Hash, uint32_t>(
applyWithType<int32_t, Murmur3Hash, int32_t>(
rows, args, context, resultRef);
}
};
Expand Down Expand Up @@ -359,18 +365,25 @@ class XxHash64Function final : public exec::VectorFunction {
exec::EvalCtx& context,
VectorPtr& resultRef) const final {
context.ensureWritable(rows, BIGINT(), resultRef);
applyWithType<int64_t, XxHash64, uint64_t>(rows, args, context, resultRef);
applyWithType<int64_t, XxHash64, int64_t>(rows, args, context, resultRef);
}
};

} // namespace

std::vector<std::shared_ptr<exec::FunctionSignature>> hashSignatures() {
return {exec::FunctionSignatureBuilder()
.returnType("integer")
.argumentType("any")
.variableArity()
.build()};
return {
exec::FunctionSignatureBuilder()
.returnType("integer")
.argumentType("any")
.variableArity()
.build(),
exec::FunctionSignatureBuilder()
.returnType("integer")
.constantArgumentType("integer")
.argumentType("any")
.variableArity()
.build()};
}

std::shared_ptr<exec::VectorFunction> makeHash(
Expand All @@ -381,11 +394,18 @@ std::shared_ptr<exec::VectorFunction> makeHash(
}

std::vector<std::shared_ptr<exec::FunctionSignature>> xxhash64Signatures() {
return {exec::FunctionSignatureBuilder()
.returnType("integer")
.argumentType("any")
.variableArity()
.build()};
return {
exec::FunctionSignatureBuilder()
.returnType("integer")
.argumentType("any")
.variableArity()
.build(),
exec::FunctionSignatureBuilder()
.returnType("integer")
.constantArgumentType("bigint")
.argumentType("any")
.variableArity()
.build()};
}

std::shared_ptr<exec::VectorFunction> makeXxHash64(
Expand Down
27 changes: 27 additions & 0 deletions velox/functions/sparksql/tests/XxHash64Test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ class XxHash64Test : public SparkFunctionBaseTest {
std::optional<int64_t> xxhash64(std::optional<T> arg) {
return evaluateOnce<int64_t>("xxhash64(c0)", arg);
}

template <typename T, typename Seed>
std::optional<int64_t> xxhash64WithSeed(Seed seed, std::optional<T> arg) {
return evaluateOnce<int64_t>(fmt::format("xxhash64({}, c0)", seed), arg);
}
};

// The expected result was obtained by running SELECT xxhash64("Spark") query
Expand Down Expand Up @@ -118,5 +123,27 @@ TEST_F(XxHash64Test, float) {
EXPECT_EQ(xxhash64<float>(limits::infinity()), -5940311692336719973);
EXPECT_EQ(xxhash64<float>(-limits::infinity()), -7580553461823983095);
}

TEST_F(XxHash64Test, hashSeed) {
using long_limits = std::numeric_limits<int64_t>;
using int_limits = std::numeric_limits<int32_t>;
std::vector<int64_t> seeds = {
long_limits::min(),
static_cast<int64_t>(int_limits::min()) - 1L,
0L,
static_cast<int64_t>(int_limits::max()) + 1L,
long_limits::max()};
std::vector<int64_t> expected = {
-6671470883434376173,
8765374525824963196,
-5379971487550586029,
8810073187160811495,
4605443450566835086};

for (auto i = 0; i < seeds.size(); ++i) {
EXPECT_EQ(xxhash64WithSeed<int64_t>(seeds[i], 42), expected[i]);
}
}

} // namespace
} // namespace facebook::velox::functions::sparksql::test

0 comments on commit c905faa

Please sign in to comment.