Skip to content

Commit

Permalink
[SPARK-37867][SQL][FOLLOWUP] Compile aggregate functions for build-in…
Browse files Browse the repository at this point in the history
… DB2 dialect

### What changes were proposed in this pull request?
This PR follows up apache#35166.
The previously referenced DB2 documentation is incorrect, resulting in the lack of compile that supports some aggregate functions.

The correct documentation is https://www.ibm.com/docs/en/db2/11.5?topic=af-regression-functions-regr-avgx-regr-avgy-regr-count

### Why are the changes needed?
Make build-in DB2 dialect support complete aggregate push-down more aggregate functions.

### Does this PR introduce _any_ user-facing change?
'Yes'.
Users could use complete aggregate push-down with build-in DB2 dialect.

### How was this patch tested?
New tests.

Closes apache#35520 from beliefer/SPARK-37867_followup.

Authored-by: Jiaan Geng <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
beliefer authored and chenzhx committed Feb 22, 2022
1 parent a4e8813 commit 4ff7e30
Show file tree
Hide file tree
Showing 11 changed files with 123 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,13 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest {
override def caseConvert(tableName: String): String = tableName.toUpperCase(Locale.ROOT)

testVarPop()
testVarPop(true)
testVarSamp()
testVarSamp(true)
testStddevPop()
testStddevPop(true)
testStddevSamp()
testStddevSamp(true)
testCovarPop()
testCovarSamp()
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,11 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD
}

testVarPop()
testVarPop(true)
testVarSamp()
testVarSamp(true)
testStddevPop()
testStddevPop(true)
testStddevSamp()
testStddevSamp(true)
}
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,17 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT
override def indexOptions: String = "FILLFACTOR=70"

testVarPop()
testVarPop(true)
testVarSamp()
testVarSamp(true)
testStddevPop()
testStddevPop(true)
testStddevSamp()
testStddevSamp(true)
testCovarPop()
testCovarPop(true)
testCovarSamp()
testCovarSamp(true)
testCorr()
testCorr(true)
}
Original file line number Diff line number Diff line change
Expand Up @@ -386,10 +386,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu

protected def caseConvert(tableName: String): String = tableName

protected def testVarPop(): Unit = {
test(s"scan with aggregate push-down: VAR_POP") {
val df = sql(s"SELECT VAR_POP(bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" +
" WHERE dept > 0 GROUP BY dept ORDER BY dept")
protected def testVarPop(isDistinct: Boolean = false): Unit = {
val distinct = if (isDistinct) "DISTINCT " else ""
test(s"scan with aggregate push-down: VAR_POP with distinct: $isDistinct") {
val df = sql(s"SELECT VAR_POP(${distinct}bonus) FROM $catalogAndNamespace." +
s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept")
checkFilterPushed(df)
checkAggregateRemoved(df)
checkAggregatePushed(df, "VAR_POP")
Expand All @@ -401,11 +402,12 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
}
}

protected def testVarSamp(): Unit = {
test(s"scan with aggregate push-down: VAR_SAMP") {
protected def testVarSamp(isDistinct: Boolean = false): Unit = {
val distinct = if (isDistinct) "DISTINCT " else ""
test(s"scan with aggregate push-down: VAR_SAMP with distinct: $isDistinct") {
val df = sql(
s"SELECT VAR_SAMP(bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" +
" WHERE dept > 0 GROUP BY dept ORDER BY dept")
s"SELECT VAR_SAMP(${distinct}bonus) FROM $catalogAndNamespace." +
s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept")
checkFilterPushed(df)
checkAggregateRemoved(df)
checkAggregatePushed(df, "VAR_SAMP")
Expand All @@ -417,11 +419,12 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
}
}

protected def testStddevPop(): Unit = {
test("scan with aggregate push-down: STDDEV_POP") {
protected def testStddevPop(isDistinct: Boolean = false): Unit = {
val distinct = if (isDistinct) "DISTINCT " else ""
test(s"scan with aggregate push-down: STDDEV_POP with distinct: $isDistinct") {
val df = sql(
s"SELECT STDDEV_POP(bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" +
" WHERE dept > 0 GROUP BY dept ORDER BY dept")
s"SELECT STDDEV_POP(${distinct}bonus) FROM $catalogAndNamespace." +
s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept")
checkFilterPushed(df)
checkAggregateRemoved(df)
checkAggregatePushed(df, "STDDEV_POP")
Expand All @@ -433,11 +436,12 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
}
}

protected def testStddevSamp(): Unit = {
test("scan with aggregate push-down: STDDEV_SAMP") {
protected def testStddevSamp(isDistinct: Boolean = false): Unit = {
val distinct = if (isDistinct) "DISTINCT " else ""
test(s"scan with aggregate push-down: STDDEV_SAMP with distinct: $isDistinct") {
val df = sql(
s"SELECT STDDEV_SAMP(bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" +
" WHERE dept > 0 GROUP BY dept ORDER BY dept")
s"SELECT STDDEV_SAMP(${distinct}bonus) FROM $catalogAndNamespace." +
s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept")
checkFilterPushed(df)
checkAggregateRemoved(df)
checkAggregatePushed(df, "STDDEV_SAMP")
Expand All @@ -449,11 +453,12 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
}
}

protected def testCovarPop(): Unit = {
test("scan with aggregate push-down: COVAR_POP") {
protected def testCovarPop(isDistinct: Boolean = false): Unit = {
val distinct = if (isDistinct) "DISTINCT " else ""
test(s"scan with aggregate push-down: COVAR_POP with distinct: $isDistinct") {
val df = sql(
s"SELECT COVAR_POP(bonus, bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" +
" WHERE dept > 0 GROUP BY dept ORDER BY dept")
s"SELECT COVAR_POP(${distinct}bonus, bonus) FROM $catalogAndNamespace." +
s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept")
checkFilterPushed(df)
checkAggregateRemoved(df)
checkAggregatePushed(df, "COVAR_POP")
Expand All @@ -465,11 +470,12 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
}
}

protected def testCovarSamp(): Unit = {
test("scan with aggregate push-down: COVAR_SAMP") {
protected def testCovarSamp(isDistinct: Boolean = false): Unit = {
val distinct = if (isDistinct) "DISTINCT " else ""
test(s"scan with aggregate push-down: COVAR_SAMP with distinct: $isDistinct") {
val df = sql(
s"SELECT COVAR_SAMP(bonus, bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" +
" WHERE dept > 0 GROUP BY dept ORDER BY dept")
s"SELECT COVAR_SAMP(${distinct}bonus, bonus) FROM $catalogAndNamespace." +
s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept")
checkFilterPushed(df)
checkAggregateRemoved(df)
checkAggregatePushed(df, "COVAR_SAMP")
Expand All @@ -481,11 +487,12 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
}
}

protected def testCorr(): Unit = {
test("scan with aggregate push-down: CORR") {
protected def testCorr(isDistinct: Boolean = false): Unit = {
val distinct = if (isDistinct) "DISTINCT " else ""
test(s"scan with aggregate push-down: CORR with distinct: $isDistinct") {
val df = sql(
s"SELECT CORR(bonus, bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" +
" WHERE dept > 0 GROUP BY dept ORDER BY dept")
s"SELECT CORR(${distinct}bonus, bonus) FROM $catalogAndNamespace." +
s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept")
checkFilterPushed(df)
checkAggregateRemoved(df)
checkAggregatePushed(df, "CORR")
Expand Down
19 changes: 19 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,32 @@ private object DB2Dialect extends JdbcDialect {
override def canHandle(url: String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:db2")

// See https://www.ibm.com/docs/en/db2/11.5?topic=functions-aggregate
override def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
super.compileAggregate(aggFunction).orElse(
aggFunction match {
case f: GeneralAggregateFunc if f.name() == "VAR_POP" =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"VARIANCE($distinct${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"VARIANCE_SAMP($distinct${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"STDDEV($distinct${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"STDDEV_SAMP($distinct${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && f.isDistinct == false =>
assert(f.inputs().length == 2)
Some(s"COVARIANCE(${f.inputs().head}, ${f.inputs().last})")
case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && f.isDistinct == false =>
assert(f.inputs().length == 2)
Some(s"COVARIANCE_SAMP(${f.inputs().head}, ${f.inputs().last})")
case _ => None
}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,22 @@ private object DerbyDialect extends JdbcDialect {
override def canHandle(url: String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:derby")

// See https://db.apache.org/derby/docs/10.15/ref/index.html
override def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
super.compileAggregate(aggFunction).orElse(
aggFunction match {
case f: GeneralAggregateFunc if f.name() == "VAR_POP" =>
case f: GeneralAggregateFunc if f.name() == "VAR_POP" && f.isDistinct == false =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"VAR_POP($distinct${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" =>
Some(s"VAR_POP(${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" && f.isDistinct == false =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"VAR_SAMP($distinct${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" =>
Some(s"VAR_SAMP(${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" && f.isDistinct == false =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"STDDEV_POP($distinct${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" =>
Some(s"STDDEV_POP(${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" && f.isDistinct == false =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"STDDEV_SAMP($distinct${f.inputs().head})")
Some(s"STDDEV_SAMP(${f.inputs().head})")
case _ => None
}
)
Expand All @@ -72,7 +69,7 @@ private object DerbyDialect extends JdbcDialect {

override def isCascadingTruncateTable(): Option[Boolean] = Some(false)

// See https://db.apache.org/derby/docs/10.5/ref/rrefsqljrenametablestatement.html
// See https://db.apache.org/derby/docs/10.15/ref/rrefsqljrenametablestatement.html
override def renameTable(oldTable: String, newTable: String): String = {
s"RENAME TABLE $oldTable TO $newTable"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ private object MsSqlServerDialect extends JdbcDialect {
override def canHandle(url: String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:sqlserver")

// scalastyle:off line.size.limit
// See https://docs.microsoft.com/en-us/sql/t-sql/functions/aggregate-functions-transact-sql?view=sql-server-ver15
// scalastyle:on line.size.limit
override def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
super.compileAggregate(aggFunction).orElse(
aggFunction match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,22 @@ private case object MySQLDialect extends JdbcDialect with SQLConfHelper {
override def canHandle(url : String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:mysql")

// See https://dev.mysql.com/doc/refman/8.0/en/aggregate-functions.html
override def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
super.compileAggregate(aggFunction).orElse(
aggFunction match {
case f: GeneralAggregateFunc if f.name() == "VAR_POP" =>
case f: GeneralAggregateFunc if f.name() == "VAR_POP" && f.isDistinct == false =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"VAR_POP($distinct${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" =>
Some(s"VAR_POP(${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" && f.isDistinct == false =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"VAR_SAMP($distinct${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" =>
Some(s"VAR_SAMP(${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" && f.isDistinct == false =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"STDDEV_POP($distinct${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" =>
Some(s"STDDEV_POP(${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" && f.isDistinct == false =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"STDDEV_SAMP($distinct${f.inputs().head})")
Some(s"STDDEV_SAMP(${f.inputs().head})")
case _ => None
}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,37 +34,33 @@ private case object OracleDialect extends JdbcDialect {
override def canHandle(url: String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:oracle")

// scalastyle:off line.size.limit
// https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/Aggregate-Functions.html#GUID-62BE676B-AF18-4E63-BD14-25206FEA0848
// scalastyle:on line.size.limit
override def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
super.compileAggregate(aggFunction).orElse(
aggFunction match {
case f: GeneralAggregateFunc if f.name() == "VAR_POP" =>
case f: GeneralAggregateFunc if f.name() == "VAR_POP" && f.isDistinct == false =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"VAR_POP($distinct${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" =>
Some(s"VAR_POP(${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" && f.isDistinct == false =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"VAR_SAMP($distinct${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" =>
Some(s"VAR_SAMP(${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" && f.isDistinct == false =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"STDDEV_POP($distinct${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" =>
Some(s"STDDEV_POP(${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" && f.isDistinct == false =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"STDDEV_SAMP($distinct${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "COVAR_POP" =>
Some(s"STDDEV_SAMP(${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && f.isDistinct == false =>
assert(f.inputs().length == 2)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"COVAR_POP($distinct${f.inputs().head}, ${f.inputs().last})")
case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" =>
Some(s"COVAR_POP(${f.inputs().head}, ${f.inputs().last})")
case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && f.isDistinct == false =>
assert(f.inputs().length == 2)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"COVAR_SAMP($distinct${f.inputs().head}, ${f.inputs().last})")
case f: GeneralAggregateFunc if f.name() == "CORR" =>
Some(s"COVAR_SAMP(${f.inputs().head}, ${f.inputs().last})")
case f: GeneralAggregateFunc if f.name() == "CORR" && f.isDistinct == false =>
assert(f.inputs().length == 2)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"CORR($distinct${f.inputs().head}, ${f.inputs().last})")
Some(s"CORR(${f.inputs().head}, ${f.inputs().last})")
case _ => None
}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ private object PostgresDialect extends JdbcDialect with SQLConfHelper {
override def canHandle(url: String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:postgresql")

// See https://www.postgresql.org/docs/8.4/functions-aggregate.html
override def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
super.compileAggregate(aggFunction).orElse(
aggFunction match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ private case object TeradataDialect extends JdbcDialect {
override def canHandle(url: String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:teradata")

// scalastyle:off line.size.limit
// See https://docs.teradata.com/r/Teradata-VantageTM-SQL-Functions-Expressions-and-Predicates/March-2019/Aggregate-Functions
// scalastyle:on line.size.limit
override def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
super.compileAggregate(aggFunction).orElse(
aggFunction match {
Expand All @@ -47,18 +50,15 @@ private case object TeradataDialect extends JdbcDialect {
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"STDDEV_SAMP($distinct${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "COVAR_POP" =>
case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && f.isDistinct == false =>
assert(f.inputs().length == 2)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"COVAR_POP($distinct${f.inputs().head}, ${f.inputs().last})")
case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" =>
Some(s"COVAR_POP(${f.inputs().head}, ${f.inputs().last})")
case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && f.isDistinct == false =>
assert(f.inputs().length == 2)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"COVAR_SAMP($distinct${f.inputs().head}, ${f.inputs().last})")
case f: GeneralAggregateFunc if f.name() == "CORR" =>
Some(s"COVAR_SAMP(${f.inputs().head}, ${f.inputs().last})")
case f: GeneralAggregateFunc if f.name() == "CORR" && f.isDistinct == false =>
assert(f.inputs().length == 2)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"CORR($distinct${f.inputs().head}, ${f.inputs().last})")
Some(s"CORR(${f.inputs().head}, ${f.inputs().last})")
case _ => None
}
)
Expand Down

0 comments on commit 4ff7e30

Please sign in to comment.