-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-23690][ML] Add handleinvalid to VectorAssembler #20829
Changes from all commits
1788379
c34332d
f2f763d
272a806
dc99db8
61fbcc4
08b8c04
8c98d36
cb0faba
3c3532c
d29228c
c0c0e3d
bf2f5b3
5ce7671
482225f
8ee702d
2b1fd4e
9624061
ab91545
e7e26f0
4c99003
f5a31a6
081b5c0
134bd1e
bf277be
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,14 +17,17 @@ | |
|
||
package org.apache.spark.ml.feature | ||
|
||
import scala.collection.mutable.ArrayBuilder | ||
import java.util.NoSuchElementException | ||
|
||
import scala.collection.mutable | ||
import scala.language.existentials | ||
|
||
import org.apache.spark.SparkException | ||
import org.apache.spark.annotation.Since | ||
import org.apache.spark.ml.Transformer | ||
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute} | ||
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} | ||
import org.apache.spark.ml.param.ParamMap | ||
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} | ||
import org.apache.spark.ml.param.shared._ | ||
import org.apache.spark.ml.util._ | ||
import org.apache.spark.sql.{DataFrame, Dataset, Row} | ||
|
@@ -33,10 +36,14 @@ import org.apache.spark.sql.types._ | |
|
||
/** | ||
* A feature transformer that merges multiple columns into a vector column. | ||
* | ||
* This requires one pass over the entire dataset. In case we need to infer column lengths from the | ||
* data we require an additional call to the 'first' Dataset method, see 'handleInvalid' parameter. | ||
*/ | ||
@Since("1.4.0") | ||
class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) | ||
extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable { | ||
extends Transformer with HasInputCols with HasOutputCol with HasHandleInvalid | ||
with DefaultParamsWritable { | ||
|
||
@Since("1.4.0") | ||
def this() = this(Identifiable.randomUID("vecAssembler")) | ||
|
@@ -49,32 +56,63 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) | |
@Since("1.4.0") | ||
def setOutputCol(value: String): this.type = set(outputCol, value) | ||
|
||
/** @group setParam */ | ||
@Since("2.4.0") | ||
def setHandleInvalid(value: String): this.type = set(handleInvalid, value) | ||
|
||
/** | ||
* Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be good to expand this doc to explain the behavior: how various types of invalid values are treated (null, NaN, incorrect Vector length) and how computationally expensive different options can be. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Behavior of options already included, explanation of column length included here, run time information included in the VectorAssembler class's documentation. Thanks for the suggestion, this is super important! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also, we just deal with nulls here. NaNs and incorrect length vectors are transmitted transparently. Do we need to test for those? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd recommend we deal with NaNs now. This PR is already dealing with some NaN cases: Dataset.na.drop handles NaNs in NumericType columns (but not VectorUDT columns). I'm Ok with postponing incorrect vector lengths until later or doing that now since that work will be more separate. |
||
* invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the | ||
* output). Column lengths are taken from the size of ML Attribute Group, which can be set using | ||
* `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred | ||
* from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'. | ||
* Default: "error" | ||
* @group param | ||
*/ | ||
@Since("2.4.0") | ||
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", | ||
"""Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with | ||
|invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the | ||
|output). Column lengths are taken from the size of ML Attribute Group, which can be set using | ||
|`VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred | ||
|from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'. | ||
|""".stripMargin.replaceAll("\n", " "), | ||
ParamValidators.inArray(VectorAssembler.supportedHandleInvalids)) | ||
|
||
setDefault(handleInvalid, VectorAssembler.ERROR_INVALID) | ||
|
||
@Since("2.0.0") | ||
override def transform(dataset: Dataset[_]): DataFrame = { | ||
transformSchema(dataset.schema, logging = true) | ||
// Schema transformation. | ||
val schema = dataset.schema | ||
lazy val first = dataset.toDF.first() | ||
val attrs = $(inputCols).flatMap { c => | ||
|
||
val vectorCols = $(inputCols).filter { c => | ||
schema(c).dataType match { | ||
case _: VectorUDT => true | ||
case _ => false | ||
} | ||
} | ||
val vectorColsLengths = VectorAssembler.getLengths(dataset, vectorCols, $(handleInvalid)) | ||
|
||
val featureAttributesMap = $(inputCols).map { c => | ||
val field = schema(c) | ||
val index = schema.fieldIndex(c) | ||
field.dataType match { | ||
case DoubleType => | ||
val attr = Attribute.fromStructField(field) | ||
// If the input column doesn't have ML attribute, assume numeric. | ||
if (attr == UnresolvedAttribute) { | ||
Some(NumericAttribute.defaultAttr.withName(c)) | ||
} else { | ||
Some(attr.withName(c)) | ||
val attribute = Attribute.fromStructField(field) | ||
attribute match { | ||
case UnresolvedAttribute => | ||
Seq(NumericAttribute.defaultAttr.withName(c)) | ||
case _ => | ||
Seq(attribute.withName(c)) | ||
} | ||
case _: NumericType | BooleanType => | ||
// If the input column type is a compatible scalar type, assume numeric. | ||
Some(NumericAttribute.defaultAttr.withName(c)) | ||
Seq(NumericAttribute.defaultAttr.withName(c)) | ||
case _: VectorUDT => | ||
val group = AttributeGroup.fromStructField(field) | ||
if (group.attributes.isDefined) { | ||
// If attributes are defined, copy them with updated names. | ||
group.attributes.get.zipWithIndex.map { case (attr, i) => | ||
val attributeGroup = AttributeGroup.fromStructField(field) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for the future, I'd avoid renaming things like this unless it's really unclear or needed (to make diffs shorter) |
||
if (attributeGroup.attributes.isDefined) { | ||
attributeGroup.attributes.get.zipWithIndex.toSeq.map { case (attr, i) => | ||
if (attr.name.isDefined) { | ||
// TODO: Define a rigorous naming scheme. | ||
attr.withName(c + "_" + attr.name.get) | ||
|
@@ -85,18 +123,25 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) | |
} else { | ||
// Otherwise, treat all attributes as numeric. If we cannot get the number of attributes | ||
// from metadata, check the first row. | ||
val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size) | ||
Array.tabulate(numAttrs)(i => NumericAttribute.defaultAttr.withName(c + "_" + i)) | ||
(0 until vectorColsLengths(c)).map { i => | ||
NumericAttribute.defaultAttr.withName(c + "_" + i) | ||
} | ||
} | ||
case otherType => | ||
throw new SparkException(s"VectorAssembler does not support the $otherType type") | ||
} | ||
} | ||
val metadata = new AttributeGroup($(outputCol), attrs).toMetadata() | ||
|
||
val featureAttributes = featureAttributesMap.flatten[Attribute].toArray | ||
val lengths = featureAttributesMap.map(a => a.length).toArray | ||
val metadata = new AttributeGroup($(outputCol), featureAttributes).toMetadata() | ||
val (filteredDataset, keepInvalid) = $(handleInvalid) match { | ||
case VectorAssembler.SKIP_INVALID => (dataset.na.drop($(inputCols)), false) | ||
case VectorAssembler.KEEP_INVALID => (dataset, true) | ||
case VectorAssembler.ERROR_INVALID => (dataset, false) | ||
} | ||
// Data transformation. | ||
val assembleFunc = udf { r: Row => | ||
VectorAssembler.assemble(r.toSeq: _*) | ||
VectorAssembler.assemble(lengths, keepInvalid)(r.toSeq: _*) | ||
}.asNondeterministic() | ||
val args = $(inputCols).map { c => | ||
schema(c).dataType match { | ||
|
@@ -106,7 +151,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) | |
} | ||
} | ||
|
||
dataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata)) | ||
filteredDataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata)) | ||
} | ||
|
||
@Since("1.4.0") | ||
|
@@ -136,34 +181,117 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) | |
@Since("1.6.0") | ||
object VectorAssembler extends DefaultParamsReadable[VectorAssembler] { | ||
|
||
private[feature] val SKIP_INVALID: String = "skip" | ||
private[feature] val ERROR_INVALID: String = "error" | ||
private[feature] val KEEP_INVALID: String = "keep" | ||
private[feature] val supportedHandleInvalids: Array[String] = | ||
Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID) | ||
|
||
/** | ||
* Infers lengths of vector columns from the first row of the dataset | ||
* @param dataset the dataset | ||
* @param columns name of vector columns whose lengths need to be inferred | ||
* @return map of column names to lengths | ||
*/ | ||
private[feature] def getVectorLengthsFromFirstRow( | ||
dataset: Dataset[_], | ||
columns: Seq[String]): Map[String, Int] = { | ||
try { | ||
val first_row = dataset.toDF().select(columns.map(col): _*).first() | ||
columns.zip(first_row.toSeq).map { | ||
case (c, x) => c -> x.asInstanceOf[Vector].size | ||
}.toMap | ||
} catch { | ||
case e: NullPointerException => throw new NullPointerException( | ||
s"""Encountered null value while inferring lengths from the first row. Consider using | ||
|VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """ | ||
.stripMargin.replaceAll("\n", " ") + e.toString) | ||
case e: NoSuchElementException => throw new NoSuchElementException( | ||
s"""Encountered empty dataframe while inferring lengths from the first row. Consider using | ||
|VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """ | ||
.stripMargin.replaceAll("\n", " ") + e.toString) | ||
} | ||
} | ||
|
||
private[feature] def getLengths( | ||
dataset: Dataset[_], | ||
columns: Seq[String], | ||
handleInvalid: String): Map[String, Int] = { | ||
val groupSizes = columns.map { c => | ||
c -> AttributeGroup.fromStructField(dataset.schema(c)).size | ||
}.toMap | ||
val missingColumns = groupSizes.filter(_._2 == -1).keys.toSeq | ||
val firstSizes = (missingColumns.nonEmpty, handleInvalid) match { | ||
case (true, VectorAssembler.ERROR_INVALID) => | ||
getVectorLengthsFromFirstRow(dataset, missingColumns) | ||
case (true, VectorAssembler.SKIP_INVALID) => | ||
getVectorLengthsFromFirstRow(dataset.na.drop(missingColumns), missingColumns) | ||
case (true, VectorAssembler.KEEP_INVALID) => throw new RuntimeException( | ||
s"""Can not infer column lengths with handleInvalid = "keep". Consider using VectorSizeHint | ||
|to add metadata for columns: ${columns.mkString("[", ", ", "]")}.""" | ||
.stripMargin.replaceAll("\n", " ")) | ||
case (_, _) => Map.empty | ||
} | ||
groupSizes ++ firstSizes | ||
} | ||
|
||
|
||
@Since("1.6.0") | ||
override def load(path: String): VectorAssembler = super.load(path) | ||
|
||
private[feature] def assemble(vv: Any*): Vector = { | ||
val indices = ArrayBuilder.make[Int] | ||
val values = ArrayBuilder.make[Double] | ||
var cur = 0 | ||
/** | ||
* Returns a function that has the required information to assemble each row. | ||
* @param lengths an array of lengths of input columns, whose size should be equal to the number | ||
* of cells in the row (vv) | ||
* @param keepInvalid indicate whether to throw an error or not on seeing a null in the rows | ||
* @return a udf that can be applied on each row | ||
*/ | ||
private[feature] def assemble(lengths: Array[Int], keepInvalid: Boolean)(vv: Any*): Vector = { | ||
val indices = mutable.ArrayBuilder.make[Int] | ||
val values = mutable.ArrayBuilder.make[Double] | ||
var featureIndex = 0 | ||
|
||
var inputColumnIndex = 0 | ||
vv.foreach { | ||
case v: Double => | ||
if (v != 0.0) { | ||
indices += cur | ||
if (v.isNaN && !keepInvalid) { | ||
throw new SparkException( | ||
s"""Encountered NaN while assembling a row with handleInvalid = "error". Consider | ||
|removing NaNs from dataset or using handleInvalid = "keep" or "skip".""" | ||
.stripMargin) | ||
} else if (v != 0.0) { | ||
indices += featureIndex | ||
values += v | ||
} | ||
cur += 1 | ||
inputColumnIndex += 1 | ||
featureIndex += 1 | ||
case vec: Vector => | ||
vec.foreachActive { case (i, v) => | ||
if (v != 0.0) { | ||
indices += cur + i | ||
indices += featureIndex + i | ||
values += v | ||
} | ||
} | ||
cur += vec.size | ||
inputColumnIndex += 1 | ||
featureIndex += vec.size | ||
case null => | ||
// TODO: output Double.NaN? | ||
throw new SparkException("Values to assemble cannot be null.") | ||
if (keepInvalid) { | ||
val length: Int = lengths(inputColumnIndex) | ||
Array.range(0, length).foreach { i => | ||
indices += featureIndex + i | ||
values += Double.NaN | ||
} | ||
inputColumnIndex += 1 | ||
featureIndex += length | ||
} else { | ||
throw new SparkException( | ||
s"""Encountered null while assembling a row with handleInvalid = "keep". Consider | ||
|removing nulls from dataset or using handleInvalid = "keep" or "skip".""" | ||
.stripMargin) | ||
} | ||
case o => | ||
throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") | ||
} | ||
Vectors.sparse(cur, indices.result(), values.result()).compressed | ||
Vectors.sparse(featureIndex, indices.result(), values.result()).compressed | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why need change this line ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for picking this out! I changed this because I was matching on
$(handleInvalid)
in VectorAssembler and that seems to be the recommended way of doing this. Should I include this in the current PR and add a note or open a separate PR?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok. it doesn't matter no need separate PR I think. just a minor change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the record, in general, I would not bother making changes like this. The one exception I do make is IntelliJ style complaints since those can be annoying for developers.