From ecf1e22a19cc702b561949fa4e9eb7be997cec05 Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Mon, 24 Jun 2024 15:11:54 +0800 Subject: [PATCH] Support feature names and feature types --- .../scala/spark/XGBoostEstimator.scala | 9 ++++++- .../scala/spark/params/XGBoostParams.scala | 25 +++++++++++++++++-- .../scala/spark/XGBoostEstimatorSuite.scala | 23 +++++++++++++++++ 3 files changed, 54 insertions(+), 3 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala index 915282a3222a..d7e77eac1cff 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala @@ -260,11 +260,15 @@ private[spark] abstract class XGBoostEstimator[ private[spark] def toRdd(dataset: Dataset[_], columnIndices: ColumnIndices): RDD[Watches] = { val trainRDD = toXGBLabeledPoint(dataset, columnIndices) + val x: Array[String] = Array.empty + val featureNames = if (getFeatureNames.isEmpty) None else Some(getFeatureNames) + val featureTypes = if (getFeatureTypes.isEmpty) None else Some(getFeatureTypes) + // transform the labeledpoint to get margins and build DMatrix // TODO support basemargin for multiclassification // TODO, move it into JNI def buildDMatrix(iter: Iterator[XGBLabeledPoint]) = { - if (columnIndices.marginId.isDefined) { + val dmatrix = if (columnIndices.marginId.isDefined) { val trainMargins = new mutable.ArrayBuilder.ofFloat val transformedIter = iter.map { labeledPoint => trainMargins += labeledPoint.baseMargin @@ -276,6 +280,9 @@ private[spark] abstract class XGBoostEstimator[ } else { new DMatrix(iter) } + featureTypes.foreach(dmatrix.setFeatureTypes) + featureNames.foreach(dmatrix.setFeatureNames) + dmatrix } getEvalDataset().map { eval => diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostParams.scala index a795f749d43b..155e12f56461 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostParams.scala @@ -164,13 +164,30 @@ private[spark] trait SparkParams[T <: Params] extends HasFeaturesCols with HasFe final def getCustomEval: EvalTrait = $(customEval) + /** Feature's name, it will be set to DMatrix and Booster, and in the final native json model. + * In native code, the parameter name is feature_name. + * */ + final val featureNames = new StringArrayParam(this, "feature_names", + "an array of feature names") + + final def getFeatureNames: Array[String] = $(featureNames) + + /** Feature types, q is numeric and c is categorical. + * In native code, the parameter name is feature_type + * */ + final val featureTypes = new StringArrayParam(this, "feature_types", + "an array of feature types") + + final def getFeatureTypes: Array[String] = $(featureTypes) + setDefault(numRound -> 100, numWorkers -> 1, inferBatchSize -> (32 << 10), numEarlyStoppingRounds -> 0, forceRepartition -> false, missing -> Float.NaN, - featuresCols -> Array.empty, customObj -> null, customEval -> null) + featuresCols -> Array.empty, customObj -> null, customEval -> null, + featureNames -> Array.empty, featureTypes -> Array.empty) addNonXGBoostParam(numWorkers, numRound, numEarlyStoppingRounds, inferBatchSize, featuresCol, labelCol, baseMarginCol, weightCol, predictionCol, leafPredictionCol, contribPredictionCol, - forceRepartition, missing, featuresCols, customEval, customObj) + forceRepartition, missing, featuresCols, customEval, customObj, featureTypes, featureNames) final def getNumWorkers: Int = $(numWorkers) @@ -209,6 +226,10 @@ private[spark] trait SparkParams[T <: Params] extends HasFeaturesCols with HasFe def setRabitTrackerHostIp(value: String): T = set(rabitTrackerHostIp, value).asInstanceOf[T] def setRabitTrackerPort(value: Int): T = set(rabitTrackerPort, value).asInstanceOf[T] + + def setFeatureNames(value: Array[String]): T = set(featureNames, value).asInstanceOf[T] + + def setFeatureTypes(value: Array[String]): T = set(featureTypes, value).asInstanceOf[T] } private[spark] trait SchemaValidationTrait { diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorSuite.scala index 4091271d14c8..606a2ef1ffe5 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorSuite.scala @@ -22,6 +22,8 @@ import java.util.Arrays import scala.collection.mutable.ArrayBuffer import org.apache.spark.ml.linalg.Vectors +import org.json4s.jackson.parseJson +import org.json4s.{DefaultFormats, Formats} import org.scalatest.funsuite.AnyFunSuite import ml.dmlc.xgboost4j.scala.DMatrix @@ -460,4 +462,25 @@ class XGBoostEstimatorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu } } + test("native json model file should store feature_name and feature_type") { + val featureNames = (1 to 33).map(idx => s"feature_${idx}").toArray + val featureTypes = (1 to 33).map(idx => "q").toArray + val trainingDF = buildDataFrame(MultiClassification.train) + val xgb = new XGBoostClassifier() + .setNumWorkers(numWorkers) + .setFeatureNames(featureNames) + .setFeatureTypes(featureTypes) + .setNumRound(2) + val model = xgb.fit(trainingDF) + val modelStr = new String(model.nativeBooster.toByteArray("json")) + val jsonModel = parseJson(modelStr) + implicit val formats: Formats = DefaultFormats + val featureNamesInModel = (jsonModel \ "learner" \ "feature_names").extract[List[String]] + val featureTypesInModel = (jsonModel \ "learner" \ "feature_types").extract[List[String]] + assert(featureNamesInModel.length == 33) + assert(featureTypesInModel.length == 33) + assert(featureNames sameElements featureNamesInModel) + assert(featureTypes sameElements featureTypesInModel) + } + }