Skip to content

Commit

Permalink
[SPARK-22966][PYTHON][SQL] Python UDFs with returnType=StringType sho…
Browse files Browse the repository at this point in the history
…uld treat return values of datetime.date or datetime.datetime as unconvertible

Add conversion to PySpark to mark Python UDFs that declared returnType=StringType() but actually returned a datatime.date or datetime.datetime as unconvertible, i.e. converting it to null.

Also added a new unit test to pyspark/sql/tests.py to reflect current semantics of Python UDFs returning a value of mismatched type with the declared returnType.
  • Loading branch information
rednaxelafx committed Jan 12, 2018
1 parent 186bf8f commit d307cee
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 0 deletions.
73 changes: 73 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,79 @@ def test_udf_with_array_type(self):
self.assertEqual(list(range(3)), l1)
self.assertEqual(1, l2)

def test_udf_returning_date_time(self):
from pyspark.sql.functions import udf
from pyspark.sql.types import DateType

data = self.spark.createDataFrame([(2017, 10, 30)], ['year', 'month', 'day'])

expected_date = datetime.date(2017, 10, 30)
expected_datetime = datetime.datetime(2017, 10, 30)

# test Python UDF with default returnType=StringType()
# Returning a date or datetime object at runtime with such returnType declaration
# is a mismatch, which results in a null, as PySpark treats it as unconvertible.
py_date_str, py_datetime_str = udf(datetime.date), udf(datetime.datetime)
query = data.select(
py_date_str(data.year, data.month, data.day).isNull(),
py_datetime_str(data.year, data.month, data.day).isNull())
[row] = query.collect()
self.assertEqual(row[0], True)
self.assertEqual(row[1], True)

query = data.select(
py_date_str(data.year, data.month, data.day),
py_datetime_str(data.year, data.month, data.day))
[row] = query.collect()
self.assertEqual(row[0], None)
self.assertEqual(row[1], None)

# test Python UDF with specific returnType matching actual result
py_date, py_datetime = udf(datetime.date, DateType()), udf(datetime.datetime, 'timestamp')
query = data.select(
py_date(data.year, data.month, data.day) == lit(expected_date),
py_datetime(data.year, data.month, data.day) == lit(expected_datetime))
[row] = query.collect()
self.assertEqual(row[0], True)
self.assertEqual(row[1], True)

query = data.select(
py_date(data.year, data.month, data.day),
py_datetime(data.year, data.month, data.day))
[row] = query.collect()
self.assertEqual(row[0], expected_date)
self.assertEqual(row[1], expected_datetime)

# test semantic matching of datetime with timezone
# class in __main__ is not serializable
from pyspark.sql.tests import UTCOffsetTimezone
datetime_with_utc0 = datetime.datetime(2017, 10, 30, tzinfo=UTCOffsetTimezone(0))
datetime_with_utc1 = datetime.datetime(2017, 10, 30, tzinfo=UTCOffsetTimezone(1))
test_udf = udf(lambda: datetime_with_utc0, 'timestamp')
query = data.select(
test_udf() == lit(datetime_with_utc0),
test_udf() > lit(datetime_with_utc1),
test_udf()
)
[row] = query.collect()
self.assertEqual(row[0], True)
self.assertEqual(row[1], True)
# Note: datetime returned from PySpark is always naive (timezone unaware).
# It currently respects Python's current local timezone.
self.assertEqual(row[2].tzinfo, None)

# tzinfo=None is really the same as not specifying it: a naive datetime object
# Just adding a test case for it here for completeness
datetime_with_null_timezone = datetime.datetime(2017, 10, 30, tzinfo=None)
test_udf = udf(lambda: datetime_with_null_timezone, 'timestamp')
query = data.select(
test_udf() == lit(datetime_with_null_timezone),
test_udf()
)
[row] = query.collect()
self.assertEqual(row[0], True)
self.assertEqual(row[1], datetime_with_null_timezone)

def test_broadcast_in_udf(self):
bar = {"a": "aa", "b": "bb", "c": "abc"}
foo = self.sc.broadcast(bar)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.python

import java.io.OutputStream
import java.nio.charset.StandardCharsets
import java.util.Calendar

import scala.collection.JavaConverters._

Expand Down Expand Up @@ -144,6 +145,7 @@ object EvaluatePython {
}

case StringType => (obj: Any) => nullSafeConvert(obj) {
case _: Calendar => null
case _ => UTF8String.fromString(obj.toString)
}

Expand Down

0 comments on commit d307cee

Please sign in to comment.