Skip to content

Commit

Permalink
udt finallly working
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbradley committed Nov 2, 2014
1 parent 50f9726 commit 893ee4c
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,28 @@ object ScalaReflection {
def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match {
// TODO: Why does this not need to flatMap stuff? Does it not support nesting?
// TODO: What about Option and Product?
case (s: Seq[_], arrayType: ArrayType) => s.map(convertToScala(_, arrayType.elementType))
case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) =>
case (s: Seq[_], arrayType: ArrayType) =>
println("convertToScala: Seq")
s.map(convertToScala(_, arrayType.elementType))
case (m: Map[_, _], mapType: MapType) =>
println("convertToScala: Map")
m.map { case (k, v) =>
convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType)
}
case d: Decimal => d.toBigDecimal
case (udt: Row, udtType: UserDefinedType[_]) => udtType.deserialize(udt)
case (other, _) => other
case (d: Decimal, DecimalType) => d.toBigDecimal
case (udt: Row, udtType: UserDefinedType[_]) =>
println("convertToScala: udt")
udtType.deserialize(udt)
case (other, _) =>
println("convertToScala: other")
other
}

def convertRowToScala(r: Row, schema: StructType): Row = {
new GenericRow(r.toArray.map(convertToScala(_, schema)))
println("convertRowToScala called with schema: $schema")
new GenericRow(
r.zip(schema.fields.map(_.dataType))
.map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray)
}

/** Returns a Sequence of attributes for the given case class type. */
Expand All @@ -129,9 +140,6 @@ object ScalaReflection {
def schemaFor(tpe: `Type`, udtRegistry: scala.collection.Map[Any, UserDefinedType[_]]): Schema = {
println(s"schemaFor: $tpe")
tpe match {
case t if udtRegistry.contains(t) =>
println(s" schemaFor T matched udtRegistry")
Schema(udtRegistry(t), nullable = true)
case t if t <:< typeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
Schema(schemaFor(optType, udtRegistry).dataType, nullable = true)
Expand Down Expand Up @@ -178,6 +186,9 @@ object ScalaReflection {
case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
case t if udtRegistry.contains(t) =>
println(s" schemaFor T matched udtRegistry")
Schema(udtRegistry(t), nullable = true)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
println(s"createSchemaRDD called")
val attributeSeq = ScalaReflection.attributesFor[A](udtRegistry)
val schema = StructType.fromAttributes(attributeSeq)
new SchemaRDD(this, LogicalRDD(attributeSeq, RDDConversions.productToRowRdd(rdd, schema))(self))
val rowRDD = RDDConversions.productToRowRdd(rdd, schema)
println("done with productToRowRdd")
new SchemaRDD(this, LogicalRDD(attributeSeq, rowRDD)(self))
}

/**
Expand Down
3 changes: 2 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class SchemaRDD(
*
* @group schema
*/
def schema: StructType = queryExecution.analyzed.schema
val schema: StructType = queryExecution.analyzed.schema

// =======================================================================
// Query DSL
Expand All @@ -148,6 +148,7 @@ class SchemaRDD(
case (ne: NamedExpression, _) => ne
case (e, i) => Alias(e, s"c$i")()
}
assert(sqlContext != null)
new SchemaRDD(sqlContext, Project(aliases, logicalPlan))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ private[sql] trait SchemaRDDLike {
@transient val sqlContext: SQLContext
@transient val baseLogicalPlan: LogicalPlan

assert(sqlContext != null)

private[sql] def baseSchemaRDD: SchemaRDD

/**
Expand All @@ -51,10 +49,7 @@ private[sql] trait SchemaRDDLike {
*/
@transient
@DeveloperApi
lazy val queryExecution = {
assert(sqlContext != null)
sqlContext.executePlan(baseLogicalPlan)
}
lazy val queryExecution = sqlContext.executePlan(baseLogicalPlan)

@transient protected[spark] val logicalPlan: LogicalPlan = baseLogicalPlan match {
// For various commands (like DDL) and queries with side effects, we force query optimization to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,15 @@ class UserDefinedTypeSuite extends QueryTest {
println("Done converting to SchemaRDD")

println("testing labels")
val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v}
val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v }
val labelsArrays: Array[Double] = labels.collect()
assert(labelsArrays.size === 2)
assert(labelsArrays.contains(1.0))
assert(labelsArrays.contains(0.0))

println("testing features")
val features: RDD[DenseVector] =
pointsRDD.select('features).map { case Row(v: DenseVector) => v}
pointsRDD.select('features).map { case Row(v: DenseVector) => v }
val featuresArrays: Array[DenseVector] = features.collect()
assert(featuresArrays.size === 2)
assert(featuresArrays.contains(new DenseVector(Array(0.1, 1.0))))
Expand Down

0 comments on commit 893ee4c

Please sign in to comment.