diff --git a/integration_tests/src/main/python/orc_write_test.py b/integration_tests/src/main/python/orc_write_test.py index 103cae474a3..b3c82675869 100644 --- a/integration_tests/src/main/python/orc_write_test.py +++ b/integration_tests/src/main/python/orc_write_test.py @@ -360,6 +360,23 @@ def create_empty_df(spark, path): conf={'spark.rapids.sql.format.orc.write.enabled': True}) +hold_gpu_configs = [True, False] +@pytest.mark.parametrize('hold_gpu', hold_gpu_configs, ids=idfn) +def test_async_writer(spark_tmp_path, hold_gpu): + data_path = spark_tmp_path + '/ORC_DATA' + num_rows = 2048 + num_cols = 10 + orc_gen = [int_gen for _ in range(num_cols)] + gen_list = [('_c' + str(i), gen) for i, gen in enumerate(orc_gen)] + assert_gpu_and_cpu_writes_are_equal_collect( + lambda spark, path: gen_df(spark, gen_list, length=num_rows).write.orc(path), + lambda spark, path: spark.read.orc(path).orderBy([('_c' + str(i)) for i in range(num_cols)]), + data_path, + conf={"spark.rapids.sql.asyncWrite.queryOutput.enabled": "true", + "spark.rapids.sql.batchSizeBytes": 4 * num_cols * 100, # 100 rows per batch + "spark.rapids.sql.queryOutput.holdGpuInTask": hold_gpu}) + + @ignore_order @pytest.mark.skipif(is_before_spark_320(), reason="is only supported in Spark 320+") def test_concurrent_writer(spark_tmp_path): diff --git a/integration_tests/src/main/python/parquet_write_test.py b/integration_tests/src/main/python/parquet_write_test.py index e5719d267b4..9b43fabd26d 100644 --- a/integration_tests/src/main/python/parquet_write_test.py +++ b/integration_tests/src/main/python/parquet_write_test.py @@ -705,8 +705,8 @@ def test_async_writer(spark_tmp_path, hold_gpu): parquet_gen = [int_gen for _ in range(num_cols)] gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gen)] assert_gpu_and_cpu_writes_are_equal_collect( - lambda spark, path: gen_df(spark, gen_list, length=num_rows).coalesce(1).write.parquet(path), - lambda spark, path: spark.read.parquet(path), + lambda spark, path: gen_df(spark, gen_list, length=num_rows).write.parquet(path), + lambda spark, path: spark.read.parquet(path).orderBy([('_c' + str(i)) for i in range(num_cols)]), data_path, copy_and_update( writer_confs, diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index c4199e3ea75..aecf3dad2b3 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -2453,7 +2453,7 @@ val SHUFFLE_COMPRESSION_LZ4_CHUNK_SIZE = conf("spark.rapids.shuffle.compression. .doc("Option to turn on the async query output write. During the final output write, the " + "task first copies the output to the host memory, and then writes it into the storage. " + "When this option is enabled, the task will asynchronously write the output in the host " + - "memory to the storage. Only the Parquet format is supported currently.") + "memory to the storage. Only the Parquet and ORC formats are supported currently.") .internal() .booleanConf .createWithDefault(false) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala index 5ac2aa1fe98..422f6c2337e 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala @@ -169,7 +169,8 @@ class GpuOrcFileFormat extends ColumnarFileFormat with Logging { options: Map[String, String], dataSchema: StructType): ColumnarOutputWriterFactory = { - val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf) + val sqlConf = sparkSession.sessionState.conf + val orcOptions = new OrcOptions(options, sqlConf) val conf = job.getConfiguration @@ -180,12 +181,18 @@ class GpuOrcFileFormat extends ColumnarFileFormat with Logging { conf.asInstanceOf[JobConf] .setOutputFormat(classOf[org.apache.orc.mapred.OrcOutputFormat[OrcStruct]]) + val asyncOutputWriteEnabled = RapidsConf.ENABLE_ASYNC_OUTPUT_WRITE.get(sqlConf) + // holdGpuBetweenBatches is on by default if asyncOutputWriteEnabled is on + val holdGpuBetweenBatches = RapidsConf.ASYNC_QUERY_OUTPUT_WRITE_HOLD_GPU_IN_TASK.get(sqlConf) + .getOrElse(asyncOutputWriteEnabled) + new ColumnarOutputWriterFactory { override def newInstance(path: String, dataSchema: StructType, context: TaskAttemptContext, debugOutputPath: Option[String]): ColumnarOutputWriter = { - new GpuOrcWriter(path, dataSchema, context, debugOutputPath) + new GpuOrcWriter(path, dataSchema, context, debugOutputPath, holdGpuBetweenBatches, + asyncOutputWriteEnabled) } override def getFileExtension(context: TaskAttemptContext): String = { @@ -204,11 +211,15 @@ class GpuOrcFileFormat extends ColumnarFileFormat with Logging { } } -class GpuOrcWriter(override val path: String, - dataSchema: StructType, - context: TaskAttemptContext, - debugOutputPath: Option[String]) - extends ColumnarOutputWriter(context, dataSchema, "ORC", true, debugOutputPath) { +class GpuOrcWriter( + override val path: String, + dataSchema: StructType, + context: TaskAttemptContext, + debugOutputPath: Option[String], + holdGpuBetweenBatches: Boolean, + useAsyncWrite: Boolean) + extends ColumnarOutputWriter(context, dataSchema, "ORC", true, debugOutputPath, + holdGpuBetweenBatches, useAsyncWrite) { override val tableWriter: TableWriter = { val builder = SchemaUtils