Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Jun 14, 2024
1 parent 71b0162 commit 76a9181
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 293 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,31 @@
limitations under the License.
*/

package ml.dmlc.xgboost4j.scala.spark.util
package ml.dmlc.xgboost4j.scala.spark

import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector}
import org.json4s.{DefaultFormats, FullTypeHints, JField, JValue, NoTypeHints, TypeHints}

// based on org.apache.spark.util copy /paste
object Utils {

private[spark] implicit class MLVectorToXGBLabeledPoint(val v: Vector) extends AnyVal {
/**
* Converts a [[Vector]] to a data point with a dummy label.
*
* This is needed for constructing a [[ml.dmlc.xgboost4j.scala.DMatrix]]
* for prediction.
*/
// TODO support sparsevector
def asXGB: XGBLabeledPoint = v match {
case v: DenseVector =>
XGBLabeledPoint(0.0f, v.size, null, v.values.map(_.toFloat))
case v: SparseVector =>
XGBLabeledPoint(0.0f, v.size, v.indices, v.toDense.values.map(_.toFloat))
}
}

def getSparkClassLoader: ClassLoader = getClass.getClassLoader

def getContextOrSparkClassLoader: ClassLoader =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@

package ml.dmlc.xgboost4j.scala.spark

import ml.dmlc.xgboost4j.java.{Communicator, ITracker, RabitTracker, XGBoostError}
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker, XGBoostError}
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import org.apache.commons.io.FileUtils
import org.apache.commons.logging.LogFactory
import org.apache.spark.rdd.RDD
import org.apache.spark.resource.{ResourceProfileBuilder, TaskResourceRequests}
import org.apache.spark.sql.SparkSession
import org.apache.spark.{SparkConf, SparkContext, TaskContext}

import java.io.File
Expand Down Expand Up @@ -117,14 +116,14 @@ private[spark] trait StageLevelScheduling extends Serializable {
* on a single executor simultaneously.
*
* @param sc the spark context
* @param rdd which rdd to be applied with new resource profile
* @return the original rdd or the changed rdd
* @param rdd the rdd to be applied with new resource profile
* @return the original rdd or the modified rdd
*/
private[spark] def tryStageLevelScheduling(
private[spark] def tryStageLevelScheduling[T](
sc: SparkContext,
xgbExecParams: RuntimeParams,
rdd: RDD[(Booster, Map[String, Array[Float]])]
): RDD[(Booster, Map[String, Array[Float]])] = {
rdd: RDD[T]
): RDD[T] = {

val conf = sc.getConf
if (skipStageLevelScheduling(sc.version, xgbExecParams.runOnGpu, conf)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@

package ml.dmlc.xgboost4j.scala.spark

import ml.dmlc.xgboost4j.java.{Booster => JBooster}
import ml.dmlc.xgboost4j.scala.spark.Utils.MLVectorToXGBLabeledPoint
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
import ml.dmlc.xgboost4j.scala.spark.params.{ClassificationParams, HasGroupCol, ParamMapConversion, SparkParams, XGBoostParams}
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.MLVectorToXGBLabeledPoint
import ml.dmlc.xgboost4j.scala.spark.util.Utils
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.Path
Expand All @@ -32,7 +30,7 @@ import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{ArrayType, FloatType, Metadata, StructField, StructType}
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}

import scala.collection.mutable.ArrayBuffer

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
package ml.dmlc.xgboost4j.scala.spark.params

import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
import ml.dmlc.xgboost4j.scala.spark.TrackerConf
import ml.dmlc.xgboost4j.scala.spark.util.Utils
import ml.dmlc.xgboost4j.scala.spark.{TrackerConf, Utils}

import org.apache.spark.ml.param.{Param, ParamPair, Params}
import org.json4s.{DefaultFormats, Extraction, NoTypeHints}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,10 @@ package ml.dmlc.xgboost4j.scala.spark

import java.io.{File, FileInputStream}

import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}

import org.apache.spark.SparkContext
import org.apache.spark.sql._
import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite
import scala.math.min
import scala.util.Random

import org.apache.commons.io.IOUtils

Expand Down Expand Up @@ -71,45 +67,6 @@ trait PerTest extends BeforeAndAfterEach { self: AnyFunSuite =>
}
}

protected def buildDataFrame(
labeledPoints: Seq[XGBLabeledPoint],
numPartitions: Int = numWorkers): DataFrame = {
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
val it = labeledPoints.iterator.zipWithIndex
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
(id, labeledPoint.label, labeledPoint.features)
}

ss.createDataFrame(sc.parallelize(it.toList, numPartitions))
.toDF("id", "label", "features")
}

protected def buildDataFrameWithRandSort(
labeledPoints: Seq[XGBLabeledPoint],
numPartitions: Int = numWorkers): DataFrame = {
val df = buildDataFrame(labeledPoints, numPartitions)
val rndSortedRDD = df.rdd.mapPartitions { iter =>
iter.map(_ -> Random.nextDouble()).toList
.sortBy(_._2)
.map(_._1).iterator
}
ss.createDataFrame(rndSortedRDD, df.schema)
}

protected def buildDataFrameWithGroup(
labeledPoints: Seq[XGBLabeledPoint],
numPartitions: Int = numWorkers): DataFrame = {
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
val it = labeledPoints.iterator.zipWithIndex
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
(id, labeledPoint.label, labeledPoint.features, labeledPoint.group)
}

ss.createDataFrame(sc.parallelize(it.toList, numPartitions))
.toDF("id", "label", "features", "group")
}


protected def compareTwoFiles(lhs: String, rhs: String): Boolean = {
withResource(new FileInputStream(lhs)) { lfis =>
withResource(new FileInputStream(rhs)) { rfis =>
Expand Down

0 comments on commit 76a9181

Please sign in to comment.