Skip to content

Commit

Permalink
[SPARK-23387][SQL][PYTHON][TEST][BRANCH-2.3] Backport assertPandasEqu…
Browse files Browse the repository at this point in the history
…al to branch-2.3.

## What changes were proposed in this pull request?

When backporting a pr with tests using `assertPandasEqual` from master to branch-2.3, the tests fail because `assertPandasEqual` doesn't exist in branch-2.3.
We should backport `assertPandasEqual` to branch-2.3 to avoid the failures.

## How was this patch tested?

Modified tests.

Author: Takuya UESHIN <[email protected]>

Closes #20577 from ueshin/issues/SPARK-23387/branch-2.3.
  • Loading branch information
ueshin authored and HyukjinKwon committed Feb 11, 2018
1 parent 9fa7b0e commit 8875e47
Showing 1 changed file with 19 additions and 25 deletions.
44 changes: 19 additions & 25 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,12 @@ def tearDownClass(cls):
ReusedPySparkTestCase.tearDownClass()
cls.spark.stop()

def assertPandasEqual(self, expected, result):
msg = ("DataFrames are not equal: " +
"\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) +
"\n\nResult:\n%s\n%s" % (result, result.dtypes))
self.assertTrue(expected.equals(result), msg=msg)


class DataTypeTests(unittest.TestCase):
# regression test for SPARK-6055
Expand Down Expand Up @@ -3422,12 +3428,6 @@ def tearDownClass(cls):
time.tzset()
ReusedSQLTestCase.tearDownClass()

def assertFramesEqual(self, df_with_arrow, df_without):
msg = ("DataFrame from Arrow is not equal" +
("\n\nWith Arrow:\n%s\n%s" % (df_with_arrow, df_with_arrow.dtypes)) +
("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes)))
self.assertTrue(df_without.equals(df_with_arrow), msg=msg)

def create_pandas_data_frame(self):
import pandas as pd
import numpy as np
Expand Down Expand Up @@ -3466,8 +3466,8 @@ def test_toPandas_arrow_toggle(self):
df = self.spark.createDataFrame(self.data, schema=self.schema)
pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
expected = self.create_pandas_data_frame()
self.assertFramesEqual(expected, pdf)
self.assertFramesEqual(expected, pdf_arrow)
self.assertPandasEqual(expected, pdf)
self.assertPandasEqual(expected, pdf_arrow)

def test_toPandas_respect_session_timezone(self):
df = self.spark.createDataFrame(self.data, schema=self.schema)
Expand All @@ -3478,11 +3478,11 @@ def test_toPandas_respect_session_timezone(self):
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false")
try:
pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df)
self.assertFramesEqual(pdf_arrow_la, pdf_la)
self.assertPandasEqual(pdf_arrow_la, pdf_la)
finally:
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true")
pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df)
self.assertFramesEqual(pdf_arrow_ny, pdf_ny)
self.assertPandasEqual(pdf_arrow_ny, pdf_ny)

self.assertFalse(pdf_ny.equals(pdf_la))

Expand All @@ -3492,15 +3492,15 @@ def test_toPandas_respect_session_timezone(self):
if isinstance(field.dataType, TimestampType):
pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz(
pdf_la_corrected[field.name], timezone)
self.assertFramesEqual(pdf_ny, pdf_la_corrected)
self.assertPandasEqual(pdf_ny, pdf_la_corrected)
finally:
self.spark.conf.set("spark.sql.session.timeZone", orig_tz)

def test_pandas_round_trip(self):
pdf = self.create_pandas_data_frame()
df = self.spark.createDataFrame(self.data, schema=self.schema)
pdf_arrow = df.toPandas()
self.assertFramesEqual(pdf_arrow, pdf)
self.assertPandasEqual(pdf_arrow, pdf)

def test_filtered_frame(self):
df = self.spark.range(3).toDF("i")
Expand Down Expand Up @@ -3558,7 +3558,7 @@ def test_createDataFrame_with_schema(self):
df = self.spark.createDataFrame(pdf, schema=self.schema)
self.assertEquals(self.schema, df.schema)
pdf_arrow = df.toPandas()
self.assertFramesEqual(pdf_arrow, pdf)
self.assertPandasEqual(pdf_arrow, pdf)

def test_createDataFrame_with_incorrect_schema(self):
pdf = self.create_pandas_data_frame()
Expand Down Expand Up @@ -4318,12 +4318,6 @@ def test_timestamp_dst(self):
_pandas_requirement_message or _pyarrow_requirement_message)
class GroupedMapPandasUDFTests(ReusedSQLTestCase):

def assertFramesEqual(self, expected, result):
msg = ("DataFrames are not equal: " +
("\n\nExpected:\n%s\n%s" % (expected, expected.dtypes)) +
("\n\nResult:\n%s\n%s" % (result, result.dtypes)))
self.assertTrue(expected.equals(result), msg=msg)

@property
def data(self):
from pyspark.sql.functions import array, explode, col, lit
Expand All @@ -4347,7 +4341,7 @@ def test_simple(self):

result = df.groupby('id').apply(foo_udf).sort('id').toPandas()
expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True)
self.assertFramesEqual(expected, result)
self.assertPandasEqual(expected, result)

def test_register_grouped_map_udf(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType
Expand All @@ -4371,7 +4365,7 @@ def foo(pdf):

result = df.groupby('id').apply(foo).sort('id').toPandas()
expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True)
self.assertFramesEqual(expected, result)
self.assertPandasEqual(expected, result)

def test_coerce(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType
Expand All @@ -4386,7 +4380,7 @@ def test_coerce(self):
result = df.groupby('id').apply(foo).sort('id').toPandas()
expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True)
expected = expected.assign(v=expected.v.astype('float64'))
self.assertFramesEqual(expected, result)
self.assertPandasEqual(expected, result)

def test_complex_groupby(self):
from pyspark.sql.functions import pandas_udf, col, PandasUDFType
Expand All @@ -4405,7 +4399,7 @@ def normalize(pdf):
expected = pdf.groupby(pdf['id'] % 2 == 0).apply(normalize.func)
expected = expected.sort_values(['id', 'v']).reset_index(drop=True)
expected = expected.assign(norm=expected.norm.astype('float64'))
self.assertFramesEqual(expected, result)
self.assertPandasEqual(expected, result)

def test_empty_groupby(self):
from pyspark.sql.functions import pandas_udf, col, PandasUDFType
Expand All @@ -4424,7 +4418,7 @@ def normalize(pdf):
expected = normalize.func(pdf)
expected = expected.sort_values(['id', 'v']).reset_index(drop=True)
expected = expected.assign(norm=expected.norm.astype('float64'))
self.assertFramesEqual(expected, result)
self.assertPandasEqual(expected, result)

def test_datatype_string(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType
Expand All @@ -4438,7 +4432,7 @@ def test_datatype_string(self):

result = df.groupby('id').apply(foo_udf).sort('id').toPandas()
expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True)
self.assertFramesEqual(expected, result)
self.assertPandasEqual(expected, result)

def test_wrong_return_type(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType
Expand Down

0 comments on commit 8875e47

Please sign in to comment.