Skip to content

Commit

Permalink
Add math function bin(a: long): string.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Jun 9, 2015
1 parent eacd4a9 commit 50e0c3b
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 13 deletions.
1 change: 1 addition & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def _():
'measured in radians.',

'bitwiseNOT': 'Computes bitwise not.',
'bin': 'Computes the binary format of the given value.',
}

# math functions that take two arguments as input
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,35 @@

package org.apache.spark.sql.catalyst.expressions

import java.lang.{Long => JLong}

import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types.{DataType, DoubleType}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StringType, UTF8String}

/**
* A unary expression specifically for math functions. Math Functions expect a specific type of
* input format, therefore these functions extend `ExpectsInputTypes`.
* @param name The short name of the function
*/
abstract class UnaryMathExpression(f: Double => Double, name: String)
abstract class AbstractUnaryMathExpression[T, U](name: String)
extends UnaryExpression with Serializable with ExpectsInputTypes {
self: Product =>

override def expectedChildTypes: Seq[DataType] = Seq(DoubleType)
override def dataType: DataType = DoubleType
override def foldable: Boolean = child.foldable
override def nullable: Boolean = true
override def toString: String = s"$name($child)"

// name of function in java.lang.Math
def funcName: String = name.toLowerCase
}

abstract class UnaryMathExpression(f: Double => Double, name: String)
extends AbstractUnaryMathExpression[Double, Double](name) {
self: Product =>

override def expectedChildTypes: Seq[DataType] = Seq(DoubleType)
override def dataType: DataType = DoubleType

override def eval(input: Row): Any = {
val evalE = child.eval(input)
if (evalE == null) {
Expand All @@ -45,9 +56,6 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
}
}

// name of function in java.lang.Math
def funcName: String = name.toLowerCase

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val eval = child.gen(ctx)
eval.code + s"""
Expand Down Expand Up @@ -152,6 +160,26 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia
override def funcName: String = "toRadians"
}

case class Bin(child: Expression)
extends AbstractUnaryMathExpression[Long, String]("BIN") {

override def expectedChildTypes: Seq[DataType] = Seq(LongType)
override def dataType: DataType = StringType

override def eval(input: Row): Any = {
val evalE = child.eval(input)
if (evalE == null) {
null
} else {
UTF8String(JLong.toBinaryString(evalE.asInstanceOf[Long]))
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, (c) =>
s"org.apache.spark.sql.types.UTF8String.apply(java.lang.Long.toBinaryString($c))")
}
}

////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.types.{DataType, DoubleType, LongType}

class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {

Expand All @@ -31,11 +31,12 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
* @param expectNull Whether the given values should return null or not
* @tparam T Generic type for primitives
*/
private def testUnary[T](
private def testUnary[T, U](
c: Expression => Expression,
f: T => T,
f: T => U,
domain: Iterable[T] = (-20 to 20).map(_ * 0.1),
expectNull: Boolean = false): Unit = {
expectNull: Boolean = false,
evalType: DataType = DoubleType): Unit = {
if (expectNull) {
domain.foreach { value =>
checkEvaluation(c(Literal(value)), null, EmptyRow)
Expand All @@ -45,7 +46,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(c(Literal(value)), f(value), EmptyRow)
}
}
checkEvaluation(c(Literal.create(null, DoubleType)), null, create_row(null))
checkEvaluation(c(Literal.create(null, evalType)), null, create_row(null))
}

/**
Expand Down Expand Up @@ -145,7 +146,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("signum") {
testUnary[Double](Signum, math.signum)
testUnary[Double, Double](Signum, math.signum)
}

test("log") {
Expand All @@ -163,6 +164,10 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
testUnary(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), expectNull = true)
}

test("bin") {
testUnary(Bin, java.lang.Long.toBinaryString, (-20 to 20).map(_.toLong), evalType = LongType)
}

test("pow") {
testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0)))
testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true)
Expand Down
15 changes: 15 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1299,6 +1299,21 @@ object functions {
*/
def toRadians(columnName: String): Column = toRadians(Column(columnName))

/**
* Computes the binary format of the given value.
*
* @group math_funcs
* @since 1.4.0
*/
def bin(e: Column): Column = Bin(e.expr)

/**
* Computes the binary format of the given value.
*
* @group math_funcs
* @since 1.4.0
*/
def bin(columnName: String): Column = bin(Column(columnName))

//////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql

import java.lang.{Long => JLong}

import org.apache.spark.sql.TestData._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -90,4 +92,10 @@ class DataFrameFunctionsSuite extends QueryTest {
testData2.select(bitwiseNOT($"a")),
testData2.collect().toSeq.map(r => Row(~r.getInt(0))))
}

test("bin") {
checkAnswer(
testData2.select(bin($"a")),
testData2.collect().toSeq.map(r => Row(JLong.toBinaryString(r.getInt(0).toLong))))
}
}

0 comments on commit 50e0c3b

Please sign in to comment.