Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
HyukjinKwon committed Feb 8, 2018
1 parent 68662ec commit 36617e4
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 23 deletions.
23 changes: 12 additions & 11 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3718,7 +3718,7 @@ def foo(x):
@pandas_udf(returnType='double', functionType=PandasUDFType.SCALAR)
def foo(x):
return x
self.assertEqual(foo.returnType, schema[0].dataType)
self.assertEqual(foo.returnType, DoubleType())
self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)

@pandas_udf(returnType=schema, functionType=PandasUDFType.GROUPED_MAP)
Expand Down Expand Up @@ -4032,7 +4032,7 @@ def test_vectorized_udf_wrong_return_type(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid returnType.*a scalar Pandas UDF.*MapType'):
'Invalid returnType.*scalar Pandas UDF.*MapType'):
pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType()))

def test_vectorized_udf_return_scalar(self):
Expand Down Expand Up @@ -4072,13 +4072,13 @@ def test_vectorized_udf_unsupported_types(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid returnType.*a scalar Pandas UDF.*MapType'):
'Invalid returnType.*scalar Pandas UDF.*MapType'):
pandas_udf(lambda x: x, MapType(StringType(), IntegerType()))

with QuietTest(self.sc):
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid returnType.*a scalar Pandas UDF.*BinaryType'):
'Invalid returnType.*scalar Pandas UDF.*BinaryType'):
pandas_udf(lambda x: x, BinaryType())

def test_vectorized_udf_dates(self):
Expand Down Expand Up @@ -4296,7 +4296,7 @@ def data(self):
.withColumn("vs", array([lit(i) for i in range(20, 30)])) \
.withColumn("v", explode(col('vs'))).drop('vs')

def test_simple(self):
def test_supported_types(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col
df = self.data.withColumn("arr", array(col("id")))

Expand Down Expand Up @@ -4412,7 +4412,7 @@ def test_wrong_return_type(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid returnType.*a grouped map Pandas UDF.*MapType'):
'Invalid returnType.*grouped map Pandas UDF.*MapType'):
pandas_udf(
lambda pdf: pdf,
'id long, v map<int, int>',
Expand Down Expand Up @@ -4448,7 +4448,7 @@ def test_unsupported_types(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid returnType.*a grouped map Pandas UDF.*MapType'):
'Invalid returnType.*grouped map Pandas UDF.*MapType'):
pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP)

schema = StructType(
Expand All @@ -4457,7 +4457,7 @@ def test_unsupported_types(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid returnType.*a grouped map Pandas UDF.*ArrayType.*TimestampType'):
'Invalid returnType.*grouped map Pandas UDF.*ArrayType.*TimestampType'):
pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP)


Expand Down Expand Up @@ -4590,9 +4590,10 @@ def test_unsupported_types(self):

with QuietTest(self.sc):
with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
@pandas_udf(ArrayType(ArrayType(TimestampType())), PandasUDFType.GROUPED_AGG)
def mean_and_std_udf(v):
return v
pandas_udf(
lambda x: x,
ArrayType(ArrayType(TimestampType())),
PandasUDFType.GROUPED_AGG)

with QuietTest(self.sc):
with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
Expand Down
24 changes: 12 additions & 12 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,30 +112,30 @@ def returnType(self):
else:
self._returnType_placeholder = _parse_datatype_string(self._returnType)

if self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
if self.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
try:
to_arrow_type(self._returnType_placeholder)
except TypeError:
raise NotImplementedError(
"Invalid returnType with scalar Pandas UDFs: %s is "
"not supported" % str(self._returnType_placeholder))
elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
if isinstance(self._returnType_placeholder, StructType):
try:
to_arrow_schema(self._returnType_placeholder)
except TypeError:
raise NotImplementedError(
"Invalid returnType with a grouped map Pandas UDF: "
"Invalid returnType with grouped map Pandas UDFs: "
"%s is not supported" % str(self._returnType_placeholder))
else:
raise TypeError("Invalid returnType for a grouped map Pandas "
"UDF: returnType must be a StructType.")
elif self.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
try:
to_arrow_type(self._returnType_placeholder)
except TypeError:
raise NotImplementedError(
"Invalid returnType with a scalar Pandas UDF: %s is "
"not supported" % str(self._returnType_placeholder))
raise TypeError("Invalid returnType for grouped map Pandas "
"UDFs: returnType must be a StructType.")
elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
try:
to_arrow_type(self._returnType_placeholder)
except TypeError:
raise NotImplementedError(
"Invalid returnType with a grouped aggregate Pandas UDF: "
"Invalid returnType with grouped aggregate Pandas UDFs: "
"%s is not supported" % str(self._returnType_placeholder))

return self._returnType_placeholder
Expand Down

0 comments on commit 36617e4

Please sign in to comment.