From 9200f38b6414255a5c60127aeeae517086ba108b Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Tue, 28 Nov 2017 12:11:27 +0900 Subject: [PATCH] Address comments. --- python/pyspark/sql/session.py | 4 ++++ python/pyspark/sql/tests.py | 2 ++ python/pyspark/sql/types.py | 16 ++++++++-------- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index e1093c8e12511..e2435e09af23d 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -459,6 +459,8 @@ def _convert_from_pandas(self, pdf, schema, timezone): if isinstance(field.dataType, TimestampType): s = _check_series_convert_timestamps_tz_local(pdf[field.name], timezone) if not copied and s is not pdf[field.name]: + # Copy once if the series is modified to prevent the original Pandas + # DataFrame from being updated pdf = pdf.copy() copied = True pdf[field.name] = s @@ -466,6 +468,8 @@ def _convert_from_pandas(self, pdf, schema, timezone): for column, series in pdf.iteritems(): s = _check_series_convert_timestamps_tz_local(pdf[column], timezone) if not copied and s is not pdf[column]: + # Copy once if the series is modified to prevent the original Pandas + # DataFrame from being updated pdf = pdf.copy() copied = True pdf[column] = s diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index fc6e575c08035..b4d32d8de8a22 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3291,6 +3291,7 @@ def test_createDataFrame_respect_session_timezone(self): self.assertNotEqual(result_ny, result_la) + # Correct result_la by adjusting 3 hours difference between Los Angeles and New York result_la_corrected = [Row(**{k: v - timedelta(hours=3) if k == '7_timestamp_t' else v for k, v in row.asDict().items()}) for row in result_la] @@ -3834,6 +3835,7 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self): df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \ .withColumn("internal_value", internal_value(col("timestamp"))) result_la = df_la.select(col("idx"), col("internal_value")).collect() + # Correct result_la by adjusting 3 hours difference between Los Angeles and New York diff = 3 * 60 * 60 * 1000 * 1000 * 1000 result_la_corrected = \ df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect() diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 40f82832950ef..78abc32a35a1c 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1729,13 +1729,13 @@ def _check_series_convert_timestamps_internal(s, timezone): return s -def _check_series_convert_timestamps_localize(s, fromTimezone, toTimezone): +def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone): """ Convert timestamp to timezone-naive in the specified timezone or local timezone :param s: a pandas.Series - :param fromTimezone: the timezone to convert from. if None then use local timezone - :param toTimezone: the timezone to convert to. if None then use local timezone + :param from_timezone: the timezone to convert from. if None then use local timezone + :param to_timezone: the timezone to convert to. if None then use local timezone :return pandas.Series where if it is a timestamp, has been converted to tz-naive """ try: @@ -1743,14 +1743,14 @@ def _check_series_convert_timestamps_localize(s, fromTimezone, toTimezone): from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype except ImportError as e: raise ImportError(_old_pandas_exception_message(e)) - fromTz = fromTimezone or 'tzlocal()' - toTz = toTimezone or 'tzlocal()' + from_tz = from_timezone or 'tzlocal()' + to_tz = to_timezone or 'tzlocal()' # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if is_datetime64tz_dtype(s.dtype): - return s.dt.tz_convert(toTz).dt.tz_localize(None) - elif is_datetime64_dtype(s.dtype) and fromTz != toTz: + return s.dt.tz_convert(to_tz).dt.tz_localize(None) + elif is_datetime64_dtype(s.dtype) and from_tz != to_tz: # `s.dt.tz_localize('tzlocal()')` doesn't work properly when including NaT. - return s.apply(lambda ts: ts.tz_localize(fromTz).tz_convert(toTz).tz_localize(None) + return s.apply(lambda ts: ts.tz_localize(from_tz).tz_convert(to_tz).tz_localize(None) if ts is not pd.NaT else pd.NaT) else: return s