Skip to content

Commit

Permalink
[SPARK-3988][SQL] add public API for date type
Browse files Browse the repository at this point in the history
Add json and python api for date type.
By using Pickle, `java.sql.Date` was serialized as calendar, and recognized in python as `datetime.datetime`.

Author: Daoyuan Wang <[email protected]>

Closes #2901 from adrian-wang/spark3988 and squashes the following commits:

c51a24d [Daoyuan Wang] convert datetime to date
5670626 [Daoyuan Wang] minor line combine
f760d8e [Daoyuan Wang] fix indent
444f100 [Daoyuan Wang] fix a typo
1d74448 [Daoyuan Wang] fix scala style
8d7dd22 [Daoyuan Wang] add json and python api for date type
  • Loading branch information
adrian-wang authored and marmbrus committed Oct 28, 2014
1 parent 5807cb4 commit 47a40f6
Show file tree
Hide file tree
Showing 10 changed files with 87 additions and 36 deletions.
57 changes: 39 additions & 18 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@


__all__ = [
"StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType",
"StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType",
"DoubleType", "FloatType", "ByteType", "IntegerType", "LongType",
"ShortType", "ArrayType", "MapType", "StructField", "StructType",
"SQLContext", "HiveContext", "SchemaRDD", "Row"]
Expand Down Expand Up @@ -132,6 +132,14 @@ class BooleanType(PrimitiveType):
"""


class DateType(PrimitiveType):

"""Spark SQL DateType
The data type representing datetime.date values.
"""


class TimestampType(PrimitiveType):

"""Spark SQL TimestampType
Expand Down Expand Up @@ -438,7 +446,7 @@ def _parse_datatype_json_value(json_value):
return _all_complex_types[json_value["type"]].fromJson(json_value)


# Mapping Python types to Spark SQL DateType
# Mapping Python types to Spark SQL DataType
_type_mappings = {
bool: BooleanType,
int: IntegerType,
Expand All @@ -448,8 +456,8 @@ def _parse_datatype_json_value(json_value):
unicode: StringType,
bytearray: BinaryType,
decimal.Decimal: DecimalType,
datetime.date: DateType,
datetime.datetime: TimestampType,
datetime.date: TimestampType,
datetime.time: TimestampType,
}

Expand Down Expand Up @@ -656,10 +664,10 @@ def _infer_schema_type(obj, dataType):
"""
Fill the dataType with types infered from obj
>>> schema = _parse_schema_abstract("a b c")
>>> row = (1, 1.0, "str")
>>> schema = _parse_schema_abstract("a b c d")
>>> row = (1, 1.0, "str", datetime.date(2014, 10, 10))
>>> _infer_schema_type(row, schema)
StructType...IntegerType...DoubleType...StringType...
StructType...IntegerType...DoubleType...StringType...DateType...
>>> row = [[1], {"key": (1, 2.0)}]
>>> schema = _parse_schema_abstract("a[] b{c d}")
>>> _infer_schema_type(row, schema)
Expand Down Expand Up @@ -703,6 +711,7 @@ def _infer_schema_type(obj, dataType):
DecimalType: (decimal.Decimal,),
StringType: (str, unicode),
BinaryType: (bytearray,),
DateType: (datetime.date,),
TimestampType: (datetime.datetime,),
ArrayType: (list, tuple, array),
MapType: (dict,),
Expand Down Expand Up @@ -740,7 +749,7 @@ def _verify_type(obj, dataType):

# subclass of them can not be deserialized in JVM
if type(obj) not in _acceptable_types[_type]:
raise TypeError("%s can not accept abject in type %s"
raise TypeError("%s can not accept object in type %s"
% (dataType, type(obj)))

if isinstance(dataType, ArrayType):
Expand All @@ -767,7 +776,7 @@ def _restore_object(dataType, obj):
""" Restore object during unpickling. """
# use id(dataType) as key to speed up lookup in dict
# Because of batched pickling, dataType will be the
# same object in mose cases.
# same object in most cases.
k = id(dataType)
cls = _cached_cls.get(k)
if cls is None:
Expand All @@ -782,6 +791,10 @@ def _restore_object(dataType, obj):

def _create_object(cls, v):
""" Create an customized object with class `cls`. """
# datetime.date would be deserialized as datetime.datetime
# from java type, so we need to set it back.
if cls is datetime.date and isinstance(v, datetime.datetime):
return v.date()
return cls(v) if v is not None else v


Expand All @@ -795,14 +808,16 @@ def getter(self):
return getter


def _has_struct(dt):
"""Return whether `dt` is or has StructType in it"""
def _has_struct_or_date(dt):
"""Return whether `dt` is or has StructType/DateType in it"""
if isinstance(dt, StructType):
return True
elif isinstance(dt, ArrayType):
return _has_struct(dt.elementType)
return _has_struct_or_date(dt.elementType)
elif isinstance(dt, MapType):
return _has_struct(dt.valueType)
return _has_struct_or_date(dt.valueType)
elif isinstance(dt, DateType):
return True
return False


Expand All @@ -815,7 +830,7 @@ def _create_properties(fields):
or keyword.iskeyword(name)):
warnings.warn("field name %s can not be accessed in Python,"
"use position to access it instead" % name)
if _has_struct(f.dataType):
if _has_struct_or_date(f.dataType):
# delay creating object until accessing it
getter = _create_getter(f.dataType, i)
else:
Expand Down Expand Up @@ -870,6 +885,9 @@ def Dict(d):

return Dict

elif isinstance(dataType, DateType):
return datetime.date

elif not isinstance(dataType, StructType):
raise Exception("unexpected data type: %s" % dataType)

Expand Down Expand Up @@ -1068,8 +1086,9 @@ def applySchema(self, rdd, schema):
>>> srdd2.collect()
[Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')]
>>> from datetime import datetime
>>> from datetime import date, datetime
>>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0,
... date(2010, 1, 1),
... datetime(2010, 1, 1, 1, 1, 1),
... {"a": 1}, (2,), [1, 2, 3], None)])
>>> schema = StructType([
Expand All @@ -1079,6 +1098,7 @@ def applySchema(self, rdd, schema):
... StructField("short2", ShortType(), False),
... StructField("int", IntegerType(), False),
... StructField("float", FloatType(), False),
... StructField("date", DateType(), False),
... StructField("time", TimestampType(), False),
... StructField("map",
... MapType(StringType(), IntegerType(), False), False),
Expand All @@ -1088,10 +1108,11 @@ def applySchema(self, rdd, schema):
... StructField("null", DoubleType(), True)])
>>> srdd = sqlCtx.applySchema(rdd, schema)
>>> results = srdd.map(
... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.time,
... x.map["a"], x.struct.b, x.list, x.null))
>>> results.collect()[0]
(127, -128, -32768, 32767, 2147483647, 1.0, ...(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date,
... x.time, x.map["a"], x.struct.b, x.list, x.null))
>>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE
(127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1),
datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
>>> srdd.registerTempTable("table2")
>>> sqlCtx.sql(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ object ScalaReflection {
case obj: FloatType.JvmType => FloatType
case obj: DoubleType.JvmType => DoubleType
case obj: DecimalType.JvmType => DecimalType
case obj: DateType.JvmType => DateType
case obj: TimestampType.JvmType => TimestampType
case null => NullType
// For other cases, there is no obvious mapping from the type of the given object to a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ object DataType {
| "BinaryType" ^^^ BinaryType
| "BooleanType" ^^^ BooleanType
| "DecimalType" ^^^ DecimalType
| "DateType" ^^^ DateType
| "TimestampType" ^^^ TimestampType
)

Expand Down Expand Up @@ -198,7 +199,8 @@ trait PrimitiveType extends DataType {
}

object PrimitiveType {
private[sql] val all = Seq(DecimalType, TimestampType, BinaryType) ++ NativeType.all
private[sql] val all = Seq(DecimalType, DateType, TimestampType, BinaryType) ++
NativeType.all

private[sql] val nameToType = all.map(t => t.typeName -> t).toMap
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst

import java.math.BigInteger
import java.sql.Timestamp
import java.sql.{Date, Timestamp}

import org.scalatest.FunSuite

Expand All @@ -43,6 +43,7 @@ case class NullableData(
booleanField: java.lang.Boolean,
stringField: String,
decimalField: BigDecimal,
dateField: Date,
timestampField: Timestamp,
binaryField: Array[Byte])

Expand Down Expand Up @@ -96,6 +97,7 @@ class ScalaReflectionSuite extends FunSuite {
StructField("booleanField", BooleanType, nullable = true),
StructField("stringField", StringType, nullable = true),
StructField("decimalField", DecimalType, nullable = true),
StructField("dateField", DateType, nullable = true),
StructField("timestampField", TimestampType, nullable = true),
StructField("binaryField", BinaryType, nullable = true))),
nullable = true))
Expand Down Expand Up @@ -199,8 +201,11 @@ class ScalaReflectionSuite extends FunSuite {
// DecimalType
assert(DecimalType === typeOfObject(BigDecimal("1.7976931348623157E318")))

// DateType
assert(DateType === typeOfObject(Date.valueOf("2014-07-25")))

// TimestampType
assert(TimestampType === typeOfObject(java.sql.Timestamp.valueOf("2014-07-25 10:26:00")))
assert(TimestampType === typeOfObject(Timestamp.valueOf("2014-07-25 10:26:00")))

// NullType
assert(NullType === typeOfObject(null))
Expand Down
10 changes: 7 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
case ByteType => true
case ShortType => true
case FloatType => true
case DateType => true
case TimestampType => true
case ArrayType(_, _) => true
case MapType(_, _, _) => true
Expand All @@ -452,9 +453,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
}

// Converts value to the type specified by the data type.
// Because Python does not have data types for TimestampType, FloatType, ShortType, and
// ByteType, we need to explicitly convert values in columns of these data types to the desired
// JVM data types.
// Because Python does not have data types for DateType, TimestampType, FloatType, ShortType,
// and ByteType, we need to explicitly convert values in columns of these data types to the
// desired JVM data types.
def convert(obj: Any, dataType: DataType): Any = (obj, dataType) match {
// TODO: We should check nullable
case (null, _) => null
Expand All @@ -474,6 +475,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
case (e, f) => convert(e, f.dataType)
}): Row

case (c: java.util.Calendar, DateType) =>
new java.sql.Date(c.getTime().getTime())

case (c: java.util.Calendar, TimestampType) =>
new java.sql.Timestamp(c.getTime().getTime())

Expand Down
20 changes: 14 additions & 6 deletions sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.json
import scala.collection.Map
import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper}
import scala.math.BigDecimal
import java.sql.Timestamp
import java.sql.{Date, Timestamp}

import com.fasterxml.jackson.core.JsonProcessingException
import com.fasterxml.jackson.databind.ObjectMapper
Expand Down Expand Up @@ -372,13 +372,20 @@ private[sql] object JsonRDD extends Logging {
}
}

private def toDate(value: Any): Date = {
value match {
// only support string as date
case value: java.lang.String => Date.valueOf(value)
}
}

private def toTimestamp(value: Any): Timestamp = {
value match {
case value: java.lang.Integer => new Timestamp(value.asInstanceOf[Int].toLong)
case value: java.lang.Long => new Timestamp(value)
case value: java.lang.String => Timestamp.valueOf(value)
}
}
case value: java.lang.Integer => new Timestamp(value.asInstanceOf[Int].toLong)
case value: java.lang.Long => new Timestamp(value)
case value: java.lang.String => Timestamp.valueOf(value)
}
}

private[json] def enforceCorrectType(value: Any, desiredType: DataType): Any ={
if (value == null) {
Expand All @@ -396,6 +403,7 @@ private[sql] object JsonRDD extends Logging {
case ArrayType(elementType, _) =>
value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType))
case struct: StructType => asRow(value.asInstanceOf[Map[String, Any]], struct)
case DateType => toDate(value)
case TimestampType => toTimestamp(value)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.api.java;

import java.math.BigDecimal;
import java.sql.Date;
import java.sql.Timestamp;
import java.util.Arrays;
import java.util.HashMap;
Expand All @@ -39,6 +40,7 @@ public class JavaRowSuite {
private boolean booleanValue;
private String stringValue;
private byte[] binaryValue;
private Date dateValue;
private Timestamp timestampValue;

@Before
Expand All @@ -53,6 +55,7 @@ public void setUp() {
booleanValue = true;
stringValue = "this is a string";
binaryValue = stringValue.getBytes();
dateValue = Date.valueOf("2014-06-30");
timestampValue = Timestamp.valueOf("2014-06-30 09:20:00.0");
}

Expand All @@ -76,6 +79,7 @@ public void constructSimpleRow() {
new Boolean(booleanValue),
stringValue, // StringType
binaryValue, // BinaryType
dateValue, // DateType
timestampValue, // TimestampType
null // null
);
Expand Down Expand Up @@ -114,9 +118,10 @@ public void constructSimpleRow() {
Assert.assertEquals(stringValue, simpleRow.getString(15));
Assert.assertEquals(stringValue, simpleRow.get(15));
Assert.assertEquals(binaryValue, simpleRow.get(16));
Assert.assertEquals(timestampValue, simpleRow.get(17));
Assert.assertEquals(true, simpleRow.isNullAt(18));
Assert.assertEquals(null, simpleRow.get(18));
Assert.assertEquals(dateValue, simpleRow.get(17));
Assert.assertEquals(timestampValue, simpleRow.get(18));
Assert.assertEquals(true, simpleRow.isNullAt(19));
Assert.assertEquals(null, simpleRow.get(19));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public void createDataTypes() {
checkDataType(DataType.StringType);
checkDataType(DataType.BinaryType);
checkDataType(DataType.BooleanType);
checkDataType(DataType.DateType);
checkDataType(DataType.TimestampType);
checkDataType(DataType.DecimalType);
checkDataType(DataType.DoubleType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class ScalaSideDataTypeConversionSuite extends FunSuite {
checkDataType(org.apache.spark.sql.StringType)
checkDataType(org.apache.spark.sql.BinaryType)
checkDataType(org.apache.spark.sql.BooleanType)
checkDataType(org.apache.spark.sql.DateType)
checkDataType(org.apache.spark.sql.TimestampType)
checkDataType(org.apache.spark.sql.DecimalType)
checkDataType(org.apache.spark.sql.DoubleType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._

import java.sql.Timestamp
import java.sql.{Date, Timestamp}

class JsonSuite extends QueryTest {
import TestJsonData._
Expand Down Expand Up @@ -58,8 +58,11 @@ class JsonSuite extends QueryTest {
checkTypePromotion(new Timestamp(intNumber), enforceCorrectType(intNumber, TimestampType))
checkTypePromotion(new Timestamp(intNumber.toLong),
enforceCorrectType(intNumber.toLong, TimestampType))
val strDate = "2014-09-30 12:34:56"
checkTypePromotion(Timestamp.valueOf(strDate), enforceCorrectType(strDate, TimestampType))
val strTime = "2014-09-30 12:34:56"
checkTypePromotion(Timestamp.valueOf(strTime), enforceCorrectType(strTime, TimestampType))

val strDate = "2014-10-15"
checkTypePromotion(Date.valueOf(strDate), enforceCorrectType(strDate, DateType))
}

test("Get compatible type") {
Expand Down

0 comments on commit 47a40f6

Please sign in to comment.