Skip to content

Commit

Permalink
support datetime type for SchemaRDD
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Jul 26, 2014
1 parent a2715cc commit 96db384
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -550,11 +550,11 @@ private[spark] object PythonRDD extends Logging {
def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
pyRDD.rdd.mapPartitions { iter =>
val unpickle = new Unpickler
// TODO: Figure out why flatMap is necessay for pyspark
iter.flatMap { row =>
unpickle.loads(row) match {
// in case of objects are pickled in batch mode
case objs: java.util.ArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap)
// Incase the partition doesn't have a collection
// not in batch mode
case obj: JMap[String @unchecked, _] => Seq(obj.toMap)
}
}
Expand Down
9 changes: 5 additions & 4 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,13 @@ def __init__(self, sparkContext, sqlContext=None):
...
ValueError:...
>>> allTypes = sc.parallelize([{"int" : 1, "string" : "string", "double" : 1.0, "long": 1L,
... "boolean" : True}])
>>> from datetime import datetime
>>> allTypes = sc.parallelize([{"int": 1, "string": "string", "double": 1.0, "long": 1L,
... "boolean": True, "time": datetime(2010, 1, 1, 1, 1, 1)}])
>>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, x.double, x.long,
... x.boolean))
... x.boolean, x.time))
>>> srdd.collect()[0]
(1, u'string', 1.0, 1, True)
(1, u'string', 1.0, 1, True, datetime.datetime(2010, 1, 1, 1, 1, 1))
"""
self._sc = sparkContext
self._jsc = self._sc._jsc
Expand Down
40 changes: 38 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -357,16 +357,52 @@ class SQLContext(@transient val sparkContext: SparkContext)
case c: java.util.Map[_, _] =>
val (key, value) = c.head
MapType(typeFor(key), typeFor(value))
case c: java.util.Calendar => TimestampType
case c if c.getClass.isArray =>
val elem = c.asInstanceOf[Array[_]].head
ArrayType(typeFor(elem))
case c => throw new Exception(s"Object of type $c cannot be used")
}
val schema = rdd.first().map { case (fieldName, obj) =>
val firstRow = rdd.first()
val schema = firstRow.map { case (fieldName, obj) =>
AttributeReference(fieldName, typeFor(obj), true)()
}.toSeq

val rowRdd = rdd.mapPartitions { iter =>
def needTransform(obj: Any): Boolean = obj match {
case c: java.util.List[_] => c.exists(needTransform)
case c: java.util.Set[_] => c.exists(needTransform)
case c: java.util.Map[_, _] => c.exists {
case (key, value) => needTransform(key) || needTransform(value)
}
case c if c.getClass.isArray =>
c.asInstanceOf[Array[_]].exists(needTransform)
case c: java.util.Calendar => true
case c => false
}

def transform(obj: Any): Any = obj match {
case c: java.util.List[_] => c.map(transform)
case c: java.util.Set[_] => c.map(transform)
case c: java.util.Map[_, _] => c.map {
case (key, value) => (transform(key), transform(value))
}
case c if c.getClass.isArray =>
c.asInstanceOf[Array[_]].map(transform)
case c: java.util.Calendar =>
new java.sql.Timestamp(c.getTime().getTime())
case c => c
}

val need = firstRow.exists {case (key, value) => needTransform(value)}
val transformed = if (need) {
rdd.mapPartitions { iter =>
iter.map {
m => m.map {case (key, value) => (key, transform(value))}
}
}
} else rdd

val rowRdd = transformed.mapPartitions { iter =>
iter.map { map =>
new GenericRow(map.values.toArray.asInstanceOf[Array[Any]]): Row
}
Expand Down
5 changes: 5 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,11 @@ class SchemaRDD(
arr.asInstanceOf[Array[Any]].map {
element => rowToMap(element.asInstanceOf[Row], struct)
}
case t: java.sql.Timestamp => {
val c = java.util.Calendar.getInstance()
c.setTimeInMillis(t.getTime())
c
}
case other => other
}
map.put(attrName, arrayValues)
Expand Down

0 comments on commit 96db384

Please sign in to comment.