Skip to content

Commit

Permalink
from_xml and schema_for_xml support (#429)
Browse files Browse the repository at this point in the history
This is a continuation of #334

This implements a `from_xml` expression that can turn an XML string column into a parsed structured column. It also opens up `inferSchema` for a `Dataset[String]` as a necessary support function.

The rest is really mild refactoring.
  • Loading branch information
srowen authored and HyukjinKwon committed Jan 7, 2020
1 parent d155b1b commit ef3af6a
Show file tree
Hide file tree
Showing 11 changed files with 341 additions and 89 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ matrix:
- TEST_SPARK_VERSION="2.4.4"
- jdk: openjdk11
env:
- TEST_SPARK_VERSION="3.0.0-preview"
- TEST_SPARK_VERSION="3.0.0-preview2"
script:
- sbt -Dspark.testVersion=$TEST_SPARK_VERSION ++$TRAVIS_SCALA_VERSION clean scalastyle test:scalastyle mimaReportBinaryIssues coverage test coverageReport
after_success:
Expand Down
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,11 @@ $SPARK_HOME/bin/spark-shell --packages com.databricks:spark-xml_2.12:0.7.0
```

## Features

This package allows reading XML files in local or distributed filesystem as [Spark DataFrames](https://spark.apache.org/docs/latest/sql-programming-guide.html).

When reading files the API accepts several options:

* `path`: Location of files. Similar to Spark can accept standard Hadoop globbing expressions.
* `rowTag`: The row tag of your xml files to treat as a row. For example, in this xml `<books> <book><book> ...</books>`, the appropriate value would be `book`. Default is `ROW`.
* `samplingRatio`: Sampling ratio for inferring schema (0.0 ~ 1). Default is 1. Possible types are `StructType`, `ArrayType`, `StringType`, `LongType`, `DoubleType`, `BooleanType`, `TimestampType` and `NullType`, unless user provides a schema for this.
Expand All @@ -78,6 +81,7 @@ it depends on should be added to the Spark executors with
In this case, to use local XSD `/foo/bar.xsd`, call `addFile("/foo/bar.xsd")` and pass just `"bar.xsd"` as `rowValidationXSDPath`. New in 0.8.0.

When writing files the API accepts several options:

* `path`: Location to write files.
* `rowTag`: The row tag of your xml files to treat as a row. For example, in this xml `<books> <book><book> ...</books>`, the appropriate value would be `book`. Default is `ROW`.
* `rootTag`: The root tag of your xml files to treat as the root. For example, in this xml `<books> <book><book> ...</books>`, the appropriate value would be `books`. Default is `ROWS`.
Expand All @@ -88,6 +92,25 @@ When writing files the API accepts several options:

Currently it supports the shortened name usage. You can use just `xml` instead of `com.databricks.spark.xml`.

### Parsing Nested XML

Although primarily used to convert (portions of) large XML documents into a `DataFrame`, from version 0.8.0 onwards,
`spark-xml` can also parse XML in a string-valued column in an existing DataFrame with `from_xml`, in order to add
it as a new column with parsed results as a struct.

```scala
import com.databricks.spark.xml.functions.from_xml
import com.databricks.spark.xml.schema_of_xml
import spark.implicits._
val df = ... /// DataFrame with XML in column 'payload'
val payloadSchema = schema_of_xml(df.select("payload").as[String])
val parsed = df.withColumn("parsed", from_xml($"payload", payloadSchema))
```

- This can converts arrays of strings containing XML to arrays of parsed structs. Use `schema_of_xml_array` instead
- `com.databricks.spark.xml.from_xml_string` is an alternative that operates on a String directly instead of a column,
for use in UDFsa

## Structure Conversion

Due to the structure differences between `DataFrame` and XML, there are some conversion rules from XML data to `DataFrame` and from `DataFrame` to XML data. Note that handling attributes can be disabled with the option `excludeAttribute`.
Expand Down
6 changes: 5 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,11 @@ val ignoredABIProblems = {
exclude[DirectMissingMethodProblem](
"com.databricks.spark.xml.util.CompressionCodecs.getCodecClass"),
exclude[IncompatibleMethTypeProblem](
"com.databricks.spark.xml.parsers.StaxXmlGenerator.apply")
"com.databricks.spark.xml.parsers.StaxXmlGenerator.apply"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"com.databricks.spark.xml.parsers.StaxXmlParser.com$databricks$spark$xml$parsers$StaxXmlParser$$failedRecord$default$3$1"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"com.databricks.spark.xml.parsers.StaxXmlParser.com$databricks$spark$xml$parsers$StaxXmlParser$$failedRecord$default$2$1")
)
}

Expand Down
2 changes: 1 addition & 1 deletion scalastyle-config.xml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ This file is divided into 3 sections:
</check>

<check level="error" class="org.scalastyle.scalariform.ObjectNamesChecker" enabled="true">
<parameters><parameter name="regex"><![CDATA[[A-Z][A-Za-z]*]]></parameter></parameters>
<parameters><parameter name="regex"><![CDATA[^[A-Za-z]+$]]></parameter></parameters>
</check>

<check level="error" class="org.scalastyle.scalariform.PackageObjectNamesChecker" enabled="true">
Expand Down
60 changes: 60 additions & 0 deletions src/main/scala/com/databricks/spark/xml/XmlDataToCatalyst.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright 2019 Databricks
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.databricks.spark.xml

import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, UnaryExpression}
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import com.databricks.spark.xml.parsers.StaxXmlParser

case class XmlDataToCatalyst(
child: Expression,
schema: DataType,
options: XmlOptions)
extends UnaryExpression with CodegenFallback with ExpectsInputTypes {

override lazy val dataType: DataType = schema

@transient
lazy val rowSchema: StructType = schema match {
case st: StructType => st
case ArrayType(st: StructType, _) => st
}

override def nullSafeEval(xml: Any): Any = xml match {
case string: UTF8String =>
CatalystTypeConverters.convertToCatalyst(
StaxXmlParser.parseColumn(string.toString, rowSchema, options))
case string: String =>
StaxXmlParser.parseColumn(string.toString, rowSchema, options)
case arr: GenericArrayData =>
CatalystTypeConverters.convertToCatalyst(
arr.array.map(s => StaxXmlParser.parseColumn(s.toString, rowSchema, options)))
case arr: Array[_] =>
arr.map(s => StaxXmlParser.parseColumn(s.toString, rowSchema, options))
case _ => null
}

override def inputTypes: Seq[DataType] = schema match {
case _: StructType => Seq(StringType)
case ArrayType(_: StructType, _) => Seq(ArrayType(StringType))
}
}
43 changes: 43 additions & 0 deletions src/main/scala/com/databricks/spark/xml/functions.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright 2019 Databricks
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.databricks.spark.xml

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.types.DataType

/**
* Support functions for working with XML columns directly.
*/
object functions {

/**
* Parses a column containing a XML string into a `StructType` with the specified schema.
*
* @param e a string column containing XML data
* @param schema the schema to use when parsing the XML string. Must be a StructType if
* column is string-valued, or ArrayType[StructType] if column is an array of strings
* @param options key-value pairs that correspond to those supported by [[XmlOptions]]
*/
@Experimental
def from_xml(e: Column, schema: DataType, options: Map[String, String] = Map.empty): Column = {
val expr = CatalystSqlParser.parseExpression(e.toString())
new Column(XmlDataToCatalyst(expr, schema, XmlOptions(options)))
}

}
46 changes: 42 additions & 4 deletions src/main/scala/com/databricks/spark/xml/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
*/
package com.databricks.spark

import scala.collection.Map

import org.apache.hadoop.io.compress.CompressionCodec

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql._
import com.databricks.spark.xml.util.XmlFile
import org.apache.spark.sql.types.{ArrayType, StructType}

import com.databricks.spark.xml.parsers.StaxXmlParser
import com.databricks.spark.xml.util.{InferSchema, XmlFile}

package object xml {
/**
Expand Down Expand Up @@ -64,7 +66,7 @@ package object xml {
implicit class XmlSchemaRDD(dataFrame: DataFrame) {
@deprecated("Use write.format(\"xml\") or write.xml", "0.4.0")
def saveAsXmlFile(
path: String, parameters: Map[String, String] = Map(),
path: String, parameters: scala.collection.Map[String, String] = Map(),
compressionCodec: Class[_ <: CompressionCodec] = null): Unit = {
val mutableParams = collection.mutable.Map(parameters.toSeq: _*)
val safeCodec = mutableParams.get("codec")
Expand Down Expand Up @@ -111,4 +113,40 @@ package object xml {
// Namely, roundtrip in writing and reading can end up in different schema structure.
def xml: String => Unit = writer.format("com.databricks.spark.xml").save
}

/**
* Infers the schema of XML documents as strings.
*
* @param ds Dataset of XML strings
* @param options additional XML parsing options
* @return inferred schema for XML
*/
@Experimental
def schema_of_xml(ds: Dataset[String], options: Map[String, String] = Map.empty): StructType =
InferSchema.infer(ds.rdd, XmlOptions(options))

/**
* Infers the schema of XML documents when inputs are arrays of strings, each an XML doc.
*
* @param ds Dataset of XML strings
* @param options additional XML parsing options
* @return inferred schema for XML. Will be an ArrayType[StructType].
*/
@Experimental
def schema_of_xml_array(ds: Dataset[Array[String]],
options: Map[String, String] = Map.empty): ArrayType =
ArrayType(InferSchema.infer(ds.rdd.flatMap(a => a), XmlOptions(options)))

/**
* @param xml XML document to parse, as string
* @param schema the schema to use when parsing the XML string
* @param options key-value pairs that correspond to those supported by [[XmlOptions]]
* @return [[Row]] representing the parsed XML structure
*/
@Experimental
def from_xml_string(xml: String, schema: StructType,
options: Map[String, String] = Map.empty): Row = {
StaxXmlParser.parseColumn(xml, schema, XmlOptions(options))
}

}
110 changes: 52 additions & 58 deletions src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ package com.databricks.spark.xml.parsers

import java.io.StringReader

import javax.xml.stream.{EventFilter, XMLEventReader, XMLInputFactory, XMLStreamConstants}
import javax.xml.stream.{XMLEventReader, XMLInputFactory}
import javax.xml.stream.events.{Attribute, Characters, EndElement, StartElement, XMLEvent}
import javax.xml.transform.stream.StreamSource

Expand Down Expand Up @@ -46,52 +46,7 @@ private[xml] object StaxXmlParser extends Serializable {
schema: StructType,
options: XmlOptions): RDD[Row] = {

// The logic below is borrowed from Apache Spark's FailureSafeParser.
val corruptFieldIndex = Try(schema.fieldIndex(options.columnNameOfCorruptRecord)).toOption
val actualSchema = StructType(schema.filterNot(_.name == options.columnNameOfCorruptRecord))
val resultRow = new Array[Any](schema.length)
val toResultRow: (Option[Row], String) => Row = (row, badRecord) => {
var i = 0
while (i < actualSchema.length) {
val from = actualSchema(i)
resultRow(schema.fieldIndex(from.name)) = row.map(_.get(i)).orNull
i += 1
}
corruptFieldIndex.foreach(index => resultRow(index) = badRecord)
Row.fromSeq(resultRow)
}

def failedRecord(
record: String, cause: Throwable = null, partialResult: Option[Row] = None): Option[Row] = {
// create a row even if no corrupt record column is present
options.parseMode match {
case FailFastMode =>
throw new IllegalArgumentException(
s"Malformed line in FAILFAST mode: ${record.replaceAll("\n", "")}", cause)
case DropMalformedMode =>
val reason = if (cause != null) s"Reason: ${cause.getMessage}" else ""
logger.warn(s"Dropping malformed line: ${record.replaceAll("\n", "")}. $reason")
logger.debug("Malformed line cause:", cause)
None
case PermissiveMode =>
logger.debug("Malformed line cause:", cause)
Some(toResultRow(partialResult, record))
}
}

xml.mapPartitions { iter =>
val factory = XMLInputFactory.newInstance()
factory.setProperty(XMLInputFactory.IS_NAMESPACE_AWARE, false)
factory.setProperty(XMLInputFactory.IS_COALESCING, true)
val filter = new EventFilter {
override def accept(event: XMLEvent): Boolean =
// Ignore comments and processing instructions
event.getEventType match {
case XMLStreamConstants.COMMENT | XMLStreamConstants.PROCESSING_INSTRUCTION => false
case _ => true
}
}

val xsdSchema = Option(options.rowValidationXSDPath).map(ValidatorUtil.getSchema)

iter.flatMap { xml =>
Expand All @@ -100,27 +55,66 @@ private[xml] object StaxXmlParser extends Serializable {
schema.newValidator().validate(new StreamSource(new StringReader(xml)))
}

// It does not have to skip for white space, since `XmlInputFormat`
// always finds the root tag without a heading space.
val eventReader = factory.createXMLEventReader(new StringReader(xml))
val parser = factory.createFilteredReader(eventReader, filter)

val rootEvent =
StaxXmlParserUtils.skipUntil(parser, XMLStreamConstants.START_ELEMENT)
val rootAttributes =
rootEvent.asStartElement.getAttributes.asScala.map(_.asInstanceOf[Attribute]).toArray
val parser = StaxXmlParserUtils.filteredReader(xml)
val rootAttributes = StaxXmlParserUtils.gatherRootAttributes(parser)
Some(convertObject(parser, schema, options, rootAttributes))
.orElse(failedRecord(xml))
.orElse(failedRecord(xml, options, schema))
} catch {
case e: PartialResultException =>
failedRecord(xml, e.cause, Some(e.partialResult))
failedRecord(xml, options, schema, e.cause, Some(e.partialResult))
case NonFatal(e) =>
failedRecord(xml, e)
failedRecord(xml, options, schema, e)
}
}
}
}

def parseColumn(xml: String, schema: StructType, options: XmlOptions): Row = {
try {
val parser = StaxXmlParserUtils.filteredReader(xml)
val rootAttributes = StaxXmlParserUtils.gatherRootAttributes(parser)
convertObject(parser, schema, options, rootAttributes)
} catch {
case e: PartialResultException =>
// Null only if mode is DropMalformedMode
failedRecord(xml, options, schema, e.cause, Some(e.partialResult)).orNull
case NonFatal(e) =>
failedRecord(xml, options, schema, e).orNull
}
}

private def failedRecord(record: String,
options: XmlOptions,
schema: StructType,
cause: Throwable = null,
partialResult: Option[Row] = None): Option[Row] = {
// create a row even if no corrupt record column is present
options.parseMode match {
case FailFastMode =>
throw new IllegalArgumentException(
s"Malformed line in FAILFAST mode: ${record.replaceAll("\n", "")}", cause)
case DropMalformedMode =>
val reason = if (cause != null) s"Reason: ${cause.getMessage}" else ""
logger.warn(s"Dropping malformed line: ${record.replaceAll("\n", "")}. $reason")
logger.debug("Malformed line cause:", cause)
None
case PermissiveMode =>
logger.debug("Malformed line cause:", cause)
// The logic below is borrowed from Apache Spark's FailureSafeParser.
val corruptFieldIndex = Try(schema.fieldIndex(options.columnNameOfCorruptRecord)).toOption
val actualSchema = StructType(schema.filterNot(_.name == options.columnNameOfCorruptRecord))
val resultRow = new Array[Any](schema.length)
var i = 0
while (i < actualSchema.length) {
val from = actualSchema(i)
resultRow(schema.fieldIndex(from.name)) = partialResult.map(_.get(i)).orNull
i += 1
}
corruptFieldIndex.foreach(index => resultRow(index) = record)
Some(Row.fromSeq(resultRow))
}
}

/**
* Parse the current token (and related children) according to a desired schema
*/
Expand Down
Loading

0 comments on commit ef3af6a

Please sign in to comment.