Skip to content

Commit

Permalink
Modify check_data udf for debug messages.
Browse files Browse the repository at this point in the history
  • Loading branch information
ueshin committed Nov 2, 2017
1 parent 6872516 commit 1f096bf
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3484,22 +3484,29 @@ def test_vectorized_udf_timestamps(self):
f_timestamp_copy = pandas_udf(lambda t: t, returnType=TimestampType())
df = df.withColumn("timestamp_copy", f_timestamp_copy(col("timestamp")))

@pandas_udf(returnType=BooleanType())
@pandas_udf(returnType=StringType())
def check_data(idx, timestamp, timestamp_copy):
import pandas as pd
msgs = []
is_equal = timestamp.isnull() # use this array to check values are equal
for i in range(len(idx)):
# Check that timestamps are as expected in the UDF
is_equal[i] = (is_equal[i] and data[idx[i]][1] is None) or \
timestamp[i].to_pydatetime() == data[idx[i]][1]
return is_equal

result = df.withColumn("is_equal", check_data(col("idx"), col("timestamp"),
col("timestamp_copy"))).collect()
if (is_equal[i] and data[idx[i]][1] is None) or \
timestamp[i].to_pydatetime() == data[idx[i]][1]:
msgs.append(None)
else:
msgs.append(
"timestamp values are not equal (timestamp='%s': data[%d][1]='%s')"
% (timestamp[i], idx[i], data[idx[i]][1]))
return pd.Series(msgs)

result = df.withColumn("check_data", check_data(col("idx"), col("timestamp"),
col("timestamp_copy"))).collect()
# Check that collection values are correct
self.assertEquals(len(data), len(result))
for i in range(len(result)):
self.assertEquals(data[i][1], result[i][1]) # "timestamp" col
self.assertTrue(result[i][3]) # "is_equal" data in udf was as expected
self.assertIsNone(result[i][3]) # "check_data" col

def test_vectorized_udf_return_timestamp_tz(self):
from pyspark.sql.functions import pandas_udf, col
Expand Down

0 comments on commit 1f096bf

Please sign in to comment.