Skip to content

Commit

Permalink
use LocalSparkSession
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Jun 11, 2018
1 parent 3f59ca2 commit 04d7f78
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,15 @@ trait LocalSparkSession extends BeforeAndAfterEach with BeforeAndAfterAll { self
override def beforeAll() {
super.beforeAll()
InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE)
SparkSession.clearActiveSession()
SparkSession.clearDefaultSession()
}

override def afterEach() {
try {
resetSparkContext()
SparkSession.clearActiveSession()
SparkSession.clearDefaultSession()
} finally {
super.afterEach()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@ import java.util.Properties
import org.apache.spark._
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.{LocalSparkSession, Row, SparkSession}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.types._
import org.apache.spark.storage.ShuffleBlockId
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.ExternalSorter

/**
Expand All @@ -42,15 +41,7 @@ class ClosableByteArrayInputStream(buf: Array[Byte]) extends ByteArrayInputStrea
}
}

class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {

override def beforeAll() {
super.beforeAll()
// This test suite calls `UnsafeProjection.create` which accesses `SQLConf.get`, we should make
// sure active session is cleaned so that `SQLConf.get` won't refer to a stopped session.
SparkSession.clearActiveSession()
SparkSession.clearDefaultSession()
}
class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkSession {

private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = {
val converter = unsafeRowConverter(schema)
Expand Down Expand Up @@ -104,59 +95,43 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
}

test("SPARK-10466: external sorter spilling with unsafe row serializer") {
var sc: SparkContext = null
var outputFile: File = null
val oldEnv = SparkEnv.get // save the old SparkEnv, as it will be overwritten
Utils.tryWithSafeFinally {
val conf = new SparkConf()
.set("spark.shuffle.spill.initialMemoryThreshold", "1")
.set("spark.shuffle.sort.bypassMergeThreshold", "0")
.set("spark.testing.memory", "80000")

sc = new SparkContext("local", "test", conf)
outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "")
// prepare data
val converter = unsafeRowConverter(Array(IntegerType))
val data = (1 to 10000).iterator.map { i =>
(i, converter(Row(i)))
}
val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0)
val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties, null)

val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow](
taskContext,
partitioner = Some(new HashPartitioner(10)),
serializer = new UnsafeRowSerializer(numFields = 1))

// Ensure we spilled something and have to merge them later
assert(sorter.numSpills === 0)
sorter.insertAll(data)
assert(sorter.numSpills > 0)
val conf = new SparkConf()
.set("spark.shuffle.spill.initialMemoryThreshold", "1")
.set("spark.shuffle.sort.bypassMergeThreshold", "0")
.set("spark.testing.memory", "80000")
spark = SparkSession.builder().master("local").appName("test").config(conf).getOrCreate()
val outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "")
outputFile.deleteOnExit()
// prepare data
val converter = unsafeRowConverter(Array(IntegerType))
val data = (1 to 10000).iterator.map { i =>
(i, converter(Row(i)))
}
val taskMemoryManager = new TaskMemoryManager(spark.sparkContext.env.memoryManager, 0)
val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties, null)

// Merging spilled files should not throw assertion error
sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), outputFile)
} {
// Clean up
if (sc != null) {
sc.stop()
}
val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow](
taskContext,
partitioner = Some(new HashPartitioner(10)),
serializer = new UnsafeRowSerializer(numFields = 1))

// restore the spark env
SparkEnv.set(oldEnv)
// Ensure we spilled something and have to merge them later
assert(sorter.numSpills === 0)
sorter.insertAll(data)
assert(sorter.numSpills > 0)

if (outputFile != null) {
outputFile.delete()
}
}
// Merging spilled files should not throw assertion error
sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), outputFile)
}

test("SPARK-10403: unsafe row serializer with SortShuffleManager") {
val conf = new SparkConf().set("spark.shuffle.manager", "sort")
sc = new SparkContext("local", "test", conf)
spark = SparkSession.builder().master("local").appName("test").config(conf).getOrCreate()
val row = Row("Hello", 123)
val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType))
val rowsRDD = sc.parallelize(Seq((0, unsafeRow), (1, unsafeRow), (0, unsafeRow)))
.asInstanceOf[RDD[Product2[Int, InternalRow]]]
val rowsRDD = spark.sparkContext.parallelize(
Seq((0, unsafeRow), (1, unsafeRow), (0, unsafeRow))
).asInstanceOf[RDD[Product2[Int, InternalRow]]]
val dependency =
new ShuffleDependency[Int, InternalRow, InternalRow](
rowsRDD,
Expand Down

0 comments on commit 04d7f78

Please sign in to comment.