Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Jun 14, 2024
1 parent 6f696ba commit ebcb98c
Showing 1 changed file with 11 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ package ml.dmlc.xgboost4j.scala.spark
import scala.collection.mutable

import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable, MLReadable, MLReader, SchemaUtils}
import org.apache.spark.ml.xgboost.SparkUtils
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.functions.{col, udf}

import ml.dmlc.xgboost4j.scala.Booster
import ml.dmlc.xgboost4j.scala.spark.params.ClassificationParams
import org.apache.spark.sql.types.{DoubleType, StructType}


class XGBoostClassifier(override val uid: String,
Expand All @@ -42,11 +42,6 @@ class XGBoostClassifier(override val uid: String,

xgboost2SparkParams(xgboostParams)

override def transformSchema(schema: StructType): StructType = {
SparkUtils.appendColumn(schema, $(predictionCol), DoubleType)
// SparkUtils.appendColumn(schema,$(rawPredictionCol), Vec)
}

/**
* Validate the parameters before training, throw exception if possible
*/
Expand Down Expand Up @@ -94,14 +89,17 @@ object XGBoostClassifier extends DefaultParamsReadable[XGBoostClassifier] {
override def load(path: String): XGBoostClassifier = super.load(path)
}

// TODO add num classes
class XGBoostClassificationModel(
uid: String,
booster: Booster,
model: Booster,
trainingSummary: Option[XGBoostTrainingSummary] = None
)
extends XGBoostModel[XGBoostClassificationModel](uid, booster, trainingSummary)
extends XGBoostModel[XGBoostClassificationModel](uid, model, trainingSummary)
with ClassificationParams[XGBoostClassificationModel] {

def this(uid: String) = this(uid, null)

// Copied from Spark
private def probability2prediction(probability: Vector): Double = {
if (!isDefined(thresholds)) {
Expand Down Expand Up @@ -142,6 +140,11 @@ class XGBoostClassificationModel(
}
output.drop(TMP_TRANSFORMED_COL)
}

override def copy(extra: ParamMap): XGBoostClassificationModel = {
val newModel = copyValues(new XGBoostClassificationModel(uid, model, trainingSummary), extra)
newModel.setParent(parent)
}
}

object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel] {
Expand Down

0 comments on commit ebcb98c

Please sign in to comment.