Skip to content

Commit

Permalink
[SPARK-7294] ADD BETWEEN
Browse files Browse the repository at this point in the history
  • Loading branch information
云峤 committed May 1, 2015
1 parent a9fc505 commit d11d5b9
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 0 deletions.
12 changes: 12 additions & 0 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,6 +1289,18 @@ def cast(self, dataType):
raise TypeError("unexpected type: %s" % type(dataType))
return Column(jc)

@ignore_unicode_prefix
def between(self, col1, col2):
""" A boolean expression that is evaluated to true if the value of this
expression is between the given columns.
>>> df[df.col1.between(col2, col3)].collect()
[Row(col1=5, col2=6, col3=8)]
"""
#sc = SparkContext._active_spark_context
jc = self > col1 & self < col2
return Column(jc)

def __repr__(self):
return 'Column<%s>' % self._jc.toString().encode('utf8')

Expand Down
6 changes: 6 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,12 @@ def test_rand_functions(self):
for row in rndn:
assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1]

def test_between_function(self):
df = self.sqlCtx.parallelize([Row(a=1, b=2, c=3), Row(a=2, b=1, c=3), Row(a=4, b=1, c=3)]).toDF()
self.assertEqual([False, True, False],
df.select(df.a.between(df.b, df.c)).collect())


def test_save_and_load(self):
df = self.df
tmpPath = tempfile.mkdtemp()
Expand Down
14 changes: 14 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,20 @@ class Column(protected[sql] val expr: Expression) extends Logging {
*/
def eqNullSafe(other: Any): Column = this <=> other

/**
* Between col1 and col2.
*
* @group java_expr_ops
*/
def between(col1: String, col2: String): Column = between(Column(col1), Column(col2))

/**
* Between col1 and col2.
*
* @group java_expr_ops
*/
def between(col1: Column, col2: Column): Column = And(GreaterThan(this.expr, col1.expr), LessThan(this.expr, col2.expr))

/**
* True if the current expression is null.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,12 @@ class ColumnExpressionSuite extends QueryTest {
testData2.collect().toSeq.filter(r => r.getInt(0) <= r.getInt(1)))
}

test("between") {
checkAnswer(
testData4.filter($"a".between($"b", $"c")),
testData4.collect().toSeq.filter(r => r.getInt(0) > r.getInt(1) && r.getInt(0) < r.getInt(2)))
}

val booleanData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize(
Row(false, false) ::
Row(false, true) ::
Expand Down
11 changes: 11 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,17 @@ object TestData {
TestData2(3, 2) :: Nil, 2).toDF()
testData2.registerTempTable("testData2")

case class TestData4(a: Int, b: Int, c: Int)
val testData4 =
TestSQLContext.sparkContext.parallelize(
TestData4(0, 1, 2) ::
TestData4(1, 2, 3) ::
TestData4(2, 1, 0) ::
TestData4(2, 2, 4) ::
TestData4(3, 1, 6) ::
TestData4(3, 2, 0) :: Nil, 2).toDF()
testData4.registerTempTable("TestData4")

case class DecimalData(a: BigDecimal, b: BigDecimal)

val decimalData =
Expand Down

0 comments on commit d11d5b9

Please sign in to comment.