From d223858a9458912388ea6096c49a86be6be1500e Mon Sep 17 00:00:00 2001 From: lixueclaire Date: Wed, 18 Jan 2023 14:50:14 +0800 Subject: [PATCH] [Improve][Spark] Improve the performance of GraphAr Spark Reader (#84) --- spark/pom.xml | 2 +- .../graphar/datasources/GarDataSource.scala | 59 +++++ .../alibaba/graphar/datasources/GarScan.scala | 237 ++++++++++++++++++ .../graphar/datasources/GarScanBuilder.scala | 65 +++++ .../graphar/datasources/GarTable.scala | 104 ++++++++ .../alibaba/graphar/reader/EdgeReader.scala | 98 +++----- .../alibaba/graphar/reader/VertexReader.scala | 23 +- .../com/alibaba/graphar/TestReader.scala | 55 +++- 8 files changed, 567 insertions(+), 76 deletions(-) create mode 100644 spark/src/main/scala/com/alibaba/graphar/datasources/GarDataSource.scala create mode 100644 spark/src/main/scala/com/alibaba/graphar/datasources/GarScan.scala create mode 100644 spark/src/main/scala/com/alibaba/graphar/datasources/GarScanBuilder.scala create mode 100644 spark/src/main/scala/com/alibaba/graphar/datasources/GarTable.scala diff --git a/spark/pom.xml b/spark/pom.xml index d350582a8..efcd231f2 100644 --- a/spark/pom.xml +++ b/spark/pom.xml @@ -15,7 +15,7 @@ 2.12 512m 1024m - 3.2.0 + 3.2.2 8 1.8 1.8 diff --git a/spark/src/main/scala/com/alibaba/graphar/datasources/GarDataSource.scala b/spark/src/main/scala/com/alibaba/graphar/datasources/GarDataSource.scala new file mode 100644 index 000000000..50a4de3ac --- /dev/null +++ b/spark/src/main/scala/com/alibaba/graphar/datasources/GarDataSource.scala @@ -0,0 +1,59 @@ +/** Copyright 2022 Alibaba Group Holding Limited. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.graphar.datasources + +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat +import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.execution.datasources.v2._ +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** GarDataSource is a class to provide gar files as the data source for spark. */ +class GarDataSource extends FileDataSourceV2 { + + /** The default fallback file format is Parquet. */ + override def fallbackFileFormat: Class[_ <: FileFormat] = classOf[ParquetFileFormat] + + /** The string that represents the format name. */ + override def shortName(): String = "gar" + + /** Provide a table from the data source. */ + override def getTable(options: CaseInsensitiveStringMap): Table = { + val paths = getPaths(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + GarTable(tableName, sparkSession, optionsWithoutPaths, paths, None, getFallbackFileFormat(options)) + } + + /** Provide a table from the data source with specific schema. */ + override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { + val paths = getPaths(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + GarTable(tableName, sparkSession, optionsWithoutPaths, paths, Some(schema), getFallbackFileFormat(options)) + } + + // Get the actual fall back file format. + private def getFallbackFileFormat(options: CaseInsensitiveStringMap): Class[_ <: FileFormat] = options.get("fileFormat") match { + case "csv" => classOf[CSVFileFormat] + case "orc" => classOf[OrcFileFormat] + case "parquet" => classOf[ParquetFileFormat] + case _ => throw new IllegalArgumentException + } +} diff --git a/spark/src/main/scala/com/alibaba/graphar/datasources/GarScan.scala b/spark/src/main/scala/com/alibaba/graphar/datasources/GarScan.scala new file mode 100644 index 000000000..1d60c4689 --- /dev/null +++ b/spark/src/main/scala/com/alibaba/graphar/datasources/GarScan.scala @@ -0,0 +1,237 @@ +/** Copyright 2022 Alibaba Group Holding Limited. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.graphar.datasources + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.parquet.hadoop.ParquetInputFormat + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.{Expression, ExprUtils} +import org.apache.spark.sql.catalyst.csv.CSVOptions +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.PartitionedFileUtil +import org.apache.spark.sql.execution.datasources.{FilePartition, PartitioningAwareFileIndex, PartitionedFile} +import org.apache.spark.sql.execution.datasources.csv.CSVDataSource +import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetReadSupport, ParquetWriteSupport} +import org.apache.spark.sql.execution.datasources.v2.FileScan +import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetPartitionReaderFactory +import org.apache.spark.sql.execution.datasources.v2.orc.OrcPartitionReaderFactory +import org.apache.spark.sql.execution.datasources.v2.csv.CSVPartitionReaderFactory +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.SerializableConfiguration + +/** GarScan is a class to implement the file scan for GarDataSource. */ +case class GarScan( + sparkSession: SparkSession, + hadoopConf: Configuration, + fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType, + readDataSchema: StructType, + readPartitionSchema: StructType, + pushedFilters: Array[Filter], + options: CaseInsensitiveStringMap, + formatName: String, + partitionFilters: Seq[Expression] = Seq.empty, + dataFilters: Seq[Expression] = Seq.empty) extends FileScan { + + /** The gar format is not splitable. */ + override def isSplitable(path: Path): Boolean = false + + /** Create the reader factory according to the actual file format. */ + override def createReaderFactory(): PartitionReaderFactory = formatName match { + case "csv" => createCSVReaderFactory() + case "orc" => createOrcReaderFactory() + case "parquet" => createParquetReaderFactory() + case _ => throw new IllegalArgumentException + } + + // Create the reader factory for the CSV format. + private def createCSVReaderFactory(): PartitionReaderFactory = { + val columnPruning = sparkSession.sessionState.conf.csvColumnPruning && + !readDataSchema.exists(_.name == sparkSession.sessionState.conf.columnNameOfCorruptRecord) + + val parsedOptions: CSVOptions = new CSVOptions( + options.asScala.toMap, + columnPruning = columnPruning, + sparkSession.sessionState.conf.sessionLocalTimeZone, + sparkSession.sessionState.conf.columnNameOfCorruptRecord) + + // Check a field requirement for corrupt records here to throw an exception in a driver side + ExprUtils.verifyColumnNameOfCorruptRecord(dataSchema, parsedOptions.columnNameOfCorruptRecord) + // Don't push any filter which refers to the "virtual" column which cannot present in the input. + // Such filters will be applied later on the upper layer. + val actualFilters = + pushedFilters.filterNot(_.references.contains(parsedOptions.columnNameOfCorruptRecord)) + + val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + // Hadoop Configurations are case sensitive. + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + val broadcastedConf = sparkSession.sparkContext.broadcast( + new SerializableConfiguration(hadoopConf)) + // The partition values are already truncated in `FileScan.partitions`. + // We should use `readPartitionSchema` as the partition schema here. + CSVPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, + dataSchema, readDataSchema, readPartitionSchema, parsedOptions, actualFilters) + } + + // Create the reader factory for the Orc format. + private def createOrcReaderFactory(): PartitionReaderFactory = { + val broadcastedConf = sparkSession.sparkContext.broadcast( + new SerializableConfiguration(hadoopConf)) + // The partition values are already truncated in `FileScan.partitions`. + // We should use `readPartitionSchema` as the partition schema here. + OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, + dataSchema, readDataSchema, readPartitionSchema, pushedFilters) + } + + // Create the reader factory for the Parquet format. + private def createParquetReaderFactory(): PartitionReaderFactory = { + val readDataSchemaAsJson = readDataSchema.json + hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) + hadoopConf.set( + ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA, + readDataSchemaAsJson) + hadoopConf.set( + ParquetWriteSupport.SPARK_ROW_SCHEMA, + readDataSchemaAsJson) + hadoopConf.set( + SQLConf.SESSION_LOCAL_TIMEZONE.key, + sparkSession.sessionState.conf.sessionLocalTimeZone) + hadoopConf.setBoolean( + SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, + sparkSession.sessionState.conf.nestedSchemaPruningEnabled) + hadoopConf.setBoolean( + SQLConf.CASE_SENSITIVE.key, + sparkSession.sessionState.conf.caseSensitiveAnalysis) + + ParquetWriteSupport.setSchema(readDataSchema, hadoopConf) + + // Sets flags for `ParquetToSparkSchemaConverter` + hadoopConf.setBoolean( + SQLConf.PARQUET_BINARY_AS_STRING.key, + sparkSession.sessionState.conf.isParquetBinaryAsString) + hadoopConf.setBoolean( + SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, + sparkSession.sessionState.conf.isParquetINT96AsTimestamp) + + val broadcastedConf = sparkSession.sparkContext.broadcast( + new SerializableConfiguration(hadoopConf)) + val sqlConf = sparkSession.sessionState.conf + ParquetPartitionReaderFactory( + sqlConf, + broadcastedConf, + dataSchema, + readDataSchema, + readPartitionSchema, + pushedFilters, + new ParquetOptions(options.asCaseSensitiveMap.asScala.toMap, sqlConf)) + } + + /** + * Override "partitions" of org.apache.spark.sql.execution.datasources.v2.FileScan + * to disable splitting and sort the files by file paths instead of by file sizes. + * Note: This implementation does not support to partition attributes. + */ + override protected def partitions: Seq[FilePartition] = { + val selectedPartitions = fileIndex.listFiles(partitionFilters, dataFilters) + val maxSplitBytes = FilePartition.maxSplitBytes(sparkSession, selectedPartitions) + + val splitFiles = selectedPartitions.flatMap { partition => + val partitionValues = partition.values + partition.files.flatMap { file => + val filePath = file.getPath + PartitionedFileUtil.splitFiles( + sparkSession = sparkSession, + file = file, + filePath = filePath, + isSplitable = isSplitable(filePath), + maxSplitBytes = maxSplitBytes, + partitionValues = partitionValues + ) + }.toArray.sortBy(_.filePath) + } + + getFilePartitions(sparkSession, splitFiles) + } + + /** + * Override "getFilePartitions" of org.apache.spark.sql.execution.datasources.FilePartition + * to assign each chunk file in GraphAr to a single partition. + */ + private def getFilePartitions( + sparkSession: SparkSession, + partitionedFiles: Seq[PartitionedFile]): Seq[FilePartition] = { + val partitions = new ArrayBuffer[FilePartition] + val currentFiles = new ArrayBuffer[PartitionedFile] + + /** Close the current partition and move to the next. */ + def closePartition(): Unit = { + if (currentFiles.nonEmpty) { + // Copy to a new Array. + val newPartition = FilePartition(partitions.size, currentFiles.toArray) + partitions += newPartition + } + currentFiles.clear() + } + // Assign a file to each partition + partitionedFiles.foreach { file => + closePartition() + // Add the given file to the current partition. + currentFiles += file + } + closePartition() + partitions.toSeq + } + + /** Check if two objects are equal. */ + override def equals(obj: Any): Boolean = obj match { + case g: GarScan => + super.equals(g) && dataSchema == g.dataSchema && options == g.options && + equivalentFilters(pushedFilters, g.pushedFilters) && formatName == g.formatName + case _ => false + } + + /** Get the hash code of the object. */ + override def hashCode(): Int = formatName match { + case "csv" => super.hashCode() + case "orc" => getClass.hashCode() + case "parquet" => getClass.hashCode() + case _ => throw new IllegalArgumentException + } + + /** Get the description string of the object. */ + override def description(): String = { + super.description() + ", PushedFilters: " + seqToString(pushedFilters) + } + + /** Get the meata data map of the object. */ + override def getMetaData(): Map[String, String] = { + super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) + } + + /** Construct the file scan with filters. */ + override def withFilters( + partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = + this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) +} diff --git a/spark/src/main/scala/com/alibaba/graphar/datasources/GarScanBuilder.scala b/spark/src/main/scala/com/alibaba/graphar/datasources/GarScanBuilder.scala new file mode 100644 index 000000000..a3452c8ad --- /dev/null +++ b/spark/src/main/scala/com/alibaba/graphar/datasources/GarScanBuilder.scala @@ -0,0 +1,65 @@ +/** Copyright 2022 Alibaba Group Holding Limited. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.graphar.datasources + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.read.Scan +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** GarScanBuilder is a class to build the file scan for GarDataSource. */ +case class GarScanBuilder( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + schema: StructType, + dataSchema: StructType, + options: CaseInsensitiveStringMap, + formatName: String) + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + lazy val hadoopConf = { + val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + // Hadoop Configurations are case sensitive. + sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + } + + // Check if the file format supports nested schema pruning. + override protected val supportsNestedSchemaPruning: Boolean = formatName match { + case "csv" => false + case "orc" => true + case "parquet" => true + case _ => throw new IllegalArgumentException + } + + // Note: This scan builder does not implement "with SupportsPushDownFilters". + private var filters: Array[Filter] = Array.empty + + // Note: To support pushdown filters, these two methods need to be implemented. + + // override def pushFilters(filters: Array[Filter]): Array[Filter] + + // override def pushedFilters(): Array[Filter] + + /** Build the file scan for GarDataSource. */ + override def build(): Scan = { + GarScan(sparkSession, hadoopConf, fileIndex, dataSchema, readDataSchema(), + readPartitionSchema(), filters, options, formatName) + } +} diff --git a/spark/src/main/scala/com/alibaba/graphar/datasources/GarTable.scala b/spark/src/main/scala/com/alibaba/graphar/datasources/GarTable.scala new file mode 100644 index 000000000..505cdacdc --- /dev/null +++ b/spark/src/main/scala/com/alibaba/graphar/datasources/GarTable.scala @@ -0,0 +1,104 @@ +/** Copyright 2022 Alibaba Group Holding Limited. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.graphar.datasources + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.FileStatus + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, Write, WriteBuilder} +import org.apache.spark.sql.catalyst.csv.CSVOptions +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.csv.CSVDataSource +import org.apache.spark.sql.execution.datasources.orc.OrcUtils +import org.apache.spark.sql.execution.datasources.parquet.ParquetUtils +import org.apache.spark.sql.execution.datasources.v2.FileTable +import org.apache.spark.sql.execution.datasources.v2.csv.CSVWrite +import org.apache.spark.sql.execution.datasources.v2.orc.OrcWrite +import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetWrite +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** GarTable is a class to represent the graph data in GraphAr as a table. */ +case class GarTable( + name: String, + sparkSession: SparkSession, + options: CaseInsensitiveStringMap, + paths: Seq[String], + userSpecifiedSchema: Option[StructType], + fallbackFileFormat: Class[_ <: FileFormat]) + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + + /** Construct a new scan builder. */ + override def newScanBuilder(options: CaseInsensitiveStringMap): GarScanBuilder = + new GarScanBuilder(sparkSession, fileIndex, schema, dataSchema, options, formatName) + + /** Infer the schema of the table through the methods of the actual file format. */ + override def inferSchema(files: Seq[FileStatus]): Option[StructType] = formatName match { + case "csv" => { + val parsedOptions = new CSVOptions( + options.asScala.toMap, + columnPruning = sparkSession.sessionState.conf.csvColumnPruning, + sparkSession.sessionState.conf.sessionLocalTimeZone) + + CSVDataSource(parsedOptions).inferSchema(sparkSession, files, parsedOptions) + } + case "orc" => OrcUtils.inferSchema(sparkSession, files, options.asScala.toMap) + case "parquet" => ParquetUtils.inferSchema(sparkSession, options.asScala.toMap, files) + case _ => throw new IllegalArgumentException + } + + /** Construct a new write builder according to the actual file format. */ + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = formatName match { + case "csv" => new WriteBuilder { + override def build(): Write = CSVWrite(paths, formatName, supportsDataType, info) + } + case "orc" => new WriteBuilder { + override def build(): Write = OrcWrite(paths, formatName, supportsDataType, info) + } + case "parquet" => new WriteBuilder { + override def build(): Write = ParquetWrite(paths, formatName, supportsDataType, info) + } + case _ => throw new IllegalArgumentException + } + + /** + * Check if a data type is supported. + * Note: Currently, the GraphAr data source only supports several atomic data types. + * To support additional data types such as Struct, Array and Map, revise this function + * to handle them case by case as the commented code shows. + */ + override def supportsDataType(dataType: DataType): Boolean = dataType match { + // case _: AnsiIntervalType => false + + case _: AtomicType => true + + // case st: StructType => st.forall { f => supportsDataType(f.dataType) } + + // case ArrayType(elementType, _) => supportsDataType(elementType) + + // case MapType(keyType, valueType, _) => + // supportsDataType(keyType) && supportsDataType(valueType) + + // case udt: UserDefinedType[_] => supportsDataType(udt.sqlType) + + case _ => false + } + + /** The actual file format for storing the data in GraphAr. */ + override def formatName: String = options.get("fileFormat") +} diff --git a/spark/src/main/scala/com/alibaba/graphar/reader/EdgeReader.scala b/spark/src/main/scala/com/alibaba/graphar/reader/EdgeReader.scala index 24c3890bd..4705322f7 100644 --- a/spark/src/main/scala/com/alibaba/graphar/reader/EdgeReader.scala +++ b/spark/src/main/scala/com/alibaba/graphar/reader/EdgeReader.scala @@ -17,8 +17,8 @@ package com.alibaba.graphar.reader import com.alibaba.graphar.utils.{IndexGenerator} import com.alibaba.graphar.{GeneralParams, EdgeInfo, FileType, AdjListType, PropertyGroup} +import com.alibaba.graphar.datasources._ -import org.apache.hadoop.fs.{Path, FileSystem} import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.types._ import org.apache.spark.sql.functions._ @@ -50,7 +50,7 @@ class EdgeReader(prefix: String, edgeInfo: EdgeInfo, adjListType: AdjListType.V val file_type_in_gar = edgeInfo.getAdjListFileType(adjListType) val file_type = FileType.FileTypeToString(file_type_in_gar) val file_path = prefix + "/" + edgeInfo.getAdjListOffsetFilePath(chunk_index, adjListType) - val df = spark.read.format(file_type).load(file_path) + val df = spark.read.option("fileFormat", file_type).format("com.alibaba.graphar.datasources.GarDataSource").load(file_path) return df } @@ -64,7 +64,7 @@ class EdgeReader(prefix: String, edgeInfo: EdgeInfo, adjListType: AdjListType.V val file_type_in_gar = edgeInfo.getAdjListFileType(adjListType) val file_type = FileType.FileTypeToString(file_type_in_gar) val file_path = prefix + "/" + edgeInfo.getAdjListFilePath(vertex_chunk_index, chunk_index, adjListType) - val df = spark.read.format(file_type).load(file_path) + val df = spark.read.option("fileFormat", file_type).format("com.alibaba.graphar.datasources.GarDataSource").load(file_path) return df } @@ -75,21 +75,15 @@ class EdgeReader(prefix: String, edgeInfo: EdgeInfo, adjListType: AdjListType.V * @return DataFrame of all AdjList chunks of vertices in given vertex chunk. */ def readAdjListForVertexChunk(vertex_chunk_index: Long, addIndex: Boolean = false): DataFrame = { - val part_prefix = prefix + "/" + edgeInfo.getAdjListPathPrefix(vertex_chunk_index, adjListType) - val file_system = FileSystem.get(new Path(part_prefix).toUri(), spark.sparkContext.hadoopConfiguration) - val path_pattern = new Path(part_prefix + "chunk*") - val chunk_number = file_system.globStatus(path_pattern).length - var df = spark.emptyDataFrame - for ( i <- 0 to chunk_number - 1) { - val new_df = readAdjListChunk(vertex_chunk_index, i) - if (i == 0) - df = new_df - else - df = df.union(new_df) + val file_type_in_gar = edgeInfo.getAdjListFileType(adjListType) + val file_type = FileType.FileTypeToString(file_type_in_gar) + val file_path = prefix + "/" + edgeInfo.getAdjListPathPrefix(vertex_chunk_index, adjListType) + val df = spark.read.option("fileFormat", file_type).format("com.alibaba.graphar.datasources.GarDataSource").load(file_path) + if (addIndex) { + return IndexGenerator.generateEdgeIndexColumn(df) + } else { + return df } - if (addIndex) - df = IndexGenerator.generateEdgeIndexColumn(df) - return df } /** Load all AdjList chunks for this edge type as a DataFrame. @@ -98,21 +92,15 @@ class EdgeReader(prefix: String, edgeInfo: EdgeInfo, adjListType: AdjListType.V * @return DataFrame of all AdjList chunks. */ def readAllAdjList(addIndex: Boolean = false): DataFrame = { + val file_type_in_gar = edgeInfo.getAdjListFileType(adjListType) + val file_type = FileType.FileTypeToString(file_type_in_gar) val file_path = prefix + "/" + edgeInfo.getAdjListPathPrefix(adjListType) - val file_system = FileSystem.get(new Path(file_path).toUri(), spark.sparkContext.hadoopConfiguration) - val path_pattern = new Path(file_path + "part*") - val vertex_chunk_number = file_system.globStatus(path_pattern).length - var df = spark.emptyDataFrame - for ( i <- 0 to vertex_chunk_number - 1) { - val new_df = readAdjListForVertexChunk(i) - if (i == 0) - df = new_df - else - df = df.union(new_df) + val df = spark.read.option("fileFormat", file_type).option("recursiveFileLookup", "true").format("com.alibaba.graphar.datasources.GarDataSource").load(file_path) + if (addIndex) { + return IndexGenerator.generateEdgeIndexColumn(df) + } else { + return df } - if (addIndex) - df = IndexGenerator.generateEdgeIndexColumn(df) - return df } /** Load a single edge property chunk as a DataFrame. @@ -126,9 +114,9 @@ class EdgeReader(prefix: String, edgeInfo: EdgeInfo, adjListType: AdjListType.V def readEdgePropertyChunk(propertyGroup: PropertyGroup, vertex_chunk_index: Long, chunk_index: Long): DataFrame = { if (edgeInfo.containPropertyGroup(propertyGroup, adjListType) == false) throw new IllegalArgumentException - val file_type = propertyGroup.getFile_type(); + val file_type = propertyGroup.getFile_type() val file_path = prefix + "/" + edgeInfo.getPropertyFilePath(propertyGroup, adjListType, vertex_chunk_index, chunk_index) - val df = spark.read.format(file_type).load(file_path) + val df = spark.read.option("fileFormat", file_type).format("com.alibaba.graphar.datasources.GarDataSource").load(file_path) return df } @@ -143,21 +131,14 @@ class EdgeReader(prefix: String, edgeInfo: EdgeInfo, adjListType: AdjListType.V def readEdgePropertiesForVertexChunk(propertyGroup: PropertyGroup, vertex_chunk_index: Long, addIndex: Boolean = false): DataFrame = { if (edgeInfo.containPropertyGroup(propertyGroup, adjListType) == false) throw new IllegalArgumentException - val path_prefix = prefix + "/" + edgeInfo.getPropertyGroupPathPrefix(propertyGroup, adjListType, vertex_chunk_index) - val file_system = FileSystem.get(new Path(path_prefix).toUri(), spark.sparkContext.hadoopConfiguration) - val path_pattern = new Path(path_prefix + "chunk*") - val chunk_number = file_system.globStatus(path_pattern).length - var df = spark.emptyDataFrame - for ( i <- 0 to chunk_number - 1) { - val new_df = readEdgePropertyChunk(propertyGroup, vertex_chunk_index, i) - if (i == 0) - df = new_df - else - df = df.union(new_df) + val file_type = propertyGroup.getFile_type() + val file_path = prefix + "/" + edgeInfo.getPropertyGroupPathPrefix(propertyGroup, adjListType, vertex_chunk_index) + val df = spark.read.option("fileFormat", file_type).format("com.alibaba.graphar.datasources.GarDataSource").load(file_path) + if (addIndex) { + return IndexGenerator.generateEdgeIndexColumn(df) + } else { + return df } - if (addIndex) - df = IndexGenerator.generateEdgeIndexColumn(df) - return df } /** Load all chunks for a property group as a DataFrame. @@ -170,21 +151,14 @@ class EdgeReader(prefix: String, edgeInfo: EdgeInfo, adjListType: AdjListType.V def readEdgeProperties(propertyGroup: PropertyGroup, addIndex: Boolean = false): DataFrame = { if (edgeInfo.containPropertyGroup(propertyGroup, adjListType) == false) throw new IllegalArgumentException - val property_group_prefix = prefix + "/" + edgeInfo.getPropertyGroupPathPrefix(propertyGroup, adjListType) - val file_system = FileSystem.get(new Path(property_group_prefix).toUri(), spark.sparkContext.hadoopConfiguration) - val path_pattern = new Path(property_group_prefix + "part*") - val vertex_chunk_number = file_system.globStatus(path_pattern).length - var df = spark.emptyDataFrame - for ( i <- 0 to vertex_chunk_number - 1) { - val new_df = readEdgePropertiesForVertexChunk(propertyGroup, i) - if (i == 0) - df = new_df - else - df = df.union(new_df) + val file_type = propertyGroup.getFile_type() + val file_path = prefix + "/" + edgeInfo.getPropertyGroupPathPrefix(propertyGroup, adjListType) + val df = spark.read.option("fileFormat", file_type).option("recursiveFileLookup", "true").format("com.alibaba.graphar.datasources.GarDataSource").load(file_path) + if (addIndex) { + return IndexGenerator.generateEdgeIndexColumn(df) + } else { + return df } - if (addIndex) - df = IndexGenerator.generateEdgeIndexColumn(df) - return df } /** Load the chunks for all property groups of a vertex chunk as a DataFrame. @@ -257,9 +231,9 @@ class EdgeReader(prefix: String, edgeInfo: EdgeInfo, adjListType: AdjListType.V def readEdges(addIndex: Boolean = false): DataFrame = { val adjList_df = readAllAdjList(true) val properties_df = readAllEdgeProperties(true) - var df = adjList_df.join(properties_df, Seq(GeneralParams.edgeIndexCol)).sort(GeneralParams.edgeIndexCol); + var df = adjList_df.join(properties_df, Seq(GeneralParams.edgeIndexCol)).sort(GeneralParams.edgeIndexCol) if (addIndex == false) df = df.drop(GeneralParams.edgeIndexCol) return df } -} \ No newline at end of file +} diff --git a/spark/src/main/scala/com/alibaba/graphar/reader/VertexReader.scala b/spark/src/main/scala/com/alibaba/graphar/reader/VertexReader.scala index 6d68d4135..0ee65aad4 100644 --- a/spark/src/main/scala/com/alibaba/graphar/reader/VertexReader.scala +++ b/spark/src/main/scala/com/alibaba/graphar/reader/VertexReader.scala @@ -17,6 +17,7 @@ package com.alibaba.graphar.reader import com.alibaba.graphar.utils.{IndexGenerator} import com.alibaba.graphar.{GeneralParams, VertexInfo, FileType, PropertyGroup} +import com.alibaba.graphar.datasources._ import org.apache.hadoop.fs.{Path, FileSystem} import org.apache.spark.sql.{DataFrame, SparkSession} @@ -59,7 +60,7 @@ class VertexReader(prefix: String, vertexInfo: VertexInfo, spark: SparkSession) throw new IllegalArgumentException val file_type = propertyGroup.getFile_type() val file_path = prefix + "/" + vertexInfo.getFilePath(propertyGroup, chunk_index) - val df = spark.read.format(file_type).load(file_path) + val df = spark.read.option("fileFormat", file_type).format("com.alibaba.graphar.datasources.GarDataSource").load(file_path) return df } @@ -73,17 +74,15 @@ class VertexReader(prefix: String, vertexInfo: VertexInfo, spark: SparkSession) def readVertexProperties(propertyGroup: PropertyGroup, addIndex: Boolean = false): DataFrame = { if (vertexInfo.containPropertyGroup(propertyGroup) == false) throw new IllegalArgumentException - var df = spark.emptyDataFrame - for ( i <- 0L to chunk_number - 1) { - val new_df = readVertexPropertyChunk(propertyGroup, i) - if (i == 0) - df = new_df - else - df = df.union(new_df) + val file_type = propertyGroup.getFile_type() + val file_path = prefix + "/" + vertexInfo.getPathPrefix(propertyGroup) + val df = spark.read.option("fileFormat", file_type).format("com.alibaba.graphar.datasources.GarDataSource").load(file_path) + + if (addIndex) { + return IndexGenerator.generateVertexIndexColumn(df) + } else { + return df } - if (addIndex) - df = IndexGenerator.generateVertexIndexColumn(df) - return df } /** Load the chunks for all property groups as a DataFrame. @@ -108,4 +107,4 @@ class VertexReader(prefix: String, vertexInfo: VertexInfo, spark: SparkSession) df = df.drop(GeneralParams.vertexIndexCol) return df } -} \ No newline at end of file +} diff --git a/spark/src/test/scala/com/alibaba/graphar/TestReader.scala b/spark/src/test/scala/com/alibaba/graphar/TestReader.scala index 4d8ce3c58..a3f5aae31 100644 --- a/spark/src/test/scala/com/alibaba/graphar/TestReader.scala +++ b/spark/src/test/scala/com/alibaba/graphar/TestReader.scala @@ -15,6 +15,7 @@ package com.alibaba.graphar +import com.alibaba.graphar.datasources._ import com.alibaba.graphar.reader.{VertexReader, EdgeReader} import java.io.{File, FileInputStream} @@ -30,22 +31,64 @@ class ReaderSuite extends AnyFunSuite { .master("local[*]") .getOrCreate() + test("read chunk files directly") { + // read vertex chunk files in Parquet + val parquet_file_path = "gar-test/ldbc_sample/parquet" + val parquet_prefix = getClass.getClassLoader.getResource(parquet_file_path).getPath + val parqeut_read_path = parquet_prefix + "/vertex/person/id" + val df1 = spark.read.option("fileFormat", "parquet").format("com.alibaba.graphar.datasources.GarDataSource").load(parqeut_read_path) + // validate reading results + assert(df1.rdd.getNumPartitions == 10) + assert(df1.count() == 903) + // println(df1.rdd.collect().mkString("\n")) + + // read vertex chunk files in Orc + val orc_file_path = "gar-test/ldbc_sample/orc" + val orc_prefix = getClass.getClassLoader.getResource(orc_file_path).getPath + val orc_read_path = orc_prefix + "/vertex/person/id" + val df2 = spark.read.option("fileFormat", "orc").format("com.alibaba.graphar.datasources.GarDataSource").load(orc_read_path) + // validate reading results + assert(df2.rdd.collect().deep == df1.rdd.collect().deep) + + // read adjList chunk files recursively in CSV + val csv_file_path = "gar-test/ldbc_sample/csv" + val csv_prefix = getClass.getClassLoader.getResource(csv_file_path).getPath + val csv_read_path = csv_prefix + "/edge/person_knows_person/ordered_by_source/adj_list" + val df3 = spark.read.option("fileFormat", "csv").option("recursiveFileLookup", "true").format("com.alibaba.graphar.datasources.GarDataSource").load(csv_read_path) + // validate reading results + assert(df3.rdd.getNumPartitions == 11) + assert(df3.count() == 6626) + + // throw an exception for unsupported file formats + assertThrows[IllegalArgumentException](spark.read.option("fileFormat", "invalid").format("com.alibaba.graphar.datasources.GarDataSource").load(csv_read_path)) + } + test("read vertex chunks") { + // construct the vertex information val file_path = "gar-test/ldbc_sample/csv" val prefix = getClass.getClassLoader.getResource(file_path).getPath val vertex_input = getClass.getClassLoader.getResourceAsStream(file_path + "/person.vertex.yml") val vertex_yaml = new Yaml(new Constructor(classOf[VertexInfo])) val vertex_info = vertex_yaml.load(vertex_input).asInstanceOf[VertexInfo] + // construct the vertex reader val reader = new VertexReader(prefix, vertex_info, spark) + + // test reading the number of vertices assert(reader.readVerticesNumber() == 903) val property_group = vertex_info.getPropertyGroup("gender") + + // test reading a single property chunk val single_chunk_df = reader.readVertexPropertyChunk(property_group, 0) assert(single_chunk_df.columns.size == 3) assert(single_chunk_df.count() == 100) + + // test reading chunks for a property group val property_df = reader.readVertexProperties(property_group) assert(property_df.columns.size == 3) assert(property_df.count() == 903) + + // test reading chunks for all property groups and optionally adding indices val vertex_df = reader.readAllVertexProperties() vertex_df.show() assert(vertex_df.columns.size == 4) @@ -55,23 +98,30 @@ class ReaderSuite extends AnyFunSuite { assert(vertex_df_with_index.columns.size == 5) assert(vertex_df_with_index.count() == 903) + // throw an exception for non-existing property groups val invalid_property_group= new PropertyGroup() assertThrows[IllegalArgumentException](reader.readVertexPropertyChunk(invalid_property_group, 0)) assertThrows[IllegalArgumentException](reader.readVertexProperties(invalid_property_group)) } test("read edge chunks") { + // construct the edge information val file_path = "gar-test/ldbc_sample/csv" val prefix = getClass.getClassLoader.getResource(file_path).getPath val edge_input = getClass.getClassLoader.getResourceAsStream(file_path + "/person_knows_person.edge.yml") val edge_yaml = new Yaml(new Constructor(classOf[EdgeInfo])) val edge_info = edge_yaml.load(edge_input).asInstanceOf[EdgeInfo] + // construct the edge reader val adj_list_type = AdjListType.ordered_by_source val reader = new EdgeReader(prefix, edge_info, adj_list_type, spark) + + // test reading a offset chunk val offset_df = reader.readOffset(0) assert(offset_df.columns.size == 1) assert(offset_df.count() == 101) + + // test reading adjList chunks val single_adj_list_df = reader.readAdjListChunk(2, 0) assert(single_adj_list_df.columns.size == 2) assert(single_adj_list_df.count() == 1024) @@ -82,6 +132,7 @@ class ReaderSuite extends AnyFunSuite { assert(adj_list_df.columns.size == 2) assert(adj_list_df.count() == 6626) + // test reading property group chunks val property_group = edge_info.getPropertyGroup("creationDate", adj_list_type) val single_property_df = reader.readEdgePropertyChunk(property_group, 2, 0) assert(single_property_df.columns.size == 1) @@ -99,6 +150,7 @@ class ReaderSuite extends AnyFunSuite { assert(all_property_df.columns.size == 1) assert(all_property_df.count() == 6626) + // test reading edges and optionally adding indices val edge_df_chunk_2 = reader.readEdgesForVertexChunk(2) edge_df_chunk_2.show() assert(edge_df_chunk_2.columns.size == 3) @@ -116,13 +168,14 @@ class ReaderSuite extends AnyFunSuite { assert(edge_df_with_index.columns.size == 4) assert(edge_df_with_index.count() == 6626) + // throw an exception for non-existing property groups val invalid_property_group= new PropertyGroup() assertThrows[IllegalArgumentException](reader.readEdgePropertyChunk(invalid_property_group, 0, 0)) assertThrows[IllegalArgumentException](reader.readEdgePropertiesForVertexChunk(invalid_property_group, 0)) assertThrows[IllegalArgumentException](reader.readEdgeProperties(invalid_property_group)) + // throw an exception for non-existing adjList types val invalid_adj_list_type = AdjListType.unordered_by_dest assertThrows[IllegalArgumentException](new EdgeReader(prefix, edge_info, invalid_adj_list_type, spark)) } - }