Skip to content

Commit

Permalink
Generalized PeriodicGraphCheckpointer to PeriodicCheckpointer, with s…
Browse files Browse the repository at this point in the history
…ubclasses for RDDs and Graphs.
  • Loading branch information
jkbradley committed Jul 28, 2015
1 parent daa1964 commit 568918c
Show file tree
Hide file tree
Showing 6 changed files with 471 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ final class EMLDAOptimizer extends LDAOptimizer {
// Update the vertex descriptors with the new counts.
val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges)
graph = newGraph
graphCheckpointer.updateGraph(newGraph)
graphCheckpointer.update(newGraph)
globalTopicTotals = computeGlobalTopicTotals()
this
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.impl

import scala.collection.mutable

import org.apache.hadoop.fs.{Path, FileSystem}

import org.apache.spark.{SparkContext, Logging}
import org.apache.spark.storage.StorageLevel


/**
* This abstraction helps with persisting and checkpointing RDDs and types derived from RDDs
* (such as Graphs and DataFrames). In documentation, we use the phrase "Dataset" to refer to
* the distributed data type (RDD, Graph, etc.).
*
* Specifically, this abstraction automatically handles persisting and (optionally) checkpointing,
* as well as unpersisting and removing checkpoint files.
*
* Users should call update() when a new Dataset has been created,
* before the Dataset has been materialized. After updating [[PeriodicCheckpointer]], users are
* responsible for materializing the Dataset to ensure that persisting and checkpointing actually
* occur.
*
* When update() is called, this does the following:
* - Persist new Dataset (if not yet persisted), and put in queue of persisted Datasets.
* - Unpersist Datasets from queue until there are at most 3 persisted Datasets.
* - If using checkpointing and the checkpoint interval has been reached,
* - Checkpoint the new Dataset, and put in a queue of checkpointed Datasets.
* - Remove older checkpoints.
*
* WARNINGS:
* - This class should NOT be copied (since copies may conflict on which Datasets should be
* checkpointed).
* - This class removes checkpoint files once later Datasets have been checkpointed.
* However, references to the older Datasets will still return isCheckpointed = true.
*
* Example usage:
* {{{
* val (data1, data2, data3, ...) = ...
* val cp = new PeriodicCheckpointer(data1, dir, 2)
* data1.count();
* // persisted: data1
* cp.update(data2)
* data2.count();
* // persisted: data1, data2
* // checkpointed: data2
* cp.update(data3)
* data3.count();
* // persisted: data1, data2, data3
* // checkpointed: data2
* cp.update(data4)
* data4.count();
* // persisted: data2, data3, data4
* // checkpointed: data4
* cp.update(data5)
* data5.count();
* // persisted: data3, data4, data5
* // checkpointed: data4
* }}}
*
* @param currentData Initial Dataset
* @param checkpointInterval Datasets will be checkpointed at this interval
* @param sc SparkContext for the Datasets given to this checkpointer
* @tparam T Dataset type, such as RDD[Double]
*/
private[mllib] abstract class PeriodicCheckpointer[T](
var currentData: T,
val checkpointInterval: Int,
val sc: SparkContext) extends Logging {

/** FIFO queue of past checkpointed Datasets */
private val checkpointQueue = mutable.Queue[T]()

/** FIFO queue of past persisted Datasets */
private val persistedQueue = mutable.Queue[T]()

/** Number of times [[update()]] has been called */
private var updateCount = 0

update(currentData)

/**
* Update [[currentData]] with a new Dataset. Handle persistence and checkpointing as needed.
* Since this handles persistence and checkpointing, this should be called before the Dataset
* has been materialized.
*
* @param newData New Dataset created from previous Datasets in the lineage.
*/
def update(newData: T): Unit = {
persist(newData)
persistedQueue.enqueue(newData)
// We try to maintain 2 Datasets in persistedQueue to support the semantics of this class:
// Users should call [[update()]] when a new Dataset has been created,
// before the Dataset has been materialized.
while (persistedQueue.size > 3) {
val dataToUnpersist = persistedQueue.dequeue()
unpersist(dataToUnpersist)
}
updateCount += 1

// Handle checkpointing (after persisting)
if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) {
// Add new checkpoint before removing old checkpoints.
checkpoint(newData)
checkpointQueue.enqueue(newData)
// Remove checkpoints before the latest one.
var canDelete = true
while (checkpointQueue.size > 1 && canDelete) {
// Delete the oldest checkpoint only if the next checkpoint exists.
if (isCheckpointed(checkpointQueue.get(1).get)) {
removeCheckpointFile()
} else {
canDelete = false
}
}
}

currentData = newData
}

/** Checkpoint the Dataset */
def checkpoint(data: T): Unit

/** Return true iff the Dataset is checkpointed */
def isCheckpointed(data: T): Boolean

/**
* Persist the Dataset.
* Note: This should handle checking the current [[StorageLevel]] of the Dataset.
*/
def persist(data: T): Unit

/** Unpersist the Dataset */
def unpersist(data: T): Unit

/** Get list of checkpoint files for this given Dataset */
def getCheckpointFiles(data: T): Iterable[String]

/**
* Call this at the end to delete any remaining checkpoint files.
*/
def deleteAllCheckpoints(): Unit = {
while (checkpointQueue.nonEmpty) {
removeCheckpointFile()
}
}

/**
* Dequeue the oldest checkpointed Dataset, and remove its checkpoint files.
* This prints a warning but does not fail if the files cannot be removed.
*/
private def removeCheckpointFile(): Unit = {
val old = checkpointQueue.dequeue()
// Since the old checkpoint is not deleted by Spark, we manually delete it.
val fs = FileSystem.get(sc.hadoopConfiguration)
getCheckpointFiles(old).foreach { checkpointFile =>
try {
fs.delete(new Path(checkpointFile), true)
} catch {
case e: Exception =>
logWarning("PeriodicCheckpointer could not remove old checkpoint file: " +
checkpointFile)
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@

package org.apache.spark.mllib.impl

import scala.collection.mutable

import org.apache.hadoop.fs.{Path, FileSystem}

import org.apache.spark.Logging
import org.apache.spark.graphx.Graph
import org.apache.spark.storage.StorageLevel

Expand All @@ -31,12 +26,12 @@ import org.apache.spark.storage.StorageLevel
* Specifically, it automatically handles persisting and (optionally) checkpointing, as well as
* unpersisting and removing checkpoint files.
*
* Users should call [[PeriodicGraphCheckpointer.updateGraph()]] when a new graph has been created,
* Users should call update() when a new graph has been created,
* before the graph has been materialized. After updating [[PeriodicGraphCheckpointer]], users are
* responsible for materializing the graph to ensure that persisting and checkpointing actually
* occur.
*
* When [[PeriodicGraphCheckpointer.updateGraph()]] is called, this does the following:
* When update() is called, this does the following:
* - Persist new graph (if not yet persisted), and put in queue of persisted graphs.
* - Unpersist graphs from queue until there are at most 3 persisted graphs.
* - If using checkpointing and the checkpoint interval has been reached,
Expand Down Expand Up @@ -73,99 +68,30 @@ import org.apache.spark.storage.StorageLevel
* // checkpointed: graph4
* }}}
*
* @param currentGraph Initial graph
* @param initGraph Initial graph
* @param checkpointInterval Graphs will be checkpointed at this interval
* @tparam VD Vertex descriptor type
* @tparam ED Edge descriptor type
*
* TODO: Generalize this for Graphs and RDDs, and move it out of MLlib.
* TODO: Move this out of MLlib?
*/
private[mllib] class PeriodicGraphCheckpointer[VD, ED](
var currentGraph: Graph[VD, ED],
val checkpointInterval: Int) extends Logging {

/** FIFO queue of past checkpointed RDDs */
private val checkpointQueue = mutable.Queue[Graph[VD, ED]]()

/** FIFO queue of past persisted RDDs */
private val persistedQueue = mutable.Queue[Graph[VD, ED]]()
initGraph: Graph[VD, ED],
checkpointInterval: Int)
extends PeriodicCheckpointer[Graph[VD, ED]](initGraph, checkpointInterval,
initGraph.vertices.sparkContext) {

/** Number of times [[updateGraph()]] has been called */
private var updateCount = 0
override def checkpoint(data: Graph[VD, ED]): Unit = data.checkpoint()

/**
* Spark Context for the Graphs given to this checkpointer.
* NOTE: This code assumes that only one SparkContext is used for the given graphs.
*/
private val sc = currentGraph.vertices.sparkContext
override def isCheckpointed(data: Graph[VD, ED]): Boolean = data.isCheckpointed

updateGraph(currentGraph)

/**
* Update [[currentGraph]] with a new graph. Handle persistence and checkpointing as needed.
* Since this handles persistence and checkpointing, this should be called before the graph
* has been materialized.
*
* @param newGraph New graph created from previous graphs in the lineage.
*/
def updateGraph(newGraph: Graph[VD, ED]): Unit = {
if (newGraph.vertices.getStorageLevel == StorageLevel.NONE) {
newGraph.persist()
}
persistedQueue.enqueue(newGraph)
// We try to maintain 2 Graphs in persistedQueue to support the semantics of this class:
// Users should call [[updateGraph()]] when a new graph has been created,
// before the graph has been materialized.
while (persistedQueue.size > 3) {
val graphToUnpersist = persistedQueue.dequeue()
graphToUnpersist.unpersist(blocking = false)
}
updateCount += 1

// Handle checkpointing (after persisting)
if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) {
// Add new checkpoint before removing old checkpoints.
newGraph.checkpoint()
checkpointQueue.enqueue(newGraph)
// Remove checkpoints before the latest one.
var canDelete = true
while (checkpointQueue.size > 1 && canDelete) {
// Delete the oldest checkpoint only if the next checkpoint exists.
if (checkpointQueue.get(1).get.isCheckpointed) {
removeCheckpointFile()
} else {
canDelete = false
}
}
}
}

/**
* Call this at the end to delete any remaining checkpoint files.
*/
def deleteAllCheckpoints(): Unit = {
while (checkpointQueue.size > 0) {
removeCheckpointFile()
override def persist(data: Graph[VD, ED]): Unit = {
if (data.vertices.getStorageLevel == StorageLevel.NONE) {
data.persist()
}
}

/**
* Dequeue the oldest checkpointed Graph, and remove its checkpoint files.
* This prints a warning but does not fail if the files cannot be removed.
*/
private def removeCheckpointFile(): Unit = {
val old = checkpointQueue.dequeue()
// Since the old checkpoint is not deleted by Spark, we manually delete it.
val fs = FileSystem.get(sc.hadoopConfiguration)
old.getCheckpointFiles.foreach { checkpointFile =>
try {
fs.delete(new Path(checkpointFile), true)
} catch {
case e: Exception =>
logWarning("PeriodicGraphCheckpointer could not remove old checkpoint file: " +
checkpointFile)
}
}
}
override def unpersist(data: Graph[VD, ED]): Unit = data.unpersist(blocking = false)

override def getCheckpointFiles(data: Graph[VD, ED]): Iterable[String] = data.getCheckpointFiles
}
Loading

0 comments on commit 568918c

Please sign in to comment.