diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 27b01e0bed1c4..96c170be3d6ae 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -468,15 +468,15 @@ booleanExpression // https://github.com/antlr/antlr4/issues/780 // https://github.com/antlr/antlr4/issues/781 predicated - : valueExpression predicate[$valueExpression.ctx]? + : valueExpression predicate? ; -predicate[ParserRuleContext value] - : NOT? BETWEEN lower=valueExpression AND upper=valueExpression #between - | NOT? IN '(' expression (',' expression)* ')' #inList - | NOT? IN '(' query ')' #inSubquery - | NOT? like=(RLIKE | LIKE) pattern=valueExpression #like - | IS NOT? NULL #nullPredicate +predicate + : NOT? kind=BETWEEN lower=valueExpression AND upper=valueExpression + | NOT? kind=IN '(' expression (',' expression)* ')' + | NOT? kind=IN '(' query ')' + | NOT? kind=(RLIKE | LIKE) pattern=valueExpression + | IS NOT? kind=NULL ; valueExpression diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 61ea3e401057f..14c90918e6011 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.antlr.v4.runtime.{ParserRuleContext, Token} -import org.antlr.v4.runtime.tree.{ParseTree, TerminalNode} +import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} @@ -46,6 +46,19 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { ctx.accept(this).asInstanceOf[T] } + /** + * Override the default behavior for all visit methods. This will only return a non-null result + * when the context has only one child. This is done because there is no generic method to + * combine the results of the context children. In all other cases null is returned. + */ + override def visitChildren(node: RuleNode): AnyRef = { + if (node.getChildCount == 1) { + node.getChild(0).accept(this) + } else { + null + } + } + override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) { visit(ctx.statement).asInstanceOf[LogicalPlan] } @@ -351,7 +364,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { string(script), attributes, withFilter, - withScriptIOSchema(inRowFormat, recordWriter, outRowFormat, recordReader, schemaLess)) + withScriptIOSchema( + ctx, inRowFormat, recordWriter, outRowFormat, recordReader, schemaLess)) case SqlBaseParser.SELECT => // Regular select @@ -398,11 +412,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * Create a (Hive based) [[ScriptInputOutputSchema]]. */ protected def withScriptIOSchema( + ctx: QuerySpecificationContext, inRowFormat: RowFormatContext, recordWriter: Token, outRowFormat: RowFormatContext, recordReader: Token, - schemaLess: Boolean): ScriptInputOutputSchema = null + schemaLess: Boolean): ScriptInputOutputSchema = { + throw new ParseException("Script Transform is not supported", ctx) + } /** * Create a logical plan for a given 'FROM' clause. Note that we support multiple (comma @@ -778,17 +795,6 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { trees.asScala.map(expression) } - /** - * Invert a boolean expression if it has a valid NOT clause. - */ - private def invertIfNotDefined(expression: Expression, not: TerminalNode): Expression = { - if (not != null) { - Not(expression) - } else { - expression - } - } - /** * Create a star (i.e. all) expression; this selects all elements (in the specified object). * Both un-targeted (global) and targeted aliases are supported. @@ -909,57 +915,55 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** - * Create a BETWEEN expression. This tests if an expression lies with in the bounds set by two - * other expressions. The inverse can also be created. - */ - override def visitBetween(ctx: BetweenContext): Expression = withOrigin(ctx) { - val value = expression(ctx.value) - val between = And( - GreaterThanOrEqual(value, expression(ctx.lower)), - LessThanOrEqual(value, expression(ctx.upper))) - invertIfNotDefined(between, ctx.NOT) - } - - /** - * Create an IN expression. This tests if the value of the left hand side expression is - * contained by the sequence of expressions on the right hand side. + * Create a predicated expression. A predicated expression is a normal expression with a + * predicate attached to it, for example: + * {{{ + * a + 1 IS NULL + * }}} */ - override def visitInList(ctx: InListContext): Expression = withOrigin(ctx) { - val in = In(expression(ctx.value), ctx.expression().asScala.map(expression)) - invertIfNotDefined(in, ctx.NOT) + override def visitPredicated(ctx: PredicatedContext): Expression = withOrigin(ctx) { + val e = expression(ctx.valueExpression) + if (ctx.predicate != null) { + withPredicate(e, ctx.predicate) + } else { + e + } } /** - * Create an IN expression, where the the right hand side is a query. This is unsupported. + * Add a predicate to the given expression. Supported expressions are: + * - (NOT) BETWEEN + * - (NOT) IN + * - (NOT) LIKE + * - (NOT) RLIKE + * - IS (NOT) NULL. */ - override def visitInSubquery(ctx: InSubqueryContext): Expression = { - throw new ParseException("IN with a Sub-query is currently not supported.", ctx) - } + private def withPredicate(e: Expression, ctx: PredicateContext): Expression = withOrigin(ctx) { + // Invert a predicate if it has a valid NOT clause. + def invertIfNotDefined(e: Expression): Expression = ctx.NOT match { + case null => e + case not => Not(e) + } - /** - * Create a (R)LIKE/REGEXP expression. - */ - override def visitLike(ctx: LikeContext): Expression = { - val left = expression(ctx.value) - val right = expression(ctx.pattern) - val like = ctx.like.getType match { + // Create the predicate. + ctx.kind.getType match { + case SqlBaseParser.BETWEEN => + // BETWEEN is translated to lower <= e && e <= upper + invertIfNotDefined(And( + GreaterThanOrEqual(e, expression(ctx.lower)), + LessThanOrEqual(e, expression(ctx.upper)))) + case SqlBaseParser.IN if ctx.query != null => + throw new ParseException("IN with a Sub-query is currently not supported.", ctx) + case SqlBaseParser.IN => + invertIfNotDefined(In(e, ctx.expression.asScala.map(expression))) case SqlBaseParser.LIKE => - Like(left, right) + invertIfNotDefined(Like(e, expression(ctx.pattern))) case SqlBaseParser.RLIKE => - RLike(left, right) - } - invertIfNotDefined(like, ctx.NOT) - } - - /** - * Create an IS (NOT) NULL expression. - */ - override def visitNullPredicate(ctx: NullPredicateContext): Expression = withOrigin(ctx) { - val value = expression(ctx.value) - if (ctx.NOT != null) { - IsNotNull(value) - } else { - IsNull(value) + invertIfNotDefined(RLike(e, expression(ctx.pattern))) + case SqlBaseParser.NULL if ctx.NOT != null => + IsNotNull(e) + case SqlBaseParser.NULL => + IsNull(e) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 23f05ce84667c..9e1660df06cc6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -122,16 +122,6 @@ class PlanParserSuite extends PlanTest { table("a").union(table("b")).as("c").select(star())) } - test("transform query spec") { - val p = ScriptTransformation(Seq('a, 'b), "func", Seq.empty, table("e"), null) - assertEqual("select transform(a, b) using 'func' from e where f < 10", - p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string))) - assertEqual("map a, b using 'func' as c, d from e", - p.copy(output = Seq('c.string, 'd.string))) - assertEqual("reduce a, b using 'func' as (c: int, d decimal(10, 0)) from e", - p.copy(output = Seq('c.int, 'd.decimal(10, 0)))) - } - test("multi select query") { assertEqual( "from a select * select * where s < 10", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 8b2a5979e2c58..47e295a7e78bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.execution.datasources.BucketSpec @@ -781,4 +782,26 @@ class DDLCommandSuite extends PlanTest { comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) } + + test("commands only available in HiveContext") { + intercept[ParseException] { + parser.parsePlan("DROP TABLE D1.T1") + } + intercept[ParseException] { + parser.parsePlan("CREATE VIEW testView AS SELECT id FROM tab") + } + intercept[ParseException] { + parser.parsePlan("ALTER VIEW testView AS SELECT id FROM tab") + } + intercept[ParseException] { + parser.parsePlan( + """ + |CREATE EXTERNAL TABLE parquet_tab2(c1 INT, c2 STRING) + |TBLPROPERTIES('prop1Key '= "prop1Val", ' `prop2Key` '= "prop2Val") + """.stripMargin) + } + intercept[ParseException] { + parser.parsePlan("SELECT TRANSFORM (key, value) USING 'cat' AS (tKey, tValue) FROM testData") + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala index 12e4f49756c35..55e69f99a41fd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala @@ -133,6 +133,18 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { } } + /** + * Create a [[CatalogStorageFormat]]. This is part of the [[CreateTableAsSelect]] command. + */ + override def visitCreateFileFormat( + ctx: CreateFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { + if (ctx.storageHandler == null) { + typedVisit[CatalogStorageFormat](ctx.fileFormat) + } else { + visitStorageHandler(ctx.storageHandler) + } + } + /** * Create a [[CreateTableAsSelect]] command. */ @@ -282,6 +294,7 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { * Create a [[HiveScriptIOSchema]]. */ override protected def withScriptIOSchema( + ctx: QuerySpecificationContext, inRowFormat: RowFormatContext, recordWriter: Token, outRowFormat: RowFormatContext, @@ -391,7 +404,8 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { /** * Storage Handlers are currently not supported in the statements we support (CTAS). */ - override def visitStorageHandler(ctx: StorageHandlerContext): AnyRef = withOrigin(ctx) { + override def visitStorageHandler( + ctx: StorageHandlerContext): CatalogStorageFormat = withOrigin(ctx) { throw new ParseException("Storage Handlers are currently unsupported.", ctx) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala index 75108c6d47ea0..a8a0d6b8de364 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala @@ -18,16 +18,19 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.serde.serdeConstants -import org.scalatest.BeforeAndAfterAll -import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans +import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan import org.apache.spark.sql.catalyst.expressions.JsonTuple -import org.apache.spark.sql.catalyst.plans.logical.Generate +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Generate, ScriptTransformation} import org.apache.spark.sql.hive.execution.HiveSqlParser -class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { +class HiveQlSuite extends PlanTest { val parser = HiveSqlParser private def extractTableDesc(sql: String): (CatalogTable, Boolean) = { @@ -201,6 +204,26 @@ class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { assert(plan.children.head.asInstanceOf[Generate].generator.isInstanceOf[JsonTuple]) } + test("transform query spec") { + val plan1 = parser.parsePlan("select transform(a, b) using 'func' from e where f < 10") + .asInstanceOf[ScriptTransformation].copy(ioschema = null) + val plan2 = parser.parsePlan("map a, b using 'func' as c, d from e") + .asInstanceOf[ScriptTransformation].copy(ioschema = null) + val plan3 = parser.parsePlan("reduce a, b using 'func' as (c: int, d decimal(10, 0)) from e") + .asInstanceOf[ScriptTransformation].copy(ioschema = null) + + val p = ScriptTransformation( + Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b")), + "func", Seq.empty, plans.table("e"), null) + + comparePlans(plan1, + p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string))) + comparePlans(plan2, + p.copy(output = Seq('c.string, 'd.string))) + comparePlans(plan3, + p.copy(output = Seq('c.int, 'd.decimal(10, 0)))) + } + test("use backticks in output of Script Transform") { val plan = parser.parsePlan( """SELECT `t`.`thing1`