diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 7a668b75c3c73..40afdbe456fc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -55,21 +55,6 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat import DataSourceV2Implicits._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ - private def withProjectAndFilter( - project: Seq[NamedExpression], - filters: Seq[Expression], - scan: LeafExecNode, - needsUnsafeConversion: Boolean): SparkPlan = { - val filterCondition = filters.reduceLeftOption(And) - val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) - - if (withFilter.output != project || needsUnsafeConversion) { - ProjectExec(project, withFilter) - } else { - withFilter - } - } - private def refreshCache(r: DataSourceV2Relation)(): Unit = { session.sharedState.cacheManager.recacheByPlan(session, r) } @@ -130,12 +115,14 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat unsafeRowRDD, v1Relation, tableIdentifier) - withProjectAndFilter(project, filters, dsScan, needsUnsafeConversion = false) :: Nil + DataSourceV2Strategy.withProjectAndFilter( + project, filters, dsScan, needsUnsafeConversion = false) :: Nil case PhysicalOperation(project, filters, DataSourceV2ScanRelation(_, scan: LocalScan, output, _, _)) => val localScanExec = LocalTableScanExec(output, scan.rows().toImmutableArraySeq) - withProjectAndFilter(project, filters, localScanExec, needsUnsafeConversion = false) :: Nil + DataSourceV2Strategy.withProjectAndFilter( + project, filters, localScanExec, needsUnsafeConversion = false) :: Nil case PhysicalOperation(project, filters, relation: DataSourceV2ScanRelation) => // projection and filters were already pushed down in the optimizer. @@ -148,7 +135,8 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat val batchExec = BatchScanExec(relation.output, relation.scan, runtimeFilters, relation.ordering, relation.relation.table, StoragePartitionJoinParams(relation.keyGroupedPartitioning)) - withProjectAndFilter(project, postScanFilters, batchExec, !batchExec.supportsColumnar) :: Nil + DataSourceV2Strategy.withProjectAndFilter( + project, postScanFilters, batchExec, !batchExec.supportsColumnar) :: Nil case PhysicalOperation(p, f, r: StreamingDataSourceV2ScanRelation) if r.startOffset.isDefined && r.endOffset.isDefined => @@ -158,7 +146,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat r.output, r.scan, microBatchStream, r.startOffset.get, r.endOffset.get) // Add a Project here to make sure we produce unsafe rows. - withProjectAndFilter(p, f, scanExec, !scanExec.supportsColumnar) :: Nil + DataSourceV2Strategy.withProjectAndFilter(p, f, scanExec, !scanExec.supportsColumnar) :: Nil case PhysicalOperation(p, f, r: StreamingDataSourceV2ScanRelation) if r.startOffset.isDefined && r.endOffset.isEmpty => @@ -167,7 +155,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat val scanExec = ContinuousScanExec(r.output, r.scan, continuousStream, r.startOffset.get) // Add a Project here to make sure we produce unsafe rows. - withProjectAndFilter(p, f, scanExec, !scanExec.supportsColumnar) :: Nil + DataSourceV2Strategy.withProjectAndFilter(p, f, scanExec, !scanExec.supportsColumnar) :: Nil case WriteToDataSourceV2(relationOpt, writer, query, customMetrics) => val invalidateCacheFunc: () => Unit = () => relationOpt match { @@ -654,6 +642,21 @@ private[sql] object DataSourceV2Strategy extends Logging { logWarning(log"Can't translate ${MDC(EXPR, other)} to source filter, unsupported expression") None } + + private def withProjectAndFilter( + project: Seq[NamedExpression], + filters: Seq[Expression], + scan: LeafExecNode, + needsUnsafeConversion: Boolean): SparkPlan = { + val filterCondition = filters.reduceLeftOption(And) + val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) + + if (withFilter.output != project || needsUnsafeConversion) { + ProjectExec(project, withFilter) + } else { + withFilter + } + } } /**