Skip to content

Commit

Permalink
[SPARK-23690][ML] Add handleinvalid to VectorAssembler
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Introduce `handleInvalid` parameter in `VectorAssembler` that can take in `"keep", "skip", "error"` options. "error" throws an error on seeing a row containing a `null`, "skip" filters out all such rows, and "keep" adds relevant number of NaN. "keep" figures out an example to find out what this number of NaN s should be added and throws an error when no such number could be found.

## How was this patch tested?

Unit tests are added to check the behavior of `assemble` on specific rows and the transformer is called on `DataFrame`s of different configurations to test different corner cases.

Author: Yogesh Garg <yogesh(dot)garg()databricks(dot)com>
Author: Bago Amirbekian <[email protected]>
Author: Yogesh Garg <[email protected]>

Closes apache#20829 from yogeshg/rformula_handleinvalid.
  • Loading branch information
Yogesh Garg authored and Robert Kruszewski committed Apr 4, 2018
1 parent 83c1da3 commit f9927a5
Show file tree
Hide file tree
Showing 3 changed files with 284 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ class StringIndexerModel (
val metadata = NominalAttribute.defaultAttr
.withName($(outputCol)).withValues(filteredLabels).toMetadata()
// If we are skipping invalid records, filter them out.
val (filteredDataset, keepInvalid) = getHandleInvalid match {
val (filteredDataset, keepInvalid) = $(handleInvalid) match {
case StringIndexer.SKIP_INVALID =>
val filterer = udf { label: String =>
labelToIndex.contains(label)
Expand Down
198 changes: 163 additions & 35 deletions mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@

package org.apache.spark.ml.feature

import scala.collection.mutable.ArrayBuilder
import java.util.NoSuchElementException

import scala.collection.mutable
import scala.language.existentials

import org.apache.spark.SparkException
import org.apache.spark.annotation.Since
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute}
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
Expand All @@ -33,10 +36,14 @@ import org.apache.spark.sql.types._

/**
* A feature transformer that merges multiple columns into a vector column.
*
* This requires one pass over the entire dataset. In case we need to infer column lengths from the
* data we require an additional call to the 'first' Dataset method, see 'handleInvalid' parameter.
*/
@Since("1.4.0")
class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable {
extends Transformer with HasInputCols with HasOutputCol with HasHandleInvalid
with DefaultParamsWritable {

@Since("1.4.0")
def this() = this(Identifiable.randomUID("vecAssembler"))
Expand All @@ -49,32 +56,63 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
@Since("1.4.0")
def setOutputCol(value: String): this.type = set(outputCol, value)

/** @group setParam */
@Since("2.4.0")
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)

/**
* Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
* invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the
* output). Column lengths are taken from the size of ML Attribute Group, which can be set using
* `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred
* from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'.
* Default: "error"
* @group param
*/
@Since("2.4.0")
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
"""Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
|invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the
|output). Column lengths are taken from the size of ML Attribute Group, which can be set using
|`VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred
|from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'.
|""".stripMargin.replaceAll("\n", " "),
ParamValidators.inArray(VectorAssembler.supportedHandleInvalids))

setDefault(handleInvalid, VectorAssembler.ERROR_INVALID)

@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
// Schema transformation.
val schema = dataset.schema
lazy val first = dataset.toDF.first()
val attrs = $(inputCols).flatMap { c =>

val vectorCols = $(inputCols).filter { c =>
schema(c).dataType match {
case _: VectorUDT => true
case _ => false
}
}
val vectorColsLengths = VectorAssembler.getLengths(dataset, vectorCols, $(handleInvalid))

val featureAttributesMap = $(inputCols).map { c =>
val field = schema(c)
val index = schema.fieldIndex(c)
field.dataType match {
case DoubleType =>
val attr = Attribute.fromStructField(field)
// If the input column doesn't have ML attribute, assume numeric.
if (attr == UnresolvedAttribute) {
Some(NumericAttribute.defaultAttr.withName(c))
} else {
Some(attr.withName(c))
val attribute = Attribute.fromStructField(field)
attribute match {
case UnresolvedAttribute =>
Seq(NumericAttribute.defaultAttr.withName(c))
case _ =>
Seq(attribute.withName(c))
}
case _: NumericType | BooleanType =>
// If the input column type is a compatible scalar type, assume numeric.
Some(NumericAttribute.defaultAttr.withName(c))
Seq(NumericAttribute.defaultAttr.withName(c))
case _: VectorUDT =>
val group = AttributeGroup.fromStructField(field)
if (group.attributes.isDefined) {
// If attributes are defined, copy them with updated names.
group.attributes.get.zipWithIndex.map { case (attr, i) =>
val attributeGroup = AttributeGroup.fromStructField(field)
if (attributeGroup.attributes.isDefined) {
attributeGroup.attributes.get.zipWithIndex.toSeq.map { case (attr, i) =>
if (attr.name.isDefined) {
// TODO: Define a rigorous naming scheme.
attr.withName(c + "_" + attr.name.get)
Expand All @@ -85,18 +123,25 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
} else {
// Otherwise, treat all attributes as numeric. If we cannot get the number of attributes
// from metadata, check the first row.
val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size)
Array.tabulate(numAttrs)(i => NumericAttribute.defaultAttr.withName(c + "_" + i))
(0 until vectorColsLengths(c)).map { i =>
NumericAttribute.defaultAttr.withName(c + "_" + i)
}
}
case otherType =>
throw new SparkException(s"VectorAssembler does not support the $otherType type")
}
}
val metadata = new AttributeGroup($(outputCol), attrs).toMetadata()

val featureAttributes = featureAttributesMap.flatten[Attribute].toArray
val lengths = featureAttributesMap.map(a => a.length).toArray
val metadata = new AttributeGroup($(outputCol), featureAttributes).toMetadata()
val (filteredDataset, keepInvalid) = $(handleInvalid) match {
case VectorAssembler.SKIP_INVALID => (dataset.na.drop($(inputCols)), false)
case VectorAssembler.KEEP_INVALID => (dataset, true)
case VectorAssembler.ERROR_INVALID => (dataset, false)
}
// Data transformation.
val assembleFunc = udf { r: Row =>
VectorAssembler.assemble(r.toSeq: _*)
VectorAssembler.assemble(lengths, keepInvalid)(r.toSeq: _*)
}.asNondeterministic()
val args = $(inputCols).map { c =>
schema(c).dataType match {
Expand All @@ -106,7 +151,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
}
}

dataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata))
filteredDataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata))
}

@Since("1.4.0")
Expand Down Expand Up @@ -136,34 +181,117 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
@Since("1.6.0")
object VectorAssembler extends DefaultParamsReadable[VectorAssembler] {

private[feature] val SKIP_INVALID: String = "skip"
private[feature] val ERROR_INVALID: String = "error"
private[feature] val KEEP_INVALID: String = "keep"
private[feature] val supportedHandleInvalids: Array[String] =
Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)

/**
* Infers lengths of vector columns from the first row of the dataset
* @param dataset the dataset
* @param columns name of vector columns whose lengths need to be inferred
* @return map of column names to lengths
*/
private[feature] def getVectorLengthsFromFirstRow(
dataset: Dataset[_],
columns: Seq[String]): Map[String, Int] = {
try {
val first_row = dataset.toDF().select(columns.map(col): _*).first()
columns.zip(first_row.toSeq).map {
case (c, x) => c -> x.asInstanceOf[Vector].size
}.toMap
} catch {
case e: NullPointerException => throw new NullPointerException(
s"""Encountered null value while inferring lengths from the first row. Consider using
|VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """
.stripMargin.replaceAll("\n", " ") + e.toString)
case e: NoSuchElementException => throw new NoSuchElementException(
s"""Encountered empty dataframe while inferring lengths from the first row. Consider using
|VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """
.stripMargin.replaceAll("\n", " ") + e.toString)
}
}

private[feature] def getLengths(
dataset: Dataset[_],
columns: Seq[String],
handleInvalid: String): Map[String, Int] = {
val groupSizes = columns.map { c =>
c -> AttributeGroup.fromStructField(dataset.schema(c)).size
}.toMap
val missingColumns = groupSizes.filter(_._2 == -1).keys.toSeq
val firstSizes = (missingColumns.nonEmpty, handleInvalid) match {
case (true, VectorAssembler.ERROR_INVALID) =>
getVectorLengthsFromFirstRow(dataset, missingColumns)
case (true, VectorAssembler.SKIP_INVALID) =>
getVectorLengthsFromFirstRow(dataset.na.drop(missingColumns), missingColumns)
case (true, VectorAssembler.KEEP_INVALID) => throw new RuntimeException(
s"""Can not infer column lengths with handleInvalid = "keep". Consider using VectorSizeHint
|to add metadata for columns: ${columns.mkString("[", ", ", "]")}."""
.stripMargin.replaceAll("\n", " "))
case (_, _) => Map.empty
}
groupSizes ++ firstSizes
}


@Since("1.6.0")
override def load(path: String): VectorAssembler = super.load(path)

private[feature] def assemble(vv: Any*): Vector = {
val indices = ArrayBuilder.make[Int]
val values = ArrayBuilder.make[Double]
var cur = 0
/**
* Returns a function that has the required information to assemble each row.
* @param lengths an array of lengths of input columns, whose size should be equal to the number
* of cells in the row (vv)
* @param keepInvalid indicate whether to throw an error or not on seeing a null in the rows
* @return a udf that can be applied on each row
*/
private[feature] def assemble(lengths: Array[Int], keepInvalid: Boolean)(vv: Any*): Vector = {
val indices = mutable.ArrayBuilder.make[Int]
val values = mutable.ArrayBuilder.make[Double]
var featureIndex = 0

var inputColumnIndex = 0
vv.foreach {
case v: Double =>
if (v != 0.0) {
indices += cur
if (v.isNaN && !keepInvalid) {
throw new SparkException(
s"""Encountered NaN while assembling a row with handleInvalid = "error". Consider
|removing NaNs from dataset or using handleInvalid = "keep" or "skip"."""
.stripMargin)
} else if (v != 0.0) {
indices += featureIndex
values += v
}
cur += 1
inputColumnIndex += 1
featureIndex += 1
case vec: Vector =>
vec.foreachActive { case (i, v) =>
if (v != 0.0) {
indices += cur + i
indices += featureIndex + i
values += v
}
}
cur += vec.size
inputColumnIndex += 1
featureIndex += vec.size
case null =>
// TODO: output Double.NaN?
throw new SparkException("Values to assemble cannot be null.")
if (keepInvalid) {
val length: Int = lengths(inputColumnIndex)
Array.range(0, length).foreach { i =>
indices += featureIndex + i
values += Double.NaN
}
inputColumnIndex += 1
featureIndex += length
} else {
throw new SparkException(
s"""Encountered null while assembling a row with handleInvalid = "keep". Consider
|removing nulls from dataset or using handleInvalid = "keep" or "skip"."""
.stripMargin)
}
case o =>
throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.")
}
Vectors.sparse(cur, indices.result(), values.result()).compressed
Vectors.sparse(featureIndex, indices.result(), values.result()).compressed
}
}
Loading

0 comments on commit f9927a5

Please sign in to comment.