-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-29808][ML][PYTHON] StopWordsRemover should support multi-cols #26480
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,23 +22,29 @@ import java.util.Locale | |
import org.apache.spark.annotation.Since | ||
import org.apache.spark.ml.Transformer | ||
import org.apache.spark.ml.param._ | ||
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} | ||
import org.apache.spark.ml.param.shared.{HasInputCol, HasInputCols, HasOutputCol, HasOutputCols} | ||
import org.apache.spark.ml.util._ | ||
import org.apache.spark.sql.{DataFrame, Dataset} | ||
import org.apache.spark.sql.functions.{col, udf} | ||
import org.apache.spark.sql.types.{ArrayType, StringType, StructType} | ||
import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType} | ||
|
||
/** | ||
* A feature transformer that filters out stop words from input. | ||
* | ||
* Since 3.0.0, | ||
* `StopWordsRemover` can filter out multiple columns at once by setting the `inputCols` | ||
* parameter. Note that when both the `inputCol` and `inputCols` parameters are set, an Exception | ||
* will be thrown. | ||
* | ||
* @note null values from input array are preserved unless adding null to stopWords | ||
* explicitly. | ||
* | ||
* @see <a href="http://en.wikipedia.org/wiki/Stop_words">Stop words (Wikipedia)</a> | ||
*/ | ||
@Since("1.5.0") | ||
class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String) | ||
extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { | ||
extends Transformer with HasInputCol with HasOutputCol with HasInputCols with HasOutputCols | ||
with DefaultParamsWritable { | ||
|
||
@Since("1.5.0") | ||
def this() = this(Identifiable.randomUID("stopWords")) | ||
|
@@ -51,6 +57,14 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String | |
@Since("1.5.0") | ||
def setOutputCol(value: String): this.type = set(outputCol, value) | ||
|
||
/** @group setParam */ | ||
@Since("3.0.0") | ||
def setInputCols(value: Array[String]): this.type = set(inputCols, value) | ||
|
||
/** @group setParam */ | ||
@Since("3.0.0") | ||
def setOutputCols(value: Array[String]): this.type = set(outputCols, value) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am debating if I should add |
||
/** | ||
* The words to be filtered out. | ||
* Default: English stop words | ||
|
@@ -121,6 +135,15 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String | |
} | ||
} | ||
|
||
/** Returns the input and output column names corresponding in pair. */ | ||
private[feature] def getInOutCols(): (Array[String], Array[String]) = { | ||
if (isSet(inputCol)) { | ||
(Array($(inputCol)), Array($(outputCol))) | ||
} else { | ||
($(inputCols), $(outputCols)) | ||
} | ||
} | ||
|
||
setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"), | ||
caseSensitive -> false, locale -> getDefaultOrUS.toString) | ||
|
||
|
@@ -142,16 +165,40 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String | |
terms.filter(s => !lowerStopWords.contains(toLower(s))) | ||
} | ||
} | ||
val metadata = outputSchema($(outputCol)).metadata | ||
dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata)) | ||
|
||
val (inputColNames, outputColNames) = getInOutCols() | ||
val ouputCols = inputColNames.map { inputColName => | ||
srowen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
t(col(inputColName)) | ||
} | ||
val ouputMetadata = outputColNames.map(outputSchema(_).metadata) | ||
dataset.withColumns(outputColNames, ouputCols, ouputMetadata) | ||
} | ||
|
||
@Since("1.5.0") | ||
override def transformSchema(schema: StructType): StructType = { | ||
val inputType = schema($(inputCol)).dataType | ||
require(inputType.sameType(ArrayType(StringType)), "Input type must be " + | ||
s"${ArrayType(StringType).catalogString} but got ${inputType.catalogString}.") | ||
SchemaUtils.appendColumn(schema, $(outputCol), inputType, schema($(inputCol)).nullable) | ||
ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol), | ||
Seq(outputCols)) | ||
|
||
if (isSet(inputCols)) { | ||
require(getInputCols.length == getOutputCols.length, | ||
s"StopWordsRemover $this has mismatched Params " + | ||
s"for multi-column transform. Params (inputCols, outputCols) should have " + | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: you don't need interpolation on these two lines. |
||
s"equal lengths, but they have different lengths: " + | ||
s"(${getInputCols.length}, ${getOutputCols.length}).") | ||
} | ||
|
||
val (inputColNames, outputColNames) = getInOutCols() | ||
var outputFields = schema.fields | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It will hardly matter unless the number of cols is large, but is it as easy and a little faster to .map the .zip below to the new output fields, and then append them once to schema.fields? |
||
inputColNames.zip(outputColNames).foreach { case (inputColName, outputColName) => | ||
require(!schema.fieldNames.contains(outputColName), | ||
s"Output Column $outputColName already exists.") | ||
val inputType = schema(inputColName).dataType | ||
require(inputType.sameType(ArrayType(StringType)), "Input type must be " + | ||
s"${ArrayType(StringType).catalogString} but got ${inputType.catalogString}.") | ||
val outputField = StructField(outputColName, inputType, schema(inputColName).nullable) | ||
outputFields :+= outputField | ||
} | ||
StructType(outputFields) | ||
} | ||
|
||
@Since("1.5.0") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ package org.apache.spark.ml.feature | |
|
||
import java.util.Locale | ||
|
||
import org.apache.spark.ml.Pipeline | ||
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} | ||
import org.apache.spark.sql.{DataFrame, Row} | ||
|
||
|
@@ -181,12 +182,19 @@ class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest { | |
} | ||
|
||
test("read/write") { | ||
val t = new StopWordsRemover() | ||
val t1 = new StopWordsRemover() | ||
.setInputCol("myInputCol") | ||
.setOutputCol("myOutputCol") | ||
.setStopWords(Array("the", "a")) | ||
.setCaseSensitive(true) | ||
testDefaultReadWrite(t) | ||
testDefaultReadWrite(t1) | ||
|
||
val t2 = new StopWordsRemover() | ||
.setInputCols(Array("input1", "input2", "input3")) | ||
.setOutputCols(Array("result1", "result2", "result3")) | ||
.setStopWords(Array("the", "a")) | ||
.setCaseSensitive(true) | ||
testDefaultReadWrite(t2) | ||
} | ||
|
||
test("StopWordsRemover output column already exists") { | ||
|
@@ -199,7 +207,7 @@ class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest { | |
testTransformerByInterceptingException[(Array[String], Array[String])]( | ||
dataSet, | ||
remover, | ||
s"requirement failed: Column $outputCol already exists.", | ||
s"requirement failed: Output Column $outputCol already exists.", | ||
"expected") | ||
} | ||
|
||
|
@@ -217,4 +225,123 @@ class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest { | |
Locale.setDefault(oldDefault) | ||
} | ||
} | ||
|
||
test("Multiple Columns: StopWordsRemover default") { | ||
val remover = new StopWordsRemover() | ||
.setInputCols(Array("raw1", "raw2")) | ||
.setOutputCols(Array("filtered1", "filtered2")) | ||
val df = Seq( | ||
(Seq("test", "test"), Seq("test1", "test2"), Seq("test", "test"), Seq("test1", "test2")), | ||
(Seq("a", "b", "c", "d"), Seq("a", "b"), Seq("b", "c", "d"), Seq("b")), | ||
(Seq("a", "the", "an"), Seq("the", "an"), Seq(), Seq()), | ||
(Seq("A", "The", "AN"), Seq("A", "The"), Seq(), Seq()), | ||
(Seq(null), Seq(null), Seq(null), Seq(null)), | ||
(Seq(), Seq(), Seq(), Seq()) | ||
).toDF("raw1", "raw2", "expected1", "expected2") | ||
|
||
remover.transform(df) | ||
.select("filtered1", "expected1", "filtered2", "expected2") | ||
.collect().foreach { | ||
case Row(r1: Seq[String], e1: Seq[String], r2: Seq[String], e2: Seq[String]) => | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Small nit: indent this more |
||
assert(r1 === e1, | ||
s"The result value is not correct after bucketing. Expected $e1 but found $r1") | ||
assert(r2 === e2, | ||
s"The result value is not correct after bucketing. Expected $e2 but found $r2") | ||
} | ||
} | ||
|
||
test("Multiple Columns: StopWordsRemover with particular stop words list") { | ||
val stopWords = Array("test", "a", "an", "the") | ||
val remover = new StopWordsRemover() | ||
.setInputCols(Array("raw1", "raw2")) | ||
.setOutputCols(Array("filtered1", "filtered2")) | ||
.setStopWords(stopWords) | ||
val df = Seq( | ||
(Seq("test", "test"), Seq("test1", "test2"), Seq(), Seq("test1", "test2")), | ||
(Seq("a", "b", "c", "d"), Seq("a", "b"), Seq("b", "c", "d"), Seq("b")), | ||
(Seq("a", "the", "an"), Seq("a", "the", "test1"), Seq(), Seq("test1")), | ||
(Seq("A", "The", "AN"), Seq("A", "The", "AN"), Seq(), Seq()), | ||
(Seq(null), Seq(null), Seq(null), Seq(null)), | ||
(Seq(), Seq(), Seq(), Seq()) | ||
).toDF("raw1", "raw2", "expected1", "expected2") | ||
|
||
remover.transform(df) | ||
.select("filtered1", "expected1", "filtered2", "expected2") | ||
.collect().foreach { | ||
case Row(r1: Seq[String], e1: Seq[String], r2: Seq[String], e2: Seq[String]) => | ||
assert(r1 === e1, | ||
s"The result value is not correct after bucketing. Expected $e1 but found $r1") | ||
assert(r2 === e2, | ||
s"The result value is not correct after bucketing. Expected $e2 but found $r2") | ||
} | ||
} | ||
|
||
test("Compare single/multiple column(s) StopWordsRemover in pipeline") { | ||
val df = Seq( | ||
(Seq("test", "test"), Seq("test1", "test2")), | ||
(Seq("a", "b", "c", "d"), Seq("a", "b")), | ||
(Seq("a", "the", "an"), Seq("a", "the", "test1")), | ||
(Seq("A", "The", "AN"), Seq("A", "The", "AN")), | ||
(Seq(null), Seq(null)), | ||
(Seq(), Seq()) | ||
).toDF("input1", "input2") | ||
|
||
val multiColsRemover = new StopWordsRemover() | ||
.setInputCols(Array("input1", "input2")) | ||
.setOutputCols(Array("output1", "output2")) | ||
|
||
val plForMultiCols = new Pipeline() | ||
.setStages(Array(multiColsRemover)) | ||
.fit(df) | ||
|
||
val removerForCol1 = new StopWordsRemover() | ||
.setInputCol("input1") | ||
.setOutputCol("output1") | ||
val removerForCol2 = new StopWordsRemover() | ||
.setInputCol("input2") | ||
.setOutputCol("output2") | ||
|
||
val plForSingleCol = new Pipeline() | ||
.setStages(Array(removerForCol1, removerForCol2)) | ||
.fit(df) | ||
|
||
val resultForSingleCol = plForSingleCol.transform(df) | ||
.select("output1", "output2") | ||
.collect() | ||
val resultForMultiCols = plForMultiCols.transform(df) | ||
.select("output1", "output2") | ||
.collect() | ||
|
||
resultForSingleCol.zip(resultForMultiCols).foreach { | ||
case (rowForSingle, rowForMultiCols) => | ||
assert(rowForSingle === rowForMultiCols) | ||
} | ||
} | ||
|
||
test("Multiple Columns: Mismatched sizes of inputCols/outputCols") { | ||
val remover = new StopWordsRemover() | ||
.setInputCols(Array("input1")) | ||
.setOutputCols(Array("result1", "result2")) | ||
val df = Seq( | ||
(Seq("A"), Seq("A")), | ||
(Seq("The", "the"), Seq("The")) | ||
).toDF("input1", "input2") | ||
intercept[IllegalArgumentException] { | ||
remover.transform(df).count() | ||
} | ||
} | ||
|
||
test("Multiple Columns: Set both of inputCol/inputCols") { | ||
val remover = new StopWordsRemover() | ||
.setInputCols(Array("input1", "input2")) | ||
.setOutputCols(Array("result1", "result2")) | ||
.setInputCol("input1") | ||
val df = Seq( | ||
(Seq("A"), Seq("A")), | ||
(Seq("The", "the"), Seq("The")) | ||
).toDF("input1", "input2") | ||
intercept[IllegalArgumentException] { | ||
remover.transform(df).count() | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't feel strongly, but you could remove this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I accidentally broke the line, but I prefer to have it. When other features added the multi columns support,
since xxx
was added to the doc. Just try to be consistent with others.