Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-6986][SQL] Use Serializer2 in more cases. #5849

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,8 @@ case class Exchange(
def serializer(
keySchema: Array[DataType],
valueSchema: Array[DataType],
hasKeyOrdering: Boolean,
numPartitions: Int): Serializer = {
// In ExternalSorter's spillToMergeableFile function, key-value pairs are written out
// through write(key) and then write(value) instead of write((key, value)). Because
// SparkSqlSerializer2 assumes that objects passed in are Product2, we cannot safely use
// it when spillToMergeableFile in ExternalSorter will be used.
// So, we will not use SparkSqlSerializer2 when
// - Sort-based shuffle is enabled and the number of reducers (numPartitions) is greater
// then the bypassMergeThreshold; or
// - newOrdering is defined.
val cannotUseSqlSerializer2 =
(sortBasedShuffleOn && numPartitions > bypassMergeThreshold) || newOrdering.nonEmpty

// It is true when there is no field that needs to be write out.
// For now, we will not use SparkSqlSerializer2 when noField is true.
val noField =
Expand All @@ -104,14 +94,13 @@ case class Exchange(

val useSqlSerializer2 =
child.sqlContext.conf.useSqlSerializer2 && // SparkSqlSerializer2 is enabled.
!cannotUseSqlSerializer2 && // Safe to use Serializer2.
SparkSqlSerializer2.support(keySchema) && // The schema of key is supported.
SparkSqlSerializer2.support(valueSchema) && // The schema of value is supported.
!noField

val serializer = if (useSqlSerializer2) {
logInfo("Using SparkSqlSerializer2.")
new SparkSqlSerializer2(keySchema, valueSchema)
new SparkSqlSerializer2(keySchema, valueSchema, hasKeyOrdering)
} else {
logInfo("Using SparkSqlSerializer.")
new SparkSqlSerializer(sparkConf)
Expand Down Expand Up @@ -154,7 +143,8 @@ case class Exchange(
}
val keySchema = expressions.map(_.dataType).toArray
val valueSchema = child.output.map(_.dataType).toArray
shuffled.setSerializer(serializer(keySchema, valueSchema, numPartitions))
shuffled.setSerializer(
serializer(keySchema, valueSchema, newOrdering.nonEmpty, numPartitions))

shuffled.map(_._2)

Expand All @@ -179,7 +169,8 @@ case class Exchange(
new ShuffledRDD[Row, Null, Null](rdd, part)
}
val keySchema = child.output.map(_.dataType).toArray
shuffled.setSerializer(serializer(keySchema, null, numPartitions))
shuffled.setSerializer(
serializer(keySchema, null, newOrdering.nonEmpty, numPartitions))

shuffled.map(_._1)

Expand All @@ -199,7 +190,7 @@ case class Exchange(
val partitioner = new HashPartitioner(1)
val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner)
val valueSchema = child.output.map(_.dataType).toArray
shuffled.setSerializer(serializer(null, valueSchema, 1))
shuffled.setSerializer(serializer(null, valueSchema, false, 1))
shuffled.map(_._2)

case _ => sys.error(s"Exchange not implemented for $newPartitioning")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import scala.reflect.ClassTag
import org.apache.spark.serializer._
import org.apache.spark.Logging
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, MutableRow, GenericMutableRow}
import org.apache.spark.sql.types._

/**
Expand All @@ -49,9 +49,9 @@ private[sql] class Serializer2SerializationStream(
out: OutputStream)
extends SerializationStream with Logging {

val rowOut = new DataOutputStream(out)
val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut)
val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut)
private val rowOut = new DataOutputStream(new BufferedOutputStream(out))
private val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut)
private val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut)

override def writeObject[T: ClassTag](t: T): SerializationStream = {
val kv = t.asInstanceOf[Product2[Row, Row]]
Expand Down Expand Up @@ -86,41 +86,55 @@ private[sql] class Serializer2SerializationStream(
private[sql] class Serializer2DeserializationStream(
keySchema: Array[DataType],
valueSchema: Array[DataType],
hasKeyOrdering: Boolean,
in: InputStream)
extends DeserializationStream with Logging {

val rowIn = new DataInputStream(new BufferedInputStream(in))
private val rowIn = new DataInputStream(new BufferedInputStream(in))

private def rowGenerator(schema: Array[DataType]): () => (MutableRow) = {
if (schema == null) {
() => null
} else {
if (hasKeyOrdering) {
// We have key ordering specified in a ShuffledRDD, it is not safe to reuse a mutable row.
() => new GenericMutableRow(schema.length)
} else {
// It is safe to reuse the mutable row.
val mutableRow = new SpecificMutableRow(schema)
() => mutableRow
}
}
}

val key = if (keySchema != null) new SpecificMutableRow(keySchema) else null
val value = if (valueSchema != null) new SpecificMutableRow(valueSchema) else null
val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn, key)
val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn, value)
// Functions used to return rows for key and value.
private val getKey = rowGenerator(keySchema)
private val getValue = rowGenerator(valueSchema)
// Functions used to read a serialized row from the InputStream and deserialize it.
private val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn)
private val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn)

override def readObject[T: ClassTag](): T = {
readKeyFunc()
readValueFunc()

(key, value).asInstanceOf[T]
(readKeyFunc(getKey()), readValueFunc(getValue())).asInstanceOf[T]
}

override def readKey[T: ClassTag](): T = {
readKeyFunc()
key.asInstanceOf[T]
readKeyFunc(getKey()).asInstanceOf[T]
}

override def readValue[T: ClassTag](): T = {
readValueFunc()
value.asInstanceOf[T]
readValueFunc(getValue()).asInstanceOf[T]
}

override def close(): Unit = {
rowIn.close()
}
}

private[sql] class ShuffleSerializerInstance(
private[sql] class SparkSqlSerializer2Instance(
keySchema: Array[DataType],
valueSchema: Array[DataType])
valueSchema: Array[DataType],
hasKeyOrdering: Boolean)
extends SerializerInstance {

def serialize[T: ClassTag](t: T): ByteBuffer =
Expand All @@ -137,7 +151,7 @@ private[sql] class ShuffleSerializerInstance(
}

def deserializeStream(s: InputStream): DeserializationStream = {
new Serializer2DeserializationStream(keySchema, valueSchema, s)
new Serializer2DeserializationStream(keySchema, valueSchema, hasKeyOrdering, s)
}
}

Expand All @@ -148,12 +162,16 @@ private[sql] class ShuffleSerializerInstance(
* The schema of keys is represented by `keySchema` and that of values is represented by
* `valueSchema`.
*/
private[sql] class SparkSqlSerializer2(keySchema: Array[DataType], valueSchema: Array[DataType])
private[sql] class SparkSqlSerializer2(
keySchema: Array[DataType],
valueSchema: Array[DataType],
hasKeyOrdering: Boolean)
extends Serializer
with Logging
with Serializable{

def newInstance(): SerializerInstance = new ShuffleSerializerInstance(keySchema, valueSchema)
def newInstance(): SerializerInstance =
new SparkSqlSerializer2Instance(keySchema, valueSchema, hasKeyOrdering)

override def supportsRelocationOfSerializedObjects: Boolean = {
// SparkSqlSerializer2 is stateless and writes no stream headers
Expand Down Expand Up @@ -323,11 +341,11 @@ private[sql] object SparkSqlSerializer2 {
*/
def createDeserializationFunction(
schema: Array[DataType],
in: DataInputStream,
mutableRow: SpecificMutableRow): () => Unit = {
() => {
// If the schema is null, the returned function does nothing when it get called.
if (schema != null) {
in: DataInputStream): (MutableRow) => Row = {
if (schema == null) {
(mutableRow: MutableRow) => null
} else {
(mutableRow: MutableRow) => {
var i = 0
while (i < schema.length) {
schema(i) match {
Expand Down Expand Up @@ -440,6 +458,8 @@ private[sql] object SparkSqlSerializer2 {
}
i += 1
}

mutableRow
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,15 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
table("shuffle").collect())
}

test("key schema is null") {
val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",")
val df = sql(s"SELECT $aggregations FROM shuffle")
checkSerializer(df.queryExecution.executedPlan, serializerClass)
checkAnswer(
df,
Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000))
}

test("value schema is null") {
val df = sql(s"SELECT col0 FROM shuffle ORDER BY col0")
checkSerializer(df.queryExecution.executedPlan, serializerClass)
Expand All @@ -167,29 +176,20 @@ class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite {
override def beforeAll(): Unit = {
super.beforeAll()
// Sort merge will not be triggered.
sql("set spark.sql.shuffle.partitions = 200")
}

test("key schema is null") {
val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",")
val df = sql(s"SELECT $aggregations FROM shuffle")
checkSerializer(df.queryExecution.executedPlan, serializerClass)
checkAnswer(
df,
Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000))
val bypassMergeThreshold =
sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold-1}")
}
}

/** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */
class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite {

// We are expecting SparkSqlSerializer.
override val serializerClass: Class[Serializer] =
classOf[SparkSqlSerializer].asInstanceOf[Class[Serializer]]

override def beforeAll(): Unit = {
super.beforeAll()
// To trigger the sort merge.
sql("set spark.sql.shuffle.partitions = 201")
val bypassMergeThreshold =
sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold + 1}")
}
}