Skip to content

Commit

Permalink
Modify a test for date type.
Browse files Browse the repository at this point in the history
  • Loading branch information
ueshin committed Feb 5, 2018
1 parent 223d0a0 commit 57ab41b
Showing 1 changed file with 34 additions and 9 deletions.
43 changes: 34 additions & 9 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4062,18 +4062,42 @@ def test_vectorized_udf_unsupported_types(self):
with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
df.select(f(col('map'))).collect()

def test_vectorized_udf_null_date(self):
def test_vectorized_udf_dates(self):
from pyspark.sql.functions import pandas_udf, col
from datetime import date
schema = StructType().add("date", DateType())
data = [(date(1969, 1, 1),),
(date(2012, 2, 2),),
(None,),
(date(2100, 4, 4),)]
schema = StructType().add("idx", LongType()).add("date", DateType())
data = [(0, date(1969, 1, 1),),
(1, date(2012, 2, 2),),
(2, None,),
(3, date(2100, 4, 4),)]
df = self.spark.createDataFrame(data, schema=schema)
date_f = pandas_udf(lambda t: t, returnType=DateType())
res = df.select(date_f(col("date")))
self.assertEquals(df.collect(), res.collect())

date_copy = pandas_udf(lambda t: t, returnType=DateType())
df = df.withColumn("date_copy", date_copy(col("date")))

@pandas_udf(returnType=StringType())
def check_data(idx, date, date_copy):
import pandas as pd
msgs = []
is_equal = date.isnull()
for i in range(len(idx)):
if (is_equal[i] and data[idx[i]][1] is None) or \
date[i] == data[idx[i]][1]:
msgs.append(None)
else:
msgs.append(
"date values are not equal (date='%s': data[%d][1]='%s')"
% (date[i], idx[i], data[idx[i]][1]))
return pd.Series(msgs)

result = df.withColumn("check_data",
check_data(col("idx"), col("date"), col("date_copy"))).collect()

self.assertEquals(len(data), len(result))
for i in range(len(result)):
self.assertEquals(data[i][1], result[i][1]) # "date" col
self.assertEquals(data[i][1], result[i][2]) # "date_copy" col
self.assertIsNone(result[i][3]) # "check_data" col

def test_vectorized_udf_timestamps(self):
from pyspark.sql.functions import pandas_udf, col
Expand Down Expand Up @@ -4114,6 +4138,7 @@ def check_data(idx, timestamp, timestamp_copy):
self.assertEquals(len(data), len(result))
for i in range(len(result)):
self.assertEquals(data[i][1], result[i][1]) # "timestamp" col
self.assertEquals(data[i][1], result[i][2]) # "timestamp_copy" col
self.assertIsNone(result[i][3]) # "check_data" col

def test_vectorized_udf_return_timestamp_tz(self):
Expand Down

0 comments on commit 57ab41b

Please sign in to comment.