Skip to content

Commit

Permalink
Dsl -> functions, toDF()
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu authored and rxin committed Feb 13, 2015
1 parent fb256af commit 3a1004f
Show file tree
Hide file tree
Showing 7 changed files with 234 additions and 155 deletions.
8 changes: 8 additions & 0 deletions python/docs/pyspark.sql.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,11 @@ pyspark.sql.types module
:members:
:undoc-members:
:show-inheritance:


pyspark.sql.functions module
------------------------
.. automodule:: pyspark.sql.functions
:members:
:undoc-members:
:show-inheritance:
3 changes: 1 addition & 2 deletions python/pyspark/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,8 @@

from pyspark.sql.context import SQLContext, HiveContext
from pyspark.sql.types import Row
from pyspark.sql.dataframe import DataFrame, GroupedData, Column, Dsl, SchemaRDD
from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD

__all__ = [
'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row',
'Dsl',
]
23 changes: 22 additions & 1 deletion python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,25 @@
__all__ = ["SQLContext", "HiveContext"]


def _monkey_patch_RDD(sqlCtx):
def toDF(self, schema=None, sampleRatio=None):
"""
Convert current :class:`RDD` into a :class:`DataFrame`
This is a shorthand for `sqlCtx.createDataFrame(rdd, schema, sampleRatio)`
:param schema: a StructType or list of names of columns
:param samplingRatio: the sample ratio of rows used for inferring
:return: a DataFrame
>>> rdd.toDF().collect()
[Row(name=u'Alice', age=1)]
"""
return sqlCtx.createDataFrame(self, schema, sampleRatio)

RDD.toDF = toDF


class SQLContext(object):

"""Main entry point for Spark SQL functionality.
Expand Down Expand Up @@ -70,6 +89,7 @@ def __init__(self, sparkContext, sqlContext=None):
self._jsc = self._sc._jsc
self._jvm = self._sc._jvm
self._scala_SQLContext = sqlContext
_monkey_patch_RDD(self)

@property
def _ssql_ctx(self):
Expand Down Expand Up @@ -800,7 +820,8 @@ def _test():
Row(field1=2, field2="row2"),
Row(field1=3, field2="row3")]
)
globs['df'] = sqlCtx.createDataFrame(rdd)
_monkey_patch_RDD(sqlCtx)
globs['df'] = rdd.toDF()
jsonStrings = [
'{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
'{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},'
Expand Down
159 changes: 21 additions & 138 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,19 @@
import random
import os
from tempfile import NamedTemporaryFile
from itertools import imap

from py4j.java_collections import ListConverter, MapConverter

from pyspark.context import SparkContext
from pyspark.rdd import RDD, _prepare_for_python_RDD
from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \
UTF8Deserializer
from pyspark.rdd import RDD
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.sql.types import *
from pyspark.sql.types import _create_cls, _parse_datatype_json_string


__all__ = ["DataFrame", "GroupedData", "Column", "Dsl", "SchemaRDD"]
__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD"]


class DataFrame(object):
Expand Down Expand Up @@ -310,8 +308,9 @@ def take(self, num):
return self.limit(num).collect()

def map(self, f):
""" Return a new RDD by applying a function to each Row, it's a
shorthand for df.rdd.map()
""" Return a new RDD by applying a function to each Row
It's a shorthand for df.rdd.map()
>>> df.map(lambda p: p.name).collect()
[u'Alice', u'Bob']
Expand Down Expand Up @@ -586,8 +585,8 @@ def agg(self, *exprs):
>>> df.agg({"age": "max"}).collect()
[Row(MAX(age#0)=5)]
>>> from pyspark.sql import Dsl
>>> df.agg(Dsl.min(df.age)).collect()
>>> from pyspark.sql import functions as F
>>> df.agg(F.min(df.age)).collect()
[Row(MIN(age#0)=2)]
"""
return self.groupBy().agg(*exprs)
Expand Down Expand Up @@ -616,18 +615,18 @@ def subtract(self, other):
"""
return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx)

def addColumn(self, colName, col):
def withColumn(self, colName, col):
""" Return a new :class:`DataFrame` by adding a column.
>>> df.addColumn('age2', df.age + 2).collect()
>>> df.withColumn('age2', df.age + 2).collect()
[Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)]
"""
return self.select('*', col.alias(colName))

def renameColumn(self, existing, new):
def withColumnRenamed(self, existing, new):
""" Rename an existing column to a new name
>>> df.renameColumn('age', 'age2').collect()
>>> df.withColumnRenamed('age', 'age2').collect()
[Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')]
"""
cols = [Column(_to_java_column(c), self.sql_ctx).alias(new)
Expand Down Expand Up @@ -689,8 +688,9 @@ def agg(self, *exprs):
>>> gdf = df.groupBy(df.name)
>>> gdf.agg({"age": "max"}).collect()
[Row(name=u'Bob', MAX(age#0)=5), Row(name=u'Alice', MAX(age#0)=2)]
>>> from pyspark.sql import Dsl
>>> gdf.agg(Dsl.min(df.age)).collect()
>>> from pyspark.sql import functions as F
>>> gdf.agg(F.min(df.age)).collect()
[Row(MIN(age#0)=5), Row(MIN(age#0)=2)]
"""
assert exprs, "exprs should not be empty"
Expand Down Expand Up @@ -742,12 +742,12 @@ def sum(self):

def _create_column_from_literal(literal):
sc = SparkContext._active_spark_context
return sc._jvm.Dsl.lit(literal)
return sc._jvm.functions.lit(literal)


def _create_column_from_name(name):
sc = SparkContext._active_spark_context
return sc._jvm.Dsl.col(name)
return sc._jvm.functions.col(name)


def _to_java_column(col):
Expand All @@ -767,9 +767,9 @@ def _(self):
return _


def _dsl_op(name, doc=''):
def _func_op(name, doc=''):
def _(self):
jc = getattr(self._sc._jvm.Dsl, name)(self._jc)
jc = getattr(self._sc._jvm.functions, name)(self._jc)
return Column(jc, self.sql_ctx)
_.__doc__ = doc
return _
Expand Down Expand Up @@ -818,7 +818,7 @@ def __init__(self, jc, sql_ctx=None):
super(Column, self).__init__(jc, sql_ctx)

# arithmetic operators
__neg__ = _dsl_op("negate")
__neg__ = _func_op("negate")
__add__ = _bin_op("plus")
__sub__ = _bin_op("minus")
__mul__ = _bin_op("multiply")
Expand All @@ -842,7 +842,7 @@ def __init__(self, jc, sql_ctx=None):
# so use bitwise operators as boolean operators
__and__ = _bin_op('and')
__or__ = _bin_op('or')
__invert__ = _dsl_op('not')
__invert__ = _func_op('not')
__rand__ = _bin_op("and")
__ror__ = _bin_op("or")

Expand Down Expand Up @@ -934,123 +934,6 @@ def to_pandas(self):
return pd.Series(data)


def _aggregate_func(name, doc=""):
""" Create a function for aggregator by name"""
def _(col):
sc = SparkContext._active_spark_context
jc = getattr(sc._jvm.Dsl, name)(_to_java_column(col))
return Column(jc)
_.__name__ = name
_.__doc__ = doc
return staticmethod(_)


class UserDefinedFunction(object):
def __init__(self, func, returnType):
self.func = func
self.returnType = returnType
self._broadcast = None
self._judf = self._create_judf()

def _create_judf(self):
f = self.func # put it in closure `func`
func = lambda _, it: imap(lambda x: f(*x), it)
ser = AutoBatchedSerializer(PickleSerializer())
command = (func, None, ser, ser)
sc = SparkContext._active_spark_context
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
jdt = ssql_ctx.parseDataType(self.returnType.json())
judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env,
includes, sc.pythonExec, broadcast_vars,
sc._javaAccumulator, jdt)
return judf

def __del__(self):
if self._broadcast is not None:
self._broadcast.unpersist()
self._broadcast = None

def __call__(self, *cols):
sc = SparkContext._active_spark_context
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
sc._gateway._gateway_client)
jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols))
return Column(jc)


class Dsl(object):
"""
A collections of builtin aggregators
"""
DSLS = {
'lit': 'Creates a :class:`Column` of literal value.',
'col': 'Returns a :class:`Column` based on the given column name.',
'column': 'Returns a :class:`Column` based on the given column name.',
'upper': 'Converts a string expression to upper case.',
'lower': 'Converts a string expression to upper case.',
'sqrt': 'Computes the square root of the specified float value.',
'abs': 'Computes the absolutle value.',

'max': 'Aggregate function: returns the maximum value of the expression in a group.',
'min': 'Aggregate function: returns the minimum value of the expression in a group.',
'first': 'Aggregate function: returns the first value in a group.',
'last': 'Aggregate function: returns the last value in a group.',
'count': 'Aggregate function: returns the number of items in a group.',
'sum': 'Aggregate function: returns the sum of all values in the expression.',
'avg': 'Aggregate function: returns the average of the values in a group.',
'mean': 'Aggregate function: returns the average of the values in a group.',
'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
}

for _name, _doc in DSLS.items():
locals()[_name] = _aggregate_func(_name, _doc)
del _name, _doc

@staticmethod
def countDistinct(col, *cols):
""" Return a new Column for distinct count of (col, *cols)
>>> from pyspark.sql import Dsl
>>> df.agg(Dsl.countDistinct(df.age, df.name).alias('c')).collect()
[Row(c=2)]
>>> df.agg(Dsl.countDistinct("age", "name").alias('c')).collect()
[Row(c=2)]
"""
sc = SparkContext._active_spark_context
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
sc._gateway._gateway_client)
jc = sc._jvm.Dsl.countDistinct(_to_java_column(col),
sc._jvm.PythonUtils.toSeq(jcols))
return Column(jc)

@staticmethod
def approxCountDistinct(col, rsd=None):
""" Return a new Column for approxiate distinct count of (col, *cols)
>>> from pyspark.sql import Dsl
>>> df.agg(Dsl.approxCountDistinct(df.age).alias('c')).collect()
[Row(c=2)]
"""
sc = SparkContext._active_spark_context
if rsd is None:
jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col))
else:
jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd)
return Column(jc)

@staticmethod
def udf(f, returnType=StringType()):
"""Create a user defined function (UDF)
>>> slen = Dsl.udf(lambda s: len(s), IntegerType())
>>> df.select(slen(df.name).alias('slen')).collect()
[Row(slen=5), Row(slen=3)]
"""
return UserDefinedFunction(f, returnType)


def _test():
import doctest
from pyspark.context import SparkContext
Expand Down
Loading

0 comments on commit 3a1004f

Please sign in to comment.