Skip to content

Commit

Permalink
Address comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
ueshin committed Nov 28, 2017
1 parent 40a9735 commit 9200f38
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 8 deletions.
4 changes: 4 additions & 0 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,13 +459,17 @@ def _convert_from_pandas(self, pdf, schema, timezone):
if isinstance(field.dataType, TimestampType):
s = _check_series_convert_timestamps_tz_local(pdf[field.name], timezone)
if not copied and s is not pdf[field.name]:
# Copy once if the series is modified to prevent the original Pandas
# DataFrame from being updated
pdf = pdf.copy()
copied = True
pdf[field.name] = s
else:
for column, series in pdf.iteritems():
s = _check_series_convert_timestamps_tz_local(pdf[column], timezone)
if not copied and s is not pdf[column]:
# Copy once if the series is modified to prevent the original Pandas
# DataFrame from being updated
pdf = pdf.copy()
copied = True
pdf[column] = s
Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3291,6 +3291,7 @@ def test_createDataFrame_respect_session_timezone(self):

self.assertNotEqual(result_ny, result_la)

# Correct result_la by adjusting 3 hours difference between Los Angeles and New York
result_la_corrected = [Row(**{k: v - timedelta(hours=3) if k == '7_timestamp_t' else v
for k, v in row.asDict().items()})
for row in result_la]
Expand Down Expand Up @@ -3834,6 +3835,7 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self):
df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
.withColumn("internal_value", internal_value(col("timestamp")))
result_la = df_la.select(col("idx"), col("internal_value")).collect()
# Correct result_la by adjusting 3 hours difference between Los Angeles and New York
diff = 3 * 60 * 60 * 1000 * 1000 * 1000
result_la_corrected = \
df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect()
Expand Down
16 changes: 8 additions & 8 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1729,28 +1729,28 @@ def _check_series_convert_timestamps_internal(s, timezone):
return s


def _check_series_convert_timestamps_localize(s, fromTimezone, toTimezone):
def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone):
"""
Convert timestamp to timezone-naive in the specified timezone or local timezone
:param s: a pandas.Series
:param fromTimezone: the timezone to convert from. if None then use local timezone
:param toTimezone: the timezone to convert to. if None then use local timezone
:param from_timezone: the timezone to convert from. if None then use local timezone
:param to_timezone: the timezone to convert to. if None then use local timezone
:return pandas.Series where if it is a timestamp, has been converted to tz-naive
"""
try:
import pandas as pd
from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype
except ImportError as e:
raise ImportError(_old_pandas_exception_message(e))
fromTz = fromTimezone or 'tzlocal()'
toTz = toTimezone or 'tzlocal()'
from_tz = from_timezone or 'tzlocal()'
to_tz = to_timezone or 'tzlocal()'
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if is_datetime64tz_dtype(s.dtype):
return s.dt.tz_convert(toTz).dt.tz_localize(None)
elif is_datetime64_dtype(s.dtype) and fromTz != toTz:
return s.dt.tz_convert(to_tz).dt.tz_localize(None)
elif is_datetime64_dtype(s.dtype) and from_tz != to_tz:
# `s.dt.tz_localize('tzlocal()')` doesn't work properly when including NaT.
return s.apply(lambda ts: ts.tz_localize(fromTz).tz_convert(toTz).tz_localize(None)
return s.apply(lambda ts: ts.tz_localize(from_tz).tz_convert(to_tz).tz_localize(None)
if ts is not pd.NaT else pd.NaT)
else:
return s
Expand Down

0 comments on commit 9200f38

Please sign in to comment.