Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23690][ML] Add handleinvalid to VectorAssembler #20829

Closed
wants to merge 25 commits into from
Closed

[SPARK-23690][ML] Add handleinvalid to VectorAssembler #20829

wants to merge 25 commits into from

Conversation

yogeshg
Copy link
Contributor

@yogeshg yogeshg commented Mar 15, 2018

What changes were proposed in this pull request?

Introduce handleInvalid parameter in VectorAssembler that can take in "keep", "skip", "error" options. "error" throws an error on seeing a row containing a null, "skip" filters out all such rows, and "keep" adds relevant number of NaN. "keep" figures out an example to find out what this number of NaN s should be added and throws an error when no such number could be found.

How was this patch tested?

Unit tests are added to check the behavior of assemble on specific rows and the transformer is called on DataFrames of different configurations to test different corner cases.

@yogeshg yogeshg changed the title [SPARK-23690] [ML] Add handleinvalid to VectorAssembler [SPARK-23690][ML] Add handleinvalid to VectorAssembler Mar 15, 2018
@SparkQA
Copy link

SparkQA commented Mar 15, 2018

Test build #88249 has finished for PR 20829 at commit c0c0e3d.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@@ -234,7 +234,7 @@ class StringIndexerModel (
val metadata = NominalAttribute.defaultAttr
.withName($(outputCol)).withValues(filteredLabels).toMetadata()
// If we are skipping invalid records, filter them out.
val (filteredDataset, keepInvalid) = getHandleInvalid match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why need change this line ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for picking this out! I changed this because I was matching on $(handleInvalid) in VectorAssembler and that seems to be the recommended way of doing this. Should I include this in the current PR and add a note or open a separate PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. it doesn't matter no need separate PR I think. just a minor change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the record, in general, I would not bother making changes like this. The one exception I do make is IntelliJ style complaints since those can be annoying for developers.

@SparkQA
Copy link

SparkQA commented Mar 15, 2018

Test build #88277 has finished for PR 20829 at commit bf2f5b3.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@yogeshg
Copy link
Contributor Author

yogeshg commented Mar 15, 2018

I fixed code paths that failed tests, waiting for @SparkQA . Offline talk with @MrBago suggests that we can perhaps decrease the number of maps in transform method. Looking into that.

@SparkQA
Copy link

SparkQA commented Mar 15, 2018

Test build #88278 has finished for PR 20829 at commit 5ce7671.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@yogeshg
Copy link
Contributor Author

yogeshg commented Mar 16, 2018

test this please

@SparkQA
Copy link

SparkQA commented Mar 16, 2018

Test build #88286 has finished for PR 20829 at commit 482225f.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Mar 16, 2018

Test build #88287 has finished for PR 20829 at commit 482225f.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

*/
@Since("1.6.0")
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
"Hhow to handle invalid data (NULL values). Options are 'skip' (filter out rows with " +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HHow -> How

override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
"Hhow to handle invalid data (NULL values). Options are 'skip' (filter out rows with " +
"invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN " +
"in the * output).", ParamValidators.inArray(VectorAssembler.supportedHandleInvalids))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"in the * output" -> "in the output"

@@ -49,32 +51,65 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
@Since("1.4.0")
def setOutputCol(value: String): this.type = set(outputCol, value)

/** @group setParam */
@Since("1.6.0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Since("2.4.0")

|VectorAssembler cannot determine the size of empty vectors. Consider applying
|VectorSizeHint to ${c} so that this transformer can be used to transform empty
|columns.
""".stripMargin.replaceAll("\n", " "))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in this case, VectorSizeHint also cannot help to providing the vector size.

val lengths = featureAttributesMap.map(a => a.length)
val metadata = new AttributeGroup($(outputCol), featureAttributes.toArray).toMetadata()
val (filteredDataset, keepInvalid) = $(handleInvalid) match {
case StringIndexer.SKIP_INVALID => (dataset.na.drop("any", $(inputCols)), false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can directly use dataset.na.drop($(inputCols))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good point! Although I do think that keeping "any" might make it easier to read, but that may not necessarily hold for experienced people :P

""".stripMargin.replaceAll("\n", " "))
}
if (isMissingNumAttrs) {
val column = dataset.select(c).na.drop()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • The var name column isn't good. colDataset is better.

  • An optional optimization is one-pass scanning the dataset and count non-null rows for each "missing num attrs" columns.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! That name was bothering me too :P
@MrBago and I are thinking of another way to do this more efficiently.

@ghost
Copy link

ghost commented Mar 16, 2018 via email

@SparkQA
Copy link

SparkQA commented Mar 20, 2018

Test build #88390 has finished for PR 20829 at commit 2b1fd4e.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Mar 20, 2018

Test build #88392 has finished for PR 20829 at commit 9624061.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Mar 20, 2018

Test build #88395 has finished for PR 20829 at commit ab91545.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@jkbradley
Copy link
Member

@hootoconnor Please refrain from making non-constructive comments. If you did not intend to leave the comment here, please remove it. Thanks.

Copy link
Member

@jkbradley jkbradley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! I made a pass, mainly looking at tests.

@@ -234,7 +234,7 @@ class StringIndexerModel (
val metadata = NominalAttribute.defaultAttr
.withName($(outputCol)).withValues(filteredLabels).toMetadata()
// If we are skipping invalid records, filter them out.
val (filteredDataset, keepInvalid) = getHandleInvalid match {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the record, in general, I would not bother making changes like this. The one exception I do make is IntelliJ style complaints since those can be annoying for developers.

@Since("1.6.0")
override def load(path: String): VectorAssembler = super.load(path)

private[feature] def assemble(vv: Any*): Vector = {
private[feature] def assemble(lengths: Seq[Int], keepInvalid: Boolean)(vv: Any*): Vector = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Use Array[Int] for faster access

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I'd add doc explaining requirements, especially that this assumes that lengths and vv have the same length.

}
}

test("assemble should compress vectors") {
import org.apache.spark.ml.feature.VectorAssembler.assemble
val v1 = assemble(0.0, 0.0, 0.0, Vectors.dense(4.0))
val v1 = assemble(Seq(1, 1, 1, 4), true)(0.0, 0.0, 0.0, Vectors.dense(4.0))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably want this to fail, right? It expects a Vector of length 4 but is given a Vector of length 1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a typo, Thanks for pointing it out! that number is not used in case we do not have nulls, which is why the test passes

def setHandleInvalid(value: String): this.type = set(handleInvalid, value)

/**
* Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to expand this doc to explain the behavior: how various types of invalid values are treated (null, NaN, incorrect Vector length) and how computationally expensive different options can be.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Behavior of options already included, explanation of column length included here, run time information included in the VectorAssembler class's documentation. Thanks for the suggestion, this is super important!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, we just deal with nulls here. NaNs and incorrect length vectors are transmitted transparently. Do we need to test for those?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd recommend we deal with NaNs now. This PR is already dealing with some NaN cases: Dataset.na.drop handles NaNs in NumericType columns (but not VectorUDT columns).

I'm Ok with postponing incorrect vector lengths until later or doing that now since that work will be more separate.

@@ -147,4 +149,72 @@ class VectorAssemblerSuite
.filter(vectorUDF($"features") > 1)
.count() == 1)
}

test("assemble should keep nulls") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make more explicit: + " when keepInvalid = true"

Vectors.dense(Double.NaN, Double.NaN))
}

test("get lengths function") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great that you're testing this carefully, but I recommend we make sure to pass better exceptions to users. E.g., they won't know what to do with a NullPointerException, so we could instead tell them something like: "Column x in the first row of the dataset has a null entry, but VectorAssembler expected a non-null entry. This can be fixed by explicitly specifying the expected size using VectorSizeHint."

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! We do throw some descriptive error here, added more description to it and made assertions in test on those messages.

}

test("Handle Invalid should behave properly") {
val df = Seq[(Long, Long, java.lang.Double, Vector, String, Vector, Long)](
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, if there are "trash" columns not used by VectorAssembler, maybe name them as such and add a few null values in them for better testing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, good idea! this helped me in catching the drop.na() bug that might drop everything


// behavior when first row has information
assert(assembler.setHandleInvalid("skip").transform(df).count() == 1)
intercept[RuntimeException](assembler.setHandleInvalid("keep").transform(df).collect())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this fail? I thought it should pad with NaNs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it fails because vector size hint is not given, adding a section with VectorSizeHInts

intercept[RuntimeException](assembler.setHandleInvalid("keep").transform(df).collect())
intercept[SparkException](assembler.setHandleInvalid("error").transform(df).collect())

// numeric column is all null
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you want to test:

  • extraction of metadata from the first row (which is what this is testing, I believe), or
  • transformation on an all-null column (which this never reaches)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was testing extraction of metadata for numeric column (is always 1). Not relevant in new framework.

intercept[RuntimeException](
assembler.setHandleInvalid("keep").transform(df.filter("id1==3")).count() == 1)

// vector column is all null
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Yogesh Garg added 2 commits March 21, 2018 18:05
…and style

review wip

adding an all null column should not break anything; bugfix

review wip

update test logic
@SparkQA
Copy link

SparkQA commented Mar 22, 2018

Test build #88494 has finished for PR 20829 at commit e7e26f0.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Mar 22, 2018

Test build #88495 has finished for PR 20829 at commit 4c99003.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

Copy link
Member

@jkbradley jkbradley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates! I mostly have style comments at this point.

import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

/**
* A feature transformer that merges multiple columns into a vector column.
* A feature transformer that merges multiple columns into a vector column. This requires one pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style nit: Move new text here into a new paragraph below. That will give nicer "pyramid-style" formatting with essential info separated from details.

def setHandleInvalid(value: String): this.type = set(handleInvalid, value)

/**
* Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd recommend we deal with NaNs now. This PR is already dealing with some NaN cases: Dataset.na.drop handles NaNs in NumericType columns (but not VectorUDT columns).

I'm Ok with postponing incorrect vector lengths until later or doing that now since that work will be more separate.

lazy val first = dataset.toDF.first()
val attrs = $(inputCols).flatMap { c =>

val vectorCols = $(inputCols).toSeq.filter { c =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Is toSeq extraneous?

}
val vectorColsLengths = VectorAssembler.getLengths(dataset, vectorCols, $(handleInvalid))

val featureAttributesMap = $(inputCols).toSeq.map { c =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the flatMap is simpler, or at least a more common pattern in Spark and Scala (rather than having nested sequences which are then flattened).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need the map to find out the length of vectors, unless there's a way to do this in one mapping way, I think it might be better than to call first a map and then a flatMap.

if (group.attributes.isDefined) {
// If attributes are defined, copy them with updated names.
group.attributes.get.zipWithIndex.map { case (attr, i) =>
val attributeGroup = AttributeGroup.fromStructField(field)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the future, I'd avoid renaming things like this unless it's really unclear or needed (to make diffs shorter)

@Since("1.6.0")
override def load(path: String): VectorAssembler = super.load(path)

private[feature] def assemble(vv: Any*): Vector = {
/**
* Returns a UDF that has the required information to assemble each row.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: When people say "UDF," they generally mean a Spark SQL UDF. This is just a function, not a SQL UDF.

val df = Seq(
(0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L)
).toDF("id", "x", "y", "name", "z", "n")
val df = dfWithNulls.filter("id1 == 1").withColumn("id", col("id1"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: If this is for consolidation, I'm actually against this little change since it obscures what this test is doing and moves the input Row farther from the expected output row.

.setInputCols(Array("x", "y", "z", "n"))
.setOutputCol("features")

def run_with_metadata(mode: String, additional_filter: String = "true"): Dataset[_] = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: use camelCase


def run_with_metadata(mode: String, additional_filter: String = "true"): Dataset[_] = {
val attributeY = new AttributeGroup("y", 2)
val subAttributesOfZ = Array(NumericAttribute.defaultAttr, NumericAttribute.defaultAttr)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused

output.collect()
output
}
def run_with_first_row(mode: String): Dataset[_] = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Put empty line between functions

Copy link
Member

@jkbradley jkbradley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes look good, but there are a few unaddressed comments.

c -> AttributeGroup.fromStructField(dataset.schema(c)).size
}.toMap
val missing_columns: Seq[String] = group_sizes.filter(_._2 == -1).keys.toSeq
val first_sizes: Map[String, Int] = (missing_columns.nonEmpty, handleInvalid) match {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping

case (true, VectorAssembler.SKIP_INVALID) =>
getVectorLengthsFromFirstRow(dataset.na.drop(missing_columns), missing_columns)
case (true, VectorAssembler.KEEP_INVALID) => throw new RuntimeException(
s"""Can not infer column lengths for 'keep invalid' mode. Consider using
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping

@SparkQA
Copy link

SparkQA commented Apr 2, 2018

Test build #88829 has finished for PR 20829 at commit 081b5c0.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Apr 2, 2018

Test build #88834 has finished for PR 20829 at commit 134bd1e.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Apr 2, 2018

Test build #88835 has finished for PR 20829 at commit bf277be.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@jkbradley
Copy link
Member

LGTM
Merging with master
Thanks @yogeshg for the PR and @WeichenXu123 for taking a look!

@asfgit asfgit closed this in a135182 Apr 2, 2018
robert3005 pushed a commit to palantir/spark that referenced this pull request Apr 4, 2018
## What changes were proposed in this pull request?

Introduce `handleInvalid` parameter in `VectorAssembler` that can take in `"keep", "skip", "error"` options. "error" throws an error on seeing a row containing a `null`, "skip" filters out all such rows, and "keep" adds relevant number of NaN. "keep" figures out an example to find out what this number of NaN s should be added and throws an error when no such number could be found.

## How was this patch tested?

Unit tests are added to check the behavior of `assemble` on specific rows and the transformer is called on `DataFrame`s of different configurations to test different corner cases.

Author: Yogesh Garg <yogesh(dot)garg()databricks(dot)com>
Author: Bago Amirbekian <[email protected]>
Author: Yogesh Garg <[email protected]>

Closes apache#20829 from yogeshg/rformula_handleinvalid.
mshtelma pushed a commit to mshtelma/spark that referenced this pull request Apr 5, 2018
## What changes were proposed in this pull request?

Introduce `handleInvalid` parameter in `VectorAssembler` that can take in `"keep", "skip", "error"` options. "error" throws an error on seeing a row containing a `null`, "skip" filters out all such rows, and "keep" adds relevant number of NaN. "keep" figures out an example to find out what this number of NaN s should be added and throws an error when no such number could be found.

## How was this patch tested?

Unit tests are added to check the behavior of `assemble` on specific rows and the transformer is called on `DataFrame`s of different configurations to test different corner cases.

Author: Yogesh Garg <yogesh(dot)garg()databricks(dot)com>
Author: Bago Amirbekian <[email protected]>
Author: Yogesh Garg <[email protected]>

Closes apache#20829 from yogeshg/rformula_handleinvalid.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants