Skip to content

Commit

Permalink
[SPARK-12004] Preserve the RDD partitioner through RDD checkpointing
Browse files Browse the repository at this point in the history
The solution is the save the RDD partitioner in a separate file in the RDD checkpoint directory. That is, `<checkpoint dir>/_partitioner`.  In most cases, whether the RDD partitioner was recovered or not, does not affect the correctness, only reduces performance. So this solution makes a best-effort attempt to save and recover the partitioner. If either fails, the checkpointing is not affected. This makes this patch safe and backward compatible.

Author: Tathagata Das <[email protected]>

Closes #9983 from tdas/SPARK-12004.
  • Loading branch information
tdas authored and Andrew Or committed Dec 1, 2015
1 parent 2cef1cd commit 60b541e
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 31 deletions.
122 changes: 116 additions & 6 deletions core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,22 @@ package org.apache.spark.rdd
import java.io.IOException

import scala.reflect.ClassTag
import scala.util.control.NonFatal

import org.apache.hadoop.fs.Path

import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.util.{SerializableConfiguration, Utils}

/**
* An RDD that reads from checkpoint files previously written to reliable storage.
*/
private[spark] class ReliableCheckpointRDD[T: ClassTag](
sc: SparkContext,
val checkpointPath: String)
extends CheckpointRDD[T](sc) {
val checkpointPath: String,
_partitioner: Option[Partitioner] = None
) extends CheckpointRDD[T](sc) {

@transient private val hadoopConf = sc.hadoopConfiguration
@transient private val cpath = new Path(checkpointPath)
Expand All @@ -47,7 +48,13 @@ private[spark] class ReliableCheckpointRDD[T: ClassTag](
/**
* Return the path of the checkpoint directory this RDD reads data from.
*/
override def getCheckpointFile: Option[String] = Some(checkpointPath)
override val getCheckpointFile: Option[String] = Some(checkpointPath)

override val partitioner: Option[Partitioner] = {
_partitioner.orElse {
ReliableCheckpointRDD.readCheckpointedPartitionerFile(context, checkpointPath)
}
}

/**
* Return partitions described by the files in the checkpoint directory.
Expand Down Expand Up @@ -100,10 +107,52 @@ private[spark] object ReliableCheckpointRDD extends Logging {
"part-%05d".format(partitionIndex)
}

private def checkpointPartitionerFileName(): String = {
"_partitioner"
}

/**
* Write RDD to checkpoint files and return a ReliableCheckpointRDD representing the RDD.
*/
def writeRDDToCheckpointDirectory[T: ClassTag](
originalRDD: RDD[T],
checkpointDir: String,
blockSize: Int = -1): ReliableCheckpointRDD[T] = {

val sc = originalRDD.sparkContext

// Create the output path for the checkpoint
val checkpointDirPath = new Path(checkpointDir)
val fs = checkpointDirPath.getFileSystem(sc.hadoopConfiguration)
if (!fs.mkdirs(checkpointDirPath)) {
throw new SparkException(s"Failed to create checkpoint path $checkpointDirPath")
}

// Save to file, and reload it as an RDD
val broadcastedConf = sc.broadcast(
new SerializableConfiguration(sc.hadoopConfiguration))
// TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582)
sc.runJob(originalRDD,
writePartitionToCheckpointFile[T](checkpointDirPath.toString, broadcastedConf) _)

if (originalRDD.partitioner.nonEmpty) {
writePartitionerToCheckpointDir(sc, originalRDD.partitioner.get, checkpointDirPath)
}

val newRDD = new ReliableCheckpointRDD[T](
sc, checkpointDirPath.toString, originalRDD.partitioner)
if (newRDD.partitions.length != originalRDD.partitions.length) {
throw new SparkException(
s"Checkpoint RDD $newRDD(${newRDD.partitions.length}) has different " +
s"number of partitions from original RDD $originalRDD(${originalRDD.partitions.length})")
}
newRDD
}

/**
* Write this partition's values to a checkpoint file.
* Write a RDD partition's data to a checkpoint file.
*/
def writeCheckpointFile[T: ClassTag](
def writePartitionToCheckpointFile[T: ClassTag](
path: String,
broadcastedConf: Broadcast[SerializableConfiguration],
blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
Expand Down Expand Up @@ -151,6 +200,67 @@ private[spark] object ReliableCheckpointRDD extends Logging {
}
}

/**
* Write a partitioner to the given RDD checkpoint directory. This is done on a best-effort
* basis; any exception while writing the partitioner is caught, logged and ignored.
*/
private def writePartitionerToCheckpointDir(
sc: SparkContext, partitioner: Partitioner, checkpointDirPath: Path): Unit = {
try {
val partitionerFilePath = new Path(checkpointDirPath, checkpointPartitionerFileName)
val bufferSize = sc.conf.getInt("spark.buffer.size", 65536)
val fs = partitionerFilePath.getFileSystem(sc.hadoopConfiguration)
val fileOutputStream = fs.create(partitionerFilePath, false, bufferSize)
val serializer = SparkEnv.get.serializer.newInstance()
val serializeStream = serializer.serializeStream(fileOutputStream)
Utils.tryWithSafeFinally {
serializeStream.writeObject(partitioner)
} {
serializeStream.close()
}
logDebug(s"Written partitioner to $partitionerFilePath")
} catch {
case NonFatal(e) =>
logWarning(s"Error writing partitioner $partitioner to $checkpointDirPath")
}
}


/**
* Read a partitioner from the given RDD checkpoint directory, if it exists.
* This is done on a best-effort basis; any exception while reading the partitioner is
* caught, logged and ignored.
*/
private def readCheckpointedPartitionerFile(
sc: SparkContext,
checkpointDirPath: String): Option[Partitioner] = {
try {
val bufferSize = sc.conf.getInt("spark.buffer.size", 65536)
val partitionerFilePath = new Path(checkpointDirPath, checkpointPartitionerFileName)
val fs = partitionerFilePath.getFileSystem(sc.hadoopConfiguration)
if (fs.exists(partitionerFilePath)) {
val fileInputStream = fs.open(partitionerFilePath, bufferSize)
val serializer = SparkEnv.get.serializer.newInstance()
val deserializeStream = serializer.deserializeStream(fileInputStream)
val partitioner = Utils.tryWithSafeFinally[Partitioner] {
deserializeStream.readObject[Partitioner]
} {
deserializeStream.close()
}
logDebug(s"Read partitioner from $partitionerFilePath")
Some(partitioner)
} else {
logDebug("No partitioner file")
None
}
} catch {
case NonFatal(e) =>
logWarning(s"Error reading partitioner from $checkpointDirPath, " +
s"partitioner will not be recovered which may lead to performance loss", e)
None
}
}

/**
* Read the content of the specified checkpoint file.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,25 +55,7 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v
* This is called immediately after the first action invoked on this RDD has completed.
*/
protected override def doCheckpoint(): CheckpointRDD[T] = {

// Create the output path for the checkpoint
val path = new Path(cpDir)
val fs = path.getFileSystem(rdd.context.hadoopConfiguration)
if (!fs.mkdirs(path)) {
throw new SparkException(s"Failed to create checkpoint path $cpDir")
}

// Save to file, and reload it as an RDD
val broadcastedConf = rdd.context.broadcast(
new SerializableConfiguration(rdd.context.hadoopConfiguration))
// TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582)
rdd.context.runJob(rdd, ReliableCheckpointRDD.writeCheckpointFile[T](cpDir, broadcastedConf) _)
val newRDD = new ReliableCheckpointRDD[T](rdd.context, cpDir)
if (newRDD.partitions.length != rdd.partitions.length) {
throw new SparkException(
s"Checkpoint RDD $newRDD(${newRDD.partitions.length}) has different " +
s"number of partitions from original RDD $rdd(${rdd.partitions.length})")
}
val newRDD = ReliableCheckpointRDD.writeRDDToCheckpointDirectory(rdd, cpDir)

// Optionally clean our checkpoint files if the reference is out of scope
if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) {
Expand All @@ -83,7 +65,6 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v
}

logInfo(s"Done checkpointing RDD ${rdd.id} to $cpDir, new parent is RDD ${newRDD.id}")

newRDD
}

Expand Down
61 changes: 56 additions & 5 deletions core/src/test/scala/org/apache/spark/CheckpointSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ import java.io.File

import scala.reflect.ClassTag

import org.apache.spark.CheckpointSuite._
import org.apache.hadoop.fs.Path

import org.apache.spark.rdd._
import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId}
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -74,8 +75,10 @@ trait RDDCheckpointTester { self: SparkFunSuite =>

// Test whether the checkpoint file has been created
if (reliableCheckpoint) {
assert(
collectFunc(sparkContext.checkpointFile[U](operatedRDD.getCheckpointFile.get)) === result)
assert(operatedRDD.getCheckpointFile.nonEmpty)
val recoveredRDD = sparkContext.checkpointFile[U](operatedRDD.getCheckpointFile.get)
assert(collectFunc(recoveredRDD) === result)
assert(recoveredRDD.partitioner === operatedRDD.partitioner)
}

// Test whether dependencies have been changed from its earlier parent RDD
Expand Down Expand Up @@ -211,9 +214,14 @@ trait RDDCheckpointTester { self: SparkFunSuite =>
}

/** Run a test twice, once for local checkpointing and once for reliable checkpointing. */
protected def runTest(name: String)(body: Boolean => Unit): Unit = {
protected def runTest(
name: String,
skipLocalCheckpoint: Boolean = false
)(body: Boolean => Unit): Unit = {
test(name + " [reliable checkpoint]")(body(true))
test(name + " [local checkpoint]")(body(false))
if (!skipLocalCheckpoint) {
test(name + " [local checkpoint]")(body(false))
}
}

/**
Expand Down Expand Up @@ -264,6 +272,49 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS
assert(flatMappedRDD.collect() === result)
}

runTest("checkpointing partitioners", skipLocalCheckpoint = true) { _: Boolean =>

def testPartitionerCheckpointing(
partitioner: Partitioner,
corruptPartitionerFile: Boolean = false
): Unit = {
val rddWithPartitioner = sc.makeRDD(1 to 4).map { _ -> 1 }.partitionBy(partitioner)
rddWithPartitioner.checkpoint()
rddWithPartitioner.count()
assert(rddWithPartitioner.getCheckpointFile.get.nonEmpty,
"checkpointing was not successful")

if (corruptPartitionerFile) {
// Overwrite the partitioner file with garbage data
val checkpointDir = new Path(rddWithPartitioner.getCheckpointFile.get)
val fs = checkpointDir.getFileSystem(sc.hadoopConfiguration)
val partitionerFile = fs.listStatus(checkpointDir)
.find(_.getPath.getName.contains("partitioner"))
.map(_.getPath)
require(partitionerFile.nonEmpty, "could not find the partitioner file for testing")
val output = fs.create(partitionerFile.get, true)
output.write(100)
output.close()
}

val newRDD = sc.checkpointFile[(Int, Int)](rddWithPartitioner.getCheckpointFile.get)
assert(newRDD.collect().toSet === rddWithPartitioner.collect().toSet, "RDD not recovered")

if (!corruptPartitionerFile) {
assert(newRDD.partitioner != None, "partitioner not recovered")
assert(newRDD.partitioner === rddWithPartitioner.partitioner,
"recovered partitioner does not match")
} else {
assert(newRDD.partitioner == None, "partitioner unexpectedly recovered")
}
}

testPartitionerCheckpointing(partitioner)

// Test that corrupted partitioner file does not prevent recovery of RDD
testPartitionerCheckpointing(partitioner, corruptPartitionerFile = true)
}

runTest("RDDs with one-to-one dependencies") { reliableCheckpoint: Boolean =>
testRDD(_.map(x => x.toString), reliableCheckpoint)
testRDD(_.flatMap(x => 1 to x), reliableCheckpoint)
Expand Down

0 comments on commit 60b541e

Please sign in to comment.