Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-12901][SQL] Refactor options for JSON and CSV datasource (not case class and same format). #10895

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@ package org.apache.spark.sql.execution.datasources.csv

import java.nio.charset.Charset

import org.apache.hadoop.io.compress._

import org.apache.spark.Logging
import org.apache.spark.sql.execution.datasources.CompressionCodecs
import org.apache.spark.util.Utils

private[sql] case class CSVParameters(@transient parameters: Map[String, String]) extends Logging {
private[sql] class CSVOptions(
@transient parameters: Map[String, String])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need @transient private val at here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that's the right fix.

extends Logging with Serializable {

private def getChar(paramName: String, default: Char): Char = {
val paramValue = parameters.get(paramName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.Logging
* @param params Parameters object
* @param headers headers for the columns
*/
private[sql] abstract class CsvReader(params: CSVParameters, headers: Seq[String]) {
private[sql] abstract class CsvReader(params: CSVOptions, headers: Seq[String]) {

protected lazy val parser: CsvParser = {
val settings = new CsvParserSettings()
Expand Down Expand Up @@ -58,7 +58,7 @@ private[sql] abstract class CsvReader(params: CSVParameters, headers: Seq[String
* @param params Parameters object for configuration
* @param headers headers for columns
*/
private[sql] class LineCsvWriter(params: CSVParameters, headers: Seq[String]) extends Logging {
private[sql] class LineCsvWriter(params: CSVOptions, headers: Seq[String]) extends Logging {
private val writerSettings = new CsvWriterSettings
private val format = writerSettings.getFormat

Expand Down Expand Up @@ -93,7 +93,7 @@ private[sql] class LineCsvWriter(params: CSVParameters, headers: Seq[String]) ex
*
* @param params Parameters object
*/
private[sql] class LineCsvReader(params: CSVParameters)
private[sql] class LineCsvReader(params: CSVOptions)
extends CsvReader(params, null) {
/**
* parse a line
Expand All @@ -118,7 +118,7 @@ private[sql] class LineCsvReader(params: CSVParameters)
*/
private[sql] class BulkCsvReader(
iter: Iterator[String],
params: CSVParameters,
params: CSVOptions,
headers: Seq[String])
extends CsvReader(params, headers) with Iterator[Array[String]] {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ private[csv] class CSVRelation(
private val maybeDataSchema: Option[StructType],
override val userDefinedPartitionColumns: Option[StructType],
private val parameters: Map[String, String])
(@transient val sqlContext: SQLContext) extends HadoopFsRelation with Serializable {
(@transient val sqlContext: SQLContext) extends HadoopFsRelation {

override lazy val dataSchema: StructType = maybeDataSchema match {
case Some(structType) => structType
case None => inferSchema(paths)
}

private val params = new CSVParameters(parameters)
private val params = new CSVOptions(parameters)

@transient
private var cachedRDD: Option[RDD[String]] = None
Expand Down Expand Up @@ -170,7 +170,7 @@ object CSVRelation extends Logging {
file: RDD[String],
header: Seq[String],
firstLine: String,
params: CSVParameters): RDD[Array[String]] = {
params: CSVOptions): RDD[Array[String]] = {
// If header is set, make sure firstLine is materialized before sending to executors.
file.mapPartitionsWithIndex({
case (split, iter) => new BulkCsvReader(
Expand All @@ -186,7 +186,7 @@ object CSVRelation extends Logging {
requiredColumns: Array[String],
inputs: Array[FileStatus],
sqlContext: SQLContext,
params: CSVParameters): RDD[Row] = {
params: CSVOptions): RDD[Row] = {

val schemaFields = schema.fields
val requiredFields = StructType(requiredColumns.map(schema(_))).fields
Expand Down Expand Up @@ -249,7 +249,7 @@ object CSVRelation extends Logging {
}
}

private[sql] class CSVOutputWriterFactory(params: CSVParameters) extends OutputWriterFactory {
private[sql] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWriterFactory {
override def newInstance(
path: String,
dataSchema: StructType,
Expand All @@ -262,7 +262,7 @@ private[sql] class CsvOutputWriter(
path: String,
dataSchema: StructType,
context: TaskAttemptContext,
params: CSVParameters) extends OutputWriter with Logging {
params: CSVOptions) extends OutputWriter with Logging {

// create the Generator without separator inserted between 2 records
private[this] val text = new Text()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,30 @@ import org.apache.spark.sql.execution.datasources.CompressionCodecs
*
* Most of these map directly to Jackson's internal options, specified in [[JsonParser.Feature]].
*/
case class JSONOptions(
samplingRatio: Double = 1.0,
primitivesAsString: Boolean = false,
allowComments: Boolean = false,
allowUnquotedFieldNames: Boolean = false,
allowSingleQuotes: Boolean = true,
allowNumericLeadingZeros: Boolean = false,
allowNonNumericNumbers: Boolean = false,
allowBackslashEscapingAnyCharacter: Boolean = false,
compressionCodec: Option[String] = None) {
private[sql] class JSONOptions(
@transient parameters: Map[String, String])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

extends Serializable {

val samplingRatio =
parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
val primitivesAsString =
parameters.get("primitivesAsString").map(_.toBoolean).getOrElse(false)
val allowComments =
parameters.get("allowComments").map(_.toBoolean).getOrElse(false)
val allowUnquotedFieldNames =
parameters.get("allowUnquotedFieldNames").map(_.toBoolean).getOrElse(false)
val allowSingleQuotes =
parameters.get("allowSingleQuotes").map(_.toBoolean).getOrElse(true)
val allowNumericLeadingZeros =
parameters.get("allowNumericLeadingZeros").map(_.toBoolean).getOrElse(false)
val allowNonNumericNumbers =
parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true)
val allowBackslashEscapingAnyCharacter =
parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false)
val compressionCodec = {
val name = parameters.get("compression").orElse(parameters.get("codec"))
name.map(CompressionCodecs.getCodecClassName)
}

/** Sets config options on a Jackson [[JsonFactory]]. */
def setJacksonOptions(factory: JsonFactory): Unit = {
Expand All @@ -48,28 +62,3 @@ case class JSONOptions(
allowBackslashEscapingAnyCharacter)
}
}

object JSONOptions {
def createFromConfigMap(parameters: Map[String, String]): JSONOptions = JSONOptions(
samplingRatio =
parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0),
primitivesAsString =
parameters.get("primitivesAsString").map(_.toBoolean).getOrElse(false),
allowComments =
parameters.get("allowComments").map(_.toBoolean).getOrElse(false),
allowUnquotedFieldNames =
parameters.get("allowUnquotedFieldNames").map(_.toBoolean).getOrElse(false),
allowSingleQuotes =
parameters.get("allowSingleQuotes").map(_.toBoolean).getOrElse(true),
allowNumericLeadingZeros =
parameters.get("allowNumericLeadingZeros").map(_.toBoolean).getOrElse(false),
allowNonNumericNumbers =
parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true),
allowBackslashEscapingAnyCharacter =
parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false),
compressionCodec = {
val name = parameters.get("compression").orElse(parameters.get("codec"))
name.map(CompressionCodecs.getCodecClassName)
}
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ private[sql] class JSONRelation(
(@transient val sqlContext: SQLContext)
extends HadoopFsRelation(maybePartitionSpec, parameters) {

val options: JSONOptions = JSONOptions.createFromConfigMap(parameters)
val options: JSONOptions = new JSONOptions(parameters)

/** Constraints to be imposed on schema to be stored. */
private def checkConstraints(schema: StructType): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1240,7 +1240,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {

test("SPARK-6245 JsonRDD.inferSchema on empty RDD") {
// This is really a test that it doesn't throw an exception
val emptySchema = InferSchema.infer(empty, "", JSONOptions())
val emptySchema = InferSchema.infer(empty, "", new JSONOptions(Map()))
assert(StructType(Seq()) === emptySchema)
}

Expand All @@ -1264,7 +1264,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}

test("SPARK-8093 Erase empty structs") {
val emptySchema = InferSchema.infer(emptyRecords, "", JSONOptions())
val emptySchema = InferSchema.infer(emptyRecords, "", new JSONOptions(Map()))
assert(StructType(Seq()) === emptySchema)
}

Expand Down