From 44d84977571cba20531830b19ecb4186c93caa8f Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 10 Jun 2015 23:33:26 -0700 Subject: [PATCH] add timezone support for DateType --- python/pyspark/sql/tests.py | 7 +++++-- python/pyspark/sql/types.py | 23 +++++++++++++++-------- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 8d06a2e0a8de1..b5fbb7d098820 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -604,14 +604,17 @@ def test_filter_with_datetime(self): self.assertEqual(0, df.filter(df.time > time).count()) def test_time_with_timezone(self): + day = datetime.date.today() now = datetime.datetime.now() ts = time.mktime(now.timetuple()) + now.microsecond / 1e6 # class in __main__ is not serializable from pyspark.sql.tests import UTC utc = UTC() utcnow = datetime.datetime.fromtimestamp(ts, utc) - df = self.sqlCtx.createDataFrame([(now, utcnow)]) - now1, utcnow1 = df.first() + df = self.sqlCtx.createDataFrame([(day, now, utcnow)]) + day1, now1, utcnow1 = df.first() + # Pyrolite serialize java.sql.Date as datetime, will be fixed in new version + self.assertEqual(day1.date(), day) # Pyrolite does not support microsecond, the error should be # less than 1 millisecond self.assertTrue(now - now1 < datetime.timedelta(0.001)) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index c532640bb6e39..23d9adb0daea1 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -655,12 +655,15 @@ def _need_python_to_sql_conversion(dataType): _need_python_to_sql_conversion(dataType.valueType) elif isinstance(dataType, UserDefinedType): return True - elif isinstance(dataType, TimestampType): + elif isinstance(dataType, (DateType, TimestampType)): return True else: return False +EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal() + + def _python_to_sql_converter(dataType): """ Returns a converter that converts a Python object into a SQL datum for the given type. @@ -698,26 +701,30 @@ def converter(obj): return tuple(c(d.get(n)) for n, c in zip(names, converters)) else: return tuple(c(v) for c, v in zip(converters, obj)) - else: + elif obj is not None: raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) return converter elif isinstance(dataType, ArrayType): element_converter = _python_to_sql_converter(dataType.elementType) - return lambda a: [element_converter(v) for v in a] + return lambda a: a and [element_converter(v) for v in a] elif isinstance(dataType, MapType): key_converter = _python_to_sql_converter(dataType.keyType) value_converter = _python_to_sql_converter(dataType.valueType) - return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) + return lambda m: m and dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) elif isinstance(dataType, UserDefinedType): - return lambda obj: dataType.serialize(obj) + return lambda obj: obj and dataType.serialize(obj) + + elif isinstance(dataType, DateType): + return lambda d: d and d.toordinal() - EPOCH_ORDINAL elif isinstance(dataType, TimestampType): def to_posix_timstamp(dt): - seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo - else time.mktime(dt.timetuple())) - return int(seconds * 1e7 + dt.microsecond * 10) + if dt: + seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo + else time.mktime(dt.timetuple())) + return int(seconds * 1e7 + dt.microsecond * 10) return to_posix_timstamp else: