Skip to content

Commit

Permalink
[SPARK-49488][SQL] Improve the DS V2 pushdown framework for DayOfWeek…
Browse files Browse the repository at this point in the history
… and WeekDay
  • Loading branch information
beliefer committed Sep 2, 2024
1 parent 30152d0 commit eeca7e3
Show file tree
Hide file tree
Showing 11 changed files with 65 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest {
"scan with aggregate push-down: REGR_INTERCEPT with DISTINCT",
"scan with aggregate push-down: REGR_SLOPE with DISTINCT",
"scan with aggregate push-down: REGR_R2 with DISTINCT",
"scan with aggregate push-down: REGR_SXY with DISTINCT")
"scan with aggregate push-down: REGR_SXY with DISTINCT",
"scan with filter push-down with date time functions")

override val catalogName: String = "db2"
override val namespaceOpt: Option[String] = Some("DB2INST1")
Expand All @@ -68,6 +69,9 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest {
|)
""".stripMargin
).executeUpdate()
connection.prepareStatement(
"CREATE TABLE datetime (name VARCHAR(32), date1 DATE, time1 TIMESTAMP)")
.executeUpdate()
}

override def testUpdateColumnType(tbl: String): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ abstract class DockerJDBCIntegrationV2Suite extends DockerJDBCIntegrationSuite {
connection.prepareStatement("INSERT INTO pattern_testing_table "
+ "VALUES ('special_character_underscorenot_present')")
.executeUpdate()

connection.prepareStatement("INSERT INTO datetime VALUES " +
"('amy', '2022-05-19', '2022-05-19 00:00:00')").executeUpdate()
connection.prepareStatement("INSERT INTO datetime VALUES " +
"('alex', '2022-05-18', '2022-05-18 00:00:00')").executeUpdate()
}

def tablePreparation(connection: Connection): Unit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD
"scan with aggregate push-down: REGR_R2 with DISTINCT",
"scan with aggregate push-down: REGR_R2 without DISTINCT",
"scan with aggregate push-down: REGR_SXY with DISTINCT",
"scan with aggregate push-down: REGR_SXY without DISTINCT")
"scan with aggregate push-down: REGR_SXY without DISTINCT",
"scan with filter push-down with date time functions")

override val catalogName: String = "mssql"
override val db = new MsSQLServerDatabaseOnDocker
Expand All @@ -76,6 +77,9 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD
|)
""".stripMargin
).executeUpdate()
connection.prepareStatement(
"CREATE TABLE datetime (name VARCHAR(32), date1 DATE, time1 TIMESTAMP)")
.executeUpdate()
}

override def notSupportsTableComment: Boolean = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,11 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest
s"""CREATE TABLE pattern_testing_table (
|pattern_testing_col LONGTEXT
|)
""".stripMargin
|""".stripMargin
).executeUpdate()
connection.prepareStatement(
"CREATE TABLE datetime (name VARCHAR(32), date1 DATE, time1 TIMESTAMP)")
.executeUpdate()
}

override def testUpdateColumnType(tbl: String): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes
"scan with aggregate push-down: REGR_INTERCEPT with DISTINCT",
"scan with aggregate push-down: REGR_SLOPE with DISTINCT",
"scan with aggregate push-down: REGR_R2 with DISTINCT",
"scan with aggregate push-down: REGR_SXY with DISTINCT")
"scan with aggregate push-down: REGR_SXY with DISTINCT",
"scan with filter push-down with date time functions")

override val catalogName: String = "oracle"
override val namespaceOpt: Option[String] = Some("SYSTEM")
Expand Down Expand Up @@ -99,6 +100,9 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes
|)
""".stripMargin
).executeUpdate()
connection.prepareStatement(
"CREATE TABLE datetime (name VARCHAR(32), date1 DATE, time1 TIMESTAMP)")
.executeUpdate()
}

override def testUpdateColumnType(tbl: String): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ import org.apache.spark.tags.DockerTest
*/
@DockerTest
class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest {

override def excluded: Seq[String] = Seq(
"scan with filter push-down with date time functions")

override val catalogName: String = "postgresql"
override val db = new DatabaseOnDocker {
override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.3-alpine")
Expand Down Expand Up @@ -65,6 +69,9 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT
|)
""".stripMargin
).executeUpdate()
connection.prepareStatement(
"CREATE TABLE datetime (name VARCHAR(32), date1 DATE, time1 TIMESTAMP)")
.executeUpdate()
}

override def testUpdateColumnType(tbl: String): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -980,4 +980,14 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
)
}
}

test("scan with filter push-down with date time functions") {
val df1 = sql(s"SELECT name FROM $catalogAndNamespace.${caseConvert("datetime")} WHERE " +
"dayofyear(date1) > 100 AND dayofmonth(date1) > 10 ")
checkFilterPushed(df1)
val row = df1.collect()
assert(row.length === 2)
assert(row(0).getString(0) === "amy")
assert(row(1).getString(0) === "alex")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg,
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate}
import org.apache.spark.sql.execution.datasources.PushableExpression
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, StringType}
import org.apache.spark.sql.types.{BooleanType, DataType, StringType}

/**
* The builder to generate V2 expressions from catalyst expressions.
Expand Down Expand Up @@ -279,18 +279,10 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) extends L
generateExpression(child).map(v => new V2Extract("QUARTER", v))
case Year(child) =>
generateExpression(child).map(v => new V2Extract("YEAR", v))
// DayOfWeek uses Sunday = 1, Monday = 2, ... and ISO standard is Monday = 1, ...,
// so we use the formula ((ISO_standard % 7) + 1) to do translation.
case DayOfWeek(child) =>
generateExpression(child).map(v => new GeneralScalarExpression("+",
Array[V2Expression](new GeneralScalarExpression("%",
Array[V2Expression](new V2Extract("DAY_OF_WEEK", v), LiteralValue(7, IntegerType))),
LiteralValue(1, IntegerType))))
// WeekDay uses Monday = 0, Tuesday = 1, ... and ISO standard is Monday = 1, ...,
// so we use the formula (ISO_standard - 1) to do translation.
generateExpression(child).map(v => new V2Extract("DAY_OF_WEEK", v))
case WeekDay(child) =>
generateExpression(child).map(v => new GeneralScalarExpression("-",
Array[V2Expression](new V2Extract("DAY_OF_WEEK", v), LiteralValue(1, IntegerType))))
generateExpression(child).map(v => new V2Extract("WEEK_DAY", v))
case DayOfMonth(child) =>
generateExpression(child).map(v => new V2Extract("DAY", v))
case DayOfYear(child) =>
Expand Down
16 changes: 10 additions & 6 deletions sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -273,13 +273,17 @@ private[sql] case class H2Dialect() extends JdbcDialect with NoLegacyJDBCError {
}

override def visitExtract(field: String, source: String): String = {
val newField = field match {
case "DAY_OF_WEEK" => "ISO_DAY_OF_WEEK"
case "WEEK" => "ISO_WEEK"
case "YEAR_OF_WEEK" => "ISO_WEEK_YEAR"
case _ => field
field match {
// DayOfWeek uses Sunday = 1, Monday = 2, ... and ISO standard is Monday = 1, ...,
// so we use the formula ((ISO_standard % 7) + 1) to do translation.
case "DAY_OF_WEEK" => s"(EXTRACT(ISO_DAY_OF_WEEK FROM $source) % 7) + 1"
// WeekDay uses Monday = 0, Tuesday = 1, ... and ISO standard is Monday = 1, ...,
// so we use the formula (ISO_standard - 1) to do translation.
case "WEEK_DAY" => s"EXTRACT(ISO_DAY_OF_WEEK FROM $source) - 1"
case "WEEK" => s"EXTRACT(ISO_WEEK FROM $source)"
case "YEAR_OF_WEEK" => s"EXTRACT(ISO_WEEK_YEAR FROM $source)"
case _ => s"EXTRACT($field FROM $source)"
}
s"EXTRACT($newField FROM $source)"
}

override def visitSQLFunction(funcName: String, inputs: Array[String]): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ private case class MySQLDialect() extends JdbcDialect with SQLConfHelper with No
supportedFunctions.contains(funcName)

class MySQLSQLBuilder extends JDBCSQLBuilder {
override def visitExtract(field: String, source: String): String = {
field match {
case "DAY_OF_YEAR" => s"DAYOFYEAR($source)"
case "DAY_OF_WEEK" => s"(EXTRACT(ISO_DAY_OF_WEEK FROM $source) % 7) + 1"
case "WEEK_DAY" => s"WEEKDAY($source)"
case _ => super.visitExtract(field, source)
}
}

override def visitSortOrder(
sortKey: String, sortDirection: SortDirection, nullOrdering: NullOrdering): String = {
(sortDirection, nullOrdering) match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1579,15 +1579,15 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
"weekday(date1) = 2")
checkFiltersRemoved(df7)
val expectedPlanFragment7 =
"PushedFilters: [DATE1 IS NOT NULL, (EXTRACT(DAY_OF_WEEK FROM DATE1) - 1) = 2]"
"PushedFilters: [DATE1 IS NOT NULL, EXTRACT(WEEK_DAY FROM DATE1) = 2]"
checkPushedInfo(df7, expectedPlanFragment7)
checkAnswer(df7, Seq(Row("alex")))

val df8 = sql("SELECT name FROM h2.test.datetime WHERE " +
"dayofweek(date1) = 4")
checkFiltersRemoved(df8)
val expectedPlanFragment8 =
"PushedFilters: [DATE1 IS NOT NULL, ((EXTRACT(DAY_OF_WEEK FROM DATE1) % 7) + 1) = 4]"
"PushedFilters: [DATE1 IS NOT NULL, EXTRACT(DAY_OF_WEEK FROM DATE1) = 4]"
checkPushedInfo(df8, expectedPlanFragment8)
checkAnswer(df8, Seq(Row("alex")))

Expand Down

0 comments on commit eeca7e3

Please sign in to comment.