diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 23f05ce84667c..9e1660df06cc6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -122,16 +122,6 @@ class PlanParserSuite extends PlanTest { table("a").union(table("b")).as("c").select(star())) } - test("transform query spec") { - val p = ScriptTransformation(Seq('a, 'b), "func", Seq.empty, table("e"), null) - assertEqual("select transform(a, b) using 'func' from e where f < 10", - p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string))) - assertEqual("map a, b using 'func' as c, d from e", - p.copy(output = Seq('c.string, 'd.string))) - assertEqual("reduce a, b using 'func' as (c: int, d decimal(10, 0)) from e", - p.copy(output = Seq('c.int, 'd.decimal(10, 0)))) - } - test("multi select query") { assertEqual( "from a select * select * where s < 10", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala index 75108c6d47ea0..371a937641ea0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala @@ -18,16 +18,17 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.serde.serdeConstants -import org.scalatest.BeforeAndAfterAll -import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogTable, CatalogTableType} -import org.apache.spark.sql.catalyst.expressions.JsonTuple -import org.apache.spark.sql.catalyst.plans.logical.Generate +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, JsonTuple} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Generate, ScriptTransformation} import org.apache.spark.sql.hive.execution.HiveSqlParser +import org.apache.spark.sql.types.{DecimalType, IntegerType, StringType} -class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { +class HiveQlSuite extends PlanTest { val parser = HiveSqlParser private def extractTableDesc(sql: String): (CatalogTable, Boolean) = { @@ -201,6 +202,31 @@ class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { assert(plan.children.head.asInstanceOf[Generate].generator.isInstanceOf[JsonTuple]) } + test("transform query spec") { + val plan1 = parser.parsePlan("select transform(a, b) using 'func' as c, d from e") + val plan2 = parser.parsePlan("map a, b using 'func' as c, d from e") + val plan3 = parser.parsePlan("reduce a, b using 'func' as c, d from e") + comparePlans(plan1, plan2) + comparePlans(plan2, plan3) + + assert(plan1.isInstanceOf[ScriptTransformation]) + assert(plan1.asInstanceOf[ScriptTransformation].input + == Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b"))) + assert(plan1.asInstanceOf[ScriptTransformation].script + == "func") + assert(plan1.asInstanceOf[ScriptTransformation].output.map(_.name) + == Seq("c", "d")) + assert(plan1.asInstanceOf[ScriptTransformation].output.map(_.dataType) + == Seq(StringType, StringType)) + + val plan4 = parser.parsePlan("reduce a, b using 'func' as (c: int, d decimal(10, 0)) from e") + assert(plan4.isInstanceOf[ScriptTransformation]) + assert(plan1.asInstanceOf[ScriptTransformation].output.map(_.name) + == Seq("c", "d")) + assert(plan4.asInstanceOf[ScriptTransformation].output.map(_.dataType) + == Seq(IntegerType, DecimalType(10, 0))) + } + test("use backticks in output of Script Transform") { val plan = parser.parsePlan( """SELECT `t`.`thing1`