diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala index dbb954945a8b6..e8372c0458600 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala @@ -22,8 +22,9 @@ import java.util import com.google.common.collect.ImmutableMap import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} -import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleDriverComponents, ShuffleExecutorComponents, ShuffleWriteSupport} +import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleDriverComponents, ShuffleExecutorComponents, ShuffleReadSupport, ShuffleWriteSupport} import org.apache.spark.internal.config.SHUFFLE_IO_PLUGIN_CLASS +import org.apache.spark.shuffle.io.DefaultShuffleReadSupport import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport class ShuffleDriverComponentsSuite extends SparkFunSuite with LocalSparkContext { @@ -66,6 +67,13 @@ class TestShuffleExecutorComponents(sparkConf: SparkConf) extends ShuffleExecuto override def writes(): ShuffleWriteSupport = { val blockManager = SparkEnv.get.blockManager val blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager) - new DefaultShuffleWriteSupport(sparkConf, blockResolver) + new DefaultShuffleWriteSupport(sparkConf, blockResolver, blockManager.shuffleServerId) + } + + override def reads(): ShuffleReadSupport = { + val blockManager = SparkEnv.get.blockManager + val mapOutputTracker = SparkEnv.get.mapOutputTracker + val serializerManager = SparkEnv.get.serializerManager + new DefaultShuffleReadSupport(blockManager, mapOutputTracker, serializerManager, sparkConf) } }