Skip to content

Commit

Permalink
Don't cache the RDD broadcast variable.
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin committed Jul 30, 2014
1 parent d256b45 commit cc152fc
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 28 deletions.
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1225,7 +1225,7 @@ abstract class RDD[T: ClassTag](
* might modify state of objects referenced in their closures. This is necessary in Hadoop
* where the JobConf/Configuration object is not thread-safe.
*/
@transient private[spark] lazy val broadcasted: Broadcast[Array[Byte]] = {
@transient private[spark] def createBroadcastBinary(): Broadcast[Array[Byte]] = synchronized {
val ser = SparkEnv.get.closureSerializer.newInstance()
val bytes = ser.serialize(this).array()
val size = Utils.bytesToString(bytes.length)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -694,18 +694,21 @@ class DAGScheduler(
// Get our pending tasks and remember them in our pendingTasks entry
stage.pendingTasks.clear()
var tasks = ArrayBuffer[Task[_]]()
val broadcastRddBinary = stage.rdd.createBroadcastBinary()
if (stage.isShuffleMap) {
for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
val locs = getPreferredLocs(stage.rdd, p)
tasks += new ShuffleMapTask(stage.id, stage.rdd, stage.shuffleDep.get, p, locs)
val part = stage.rdd.partitions(p)
tasks += new ShuffleMapTask(stage.id, broadcastRddBinary, stage.shuffleDep.get, part, locs)
}
} else {
// This is a final stage; figure out its job's missing partitions
val job = stage.resultOfJob.get
for (id <- 0 until job.numPartitions if !job.finished(id)) {
val partition = job.partitions(id)
val locs = getPreferredLocs(stage.rdd, partition)
tasks += new ResultTask(stage.id, stage.rdd, job.func, partition, locs, id)
val p: Int = job.partitions(id)
val part = stage.rdd.partitions(p)
val locs = getPreferredLocs(stage.rdd, p)
tasks += new ResultTask(stage.id, broadcastRddBinary, job.func, part, locs, id)
}
}

Expand Down
10 changes: 0 additions & 10 deletions core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,6 @@ private[spark] class ResultTask[T, U](
// TODO: Should we also broadcast func? For that we would need a place to
// keep a reference to it (perhaps in DAGScheduler's job object).

def this(
stageId: Int,
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitionId: Int,
locs: Seq[TaskLocation],
outputId: Int) = {
this(stageId, rdd.broadcasted, func, rdd.partitions(partitionId), locs, outputId)
}

@transient private[this] val preferredLocs: Seq[TaskLocation] = {
if (locs == null) Nil else locs.toSet.toSeq
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,6 @@ private[spark] class ShuffleMapTask(
// TODO: Should we also broadcast the ShuffleDependency? For that we would need a place to
// keep a reference to it (perhaps in Stage).

def this(
stageId: Int,
rdd: RDD[_],
dep: ShuffleDependency[_, _, _],
partitionId: Int,
locs: Seq[TaskLocation]) = {
this(stageId, rdd.broadcasted, dep, rdd.partitions(partitionId), locs)
}

/** A constructor used only in test suites. This does not require passing in an RDD. */
def this(partitionId: Int) {
this(0, null, null, new Partition { override def index = 0 }, null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo

// Test that GC causes broadcast task data cleanup after dereferencing the RDD.
val postGCTester = new CleanerTester(sc,
broadcastIds = Seq(rdd.broadcasted.id, rdd.firstParent.broadcasted.id))
broadcastIds = Seq(rdd.createBroadcastBinary.id, rdd.firstParent.createBroadcastBinary.id))
rdd = null
runGC()
postGCTester.assertCleanup()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,14 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte
sys.error("failed")
}
}
val func = (c: TaskContext, i: Iterator[String]) => i.next
val task = new ResultTask[String, String](0, rdd, func, 0, Seq(), 0)
val func = (c: TaskContext, i: Iterator[String]) => i.next()
val task = new ResultTask[String, String](
0, rdd.createBroadcastBinary(), func, rdd.partitions(0), Seq(), 0)
intercept[RuntimeException] {
task.run(0)
}
assert(completed === true)
}

case class StubPartition(val index: Int) extends Partition
case class StubPartition(index: Int) extends Partition
}

0 comments on commit cc152fc

Please sign in to comment.