diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CompressionCodecs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CompressionCodecs.scala index e683a95ed2aef..bc8ef4ad7e236 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CompressionCodecs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CompressionCodecs.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.{BZip2Codec, GzipCodec, Lz4Codec, SnappyCodec} import org.apache.spark.util.Utils @@ -44,4 +46,16 @@ private[datasources] object CompressionCodecs { s"is not available. Known codecs are ${shortCompressionCodecNames.keys.mkString(", ")}.") } } + + /** + * Set compression configurations to Hadoop `Configuration`. + * `codec` should be a full class path + */ + def setCodecConfiguration(conf: Configuration, codec: String): Unit = { + conf.set("mapreduce.output.fileoutputformat.compress", "true") + conf.set("mapreduce.output.fileoutputformat.compress.type", CompressionType.BLOCK.toString) + conf.set("mapreduce.output.fileoutputformat.compress.codec", codec) + conf.set("mapreduce.map.output.compress", "true") + conf.set("mapreduce.map.output.compress.codec", codec) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index da945c44cde1c..e9afee1cc5142 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -24,7 +24,6 @@ import scala.util.control.NonFatal import com.google.common.base.Objects import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{LongWritable, NullWritable, Text} -import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.hadoop.mapreduce.RecordWriter @@ -34,6 +33,7 @@ import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.CompressionCodecs import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -50,16 +50,16 @@ private[sql] class CSVRelation( case None => inferSchema(paths) } - private val params = new CSVOptions(parameters) + private val options = new CSVOptions(parameters) @transient private var cachedRDD: Option[RDD[String]] = None private def readText(location: String): RDD[String] = { - if (Charset.forName(params.charset) == Charset.forName("UTF-8")) { + if (Charset.forName(options.charset) == Charset.forName("UTF-8")) { sqlContext.sparkContext.textFile(location) } else { - val charset = params.charset + val charset = options.charset sqlContext.sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](location) .mapPartitions { _.map { pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset) @@ -81,8 +81,8 @@ private[sql] class CSVRelation( private def tokenRdd(header: Array[String], inputPaths: Array[String]): RDD[Array[String]] = { val rdd = baseRdd(inputPaths) // Make sure firstLine is materialized before sending to executors - val firstLine = if (params.headerFlag) findFirstLine(rdd) else null - CSVRelation.univocityTokenizer(rdd, header, firstLine, params) + val firstLine = if (options.headerFlag) findFirstLine(rdd) else null + CSVRelation.univocityTokenizer(rdd, header, firstLine, options) } /** @@ -96,20 +96,16 @@ private[sql] class CSVRelation( val pathsString = inputs.map(_.getPath.toUri.toString) val header = schema.fields.map(_.name) val tokenizedRdd = tokenRdd(header, pathsString) - CSVRelation.parseCsv(tokenizedRdd, schema, requiredColumns, inputs, sqlContext, params) + CSVRelation.parseCsv(tokenizedRdd, schema, requiredColumns, inputs, sqlContext, options) } override def prepareJobForWrite(job: Job): OutputWriterFactory = { val conf = job.getConfiguration - params.compressionCodec.foreach { codec => - conf.set("mapreduce.output.fileoutputformat.compress", "true") - conf.set("mapreduce.output.fileoutputformat.compress.type", CompressionType.BLOCK.toString) - conf.set("mapreduce.output.fileoutputformat.compress.codec", codec) - conf.set("mapreduce.map.output.compress", "true") - conf.set("mapreduce.map.output.compress.codec", codec) + options.compressionCodec.foreach { codec => + CompressionCodecs.setCodecConfiguration(conf, codec) } - new CSVOutputWriterFactory(params) + new CSVOutputWriterFactory(options) } override def hashCode(): Int = Objects.hashCode(paths.toSet, dataSchema, schema, partitionColumns) @@ -129,17 +125,17 @@ private[sql] class CSVRelation( private def inferSchema(paths: Array[String]): StructType = { val rdd = baseRdd(paths) val firstLine = findFirstLine(rdd) - val firstRow = new LineCsvReader(params).parseLine(firstLine) + val firstRow = new LineCsvReader(options).parseLine(firstLine) - val header = if (params.headerFlag) { + val header = if (options.headerFlag) { firstRow } else { firstRow.zipWithIndex.map { case (value, index) => s"C$index" } } val parsedRdd = tokenRdd(header, paths) - if (params.inferSchemaFlag) { - CSVInferSchema.infer(parsedRdd, header, params.nullValue) + if (options.inferSchemaFlag) { + CSVInferSchema.infer(parsedRdd, header, options.nullValue) } else { // By default fields are assumed to be StringType val schemaFields = header.map { fieldName => @@ -153,8 +149,8 @@ private[sql] class CSVRelation( * Returns the first line of the first non-empty file in path */ private def findFirstLine(rdd: RDD[String]): String = { - if (params.isCommentSet) { - val comment = params.comment.toString + if (options.isCommentSet) { + val comment = options.comment.toString rdd.filter { line => line.trim.nonEmpty && !line.startsWith(comment) }.first() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index c893558136549..28136911fe240 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -165,11 +165,7 @@ private[sql] class JSONRelation( override def prepareJobForWrite(job: Job): BucketedOutputWriterFactory = { val conf = job.getConfiguration options.compressionCodec.foreach { codec => - conf.set("mapreduce.output.fileoutputformat.compress", "true") - conf.set("mapreduce.output.fileoutputformat.compress.type", CompressionType.BLOCK.toString) - conf.set("mapreduce.output.fileoutputformat.compress.codec", codec) - conf.set("mapreduce.map.output.compress", "true") - conf.set("mapreduce.map.output.compress.codec", codec) + CompressionCodecs.setCodecConfiguration(conf, codec) } new BucketedOutputWriterFactory { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index 430257f60d9fe..60155b32349a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} -import org.apache.spark.sql.execution.datasources.PartitionSpec +import org.apache.spark.sql.execution.datasources.{CompressionCodecs, PartitionSpec} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -48,7 +48,7 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { partitionColumns: Option[StructType], parameters: Map[String, String]): HadoopFsRelation = { dataSchema.foreach(verifySchema) - new TextRelation(None, dataSchema, partitionColumns, paths)(sqlContext) + new TextRelation(None, dataSchema, partitionColumns, paths, parameters)(sqlContext) } override def shortName(): String = "text" @@ -114,6 +114,15 @@ private[sql] class TextRelation( /** Write path. */ override def prepareJobForWrite(job: Job): OutputWriterFactory = { + val conf = job.getConfiguration + val compressionCodec = { + val name = parameters.get("compression").orElse(parameters.get("codec")) + name.map(CompressionCodecs.getCodecClassName) + } + compressionCodec.foreach { codec => + CompressionCodecs.setCodecConfiguration(conf, codec) + } + new OutputWriterFactory { override def newInstance( path: String, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index f95272530d585..6ae42a30fb00c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -57,6 +57,21 @@ class TextSuite extends QueryTest with SharedSQLContext { } } + test("SPARK-13503 Support to specify the option for compression codec for TEXT") { + val df = sqlContext.read.text(testFile).withColumnRenamed("value", "adwrasdf") + + val tempFile = Utils.createTempDir() + tempFile.delete() + df.write + .option("compression", "gZiP") + .text(tempFile.getCanonicalPath) + val compressedFiles = tempFile.listFiles() + assert(compressedFiles.exists(_.getName.endsWith(".gz"))) + verifyFrame(sqlContext.read.text(tempFile.getCanonicalPath)) + + Utils.deleteRecursively(tempFile) + } + private def testFile: String = { Thread.currentThread().getContextClassLoader.getResource("text-suite.txt").toString }