From b1c6187734fba702a55a92598f6a6c381539dd38 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 17 Dec 2015 02:46:23 +0800 Subject: [PATCH] fix tests --- .../expressions/aggregate/interfaces.scala | 1 - .../spark/sql/catalyst/plans/QueryPlan.scala | 2 + .../catalyst/util/TreeNodeJsonFormatter.scala | 45 +++++++++++++++---- .../spark/sql/execution/ExistingRDD.scala | 4 +- .../datasources/LogicalRelation.scala | 4 +- .../org/apache/spark/sql/QueryTest.scala | 29 +++++++++++- 6 files changed, 69 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 3b441de34a49f..e6fd726e74716 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.InternalRow diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index b9db7838db08a..d2626440b9434 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -88,6 +88,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map(recursiveTransform) case other: AnyRef => other + case null => null } val newArgs = productIterator.map(recursiveTransform).toArray @@ -120,6 +121,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map(recursiveTransform) case other: AnyRef => other + case null => null } val newArgs = productIterator.map(recursiveTransform).toArray diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TreeNodeJsonFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TreeNodeJsonFormatter.scala index a1390910a939f..dc1cff1ad1a7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TreeNodeJsonFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TreeNodeJsonFormatter.scala @@ -24,11 +24,14 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.util.Utils +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.ScalaReflection._ import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans._ object TreeNodeJsonFormatter { @@ -78,14 +81,24 @@ object TreeNodeJsonFormatter { case f: Float => JDouble(f) case d: Double => JDouble(d) case s: String => JString(s) + case s: UTF8String => JString(s.toString) case dt: DataType => dt.jsonValue case m: Metadata => m.jsonValue case e: ExprId => ("id" -> e.id) ~ ("jvmId" -> e.jvmId.toString) case s: SortDirection => s.toString + case j: JoinType => j.toString + case a: AggregateMode => a.toString case n: TreeNode[_] => jsonValue(n) case o: Option[_] => o.map(parseToJson).getOrElse(JNull) case t: Seq[_] => JArray(t.map(parseToJson).toList) - case _ => throw new RuntimeException(s"Do not support type ${obj.getClass}.") + case _ => + val clsName = obj.getClass.getName + if (clsName.contains("RDD") || clsName.contains("SQLContext") || + clsName.contains("Relation")) { + JNull + } else { + throw new RuntimeException(s"Do not support type ${obj.getClass}.") + } } def fromJSON(json: String): TreeNode[_] = { @@ -157,21 +170,34 @@ object TreeNodeJsonFormatter { value.asInstanceOf[JDouble].num: java.lang.Double case t if t <:< localTypeOf[java.lang.String] => value.asInstanceOf[JString].s + case t if t <:< localTypeOf[UTF8String] => + UTF8String.fromString(value.asInstanceOf[JString].s) case t if t <:< localTypeOf[DataType] => DataType.parseDataType(value) case t if t <:< localTypeOf[Metadata] => Metadata.fromJObject(value.asInstanceOf[JObject]) case t if t <:< localTypeOf[ExprId] => val JInt(id) = value \ "id" val JString(jvmId) = value \ "jvmId" ExprId(id.toInt, UUID.fromString(jvmId)) - case t if t <:< localTypeOf[SortDirection] => - val JString(direction) = value - if (direction == Ascending.toString) { - Ascending - } else if (direction == Descending.toString) { - Descending - } else { - throw new RuntimeException(s"$direction is not a valid SortDirection string.") + case t if t <:< localTypeOf[SortDirection] => value.asInstanceOf[JString].s match { + case "Ascending" => Ascending + case "Descending" => Descending + case other => throw new RuntimeException(s"$other is not a valid SortDirection string.") } + case t if t <:< localTypeOf[JoinType] => value.asInstanceOf[JString].s match { + case "Inner" => Inner + case "LeftOuter" => LeftOuter + case "RightOuter" => RightOuter + case "FullOuter" => FullOuter + case "LeftSemi" => LeftSemi + case other => throw new RuntimeException(s"$other is not a valid JoinType string.") + } + case t if t <:< localTypeOf[AggregateMode] => value.asInstanceOf[JString].s match { + case "Partial" => Partial + case "PartialMerge" => PartialMerge + case "Final" => Final + case "Complete" => Complete + case other => throw new RuntimeException(s"$other is not a valid AggregateMode string.") + } case t if t <:< localTypeOf[TreeNode[_]] => value match { case JInt(i) => children(i.toInt) case arr: JArray => reconstruct(arr) @@ -193,6 +219,7 @@ object TreeNodeJsonFormatter { case JInt(i) => i case JDouble(d) => d: java.lang.Double case JString(s) => s + case JNull => null case _ => throw new RuntimeException(s"Do not support type $expectedType.") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 18ee91bfc44a0..3b0c468d219f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -74,9 +74,7 @@ private[sql] case class LogicalRDD( override def children: Seq[LogicalPlan] = Nil - override protected[sql] final def otherCopyArgs: Seq[AnyRef] = { - sqlContext :: Nil - } + override protected[sql] final def otherCopyArgs: Seq[AnyRef] = sqlContext :: Nil override def newInstance(): LogicalRDD.this.type = LogicalRDD(output.map(_.newInstance()), rdd)(sqlContext).asInstanceOf[this.type] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 219dae88e515d..1c165ba243d9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -33,7 +33,7 @@ case class LogicalRelation( expectedOutputAttributes: Option[Seq[Attribute]] = None) extends LeafNode with MultiInstanceRelation { - override val output: Seq[AttributeReference] = { + override lazy val output: Seq[AttributeReference] = { val attrs = relation.schema.toAttributes expectedOutputAttributes.map { expectedAttrs => assert(expectedAttrs.length == attrs.length) @@ -72,7 +72,7 @@ case class LogicalRelation( ) /** Used to lookup original attribute capitalization */ - val attributeMap: AttributeMap[AttributeReference] = AttributeMap(output.map(o => (o, o))) + lazy val attributeMap: AttributeMap[AttributeReference] = AttributeMap(output.map(o => (o, o))) def newInstance(): this.type = LogicalRelation(relation).asInstanceOf[this.type] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index bc22fb8b7bdb4..7ad8f6f7c54f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -23,8 +23,10 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.columnar.InMemoryRelation -import org.apache.spark.sql.execution.Queryable +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.{LogicalRDD, Queryable} abstract class QueryTest extends PlanTest { @@ -123,6 +125,31 @@ abstract class QueryTest extends PlanTest { |""".stripMargin) } + import TreeNodeJsonFormatter._ + val logicalPlan = analyzedDF.queryExecution.analyzed + var logicalRDDs = logicalPlan.collect { case l: LogicalRDD => l } + var logicalRelations = logicalPlan.collect { case l: LogicalRelation => l } + try { + val jsonBackPlan = fromJSON(toJSON(logicalPlan)).asInstanceOf[LogicalPlan] + val normalized = jsonBackPlan transformDown { + case l: LogicalRDD => + val replace = logicalRDDs.head + logicalRDDs = logicalRDDs.drop(1) + replace + case l: LogicalRelation => + val replace = logicalRelations.head + logicalRelations = logicalRelations.drop(1) + replace + } + assert(logicalRDDs.isEmpty) + assert(logicalRelations.isEmpty) + comparePlans(logicalPlan, normalized) + } catch { + case e => + println(logicalPlan.treeString) + throw e + } + QueryTest.checkAnswer(analyzedDF, expectedAnswer) match { case Some(errorMessage) => fail(errorMessage) case None =>