From c1661675df87163c49d6fc3692a984e93cd12eab Mon Sep 17 00:00:00 2001 From: Shilei Date: Fri, 12 Jun 2015 14:20:17 +0800 Subject: [PATCH] Add md5 function --- .../catalyst/analysis/FunctionRegistry.scala | 3 ++ .../spark/sql/catalyst/expressions/misc.scala | 52 +++++++++++++++++++ .../expressions/MiscFunctionsSuite.scala | 32 ++++++++++++ .../org/apache/spark/sql/functions.scala | 19 +++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 11 ++++ 5 files changed, 117 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 04e306da23e4c..4d20576b8539f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -130,6 +130,9 @@ object FunctionRegistry { expression[ToDegrees]("degrees"), expression[ToRadians]("radians"), + // misc functions + expression[Md5]("md5"), + // aggregate functions expression[Average]("avg"), expression[Count]("count"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala new file mode 100644 index 0000000000000..fbdceb72c1ae7 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.security.MessageDigest + +import org.apache.commons.codec.digest.DigestUtils +import org.apache.spark.sql.types.{BinaryType, StringType, DataType} +import org.apache.spark.unsafe.types.UTF8String + +/** + * A function that calculates an MD5 128-bit checksum for the string or binary. + * Defined for String and Binary types. + */ +case class Md5(child: Expression) + extends UnaryExpression with ExpectsInputTypes { + + override def dataType: DataType = StringType + + override def expectedChildTypes: Seq[DataType] = + if (child.dataType == BinaryType) Seq(BinaryType) else Seq(StringType) + + override def children: Seq[Expression] = child :: Nil + + override def eval(input: Row): Any = { + val value = child.eval(input) + if (value == null) { + null + } else if (child.dataType == BinaryType) { + UTF8String.fromString(DigestUtils.md5Hex(value.asInstanceOf[Array[Byte]])) + } else { + UTF8String.fromString(DigestUtils.md5Hex(value.asInstanceOf[UTF8String].getBytes)) + } + } + + override def toString: String = s"md5($child)" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala new file mode 100644 index 0000000000000..f3dec71c5b756 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ + +class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("md5") { + val s1 = 'a.string.at(0) + val s2 = 'a.binary.at(0) + checkEvaluation(Md5(s1), "902fbdd2b1df0c4f70b4a5d23525e932", create_row("ABC")) + checkEvaluation(Md5(s2), "6ac1e56bc78f031059be7be854522c4c", create_row(Array[Byte](1,2,3,4,5,6))) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c5b77724aae17..ab37576415a41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -36,6 +36,7 @@ import org.apache.spark.util.Utils * @groupname sort_funcs Sorting functions * @groupname normal_funcs Non-aggregate functions * @groupname math_funcs Math functions + * @groupname misc_funcs Misc functions * @groupname window_funcs Window functions * @groupname string_funcs String functions * @groupname Ungrouped Support functions for DataFrames. @@ -1334,6 +1335,24 @@ object functions { */ def toRadians(columnName: String): Column = toRadians(Column(columnName)) + ////////////////////////////////////////////////////////////////////////////////////////////// + // Misc functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Calculates an MD5 128-bit checksum for the string or binary + * @group misc_funcs + * @since 1.5.0 + */ + def md5(e: Column): Column = Md5(e.expr) + + /** + * Calculates an MD5 128-bit checksum for the string or binary + * @group misc_funcs + * @since 1.5.0 + */ + def md5(columnName: String): Column = md5(Column(columnName)) + ////////////////////////////////////////////////////////////////////////////////////////////// // String functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index cfd23867a9bba..6d7750248aa73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -123,6 +123,17 @@ class DataFrameFunctionsSuite extends QueryTest { Row("x", "y", null)) } + test("misc md5 function") { + val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") + checkAnswer( + df.select(md5($"a"), md5("b")), + Row("902fbdd2b1df0c4f70b4a5d23525e932", "6ac1e56bc78f031059be7be854522c4c")) + + checkAnswer( + df.selectExpr("md5(a)", "md5(b)"), + Row("902fbdd2b1df0c4f70b4a5d23525e932", "6ac1e56bc78f031059be7be854522c4c")) + } + test("string length function") { checkAnswer( nullStrings.select(strlen($"s"), strlen("s")),