-
Notifications
You must be signed in to change notification settings - Fork 44
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Improve][Spark] Improve the performance of GraphAr Spark Reader (#84)
- Loading branch information
1 parent
4b002a4
commit d223858
Showing
8 changed files
with
567 additions
and
76 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
59 changes: 59 additions & 0 deletions
59
spark/src/main/scala/com/alibaba/graphar/datasources/GarDataSource.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
/** Copyright 2022 Alibaba Group Holding Limited. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.alibaba.graphar.datasources | ||
|
||
import org.apache.spark.sql.connector.catalog.Table | ||
import org.apache.spark.sql.execution.datasources._ | ||
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat | ||
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat | ||
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat | ||
import org.apache.spark.sql.execution.datasources.v2._ | ||
import org.apache.spark.sql.types.StructType | ||
import org.apache.spark.sql.util.CaseInsensitiveStringMap | ||
|
||
/** GarDataSource is a class to provide gar files as the data source for spark. */ | ||
class GarDataSource extends FileDataSourceV2 { | ||
|
||
/** The default fallback file format is Parquet. */ | ||
override def fallbackFileFormat: Class[_ <: FileFormat] = classOf[ParquetFileFormat] | ||
|
||
/** The string that represents the format name. */ | ||
override def shortName(): String = "gar" | ||
|
||
/** Provide a table from the data source. */ | ||
override def getTable(options: CaseInsensitiveStringMap): Table = { | ||
val paths = getPaths(options) | ||
val tableName = getTableName(options, paths) | ||
val optionsWithoutPaths = getOptionsWithoutPaths(options) | ||
GarTable(tableName, sparkSession, optionsWithoutPaths, paths, None, getFallbackFileFormat(options)) | ||
} | ||
|
||
/** Provide a table from the data source with specific schema. */ | ||
override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { | ||
val paths = getPaths(options) | ||
val tableName = getTableName(options, paths) | ||
val optionsWithoutPaths = getOptionsWithoutPaths(options) | ||
GarTable(tableName, sparkSession, optionsWithoutPaths, paths, Some(schema), getFallbackFileFormat(options)) | ||
} | ||
|
||
// Get the actual fall back file format. | ||
private def getFallbackFileFormat(options: CaseInsensitiveStringMap): Class[_ <: FileFormat] = options.get("fileFormat") match { | ||
case "csv" => classOf[CSVFileFormat] | ||
case "orc" => classOf[OrcFileFormat] | ||
case "parquet" => classOf[ParquetFileFormat] | ||
case _ => throw new IllegalArgumentException | ||
} | ||
} |
237 changes: 237 additions & 0 deletions
237
spark/src/main/scala/com/alibaba/graphar/datasources/GarScan.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,237 @@ | ||
/** Copyright 2022 Alibaba Group Holding Limited. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.alibaba.graphar.datasources | ||
|
||
import scala.collection.JavaConverters._ | ||
import scala.collection.mutable | ||
import scala.collection.mutable.ArrayBuffer | ||
|
||
import org.apache.hadoop.conf.Configuration | ||
import org.apache.hadoop.fs.Path | ||
import org.apache.parquet.hadoop.ParquetInputFormat | ||
|
||
import org.apache.spark.sql.SparkSession | ||
import org.apache.spark.sql.catalyst.expressions.{Expression, ExprUtils} | ||
import org.apache.spark.sql.catalyst.csv.CSVOptions | ||
import org.apache.spark.sql.connector.read.PartitionReaderFactory | ||
import org.apache.spark.sql.execution.PartitionedFileUtil | ||
import org.apache.spark.sql.execution.datasources.{FilePartition, PartitioningAwareFileIndex, PartitionedFile} | ||
import org.apache.spark.sql.execution.datasources.csv.CSVDataSource | ||
import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetReadSupport, ParquetWriteSupport} | ||
import org.apache.spark.sql.execution.datasources.v2.FileScan | ||
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetPartitionReaderFactory | ||
import org.apache.spark.sql.execution.datasources.v2.orc.OrcPartitionReaderFactory | ||
import org.apache.spark.sql.execution.datasources.v2.csv.CSVPartitionReaderFactory | ||
import org.apache.spark.sql.internal.SQLConf | ||
import org.apache.spark.sql.sources.Filter | ||
import org.apache.spark.sql.types.StructType | ||
import org.apache.spark.sql.util.CaseInsensitiveStringMap | ||
import org.apache.spark.util.SerializableConfiguration | ||
|
||
/** GarScan is a class to implement the file scan for GarDataSource. */ | ||
case class GarScan( | ||
sparkSession: SparkSession, | ||
hadoopConf: Configuration, | ||
fileIndex: PartitioningAwareFileIndex, | ||
dataSchema: StructType, | ||
readDataSchema: StructType, | ||
readPartitionSchema: StructType, | ||
pushedFilters: Array[Filter], | ||
options: CaseInsensitiveStringMap, | ||
formatName: String, | ||
partitionFilters: Seq[Expression] = Seq.empty, | ||
dataFilters: Seq[Expression] = Seq.empty) extends FileScan { | ||
|
||
/** The gar format is not splitable. */ | ||
override def isSplitable(path: Path): Boolean = false | ||
|
||
/** Create the reader factory according to the actual file format. */ | ||
override def createReaderFactory(): PartitionReaderFactory = formatName match { | ||
case "csv" => createCSVReaderFactory() | ||
case "orc" => createOrcReaderFactory() | ||
case "parquet" => createParquetReaderFactory() | ||
case _ => throw new IllegalArgumentException | ||
} | ||
|
||
// Create the reader factory for the CSV format. | ||
private def createCSVReaderFactory(): PartitionReaderFactory = { | ||
val columnPruning = sparkSession.sessionState.conf.csvColumnPruning && | ||
!readDataSchema.exists(_.name == sparkSession.sessionState.conf.columnNameOfCorruptRecord) | ||
|
||
val parsedOptions: CSVOptions = new CSVOptions( | ||
options.asScala.toMap, | ||
columnPruning = columnPruning, | ||
sparkSession.sessionState.conf.sessionLocalTimeZone, | ||
sparkSession.sessionState.conf.columnNameOfCorruptRecord) | ||
|
||
// Check a field requirement for corrupt records here to throw an exception in a driver side | ||
ExprUtils.verifyColumnNameOfCorruptRecord(dataSchema, parsedOptions.columnNameOfCorruptRecord) | ||
// Don't push any filter which refers to the "virtual" column which cannot present in the input. | ||
// Such filters will be applied later on the upper layer. | ||
val actualFilters = | ||
pushedFilters.filterNot(_.references.contains(parsedOptions.columnNameOfCorruptRecord)) | ||
|
||
val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap | ||
// Hadoop Configurations are case sensitive. | ||
val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) | ||
val broadcastedConf = sparkSession.sparkContext.broadcast( | ||
new SerializableConfiguration(hadoopConf)) | ||
// The partition values are already truncated in `FileScan.partitions`. | ||
// We should use `readPartitionSchema` as the partition schema here. | ||
CSVPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, | ||
dataSchema, readDataSchema, readPartitionSchema, parsedOptions, actualFilters) | ||
} | ||
|
||
// Create the reader factory for the Orc format. | ||
private def createOrcReaderFactory(): PartitionReaderFactory = { | ||
val broadcastedConf = sparkSession.sparkContext.broadcast( | ||
new SerializableConfiguration(hadoopConf)) | ||
// The partition values are already truncated in `FileScan.partitions`. | ||
// We should use `readPartitionSchema` as the partition schema here. | ||
OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, | ||
dataSchema, readDataSchema, readPartitionSchema, pushedFilters) | ||
} | ||
|
||
// Create the reader factory for the Parquet format. | ||
private def createParquetReaderFactory(): PartitionReaderFactory = { | ||
val readDataSchemaAsJson = readDataSchema.json | ||
hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) | ||
hadoopConf.set( | ||
ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA, | ||
readDataSchemaAsJson) | ||
hadoopConf.set( | ||
ParquetWriteSupport.SPARK_ROW_SCHEMA, | ||
readDataSchemaAsJson) | ||
hadoopConf.set( | ||
SQLConf.SESSION_LOCAL_TIMEZONE.key, | ||
sparkSession.sessionState.conf.sessionLocalTimeZone) | ||
hadoopConf.setBoolean( | ||
SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, | ||
sparkSession.sessionState.conf.nestedSchemaPruningEnabled) | ||
hadoopConf.setBoolean( | ||
SQLConf.CASE_SENSITIVE.key, | ||
sparkSession.sessionState.conf.caseSensitiveAnalysis) | ||
|
||
ParquetWriteSupport.setSchema(readDataSchema, hadoopConf) | ||
|
||
// Sets flags for `ParquetToSparkSchemaConverter` | ||
hadoopConf.setBoolean( | ||
SQLConf.PARQUET_BINARY_AS_STRING.key, | ||
sparkSession.sessionState.conf.isParquetBinaryAsString) | ||
hadoopConf.setBoolean( | ||
SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, | ||
sparkSession.sessionState.conf.isParquetINT96AsTimestamp) | ||
|
||
val broadcastedConf = sparkSession.sparkContext.broadcast( | ||
new SerializableConfiguration(hadoopConf)) | ||
val sqlConf = sparkSession.sessionState.conf | ||
ParquetPartitionReaderFactory( | ||
sqlConf, | ||
broadcastedConf, | ||
dataSchema, | ||
readDataSchema, | ||
readPartitionSchema, | ||
pushedFilters, | ||
new ParquetOptions(options.asCaseSensitiveMap.asScala.toMap, sqlConf)) | ||
} | ||
|
||
/** | ||
* Override "partitions" of org.apache.spark.sql.execution.datasources.v2.FileScan | ||
* to disable splitting and sort the files by file paths instead of by file sizes. | ||
* Note: This implementation does not support to partition attributes. | ||
*/ | ||
override protected def partitions: Seq[FilePartition] = { | ||
val selectedPartitions = fileIndex.listFiles(partitionFilters, dataFilters) | ||
val maxSplitBytes = FilePartition.maxSplitBytes(sparkSession, selectedPartitions) | ||
|
||
val splitFiles = selectedPartitions.flatMap { partition => | ||
val partitionValues = partition.values | ||
partition.files.flatMap { file => | ||
val filePath = file.getPath | ||
PartitionedFileUtil.splitFiles( | ||
sparkSession = sparkSession, | ||
file = file, | ||
filePath = filePath, | ||
isSplitable = isSplitable(filePath), | ||
maxSplitBytes = maxSplitBytes, | ||
partitionValues = partitionValues | ||
) | ||
}.toArray.sortBy(_.filePath) | ||
} | ||
|
||
getFilePartitions(sparkSession, splitFiles) | ||
} | ||
|
||
/** | ||
* Override "getFilePartitions" of org.apache.spark.sql.execution.datasources.FilePartition | ||
* to assign each chunk file in GraphAr to a single partition. | ||
*/ | ||
private def getFilePartitions( | ||
sparkSession: SparkSession, | ||
partitionedFiles: Seq[PartitionedFile]): Seq[FilePartition] = { | ||
val partitions = new ArrayBuffer[FilePartition] | ||
val currentFiles = new ArrayBuffer[PartitionedFile] | ||
|
||
/** Close the current partition and move to the next. */ | ||
def closePartition(): Unit = { | ||
if (currentFiles.nonEmpty) { | ||
// Copy to a new Array. | ||
val newPartition = FilePartition(partitions.size, currentFiles.toArray) | ||
partitions += newPartition | ||
} | ||
currentFiles.clear() | ||
} | ||
// Assign a file to each partition | ||
partitionedFiles.foreach { file => | ||
closePartition() | ||
// Add the given file to the current partition. | ||
currentFiles += file | ||
} | ||
closePartition() | ||
partitions.toSeq | ||
} | ||
|
||
/** Check if two objects are equal. */ | ||
override def equals(obj: Any): Boolean = obj match { | ||
case g: GarScan => | ||
super.equals(g) && dataSchema == g.dataSchema && options == g.options && | ||
equivalentFilters(pushedFilters, g.pushedFilters) && formatName == g.formatName | ||
case _ => false | ||
} | ||
|
||
/** Get the hash code of the object. */ | ||
override def hashCode(): Int = formatName match { | ||
case "csv" => super.hashCode() | ||
case "orc" => getClass.hashCode() | ||
case "parquet" => getClass.hashCode() | ||
case _ => throw new IllegalArgumentException | ||
} | ||
|
||
/** Get the description string of the object. */ | ||
override def description(): String = { | ||
super.description() + ", PushedFilters: " + seqToString(pushedFilters) | ||
} | ||
|
||
/** Get the meata data map of the object. */ | ||
override def getMetaData(): Map[String, String] = { | ||
super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) | ||
} | ||
|
||
/** Construct the file scan with filters. */ | ||
override def withFilters( | ||
partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = | ||
this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) | ||
} |
65 changes: 65 additions & 0 deletions
65
spark/src/main/scala/com/alibaba/graphar/datasources/GarScanBuilder.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
/** Copyright 2022 Alibaba Group Holding Limited. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.alibaba.graphar.datasources | ||
|
||
import scala.collection.JavaConverters._ | ||
|
||
import org.apache.spark.sql.SparkSession | ||
import org.apache.spark.sql.connector.read.Scan | ||
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex | ||
import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder | ||
import org.apache.spark.sql.sources.Filter | ||
import org.apache.spark.sql.types.StructType | ||
import org.apache.spark.sql.util.CaseInsensitiveStringMap | ||
|
||
/** GarScanBuilder is a class to build the file scan for GarDataSource. */ | ||
case class GarScanBuilder( | ||
sparkSession: SparkSession, | ||
fileIndex: PartitioningAwareFileIndex, | ||
schema: StructType, | ||
dataSchema: StructType, | ||
options: CaseInsensitiveStringMap, | ||
formatName: String) | ||
extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { | ||
lazy val hadoopConf = { | ||
val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap | ||
// Hadoop Configurations are case sensitive. | ||
sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) | ||
} | ||
|
||
// Check if the file format supports nested schema pruning. | ||
override protected val supportsNestedSchemaPruning: Boolean = formatName match { | ||
case "csv" => false | ||
case "orc" => true | ||
case "parquet" => true | ||
case _ => throw new IllegalArgumentException | ||
} | ||
|
||
// Note: This scan builder does not implement "with SupportsPushDownFilters". | ||
private var filters: Array[Filter] = Array.empty | ||
|
||
// Note: To support pushdown filters, these two methods need to be implemented. | ||
|
||
// override def pushFilters(filters: Array[Filter]): Array[Filter] | ||
|
||
// override def pushedFilters(): Array[Filter] | ||
|
||
/** Build the file scan for GarDataSource. */ | ||
override def build(): Scan = { | ||
GarScan(sparkSession, hadoopConf, fileIndex, dataSchema, readDataSchema(), | ||
readPartitionSchema(), filters, options, formatName) | ||
} | ||
} |
Oops, something went wrong.