Skip to content

Commit

Permalink
Fixed bugs and added support for writeObject
Browse files Browse the repository at this point in the history
  • Loading branch information
tdas committed Jun 3, 2015
1 parent cafd505 commit 50a608d
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.serializer

import java.io.{NotSerializableException, ObjectOutput, ObjectStreamClass, ObjectStreamField}
import java.io._
import java.lang.reflect.{Field, Method}
import java.security.AccessController

Expand Down Expand Up @@ -145,17 +145,25 @@ private[spark] object SerializationDebugger extends Logging {
// An object contains multiple slots in serialization.
// Get the slots and visit fields in all of them.
val (finalObj, desc) = findObjectAndDescriptor(o)

if (!finalObj.eq(o)) {
return visit(finalObj, s"writeReplace data (class: ${finalObj.getClass.getName})" :: stack)
}

val slotDescs = desc.getSlotDescs
var i = 0
while (i < slotDescs.length) {
val slotDesc = slotDescs(i)
if (slotDesc.hasWriteObjectMethod) {
// TODO: Handle classes that specify writeObject method.
val childStack = visitSerializableWithWriteObjectMethod(finalObj, slotDesc, stack)
if (childStack.nonEmpty) {
return childStack
}
} else {
val fields: Array[ObjectStreamField] = slotDesc.getFields
val objFieldValues: Array[Object] = new Array[Object](slotDesc.getNumObjFields)
val numPrims = fields.length - objFieldValues.length
desc.getObjFieldValues(finalObj, objFieldValues)
slotDesc.getObjFieldValues(finalObj, objFieldValues)

var j = 0
while (j < objFieldValues.length) {
Expand All @@ -169,12 +177,39 @@ private[spark] object SerializationDebugger extends Logging {
}
j += 1
}

}
i += 1
}
return List.empty
}

private def visitSerializableWithWriteObjectMethod(
o: Object, slotDesc: ObjectStreamClass, stack: List[String]): List[String] = {
println(">>> processing serializable with writeObject" + o)
val innerObjectsCatcher = new ListObjectOutputStream
var notSerializableFound = false
try {
innerObjectsCatcher.writeObject(o)
} catch {
case io: IOException =>
notSerializableFound = true
}
if (notSerializableFound) {
val innerObjects = innerObjectsCatcher.outputArray
var k = 0
while (k < innerObjects.length) {
val elem = s"writeObject data (class: ${slotDesc.getName})"
val childStack = visit(innerObjects(k), elem :: stack)
if (childStack.nonEmpty) {
return childStack
}
k += 1
}
} else {
visited ++= innerObjectsCatcher.outputArray
}
return List.empty
}
}

/**
Expand Down Expand Up @@ -220,6 +255,27 @@ private[spark] object SerializationDebugger extends Logging {
override def writeByte(i: Int): Unit = {}
}

/** An output stream that emulates /dev/null */
private class NullOutputStream extends OutputStream {
override def write(b: Int) { }
}

/**
* A dummy [[ObjectOutputStream]] that saves the list of objects written to it and returns
* them through `outputArray`.
*/
private class ListObjectOutputStream extends ObjectOutputStream(new NullOutputStream) {
private val output = new mutable.ArrayBuffer[Any]
this.enableReplaceObject(true)

def outputArray: Array[Any] = output.toArray

override def replaceObject(obj: Object): Object = {
output += obj
obj
}
}

/** An implicit class that allows us to call private methods of ObjectStreamClass. */
implicit class ObjectStreamClassMethods(val desc: ObjectStreamClass) extends AnyVal {
def getSlotDescs: Array[ObjectStreamClass] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.serializer

import java.io.{ObjectOutput, ObjectInput}
import java.io.{IOException, ObjectOutputStream, ObjectOutput, ObjectInput}

import org.scalatest.BeforeAndAfterEach

Expand Down Expand Up @@ -98,14 +98,59 @@ class SerializationDebuggerSuite extends SparkFunSuite with BeforeAndAfterEach {
}

test("externalizable class writing out not serializable object") {
val s = find(new ExternalizableClass)
val s = find(new ExternalizableClass(new SerializableClass2(new NotSerializable)))
assert(s.size === 5)
assert(s(0).contains("NotSerializable"))
assert(s(1).contains("objectField"))
assert(s(2).contains("SerializableClass2"))
assert(s(3).contains("writeExternal"))
assert(s(4).contains("ExternalizableClass"))
}

test("externalizable class writing out serializable objects") {
assert(find(new ExternalizableClass(new SerializableClass1)).isEmpty)
}

test("object containing writeReplace() which returns not serializable object") {
val s = find(new SerializableClassWithWriteReplace(new NotSerializable))
println("-----\n" + s.zipWithIndex.mkString("\n") + "\n----")
assert(s.size === 3)
assert(s(0).contains("NotSerializable"))
assert(s(1).contains("writeReplace"))
assert(s(2).contains("SerializableClassWithWriteReplace"))
}

test("object containing writeObject() and not serializable field") {
val s = find(new SerializableClassWithWriteObject(new NotSerializable))
println("-----\n" + s.zipWithIndex.mkString("\n") + "\n----")
assert(s.size === 3)
assert(s(0).contains("NotSerializable"))
assert(s(1).contains("writeObject data"))
assert(s(2).contains("SerializableClassWithWriteObject"))
}

test("object containing writeObject() and serializable field") {
assert(find(new SerializableClassWithWriteObject(new SerializableClass1)).isEmpty)
}


test("object of serializable subclass with more fields than superclass (SPARK-7180)") {
// This should not throw ArrayOutOfBoundsException
find(new SerializableSubclass(new SerializableClass1))
}

test("crazy nested objects") {
val s = find(
new SerializableClassWithWriteReplace(
new ExternalizableClass(
new SerializableSubclass(
new SerializableArray(
Array(new SerializableClass1, new SerializableClass2(new NotSerializable))
))))
)
assert(s.nonEmpty)
assert(s.head.contains("NotSerializable"))
}
}


Expand All @@ -118,10 +163,34 @@ class SerializableClass2(val objectField: Object) extends Serializable
class SerializableArray(val arrayField: Array[Object]) extends Serializable


class ExternalizableClass extends java.io.Externalizable {
class SerializableSubclass(val objectField: Object) extends SerializableClass1


class SerializableClassWithWriteObject(val objectField: Object) extends Serializable {
val serializableObjectField = new SerializableClass1

@throws(classOf[IOException])
private def writeObject(oos: ObjectOutputStream): Unit = {
oos.defaultWriteObject()
}
}


class SerializableClassWithWriteReplace(@transient replacementFieldObject: Object)
extends Serializable {
private def writeReplace(): Object = {
replacementFieldObject
}
}


class ExternalizableClass(objectField: Object) extends java.io.Externalizable {
val serializableObjectField = new SerializableClass1

override def writeExternal(out: ObjectOutput): Unit = {
out.writeInt(1)
out.writeObject(new SerializableClass2(new NotSerializable))
out.writeObject(serializableObjectField)
out.writeObject(objectField)
}

override def readExternal(in: ObjectInput): Unit = {}
Expand Down

0 comments on commit 50a608d

Please sign in to comment.