Skip to content

Commit

Permalink
remove tests for sets and tuple in sql, fix list of list
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Jul 28, 2014
1 parent c9d607a commit f0599b0
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 45 deletions.
8 changes: 4 additions & 4 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def inferSchema(self, rdd):
True
>>> srdd = sqlCtx.inferSchema(nestedRdd2)
>>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : [1, 2], "f3" : [1, 2]},
... {"f1" : [[2, 3], [3, 4]], "f2" : [2, 3], "f3" : [2, 3]}]
>>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : [1, 2]},
... {"f1" : [[2, 3], [3, 4]], "f2" : [2, 3]}]
True
"""
if (rdd.__class__ is SchemaRDD):
Expand Down Expand Up @@ -511,8 +511,8 @@ def _test():
{"f1": array('i', [1, 2]), "f2": {"row1": 1.0}},
{"f1": array('i', [2, 3]), "f2": {"row2": 2.0}}])
globs['nestedRdd2'] = sc.parallelize([
{"f1": [[1, 2], [2, 3]], "f2": set([1, 2]), "f3": [1, 2]},
{"f1": [[2, 3], [3, 4]], "f2": set([2, 3]), "f3": [2, 3]}])
{"f1": [[1, 2], [2, 3]], "f2": [1, 2]},
{"f1": [[2, 3], [3, 4]], "f2": [2, 3]}])
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count:
Expand Down
9 changes: 4 additions & 5 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -352,12 +352,13 @@ class SQLContext(@transient val sparkContext: SparkContext)
case c: java.lang.Long => LongType
case c: java.lang.Double => DoubleType
case c: java.lang.Boolean => BooleanType
case c: java.math.BigDecimal => DecimalType
case c: java.sql.Timestamp => TimestampType
case c: java.util.Calendar => TimestampType
case c: java.util.List[_] => ArrayType(typeFor(c.head))
case c: java.util.Set[_] => ArrayType(typeFor(c.head))
case c: java.util.Map[_, _] =>
val (key, value) = c.head
MapType(typeFor(key), typeFor(value))
case c: java.util.Calendar => TimestampType
case c if c.getClass.isArray =>
val elem = c.asInstanceOf[Array[_]].head
ArrayType(typeFor(elem))
Expand All @@ -370,18 +371,16 @@ class SQLContext(@transient val sparkContext: SparkContext)

def needTransform(obj: Any): Boolean = obj match {
case c: java.util.List[_] => true
case c: java.util.Set[_] => true
case c: java.util.Map[_, _] => true
case c if c.getClass.isArray => true
case c: java.util.Calendar => true
case c => false
}

// convert JList, JSet into Seq, convert JMap into Map
// convert JList, JArray into Seq, convert JMap into Map
// convert Calendar into Timestamp
def transform(obj: Any): Any = obj match {
case c: java.util.List[_] => c.map(transform).toSeq
case c: java.util.Set[_] => c.map(transform).toSet.toSeq
case c: java.util.Map[_, _] => c.map {
case (key, value) => (key, transform(value))
}.toMap
Expand Down
53 changes: 17 additions & 36 deletions sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
import org.apache.spark.sql.catalyst.types.{ArrayType, BooleanType, StructType, MapType}
import org.apache.spark.sql.catalyst.types.{DataType, ArrayType, BooleanType, StructType, MapType}
import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan}
import org.apache.spark.api.java.JavaRDD

Expand Down Expand Up @@ -376,46 +376,27 @@ class SchemaRDD(
* Converts a JavaRDD to a PythonRDD. It is used by pyspark.
*/
private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
def toJava(obj: Any, dataType: DataType): Any = dataType match {
case struct: StructType => rowToMap(obj.asInstanceOf[Row], struct)
case array: ArrayType => obj match {
case seq: Seq[Any] => seq.map(x => toJava(x, array.elementType)).asJava
case list: JList[_] => list.map(x => toJava(x, array.elementType)).asJava
case arr if arr != null && arr.getClass.isArray =>
arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType))
case other => other
}
case mt: MapType => obj.asInstanceOf[Map[_, _]].map {
case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type
}.asJava
// Pyrolite can handle Timestamp
case other => obj
}
def rowToMap(row: Row, structType: StructType): JMap[String, Any] = {
val fields = structType.fields.map(field => (field.name, field.dataType))
val map: JMap[String, Any] = new java.util.HashMap
row.zip(fields).foreach {
case (obj, (attrName, dataType)) =>
dataType match {
case struct: StructType => map.put(attrName, rowToMap(obj.asInstanceOf[Row], struct))
case array @ ArrayType(struct: StructType) =>
val arrayValues = obj match {
case seq: Seq[Any] =>
seq.map(element => rowToMap(element.asInstanceOf[Row], struct)).asJava
case list: JList[_] =>
list.map(element => rowToMap(element.asInstanceOf[Row], struct)).asJava
case set: JSet[_] =>
set.map(element => rowToMap(element.asInstanceOf[Row], struct)).asJava
case arr if arr != null && arr.getClass.isArray =>
arr.asInstanceOf[Array[Any]].map {
element => rowToMap(element.asInstanceOf[Row], struct)
}
case other => other
}
map.put(attrName, arrayValues)
case m @ MapType(_, struct: StructType) =>
val nm = obj.asInstanceOf[Map[_,_]].map {
case (k, v) => (k, rowToMap(v.asInstanceOf[Row], struct))
}.asJava
map.put(attrName, nm)
case array: ArrayType => {
val arrayValues = obj match {
case seq: Seq[Any] => seq.asJava
case other => other
}
map.put(attrName, arrayValues)
}
case m: MapType => map.put(attrName, obj.asInstanceOf[Map[_,_]].asJava)
// Pyrolite can handle Timestamp
case other => map.put(attrName, obj)
}
case (obj, (attrName, dataType)) => map.put(attrName, toJava(obj, dataType))
}

map
}

Expand Down

0 comments on commit f0599b0

Please sign in to comment.