Skip to content

Commit

Permalink
[jvm-packages] fix the scalability issue of prediction (#4033)
Browse files Browse the repository at this point in the history
  • Loading branch information
CodingCat authored Dec 30, 2018
1 parent 15fe2f1 commit f368d0d
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 34 deletions.
19 changes: 0 additions & 19 deletions jvm-packages/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -335,25 +335,6 @@
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.jacoco</groupId>
<artifactId>jacoco-maven-plugin</artifactId>
<version>0.7.9</version>
<executions>
<execution>
<goals>
<goal>prepare-agent</goal>
</goals>
</execution>
<execution>
<id>report</id>
<phase>test</phase>
<goals>
<goal>report</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
<dependencies>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,12 +285,12 @@ class XGBoostClassificationModel private[ml](
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
val appName = dataset.sparkSession.sparkContext.appName

val rdd = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
val inputRDD = dataset.asInstanceOf[Dataset[Row]].rdd
val predictionRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
if (rowIterator.hasNext) {
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
Rabit.init(rabitEnv.asJava)
val (rowItr1, rowItr2) = rowIterator.duplicate
val featuresIterator = rowItr2.map(row => row.getAs[Vector](
val featuresIterator = rowIterator.map(row => row.getAs[Vector](
$(featuresCol))).toList.iterator
import DataUtils._
val cacheInfo = {
Expand All @@ -307,19 +307,27 @@ class XGBoostClassificationModel private[ml](
val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) =
producePredictionItrs(bBooster, dm)
Rabit.shutdown()
produceResultIterator(rowItr1, rawPredictionItr, probabilityItr, predLeafItr,
Iterator(rawPredictionItr, probabilityItr, predLeafItr,
predContribItr)
} finally {
dm.delete()
}
} else {
Iterator[Row]()
Iterator()
}
}
val resultRDD = inputRDD.zipPartitions(predictionRDD, preservesPartitioning = true) {
case (inputIterator, predictionItr) =>
if (inputIterator.hasNext) {
produceResultIterator(inputIterator, predictionItr.next(), predictionItr.next(),
predictionItr.next(), predictionItr.next())
} else {
Iterator()
}
}

bBooster.unpersist(blocking = false)

dataset.sparkSession.createDataFrame(rdd, generateResultSchema(schema))
dataset.sparkSession.createDataFrame(resultRDD, generateResultSchema(schema))
}

private def produceResultIterator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,13 +257,12 @@ class XGBoostRegressionModel private[ml] (

val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
val appName = dataset.sparkSession.sparkContext.appName

val rdd = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
val inputRDD = dataset.asInstanceOf[Dataset[Row]].rdd
val predictionRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
if (rowIterator.hasNext) {
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
Rabit.init(rabitEnv.asJava)
val (rowItr1, rowItr2) = rowIterator.duplicate
val featuresIterator = rowItr2.map(row => row.getAs[Vector](
val featuresIterator = rowIterator.map(row => row.getAs[Vector](
$(featuresCol))).toList.iterator
import DataUtils._
val cacheInfo = {
Expand All @@ -273,24 +272,32 @@ class XGBoostRegressionModel private[ml] (
null
}
}

val dm = new DMatrix(
XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)),
cacheInfo)
try {
val Array(originalPredictionItr, predLeafItr, predContribItr) =
producePredictionItrs(bBooster, dm)
Rabit.shutdown()
produceResultIterator(rowItr1, originalPredictionItr, predLeafItr, predContribItr)
Iterator(originalPredictionItr, predLeafItr, predContribItr)
} finally {
dm.delete()
}
} else {
Iterator[Row]()
Iterator()
}
}
val resultRDD = inputRDD.zipPartitions(predictionRDD, preservesPartitioning = true) {
case (inputIterator, predictionItr) =>
if (inputIterator.hasNext) {
produceResultIterator(inputIterator, predictionItr.next(), predictionItr.next(),
predictionItr.next())
} else {
Iterator()
}
}
bBooster.unpersist(blocking = false)
dataset.sparkSession.createDataFrame(rdd, generateResultSchema(schema))
dataset.sparkSession.createDataFrame(resultRDD, generateResultSchema(schema))
}

private def produceResultIterator(
Expand Down

0 comments on commit f368d0d

Please sign in to comment.