diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 61d6fc63554bb..ff36dcbdf875b 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -103,6 +103,22 @@ + + com.thoughtworks.paranamer + paranamer-maven-plugin + + + run + + ${project.build.sourceDirectory} + ${project.build.outputDirectory} + + + generate + + + + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index c8ee87e8819f2..022aa69d28773 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -21,6 +21,8 @@ import java.beans.{PropertyDescriptor, Introspector} import java.lang.{Iterable => JIterable} import java.util.{Iterator => JIterator, Map => JMap, List => JList} +import com.thoughtworks.paranamer.{BytecodeReadingParanamer, CachingParanamer, Paranamer} + import scala.language.existentials import com.google.common.reflect.TypeToken @@ -405,4 +407,11 @@ object JavaTypeInference { } } } + + private val paranamer: Paranamer = new BytecodeReadingParanamer() + + def getConstructorParaNames(cls: Class[_]): Seq[String] = { + val ctr = cls.getConstructors.maxBy(_.getParameterCount) + paranamer.lookupParameterNames(ctr) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 26b6aca79971e..eefd9c7482553 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -262,6 +262,10 @@ case class AttributeReference( } } + override protected final def otherCopyArgs: Seq[AnyRef] = { + exprId :: qualifiers :: Nil + } + override def toString: String = s"$name#${exprId.id}$typeSuffix" // Since the expression id is not in the first constructor it is missing from the default diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index d838d845d20fd..a48f8df8a2e93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -17,7 +17,12 @@ package org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.catalyst.JavaTypeInference + import scala.collection.Map +import org.json4s.JsonAST._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.types.{StructType, DataType} @@ -463,4 +468,40 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } s"$nodeName(${args.mkString(",")})" } + + private[sql] def jsonValue: JValue = { + val fieldNames = JavaTypeInference.getConstructorParaNames(getClass) + val fieldValues = productIterator.toSeq ++ otherCopyArgs + assert(fieldNames.length == fieldValues.length, s"${getClass.getSimpleName} fields: " + + fieldNames.mkString(", ") + s", values: " + fieldValues.map(_.toString).mkString(", ")) + + val jsonFields: Seq[JField] = fieldNames.zip(fieldValues).flatMap { + case (name, value: TreeNode[_]) if containsChild(value) => None + case (name, value: TreeNode[_]) => Some(name -> value.allJsonValues) + case (name, value: Seq[BaseType]) if value.toSet.subsetOf(children.toSet) => None + case (name, value: Seq[_]) => + if (value.length > 0 && value(0).isInstanceOf[TreeNode[_]]) { + Some(name -> JArray(value.map(_.asInstanceOf[TreeNode[_]].allJsonValues).toList)) + } else { + Some(name -> JArray(value.map(v => JString(v.toString)).toList)) + } + case (name, value: Set[_]) => + Some(name -> JArray(value.map(v => JString(v.toString)).toList)) + case (name, value) => Some(name -> JString(value.toString)) + } + + JObject(("node-name" -> JString(nodeName)) :: jsonFields.toList) + } + + private def allJsonValues: JValue = { + val jsonValues = scala.collection.mutable.ArrayBuffer.empty[JValue] + def collectJsonValue(node: BaseType): Unit = { + jsonValues += node.jsonValue.merge(JObject("num-children" -> JInt(node.children.length))) + node.children.foreach(collectJsonValue) + } + collectJsonValue(this) + jsonValues + } + + def toJSON: String = pretty(render(allJsonValues)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TreeNodeJsonFormatSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TreeNodeJsonFormatSuite.scala new file mode 100644 index 0000000000000..632290d6737ee --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TreeNodeJsonFormatSuite.scala @@ -0,0 +1,30 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ + +class TreeNodeJsonFormatSuite extends SparkFunSuite { + import org.apache.spark.sql.catalyst.analysis.TestRelations._ + + test("logical plan json format") { + println(testRelation.select(('a + 1).as("i"), ('a * 2).as("j")).sortBy('i.asc).toJSON) + } +}