Skip to content

Commit

Permalink
[SPARK-28866][ML] Persist item factors RDD when checkpointing in ALS
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

In ALS ML implementation, for non-implicit case, we checkpoint the RDD of item factors, between intervals. Before checkpointing (.checkpoint()) and materializing (.count()) RDD, this RDD was not persisted. It causes recomputation. In an experiment, there is performance difference between persisting and no persisting before checkpointing the RDD.

The performance difference is not big, but this change is not big too. The actual performance difference varies depending the interval of checkpoint, training dataset, etc.

### Why are the changes needed?

Persisting the RDD before checkpointing the RDD of item factors can avoid recomputation.

### Does this PR introduce any user-facing change?

No

### How was this patch tested?

Manual check RDD recomputation or not.

Taking 30% MovieLens 20M Dataset as training dataset. Setting checkpoint dir for SparkContext. Fitting an ALS model like:

```scala
val als = new ALS()
      .setMaxIter(100)
      .setCheckpointInterval(5)
      .setRegParam(0.01)
      .setUserCol("userId")
      .setItemCol("movieId")
      .setRatingCol("rating")

val t0 = System.currentTimeMillis()
val model = als.fit(training)
val t1 = System.currentTimeMillis()
```

Before this patch:  65.386 s
After this patch: 61.022 s

Closes #25576 from viirya/persist-item-factors.

Authored-by: Liang-Chi Hsieh <[email protected]>
Signed-off-by: Sean Owen <[email protected]>
  • Loading branch information
viirya authored and srowen committed Aug 30, 2019
1 parent 8279693 commit 2bd02e2
Showing 1 changed file with 6 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -990,16 +990,21 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
previousUserFactors.unpersist()
}
} else {
var previousCachedItemFactors: Option[RDD[(Int, FactorBlock)]] = None
for (iter <- 0 until maxIter) {
itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
userLocalIndexEncoder, solver = solver)
if (shouldCheckpoint(iter)) {
itemFactors.setName(s"itemFactors-$iter").persist(intermediateRDDStorageLevel)
val deps = itemFactors.dependencies
itemFactors.checkpoint()
itemFactors.count() // checkpoint item factors and cut lineage
ALS.cleanShuffleDependencies(sc, deps)
deletePreviousCheckpointFile()

previousCachedItemFactors.foreach(_.unpersist())
previousCheckpointFile = itemFactors.getCheckpointFile
previousCachedItemFactors = Option(itemFactors)
}
userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
itemLocalIndexEncoder, solver = solver)
Expand Down Expand Up @@ -1029,8 +1034,8 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
.persist(finalRDDStorageLevel)
if (finalRDDStorageLevel != StorageLevel.NONE) {
userIdAndFactors.count()
itemFactors.unpersist()
itemIdAndFactors.count()
itemFactors.unpersist()
userInBlocks.unpersist()
userOutBlocks.unpersist()
itemInBlocks.unpersist()
Expand Down

0 comments on commit 2bd02e2

Please sign in to comment.