Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Jun 12, 2024
1 parent 5629e17 commit 88bf925
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,17 @@ private[spark] object NewXGBoost extends StageLevelScheduling {
require(tracker.start(), "FAULT: Failed to start tracker")
try {
block(tracker)
} catch {
case t: Throwable =>
logger.error(t)
throw t
} finally {
tracker.stop()
try {
tracker.stop()
} catch {
// swallow the exception from stop
case _ => logger.error("Failed to stop tracker ...")
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark

import ml.dmlc.xgboost4j.scala.spark.params.{ClassificationParams, HasGroupCol, SparkParams, XGBoostParams}
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.MLVectorToXGBLabeledPoint
import ml.dmlc.xgboost4j.scala.spark.util.Utils
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.spark.ml.linalg.Vector
Expand All @@ -37,7 +38,8 @@ import scala.collection.mutable.ArrayBuffer
private case class ColumnIndexes(label: String, features: String,
weight: Option[String] = None,
baseMargin: Option[String] = None,
group: Option[String] = None)
group: Option[String] = None,
valiation: Option[String] = None)

private[spark] abstract class XGBoostEstimator[
Learner <: XGBoostEstimator[Learner, M],
Expand Down Expand Up @@ -90,6 +92,15 @@ private[spark] abstract class XGBoostEstimator[
None
}

// TODO, check the validation col
val validationName = if (isDefined(validationIndicatorCol) &&
getValidationIndicatorCol.nonEmpty) {
selectedCols.append(col(getValidationIndicatorCol))
Some(getValidationIndicatorCol)
} else {
None
}

var groupName: Option[String] = None
this match {
case p: HasGroupCol =>
Expand Down Expand Up @@ -118,7 +129,8 @@ private[spark] abstract class XGBoostEstimator[
getFeaturesCol,
weight = weightName,
baseMargin = baseMarginName,
group = groupName)
group = groupName,
valiation = validationName)
(input, columnIndexes)
}

Expand All @@ -145,12 +157,59 @@ private[spark] abstract class XGBoostEstimator[
// TODO support sparse vector.
// TODO support array
val values = features.toArray.map(_.toFloat)
XGBLabeledPoint(label, values.length, null, values, weight, group.toInt, baseMargin)
val isValidation = columnIndexes.valiation.exists(v =>
row.getBoolean(row.fieldIndex(v)))

(isValidation,
XGBLabeledPoint(label, values.length, null, values, weight, group.toInt, baseMargin))
}


labeledPointRDD.mapPartitions { iter =>
val datasets: ArrayBuffer[DMatrix] = ArrayBuffer.empty
val names: ArrayBuffer[String] = ArrayBuffer.empty
val validations: ArrayBuffer[XGBLabeledPoint] = ArrayBuffer.empty

val trainIter = if (columnIndexes.valiation.isDefined) {
val dataIter = new Iterator[XGBLabeledPoint] {
private var tmp: Option[XGBLabeledPoint] = None

override def hasNext: Boolean = {
if (tmp.isDefined) {
return true
}
while (iter.hasNext) {
val (isVal, labelPoint) = iter.next()
if (isVal) {
validations.append(labelPoint)
} else {
tmp = Some(labelPoint)
return true
}
}
false
}

override def next(): XGBLabeledPoint = {
val xgbLabeledPoint = tmp.get
tmp = None
xgbLabeledPoint
}
}
dataIter
} else {
iter.map(_._2)
}

datasets.append(new DMatrix(trainIter))
names.append(Utils.TRAIN_NAME)
if (columnIndexes.valiation.isDefined) {
datasets.append(new DMatrix(validations.toIterator))
names.append(Utils.VALIDATION_NAME)
}

// TODO 1. support external memory 2. rework or remove Watches
val watches = new Watches(Array(new DMatrix(iter)), Array("train"), None)
val watches = new Watches(datasets.toArray, names.toArray, None)
Iterator.single(watches)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,7 @@ object Utils {
FullTypeHints(List(Utils.classForName(className)))
}.getOrElse(NoTypeHints)
}

val TRAIN_NAME = "train"
val VALIDATION_NAME = "eval"
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ class NewXGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderP
// df = df.withColumn("base_margin", lit(20))
// .withColumn("weight", rand(1))

var Array(trainDf, validationDf) = df.randomSplit(Array(0.8, 0.2), seed = 1)

trainDf = trainDf.withColumn("validation", lit(false))
validationDf = validationDf.withColumn("validationDf", lit(true))

df = trainDf.union(validationDf)

// Assemble the feature columns into a single vector column
val assembler = new VectorAssembler()
Expand All @@ -44,6 +50,7 @@ class NewXGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderP
// .setWeightCol("weight")
// .setBaseMarginCol("base_margin")
.setLabelCol(labelCol)
.setValidationIndicatorCol("validation")
// .setPredictionCol("")
.setRawPredictionCol("")
.setProbabilityCol("xxxx")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ public boolean start() throws XGBoostError {
this.trackerDaemon = new Thread(() -> {
try {
waitFor(0);
} catch (XGBoostError ex) {
} catch (Exception ex) {
logger.error(ex);
return; // exit the thread
}
Expand Down

0 comments on commit 88bf925

Please sign in to comment.