Skip to content

Commit

Permalink
[SPARK-5588] [SQL] support select/filter by SQL expression
Browse files Browse the repository at this point in the history
```
df.selectExpr('a + 1', 'abs(age)')
df.filter('age > 3')
df[ df.age > 3 ]
df[ ['age', 'name'] ]
```

Author: Davies Liu <[email protected]>

Closes #4359 from davies/select_expr and squashes the following commits:

d99856b [Davies Liu] support select/filter by SQL expression
  • Loading branch information
Davies Liu authored and rxin committed Feb 4, 2015
1 parent 38a416f commit ac0b2b7
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

package org.apache.spark.api.python

import java.io.{File, InputStream, IOException, OutputStream}
import java.io.{File}
import java.util.{List => JList}

import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.SparkContext
Expand All @@ -44,4 +46,11 @@ private[spark] object PythonUtils {
def generateRDDWithNull(sc: JavaSparkContext): JavaRDD[String] = {
sc.parallelize(List("a", null, "b"))
}

/**
* Convert list of T into seq of T (for calling API with varargs)
*/
def toSeq[T](cols: JList[T]): Seq[T] = {
cols.toList.toSeq
}
}
53 changes: 43 additions & 10 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2128,7 +2128,7 @@ def sort(self, *cols):
raise ValueError("should sort by at least one column")
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
self._sc._gateway._gateway_client)
jdf = self._jdf.sort(self._sc._jvm.Dsl.toColumns(jcols))
jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols))
return DataFrame(jdf, self.sql_ctx)

sortBy = sort
Expand Down Expand Up @@ -2159,13 +2159,20 @@ def __getitem__(self, item):
>>> df['age'].collect()
[Row(age=2), Row(age=5)]
>>> df[ ["name", "age"]].collect()
[Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
>>> df[ df.age > 3 ].collect()
[Row(age=5, name=u'Bob')]
"""
if isinstance(item, basestring):
jc = self._jdf.apply(item)
return Column(jc, self.sql_ctx)

# TODO projection
raise IndexError
elif isinstance(item, Column):
return self.filter(item)
elif isinstance(item, list):
return self.select(*item)
else:
raise IndexError("unexpected index: %s" % item)

def __getattr__(self, name):
""" Return the column by given name
Expand Down Expand Up @@ -2194,18 +2201,44 @@ def select(self, *cols):
cols = ["*"]
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
self._sc._gateway._gateway_client)
jdf = self._jdf.select(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
return DataFrame(jdf, self.sql_ctx)

def selectExpr(self, *expr):
"""
Selects a set of SQL expressions. This is a variant of
`select` that accepts SQL expressions.
>>> df.selectExpr("age * 2", "abs(age)").collect()
[Row(('age * 2)=4, Abs('age)=2), Row(('age * 2)=10, Abs('age)=5)]
"""
jexpr = ListConverter().convert(expr, self._sc._gateway._gateway_client)
jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr))
return DataFrame(jdf, self.sql_ctx)

def filter(self, condition):
""" Filtering rows using the given condition.
""" Filtering rows using the given condition, which could be
Column expression or string of SQL expression.
where() is an alias for filter().
>>> df.filter(df.age > 3).collect()
[Row(age=5, name=u'Bob')]
>>> df.where(df.age == 2).collect()
[Row(age=2, name=u'Alice')]
>>> df.filter("age > 3").collect()
[Row(age=5, name=u'Bob')]
>>> df.where("age = 2").collect()
[Row(age=2, name=u'Alice')]
"""
return DataFrame(self._jdf.filter(condition._jc), self.sql_ctx)
if isinstance(condition, basestring):
jdf = self._jdf.filter(condition)
elif isinstance(condition, Column):
jdf = self._jdf.filter(condition._jc)
else:
raise TypeError("condition should be string or Column")
return DataFrame(jdf, self.sql_ctx)

where = filter

Expand All @@ -2223,7 +2256,7 @@ def groupBy(self, *cols):
"""
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
self._sc._gateway._gateway_client)
jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
return GroupedDataFrame(jdf, self.sql_ctx)

def agg(self, *exprs):
Expand Down Expand Up @@ -2338,7 +2371,7 @@ def agg(self, *exprs):
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
jcols = ListConverter().convert([c._jc for c in exprs[1:]],
self.sql_ctx._sc._gateway._gateway_client)
jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
return DataFrame(jdf, self.sql_ctx)

@dfapi
Expand Down Expand Up @@ -2633,7 +2666,7 @@ def countDistinct(col, *cols):
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
sc._gateway._gateway_client)
jc = sc._jvm.Dsl.countDistinct(_to_java_column(col),
sc._jvm.Dsl.toColumns(jcols))
sc._jvm.PythonUtils.toSeq(jcols))
return Column(jc)

@staticmethod
Expand Down
11 changes: 0 additions & 11 deletions sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,8 @@

package org.apache.spark.sql

import java.util.{List => JList}

import scala.language.implicitConversions
import scala.reflect.runtime.universe.{TypeTag, typeTag}
import scala.collection.JavaConversions._

import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -169,14 +166,6 @@ object Dsl {
/** Computes the absolutle value. */
def abs(e: Column): Column = Abs(e.expr)

/**
* This is a private API for Python
* TODO: move this to a private package
*/
def toColumns(cols: JList[Column]): Seq[Column] = {
cols.toList.toSeq
}

//////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////

Expand Down

0 comments on commit ac0b2b7

Please sign in to comment.