Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-24959][SQL] Speed up count() for JSON and CSV #21909

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
bc4ce26
Added a benchmark for count()
MaxGekk Jul 28, 2018
91250d2
Added a CSV benchmark for count()
MaxGekk Jul 28, 2018
bdc5ea5
Speed up count()
MaxGekk Jul 28, 2018
d40f9bb
Updating CSV and JSON benchmarks for count()
MaxGekk Jul 28, 2018
abd8572
Fix benchmark's output
MaxGekk Jul 28, 2018
359c4fc
Uncomment other benchmarks
MaxGekk Jul 28, 2018
168eb99
A SQL config for bypassing parser in the case of empty schema
MaxGekk Aug 3, 2018
05c8dbb
Making Scala style checker happy
MaxGekk Aug 3, 2018
6248c01
Merge remote-tracking branch 'origin/master' into empty-schema-optimi…
MaxGekk Aug 3, 2018
0e245a7
Merge remote-tracking branch 'origin/master' into empty-schema-optimi…
MaxGekk Aug 5, 2018
4a8a2eb
Put config to the legacy namespace
MaxGekk Aug 5, 2018
3f8fc5e
Updating the migration guide
MaxGekk Aug 5, 2018
c40dc3d
Merge remote-tracking branch 'origin/master' into empty-schema-optimi…
MaxGekk Aug 14, 2018
da16234
Revert unnecessary changes
MaxGekk Aug 14, 2018
900bd0e
Test for malformed JSON input
MaxGekk Aug 14, 2018
f5f13fa
Test for malformed CSV input
MaxGekk Aug 14, 2018
12d50d0
Handle errors caused by wrong input
MaxGekk Aug 15, 2018
2f74059
Adding tests for count and wrong encoding for input json
MaxGekk Aug 15, 2018
6b98f3e
Migration guide is updated
MaxGekk Aug 15, 2018
2998363
Merge remote-tracking branch 'origin/master' into empty-schema-optimi…
MaxGekk Aug 16, 2018
6b34018
Fix a typo
MaxGekk Aug 16, 2018
3240405
Skip parsing for the PERMISSIVE mode only
MaxGekk Aug 17, 2018
2d8e754
Revert test for invalid encoding
MaxGekk Aug 17, 2018
96a94cc
Removing an unnecessary note in migration guide
MaxGekk Aug 17, 2018
50a0ef0
Removing the SQL config
MaxGekk Aug 18, 2018
050c8ce
Renaming optimizeEmptySchema to isMultiLine
MaxGekk Aug 18, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.json

import java.io.{ByteArrayOutputStream, CharConversionException}
import java.nio.charset.MalformedInputException

import scala.collection.mutable.ArrayBuffer
import scala.util.Try
Expand Down Expand Up @@ -402,7 +403,7 @@ class JacksonParser(
}
}
} catch {
case e @ (_: RuntimeException | _: JsonProcessingException) =>
case e @ (_: RuntimeException | _: JsonProcessingException | _: MalformedInputException) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change related, @MaxGekk? Let's don't add unrelated changes next time.

// JSON parser currently doesn't support partial results for corrupted records.
// For such records, all fields other than the field configured by
// `columnNameOfCorruptRecord` are set to `null`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
input => rawParser.parse(input, createParser, UTF8String.fromString),
parsedOptions.parseMode,
schema,
parsedOptions.columnNameOfCorruptRecord)
parsedOptions.columnNameOfCorruptRecord,
parsedOptions.multiLine)
iter.flatMap(parser.parse)
}
sparkSession.internalCreateDataFrame(parsed, schema, isStreaming = jsonDataset.isStreaming)
Expand Down Expand Up @@ -521,7 +522,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
input => Seq(rawParser.parse(input)),
parsedOptions.parseMode,
schema,
parsedOptions.columnNameOfCorruptRecord)
parsedOptions.columnNameOfCorruptRecord,
parsedOptions.multiLine)
iter.flatMap(parser.parse)
}
sparkSession.internalCreateDataFrame(parsed, schema, isStreaming = csvDataset.isStreaming)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@ import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.types.UTF8String

class FailureSafeParser[IN](
rawParser: IN => Seq[InternalRow],
mode: ParseMode,
schema: StructType,
columnNameOfCorruptRecord: String) {
columnNameOfCorruptRecord: String,
isMultiLine: Boolean) {

private val corruptFieldIndex = schema.getFieldIndex(columnNameOfCorruptRecord)
private val actualSchema = StructType(schema.filterNot(_.name == columnNameOfCorruptRecord))
Expand Down Expand Up @@ -56,9 +58,15 @@ class FailureSafeParser[IN](
}
}

private val skipParsing = !isMultiLine && mode == PermissiveMode && schema.isEmpty
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a big deal but I would leave a comment to explain why it's permissive and non-miltiline only. I assume counts are known when it's actually parsed for multiline cases, and counts should be given in any case when the mode is permissive, right?


def parse(input: IN): Iterator[InternalRow] = {
try {
rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () => null))
if (skipParsing) {
Iterator.single(InternalRow.empty)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Iterator.empty

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not the same. If you return empty iterator, count() will always return 0.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh yes my bad!

} else {
rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () => null))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there are broken records the parser can't parse, this skipping won't detect them?

Copy link
Member Author

@MaxGekk MaxGekk Aug 1, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. To detect them with 100% guarantee, the parser must fully parse such records and column values must be casted according to types in data schema. We actually don't do that due to the column pruning mechanisms in both datasources - CSV and JSON.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a test case for counting both CSV and JSON source when the files having broken records? Any behavior change after this PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... when the files having broken records?

Syntactically broken or semantically (wrong types for example)?

Any behavior change after this PR?

We have many tests in CSVSuite and JSONSuite for broken records. I have found behavior change in only one case: https://github.com/apache/spark/pull/21909/files#diff-fde14032b0e6ef8086461edf79a27c5dL2227 . This is due to Jackson parser touches a few first bytes in the input stream even if it is not called. Jackson checks encoding eagerly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

both?

If we introduce a behavior change, we need to document it in the migration guide and add a conf. Users can do the conf to revert back to the previous behaviors.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the tests

}
} catch {
case e: BadRecordException => mode match {
case PermissiveMode =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,19 +203,11 @@ class UnivocityParser(
}
}

private val doParse = if (requiredSchema.nonEmpty) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are the changes here https://github.com/apache/spark/pull/21909/files#diff-3a4dc120191f7052e5d98db11934bfb5R63 replacing the need for the requiredSchema.nonEmpty check ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The introduced optimization works in the case if multiLine is disable. In that case, this removed code is used. For now it is not needed anymore because it just duplicates optimization in some sense.

(input: String) => convert(tokenizer.parseLine(input))
} else {
// If `columnPruning` enabled and partition attributes scanned only,
// `schema` gets empty.
(_: String) => InternalRow.empty
}

/**
* Parses a single CSV string and turns it into either one resulting row or no row (if the
* the record is malformed).
*/
def parse(input: String): InternalRow = doParse(input)
def parse(input: String): InternalRow = convert(tokenizer.parseLine(input))

private val getToken = if (options.columnPruning) {
(tokens: Array[String], index: Int) => tokens(index)
Expand Down Expand Up @@ -293,7 +285,8 @@ private[csv] object UnivocityParser {
input => Seq(parser.convert(input)),
parser.options.parseMode,
schema,
parser.options.columnNameOfCorruptRecord)
parser.options.columnNameOfCorruptRecord,
parser.options.multiLine)
convertStream(inputStream, shouldDropHeader, tokenizer, checkHeader) { tokens =>
safeParser.parse(tokens)
}.flatten
Expand Down Expand Up @@ -341,7 +334,8 @@ private[csv] object UnivocityParser {
input => Seq(parser.parse(input)),
parser.options.parseMode,
schema,
parser.options.columnNameOfCorruptRecord)
parser.options.columnNameOfCorruptRecord,
parser.options.multiLine)
filteredLines.flatMap(safeParser.parse)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ object TextInputJsonDataSource extends JsonDataSource {
input => parser.parse(input, textParser, textToUTF8String),
parser.options.parseMode,
schema,
parser.options.columnNameOfCorruptRecord)
parser.options.columnNameOfCorruptRecord,
parser.options.multiLine)
linesReader.flatMap(safeParser.parse)
}

Expand Down Expand Up @@ -223,7 +224,8 @@ object MultiLineJsonDataSource extends JsonDataSource {
input => parser.parse[InputStream](input, streamParser, partitionedFileString),
parser.options.parseMode,
schema,
parser.options.columnNameOfCorruptRecord)
parser.options.columnNameOfCorruptRecord,
parser.options.multiLine)

safeParser.parse(
CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath))))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,47 @@ object CSVBenchmarks {
}
}

def countBenchmark(rowsNum: Int): Unit = {
val colsNum = 10
val benchmark = new Benchmark(s"Count a dataset with $colsNum columns", rowsNum)

withTempPath { path =>
val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", IntegerType))
val schema = StructType(fields)

spark.range(rowsNum)
.select(Seq.tabulate(colsNum)(i => lit(i).as(s"col$i")): _*)
.write
.csv(path.getAbsolutePath)

val ds = spark.read.schema(schema).csv(path.getAbsolutePath)

benchmark.addCase(s"Select $colsNum columns + count()", 3) { _ =>
ds.select("*").filter((_: Row) => true).count()
}
benchmark.addCase(s"Select 1 column + count()", 3) { _ =>
ds.select($"col1").filter((_: Row) => true).count()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this benchmark result vary if we select col2 or col10?

}
benchmark.addCase(s"count()", 3) { _ =>
ds.count()
}

/*
Intel(R) Core(TM) i7-7700HQ CPU @ 2.80GHz

Count a dataset with 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
---------------------------------------------------------------------------------------------
Select 10 columns + count() 12598 / 12740 0.8 1259.8 1.0X
Select 1 column + count() 7960 / 8175 1.3 796.0 1.6X
count() 2332 / 2386 4.3 233.2 5.4X
*/
benchmark.run()
}
}

def main(args: Array[String]): Unit = {
quotedValuesBenchmark(rowsNum = 50 * 1000, numIters = 3)
multiColumnsBenchmark(rowsNum = 1000 * 1000)
countBenchmark(10 * 1000 * 1000)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1641,4 +1641,30 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te
}
}
}

test("count() for malformed input") {
def countForMalformedCSV(expected: Long, input: Seq[String]): Unit = {
val schema = new StructType().add("a", IntegerType)
val strings = spark.createDataset(input)
val df = spark.read.schema(schema).option("header", false).csv(strings)

assert(df.count() == expected)
}
def checkCount(expected: Long): Unit = {
val validRec = "1"
val inputs = Seq(
Seq("{-}", validRec),
Seq(validRec, "?"),
Seq("0xAC", validRec),
Seq(validRec, "0.314"),
Seq("\\\\\\", validRec)
)
inputs.foreach { input =>
countForMalformedCSV(expected, input)
}
}

checkCount(2)
countForMalformedCSV(0, Seq(""))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ package org.apache.spark.sql.execution.datasources.json
import java.io.File

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.{LongType, StringType, StructType}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.types._
import org.apache.spark.util.{Benchmark, Utils}

/**
Expand Down Expand Up @@ -171,9 +172,49 @@ object JSONBenchmarks {
}
}

def countBenchmark(rowsNum: Int): Unit = {
val colsNum = 10
val benchmark = new Benchmark(s"Count a dataset with $colsNum columns", rowsNum)

withTempPath { path =>
val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", IntegerType))
val schema = StructType(fields)
val columnNames = schema.fieldNames

spark.range(rowsNum)
.select(Seq.tabulate(colsNum)(i => lit(i).as(s"col$i")): _*)
.write
.json(path.getAbsolutePath)

val ds = spark.read.schema(schema).json(path.getAbsolutePath)

benchmark.addCase(s"Select $colsNum columns + count()", 3) { _ =>
ds.select("*").filter((_: Row) => true).count()
}
benchmark.addCase(s"Select 1 column + count()", 3) { _ =>
ds.select($"col1").filter((_: Row) => true).count()
}
benchmark.addCase(s"count()", 3) { _ =>
ds.count()
}

/*
Intel(R) Core(TM) i7-7700HQ CPU @ 2.80GHz

Count a dataset with 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
---------------------------------------------------------------------------------------------
Select 10 columns + count() 9961 / 10006 1.0 996.1 1.0X
Select 1 column + count() 8355 / 8470 1.2 835.5 1.2X
count() 2104 / 2156 4.8 210.4 4.7X
*/
benchmark.run()
}
}

def main(args: Array[String]): Unit = {
schemaInferring(100 * 1000 * 1000)
perlineParsing(100 * 1000 * 1000)
perlineParsingOfWideColumn(10 * 1000 * 1000)
countBenchmark(10 * 1000 * 1000)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2223,7 +2223,6 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
checkAnswer(jsonDF, Seq(Row("Chris", "Baird")))
}


test("SPARK-23723: specified encoding is not matched to actual encoding") {
val fileName = "test-data/utf16LE.json"
val schema = new StructType().add("firstName", StringType).add("lastName", StringType)
Expand Down Expand Up @@ -2490,4 +2489,30 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
assert(exception.getMessage.contains("encoding must not be included in the blacklist"))
}
}

test("count() for malformed input") {
def countForMalformedJSON(expected: Long, input: Seq[String]): Unit = {
val schema = new StructType().add("a", StringType)
val strings = spark.createDataset(input)
val df = spark.read.schema(schema).json(strings)

assert(df.count() == expected)
}
def checkCount(expected: Long): Unit = {
val validRec = """{"a":"b"}"""
val inputs = Seq(
Seq("{-}", validRec),
Seq(validRec, "?"),
Seq("}", validRec),
Seq(validRec, """{"a": [1, 2, 3]}"""),
Seq("""{"a": {"a": "b"}}""", validRec)
)
inputs.foreach { input =>
countForMalformedJSON(expected, input)
}
}

checkCount(2)
countForMalformedJSON(0, Seq(""))
}
}