-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-23690][ML] Add handleinvalid to VectorAssembler #20829
Conversation
Vector assembler stuff
Test build #88249 has finished for PR 20829 at commit
|
@@ -234,7 +234,7 @@ class StringIndexerModel ( | |||
val metadata = NominalAttribute.defaultAttr | |||
.withName($(outputCol)).withValues(filteredLabels).toMetadata() | |||
// If we are skipping invalid records, filter them out. | |||
val (filteredDataset, keepInvalid) = getHandleInvalid match { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why need change this line ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for picking this out! I changed this because I was matching on $(handleInvalid)
in VectorAssembler and that seems to be the recommended way of doing this. Should I include this in the current PR and add a note or open a separate PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok. it doesn't matter no need separate PR I think. just a minor change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the record, in general, I would not bother making changes like this. The one exception I do make is IntelliJ style complaints since those can be annoying for developers.
Test build #88277 has finished for PR 20829 at commit
|
Test build #88278 has finished for PR 20829 at commit
|
test this please |
Test build #88286 has finished for PR 20829 at commit
|
Test build #88287 has finished for PR 20829 at commit
|
*/ | ||
@Since("1.6.0") | ||
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", | ||
"Hhow to handle invalid data (NULL values). Options are 'skip' (filter out rows with " + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
HHow -> How
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", | ||
"Hhow to handle invalid data (NULL values). Options are 'skip' (filter out rows with " + | ||
"invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN " + | ||
"in the * output).", ParamValidators.inArray(VectorAssembler.supportedHandleInvalids)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"in the * output" -> "in the output"
@@ -49,32 +51,65 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) | |||
@Since("1.4.0") | |||
def setOutputCol(value: String): this.type = set(outputCol, value) | |||
|
|||
/** @group setParam */ | |||
@Since("1.6.0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Since("2.4.0")
|VectorAssembler cannot determine the size of empty vectors. Consider applying | ||
|VectorSizeHint to ${c} so that this transformer can be used to transform empty | ||
|columns. | ||
""".stripMargin.replaceAll("\n", " ")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think in this case, VectorSizeHint
also cannot help to providing the vector size.
val lengths = featureAttributesMap.map(a => a.length) | ||
val metadata = new AttributeGroup($(outputCol), featureAttributes.toArray).toMetadata() | ||
val (filteredDataset, keepInvalid) = $(handleInvalid) match { | ||
case StringIndexer.SKIP_INVALID => (dataset.na.drop("any", $(inputCols)), false) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can directly use dataset.na.drop($(inputCols))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, good point! Although I do think that keeping "any" might make it easier to read, but that may not necessarily hold for experienced people :P
""".stripMargin.replaceAll("\n", " ")) | ||
} | ||
if (isMissingNumAttrs) { | ||
val column = dataset.select(c).na.drop() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
The var name
column
isn't good.colDataset
is better. -
An optional optimization is one-pass scanning the dataset and count non-null rows for each "missing num attrs" columns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! That name was bothering me too :P
@MrBago and I are thinking of another way to do this more efficiently.
What A Mess !
…On Mar 14, 2018 8:17 PM, "UCB AMPLab" ***@***.***> wrote:
Merged build finished. Test FAILed.
—
You are receiving this because you are subscribed to this thread.
Reply to this email directly, view it on GitHub
<#20829 (comment)>, or mute
the thread
<https://github.com/notifications/unsubscribe-auth/Ad_gynMT5mf8PhL_CV3DvvUzLzcUfzKvks5ted1VgaJpZM4Srd7a>
.
|
Test build #88390 has finished for PR 20829 at commit
|
Test build #88392 has finished for PR 20829 at commit
|
Test build #88395 has finished for PR 20829 at commit
|
@hootoconnor Please refrain from making non-constructive comments. If you did not intend to leave the comment here, please remove it. Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! I made a pass, mainly looking at tests.
@@ -234,7 +234,7 @@ class StringIndexerModel ( | |||
val metadata = NominalAttribute.defaultAttr | |||
.withName($(outputCol)).withValues(filteredLabels).toMetadata() | |||
// If we are skipping invalid records, filter them out. | |||
val (filteredDataset, keepInvalid) = getHandleInvalid match { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the record, in general, I would not bother making changes like this. The one exception I do make is IntelliJ style complaints since those can be annoying for developers.
@Since("1.6.0") | ||
override def load(path: String): VectorAssembler = super.load(path) | ||
|
||
private[feature] def assemble(vv: Any*): Vector = { | ||
private[feature] def assemble(lengths: Seq[Int], keepInvalid: Boolean)(vv: Any*): Vector = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Use Array[Int] for faster access
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, I'd add doc explaining requirements, especially that this assumes that lengths and vv have the same length.
} | ||
} | ||
|
||
test("assemble should compress vectors") { | ||
import org.apache.spark.ml.feature.VectorAssembler.assemble | ||
val v1 = assemble(0.0, 0.0, 0.0, Vectors.dense(4.0)) | ||
val v1 = assemble(Seq(1, 1, 1, 4), true)(0.0, 0.0, 0.0, Vectors.dense(4.0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably want this to fail, right? It expects a Vector of length 4 but is given a Vector of length 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's a typo, Thanks for pointing it out! that number is not used in case we do not have nulls, which is why the test passes
def setHandleInvalid(value: String): this.type = set(handleInvalid, value) | ||
|
||
/** | ||
* Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be good to expand this doc to explain the behavior: how various types of invalid values are treated (null, NaN, incorrect Vector length) and how computationally expensive different options can be.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Behavior of options already included, explanation of column length included here, run time information included in the VectorAssembler class's documentation. Thanks for the suggestion, this is super important!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, we just deal with nulls here. NaNs and incorrect length vectors are transmitted transparently. Do we need to test for those?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd recommend we deal with NaNs now. This PR is already dealing with some NaN cases: Dataset.na.drop handles NaNs in NumericType columns (but not VectorUDT columns).
I'm Ok with postponing incorrect vector lengths until later or doing that now since that work will be more separate.
@@ -147,4 +149,72 @@ class VectorAssemblerSuite | |||
.filter(vectorUDF($"features") > 1) | |||
.count() == 1) | |||
} | |||
|
|||
test("assemble should keep nulls") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make more explicit: + " when keepInvalid = true"
Vectors.dense(Double.NaN, Double.NaN)) | ||
} | ||
|
||
test("get lengths function") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great that you're testing this carefully, but I recommend we make sure to pass better exceptions to users. E.g., they won't know what to do with a NullPointerException, so we could instead tell them something like: "Column x in the first row of the dataset has a null entry, but VectorAssembler expected a non-null entry. This can be fixed by explicitly specifying the expected size using VectorSizeHint."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! We do throw some descriptive error here, added more description to it and made assertions in test on those messages.
} | ||
|
||
test("Handle Invalid should behave properly") { | ||
val df = Seq[(Long, Long, java.lang.Double, Vector, String, Vector, Long)]( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is shared across multiple tests, just make it a shared value. See e.g. https://github.com/apache/spark/blob/master/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala#L55
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, if there are "trash" columns not used by VectorAssembler, maybe name them as such and add a few null values in them for better testing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, good idea! this helped me in catching the drop.na()
bug that might drop everything
|
||
// behavior when first row has information | ||
assert(assembler.setHandleInvalid("skip").transform(df).count() == 1) | ||
intercept[RuntimeException](assembler.setHandleInvalid("keep").transform(df).collect()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this fail? I thought it should pad with NaNs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it fails because vector size hint is not given, adding a section with VectorSizeHInts
intercept[RuntimeException](assembler.setHandleInvalid("keep").transform(df).collect()) | ||
intercept[SparkException](assembler.setHandleInvalid("error").transform(df).collect()) | ||
|
||
// numeric column is all null |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you want to test:
- extraction of metadata from the first row (which is what this is testing, I believe), or
- transformation on an all-null column (which this never reaches)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was testing extraction of metadata for numeric column (is always 1). Not relevant in new framework.
intercept[RuntimeException]( | ||
assembler.setHandleInvalid("keep").transform(df.filter("id1==3")).count() == 1) | ||
|
||
// vector column is all null |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
…and style review wip adding an all null column should not break anything; bugfix review wip update test logic
Test build #88494 has finished for PR 20829 at commit
|
Test build #88495 has finished for PR 20829 at commit
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates! I mostly have style comments at this point.
import org.apache.spark.ml.param.shared._ | ||
import org.apache.spark.ml.util._ | ||
import org.apache.spark.sql.{DataFrame, Dataset, Row} | ||
import org.apache.spark.sql.functions._ | ||
import org.apache.spark.sql.types._ | ||
|
||
/** | ||
* A feature transformer that merges multiple columns into a vector column. | ||
* A feature transformer that merges multiple columns into a vector column. This requires one pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style nit: Move new text here into a new paragraph below. That will give nicer "pyramid-style" formatting with essential info separated from details.
def setHandleInvalid(value: String): this.type = set(handleInvalid, value) | ||
|
||
/** | ||
* Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd recommend we deal with NaNs now. This PR is already dealing with some NaN cases: Dataset.na.drop handles NaNs in NumericType columns (but not VectorUDT columns).
I'm Ok with postponing incorrect vector lengths until later or doing that now since that work will be more separate.
lazy val first = dataset.toDF.first() | ||
val attrs = $(inputCols).flatMap { c => | ||
|
||
val vectorCols = $(inputCols).toSeq.filter { c => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Is toSeq extraneous?
} | ||
val vectorColsLengths = VectorAssembler.getLengths(dataset, vectorCols, $(handleInvalid)) | ||
|
||
val featureAttributesMap = $(inputCols).toSeq.map { c => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the flatMap is simpler, or at least a more common pattern in Spark and Scala (rather than having nested sequences which are then flattened).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need the map to find out the length of vectors, unless there's a way to do this in one mapping way, I think it might be better than to call first a map
and then a flatMap
.
if (group.attributes.isDefined) { | ||
// If attributes are defined, copy them with updated names. | ||
group.attributes.get.zipWithIndex.map { case (attr, i) => | ||
val attributeGroup = AttributeGroup.fromStructField(field) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for the future, I'd avoid renaming things like this unless it's really unclear or needed (to make diffs shorter)
@Since("1.6.0") | ||
override def load(path: String): VectorAssembler = super.load(path) | ||
|
||
private[feature] def assemble(vv: Any*): Vector = { | ||
/** | ||
* Returns a UDF that has the required information to assemble each row. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: When people say "UDF," they generally mean a Spark SQL UDF. This is just a function, not a SQL UDF.
val df = Seq( | ||
(0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L) | ||
).toDF("id", "x", "y", "name", "z", "n") | ||
val df = dfWithNulls.filter("id1 == 1").withColumn("id", col("id1")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: If this is for consolidation, I'm actually against this little change since it obscures what this test is doing and moves the input Row farther from the expected output row.
.setInputCols(Array("x", "y", "z", "n")) | ||
.setOutputCol("features") | ||
|
||
def run_with_metadata(mode: String, additional_filter: String = "true"): Dataset[_] = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: use camelCase
|
||
def run_with_metadata(mode: String, additional_filter: String = "true"): Dataset[_] = { | ||
val attributeY = new AttributeGroup("y", 2) | ||
val subAttributesOfZ = Array(NumericAttribute.defaultAttr, NumericAttribute.defaultAttr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unused
output.collect() | ||
output | ||
} | ||
def run_with_first_row(mode: String): Dataset[_] = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Put empty line between functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes look good, but there are a few unaddressed comments.
c -> AttributeGroup.fromStructField(dataset.schema(c)).size | ||
}.toMap | ||
val missing_columns: Seq[String] = group_sizes.filter(_._2 == -1).keys.toSeq | ||
val first_sizes: Map[String, Int] = (missing_columns.nonEmpty, handleInvalid) match { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ping
case (true, VectorAssembler.SKIP_INVALID) => | ||
getVectorLengthsFromFirstRow(dataset.na.drop(missing_columns), missing_columns) | ||
case (true, VectorAssembler.KEEP_INVALID) => throw new RuntimeException( | ||
s"""Can not infer column lengths for 'keep invalid' mode. Consider using |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ping
Test build #88829 has finished for PR 20829 at commit
|
Test build #88834 has finished for PR 20829 at commit
|
Test build #88835 has finished for PR 20829 at commit
|
LGTM |
## What changes were proposed in this pull request? Introduce `handleInvalid` parameter in `VectorAssembler` that can take in `"keep", "skip", "error"` options. "error" throws an error on seeing a row containing a `null`, "skip" filters out all such rows, and "keep" adds relevant number of NaN. "keep" figures out an example to find out what this number of NaN s should be added and throws an error when no such number could be found. ## How was this patch tested? Unit tests are added to check the behavior of `assemble` on specific rows and the transformer is called on `DataFrame`s of different configurations to test different corner cases. Author: Yogesh Garg <yogesh(dot)garg()databricks(dot)com> Author: Bago Amirbekian <[email protected]> Author: Yogesh Garg <[email protected]> Closes apache#20829 from yogeshg/rformula_handleinvalid.
## What changes were proposed in this pull request? Introduce `handleInvalid` parameter in `VectorAssembler` that can take in `"keep", "skip", "error"` options. "error" throws an error on seeing a row containing a `null`, "skip" filters out all such rows, and "keep" adds relevant number of NaN. "keep" figures out an example to find out what this number of NaN s should be added and throws an error when no such number could be found. ## How was this patch tested? Unit tests are added to check the behavior of `assemble` on specific rows and the transformer is called on `DataFrame`s of different configurations to test different corner cases. Author: Yogesh Garg <yogesh(dot)garg()databricks(dot)com> Author: Bago Amirbekian <[email protected]> Author: Yogesh Garg <[email protected]> Closes apache#20829 from yogeshg/rformula_handleinvalid.
What changes were proposed in this pull request?
Introduce
handleInvalid
parameter inVectorAssembler
that can take in"keep", "skip", "error"
options. "error" throws an error on seeing a row containing anull
, "skip" filters out all such rows, and "keep" adds relevant number of NaN. "keep" figures out an example to find out what this number of NaN s should be added and throws an error when no such number could be found.How was this patch tested?
Unit tests are added to check the behavior of
assemble
on specific rows and the transformer is called onDataFrame
s of different configurations to test different corner cases.