Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23690][ML] Add handleinvalid to VectorAssembler #20829

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ class StringIndexerModel (
val metadata = NominalAttribute.defaultAttr
.withName($(outputCol)).withValues(filteredLabels).toMetadata()
// If we are skipping invalid records, filter them out.
val (filteredDataset, keepInvalid) = getHandleInvalid match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why need change this line ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for picking this out! I changed this because I was matching on $(handleInvalid) in VectorAssembler and that seems to be the recommended way of doing this. Should I include this in the current PR and add a note or open a separate PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. it doesn't matter no need separate PR I think. just a minor change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the record, in general, I would not bother making changes like this. The one exception I do make is IntelliJ style complaints since those can be annoying for developers.

val (filteredDataset, keepInvalid) = $(handleInvalid) match {
case StringIndexer.SKIP_INVALID =>
val filterer = udf { label: String =>
labelToIndex.contains(label)
Expand Down
198 changes: 163 additions & 35 deletions mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@

package org.apache.spark.ml.feature

import scala.collection.mutable.ArrayBuilder
import java.util.NoSuchElementException

import scala.collection.mutable
import scala.language.existentials

import org.apache.spark.SparkException
import org.apache.spark.annotation.Since
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute}
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
Expand All @@ -33,10 +36,14 @@ import org.apache.spark.sql.types._

/**
* A feature transformer that merges multiple columns into a vector column.
*
* This requires one pass over the entire dataset. In case we need to infer column lengths from the
* data we require an additional call to the 'first' Dataset method, see 'handleInvalid' parameter.
*/
@Since("1.4.0")
class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable {
extends Transformer with HasInputCols with HasOutputCol with HasHandleInvalid
with DefaultParamsWritable {

@Since("1.4.0")
def this() = this(Identifiable.randomUID("vecAssembler"))
Expand All @@ -49,32 +56,63 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
@Since("1.4.0")
def setOutputCol(value: String): this.type = set(outputCol, value)

/** @group setParam */
@Since("2.4.0")
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)

/**
* Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to expand this doc to explain the behavior: how various types of invalid values are treated (null, NaN, incorrect Vector length) and how computationally expensive different options can be.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Behavior of options already included, explanation of column length included here, run time information included in the VectorAssembler class's documentation. Thanks for the suggestion, this is super important!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, we just deal with nulls here. NaNs and incorrect length vectors are transmitted transparently. Do we need to test for those?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd recommend we deal with NaNs now. This PR is already dealing with some NaN cases: Dataset.na.drop handles NaNs in NumericType columns (but not VectorUDT columns).

I'm Ok with postponing incorrect vector lengths until later or doing that now since that work will be more separate.

* invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the
* output). Column lengths are taken from the size of ML Attribute Group, which can be set using
* `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred
* from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'.
* Default: "error"
* @group param
*/
@Since("2.4.0")
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
"""Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
|invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the
|output). Column lengths are taken from the size of ML Attribute Group, which can be set using
|`VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred
|from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'.
|""".stripMargin.replaceAll("\n", " "),
ParamValidators.inArray(VectorAssembler.supportedHandleInvalids))

setDefault(handleInvalid, VectorAssembler.ERROR_INVALID)

@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
// Schema transformation.
val schema = dataset.schema
lazy val first = dataset.toDF.first()
val attrs = $(inputCols).flatMap { c =>

val vectorCols = $(inputCols).filter { c =>
schema(c).dataType match {
case _: VectorUDT => true
case _ => false
}
}
val vectorColsLengths = VectorAssembler.getLengths(dataset, vectorCols, $(handleInvalid))

val featureAttributesMap = $(inputCols).map { c =>
val field = schema(c)
val index = schema.fieldIndex(c)
field.dataType match {
case DoubleType =>
val attr = Attribute.fromStructField(field)
// If the input column doesn't have ML attribute, assume numeric.
if (attr == UnresolvedAttribute) {
Some(NumericAttribute.defaultAttr.withName(c))
} else {
Some(attr.withName(c))
val attribute = Attribute.fromStructField(field)
attribute match {
case UnresolvedAttribute =>
Seq(NumericAttribute.defaultAttr.withName(c))
case _ =>
Seq(attribute.withName(c))
}
case _: NumericType | BooleanType =>
// If the input column type is a compatible scalar type, assume numeric.
Some(NumericAttribute.defaultAttr.withName(c))
Seq(NumericAttribute.defaultAttr.withName(c))
case _: VectorUDT =>
val group = AttributeGroup.fromStructField(field)
if (group.attributes.isDefined) {
// If attributes are defined, copy them with updated names.
group.attributes.get.zipWithIndex.map { case (attr, i) =>
val attributeGroup = AttributeGroup.fromStructField(field)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the future, I'd avoid renaming things like this unless it's really unclear or needed (to make diffs shorter)

if (attributeGroup.attributes.isDefined) {
attributeGroup.attributes.get.zipWithIndex.toSeq.map { case (attr, i) =>
if (attr.name.isDefined) {
// TODO: Define a rigorous naming scheme.
attr.withName(c + "_" + attr.name.get)
Expand All @@ -85,18 +123,25 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
} else {
// Otherwise, treat all attributes as numeric. If we cannot get the number of attributes
// from metadata, check the first row.
val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size)
Array.tabulate(numAttrs)(i => NumericAttribute.defaultAttr.withName(c + "_" + i))
(0 until vectorColsLengths(c)).map { i =>
NumericAttribute.defaultAttr.withName(c + "_" + i)
}
}
case otherType =>
throw new SparkException(s"VectorAssembler does not support the $otherType type")
}
}
val metadata = new AttributeGroup($(outputCol), attrs).toMetadata()

val featureAttributes = featureAttributesMap.flatten[Attribute].toArray
val lengths = featureAttributesMap.map(a => a.length).toArray
val metadata = new AttributeGroup($(outputCol), featureAttributes).toMetadata()
val (filteredDataset, keepInvalid) = $(handleInvalid) match {
case VectorAssembler.SKIP_INVALID => (dataset.na.drop($(inputCols)), false)
case VectorAssembler.KEEP_INVALID => (dataset, true)
case VectorAssembler.ERROR_INVALID => (dataset, false)
}
// Data transformation.
val assembleFunc = udf { r: Row =>
VectorAssembler.assemble(r.toSeq: _*)
VectorAssembler.assemble(lengths, keepInvalid)(r.toSeq: _*)
}.asNondeterministic()
val args = $(inputCols).map { c =>
schema(c).dataType match {
Expand All @@ -106,7 +151,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
}
}

dataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata))
filteredDataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata))
}

@Since("1.4.0")
Expand Down Expand Up @@ -136,34 +181,117 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
@Since("1.6.0")
object VectorAssembler extends DefaultParamsReadable[VectorAssembler] {

private[feature] val SKIP_INVALID: String = "skip"
private[feature] val ERROR_INVALID: String = "error"
private[feature] val KEEP_INVALID: String = "keep"
private[feature] val supportedHandleInvalids: Array[String] =
Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)

/**
* Infers lengths of vector columns from the first row of the dataset
* @param dataset the dataset
* @param columns name of vector columns whose lengths need to be inferred
* @return map of column names to lengths
*/
private[feature] def getVectorLengthsFromFirstRow(
dataset: Dataset[_],
columns: Seq[String]): Map[String, Int] = {
try {
val first_row = dataset.toDF().select(columns.map(col): _*).first()
columns.zip(first_row.toSeq).map {
case (c, x) => c -> x.asInstanceOf[Vector].size
}.toMap
} catch {
case e: NullPointerException => throw new NullPointerException(
s"""Encountered null value while inferring lengths from the first row. Consider using
|VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """
.stripMargin.replaceAll("\n", " ") + e.toString)
case e: NoSuchElementException => throw new NoSuchElementException(
s"""Encountered empty dataframe while inferring lengths from the first row. Consider using
|VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """
.stripMargin.replaceAll("\n", " ") + e.toString)
}
}

private[feature] def getLengths(
dataset: Dataset[_],
columns: Seq[String],
handleInvalid: String): Map[String, Int] = {
val groupSizes = columns.map { c =>
c -> AttributeGroup.fromStructField(dataset.schema(c)).size
}.toMap
val missingColumns = groupSizes.filter(_._2 == -1).keys.toSeq
val firstSizes = (missingColumns.nonEmpty, handleInvalid) match {
case (true, VectorAssembler.ERROR_INVALID) =>
getVectorLengthsFromFirstRow(dataset, missingColumns)
case (true, VectorAssembler.SKIP_INVALID) =>
getVectorLengthsFromFirstRow(dataset.na.drop(missingColumns), missingColumns)
case (true, VectorAssembler.KEEP_INVALID) => throw new RuntimeException(
s"""Can not infer column lengths with handleInvalid = "keep". Consider using VectorSizeHint
|to add metadata for columns: ${columns.mkString("[", ", ", "]")}."""
.stripMargin.replaceAll("\n", " "))
case (_, _) => Map.empty
}
groupSizes ++ firstSizes
}


@Since("1.6.0")
override def load(path: String): VectorAssembler = super.load(path)

private[feature] def assemble(vv: Any*): Vector = {
val indices = ArrayBuilder.make[Int]
val values = ArrayBuilder.make[Double]
var cur = 0
/**
* Returns a function that has the required information to assemble each row.
* @param lengths an array of lengths of input columns, whose size should be equal to the number
* of cells in the row (vv)
* @param keepInvalid indicate whether to throw an error or not on seeing a null in the rows
* @return a udf that can be applied on each row
*/
private[feature] def assemble(lengths: Array[Int], keepInvalid: Boolean)(vv: Any*): Vector = {
val indices = mutable.ArrayBuilder.make[Int]
val values = mutable.ArrayBuilder.make[Double]
var featureIndex = 0

var inputColumnIndex = 0
vv.foreach {
case v: Double =>
if (v != 0.0) {
indices += cur
if (v.isNaN && !keepInvalid) {
throw new SparkException(
s"""Encountered NaN while assembling a row with handleInvalid = "error". Consider
|removing NaNs from dataset or using handleInvalid = "keep" or "skip"."""
.stripMargin)
} else if (v != 0.0) {
indices += featureIndex
values += v
}
cur += 1
inputColumnIndex += 1
featureIndex += 1
case vec: Vector =>
vec.foreachActive { case (i, v) =>
if (v != 0.0) {
indices += cur + i
indices += featureIndex + i
values += v
}
}
cur += vec.size
inputColumnIndex += 1
featureIndex += vec.size
case null =>
// TODO: output Double.NaN?
throw new SparkException("Values to assemble cannot be null.")
if (keepInvalid) {
val length: Int = lengths(inputColumnIndex)
Array.range(0, length).foreach { i =>
indices += featureIndex + i
values += Double.NaN
}
inputColumnIndex += 1
featureIndex += length
} else {
throw new SparkException(
s"""Encountered null while assembling a row with handleInvalid = "keep". Consider
|removing nulls from dataset or using handleInvalid = "keep" or "skip"."""
.stripMargin)
}
case o =>
throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.")
}
Vectors.sparse(cur, indices.result(), values.result()).compressed
Vectors.sparse(featureIndex, indices.result(), values.result()).compressed
}
}
Loading