-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-30144][ML][PySpark] Make MultilayerPerceptronClassificationModel extend MultilayerPerceptronParams #26838
Changes from all commits
fc2cc5a
09bca1e
14ce378
7590bf8
f98de6b
2844d79
6be731d
7a98ffb
fdfeb6b
94b51a7
1833754
07267ff
fa1797e
40fc5da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ import org.apache.spark.ml.param.shared._ | |
import org.apache.spark.ml.util._ | ||
import org.apache.spark.ml.util.Instrumentation.instrumented | ||
import org.apache.spark.sql.{Dataset, Row} | ||
import org.apache.spark.util.VersionUtils.majorMinorVersion | ||
|
||
/** Params for Multilayer Perceptron. */ | ||
private[classification] trait MultilayerPerceptronParams extends ProbabilisticClassifierParams | ||
|
@@ -247,7 +248,7 @@ class MultilayerPerceptronClassifier @Since("1.5.0") ( | |
} | ||
trainer.setStackSize($(blockSize)) | ||
val mlpModel = trainer.train(data) | ||
new MultilayerPerceptronClassificationModel(uid, myLayers, mlpModel.weights) | ||
new MultilayerPerceptronClassificationModel(uid, mlpModel.weights) | ||
} | ||
} | ||
|
||
|
@@ -273,31 +274,22 @@ object MultilayerPerceptronClassifier | |
* Each layer has sigmoid activation function, output layer has softmax. | ||
* | ||
* @param uid uid | ||
* @param layers array of layer sizes including input and output layers | ||
* @param weights the weights of layers | ||
*/ | ||
@Since("1.5.0") | ||
class MultilayerPerceptronClassificationModel private[ml] ( | ||
@Since("1.5.0") override val uid: String, | ||
@Since("1.5.0") val layers: Array[Int], | ||
@Since("2.0.0") val weights: Vector) | ||
extends ProbabilisticClassificationModel[Vector, MultilayerPerceptronClassificationModel] | ||
with Serializable with MLWritable { | ||
with MultilayerPerceptronParams with Serializable with MLWritable { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not related to this change. But do we use MultilayerPerceptronClassificationModel in executors? Like not every classification model extends Serializable. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure about this. Seems only the tree related model extends Serializable. |
||
|
||
@Since("1.6.0") | ||
override val numFeatures: Int = layers.head | ||
override lazy val numFeatures: Int = $(layers).head | ||
|
||
private[ml] val mlpModel = FeedForwardTopology | ||
.multiLayerPerceptron(layers, softmaxOnTop = true) | ||
@transient private[ml] lazy val mlpModel = FeedForwardTopology | ||
.multiLayerPerceptron($(layers), softmaxOnTop = true) | ||
.model(weights) | ||
|
||
/** | ||
* Returns layers in a Java List. | ||
*/ | ||
private[ml] def javaLayers: java.util.List[Int] = { | ||
layers.toList.asJava | ||
} | ||
|
||
/** | ||
* Predict label for the given features. | ||
* This internal method is used to implement `transform()` and output [[predictionCol]]. | ||
|
@@ -308,7 +300,8 @@ class MultilayerPerceptronClassificationModel private[ml] ( | |
|
||
@Since("1.5.0") | ||
override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = { | ||
val copied = new MultilayerPerceptronClassificationModel(uid, layers, weights).setParent(parent) | ||
val copied = new MultilayerPerceptronClassificationModel(uid, weights) | ||
.setParent(parent) | ||
copyValues(copied, extra) | ||
} | ||
|
||
|
@@ -322,11 +315,11 @@ class MultilayerPerceptronClassificationModel private[ml] ( | |
|
||
override protected def predictRaw(features: Vector): Vector = mlpModel.predictRaw(features) | ||
|
||
override def numClasses: Int = layers.last | ||
override def numClasses: Int = $(layers).last | ||
|
||
@Since("3.0.0") | ||
override def toString: String = { | ||
s"MultilayerPerceptronClassificationModel: uid=$uid, numLayers=${layers.length}, " + | ||
s"MultilayerPerceptronClassificationModel: uid=$uid, numLayers=${$(layers).length}, " + | ||
s"numClasses=$numClasses, numFeatures=$numFeatures" | ||
} | ||
} | ||
|
@@ -347,13 +340,13 @@ object MultilayerPerceptronClassificationModel | |
class MultilayerPerceptronClassificationModelWriter( | ||
instance: MultilayerPerceptronClassificationModel) extends MLWriter { | ||
|
||
private case class Data(layers: Array[Int], weights: Vector) | ||
private case class Data(weights: Vector) | ||
|
||
override protected def saveImpl(path: String): Unit = { | ||
// Save metadata and Params | ||
DefaultParamsWriter.saveMetadata(instance, path, sc) | ||
// Save model data: layers, weights | ||
val data = Data(instance.layers, instance.weights) | ||
// Save model data: weights | ||
val data = Data(instance.weights) | ||
val dataPath = new Path(path, "data").toString | ||
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) | ||
} | ||
|
@@ -367,13 +360,21 @@ object MultilayerPerceptronClassificationModel | |
|
||
override def load(path: String): MultilayerPerceptronClassificationModel = { | ||
val metadata = DefaultParamsReader.loadMetadata(path, sc, className) | ||
val (majorVersion, _) = majorMinorVersion(metadata.sparkVersion) | ||
|
||
val dataPath = new Path(path, "data").toString | ||
val data = sparkSession.read.parquet(dataPath).select("layers", "weights").head() | ||
val layers = data.getAs[Seq[Int]](0).toArray | ||
val weights = data.getAs[Vector](1) | ||
val model = new MultilayerPerceptronClassificationModel(metadata.uid, layers, weights) | ||
|
||
val df = sparkSession.read.parquet(dataPath) | ||
val model = if (majorVersion < 3) { // model prior to 3.0.0 | ||
val data = df.select("layers", "weights").head() | ||
val layers = data.getAs[Seq[Int]](0).toArray | ||
val weights = data.getAs[Vector](1) | ||
val model = new MultilayerPerceptronClassificationModel(metadata.uid, weights) | ||
model.set("layers", layers) | ||
} else { | ||
val data = df.select("weights").head() | ||
val weights = data.getAs[Vector](0) | ||
new MultilayerPerceptronClassificationModel(metadata.uid, weights) | ||
srowen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
metadata.getAndSetParams(model) | ||
model | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"class":"org.apache.spark.ml.feature.HashingTF","timestamp":1577833408759,"sparkVersion":"2.4.4","uid":"hashingTF_f4565fe7f7da","paramMap":{"numFeatures":100,"outputCol":"features","inputCol":"words","binary":true},"defaultParamMap":{"numFeatures":262144,"outputCol":"hashingTF_f4565fe7f7da__output","binary":false}} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"class":"org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel","timestamp":1577833765310,"sparkVersion":"2.4.4","uid":"mlpc_30aa2f44dacc","paramMap":{},"defaultParamMap":{"rawPredictionCol":"rawPrediction","predictionCol":"prediction","probabilityCol":"probability","labelCol":"label","featuresCol":"features"}} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"class":"org.apache.spark.ml.feature.StringIndexerModel","timestamp":1577831053235,"sparkVersion":"2.4.4","uid":"myStringIndexerModel","paramMap":{"inputCol":"myInputCol","outputCol":"myOutputCol","handleInvalid":"skip"},"defaultParamMap":{"outputCol":"myStringIndexerModel__output","handleInvalid":"error"}} |
This file was deleted.
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -229,4 +229,17 @@ class MultilayerPerceptronClassifierSuite extends MLTest with DefaultReadWriteTe | |
assert(expected.weights === actual.weights) | ||
} | ||
} | ||
|
||
test("Load MultilayerPerceptronClassificationModel prior to Spark 3.0") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a general question, we also have other algorithms that modified the load/save method (like NaiveBayes), do we need to add testsuites for them like this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am ok either way. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so add similar test in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe needed, but in other PRs. |
||
val mlpPath = testFile("ml-models/mlp-2.4.4") | ||
val model = MultilayerPerceptronClassificationModel.load(mlpPath) | ||
val layers = model.getLayers | ||
assert(layers(0) === 4) | ||
assert(layers(1) === 5) | ||
assert(layers(2) === 2) | ||
|
||
val metadata = spark.read.json(s"$mlpPath/metadata") | ||
val sparkVersionStr = metadata.select("sparkVersion").first().getString(0) | ||
assert(sparkVersionStr == "2.4.4") | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -459,13 +459,13 @@ class StringIndexerSuite extends MLTest with DefaultReadWriteTest { | |
} | ||
|
||
test("Load StringIndexderModel prior to Spark 3.0") { | ||
val modelPath = testFile("test-data/strIndexerModel") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. strIndexerModel-2.4.4? |
||
val modelPath = testFile("ml-models/strIndexerModel-2.4.4") | ||
|
||
val loadedModel = StringIndexerModel.load(modelPath) | ||
assert(loadedModel.labelsArray === Array(Array("b", "c", "a"))) | ||
|
||
val metadata = spark.read.json(s"$modelPath/metadata") | ||
val sparkVersionStr = metadata.select("sparkVersion").first().getString(0) | ||
assert(sparkVersionStr == "2.4.1-SNAPSHOT") | ||
assert(sparkVersionStr == "2.4.4") | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -328,6 +328,10 @@ object MimaExcludes { | |
// [SPARK-26457] Show hadoop configurations in HistoryServer environment tab | ||
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ApplicationEnvironmentInfo.this"), | ||
|
||
// [SPARK-30144][ML] Make MultilayerPerceptronClassificationModel extend MultilayerPerceptronParams | ||
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel.layers"), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a question. Is this worth to break the API, @huaxingao ? |
||
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel.this"), | ||
|
||
// Data Source V2 API changes | ||
(problem: Problem) => problem match { | ||
case MissingClassProblem(cls) => | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we update migration guild?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@srowen Sean, this question is for you.