Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-12004] Preserve the RDD partitioner through RDD checkpointing #9983

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 116 additions & 6 deletions core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,22 @@ package org.apache.spark.rdd
import java.io.IOException

import scala.reflect.ClassTag
import scala.util.control.NonFatal

import org.apache.hadoop.fs.Path

import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.util.{SerializableConfiguration, Utils}

/**
* An RDD that reads from checkpoint files previously written to reliable storage.
*/
private[spark] class ReliableCheckpointRDD[T: ClassTag](
sc: SparkContext,
val checkpointPath: String)
extends CheckpointRDD[T](sc) {
val checkpointPath: String,
_partitioner: Option[Partitioner] = None
) extends CheckpointRDD[T](sc) {

@transient private val hadoopConf = sc.hadoopConfiguration
@transient private val cpath = new Path(checkpointPath)
Expand All @@ -47,7 +48,13 @@ private[spark] class ReliableCheckpointRDD[T: ClassTag](
/**
* Return the path of the checkpoint directory this RDD reads data from.
*/
override def getCheckpointFile: Option[String] = Some(checkpointPath)
override val getCheckpointFile: Option[String] = Some(checkpointPath)

override val partitioner: Option[Partitioner] = {
_partitioner.orElse {
ReliableCheckpointRDD.readCheckpointedPartitionerFile(context, checkpointPath)
}
}

/**
* Return partitions described by the files in the checkpoint directory.
Expand Down Expand Up @@ -100,10 +107,52 @@ private[spark] object ReliableCheckpointRDD extends Logging {
"part-%05d".format(partitionIndex)
}

private def checkpointPartitionerFileName(): String = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just make this a val

"_partitioner"
}

/**
* Write RDD to checkpoint files and return a ReliableCheckpointRDD representing the RDD.
*/
def writeRDDToCheckpointDirectory[T: ClassTag](
originalRDD: RDD[T],
checkpointDir: String,
blockSize: Int = -1): ReliableCheckpointRDD[T] = {

val sc = originalRDD.sparkContext

// Create the output path for the checkpoint
val checkpointDirPath = new Path(checkpointDir)
val fs = checkpointDirPath.getFileSystem(sc.hadoopConfiguration)
if (!fs.mkdirs(checkpointDirPath)) {
throw new SparkException(s"Failed to create checkpoint path $checkpointDirPath")
}

// Save to file, and reload it as an RDD
val broadcastedConf = sc.broadcast(
new SerializableConfiguration(sc.hadoopConfiguration))
// TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582)
sc.runJob(originalRDD,
writePartitionToCheckpointFile[T](checkpointDirPath.toString, broadcastedConf) _)

if (originalRDD.partitioner.nonEmpty) {
writePartitionerToCheckpointDir(sc, originalRDD.partitioner.get, checkpointDirPath)
}

val newRDD = new ReliableCheckpointRDD[T](
sc, checkpointDirPath.toString, originalRDD.partitioner)
if (newRDD.partitions.length != originalRDD.partitions.length) {
throw new SparkException(
s"Checkpoint RDD $newRDD(${newRDD.partitions.length}) has different " +
s"number of partitions from original RDD $originalRDD(${originalRDD.partitions.length})")
}
newRDD
}

/**
* Write this partition's values to a checkpoint file.
* Write a RDD partition's data to a checkpoint file.
*/
def writeCheckpointFile[T: ClassTag](
def writePartitionToCheckpointFile[T: ClassTag](
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

private

path: String,
broadcastedConf: Broadcast[SerializableConfiguration],
blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
Expand Down Expand Up @@ -151,6 +200,67 @@ private[spark] object ReliableCheckpointRDD extends Logging {
}
}

/**
* Write a partitioner to the given RDD checkpoint directory. This is done on a best-effort
* basis; any exception while writing the partitioner is caught, logged and ignored.
*/
private def writePartitionerToCheckpointDir(
sc: SparkContext, partitioner: Partitioner, checkpointDirPath: Path): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

private, also style:

def writePartitionerToCheckpointDir(
    sc: SparkContext,
    partitioner: Partitioner,
    checkpointDirPath: Path): Unit

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add java doc? (why do we need this)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, can you just call this writePartitionerToCheckpoint

try {
val partitionerFilePath = new Path(checkpointDirPath, checkpointPartitionerFileName)
val bufferSize = sc.conf.getInt("spark.buffer.size", 65536)
val fs = partitionerFilePath.getFileSystem(sc.hadoopConfiguration)
val fileOutputStream = fs.create(partitionerFilePath, false, bufferSize)
val serializer = SparkEnv.get.serializer.newInstance()
val serializeStream = serializer.serializeStream(fileOutputStream)
Utils.tryWithSafeFinally {
serializeStream.writeObject(partitioner)
} {
serializeStream.close()
}
logDebug(s"Written partitioner to $partitionerFilePath")
} catch {
case NonFatal(e) =>
logWarning(s"Error writing partitioner $partitioner to $checkpointDirPath")
}
}


/**
* Read a partitioner from the given RDD checkpoint directory, if it exists.
* This is done on a best-effort basis; any exception while reading the partitioner is
* caught, logged and ignored.
*/
private def readCheckpointedPartitionerFile(
sc: SparkContext,
checkpointDirPath: String): Option[Partitioner] = {
try {
val bufferSize = sc.conf.getInt("spark.buffer.size", 65536)
val partitionerFilePath = new Path(checkpointDirPath, checkpointPartitionerFileName)
val fs = partitionerFilePath.getFileSystem(sc.hadoopConfiguration)
if (fs.exists(partitionerFilePath)) {
val fileInputStream = fs.open(partitionerFilePath, bufferSize)
val serializer = SparkEnv.get.serializer.newInstance()
val deserializeStream = serializer.deserializeStream(fileInputStream)
val partitioner = Utils.tryWithSafeFinally[Partitioner] {
deserializeStream.readObject[Partitioner]
} {
deserializeStream.close()
}
logDebug(s"Read partitioner from $partitionerFilePath")
Some(partitioner)
} else {
logDebug("No partitioner file")
None
}
} catch {
case NonFatal(e) =>
logWarning(s"Error reading partitioner from $checkpointDirPath, " +
s"partitioner will not be recovered which may lead to performance loss", e)
None
}
}

/**
* Read the content of the specified checkpoint file.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,25 +55,7 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v
* This is called immediately after the first action invoked on this RDD has completed.
*/
protected override def doCheckpoint(): CheckpointRDD[T] = {

// Create the output path for the checkpoint
val path = new Path(cpDir)
val fs = path.getFileSystem(rdd.context.hadoopConfiguration)
if (!fs.mkdirs(path)) {
throw new SparkException(s"Failed to create checkpoint path $cpDir")
}

// Save to file, and reload it as an RDD
val broadcastedConf = rdd.context.broadcast(
new SerializableConfiguration(rdd.context.hadoopConfiguration))
// TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582)
rdd.context.runJob(rdd, ReliableCheckpointRDD.writeCheckpointFile[T](cpDir, broadcastedConf) _)
val newRDD = new ReliableCheckpointRDD[T](rdd.context, cpDir)
if (newRDD.partitions.length != rdd.partitions.length) {
throw new SparkException(
s"Checkpoint RDD $newRDD(${newRDD.partitions.length}) has different " +
s"number of partitions from original RDD $rdd(${rdd.partitions.length})")
}
val newRDD = ReliableCheckpointRDD.writeRDDToCheckpointDirectory(rdd, cpDir)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All this code has been moved in the ReliableCheckpointRDD.createCheckpointedRDD


// Optionally clean our checkpoint files if the reference is out of scope
if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) {
Expand All @@ -83,7 +65,6 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v
}

logInfo(s"Done checkpointing RDD ${rdd.id} to $cpDir, new parent is RDD ${newRDD.id}")

newRDD
}

Expand Down
61 changes: 56 additions & 5 deletions core/src/test/scala/org/apache/spark/CheckpointSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ import java.io.File

import scala.reflect.ClassTag

import org.apache.spark.CheckpointSuite._
import org.apache.hadoop.fs.Path

import org.apache.spark.rdd._
import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId}
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -74,8 +75,10 @@ trait RDDCheckpointTester { self: SparkFunSuite =>

// Test whether the checkpoint file has been created
if (reliableCheckpoint) {
assert(
collectFunc(sparkContext.checkpointFile[U](operatedRDD.getCheckpointFile.get)) === result)
assert(operatedRDD.getCheckpointFile.nonEmpty)
val recoveredRDD = sparkContext.checkpointFile[U](operatedRDD.getCheckpointFile.get)
assert(collectFunc(recoveredRDD) === result)
assert(recoveredRDD.partitioner === operatedRDD.partitioner)
}

// Test whether dependencies have been changed from its earlier parent RDD
Expand Down Expand Up @@ -211,9 +214,14 @@ trait RDDCheckpointTester { self: SparkFunSuite =>
}

/** Run a test twice, once for local checkpointing and once for reliable checkpointing. */
protected def runTest(name: String)(body: Boolean => Unit): Unit = {
protected def runTest(
name: String,
skipLocalCheckpoint: Boolean = false
)(body: Boolean => Unit): Unit = {
test(name + " [reliable checkpoint]")(body(true))
test(name + " [local checkpoint]")(body(false))
if (!skipLocalCheckpoint) {
test(name + " [local checkpoint]")(body(false))
}
}

/**
Expand Down Expand Up @@ -264,6 +272,49 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS
assert(flatMappedRDD.collect() === result)
}

runTest("checkpointing partitioners", skipLocalCheckpoint = true) { _: Boolean =>

def testPartitionerCheckpointing(
partitioner: Partitioner,
corruptPartitionerFile: Boolean = false
): Unit = {
val rddWithPartitioner = sc.makeRDD(1 to 4).map { _ -> 1 }.partitionBy(partitioner)
rddWithPartitioner.checkpoint()
rddWithPartitioner.count()
assert(rddWithPartitioner.getCheckpointFile.get.nonEmpty,
"checkpointing was not successful")

if (corruptPartitionerFile) {
// Overwrite the partitioner file with garbage data
val checkpointDir = new Path(rddWithPartitioner.getCheckpointFile.get)
val fs = checkpointDir.getFileSystem(sc.hadoopConfiguration)
val partitionerFile = fs.listStatus(checkpointDir)
.find(_.getPath.getName.contains("partitioner"))
.map(_.getPath)
require(partitionerFile.nonEmpty, "could not find the partitioner file for testing")
val output = fs.create(partitionerFile.get, true)
output.write(100)
output.close()
}

val newRDD = sc.checkpointFile[(Int, Int)](rddWithPartitioner.getCheckpointFile.get)
assert(newRDD.collect().toSet === rddWithPartitioner.collect().toSet, "RDD not recovered")

if (!corruptPartitionerFile) {
assert(newRDD.partitioner != None, "partitioner not recovered")
assert(newRDD.partitioner === rddWithPartitioner.partitioner,
"recovered partitioner does not match")
} else {
assert(newRDD.partitioner == None, "partitioner unexpectedly recovered")
}
}

testPartitionerCheckpointing(partitioner)

// Test that corrupted partitioner file does not prevent recovery of RDD
testPartitionerCheckpointing(partitioner, corruptPartitionerFile = true)
}

runTest("RDDs with one-to-one dependencies") { reliableCheckpoint: Boolean =>
testRDD(_.map(x => x.toString), reliableCheckpoint)
testRDD(_.flatMap(x => 1 to x), reliableCheckpoint)
Expand Down