Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-20100][SQL] Refactor SessionState initialization #17433

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ class SessionCatalog(
functionRegistry: FunctionRegistry,
conf: CatalystConf,
hadoopConf: Configuration,
parser: ParserInterface) extends Logging {
parser: ParserInterface,
functionResourceLoader: FunctionResourceLoader) extends Logging {
import SessionCatalog._
import CatalogTypes.TablePartitionSpec

Expand All @@ -69,8 +70,8 @@ class SessionCatalog(
functionRegistry,
conf,
new Configuration(),
CatalystSqlParser)
functionResourceLoader = DummyFunctionResourceLoader
CatalystSqlParser,
DummyFunctionResourceLoader)
}

// For testing only.
Expand All @@ -90,9 +91,7 @@ class SessionCatalog(
// check whether the temporary table or function exists, then, if not, operate on
// the corresponding item in the current database.
@GuardedBy("this")
protected var currentDb = formatDatabaseName(DEFAULT_DATABASE)

@volatile var functionResourceLoader: FunctionResourceLoader = _
protected var currentDb: String = formatDatabaseName(DEFAULT_DATABASE)

/**
* Checks if the given name conforms the Hive standard ("[a-zA-z_0-9]+"),
Expand Down Expand Up @@ -1059,9 +1058,6 @@ class SessionCatalog(
* by a tuple (resource type, resource uri).
*/
def loadFunctionResources(resources: Seq[FunctionResource]): Unit = {
if (functionResourceLoader == null) {
throw new IllegalStateException("functionResourceLoader has not yet been initialized")
}
resources.foreach(functionResourceLoader.loadResource)
}

Expand Down Expand Up @@ -1259,28 +1255,16 @@ class SessionCatalog(
}

/**
* Create a new [[SessionCatalog]] with the provided parameters. `externalCatalog` and
* `globalTempViewManager` are `inherited`, while `currentDb` and `tempTables` are copied.
* Copy the current state of the catalog to another catalog.
*
* This function is synchronized on this [[SessionCatalog]] (the source) to make sure the copied
* state is consistent. The target [[SessionCatalog]] is not synchronized, and should not be
* because the target [[SessionCatalog]] should not be published at this point. The caller must
* synchronize on the target if this assumption does not hold.
*/
def newSessionCatalogWith(
conf: CatalystConf,
hadoopConf: Configuration,
functionRegistry: FunctionRegistry,
parser: ParserInterface): SessionCatalog = {
val catalog = new SessionCatalog(
externalCatalog,
globalTempViewManager,
functionRegistry,
conf,
hadoopConf,
parser)

synchronized {
catalog.currentDb = currentDb
// copy over temporary tables
tempTables.foreach(kv => catalog.tempTables.put(kv._1, kv._2))
}

catalog
private[sql] def copyStateTo(target: SessionCatalog): Unit = synchronized {
target.currentDb = currentDb
// copy over temporary tables
tempTables.foreach(kv => target.tempTables.put(kv._1, kv._2))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,14 @@

package org.apache.spark.sql.catalyst.optimizer

import scala.annotation.tailrec
import scala.collection.immutable.HashSet
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.api.java.function.FilterFunction
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
Expand Down Expand Up @@ -79,7 +73,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
Batch("Aggregate", fixedPoint,
RemoveLiteralFromGroupExpressions,
RemoveRepetitionFromGroupExpressions) ::
Batch("Operator Optimizations", fixedPoint,
Batch("Operator Optimizations", fixedPoint, Seq(
// Operator push down
PushProjectionThroughUnion,
ReorderJoin(conf),
Expand Down Expand Up @@ -117,7 +111,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
RemoveRedundantProject,
SimplifyCreateStructOps,
SimplifyCreateArrayOps,
SimplifyCreateMapOps) ::
SimplifyCreateMapOps) ++
extendedOperatorOptimizationRules: _*) ::
Batch("Check Cartesian Products", Once,
CheckCartesianProducts(conf)) ::
Batch("Join Reorder", Once,
Expand Down Expand Up @@ -146,6 +141,11 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
s.withNewPlan(newPlan)
}
}

/**
* Override to provide additional rules for the operator optimization batch.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure whether we need to split the batch Operator Optimizations to smaller independent batches or move some rules out of this batch in the future. If so, the location of this rule becomes unstable. We might need to explain it in the comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anything in catalyst can be changed between spark versions. This hook included.

*/
def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.sql.catalyst.catalog

import org.apache.hadoop.conf.Configuration

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, SimpleCatalystConf, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis._
Expand Down Expand Up @@ -1331,17 +1329,15 @@ abstract class SessionCatalogSuite extends PlanTest {
}
}

test("clone SessionCatalog - temp views") {
test("copy SessionCatalog state - temp views") {
withEmptyCatalog { original =>
val tempTable1 = Range(1, 10, 1, 10)
original.createTempView("copytest1", tempTable1, overrideIfExists = false)

// check if tables copied over
val clone = original.newSessionCatalogWith(
SimpleCatalystConf(caseSensitiveAnalysis = true),
new Configuration(),
new SimpleFunctionRegistry,
CatalystSqlParser)
val clone = new SessionCatalog(original.externalCatalog)
original.copyStateTo(clone)

assert(original ne clone)
assert(clone.getTempView("copytest1") == Some(tempTable1))

Expand All @@ -1355,7 +1351,7 @@ abstract class SessionCatalogSuite extends PlanTest {
}
}

test("clone SessionCatalog - current db") {
test("copy SessionCatalog state - current db") {
withEmptyCatalog { original =>
val db1 = "db1"
val db2 = "db2"
Expand All @@ -1368,11 +1364,9 @@ abstract class SessionCatalogSuite extends PlanTest {
original.setCurrentDatabase(db1)

// check if current db copied over
val clone = original.newSessionCatalogWith(
SimpleCatalystConf(caseSensitiveAnalysis = true),
new Configuration(),
new SimpleFunctionRegistry,
CatalystSqlParser)
val clone = new SessionCatalog(original.externalCatalog)
original.copyStateTo(clone)

assert(original ne clone)
assert(clone.getCurrentDatabase == db1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,17 @@ class SparkOptimizer(
experimentalMethods: ExperimentalMethods)
extends Optimizer(catalog, conf) {

override def batches: Seq[Batch] = super.batches :+
override def batches: Seq[Batch] = (super.batches :+
Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog, conf)) :+
Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+
Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+
Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions)) ++
postHocOptimizationBatches :+
Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*)

/**
* Optimization batches that are executed after the regular optimization batches, but before the
* batch executing the [[ExperimentalMethods]] optimizer rules. This hook can be used to add
* custom optimizer batches to the Spark optimizer.
*/
def postHocOptimizationBatches: Seq[Batch] = Nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@ import org.apache.spark.sql.internal.SQLConf
class SparkPlanner(
val sparkContext: SparkContext,
val conf: SQLConf,
val extraStrategies: Seq[Strategy])
val experimentalMethods: ExperimentalMethods)
extends SparkStrategies {

def numPartitions: Int = conf.numShufflePartitions

def strategies: Seq[Strategy] =
extraStrategies ++ (
experimentalMethods.extraStrategies ++
extraPlanningStrategies ++ (
FileSourceStrategy ::
DataSourceStrategy ::
SpecialLimits ::
Expand All @@ -42,6 +43,12 @@ class SparkPlanner(
InMemoryScans ::
BasicOperators :: Nil)

/**
* Override to add extra planning strategies to the planner. These strategies are tried after
* the strategies defined in [[ExperimentalMethods]], and before the regular strategies.
*/
def extraPlanningStrategies: Seq[Strategy] = Nil

override protected def collectPlaceholders(plan: SparkPlan): Seq[(SparkPlan, LogicalPlan)] = {
plan.collect {
case placeholder @ PlanLater(logicalPlan) => placeholder -> logicalPlan
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ package org.apache.spark.sql.execution.streaming
import java.util.concurrent.atomic.AtomicInteger

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.{CurrentBatchTimestamp, Literal}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.{SparkSession, Strategy}
import org.apache.spark.sql.catalyst.expressions.CurrentBatchTimestamp
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode}
Expand All @@ -40,20 +40,17 @@ class IncrementalExecution(
offsetSeqMetadata: OffsetSeqMetadata)
extends QueryExecution(sparkSession, logicalPlan) with Logging {

// TODO: make this always part of planning.
val streamingExtraStrategies =
sparkSession.sessionState.planner.StatefulAggregationStrategy +:
sparkSession.sessionState.planner.FlatMapGroupsWithStateStrategy +:
sparkSession.sessionState.planner.StreamingRelationStrategy +:
sparkSession.sessionState.planner.StreamingDeduplicationStrategy +:
sparkSession.sessionState.experimentalMethods.extraStrategies

// Modified planner with stateful operations.
override def planner: SparkPlanner =
new SparkPlanner(
override val planner: SparkPlanner = new SparkPlanner(
sparkSession.sparkContext,
sparkSession.sessionState.conf,
streamingExtraStrategies)
sparkSession.sessionState.experimentalMethods) {
override def extraPlanningStrategies: Seq[Strategy] =
StatefulAggregationStrategy ::
FlatMapGroupsWithStateStrategy ::
StreamingRelationStrategy ::
StreamingDeduplicationStrategy :: Nil
}

/**
* See [SPARK-18339]
Expand Down
Loading