Skip to content

Commit

Permalink
[SPARK-14013][SQL] Proper temp function support in catalog
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Session catalog was added in #11750. However, it doesn't really support temporary functions properly; right now we only store the metadata in the form of `CatalogFunction`, but this doesn't make sense for temporary functions because there is no class name.

This patch moves the `FunctionRegistry` into the `SessionCatalog`. With this, the user can call `catalog.createTempFunction` and `catalog.lookupFunction` to use the function they registered previously. This is currently still dead code, however.

## How was this patch tested?

`SessionCatalogSuite`.

Author: Andrew Or <[email protected]>

Closes #11972 from andrewor14/temp-functions.
  • Loading branch information
Andrew Or committed Mar 28, 2016
1 parent 2f98ee6 commit 27aab80
Show file tree
Hide file tree
Showing 12 changed files with 136 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,15 @@ import org.apache.spark.sql.types._
* to resolve attribute references.
*/
object SimpleAnalyzer
extends SimpleAnalyzer(new SimpleCatalystConf(caseSensitiveAnalysis = true))
class SimpleAnalyzer(conf: CatalystConf)
extends Analyzer(new SessionCatalog(new InMemoryCatalog, conf), EmptyFunctionRegistry, conf)
extends SimpleAnalyzer(
EmptyFunctionRegistry,
new SimpleCatalystConf(caseSensitiveAnalysis = true))

class SimpleAnalyzer(functionRegistry: FunctionRegistry, conf: CatalystConf)
extends Analyzer(
new SessionCatalog(new InMemoryCatalog, functionRegistry, conf),
functionRegistry,
conf)

/**
* Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ trait FunctionRegistry {

/* Get the class of the registered function by specified name. */
def lookupFunction(name: String): Option[ExpressionInfo]

/* Get the builder of the registered function by specified name. */
def lookupFunctionBuilder(name: String): Option[FunctionBuilder]

/** Drop a function and return whether the function existed. */
def dropFunction(name: String): Boolean

}

class SimpleFunctionRegistry extends FunctionRegistry {
Expand Down Expand Up @@ -76,6 +83,14 @@ class SimpleFunctionRegistry extends FunctionRegistry {
functionBuilders.get(name).map(_._1)
}

override def lookupFunctionBuilder(name: String): Option[FunctionBuilder] = synchronized {
functionBuilders.get(name).map(_._2)
}

override def dropFunction(name: String): Boolean = synchronized {
functionBuilders.remove(name).isDefined
}

def copy(): SimpleFunctionRegistry = synchronized {
val registry = new SimpleFunctionRegistry
functionBuilders.iterator.foreach { case (name, (info, builder)) =>
Expand Down Expand Up @@ -106,6 +121,15 @@ object EmptyFunctionRegistry extends FunctionRegistry {
override def lookupFunction(name: String): Option[ExpressionInfo] = {
throw new UnsupportedOperationException
}

override def lookupFunctionBuilder(name: String): Option[FunctionBuilder] = {
throw new UnsupportedOperationException
}

override def dropFunction(name: String): Boolean = {
throw new UnsupportedOperationException
}

}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ import scala.collection.mutable
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf}
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, SimpleFunctionRegistry}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}


Expand All @@ -32,15 +35,22 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
*
* This class is not thread-safe.
*/
class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
class SessionCatalog(
externalCatalog: ExternalCatalog,
functionRegistry: FunctionRegistry,
conf: CatalystConf) {
import ExternalCatalog._

def this(externalCatalog: ExternalCatalog, functionRegistry: FunctionRegistry) {
this(externalCatalog, functionRegistry, new SimpleCatalystConf(true))
}

// For testing only.
def this(externalCatalog: ExternalCatalog) {
this(externalCatalog, new SimpleCatalystConf(true))
this(externalCatalog, new SimpleFunctionRegistry)
}

protected[this] val tempTables = new mutable.HashMap[String, LogicalPlan]
protected[this] val tempFunctions = new mutable.HashMap[String, CatalogFunction]

// Note: we track current database here because certain operations do not explicitly
// specify the database (e.g. DROP TABLE my_table). In these cases we must first
Expand Down Expand Up @@ -431,6 +441,18 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
externalCatalog.alterFunction(db, newFuncDefinition)
}

/**
* Retrieve the metadata of a metastore function.
*
* If a database is specified in `name`, this will return the function in that database.
* If no database is specified, this will return the function in the current database.
*/
def getFunction(name: FunctionIdentifier): CatalogFunction = {
val db = name.database.getOrElse(currentDb)
externalCatalog.getFunction(db, name.funcName)
}


// ----------------------------------------------------------------
// | Methods that interact with temporary and metastore functions |
// ----------------------------------------------------------------
Expand All @@ -439,14 +461,14 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
* Create a temporary function.
* This assumes no database is specified in `funcDefinition`.
*/
def createTempFunction(funcDefinition: CatalogFunction, ignoreIfExists: Boolean): Unit = {
require(funcDefinition.identifier.database.isEmpty,
"attempted to create a temporary function while specifying a database")
val name = funcDefinition.identifier.funcName
if (tempFunctions.contains(name) && !ignoreIfExists) {
def createTempFunction(
name: String,
funcDefinition: FunctionBuilder,
ignoreIfExists: Boolean): Unit = {
if (functionRegistry.lookupFunctionBuilder(name).isDefined && !ignoreIfExists) {
throw new AnalysisException(s"Temporary function '$name' already exists.")
}
tempFunctions.put(name, funcDefinition)
functionRegistry.registerFunction(name, funcDefinition)
}

/**
Expand All @@ -456,11 +478,10 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
// Hive has DROP FUNCTION and DROP TEMPORARY FUNCTION. We may want to consolidate
// dropFunction and dropTempFunction.
def dropTempFunction(name: String, ignoreIfNotExists: Boolean): Unit = {
if (!tempFunctions.contains(name) && !ignoreIfNotExists) {
if (!functionRegistry.dropFunction(name) && !ignoreIfNotExists) {
throw new AnalysisException(
s"Temporary function '$name' cannot be dropped because it does not exist!")
}
tempFunctions.remove(name)
}

/**
Expand All @@ -477,42 +498,37 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
throw new AnalysisException("rename does not support moving functions across databases")
}
val db = oldName.database.getOrElse(currentDb)
if (oldName.database.isDefined || !tempFunctions.contains(oldName.funcName)) {
val oldBuilder = functionRegistry.lookupFunctionBuilder(oldName.funcName)
if (oldName.database.isDefined || oldBuilder.isEmpty) {
externalCatalog.renameFunction(db, oldName.funcName, newName.funcName)
} else {
val func = tempFunctions(oldName.funcName)
val newFunc = func.copy(identifier = func.identifier.copy(funcName = newName.funcName))
tempFunctions.remove(oldName.funcName)
tempFunctions.put(newName.funcName, newFunc)
val oldExpressionInfo = functionRegistry.lookupFunction(oldName.funcName).get
val newExpressionInfo = new ExpressionInfo(
oldExpressionInfo.getClassName,
newName.funcName,
oldExpressionInfo.getUsage,
oldExpressionInfo.getExtended)
functionRegistry.dropFunction(oldName.funcName)
functionRegistry.registerFunction(newName.funcName, newExpressionInfo, oldBuilder.get)
}
}

/**
* Retrieve the metadata of an existing function.
*
* If a database is specified in `name`, this will return the function in that database.
* If no database is specified, this will first attempt to return a temporary function with
* the same name, then, if that does not exist, return the function in the current database.
* Return an [[Expression]] that represents the specified function, assuming it exists.
* Note: This is currently only used for temporary functions.
*/
def getFunction(name: FunctionIdentifier): CatalogFunction = {
val db = name.database.getOrElse(currentDb)
if (name.database.isDefined || !tempFunctions.contains(name.funcName)) {
externalCatalog.getFunction(db, name.funcName)
} else {
tempFunctions(name.funcName)
}
def lookupFunction(name: String, children: Seq[Expression]): Expression = {
functionRegistry.lookupFunction(name, children)
}

// TODO: implement lookupFunction that returns something from the registry itself

/**
* List all matching functions in the specified database, including temporary functions.
*/
def listFunctions(db: String, pattern: String): Seq[FunctionIdentifier] = {
val dbFunctions =
externalCatalog.listFunctions(db, pattern).map { f => FunctionIdentifier(f, Some(db)) }
val regex = pattern.replaceAll("\\*", ".*").r
val _tempFunctions = tempFunctions.keys.toSeq
val _tempFunctions = functionRegistry.listFunction()
.filter { f => regex.pattern.matcher(f).matches() }
.map { f => FunctionIdentifier(f) }
dbFunctions ++ _tempFunctions
Expand All @@ -521,8 +537,8 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
/**
* Return a temporary function. For testing only.
*/
private[catalog] def getTempFunction(name: String): Option[CatalogFunction] = {
tempFunctions.get(name)
private[catalog] def getTempFunction(name: String): Option[FunctionBuilder] = {
functionRegistry.lookupFunctionBuilder(name)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ trait AnalysisTest extends PlanTest {

private def makeAnalyzer(caseSensitive: Boolean): Analyzer = {
val conf = new SimpleCatalystConf(caseSensitive)
val catalog = new SessionCatalog(new InMemoryCatalog, conf)
val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
catalog.createTempTable("TaBlE", TestRelations.testRelation, overrideIfExists = true)
new Analyzer(catalog, EmptyFunctionRegistry, conf) {
override val extendedResolutionRules = EliminateSubqueryAliases :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.types._

class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter {
private val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true)
private val catalog = new SessionCatalog(new InMemoryCatalog, conf)
private val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
private val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, conf)

private val relation = LocalRelation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.catalog
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
import org.apache.spark.sql.catalyst.plans.logical.{Range, SubqueryAlias}


Expand Down Expand Up @@ -682,20 +683,20 @@ class SessionCatalogSuite extends SparkFunSuite {

test("create temp function") {
val catalog = new SessionCatalog(newBasicCatalog())
val tempFunc1 = newFunc("temp1")
val tempFunc2 = newFunc("temp2")
catalog.createTempFunction(tempFunc1, ignoreIfExists = false)
catalog.createTempFunction(tempFunc2, ignoreIfExists = false)
val tempFunc1 = (e: Seq[Expression]) => e.head
val tempFunc2 = (e: Seq[Expression]) => e.last
catalog.createTempFunction("temp1", tempFunc1, ignoreIfExists = false)
catalog.createTempFunction("temp2", tempFunc2, ignoreIfExists = false)
assert(catalog.getTempFunction("temp1") == Some(tempFunc1))
assert(catalog.getTempFunction("temp2") == Some(tempFunc2))
assert(catalog.getTempFunction("temp3") == None)
val tempFunc3 = (e: Seq[Expression]) => Literal(e.size)
// Temporary function already exists
intercept[AnalysisException] {
catalog.createTempFunction(tempFunc1, ignoreIfExists = false)
catalog.createTempFunction("temp1", tempFunc3, ignoreIfExists = false)
}
// Temporary function is overridden
val tempFunc3 = tempFunc1.copy(className = "something else")
catalog.createTempFunction(tempFunc3, ignoreIfExists = true)
catalog.createTempFunction("temp1", tempFunc3, ignoreIfExists = true)
assert(catalog.getTempFunction("temp1") == Some(tempFunc3))
}

Expand Down Expand Up @@ -725,8 +726,8 @@ class SessionCatalogSuite extends SparkFunSuite {

test("drop temp function") {
val catalog = new SessionCatalog(newBasicCatalog())
val tempFunc = newFunc("func1")
catalog.createTempFunction(tempFunc, ignoreIfExists = false)
val tempFunc = (e: Seq[Expression]) => e.head
catalog.createTempFunction("func1", tempFunc, ignoreIfExists = false)
assert(catalog.getTempFunction("func1") == Some(tempFunc))
catalog.dropTempFunction("func1", ignoreIfNotExists = false)
assert(catalog.getTempFunction("func1") == None)
Expand Down Expand Up @@ -755,20 +756,15 @@ class SessionCatalogSuite extends SparkFunSuite {
}
}

test("get temp function") {
val externalCatalog = newBasicCatalog()
val sessionCatalog = new SessionCatalog(externalCatalog)
val metastoreFunc = externalCatalog.getFunction("db2", "func1")
val tempFunc = newFunc("func1").copy(className = "something weird")
sessionCatalog.createTempFunction(tempFunc, ignoreIfExists = false)
sessionCatalog.setCurrentDatabase("db2")
// If a database is specified, we'll always return the function in that database
assert(sessionCatalog.getFunction(FunctionIdentifier("func1", Some("db2"))) == metastoreFunc)
// If no database is specified, we'll first return temporary functions
assert(sessionCatalog.getFunction(FunctionIdentifier("func1")) == tempFunc)
// Then, if no such temporary function exist, check the current database
sessionCatalog.dropTempFunction("func1", ignoreIfNotExists = false)
assert(sessionCatalog.getFunction(FunctionIdentifier("func1")) == metastoreFunc)
test("lookup temp function") {
val catalog = new SessionCatalog(newBasicCatalog())
val tempFunc1 = (e: Seq[Expression]) => e.head
catalog.createTempFunction("func1", tempFunc1, ignoreIfExists = false)
assert(catalog.lookupFunction("func1", Seq(Literal(1), Literal(2), Literal(3))) == Literal(1))
catalog.dropTempFunction("func1", ignoreIfNotExists = false)
intercept[AnalysisException] {
catalog.lookupFunction("func1", Seq(Literal(1), Literal(2), Literal(3)))
}
}

test("rename function") {
Expand Down Expand Up @@ -813,8 +809,8 @@ class SessionCatalogSuite extends SparkFunSuite {
test("rename temp function") {
val externalCatalog = newBasicCatalog()
val sessionCatalog = new SessionCatalog(externalCatalog)
val tempFunc = newFunc("func1").copy(className = "something weird")
sessionCatalog.createTempFunction(tempFunc, ignoreIfExists = false)
val tempFunc = (e: Seq[Expression]) => e.head
sessionCatalog.createTempFunction("func1", tempFunc, ignoreIfExists = false)
sessionCatalog.setCurrentDatabase("db2")
// If a database is specified, we'll always rename the function in that database
sessionCatalog.renameFunction(
Expand All @@ -825,8 +821,7 @@ class SessionCatalogSuite extends SparkFunSuite {
// If no database is specified, we'll first rename temporary functions
sessionCatalog.createFunction(newFunc("func1", Some("db2")))
sessionCatalog.renameFunction(FunctionIdentifier("func1"), FunctionIdentifier("func4"))
assert(sessionCatalog.getTempFunction("func4") ==
Some(tempFunc.copy(identifier = FunctionIdentifier("func4"))))
assert(sessionCatalog.getTempFunction("func4") == Some(tempFunc))
assert(sessionCatalog.getTempFunction("func1") == None)
assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func1", "func3"))
// Then, if no such temporary function exist, rename the function in the current database
Expand Down Expand Up @@ -858,12 +853,12 @@ class SessionCatalogSuite extends SparkFunSuite {

test("list functions") {
val catalog = new SessionCatalog(newBasicCatalog())
val tempFunc1 = newFunc("func1").copy(className = "march")
val tempFunc2 = newFunc("yes_me").copy(className = "april")
val tempFunc1 = (e: Seq[Expression]) => e.head
val tempFunc2 = (e: Seq[Expression]) => e.last
catalog.createFunction(newFunc("func2", Some("db2")))
catalog.createFunction(newFunc("not_me", Some("db2")))
catalog.createTempFunction(tempFunc1, ignoreIfExists = false)
catalog.createTempFunction(tempFunc2, ignoreIfExists = false)
catalog.createTempFunction("func1", tempFunc1, ignoreIfExists = false)
catalog.createTempFunction("yes_me", tempFunc2, ignoreIfExists = false)
assert(catalog.listFunctions("db1", "*").toSet ==
Set(FunctionIdentifier("func1"),
FunctionIdentifier("yes_me")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {

private val caseInsensitiveConf = new SimpleCatalystConf(false)
private val caseInsensitiveAnalyzer = new Analyzer(
new SessionCatalog(new InMemoryCatalog, caseInsensitiveConf),
new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, caseInsensitiveConf),
EmptyFunctionRegistry,
caseInsensitiveConf)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.rules._

class EliminateSortsSuite extends PlanTest {
val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true, orderByOrdinal = false)
val catalog = new SessionCatalog(new InMemoryCatalog, conf)
val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, conf)

object Optimize extends RuleExecutor[LogicalPlan] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@ private[sql] class SessionState(ctx: SQLContext) {
lazy val experimentalMethods = new ExperimentalMethods

/**
* Internal catalog for managing table and database states.
* Internal catalog for managing functions registered by the user.
*/
lazy val catalog = new SessionCatalog(ctx.externalCatalog, conf)
lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy()

/**
* Internal catalog for managing functions registered by the user.
* Internal catalog for managing table and database states.
*/
lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy()
lazy val catalog = new SessionCatalog(ctx.externalCatalog, functionRegistry, conf)

/**
* Interface exposed to the user for registering user-defined functions.
Expand Down
Loading

0 comments on commit 27aab80

Please sign in to comment.