From e1160cfceb249db8071181620871a25f7a910a91 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 7 May 2015 15:45:11 -0700 Subject: [PATCH] fix tests --- .../spark/ml/classification/DecisionTreeClassifier.scala | 1 + .../org/apache/spark/ml/classification/GBTClassifier.scala | 1 + .../org/apache/spark/ml/regression/DecisionTreeRegressor.scala | 3 ++- .../scala/org/apache/spark/ml/regression/GBTRegressor.scala | 1 + .../src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala | 2 +- 5 files changed, 6 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index bc3ba7851aded..d34ffd3990c90 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -140,6 +140,7 @@ private[ml] object DecisionTreeClassificationModel { s"Cannot convert non-classification DecisionTreeModel (old API) to" + s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}") val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) + val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc") new DecisionTreeClassificationModel(parent.uid, rootNode) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 1cecada8c21f2..441f1bddc9192 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -212,6 +212,7 @@ private[ml] object GBTClassificationModel { // parent, fittingParamMap for each tree is null since there are no good ways to set these. DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } + val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc") new GBTClassificationModel(parent.uid, newTrees, oldModel.treeWeights) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index d71a56cf702ae..e67df21b2e4ae 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -130,6 +130,7 @@ private[ml] object DecisionTreeRegressionModel { s"Cannot convert non-regression DecisionTreeModel (old API) to" + s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}") val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) - new DecisionTreeRegressionModel(parent.uid, rootNode) + val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtr") + new DecisionTreeRegressionModel(uid, rootNode) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 11717c412dfc7..050ae0f6bc10a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -201,6 +201,7 @@ private[ml] object GBTRegressionModel { // parent, fittingParamMap for each tree is null since there are no good ways to set these. DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } + val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr") new GBTRegressionModel(parent.uid, newTrees, oldModel.treeWeights) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 6056e7d3f6ff8..e30eab9e005f7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -27,7 +27,7 @@ class ParamsSuite extends FunSuite { assert(maxIter.name === "maxIter") assert(maxIter.doc === "max number of iterations (>= 0)") - assert(maxIter.parent.eq(solver)) + assert(maxIter.parent === solver.uid) assert(maxIter.toString === "maxIter: max number of iterations (>= 0) (default: 10)") assert(!maxIter.isValid(-1)) assert(maxIter.isValid(0))