Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-11691][SQL] Support setting hadoop compression codecs in DataFrameWriter#option #11324

Closed
wants to merge 12 commits into from
22 changes: 13 additions & 9 deletions core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.xerial.snappy.{Snappy, SnappyInputStream, SnappyOutputStream}

import org.apache.spark.SparkConf
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils
import org.apache.spark.util.{ShortCompressionCodecNameMapper, Utils}

/**
* :: DeveloperApi ::
Expand Down Expand Up @@ -53,10 +53,14 @@ private[spark] object CompressionCodec {
|| codec.isInstanceOf[LZ4CompressionCodec])
}

private val shortCompressionCodecNames = Map(
"lz4" -> classOf[LZ4CompressionCodec].getName,
"lzf" -> classOf[LZFCompressionCodec].getName,
"snappy" -> classOf[SnappyCompressionCodec].getName)
/** Maps the short versions of compression codec names to fully-qualified class names. */
private val shortCompressionCodecNameMapper = new ShortCompressionCodecNameMapper {
override def lz4: Option[String] = Some(classOf[LZ4CompressionCodec].getName)
override def lzf: Option[String] = Some(classOf[LZFCompressionCodec].getName)
override def snappy: Option[String] = Some(classOf[SnappyCompressionCodec].getName)
}

private val shortCompressionCodecMap = shortCompressionCodecNameMapper.getAsMap

def getCodecName(conf: SparkConf): String = {
conf.get(configKey, DEFAULT_COMPRESSION_CODEC)
Expand All @@ -67,7 +71,7 @@ private[spark] object CompressionCodec {
}

def createCodec(conf: SparkConf, codecName: String): CompressionCodec = {
val codecClass = shortCompressionCodecNames.getOrElse(codecName.toLowerCase, codecName)
val codecClass = shortCompressionCodecNameMapper.get(codecName).getOrElse(codecName)
val codec = try {
val ctor = Utils.classForName(codecClass).getConstructor(classOf[SparkConf])
Some(ctor.newInstance(conf).asInstanceOf[CompressionCodec])
Expand All @@ -84,18 +88,18 @@ private[spark] object CompressionCodec {
* If it is already a short name, just return it.
*/
def getShortName(codecName: String): String = {
if (shortCompressionCodecNames.contains(codecName)) {
if (shortCompressionCodecMap.contains(codecName)) {
codecName
} else {
shortCompressionCodecNames
shortCompressionCodecMap
.collectFirst { case (k, v) if v == codecName => k }
.getOrElse { throw new IllegalArgumentException(s"No short name for codec $codecName.") }
}
}

val FALLBACK_COMPRESSION_CODEC = "snappy"
val DEFAULT_COMPRESSION_CODEC = "lz4"
val ALL_COMPRESSION_CODECS = shortCompressionCodecNames.values.toSeq
val ALL_COMPRESSION_CODECS = shortCompressionCodecMap.values.toSeq
}

/**
Expand Down
45 changes: 45 additions & 0 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,51 @@ private[spark] object CallSite {
val empty = CallSite("", "")
}

/** An utility class to map short compression codec names to qualified ones. */
private[spark] class ShortCompressionCodecNameMapper {

def get(codecName: String): Option[String] = codecName.toLowerCase match {
case "none" => none
case "uncompressed" => uncompressed
case "bzip2" => bzip2
case "deflate" => deflate
case "gzip" => gzip
case "lzo" => lzo
case "lz4" => lz4
case "lzf" => lzf
case "snappy" => snappy
case _ => None
}

def getAsMap: Map[String, String] = {
Seq(
("none", none),
("uncompressed", uncompressed),
("bzip2", bzip2),
("deflate", deflate),
("gzip", gzip),
("lzo", lzo),
("lz4", lz4),
("lzf", lzf),
("snappy", snappy)
).flatMap { case (shortCodecName, codecName) =>
if (codecName.isDefined) Some(shortCodecName, codecName.get) else None
}.toMap
}

// To support short codec names, derived classes need to override the methods below that return
// corresponding qualified codec names.
def none: Option[String] = None
def uncompressed: Option[String] = None
def bzip2: Option[String] = None
def deflate: Option[String] = None
def gzip: Option[String] = None
def lzo: Option[String] = None
def lz4: Option[String] = None
def lzf: Option[String] = None
def snappy: Option[String] = None
}

/**
* Various utility methods used by Spark.
*/
Expand Down
30 changes: 15 additions & 15 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources
import java.io.IOException

import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat

Expand Down Expand Up @@ -58,7 +59,8 @@ import org.apache.spark.util.Utils
private[sql] case class InsertIntoHadoopFsRelation(
@transient relation: HadoopFsRelation,
@transient query: LogicalPlan,
mode: SaveMode)
mode: SaveMode,
compressionCodec: Option[Class[_ <: CompressionCodec]] = None)
extends RunnableCommand {

override def run(sqlContext: SQLContext): Seq[Row] = {
Expand Down Expand Up @@ -126,7 +128,7 @@ private[sql] case class InsertIntoHadoopFsRelation(
""".stripMargin)

val writerContainer = if (partitionColumns.isEmpty && relation.maybeBucketSpec.isEmpty) {
new DefaultWriterContainer(relation, job, isAppend)
new DefaultWriterContainer(relation, job, isAppend, compressionCodec)
} else {
val output = df.queryExecution.executedPlan.output
val (partitionOutput, dataOutput) =
Expand All @@ -140,7 +142,8 @@ private[sql] case class InsertIntoHadoopFsRelation(
output,
PartitioningUtils.DEFAULT_PARTITION_NAME,
sqlContext.conf.getConf(SQLConf.PARTITION_MAX_FILES),
isAppend)
isAppend,
compressionCodec)
}

// This call shouldn't be put into the `try` block below because it only initializes and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import scala.language.{existentials, implicitConversions}
import scala.util.{Failure, Success, Try}

import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.compress._
import org.apache.hadoop.util.StringUtils

import org.apache.spark.Logging
Expand All @@ -32,7 +33,7 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext}
import org.apache.spark.sql.execution.streaming.{Sink, Source}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{CalendarIntervalType, StructType}
import org.apache.spark.util.Utils
import org.apache.spark.util.{ShortCompressionCodecNameMapper, Utils}

case class ResolvedDataSource(provider: Class[_], relation: BaseRelation)

Expand All @@ -49,6 +50,14 @@ object ResolvedDataSource extends Logging {
"org.apache.spark.sql.parquet.DefaultSource" -> classOf[parquet.DefaultSource].getCanonicalName
)

/** Maps the short versions of compression codec names to fully-qualified class names. */
private val hadoopShortCodecNameMapper = new ShortCompressionCodecNameMapper {
override def bzip2: Option[String] = Some(classOf[BZip2Codec].getCanonicalName)
override def deflate: Option[String] = Some(classOf[DeflateCodec].getCanonicalName)
override def gzip: Option[String] = Some(classOf[GzipCodec].getCanonicalName)
override def snappy: Option[String] = Some(classOf[SnappyCodec].getCanonicalName)
}

/** Given a provider name, look up the data source class definition. */
def lookupDataSource(provider0: String): Class[_] = {
val provider = backwardCompatibilityMap.getOrElse(provider0, provider0)
Expand Down Expand Up @@ -286,14 +295,25 @@ object ResolvedDataSource extends Logging {
bucketSpec,
caseInsensitiveOptions)

val compressionCodec = options
.get("compressionCodec")
.map { codecName =>
val codecFactory = new CompressionCodecFactory(
sqlContext.sparkContext.hadoopConfiguration)
val resolvedCodecName = hadoopShortCodecNameMapper.get(codecName).getOrElse(codecName)
Option(codecFactory.getCodecClassByName(resolvedCodecName))
}
.getOrElse(None)

// For partitioned relation r, r.schema's column ordering can be different from the column
// ordering of data.logicalPlan (partition columns are all moved after data column). This
// will be adjusted within InsertIntoHadoopFsRelation.
sqlContext.executePlan(
InsertIntoHadoopFsRelation(
r,
data.logicalPlan,
mode)).toRdd
mode,
compressionCodec)).toRdd
r
case _ =>
sys.error(s"${clazz.getCanonicalName} does not allow create table as select.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources
import java.util.{Date, UUID}

import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter}
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
Expand All @@ -39,7 +40,8 @@ import org.apache.spark.util.SerializableConfiguration
private[sql] abstract class BaseWriterContainer(
@transient val relation: HadoopFsRelation,
@transient private val job: Job,
isAppend: Boolean)
isAppend: Boolean,
compressionCodec: Option[Class[_ <: CompressionCodec]] = None)
extends Logging with Serializable {

protected val dataSchema = relation.dataSchema
Expand Down Expand Up @@ -207,6 +209,11 @@ private[sql] abstract class BaseWriterContainer(
serializableConf.value.set("mapred.task.id", taskAttemptId.toString)
serializableConf.value.setBoolean("mapred.task.is.map", true)
serializableConf.value.setInt("mapred.task.partition", 0)
compressionCodec.map { codecClass =>
serializableConf.value.set("mapred.output.compress", "true")
serializableConf.value.set("mapred.output.compression.codec", codecClass.getCanonicalName)
serializableConf.value.set("mapred.output.compression.type", "BLOCK")
}
}

def commitTask(): Unit = {
Expand Down Expand Up @@ -239,8 +246,9 @@ private[sql] abstract class BaseWriterContainer(
private[sql] class DefaultWriterContainer(
relation: HadoopFsRelation,
job: Job,
isAppend: Boolean)
extends BaseWriterContainer(relation, job, isAppend) {
isAppend: Boolean,
compressionCodec: Option[Class[_ <: CompressionCodec]])
extends BaseWriterContainer(relation, job, isAppend, compressionCodec) {

def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = {
executorSideSetup(taskContext)
Expand Down Expand Up @@ -308,8 +316,9 @@ private[sql] class DynamicPartitionWriterContainer(
inputSchema: Seq[Attribute],
defaultPartitionName: String,
maxOpenFiles: Int,
isAppend: Boolean)
extends BaseWriterContainer(relation, job, isAppend) {
isAppend: Boolean,
compressionCodec: Option[Class[_ <: CompressionCodec]])
extends BaseWriterContainer(relation, job, isAppend, compressionCodec) {

private val bucketSpec = relation.maybeBucketSpec

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ import org.apache.spark.sql.catalyst.util.LegacyTypeStringParser
import org.apache.spark.sql.execution.datasources.{PartitionSpec, _}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.{SerializableConfiguration, Utils}
import org.apache.spark.util.{SerializableConfiguration, ShortCompressionCodecNameMapper, Utils}

private[sql] class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegister {

Expand Down Expand Up @@ -283,10 +283,8 @@ private[sql] class ParquetRelation(
conf.set(
ParquetOutputFormat.COMPRESSION,
ParquetRelation
.shortParquetCompressionCodecNames
.getOrElse(
sqlContext.conf.parquetCompressionCodec.toUpperCase,
CompressionCodecName.UNCOMPRESSED).name())
.parquetShortCodecNameMapper.get(sqlContext.conf.parquetCompressionCodec)
.getOrElse(CompressionCodecName.UNCOMPRESSED.name()))

new BucketedOutputWriterFactory {
override def newInstance(
Expand Down Expand Up @@ -902,11 +900,12 @@ private[sql] object ParquetRelation extends Logging {
}
}

// The parquet compression short names
val shortParquetCompressionCodecNames = Map(
"NONE" -> CompressionCodecName.UNCOMPRESSED,
"UNCOMPRESSED" -> CompressionCodecName.UNCOMPRESSED,
"SNAPPY" -> CompressionCodecName.SNAPPY,
"GZIP" -> CompressionCodecName.GZIP,
"LZO" -> CompressionCodecName.LZO)
/** Maps the short versions of compression codec names to qualified compression names. */
val parquetShortCodecNameMapper = new ShortCompressionCodecNameMapper {
override def none: Option[String] = Some(CompressionCodecName.UNCOMPRESSED.name())
override def uncompressed: Option[String] = Some(CompressionCodecName.UNCOMPRESSED.name())
override def gzip: Option[String] = Some(CompressionCodecName.GZIP.name())
override def lzo: Option[String] = Some(CompressionCodecName.LZO.name())
override def snappy: Option[String] = Some(CompressionCodecName.SNAPPY.name())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.execution.datasources.text

import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveMode}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -57,6 +57,15 @@ class TextSuite extends QueryTest with SharedSQLContext {
}
}

test("compression") {
Seq("bzip2", "deflate", "gzip").map { codecName =>
val tempDirPath = Utils.createTempDir().getAbsolutePath
val df = sqlContext.read.text(testFile)
df.write.option("compressionCodec", codecName).mode(SaveMode.Overwrite).text(tempDirPath)
verifyFrame(sqlContext.read.text(tempDirPath))
}
}

private def testFile: String = {
Thread.currentThread().getContextClassLoader.getResource("text-suite.txt").toString
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest {

val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt")
df.queryExecution.sparkPlan match {
case ExecutedCommand(InsertIntoHadoopFsRelation(_: ParquetRelation, _, _)) => // OK
case ExecutedCommand(InsertIntoHadoopFsRelation(_: ParquetRelation, _, _, _)) => // OK
case o => fail("test_insert_parquet should be converted to a " +
s"${classOf[ParquetRelation].getCanonicalName} and " +
s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan. " +
Expand Down Expand Up @@ -336,7 +336,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest {

val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array")
df.queryExecution.sparkPlan match {
case ExecutedCommand(InsertIntoHadoopFsRelation(r: ParquetRelation, _, _)) => // OK
case ExecutedCommand(InsertIntoHadoopFsRelation(r: ParquetRelation, _, _, _)) => // OK
case o => fail("test_insert_parquet should be converted to a " +
s"${classOf[ParquetRelation].getCanonicalName} and " +
s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan." +
Expand Down