From 3770ffba766f014dc7a8b2332d7d8ab40dc17750 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 25 Jan 2016 13:54:44 +0900 Subject: [PATCH] Refactor CSVParamters and JSONOptions. --- .../{CSVParameters.scala => CSVOptions.scala} | 7 +-- .../execution/datasources/csv/CSVParser.scala | 8 +-- .../datasources/csv/CSVRelation.scala | 12 ++-- .../datasources/json/JSONOptions.scala | 59 ++++++++----------- .../datasources/json/JSONRelation.scala | 2 +- .../datasources/json/JsonSuite.scala | 4 +- 6 files changed, 40 insertions(+), 52 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/{CSVParameters.scala => CSVOptions.scala} (95%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala similarity index 95% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 0278675aa61b0..5d0e99d7601dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -19,13 +19,12 @@ package org.apache.spark.sql.execution.datasources.csv import java.nio.charset.Charset -import org.apache.hadoop.io.compress._ - import org.apache.spark.Logging import org.apache.spark.sql.execution.datasources.CompressionCodecs -import org.apache.spark.util.Utils -private[sql] case class CSVParameters(@transient parameters: Map[String, String]) extends Logging { +private[sql] class CSVOptions( + @transient parameters: Map[String, String]) + extends Logging with Serializable { private def getChar(paramName: String, default: Char): Char = { val paramValue = parameters.get(paramName) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala index ba1cc42f3e446..8f1421844c648 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala @@ -29,7 +29,7 @@ import org.apache.spark.Logging * @param params Parameters object * @param headers headers for the columns */ -private[sql] abstract class CsvReader(params: CSVParameters, headers: Seq[String]) { +private[sql] abstract class CsvReader(params: CSVOptions, headers: Seq[String]) { protected lazy val parser: CsvParser = { val settings = new CsvParserSettings() @@ -58,7 +58,7 @@ private[sql] abstract class CsvReader(params: CSVParameters, headers: Seq[String * @param params Parameters object for configuration * @param headers headers for columns */ -private[sql] class LineCsvWriter(params: CSVParameters, headers: Seq[String]) extends Logging { +private[sql] class LineCsvWriter(params: CSVOptions, headers: Seq[String]) extends Logging { private val writerSettings = new CsvWriterSettings private val format = writerSettings.getFormat @@ -93,7 +93,7 @@ private[sql] class LineCsvWriter(params: CSVParameters, headers: Seq[String]) ex * * @param params Parameters object */ -private[sql] class LineCsvReader(params: CSVParameters) +private[sql] class LineCsvReader(params: CSVOptions) extends CsvReader(params, null) { /** * parse a line @@ -118,7 +118,7 @@ private[sql] class LineCsvReader(params: CSVParameters) */ private[sql] class BulkCsvReader( iter: Iterator[String], - params: CSVParameters, + params: CSVOptions, headers: Seq[String]) extends CsvReader(params, headers) with Iterator[Array[String]] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index 1502501c3b89e..5959f7cc5051b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -43,14 +43,14 @@ private[csv] class CSVRelation( private val maybeDataSchema: Option[StructType], override val userDefinedPartitionColumns: Option[StructType], private val parameters: Map[String, String]) - (@transient val sqlContext: SQLContext) extends HadoopFsRelation with Serializable { + (@transient val sqlContext: SQLContext) extends HadoopFsRelation { override lazy val dataSchema: StructType = maybeDataSchema match { case Some(structType) => structType case None => inferSchema(paths) } - private val params = new CSVParameters(parameters) + private val params = new CSVOptions(parameters) @transient private var cachedRDD: Option[RDD[String]] = None @@ -170,7 +170,7 @@ object CSVRelation extends Logging { file: RDD[String], header: Seq[String], firstLine: String, - params: CSVParameters): RDD[Array[String]] = { + params: CSVOptions): RDD[Array[String]] = { // If header is set, make sure firstLine is materialized before sending to executors. file.mapPartitionsWithIndex({ case (split, iter) => new BulkCsvReader( @@ -186,7 +186,7 @@ object CSVRelation extends Logging { requiredColumns: Array[String], inputs: Array[FileStatus], sqlContext: SQLContext, - params: CSVParameters): RDD[Row] = { + params: CSVOptions): RDD[Row] = { val schemaFields = schema.fields val requiredFields = StructType(requiredColumns.map(schema(_))).fields @@ -249,7 +249,7 @@ object CSVRelation extends Logging { } } -private[sql] class CSVOutputWriterFactory(params: CSVParameters) extends OutputWriterFactory { +private[sql] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWriterFactory { override def newInstance( path: String, dataSchema: StructType, @@ -262,7 +262,7 @@ private[sql] class CsvOutputWriter( path: String, dataSchema: StructType, context: TaskAttemptContext, - params: CSVParameters) extends OutputWriter with Logging { + params: CSVOptions) extends OutputWriter with Logging { // create the Generator without separator inserted between 2 records private[this] val text = new Text() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala index e74a76c532367..0a083b5e3598e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala @@ -26,16 +26,30 @@ import org.apache.spark.sql.execution.datasources.CompressionCodecs * * Most of these map directly to Jackson's internal options, specified in [[JsonParser.Feature]]. */ -case class JSONOptions( - samplingRatio: Double = 1.0, - primitivesAsString: Boolean = false, - allowComments: Boolean = false, - allowUnquotedFieldNames: Boolean = false, - allowSingleQuotes: Boolean = true, - allowNumericLeadingZeros: Boolean = false, - allowNonNumericNumbers: Boolean = false, - allowBackslashEscapingAnyCharacter: Boolean = false, - compressionCodec: Option[String] = None) { +private[sql] class JSONOptions( + @transient parameters: Map[String, String]) + extends Serializable { + + val samplingRatio = + parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + val primitivesAsString = + parameters.get("primitivesAsString").map(_.toBoolean).getOrElse(false) + val allowComments = + parameters.get("allowComments").map(_.toBoolean).getOrElse(false) + val allowUnquotedFieldNames = + parameters.get("allowUnquotedFieldNames").map(_.toBoolean).getOrElse(false) + val allowSingleQuotes = + parameters.get("allowSingleQuotes").map(_.toBoolean).getOrElse(true) + val allowNumericLeadingZeros = + parameters.get("allowNumericLeadingZeros").map(_.toBoolean).getOrElse(false) + val allowNonNumericNumbers = + parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true) + val allowBackslashEscapingAnyCharacter = + parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false) + val compressionCodec = { + val name = parameters.get("compression").orElse(parameters.get("codec")) + name.map(CompressionCodecs.getCodecClassName) + } /** Sets config options on a Jackson [[JsonFactory]]. */ def setJacksonOptions(factory: JsonFactory): Unit = { @@ -48,28 +62,3 @@ case class JSONOptions( allowBackslashEscapingAnyCharacter) } } - -object JSONOptions { - def createFromConfigMap(parameters: Map[String, String]): JSONOptions = JSONOptions( - samplingRatio = - parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0), - primitivesAsString = - parameters.get("primitivesAsString").map(_.toBoolean).getOrElse(false), - allowComments = - parameters.get("allowComments").map(_.toBoolean).getOrElse(false), - allowUnquotedFieldNames = - parameters.get("allowUnquotedFieldNames").map(_.toBoolean).getOrElse(false), - allowSingleQuotes = - parameters.get("allowSingleQuotes").map(_.toBoolean).getOrElse(true), - allowNumericLeadingZeros = - parameters.get("allowNumericLeadingZeros").map(_.toBoolean).getOrElse(false), - allowNonNumericNumbers = - parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true), - allowBackslashEscapingAnyCharacter = - parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false), - compressionCodec = { - val name = parameters.get("compression").orElse(parameters.get("codec")) - name.map(CompressionCodecs.getCodecClassName) - } - ) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 93727abcc7de9..c893558136549 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -75,7 +75,7 @@ private[sql] class JSONRelation( (@transient val sqlContext: SQLContext) extends HadoopFsRelation(maybePartitionSpec, parameters) { - val options: JSONOptions = JSONOptions.createFromConfigMap(parameters) + val options: JSONOptions = new JSONOptions(parameters) /** Constraints to be imposed on schema to be stored. */ private def checkConstraints(schema: StructType): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index d22fa7905aec1..00eaeb0d34e87 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1240,7 +1240,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("SPARK-6245 JsonRDD.inferSchema on empty RDD") { // This is really a test that it doesn't throw an exception - val emptySchema = InferSchema.infer(empty, "", JSONOptions()) + val emptySchema = InferSchema.infer(empty, "", new JSONOptions(Map())) assert(StructType(Seq()) === emptySchema) } @@ -1264,7 +1264,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-8093 Erase empty structs") { - val emptySchema = InferSchema.infer(emptyRecords, "", JSONOptions()) + val emptySchema = InferSchema.infer(emptyRecords, "", new JSONOptions(Map())) assert(StructType(Seq()) === emptySchema) }