Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Jun 13, 2024
1 parent 3a503a1 commit 71b0162
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,11 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel]

private class ModelReader extends XGBoostModelReader[XGBoostClassificationModel] {
override def load(path: String): XGBoostClassificationModel = {
val model = loadBooster(path)
val xgbModel = loadBooster(path)
val meta = SparkUtils.loadMetadata(path, sc)
new XGBoostClassificationModel(meta.uid, model)
val model = new XGBoostClassificationModel(meta.uid, xgbModel)
meta.getAndSetParams(model)
model
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -457,10 +457,10 @@ private[spark] class XGBoostModelWriter[M <: XGBoostModel[M]](instance: M) exten

// Save model data
val dataPath = new Path(path, "data").toString
val internalPath = new Path(dataPath, "model." + JBooster.DEFAULT_FORMAT)
val internalPath = new Path(dataPath, "model")
val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath)
try {
instance.nativeBooster.saveModel(outputStream, JBooster.DEFAULT_FORMAT)
instance.nativeBooster.saveModel(outputStream)
} finally {
outputStream.close()
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package ml.dmlc.xgboost4j.scala.spark

import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.functions.{array, col, lit, rand}
import org.apache.spark.sql.functions.lit
import org.scalatest.funsuite.AnyFunSuite

class NewXGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
Expand Down Expand Up @@ -48,7 +48,7 @@ class NewXGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderP
val est = new XGBoostClassifier()
.setNumWorkers(1)
.setNumRound(2)
.setMaxDepth(2)
.setMaxDepth(3)
// .setWeightCol("weight")
// .setBaseMarginCol("base_margin")
.setLabelCol(labelCol)
Expand All @@ -62,8 +62,18 @@ class NewXGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderP
// est.fit(arrayInput)
est.write.overwrite().save("/tmp/abcdef")
val loadedEst = XGBoostClassifier.load("/tmp/abcdef")
println(loadedEst.getNumRound)
println(loadedEst.getMaxDepth)

val model = loadedEst.fit(dataset)
println("-----------------------")
println(model.getNumRound)
println(model.getMaxDepth)

model.write.overwrite().save("/tmp/model/")
val loadedModel = XGBoostClassificationModel.load("/tmp/model")
println(loadedModel.getNumRound)
println(loadedModel.getMaxDepth)
model.transform(dataset).drop(features: _*).show(150, false)
}

Expand Down

0 comments on commit 71b0162

Please sign in to comment.