Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23377][ML] Fixes Bucketizer with multiple columns persistence bug #20594

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
override def copy(extra: ParamMap): Bucketizer = {
defaultCopy[Bucketizer](extra).setParent(parent)
}

@Since("2.3.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for this since annotation; the signature isn't changed in 2.3.0

override def write: MLWriter = new Bucketizer.BucketizerWriter(this)
}

@Since("1.6.0")
Expand Down Expand Up @@ -290,6 +293,27 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] {
}
}


private[Bucketizer] class BucketizerWriter(instance: Bucketizer) extends MLWriter {

override protected def saveImpl(path: String): Unit = {
// SPARK-23377: The default params will be saved and loaded as user-supplied params.
// Once `inputCols` is set, the default value of `outputCol` param causes the error
// when checking exclusive params. As a temporary to fix it, we remove the default
// value of `outputCol` if `inputCols` is set before saving.
// TODO: If we modify the persistence mechanism later to better handle default params,
// we can get rid of this.
var removedOutputCol: Option[String] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt whether it need a "lock" here, because it is the way "clear default value first, then save model, then restore default value".
Maybe wrapping the code block here by synchronized is safer ?

Copy link
Member Author

@viirya viirya Feb 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about this too. But looks like we don't add lock to the places we might change params in ML. I guess that we assume the usage of ML models is single-threaded. So I leave it as this. Will add it if others think this is required too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep. But I have some new thoughts, see my comments at bottom. -:)

if (instance.isSet(instance.inputCols)) {
Copy link
Contributor

@mgaido91 mgaido91 Feb 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can create a lot of issues with the Python API. Please see #20410 for reference. Thus I am against this fix, unless we first fix the problem I linked

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? I think they are orthogonal and this shouldn't cause the issue in Python side. Besides, as the PySpark multi-column support is not added yet (it's reverted), I think we don't hit the Python API issue. This is a quick fix to deal with the persistence bug. I'm not sure we should be blocked.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think #20410 is not related to this PR for now. But I am afraid in the future, when we add more functionality, potential bugs will possible to be triggered.
But I think we don't need to care the order of them to be merged. :)

removedOutputCol = instance.getDefault(instance.outputCol)
instance.clearDefault(instance.outputCol)
}
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Add the default param back.
removedOutputCol.map(instance.setDefault(instance.outputCol, _))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although the saving logic is the same as QuantileDiscretizerWriter, I leave them as duplicate for now since this is a quick fix. If there is strong preference, I can make a common class for it.

}
}

@Since("1.6.0")
override def load(path: String): Bucketizer = super.load(path)
}
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,35 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui

@Since("1.6.0")
override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra)

@Since("2.3.0")
override def write: MLWriter = new QuantileDiscretizer.QuantileDiscretizerWriter(this)
}

@Since("1.6.0")
object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging {

private[QuantileDiscretizer]
class QuantileDiscretizerWriter(instance: QuantileDiscretizer) extends MLWriter {

override protected def saveImpl(path: String): Unit = {
// SPARK-23377: The default params will be saved and loaded as user-supplied params.
// Once `inputCols` is set, the default value of `outputCol` param causes the error
// when checking exclusive params. As a temporary to fix it, we remove the default
// value of `outputCol` if `inputCols` is set before saving.
// TODO: If we modify the persistence mechanism later to better handle default params,
// we can get rid of this.
var removedOutputCol: Option[String] = None
if (instance.isSet(instance.inputCols)) {
removedOutputCol = instance.getDefault(instance.outputCol)
instance.clearDefault(instance.outputCol)
}
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Add the default param back.
removedOutputCol.map(instance.setDefault(instance.outputCol, _))
}
}

@Since("1.6.0")
override def load(path: String): QuantileDiscretizer = super.load(path)
}
9 changes: 9 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,15 @@ trait Params extends Identifiable with Serializable {
defaultParamMap.contains(param)
}

/**
* Clears the default value for the input param.
*/
final def clearDefault[T](param: Param[T]): this.type = {
shouldOwn(param)
defaultParamMap.remove(param)
this
}

/**
* Creates a copy of this instance with the same UID and some extra params.
* Subclasses should implement this method and set the return type properly.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,10 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setInputCol("myInputCol")
.setOutputCol("myOutputCol")
.setSplits(Array(0.1, 0.8, 0.9))
testDefaultReadWrite(t)

val bucketizer = testDefaultReadWrite(t)
val data = Seq((1.0, 2.0), (10.0, 100.0), (101.0, -1.0)).toDF("myInputCol", "myInputCol2")
bucketizer.transform(data)
}

test("Bucket numeric features") {
Expand Down Expand Up @@ -327,7 +330,12 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setInputCols(Array("myInputCol"))
.setOutputCols(Array("myOutputCol"))
.setSplitsArray(Array(Array(0.1, 0.8, 0.9)))
testDefaultReadWrite(t)

val bucketizer = testDefaultReadWrite(t)
val data = Seq((1.0, 2.0), (10.0, 100.0), (101.0, -1.0)).toDF("myInputCol", "myInputCol2")
bucketizer.transform(data)
assert(t.hasDefault(t.outputCol))
assert(bucketizer.hasDefault(bucketizer.outputCol))
}

test("Bucketizer in a pipeline") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import org.apache.spark.sql.functions.udf
class QuantileDiscretizerSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

import testImplicits._

test("Test observed number of buckets and their sizes match expected values") {
val spark = this.spark
import spark.implicits._
Expand Down Expand Up @@ -132,7 +134,10 @@ class QuantileDiscretizerSuite
.setInputCol("myInputCol")
.setOutputCol("myOutputCol")
.setNumBuckets(6)
testDefaultReadWrite(t)

val readDiscretizer = testDefaultReadWrite(t)
val data = sc.parallelize(1 to 100).map(Tuple1.apply).toDF("myInputCol")
readDiscretizer.fit(data)
}

test("Verify resulting model has parent") {
Expand Down Expand Up @@ -379,7 +384,12 @@ class QuantileDiscretizerSuite
.setInputCols(Array("input1", "input2"))
.setOutputCols(Array("result1", "result2"))
.setNumBucketsArray(Array(5, 10))
testDefaultReadWrite(discretizer)

val readDiscretizer = testDefaultReadWrite(discretizer)
val data = Seq((1.0, 2.0), (2.0, 3.0), (3.0, 4.0)).toDF("input1", "input2")
readDiscretizer.fit(data)
assert(discretizer.hasDefault(discretizer.outputCol))
assert(readDiscretizer.hasDefault(readDiscretizer.outputCol))
}

test("Multiple Columns: Both inputCol and inputCols are set") {
Expand Down
5 changes: 4 additions & 1 deletion project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,10 @@ object MimaExcludes {
ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.Bucketizer.getHandleInvalid"),
ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.StringIndexer.getHandleInvalid"),
ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.QuantileDiscretizer.getHandleInvalid"),
ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.StringIndexerModel.getHandleInvalid")
ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.StringIndexerModel.getHandleInvalid"),

// [SPARK-23377][ML] Fixes Bucketizer with multiple columns persistence bug
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.clearDefault")
)

// Exclude rules for 2.2.x
Expand Down