Skip to content

Commit

Permalink
[SPARK-48344][SQL] Prepare SQL Scripting for addition of Execution Fr…
Browse files Browse the repository at this point in the history
…amework

### What changes were proposed in this pull request?
This PR is Initial refactoring of SQL Scripting to prepare it for addition of **Execution Framework**:
- Move all files to proper directories/paths.
- Convert `SqlScriptingLogicalOperators` to `SqlScriptingLogicalPlans`.
- Remove `CompoundNestedStatementIteratorExec` because it is unnecessary abstraction.
- Remove `parseScript` because it is no more needed. Parsing is done in `parsePlan` method.

### Why are the changes needed?
This changes are needed so execution of SQL Scripts can be implemented properly.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Existing tests.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #48879 from miland-db/milan-dankovic_data/refactor-execution-1.

Authored-by: Milan Dankovic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
miland-db authored and cloud-fan committed Nov 20, 2024
1 parent c149dcb commit 8791767
Show file tree
Hide file tree
Showing 15 changed files with 431 additions and 240 deletions.
5 changes: 5 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -5370,6 +5370,11 @@
"<variableName> is a VARIABLE and cannot be updated using the SET statement. Use SET VARIABLE <variableName> = ... instead."
]
},
"SQL_SCRIPTING" : {
"message" : [
"SQL Scripting is under development and not all features are supported. SQL Scripting enables users to write procedural SQL including control flow and error handling. To enable existing features set <sqlScriptingEnabled> to `true`."
]
},
"STATE_STORE_MULTIPLE_COLUMN_FAMILIES" : {
"message" : [
"Creating multiple column families with <stateStoreProvider> is not supported."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ compoundOrSingleStatement
;

singleCompoundStatement
: beginEndCompoundBlock SEMICOLON? EOF
: BEGIN compoundBody END SEMICOLON? EOF
;

beginEndCompoundBlock
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,20 @@ case class UnclosedCommentProcessor(command: String, tokenStream: CommonTokenStr
}
}

override def exitCompoundOrSingleStatement(
ctx: SqlBaseParser.CompoundOrSingleStatementContext): Unit = {
// Same as in exitSingleStatement, we shouldn't parse the comments in SET command.
if (Option(ctx.singleStatement()).forall(
!_.setResetStatement().isInstanceOf[SqlBaseParser.SetConfigurationContext])) {
checkUnclosedComment(tokenStream, command)
}
}

override def exitSingleCompoundStatement(
ctx: SqlBaseParser.SingleCompoundStatementContext): Unit = {
checkUnclosedComment(tokenStream, command)
}

/** check `has_unclosed_bracketed_comment` to find out the unclosed bracketed comment. */
private def checkUnclosedComment(tokenStream: CommonTokenStream, command: String) = {
assert(tokenStream.getTokenSource.isInstanceOf[SqlBaseLexer])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.antlr.v4.runtime.ParserRuleContext
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.parser.ParserUtils.withOrigin
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.{CompoundPlanStatement, LogicalPlan}
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.errors.QueryParsingErrors

Expand Down Expand Up @@ -80,9 +80,10 @@ abstract class AbstractSqlParser extends AbstractParser with ParserInterface {

/** Creates LogicalPlan for a given SQL string. */
override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser =>
val ctx = parser.singleStatement()
val ctx = parser.compoundOrSingleStatement()
withErrorHandling(ctx, Some(sqlText)) {
astBuilder.visitSingleStatement(ctx) match {
astBuilder.visitCompoundOrSingleStatement(ctx) match {
case compoundBody: CompoundPlanStatement => compoundBody
case plan: LogicalPlan => plan
case _ =>
val position = Origin(None, None)
Expand All @@ -91,19 +92,6 @@ abstract class AbstractSqlParser extends AbstractParser with ParserInterface {
}
}

/** Creates [[CompoundBody]] for a given SQL script string. */
override def parseScript(sqlScriptText: String): CompoundBody = parse(sqlScriptText) { parser =>
val ctx = parser.compoundOrSingleStatement()
withErrorHandling(ctx, Some(sqlScriptText)) {
astBuilder.visitCompoundOrSingleStatement(ctx) match {
case body: CompoundBody => body
case _ =>
val position = Origin(None, None)
throw QueryParsingErrors.sqlStatementUnsupportedError(sqlScriptText, position)
}
}
}

def withErrorHandling[T](ctx: ParserRuleContext, sqlText: Option[String])(toResult: => T): T = {
withOrigin(ctx, sqlText) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,19 +131,20 @@ class AstBuilder extends DataTypeAstBuilder
}

override def visitCompoundOrSingleStatement(
ctx: CompoundOrSingleStatementContext): CompoundBody = withOrigin(ctx) {
ctx: CompoundOrSingleStatementContext): LogicalPlan = withOrigin(ctx) {
Option(ctx.singleCompoundStatement()).map { s =>
if (!conf.getConf(SQLConf.SQL_SCRIPTING_ENABLED)) {
throw SqlScriptingErrors.sqlScriptingNotEnabled(CurrentOrigin.get)
}
visit(s).asInstanceOf[CompoundBody]
}.getOrElse {
val logicalPlan = visitSingleStatement(ctx.singleStatement())
CompoundBody(Seq(SingleStatement(parsedPlan = logicalPlan)),
Some(java.util.UUID.randomUUID.toString.toLowerCase(Locale.ROOT)))
visitSingleStatement(ctx.singleStatement())
}
}

override def visitSingleCompoundStatement(ctx: SingleCompoundStatementContext): CompoundBody = {
val labelCtx = new SqlScriptingLabelContext()
visitBeginEndCompoundBlockImpl(ctx.beginEndCompoundBlock(), labelCtx)
visitCompoundBodyImpl(ctx.compoundBody(), None, allowVarDeclare = true, labelCtx)
}

private def visitCompoundBodyImpl(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,4 @@ trait ParserInterface extends DataTypeParserInterface {
*/
@throws[ParseException]("Text cannot be parsed to a LogicalPlan")
def parseQuery(sqlText: String): LogicalPlan

/**
* Parse a SQL script string to a [[CompoundBody]].
*/
@throws[ParseException]("Text cannot be parsed to a CompoundBody")
def parseScript(sqlScriptText: String): CompoundBody
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,25 @@
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.parser
package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}

import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, WithOrigin}

/**
* Trait for all SQL Scripting logical operators that are product of parsing phase.
* These operators will be used by the SQL Scripting interpreter to generate execution nodes.
*/
sealed trait CompoundPlanStatement
sealed trait CompoundPlanStatement extends LogicalPlan

/**
* Logical operator representing result of parsing a single SQL statement
* that is supposed to be executed against Spark.
* @param parsedPlan Result of SQL statement parsing.
*/
case class SingleStatement(parsedPlan: LogicalPlan)
extends CompoundPlanStatement
with WithOrigin {
extends CompoundPlanStatement {

override val origin: Origin = CurrentOrigin.get

Expand All @@ -46,6 +46,14 @@ case class SingleStatement(parsedPlan: LogicalPlan)
assert(origin.sqlText.isDefined && origin.startIndex.isDefined && origin.stopIndex.isDefined)
origin.sqlText.get.substring(origin.startIndex.get, origin.stopIndex.get + 1)
}

override def output: Seq[Attribute] = parsedPlan.output

override def children: Seq[LogicalPlan] = parsedPlan.children

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[LogicalPlan]): LogicalPlan =
SingleStatement(parsedPlan.withNewChildren(newChildren))
}

/**
Expand All @@ -57,7 +65,15 @@ case class SingleStatement(parsedPlan: LogicalPlan)
*/
case class CompoundBody(
collection: Seq[CompoundPlanStatement],
label: Option[String]) extends CompoundPlanStatement
label: Option[String]) extends Command with CompoundPlanStatement {

override def children: Seq[LogicalPlan] = collection

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = {
CompoundBody(newChildren.map(_.asInstanceOf[CompoundPlanStatement]), label)
}
}

/**
* Logical operator for IF ELSE statement.
Expand All @@ -73,6 +89,30 @@ case class IfElseStatement(
conditionalBodies: Seq[CompoundBody],
elseBody: Option[CompoundBody]) extends CompoundPlanStatement {
assert(conditions.length == conditionalBodies.length)

override def output: Seq[Attribute] = Seq.empty

override def children: Seq[LogicalPlan] = Seq.concat(conditions, conditionalBodies, elseBody)

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = {
val conditions = newChildren
.filter(_.isInstanceOf[SingleStatement])
.map(_.asInstanceOf[SingleStatement])
var conditionalBodies = newChildren
.filter(_.isInstanceOf[CompoundBody])
.map(_.asInstanceOf[CompoundBody])
var elseBody: Option[CompoundBody] = None

assert(conditions.length == conditionalBodies.length ||
conditions.length + 1 == conditionalBodies.length)

if (conditions.length < conditionalBodies.length) {
conditionalBodies = conditionalBodies.dropRight(1)
elseBody = Some(conditionalBodies.last)
}
IfElseStatement(conditions, conditionalBodies, elseBody)
}
}

/**
Expand All @@ -88,7 +128,21 @@ case class IfElseStatement(
case class WhileStatement(
condition: SingleStatement,
body: CompoundBody,
label: Option[String]) extends CompoundPlanStatement
label: Option[String]) extends CompoundPlanStatement {

override def output: Seq[Attribute] = Seq.empty

override def children: Seq[LogicalPlan] = Seq(condition, body)

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = {
assert(newChildren.length == 2)
WhileStatement(
newChildren(0).asInstanceOf[SingleStatement],
newChildren(1).asInstanceOf[CompoundBody],
label)
}
}

/**
* Logical operator for REPEAT statement.
Expand All @@ -104,7 +158,21 @@ case class WhileStatement(
case class RepeatStatement(
condition: SingleStatement,
body: CompoundBody,
label: Option[String]) extends CompoundPlanStatement
label: Option[String]) extends CompoundPlanStatement {

override def output: Seq[Attribute] = Seq.empty

override def children: Seq[LogicalPlan] = Seq(condition, body)

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = {
assert(newChildren.length == 2)
RepeatStatement(
newChildren(0).asInstanceOf[SingleStatement],
newChildren(1).asInstanceOf[CompoundBody],
label)
}
}

/**
* Logical operator for LEAVE statement.
Expand All @@ -113,7 +181,14 @@ case class RepeatStatement(
* with the next statement after the body/loop.
* @param label Label of the compound or loop to leave.
*/
case class LeaveStatement(label: String) extends CompoundPlanStatement
case class LeaveStatement(label: String) extends CompoundPlanStatement {
override def output: Seq[Attribute] = Seq.empty

override def children: Seq[LogicalPlan] = Seq.empty

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = LeaveStatement(label)
}

/**
* Logical operator for ITERATE statement.
Expand All @@ -122,7 +197,14 @@ case class LeaveStatement(label: String) extends CompoundPlanStatement
* with the next iteration.
* @param label Label of the loop to iterate.
*/
case class IterateStatement(label: String) extends CompoundPlanStatement
case class IterateStatement(label: String) extends CompoundPlanStatement {
override def output: Seq[Attribute] = Seq.empty

override def children: Seq[LogicalPlan] = Seq.empty

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = IterateStatement(label)
}

/**
* Logical operator for CASE statement.
Expand All @@ -136,6 +218,30 @@ case class CaseStatement(
conditionalBodies: Seq[CompoundBody],
elseBody: Option[CompoundBody]) extends CompoundPlanStatement {
assert(conditions.length == conditionalBodies.length)

override def output: Seq[Attribute] = Seq.empty

override def children: Seq[LogicalPlan] = Seq.concat(conditions, conditionalBodies, elseBody)

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = {
val conditions = newChildren
.filter(_.isInstanceOf[SingleStatement])
.map(_.asInstanceOf[SingleStatement])
var conditionalBodies = newChildren
.filter(_.isInstanceOf[CompoundBody])
.map(_.asInstanceOf[CompoundBody])
var elseBody: Option[CompoundBody] = None

assert(conditions.length == conditionalBodies.length ||
conditions.length + 1 == conditionalBodies.length)

if (conditions.length < conditionalBodies.length) {
conditionalBodies = conditionalBodies.dropRight(1)
elseBody = Some(conditionalBodies.last)
}
CaseStatement(conditions, conditionalBodies, elseBody)
}
}

/**
Expand All @@ -149,4 +255,15 @@ case class CaseStatement(
*/
case class LoopStatement(
body: CompoundBody,
label: Option[String]) extends CompoundPlanStatement
label: Option[String]) extends CompoundPlanStatement {

override def output: Seq[Attribute] = Seq.empty

override def children: Seq[LogicalPlan] = Seq(body)

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = {
assert(newChildren.length == 1)
LoopStatement(newChildren(0).asInstanceOf[CompoundBody], label)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
package org.apache.spark.sql.errors

import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.catalyst.util.QuotingUtils.toSQLConf
import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
import org.apache.spark.sql.errors.QueryExecutionErrors.toSQLStmt
import org.apache.spark.sql.exceptions.SqlScriptingException
import org.apache.spark.sql.internal.SQLConf

/**
* Object for grouping error messages thrown during parsing/interpreting phase
Expand Down Expand Up @@ -82,6 +84,15 @@ private[sql] object SqlScriptingErrors {
messageParameters = Map("invalidStatement" -> toSQLStmt(stmt)))
}

def sqlScriptingNotEnabled(origin: Origin): Throwable = {
new SqlScriptingException(
errorClass = "UNSUPPORTED_FEATURE.SQL_SCRIPTING",
cause = null,
origin = origin,
messageParameters = Map(
"sqlScriptingEnabled" -> toSQLConf(SQLConf.SQL_SCRIPTING_ENABLED.key)))
}

def booleanStatementWithEmptyRow(
origin: Origin,
stmt: String): Throwable = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3423,6 +3423,15 @@ object SQLConf {
.version("2.3.0")
.fallbackConf(org.apache.spark.internal.config.STRING_REDACTION_PATTERN)

val SQL_SCRIPTING_ENABLED =
buildConf("spark.sql.scripting.enabled")
.doc("SQL Scripting feature is under development and its use should be done under this " +
"feature flag. SQL Scripting enables users to write procedural SQL including control " +
"flow and error handling.")
.version("4.0.0")
.booleanConf
.createWithDefault(false)

val CONCAT_BINARY_AS_STRING = buildConf("spark.sql.function.concatBinaryAsString")
.doc("When this option is set to false and all inputs are binary, `functions.concat` returns " +
"an output as binary. Otherwise, it returns as a string.")
Expand Down
Loading

0 comments on commit 8791767

Please sign in to comment.