Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SQL] Tighten the visibility of various SQLConf methods and renamed setter/getters #1794

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 27 additions & 28 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@

package org.apache.spark.sql

import scala.collection.immutable
import scala.collection.JavaConversions._
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: under the java import & switch order of these two


import java.util.Properties

import scala.collection.JavaConverters._

object SQLConf {
private[spark] object SQLConf {
val COMPRESS_CACHED = "spark.sql.inMemoryColumnarStorage.compressed"
val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold"
val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes"
val AUTO_CONVERT_JOIN_SIZE = "spark.sql.auto.convert.join.size"
val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions"
val JOIN_BROADCAST_TABLES = "spark.sql.join.broadcastTables"
val CODEGEN_ENABLED = "spark.sql.codegen"
val DIALECT = "spark.sql.dialect"

Expand Down Expand Up @@ -66,13 +66,13 @@ trait SQLConf {
* Note that the choice of dialect does not affect things like what tables are available or
* how query execution is performed.
*/
private[spark] def dialect: String = get(DIALECT, "sql")
private[spark] def dialect: String = getConf(DIALECT, "sql")

/** When true tables cached using the in-memory columnar caching will be compressed. */
private[spark] def useCompression: Boolean = get(COMPRESS_CACHED, "false").toBoolean
private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED, "false").toBoolean

/** Number of partitions to use for shuffle operators. */
private[spark] def numShufflePartitions: Int = get(SHUFFLE_PARTITIONS, "200").toInt
private[spark] def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS, "200").toInt

/**
* When set to true, Spark SQL will use the Scala compiler at runtime to generate custom bytecode
Expand All @@ -84,7 +84,7 @@ trait SQLConf {
* Defaults to false as this feature is currently experimental.
*/
private[spark] def codegenEnabled: Boolean =
if (get(CODEGEN_ENABLED, "false") == "true") true else false
if (getConf(CODEGEN_ENABLED, "false") == "true") true else false

/**
* Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to
Expand All @@ -94,49 +94,48 @@ trait SQLConf {
* Hive setting: hive.auto.convert.join.noconditionaltask.size, whose default value is also 10000.
*/
private[spark] def autoBroadcastJoinThreshold: Int =
get(AUTO_BROADCASTJOIN_THRESHOLD, "10000").toInt
getConf(AUTO_BROADCASTJOIN_THRESHOLD, "10000").toInt

/**
* The default size in bytes to assign to a logical operator's estimation statistics. By default,
* it is set to a larger value than `autoConvertJoinSize`, hence any logical operator without a
* properly implemented estimation of this statistic will not be incorrectly broadcasted in joins.
*/
private[spark] def defaultSizeInBytes: Long =
getOption(DEFAULT_SIZE_IN_BYTES).map(_.toLong).getOrElse(autoBroadcastJoinThreshold + 1)
getConf(DEFAULT_SIZE_IN_BYTES, (autoBroadcastJoinThreshold + 1).toString).toLong

/** ********************** SQLConf functionality methods ************ */

def set(props: Properties): Unit = {
settings.synchronized {
props.asScala.foreach { case (k, v) => settings.put(k, v) }
}
/** Set Spark SQL configuration properties. */
def setConf(props: Properties): Unit = settings.synchronized {
props.foreach { case (k, v) => settings.put(k, v) }
}

def set(key: String, value: String): Unit = {
/** Set the given Spark SQL configuration property. */
def setConf(key: String, value: String): Unit = {
require(key != null, "key cannot be null")
require(value != null, s"value cannot be null for key: $key")
settings.put(key, value)
}

def get(key: String): String = {
/** Return the value of Spark SQL configuration property for the given key. */
def getConf(key: String): String = {
Option(settings.get(key)).getOrElse(throw new NoSuchElementException(key))
}

def get(key: String, defaultValue: String): String = {
/**
* Return the value of Spark SQL configuration property for the given key. If the key is not set
* yet, return `defaultValue`.
*/
def getConf(key: String, defaultValue: String): String = {
Option(settings.get(key)).getOrElse(defaultValue)
}

def getAll: Array[(String, String)] = settings.synchronized { settings.asScala.toArray }

def getOption(key: String): Option[String] = Option(settings.get(key))

def contains(key: String): Boolean = settings.containsKey(key)

def toDebugString: String = {
settings.synchronized {
settings.asScala.toArray.sorted.map{ case (k, v) => s"$k=$v" }.mkString("\n")
}
}
/**
* Return all the configuration properties that have been set (i.e. not the default).
* This creates a new copy of the config properties in the form of a Map.
*/
def getAllConfs: immutable.Map[String, String] = settings.synchronized { settings.toMap }

private[spark] def clear() {
settings.clear()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ case class SetCommand(
if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) {
logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " +
s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS} instead.")
context.set(SQLConf.SHUFFLE_PARTITIONS, v)
context.setConf(SQLConf.SHUFFLE_PARTITIONS, v)
Array(s"${SQLConf.SHUFFLE_PARTITIONS}=$v")
} else {
context.set(k, v)
context.setConf(k, v)
Array(s"$k=$v")
}

Expand All @@ -77,14 +77,14 @@ case class SetCommand(
"system:sun.java.command=shark.SharkServer2")
}
else {
Array(s"$k=${context.getOption(k).getOrElse("<undefined>")}")
Array(s"$k=${context.getConf(k, "<undefined>")}")
}

// Query all key-value pairs that are set in the SQLConf of the context.
case (None, None) =>
context.getAll.map { case (k, v) =>
context.getAllConfs.map { case (k, v) =>
s"$k=$v"
}
}.toSeq

case _ =>
throw new IllegalArgumentException()
Expand Down
33 changes: 15 additions & 18 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,50 +29,47 @@ class SQLConfSuite extends QueryTest {

test("programmatic ways of basic setting and getting") {
clear()
assert(getOption(testKey).isEmpty)
assert(getAll.toSet === Set())
assert(getAllConfs.size === 0)

set(testKey, testVal)
assert(get(testKey) == testVal)
assert(get(testKey, testVal + "_") == testVal)
assert(getOption(testKey) == Some(testVal))
assert(contains(testKey))
setConf(testKey, testVal)
assert(getConf(testKey) == testVal)
assert(getConf(testKey, testVal + "_") == testVal)
assert(getAllConfs.contains(testKey))

// Tests SQLConf as accessed from a SQLContext is mutable after
// the latter is initialized, unlike SparkConf inside a SparkContext.
assert(TestSQLContext.get(testKey) == testVal)
assert(TestSQLContext.get(testKey, testVal + "_") == testVal)
assert(TestSQLContext.getOption(testKey) == Some(testVal))
assert(TestSQLContext.contains(testKey))
assert(TestSQLContext.getConf(testKey) == testVal)
assert(TestSQLContext.getConf(testKey, testVal + "_") == testVal)
assert(TestSQLContext.getAllConfs.contains(testKey))

clear()
}

test("parse SQL set commands") {
clear()
sql(s"set $testKey=$testVal")
assert(get(testKey, testVal + "_") == testVal)
assert(TestSQLContext.get(testKey, testVal + "_") == testVal)
assert(getConf(testKey, testVal + "_") == testVal)
assert(TestSQLContext.getConf(testKey, testVal + "_") == testVal)

sql("set some.property=20")
assert(get("some.property", "0") == "20")
assert(getConf("some.property", "0") == "20")
sql("set some.property = 40")
assert(get("some.property", "0") == "40")
assert(getConf("some.property", "0") == "40")

val key = "spark.sql.key"
val vs = "val0,val_1,val2.3,my_table"
sql(s"set $key=$vs")
assert(get(key, "0") == vs)
assert(getConf(key, "0") == vs)

sql(s"set $key=")
assert(get(key, "0") == "")
assert(getConf(key, "0") == "")

clear()
}

test("deprecated property") {
clear()
sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10")
assert(get(SQLConf.SHUFFLE_PARTITIONS) == "10")
assert(getConf(SQLConf.SHUFFLE_PARTITIONS) == "10")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ class LocalHiveContext(sc: SparkContext) extends HiveContext(sc) {

/** Sets up the system initially or after a RESET command */
protected def configure() {
set("javax.jdo.option.ConnectionURL",
setConf("javax.jdo.option.ConnectionURL",
s"jdbc:derby:;databaseName=$metastorePath;create=true")
set("hive.metastore.warehouse.dir", warehousePath)
setConf("hive.metastore.warehouse.dir", warehousePath)
}

configure() // Must be called before initializing the catalog below.
Expand All @@ -76,7 +76,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
self =>

// Change the default SQL dialect to HiveQL
override private[spark] def dialect: String = get(SQLConf.DIALECT, "hiveql")
override private[spark] def dialect: String = getConf(SQLConf.DIALECT, "hiveql")

override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution =
new this.QueryExecution { val logical = plan }
Expand Down Expand Up @@ -224,15 +224,15 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
@transient protected[hive] lazy val hiveconf = new HiveConf(classOf[SessionState])
@transient protected[hive] lazy val sessionState = {
val ss = new SessionState(hiveconf)
set(hiveconf.getAllProperties) // Have SQLConf pick up the initial set of HiveConf.
setConf(hiveconf.getAllProperties) // Have SQLConf pick up the initial set of HiveConf.
ss
}

sessionState.err = new PrintStream(outputBuffer, true, "UTF-8")
sessionState.out = new PrintStream(outputBuffer, true, "UTF-8")

override def set(key: String, value: String): Unit = {
super.set(key, value)
override def setConf(key: String, value: String): Unit = {
super.setConf(key, value)
runSqlHive(s"SET $key=$value")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {

/** Sets up the system initially or after a RESET command */
protected def configure() {
set("javax.jdo.option.ConnectionURL",
setConf("javax.jdo.option.ConnectionURL",
s"jdbc:derby:;databaseName=$metastorePath;create=true")
set("hive.metastore.warehouse.dir", warehousePath)
setConf("hive.metastore.warehouse.dir", warehousePath)
}

configure() // Must be called before initializing the catalog below.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ class HiveQuerySuite extends HiveComparisonTest {
"SELECT 2 / 1, 1 / 2, 1 / 3, 1 / COUNT(*) FROM src LIMIT 1")

test("Query expressed in SQL") {
set("spark.sql.dialect", "sql")
setConf("spark.sql.dialect", "sql")
assert(sql("SELECT 1").collect() === Array(Seq(1)))
set("spark.sql.dialect", "hiveql")
setConf("spark.sql.dialect", "hiveql")

}

Expand Down Expand Up @@ -436,18 +436,18 @@ class HiveQuerySuite extends HiveComparisonTest {
val testVal = "val0,val_1,val2.3,my_table"

sql(s"set $testKey=$testVal")
assert(get(testKey, testVal + "_") == testVal)
assert(getConf(testKey, testVal + "_") == testVal)

sql("set some.property=20")
assert(get("some.property", "0") == "20")
assert(getConf("some.property", "0") == "20")
sql("set some.property = 40")
assert(get("some.property", "0") == "40")
assert(getConf("some.property", "0") == "40")

sql(s"set $testKey=$testVal")
assert(get(testKey, "0") == testVal)
assert(getConf(testKey, "0") == testVal)

sql(s"set $testKey=")
assert(get(testKey, "0") == "")
assert(getConf(testKey, "0") == "")
}

test("SET commands semantics for a HiveContext") {
Expand Down