Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-3814][SQL] Support for Bitwise AND(&), OR(|) ,XOR(^), NOT(~) in Spark HQL and SQL #2961

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class SqlLexical(val keywords: Seq[String]) extends StdLexical {

delimiters += (
"@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")",
",", ";", "%", "{", "}", ":", "[", "]", "."
",", ";", "%", "{", "}", ":", "[", "]", ".", "&", "|", "^", "~"
)

override lazy val token: Parser[Token] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,9 @@ class SqlParser extends AbstractSparkSQLParser {
( "*" ^^^ { (e1: Expression, e2: Expression) => Multiply(e1, e2) }
| "/" ^^^ { (e1: Expression, e2: Expression) => Divide(e1, e2) }
| "%" ^^^ { (e1: Expression, e2: Expression) => Remainder(e1, e2) }
| "&" ^^^ { (e1: Expression, e2: Expression) => BitwiseAnd(e1, e2) }
| "|" ^^^ { (e1: Expression, e2: Expression) => BitwiseOr(e1, e2) }
| "^" ^^^ { (e1: Expression, e2: Expression) => BitwiseXor(e1, e2) }
)

protected lazy val function: Parser[Expression] =
Expand Down Expand Up @@ -370,6 +373,7 @@ class SqlParser extends AbstractSparkSQLParser {
| dotExpressionHeader
| ident ^^ UnresolvedAttribute
| signedPrimary
| "~" ~> expression ^^ BitwiseNot
)

protected lazy val dotExpressionHeader: Parser[Expression] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,16 @@ package object dsl {

def unary_- = UnaryMinus(expr)
def unary_! = Not(expr)
def unary_~ = BitwiseNot(expr)

def + (other: Expression) = Add(expr, other)
def - (other: Expression) = Subtract(expr, other)
def * (other: Expression) = Multiply(expr, other)
def / (other: Expression) = Divide(expr, other)
def % (other: Expression) = Remainder(expr, other)
def & (other: Expression) = BitwiseAnd(expr, other)
def | (other: Expression) = BitwiseOr(expr, other)
def ^ (other: Expression) = BitwiseXor(expr, other)

def && (other: Expression) = And(expr, other)
def || (other: Expression) = Or(expr, other)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,23 @@ abstract class BinaryArithmetic extends BinaryExpression {
}
left.dataType
}

override def eval(input: Row): Any = {
val evalE1 = left.eval(input)
if(evalE1 == null) {
null
} else {
val evalE2 = right.eval(input)
if (evalE2 == null) {
null
} else {
evalInternal(evalE1, evalE2)
}
}
}

def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any =
sys.error(s"BinaryExpressions must either override eval or evalInternal")
}

case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
Expand Down Expand Up @@ -100,6 +117,78 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
override def eval(input: Row): Any = i2(input, left, right, _.rem(_, _))
}

/**
* A function that calculates bitwise and(&) of two numbers.
*/
case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol = "&"

override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = dataType match {
case ByteType => (evalE1.asInstanceOf[Byte] & evalE2.asInstanceOf[Byte]).toByte
case ShortType => (evalE1.asInstanceOf[Short] & evalE2.asInstanceOf[Short]).toShort
case IntegerType => evalE1.asInstanceOf[Int] & evalE2.asInstanceOf[Int]
case LongType => evalE1.asInstanceOf[Long] & evalE2.asInstanceOf[Long]
case other => sys.error(s"Unsupported bitwise & operation on ${other}")
}
}

/**
* A function that calculates bitwise or(|) of two numbers.
*/
case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol = "&"

override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = dataType match {
case ByteType => (evalE1.asInstanceOf[Byte] | evalE2.asInstanceOf[Byte]).toByte
case ShortType => (evalE1.asInstanceOf[Short] | evalE2.asInstanceOf[Short]).toShort
case IntegerType => evalE1.asInstanceOf[Int] | evalE2.asInstanceOf[Int]
case LongType => evalE1.asInstanceOf[Long] | evalE2.asInstanceOf[Long]
case other => sys.error(s"Unsupported bitwise | operation on ${other}")
}
}

/**
* A function that calculates bitwise xor(^) of two numbers.
*/
case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol = "^"

override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = dataType match {
case ByteType => (evalE1.asInstanceOf[Byte] ^ evalE2.asInstanceOf[Byte]).toByte
case ShortType => (evalE1.asInstanceOf[Short] ^ evalE2.asInstanceOf[Short]).toShort
case IntegerType => evalE1.asInstanceOf[Int] ^ evalE2.asInstanceOf[Int]
case LongType => evalE1.asInstanceOf[Long] ^ evalE2.asInstanceOf[Long]
case other => sys.error(s"Unsupported bitwise ^ operation on ${other}")
}
}

/**
* A function that calculates bitwise not(~) of a number.
*/
case class BitwiseNot(child: Expression) extends UnaryExpression {
type EvaluatedType = Any

def dataType = child.dataType
override def foldable = child.foldable
def nullable = child.nullable
override def toString = s"-$child"

override def eval(input: Row): Any = {
val evalE = child.eval(input)
if (evalE == null) {
null
} else {
dataType match {
case ByteType => (~(evalE.asInstanceOf[Byte])).toByte
case ShortType => (~(evalE.asInstanceOf[Short])).toShort
case IntegerType => ~(evalE.asInstanceOf[Int])
case LongType => ~(evalE.asInstanceOf[Long])
case other => sys.error(s"Unsupported bitwise ~ operation on ${other}")
}
}
}
}

case class MaxOf(left: Expression, right: Expression) extends Expression {
type EvaluatedType = Any

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -680,4 +680,36 @@ class ExpressionEvaluationSuite extends FunSuite {

checkEvaluation(Sqrt(Literal(null, DoubleType)), null, new GenericRow(Array[Any](null)))
}

test("Bitwise operations") {
val row = new GenericRow(Array[Any](1, 2, 3, null))
val c1 = 'a.int.at(0)
val c2 = 'a.int.at(1)
val c3 = 'a.int.at(2)
val c4 = 'a.int.at(3)

checkEvaluation(BitwiseAnd(c1, c4), null, row)
checkEvaluation(BitwiseAnd(c1, c2), 0, row)
checkEvaluation(BitwiseAnd(c1, Literal(null, IntegerType)), null, row)
checkEvaluation(BitwiseAnd(Literal(null, IntegerType), Literal(null, IntegerType)), null, row)

checkEvaluation(BitwiseOr(c1, c4), null, row)
checkEvaluation(BitwiseOr(c1, c2), 3, row)
checkEvaluation(BitwiseOr(c1, Literal(null, IntegerType)), null, row)
checkEvaluation(BitwiseOr(Literal(null, IntegerType), Literal(null, IntegerType)), null, row)

checkEvaluation(BitwiseXor(c1, c4), null, row)
checkEvaluation(BitwiseXor(c1, c2), 3, row)
checkEvaluation(BitwiseXor(c1, Literal(null, IntegerType)), null, row)
checkEvaluation(BitwiseXor(Literal(null, IntegerType), Literal(null, IntegerType)), null, row)

checkEvaluation(BitwiseNot(c4), null, row)
checkEvaluation(BitwiseNot(c1), -2, row)
checkEvaluation(BitwiseNot(Literal(null, IntegerType)), null, row)

checkEvaluation(c1 & c2, 0, row)
checkEvaluation(c1 | c2, 3, row)
checkEvaluation(c1 ^ c2, 3, row)
checkEvaluation(~c1, -2, row)
}
}
16 changes: 16 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -883,4 +883,20 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
jsonRDD(data).registerTempTable("records")
sql("SELECT `key?number1` FROM records")
}

test("SPARK-3814 Support Bitwise & operator") {
checkAnswer(sql("SELECT key&1 FROM testData WHERE key = 1 "), 1)
}

test("SPARK-3814 Support Bitwise | operator") {
checkAnswer(sql("SELECT key|0 FROM testData WHERE key = 1 "), 1)
}

test("SPARK-3814 Support Bitwise ^ operator") {
checkAnswer(sql("SELECT key^0 FROM testData WHERE key = 1 "), 1)
}

test("SPARK-3814 Support Bitwise ~ operator") {
checkAnswer(sql("SELECT ~key FROM testData WHERE key = 1 "), -2)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -951,13 +951,17 @@ private[hive] object HiveQl {

/* Arithmetic */
case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child))
case Token("~", child :: Nil) => BitwiseNot(nodeToExpr(child))
case Token("+", left :: right:: Nil) => Add(nodeToExpr(left), nodeToExpr(right))
case Token("-", left :: right:: Nil) => Subtract(nodeToExpr(left), nodeToExpr(right))
case Token("*", left :: right:: Nil) => Multiply(nodeToExpr(left), nodeToExpr(right))
case Token("/", left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right))
case Token(DIV(), left :: right:: Nil) =>
Cast(Divide(nodeToExpr(left), nodeToExpr(right)), LongType)
case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right))
case Token("&", left :: right:: Nil) => BitwiseAnd(nodeToExpr(left), nodeToExpr(right))
case Token("|", left :: right:: Nil) => BitwiseOr(nodeToExpr(left), nodeToExpr(right))
case Token("^", left :: right:: Nil) => BitwiseXor(nodeToExpr(left), nodeToExpr(right))
case Token("TOK_FUNCTION", Token(SQRT(), Nil) :: arg :: Nil) => Sqrt(nodeToExpr(arg))

/* Comparisons */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,28 @@ class SQLQuerySuite extends QueryTest {
sql("SELECT a.key FROM (SELECT key FROM src) `a`"),
sql("SELECT `key` FROM src").collect().toSeq)
}

test("SPARK-3814 Support Bitwise & operator") {
checkAnswer(
sql("SELECT case when 1&1=1 then 1 else 0 end FROM src"),
sql("SELECT 1 FROM src").collect().toSeq)
}

test("SPARK-3814 Support Bitwise | operator") {
checkAnswer(
sql("SELECT case when 1|0=1 then 1 else 0 end FROM src"),
sql("SELECT 1 FROM src").collect().toSeq)
}

test("SPARK-3814 Support Bitwise ^ operator") {
checkAnswer(
sql("SELECT case when 1^0=1 then 1 else 0 end FROM src"),
sql("SELECT 1 FROM src").collect().toSeq)
}

test("SPARK-3814 Support Bitwise ~ operator") {
checkAnswer(
sql("SELECT case when ~1=-2 then 1 else 0 end FROM src"),
sql("SELECT 1 FROM src").collect().toSeq)
}
}