Skip to content

Commit

Permalink
Support for PERMISSIVE/DROPMALFORMED mode and corrupt record option.
Browse files Browse the repository at this point in the history
databricks/spark-xml#105

Currently, this library does not support `PERMISSIVE` parse mode. Similar with JSON data source, this also can be done in the same way with `_corrupt_record`.

This PR adds the support for `PERMISSIVE` mode and make this behaviour consistent with the other data sources supporting parse modes (JSON and CSV data sources.)

Also, this PR adds the support for `_corrupt_record`.

This PR is similar with apache/spark#11756 and apache/spark#11881.

Author: hyukjinkwon <[email protected]>

Closes #107 from HyukjinKwon/ISSUE-105-permissive.
  • Loading branch information
beluisterql authored and HyukjinKwon committed Sep 10, 2016
1 parent 982f5f9 commit bdb9ea3
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 87 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ When reading files the API accepts several options:
* `samplingRatio`: Sampling ratio for inferring schema (0.0 ~ 1). Default is 1. Possible types are `StructType`, `ArrayType`, `StringType`, `LongType`, `DoubleType`, `BooleanType`, `TimestampType` and `NullType`, unless user provides a schema for this.
* `excludeAttribute` : Whether you want to exclude attributes in elements or not. Default is false.
* `treatEmptyValuesAsNulls` : Whether you want to treat whitespaces as a null value. Default is false.
* `failFast` : Whether you want to fail when it fails to parse malformed rows in XML files, instead of dropping the rows. Default is false.
* `mode`: The mode for dealing with corrupt records during parsing. Default is `PERMISSIVE`.
* `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts the malformed string into a new field configured by `columnNameOfCorruptRecord`. When a schema is set by user, it sets `null` for extra fields.
* `DROPMALFORMED` : ignores the whole corrupted records.
* `FAILFAST` : throws an exception when it meets corrupted records.
* `columnNameOfCorruptRecord`: The name of new field where malformed strings are stored. Default is `_corrupt_record`.
* `attributePrefix`: The prefix for attributes so that we can differentiate attributes and elements. This will be the prefix for field names. Default is `_`.
* `valueTag`: The tag used for the value when there are attributes in the element having no child. Default is `_VALUE`.
* `charset`: Defaults to 'UTF-8' but can be set to other valid charset names
Expand Down
25 changes: 24 additions & 1 deletion src/main/scala/com/databricks/spark/xml/XmlOptions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@
*/
package com.databricks.spark.xml

import org.slf4j.LoggerFactory

import com.databricks.spark.xml.util.ParseModes

/**
* Options for the XML data source.
*/
private[xml] class XmlOptions(
@transient private val parameters: Map[String, String])
extends Serializable{
private val logger = LoggerFactory.getLogger(XmlRelation.getClass)

val charset = parameters.getOrElse("charset", XmlOptions.DEFAULT_CHARSET)
val codec = parameters.get("compression").orElse(parameters.get("codec")).orNull
Expand All @@ -30,11 +35,29 @@ private[xml] class XmlOptions(
val excludeAttributeFlag = parameters.get("excludeAttribute").map(_.toBoolean).getOrElse(false)
val treatEmptyValuesAsNulls =
parameters.get("treatEmptyValuesAsNulls").map(_.toBoolean).getOrElse(false)
val failFastFlag = parameters.get("failFast").map(_.toBoolean).getOrElse(false)
val attributePrefix =
parameters.getOrElse("attributePrefix", XmlOptions.DEFAULT_ATTRIBUTE_PREFIX)
val valueTag = parameters.getOrElse("valueTag", XmlOptions.DEFAULT_VALUE_TAG)
val nullValue = parameters.getOrElse("nullValue", XmlOptions.DEFAULT_NULL_VALUE)
val columnNameOfCorruptRecord =
parameters.getOrElse("columnNameOfCorruptRecord", "_corrupt_record")

// Leave this option for backwards compatibility.
private val failFastFlag = parameters.get("failFast").map(_.toBoolean).getOrElse(false)
private val parseMode = if (failFastFlag) {
parameters.getOrElse("mode", ParseModes.FAIL_FAST_MODE)
} else {
parameters.getOrElse("mode", ParseModes.PERMISSIVE_MODE)
}

// Parse mode flags
if (!ParseModes.isValidMode(parseMode)) {
logger.warn(s"$parseMode is not a valid parse mode. Using ${ParseModes.DEFAULT}.")
}

val failFast = ParseModes.isFailFastMode(parseMode)
val dropMalformed = ParseModes.isDropMalformedMode(parseMode)
val permissive = ParseModes.isPermissiveMode(parseMode)

require(rowTag.nonEmpty, "'rowTag' option should not be empty string.")
require(attributePrefix.nonEmpty, "'attributePrefix' option should not be empty string.")
Expand Down
10 changes: 10 additions & 0 deletions src/main/scala/com/databricks/spark/xml/XmlReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ class XmlReader extends Serializable {
this
}

def withParseMode(valueTag: String): XmlReader = {
parameters += ("mode" -> valueTag)
this
}

def withAttributePrefix(attributePrefix: String): XmlReader = {
parameters += ("attributePrefix" -> attributePrefix)
this
Expand All @@ -72,6 +77,11 @@ class XmlReader extends Serializable {
this
}

def withColumnNameOfCorruptRecord(name: String): XmlReader = {
parameters += ("columnNameOfCorruptRecord" -> name)
this
}

def withSchema(schema: StructType): XmlReader = {
this.schema = schema
this
Expand Down
22 changes: 3 additions & 19 deletions src/main/scala/com/databricks/spark/xml/XmlRelation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,28 +61,12 @@ case class XmlRelation protected[spark] (
val schemaFields = schema.fields
if (schemaFields.deep == requiredFields.deep) {
buildScan()
} else if (options.failFastFlag) {
val safeRequestedSchema = StructType(requiredFields)
StaxXmlParser.parse(
baseRDD(),
safeRequestedSchema,
options)
} else {
// If `failFast` is disabled, then it needs to parse all the values
// so that we can decide which row is malformed.
val safeRequestedSchema = StructType(
requiredFields ++ schema.fields.filterNot(requiredFields.contains(_)))
val rows = StaxXmlParser.parse(
val requestedSchema = StructType(requiredFields)
StaxXmlParser.parse(
baseRDD(),
safeRequestedSchema,
requestedSchema,
options)

val rowSize = requiredFields.length
rows.mapPartitions { iter =>
iter.flatMap { xml =>
Some(Row.fromSeq(xml.toSeq.take(rowSize)))
}
}
}
}

Expand Down
41 changes: 26 additions & 15 deletions src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,25 @@ private[xml] object StaxXmlParser {
xml: RDD[String],
schema: StructType,
options: XmlOptions): RDD[Row] = {
val failFast = options.failFastFlag
def failedRecord(record: String): Option[Row] = {
// create a row even if no corrupt record column is present
if (options.failFast) {
throw new RuntimeException(
s"Malformed line in FAILFAST mode: ${record.replaceAll("\n", "")}")
} else if (options.dropMalformed) {
logger.warn(s"Dropping malformed line: ${record.replaceAll("\n", "")}")
None
} else {
val row = new Array[Any](schema.length)
val nameToIndex = schema.map(_.name).zipWithIndex.toMap
nameToIndex.get(options.columnNameOfCorruptRecord).foreach { corruptIndex =>
require(schema(corruptIndex).dataType == StringType)
row.update(corruptIndex, record)
}
Some(Row.fromSeq(row))
}
}

xml.mapPartitions { iter =>
val factory = XMLInputFactory.newInstance()
factory.setProperty(XMLInputFactory.IS_NAMESPACE_AWARE, false)
Expand All @@ -63,22 +81,15 @@ private[xml] object StaxXmlParser {
StaxXmlParserUtils.skipUntil(parser, XMLStreamConstants.START_ELEMENT)
val rootAttributes =
rootEvent.asStartElement.getAttributes.map(_.asInstanceOf[Attribute]).toArray

Some(convertObject(parser, schema, options, rootAttributes))
.orElse(failedRecord(xml))
} catch {
case _: java.lang.NumberFormatException if !failFast =>
logger.warn("Number format exception. " +
s"Dropping malformed line: ${xml.replaceAll("\n", "")}")
None
case _: java.text.ParseException | _: IllegalArgumentException if !failFast =>
logger.warn("Parse exception. " +
s"Dropping malformed line: ${xml.replaceAll("\n", "")}")
None
case _: XMLStreamException if failFast =>
throw new RuntimeException(s"Malformed row (failing fast): ${xml.replaceAll("\n", "")}")
case _: XMLStreamException if !failFast =>
logger.warn(s"Dropping malformed row: ${xml.replaceAll("\n", "")}")
None
case _: java.lang.NumberFormatException =>
failedRecord(xml)
case _: java.text.ParseException | _: IllegalArgumentException =>
failedRecord(xml)
case _: XMLStreamException =>
failedRecord(xml)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ private[xml] object InferSchema {
def infer(xml: RDD[String], options: XmlOptions): StructType = {
require(options.samplingRatio > 0,
s"samplingRatio ($options.samplingRatio) should be greater than 0")
val shouldHandleCorruptRecord = options.permissive
val schemaData = if (options.samplingRatio > 0.99) {
xml
} else {
xml.sample(withReplacement = false, options.samplingRatio, 1)
}
val failFast = options.failFastFlag
// perform schema inference on each row and merge afterwards
val rootType = schemaData.mapPartitions { iter =>
val factory = XMLInputFactory.newInstance()
Expand All @@ -100,11 +100,10 @@ private[xml] object InferSchema {

Some(inferObject(parser, options, rootAttributes))
} catch {
case _: XMLStreamException if !failFast =>
logger.warn(s"Dropping malformed row: ${xml.replaceAll("\n", "")}")
case _: XMLStreamException if shouldHandleCorruptRecord =>
Some(StructType(Seq(StructField(options.columnNameOfCorruptRecord, StringType))))
case _: XMLStreamException =>
None
case _: XMLStreamException if failFast =>
throw new RuntimeException(s"Malformed row (failing fast): ${xml.replaceAll("\n", "")}")
}
}
}.treeAggregate[DataType](StructType(Seq()))(
Expand Down
39 changes: 39 additions & 0 deletions src/main/scala/com/databricks/spark/xml/util/ParseModes.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright 2014 Databricks
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.databricks.spark.xml.util

private[xml] object ParseModes {
val PERMISSIVE_MODE = "PERMISSIVE"
val DROP_MALFORMED_MODE = "DROPMALFORMED"
val FAIL_FAST_MODE = "FAILFAST"

val DEFAULT = PERMISSIVE_MODE

def isValidMode(mode: String): Boolean = {
mode.toUpperCase match {
case PERMISSIVE_MODE | DROP_MALFORMED_MODE | FAIL_FAST_MODE => true
case _ => false
}
}

def isDropMalformedMode(mode: String): Boolean = mode.toUpperCase == DROP_MALFORMED_MODE
def isFailFastMode(mode: String): Boolean = mode.toUpperCase == FAIL_FAST_MODE
def isPermissiveMode(mode: String): Boolean = if (isValidMode(mode)) {
mode.toUpperCase == PERMISSIVE_MODE
} else {
true // We default to permissive is the mode string is not valid
}
}
89 changes: 43 additions & 46 deletions src/test/scala/com/databricks/spark/xml/XmlSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.{SparkException, SparkContext}
import org.apache.spark.sql.{SaveMode, Row, SQLContext}
import org.apache.spark.sql.types._
import com.databricks.spark.xml.XmlOptions._
import com.databricks.spark.xml.util.ParseModes

class XmlSuite extends FunSuite with BeforeAndAfterAll {
val tempEmptyDir = "target/test/empty/"
Expand Down Expand Up @@ -211,58 +212,52 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {

test("DSL test for parsing a malformed XML file") {
val results = new XmlReader()
.withFailFast(false)
.withParseMode(ParseModes.DROP_MALFORMED_MODE)
.xmlFile(sqlContext, carsMalformedFile)

assert(results.count() === 1)
}

test("DSL test for dropping malformed rows") {
val schema = new StructType(
Array(
StructField("color", IntegerType, true),
StructField("make", TimestampType, true),
StructField("model", DoubleType, true),
StructField("comment", StringType, true),
StructField("year", DoubleType, true)
)
)
val results = new XmlReader()
.withSchema(schema)
.xmlFile(sqlContext, carsUnbalancedFile)
.count()
val cars = new XmlReader()
.withParseMode(ParseModes.DROP_MALFORMED_MODE)
.xmlFile(sqlContext, carsMalformedFile)

assert(results === 0)
assert(cars.count() == 1)
assert(cars.head().toSeq === Seq("Chevy", "Volt", 2015))
}

test("DSL test for failing fast") {
// Fail fast in type inference
val exceptionInSchema = intercept[SparkException] {
new XmlReader()
.withFailFast(true)
.xmlFile(sqlContext, carsMalformedFile)
.printSchema()
}
assert(exceptionInSchema.getMessage.contains("Malformed row (failing fast)"))

// Fail fast in parsing data
val schema = new StructType(
Array(
StructField("color", StringType, true),
StructField("make", StringType, true),
StructField("model", StringType, true),
StructField("comment", StringType, true),
StructField("year", StringType, true)
)
)
val exceptionInParse = intercept[SparkException] {
new XmlReader()
.withFailFast(true)
.withSchema(schema)
.xmlFile(sqlContext, carsMalformedFile)
.collect()
}
assert(exceptionInParse.getMessage.contains("Malformed row (failing fast)"))
assert(exceptionInParse.getMessage.contains("Malformed line in FAILFAST mode"))
}

test("DSL test for permissive mode for corrupt records") {
val carsDf = new XmlReader()
.withParseMode(ParseModes.PERMISSIVE_MODE)
.withColumnNameOfCorruptRecord("_malformed_records")
.xmlFile(sqlContext, carsMalformedFile)
val cars = carsDf.collect()
assert(cars.length == 3)

val malformedRowOne = carsDf.select("_malformed_records").first().toSeq.head.toString
val malformedRowTwo = carsDf.select("_malformed_records").take(2).last.toSeq.head.toString
val expectedMalformedRowOne = "<ROW><year>2012</year><make>Tesla</make><model>>S" +
"<comment>No comment</comment></ROW>"
val expectedMalformedRowTwo = "<ROW></year><make>Ford</make><model>E350</model>model></model>" +
"<comment>Go get one now they are going fast</comment></ROW>"

assert(malformedRowOne.replaceAll("\\s", "") === expectedMalformedRowOne.replaceAll("\\s", ""))
assert(malformedRowTwo.replaceAll("\\s", "") === expectedMalformedRowTwo.replaceAll("\\s", ""))
assert(cars(2).toSeq.head === null)
assert(cars(0).toSeq.takeRight(3) === Seq(null, null, null))
assert(cars(1).toSeq.takeRight(3) === Seq(null, null, null))
assert(cars(2).toSeq.takeRight(3) === Seq("Chevy", "Volt", 2015))
}

test("DSL test with empty file and known schema") {
Expand Down Expand Up @@ -421,7 +416,7 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
val schemaCopy = StructType(
List(StructField("a", ArrayType(
StructType(List(StructField("item", ArrayType(StringType), nullable = true)))),
nullable = true)))
nullable = true)))
val dfCopy = sqlContext.xmlFile(copyFilePath + "/")

assert(dfCopy.count == df.count)
Expand Down Expand Up @@ -461,8 +456,8 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
df.saveAsXmlFile(copyFilePath)

val dfCopy = new XmlReader()
.withSchema(schema)
.xmlFile(sqlContext, copyFilePath + "/")
.withSchema(schema)
.xmlFile(sqlContext, copyFilePath + "/")

assert(dfCopy.collect() === df.collect())
assert(dfCopy.schema === df.schema)
Expand Down Expand Up @@ -551,11 +546,11 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
StructField("price", DoubleType, nullable = true),
StructField("publish_dates", StructType(
List(StructField("publish_date",
ArrayType(StructType(
List(StructField(s"${DEFAULT_ATTRIBUTE_PREFIX}tag", StringType, nullable = true),
StructField("day", LongType, nullable = true),
StructField("month", LongType, nullable = true),
StructField("year", LongType, nullable = true))))))),
ArrayType(StructType(
List(StructField(s"${DEFAULT_ATTRIBUTE_PREFIX}tag", StringType, nullable = true),
StructField("day", LongType, nullable = true),
StructField("month", LongType, nullable = true),
StructField("year", LongType, nullable = true))))))),
nullable = true),
StructField("title", StringType, nullable = true))
))
Expand Down Expand Up @@ -656,9 +651,11 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
}

test("DSL test nullable fields") {
val schema = StructType(
StructField("name", StringType, false) ::
StructField("age", StringType, true) :: Nil)
val results = new XmlReader()
.withSchema(StructType(List(StructField("name", StringType, false),
StructField("age", StringType, true))))
.withSchema(schema)
.xmlFile(sqlContext, nullNumbersFile)
.collect()

Expand Down

0 comments on commit bdb9ea3

Please sign in to comment.