Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-14013][SQL] Proper temp function support in catalog #11972

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,15 @@ import org.apache.spark.sql.types._
* to resolve attribute references.
*/
object SimpleAnalyzer
extends SimpleAnalyzer(new SimpleCatalystConf(caseSensitiveAnalysis = true))
class SimpleAnalyzer(conf: CatalystConf)
extends Analyzer(new SessionCatalog(new InMemoryCatalog, conf), EmptyFunctionRegistry, conf)
extends SimpleAnalyzer(
EmptyFunctionRegistry,
new SimpleCatalystConf(caseSensitiveAnalysis = true))

class SimpleAnalyzer(functionRegistry: FunctionRegistry, conf: CatalystConf)
extends Analyzer(
new SessionCatalog(new InMemoryCatalog, functionRegistry, conf),
functionRegistry,
conf)

/**
* Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ trait FunctionRegistry {

/* Get the class of the registered function by specified name. */
def lookupFunction(name: String): Option[ExpressionInfo]

/* Get the builder of the registered function by specified name. */
def lookupFunctionBuilder(name: String): Option[FunctionBuilder]

/** Drop a function and return whether the function existed. */
def dropFunction(name: String): Boolean

}

class SimpleFunctionRegistry extends FunctionRegistry {
Expand Down Expand Up @@ -76,6 +83,14 @@ class SimpleFunctionRegistry extends FunctionRegistry {
functionBuilders.get(name).map(_._1)
}

override def lookupFunctionBuilder(name: String): Option[FunctionBuilder] = synchronized {
functionBuilders.get(name).map(_._2)
}

override def dropFunction(name: String): Boolean = synchronized {
functionBuilders.remove(name).isDefined
}

def copy(): SimpleFunctionRegistry = synchronized {
val registry = new SimpleFunctionRegistry
functionBuilders.iterator.foreach { case (name, (info, builder)) =>
Expand Down Expand Up @@ -106,6 +121,15 @@ object EmptyFunctionRegistry extends FunctionRegistry {
override def lookupFunction(name: String): Option[ExpressionInfo] = {
throw new UnsupportedOperationException
}

override def lookupFunctionBuilder(name: String): Option[FunctionBuilder] = {
throw new UnsupportedOperationException
}

override def dropFunction(name: String): Boolean = {
throw new UnsupportedOperationException
}

}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ import scala.collection.mutable
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf}
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, SimpleFunctionRegistry}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}


Expand All @@ -32,15 +35,22 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
*
* This class is not thread-safe.
*/
class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
class SessionCatalog(
externalCatalog: ExternalCatalog,
functionRegistry: FunctionRegistry,
conf: CatalystConf) {
import ExternalCatalog._

def this(externalCatalog: ExternalCatalog, functionRegistry: FunctionRegistry) {
this(externalCatalog, functionRegistry, new SimpleCatalystConf(true))
}

// For testing only.
def this(externalCatalog: ExternalCatalog) {
this(externalCatalog, new SimpleCatalystConf(true))
this(externalCatalog, new SimpleFunctionRegistry)
}

protected[this] val tempTables = new mutable.HashMap[String, LogicalPlan]
protected[this] val tempFunctions = new mutable.HashMap[String, CatalogFunction]

// Note: we track current database here because certain operations do not explicitly
// specify the database (e.g. DROP TABLE my_table). In these cases we must first
Expand Down Expand Up @@ -431,6 +441,18 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
externalCatalog.alterFunction(db, newFuncDefinition)
}

/**
* Retrieve the metadata of a metastore function.
*
* If a database is specified in `name`, this will return the function in that database.
* If no database is specified, this will return the function in the current database.
*/
def getFunction(name: FunctionIdentifier): CatalogFunction = {
val db = name.database.getOrElse(currentDb)
externalCatalog.getFunction(db, name.funcName)
}


// ----------------------------------------------------------------
// | Methods that interact with temporary and metastore functions |
// ----------------------------------------------------------------
Expand All @@ -439,14 +461,14 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
* Create a temporary function.
* This assumes no database is specified in `funcDefinition`.
*/
def createTempFunction(funcDefinition: CatalogFunction, ignoreIfExists: Boolean): Unit = {
require(funcDefinition.identifier.database.isEmpty,
"attempted to create a temporary function while specifying a database")
val name = funcDefinition.identifier.funcName
if (tempFunctions.contains(name) && !ignoreIfExists) {
def createTempFunction(
name: String,
funcDefinition: FunctionBuilder,
ignoreIfExists: Boolean): Unit = {
if (functionRegistry.lookupFunctionBuilder(name).isDefined && !ignoreIfExists) {
throw new AnalysisException(s"Temporary function '$name' already exists.")
}
tempFunctions.put(name, funcDefinition)
functionRegistry.registerFunction(name, funcDefinition)
}

/**
Expand All @@ -456,11 +478,10 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
// Hive has DROP FUNCTION and DROP TEMPORARY FUNCTION. We may want to consolidate
// dropFunction and dropTempFunction.
def dropTempFunction(name: String, ignoreIfNotExists: Boolean): Unit = {
if (!tempFunctions.contains(name) && !ignoreIfNotExists) {
if (!functionRegistry.dropFunction(name) && !ignoreIfNotExists) {
throw new AnalysisException(
s"Temporary function '$name' cannot be dropped because it does not exist!")
}
tempFunctions.remove(name)
}

/**
Expand All @@ -477,42 +498,37 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
throw new AnalysisException("rename does not support moving functions across databases")
}
val db = oldName.database.getOrElse(currentDb)
if (oldName.database.isDefined || !tempFunctions.contains(oldName.funcName)) {
val oldBuilder = functionRegistry.lookupFunctionBuilder(oldName.funcName)
if (oldName.database.isDefined || oldBuilder.isEmpty) {
externalCatalog.renameFunction(db, oldName.funcName, newName.funcName)
} else {
val func = tempFunctions(oldName.funcName)
val newFunc = func.copy(identifier = func.identifier.copy(funcName = newName.funcName))
tempFunctions.remove(oldName.funcName)
tempFunctions.put(newName.funcName, newFunc)
val oldExpressionInfo = functionRegistry.lookupFunction(oldName.funcName).get
val newExpressionInfo = new ExpressionInfo(
oldExpressionInfo.getClassName,
newName.funcName,
oldExpressionInfo.getUsage,
oldExpressionInfo.getExtended)
functionRegistry.dropFunction(oldName.funcName)
functionRegistry.registerFunction(newName.funcName, newExpressionInfo, oldBuilder.get)
}
}

/**
* Retrieve the metadata of an existing function.
*
* If a database is specified in `name`, this will return the function in that database.
* If no database is specified, this will first attempt to return a temporary function with
* the same name, then, if that does not exist, return the function in the current database.
* Return an [[Expression]] that represents the specified function, assuming it exists.
* Note: This is currently only used for temporary functions.
*/
def getFunction(name: FunctionIdentifier): CatalogFunction = {
val db = name.database.getOrElse(currentDb)
if (name.database.isDefined || !tempFunctions.contains(name.funcName)) {
externalCatalog.getFunction(db, name.funcName)
} else {
tempFunctions(name.funcName)
}
def lookupFunction(name: String, children: Seq[Expression]): Expression = {
functionRegistry.lookupFunction(name, children)
}

// TODO: implement lookupFunction that returns something from the registry itself

/**
* List all matching functions in the specified database, including temporary functions.
*/
def listFunctions(db: String, pattern: String): Seq[FunctionIdentifier] = {
val dbFunctions =
externalCatalog.listFunctions(db, pattern).map { f => FunctionIdentifier(f, Some(db)) }
val regex = pattern.replaceAll("\\*", ".*").r
val _tempFunctions = tempFunctions.keys.toSeq
val _tempFunctions = functionRegistry.listFunction()
.filter { f => regex.pattern.matcher(f).matches() }
.map { f => FunctionIdentifier(f) }
dbFunctions ++ _tempFunctions
Expand All @@ -521,8 +537,8 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
/**
* Return a temporary function. For testing only.
*/
private[catalog] def getTempFunction(name: String): Option[CatalogFunction] = {
tempFunctions.get(name)
private[catalog] def getTempFunction(name: String): Option[FunctionBuilder] = {
functionRegistry.lookupFunctionBuilder(name)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ trait AnalysisTest extends PlanTest {

private def makeAnalyzer(caseSensitive: Boolean): Analyzer = {
val conf = new SimpleCatalystConf(caseSensitive)
val catalog = new SessionCatalog(new InMemoryCatalog, conf)
val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
catalog.createTempTable("TaBlE", TestRelations.testRelation, overrideIfExists = true)
new Analyzer(catalog, EmptyFunctionRegistry, conf) {
override val extendedResolutionRules = EliminateSubqueryAliases :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.types._

class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter {
private val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true)
private val catalog = new SessionCatalog(new InMemoryCatalog, conf)
private val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
private val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, conf)

private val relation = LocalRelation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.catalog
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
import org.apache.spark.sql.catalyst.plans.logical.{Range, SubqueryAlias}


Expand Down Expand Up @@ -682,20 +683,20 @@ class SessionCatalogSuite extends SparkFunSuite {

test("create temp function") {
val catalog = new SessionCatalog(newBasicCatalog())
val tempFunc1 = newFunc("temp1")
val tempFunc2 = newFunc("temp2")
catalog.createTempFunction(tempFunc1, ignoreIfExists = false)
catalog.createTempFunction(tempFunc2, ignoreIfExists = false)
val tempFunc1 = (e: Seq[Expression]) => e.head
val tempFunc2 = (e: Seq[Expression]) => e.last
catalog.createTempFunction("temp1", tempFunc1, ignoreIfExists = false)
catalog.createTempFunction("temp2", tempFunc2, ignoreIfExists = false)
assert(catalog.getTempFunction("temp1") == Some(tempFunc1))
assert(catalog.getTempFunction("temp2") == Some(tempFunc2))
assert(catalog.getTempFunction("temp3") == None)
val tempFunc3 = (e: Seq[Expression]) => Literal(e.size)
// Temporary function already exists
intercept[AnalysisException] {
catalog.createTempFunction(tempFunc1, ignoreIfExists = false)
catalog.createTempFunction("temp1", tempFunc3, ignoreIfExists = false)
}
// Temporary function is overridden
val tempFunc3 = tempFunc1.copy(className = "something else")
catalog.createTempFunction(tempFunc3, ignoreIfExists = true)
catalog.createTempFunction("temp1", tempFunc3, ignoreIfExists = true)
assert(catalog.getTempFunction("temp1") == Some(tempFunc3))
}

Expand Down Expand Up @@ -725,8 +726,8 @@ class SessionCatalogSuite extends SparkFunSuite {

test("drop temp function") {
val catalog = new SessionCatalog(newBasicCatalog())
val tempFunc = newFunc("func1")
catalog.createTempFunction(tempFunc, ignoreIfExists = false)
val tempFunc = (e: Seq[Expression]) => e.head
catalog.createTempFunction("func1", tempFunc, ignoreIfExists = false)
assert(catalog.getTempFunction("func1") == Some(tempFunc))
catalog.dropTempFunction("func1", ignoreIfNotExists = false)
assert(catalog.getTempFunction("func1") == None)
Expand Down Expand Up @@ -755,20 +756,15 @@ class SessionCatalogSuite extends SparkFunSuite {
}
}

test("get temp function") {
val externalCatalog = newBasicCatalog()
val sessionCatalog = new SessionCatalog(externalCatalog)
val metastoreFunc = externalCatalog.getFunction("db2", "func1")
val tempFunc = newFunc("func1").copy(className = "something weird")
sessionCatalog.createTempFunction(tempFunc, ignoreIfExists = false)
sessionCatalog.setCurrentDatabase("db2")
// If a database is specified, we'll always return the function in that database
assert(sessionCatalog.getFunction(FunctionIdentifier("func1", Some("db2"))) == metastoreFunc)
// If no database is specified, we'll first return temporary functions
assert(sessionCatalog.getFunction(FunctionIdentifier("func1")) == tempFunc)
// Then, if no such temporary function exist, check the current database
sessionCatalog.dropTempFunction("func1", ignoreIfNotExists = false)
assert(sessionCatalog.getFunction(FunctionIdentifier("func1")) == metastoreFunc)
test("lookup temp function") {
val catalog = new SessionCatalog(newBasicCatalog())
val tempFunc1 = (e: Seq[Expression]) => e.head
catalog.createTempFunction("func1", tempFunc1, ignoreIfExists = false)
assert(catalog.lookupFunction("func1", Seq(Literal(1), Literal(2), Literal(3))) == Literal(1))
catalog.dropTempFunction("func1", ignoreIfNotExists = false)
intercept[AnalysisException] {
catalog.lookupFunction("func1", Seq(Literal(1), Literal(2), Literal(3)))
}
}

test("rename function") {
Expand Down Expand Up @@ -813,8 +809,8 @@ class SessionCatalogSuite extends SparkFunSuite {
test("rename temp function") {
val externalCatalog = newBasicCatalog()
val sessionCatalog = new SessionCatalog(externalCatalog)
val tempFunc = newFunc("func1").copy(className = "something weird")
sessionCatalog.createTempFunction(tempFunc, ignoreIfExists = false)
val tempFunc = (e: Seq[Expression]) => e.head
sessionCatalog.createTempFunction("func1", tempFunc, ignoreIfExists = false)
sessionCatalog.setCurrentDatabase("db2")
// If a database is specified, we'll always rename the function in that database
sessionCatalog.renameFunction(
Expand All @@ -825,8 +821,7 @@ class SessionCatalogSuite extends SparkFunSuite {
// If no database is specified, we'll first rename temporary functions
sessionCatalog.createFunction(newFunc("func1", Some("db2")))
sessionCatalog.renameFunction(FunctionIdentifier("func1"), FunctionIdentifier("func4"))
assert(sessionCatalog.getTempFunction("func4") ==
Some(tempFunc.copy(identifier = FunctionIdentifier("func4"))))
assert(sessionCatalog.getTempFunction("func4") == Some(tempFunc))
assert(sessionCatalog.getTempFunction("func1") == None)
assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func1", "func3"))
// Then, if no such temporary function exist, rename the function in the current database
Expand Down Expand Up @@ -858,12 +853,12 @@ class SessionCatalogSuite extends SparkFunSuite {

test("list functions") {
val catalog = new SessionCatalog(newBasicCatalog())
val tempFunc1 = newFunc("func1").copy(className = "march")
val tempFunc2 = newFunc("yes_me").copy(className = "april")
val tempFunc1 = (e: Seq[Expression]) => e.head
val tempFunc2 = (e: Seq[Expression]) => e.last
catalog.createFunction(newFunc("func2", Some("db2")))
catalog.createFunction(newFunc("not_me", Some("db2")))
catalog.createTempFunction(tempFunc1, ignoreIfExists = false)
catalog.createTempFunction(tempFunc2, ignoreIfExists = false)
catalog.createTempFunction("func1", tempFunc1, ignoreIfExists = false)
catalog.createTempFunction("yes_me", tempFunc2, ignoreIfExists = false)
assert(catalog.listFunctions("db1", "*").toSet ==
Set(FunctionIdentifier("func1"),
FunctionIdentifier("yes_me")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {

private val caseInsensitiveConf = new SimpleCatalystConf(false)
private val caseInsensitiveAnalyzer = new Analyzer(
new SessionCatalog(new InMemoryCatalog, caseInsensitiveConf),
new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, caseInsensitiveConf),
EmptyFunctionRegistry,
caseInsensitiveConf)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.rules._

class EliminateSortsSuite extends PlanTest {
val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true, orderByOrdinal = false)
val catalog = new SessionCatalog(new InMemoryCatalog, conf)
val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, conf)

object Optimize extends RuleExecutor[LogicalPlan] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@ private[sql] class SessionState(ctx: SQLContext) {
lazy val experimentalMethods = new ExperimentalMethods

/**
* Internal catalog for managing table and database states.
* Internal catalog for managing functions registered by the user.
*/
lazy val catalog = new SessionCatalog(ctx.externalCatalog, conf)
lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy()

/**
* Internal catalog for managing functions registered by the user.
* Internal catalog for managing table and database states.
*/
lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy()
lazy val catalog = new SessionCatalog(ctx.externalCatalog, functionRegistry, conf)

/**
* Interface exposed to the user for registering user-defined functions.
Expand Down
Loading