Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of from_xml function #334

Closed
wants to merge 9 commits into from
56 changes: 56 additions & 0 deletions src/main/scala/com/databricks/spark/xml/XmlDataToCatalyst.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright 2014 Databricks
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.databricks.spark.xml

import com.databricks.spark.xml.parsers.StaxXmlParser
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, UnaryExpression}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

case class XmlDataToCatalyst(child: Expression,
schema: DataType,
options: XmlOptions)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

case class XmlDataToCatalyst(
    child: Expression,
    schema: DataType,
    options: XmlOptions)
  extends ...

(per the guide I pointed out above)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HyukjinKwon I'm having a challenging time getting intellij to enforce this. I'm going to try swinging back around to it after I commit the more sizable change requests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, it is a bit grunting to manually fix the style. I admit. Such change suggestion wouldn't block this PR but should be good to follow the style guide (https://github.com/databricks/scala-style-guide) since Databricks repositories comply this style .. (BTW, I'm not a Databricks guy so I would like to suggest follow this guide to be safe).

extends UnaryExpression with CodegenFallback with ExpectsInputTypes {

override lazy val dataType: DataType = schema

override def checkInputDataTypes(): TypeCheckResult = schema match {
case _: StructType | ArrayType(_: StructType, _) =>
super.checkInputDataTypes()
case _ => TypeCheckResult.TypeCheckFailure(
s"Input schema ${schema.simpleString} must be a struct or an array of structs.")
}

@transient
lazy val rowSchema: StructType = schema match {
case st: StructType => st
case ArrayType(st: StructType, _) => st
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think XML implementation does not support an array of struct. Can be removed.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you like me to have a case _ => throw SomeExceptionClass('string') statement? Otherwise rowSchema must be of type DataType

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I think we can change DataType to StructType.

}

override def nullSafeEval(xml: Any): Any = xml match {
case string: UTF8String =>
CatalystTypeConverters.convertToCatalyst(
StaxXmlParser.parseColumn(string.toString, rowSchema, options))
case string: String =>
StaxXmlParser.parseColumn(string.toString, rowSchema, options)
case _ => null
}

override def inputTypes: Seq[DataType] = StringType :: Nil
}
39 changes: 37 additions & 2 deletions src/main/scala/com/databricks/spark/xml/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,47 @@ package com.databricks.spark

import scala.collection.Map

import com.databricks.spark.xml.util.XmlFile
import org.apache.hadoop.io.compress.CompressionCodec

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql._
import com.databricks.spark.xml.util.XmlFile
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.types.StructType


package object xml {


private def withExpr(expr: Expression): Column = new Column(expr)

/**
* Parses a column containing a XML string into a `StructType` with the specified schema.
*
* @param e a string column containing XML data
* @param schema the schema to use when parsing the XML string
*/
@Experimental
implicit def from_xml(e: Column, schema: StructType): Column = {
from_xml(e, schema, Map.empty[String, String])
}

/**
* Parses a column containing a XML string into a `StructType` with the specified schema.
*
* @param e a string column containing XML data
* @param schema the schema to use when parsing the XML string
* @param options key-value pairs that correspond to those supported by [[XmlOptions]]
*/
@Experimental
implicit def from_xml(e: Column, schema: StructType, options: Map[String, String]): Column =
withExpr {

val map: Map[String, String] = options + ("isFunction" -> "true")
val expr: Expression = CatalystSqlParser.parseExpression(e.toString())
XmlDataToCatalyst(expr, schema, XmlOptions(map.toMap))
}

/**
* Adds a method, `xmlFile`, to [[SQLContext]] that allows reading XML data.
*/
Expand Down
64 changes: 45 additions & 19 deletions src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,54 @@
package com.databricks.spark.xml.parsers

import java.io.StringReader

import javax.xml.stream.events.{Attribute, XMLEvent}
import javax.xml.stream.events._
import javax.xml.stream._
import javax.xml.stream.{XMLEventReader, _}

import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConverters._
import scala.util.control.NonFatal
import scala.util.Try

import org.slf4j.LoggerFactory

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import com.databricks.spark.xml.util.TypeCast._
import com.databricks.spark.xml.XmlOptions
import com.databricks.spark.xml.util._
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, GenericRowWithSchema}
import org.apache.spark.unsafe.types.UTF8String

/**
* Wraps parser to iteration process.
*/
private[xml] object StaxXmlParser extends Serializable {
private val logger = LoggerFactory.getLogger(StaxXmlParser.getClass)

private def filteredReader(xml: String, factory: XMLInputFactory): XMLEventReader = {

val filter: EventFilter = new EventFilter {
override def accept(event: XMLEvent): Boolean =
// Ignore comments. This library does not treat comments.
event.getEventType != XMLStreamConstants.COMMENT
}

// It does not have to skip for white space, since `XmlInputFormat`
// always finds the root tag without a heading space.
val eventReader = factory.createXMLEventReader(new StringReader(xml))

factory.createFilteredReader(eventReader, filter)
}

private def gatherRootAttributes(xmlEventReader: XMLEventReader): Array[Attribute] = {
val rootEvent =
StaxXmlParserUtils.skipUntil(xmlEventReader, XMLStreamConstants.START_ELEMENT)

rootEvent.asStartElement.getAttributes.asScala.map(_.asInstanceOf[Attribute]).toArray
}

def parse(
xml: RDD[String],
schema: StructType,
Expand Down Expand Up @@ -80,25 +103,14 @@ private[xml] object StaxXmlParser extends Serializable {
}

xml.mapPartitions { iter =>
val factory = XMLInputFactory.newInstance()
factory.setProperty(XMLInputFactory.IS_NAMESPACE_AWARE, false)
factory.setProperty(XMLInputFactory.IS_COALESCING, true)
val filter = new EventFilter {
override def accept(event: XMLEvent): Boolean =
// Ignore comments. This library does not treat comments.
event.getEventType != XMLStreamConstants.COMMENT
}
val partitionFactory: XMLInputFactory = XMLInputFactory.newInstance()
partitionFactory.setProperty(XMLInputFactory.IS_NAMESPACE_AWARE, false)
partitionFactory.setProperty(XMLInputFactory.IS_COALESCING, true)

iter.flatMap { xml =>
// It does not have to skip for white space, since `XmlInputFormat`
// always finds the root tag without a heading space.
val eventReader = factory.createXMLEventReader(new StringReader(xml))
val parser = factory.createFilteredReader(eventReader, filter)
val parser = filteredReader(xml, partitionFactory)
try {
val rootEvent =
StaxXmlParserUtils.skipUntil(parser, XMLStreamConstants.START_ELEMENT)
val rootAttributes =
rootEvent.asStartElement.getAttributes.asScala.map(_.asInstanceOf[Attribute]).toArray
val rootAttributes = gatherRootAttributes(parser)
Some(convertObject(parser, schema, options, rootAttributes))
.orElse(failedRecord(xml))
} catch {
Expand All @@ -111,6 +123,20 @@ private[xml] object StaxXmlParser extends Serializable {
}
}

def parseColumn(xml: String,
schema: StructType,
options: XmlOptions): Row = {
Copy link
Member

@HyukjinKwon HyukjinKwon Jan 2, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

def parseColumn(
    xml: String,
    ...): Row = {

val factory: XMLInputFactory = XMLInputFactory.newInstance()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make a private function for:

    val factory: XMLInputFactory = XMLInputFactory.newInstance()
    factory.setProperty(XMLInputFactory.IS_NAMESPACE_AWARE, false)
    factory.setProperty(XMLInputFactory.IS_COALESCING, true)

    val filter = new EventFilter {
      override def accept(event: XMLEvent): Boolean =
      // Ignore comments. This library does not treat comments.
        event.getEventType != XMLStreamConstants.COMMENT
    }

and deduplicate the logic with parse above.

Also, I would make a function to deduplicate

    val eventReader = factory.createXMLEventReader(new StringReader(xml))
    val parser = factory.createFilteredReader(eventReader, filter)

    val rootEvent =
      StaxXmlParserUtils.skipUntil(parser, XMLStreamConstants.START_ELEMENT)
    val rootAttributes =
      rootEvent.asStartElement.getAttributes.asScala.map(_.asInstanceOf[Attribute]).toArray
    convertObject(parser, schema, options, rootAttributes)

as well.

Copy link
Member

@HyukjinKwon HyukjinKwon Feb 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently,

  1. XML string
  2. XMLInputFactory instance is created
  3. XML is parsed to Row

I think this can be avoided by:

class XmlDataToCatalyst(..) {
  @transient
  lazy val factory: XMLInputFactory = XMLInputFactory.newInstance()

  def nullSafeEval(...) = {
    factory.createXMLEventReader(new StringReader(xml))
  }
}

In this way, we will create one factory first, and then reuse it for every string input. (BTW, it's quite critical to fix this.).

factory.setProperty(XMLInputFactory.IS_NAMESPACE_AWARE, false)
factory.setProperty(XMLInputFactory.IS_COALESCING, true)

val parser = filteredReader(xml, factory)
val rootAttributes = gatherRootAttributes(parser)

convertObject(parser, schema, options, rootAttributes)
}


/**
* Parse the current token (and related children) according to a desired schema
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package com.databricks.spark.xml.parsers

import javax.xml.stream.XMLEventReader
import java.io.ByteArrayInputStream

import javax.xml.stream.{EventFilter, XMLEventReader, XMLInputFactory, XMLStreamConstants}
import javax.xml.stream.events._

import scala.annotation.tailrec
Expand Down
27 changes: 27 additions & 0 deletions src/test/scala/com/databricks/spark/xml/XmlSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ import com.databricks.spark.xml.util._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SaveMode, SparkSession}
import org.apache.spark.SparkException
import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode}
import org.apache.spark.{SparkConf, SparkContext, SparkException}
import org.apache.spark.sql.functions._

final class XmlSuite extends FunSuite with BeforeAndAfterAll {

Expand Down Expand Up @@ -1019,4 +1022,28 @@ final class XmlSuite extends FunSuite with BeforeAndAfterAll {
<integer_value int="Ten">Ten</integer_value>.toString))
}


test("from_xml roundtrip happy path") {

val xmlData =
"""
| <parent><pid>14ft3</pid>
| <name>dave guy</name>
| </parent>
""".stripMargin

val xmlSchema: StructType = new StructType().add("pid", StringType).add("name", StringType)
val rowSchema: StructType = StructType(
Seq(StructField("number", IntegerType, true), StructField("payload", StringType, true)))
val expectedSchema: StructType = rowSchema.add("decoded", xmlSchema)

val df: DataFrame = spark.createDataFrame(
spark.sparkContext.parallelize(List(Row(8, xmlData))), rowSchema)

val result: DataFrame = df.withColumn("decoded",
from_xml(df.col("payload"), xmlSchema))

assert(expectedSchema == result.schema)
assert(result.where(col("decoded").isNotNull).count() > 0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the actual output was also asserted and checked for completeness.

}
}