Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support feature names and feature types #16

Merged
merged 1 commit into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,15 @@ private[spark] abstract class XGBoostEstimator[
private[spark] def toRdd(dataset: Dataset[_], columnIndices: ColumnIndices): RDD[Watches] = {
val trainRDD = toXGBLabeledPoint(dataset, columnIndices)

val x: Array[String] = Array.empty
val featureNames = if (getFeatureNames.isEmpty) None else Some(getFeatureNames)
val featureTypes = if (getFeatureTypes.isEmpty) None else Some(getFeatureTypes)

// transform the labeledpoint to get margins and build DMatrix
// TODO support basemargin for multiclassification
// TODO, move it into JNI
def buildDMatrix(iter: Iterator[XGBLabeledPoint]) = {
if (columnIndices.marginId.isDefined) {
val dmatrix = if (columnIndices.marginId.isDefined) {
val trainMargins = new mutable.ArrayBuilder.ofFloat
val transformedIter = iter.map { labeledPoint =>
trainMargins += labeledPoint.baseMargin
Expand All @@ -276,6 +280,9 @@ private[spark] abstract class XGBoostEstimator[
} else {
new DMatrix(iter)
}
featureTypes.foreach(dmatrix.setFeatureTypes)
featureNames.foreach(dmatrix.setFeatureNames)
dmatrix
}

getEvalDataset().map { eval =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,30 @@ private[spark] trait SparkParams[T <: Params] extends HasFeaturesCols with HasFe

final def getCustomEval: EvalTrait = $(customEval)

/** Feature's name, it will be set to DMatrix and Booster, and in the final native json model.
* In native code, the parameter name is feature_name.
* */
final val featureNames = new StringArrayParam(this, "feature_names",
"an array of feature names")

final def getFeatureNames: Array[String] = $(featureNames)

/** Feature types, q is numeric and c is categorical.
* In native code, the parameter name is feature_type
* */
final val featureTypes = new StringArrayParam(this, "feature_types",
"an array of feature types")

final def getFeatureTypes: Array[String] = $(featureTypes)

setDefault(numRound -> 100, numWorkers -> 1, inferBatchSize -> (32 << 10),
numEarlyStoppingRounds -> 0, forceRepartition -> false, missing -> Float.NaN,
featuresCols -> Array.empty, customObj -> null, customEval -> null)
featuresCols -> Array.empty, customObj -> null, customEval -> null,
featureNames -> Array.empty, featureTypes -> Array.empty)

addNonXGBoostParam(numWorkers, numRound, numEarlyStoppingRounds, inferBatchSize, featuresCol,
labelCol, baseMarginCol, weightCol, predictionCol, leafPredictionCol, contribPredictionCol,
forceRepartition, missing, featuresCols, customEval, customObj)
forceRepartition, missing, featuresCols, customEval, customObj, featureTypes, featureNames)

final def getNumWorkers: Int = $(numWorkers)

Expand Down Expand Up @@ -209,6 +226,10 @@ private[spark] trait SparkParams[T <: Params] extends HasFeaturesCols with HasFe
def setRabitTrackerHostIp(value: String): T = set(rabitTrackerHostIp, value).asInstanceOf[T]

def setRabitTrackerPort(value: Int): T = set(rabitTrackerPort, value).asInstanceOf[T]

def setFeatureNames(value: Array[String]): T = set(featureNames, value).asInstanceOf[T]

def setFeatureTypes(value: Array[String]): T = set(featureTypes, value).asInstanceOf[T]
}

private[spark] trait SchemaValidationTrait {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import java.util.Arrays
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.ml.linalg.Vectors
import org.json4s.jackson.parseJson
import org.json4s.{DefaultFormats, Formats}
import org.scalatest.funsuite.AnyFunSuite

import ml.dmlc.xgboost4j.scala.DMatrix
Expand Down Expand Up @@ -460,4 +462,25 @@ class XGBoostEstimatorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu
}
}

test("native json model file should store feature_name and feature_type") {
val featureNames = (1 to 33).map(idx => s"feature_${idx}").toArray
val featureTypes = (1 to 33).map(idx => "q").toArray
val trainingDF = buildDataFrame(MultiClassification.train)
val xgb = new XGBoostClassifier()
.setNumWorkers(numWorkers)
.setFeatureNames(featureNames)
.setFeatureTypes(featureTypes)
.setNumRound(2)
val model = xgb.fit(trainingDF)
val modelStr = new String(model.nativeBooster.toByteArray("json"))
val jsonModel = parseJson(modelStr)
implicit val formats: Formats = DefaultFormats
val featureNamesInModel = (jsonModel \ "learner" \ "feature_names").extract[List[String]]
val featureTypesInModel = (jsonModel \ "learner" \ "feature_types").extract[List[String]]
assert(featureNamesInModel.length == 33)
assert(featureTypesInModel.length == 33)
assert(featureNames sameElements featureNamesInModel)
assert(featureTypes sameElements featureTypesInModel)
}

}
Loading