From baf839b4a4aa8d7d4ab8cdb1a5b82affd3ce376e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BA=91=E5=B3=A4?= Date: Sat, 2 May 2015 09:39:17 +0800 Subject: [PATCH] [SPARK-7294] ADD BETWEEN --- python/pyspark/sql/dataframe.py | 7 +++---- python/pyspark/sql/tests.py | 4 ++-- .../main/scala/org/apache/spark/sql/Column.scala | 13 +++++++++---- .../apache/spark/sql/ColumnExpressionSuite.scala | 2 +- 4 files changed, 15 insertions(+), 11 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index a4cbc7396e386..8c09bf23f3cc0 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1290,15 +1290,14 @@ def cast(self, dataType): return Column(jc) @ignore_unicode_prefix - def between(self, col1, col2): + def between(self, lowerBound, upperBound): """ A boolean expression that is evaluated to true if the value of this expression is between the given columns. - >>> df[df.col1.between(col2, col3)].collect() + >>> df[df.col1.between(lowerBound, upperBound)].collect() [Row(col1=5, col2=6, col3=8)] """ - #sc = SparkContext._active_spark_context - jc = self > col1 & self < col2 + jc = (self >= lowerBound) & (self <= upperBound) return Column(jc) def __repr__(self): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 206e3b7fd08f2..b5faedfe15e46 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -427,8 +427,8 @@ def test_rand_functions(self): assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1] def test_between_function(self): - df = self.sqlCtx.parallelize([Row(a=1, b=2, c=3), Row(a=2, b=1, c=3), Row(a=4, b=1, c=3)]).toDF() - self.assertEqual([False, True, False], + df = self.sqlCtx.parallelize([Row(a=1, b=2, c=3), Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)]).toDF() + self.assertEqual([False, True, True], df.select(df.a.between(df.b, df.c)).collect()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 8e0eab7918131..b51b6368eeb56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -296,18 +296,23 @@ class Column(protected[sql] val expr: Expression) extends Logging { def eqNullSafe(other: Any): Column = this <=> other /** - * Between col1 and col2. + * True if the current column is between the lower bound and upper bound, inclusive. * * @group java_expr_ops */ - def between(col1: String, col2: String): Column = between(Column(col1), Column(col2)) + def between(lowerBound: String, upperBound: String): Column = { + between(Column(lowerBound), Column(upperBound)) + } /** - * Between col1 and col2. + * True if the current column is between the lower bound and upper bound, inclusive. * * @group java_expr_ops */ - def between(col1: Column, col2: Column): Column = And(GreaterThan(this.expr, col1.expr), LessThan(this.expr, col2.expr)) + def between(lowerBound: Column, upperBound: Column): Column = { + And(GreaterThanOrEqual(this.expr, lowerBound.expr), + LessThanOrEqual(this.expr, upperBound.expr)) + } /** * True if the current expression is null. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 0a81f884e9a16..b63c1814adc3d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -211,7 +211,7 @@ class ColumnExpressionSuite extends QueryTest { test("between") { checkAnswer( testData4.filter($"a".between($"b", $"c")), - testData4.collect().toSeq.filter(r => r.getInt(0) > r.getInt(1) && r.getInt(0) < r.getInt(2))) + testData4.collect().toSeq.filter(r => r.getInt(0) >= r.getInt(1) && r.getInt(0) <= r.getInt(2))) } val booleanData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize(