diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index fa8a337ad63a8..c5f6062a926e7 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -39,7 +39,12 @@ private[spark] class JavaSerializationStream(out: OutputStream, counterReset: In * the stream 'resets' object class descriptions have to be re-written) */ def writeObject[T: ClassTag](t: T): SerializationStream = { - objOut.writeObject(t) + try { + objOut.writeObject(t) + } catch { + case e: NotSerializableException => + throw SerializationDebugger.improveException(t, e) + } counter += 1 if (counterReset > 0 && counter >= counterReset) { objOut.reset() diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala new file mode 100644 index 0000000000000..cea7d2a864bef --- /dev/null +++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala @@ -0,0 +1,306 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import java.io.{NotSerializableException, ObjectOutput, ObjectStreamClass, ObjectStreamField} +import java.lang.reflect.{Field, Method} +import java.security.AccessController + +import scala.annotation.tailrec +import scala.collection.mutable + +import org.apache.spark.Logging + +private[serializer] object SerializationDebugger extends Logging { + + /** + * Improve the given NotSerializableException with the serialization path leading from the given + * object to the problematic object. + */ + def improveException(obj: Any, e: NotSerializableException): NotSerializableException = { + if (enableDebugging && reflect != null) { + new NotSerializableException( + e.getMessage + "\nSerialization stack:\n" + find(obj).map("\t- " + _).mkString("\n")) + } else { + e + } + } + + /** + * Find the path leading to a not serializable object. This method is modeled after OpenJDK's + * serialization mechanism, and handles the following cases: + * - primitives + * - arrays of primitives + * - arrays of non-primitive objects + * - Serializable objects + * - Externalizable objects + * - writeReplace + * + * It does not yet handle writeObject override, but that shouldn't be too hard to do either. + */ + def find(obj: Any): List[String] = { + new SerializationDebugger().visit(obj, List.empty) + } + + private[serializer] var enableDebugging: Boolean = { + !AccessController.doPrivileged(new sun.security.action.GetBooleanAction( + "sun.io.serialization.extendedDebugInfo")).booleanValue() + } + + private class SerializationDebugger { + + /** A set to track the list of objects we have visited, to avoid cycles in the graph. */ + private val visited = new mutable.HashSet[Any] + + /** + * Visit the object and its fields and stop when we find an object that is not serializable. + * Return the path as a list. If everything can be serialized, return an empty list. + */ + def visit(o: Any, stack: List[String]): List[String] = { + if (o == null) { + List.empty + } else if (visited.contains(o)) { + List.empty + } else { + visited += o + o match { + // Primitive value, string, and primitive arrays are always serializable + case _ if o.getClass.isPrimitive => List.empty + case _: String => List.empty + case _ if o.getClass.isArray && o.getClass.getComponentType.isPrimitive => List.empty + + // Traverse non primitive array. + case a: Array[_] if o.getClass.isArray && !o.getClass.getComponentType.isPrimitive => + val elem = s"array (class ${a.getClass.getName}, size ${a.length})" + visitArray(o.asInstanceOf[Array[_]], elem :: stack) + + case e: java.io.Externalizable => + val elem = s"externalizable object (class ${e.getClass.getName}, $e)" + visitExternalizable(e, elem :: stack) + + case s: Object with java.io.Serializable => + val elem = s"object (class ${s.getClass.getName}, $s)" + visitSerializable(s, elem :: stack) + + case _ => + // Found an object that is not serializable! + s"object not serializable (class: ${o.getClass.getName}, value: $o)" :: stack + } + } + } + + private def visitArray(o: Array[_], stack: List[String]): List[String] = { + var i = 0 + while (i < o.length) { + val childStack = visit(o(i), s"element of array (index: $i)" :: stack) + if (childStack.nonEmpty) { + return childStack + } + i += 1 + } + return List.empty + } + + private def visitExternalizable(o: java.io.Externalizable, stack: List[String]): List[String] = + { + val fieldList = new ListObjectOutput + o.writeExternal(fieldList) + val childObjects = fieldList.outputArray + var i = 0 + while (i < childObjects.length) { + val childStack = visit(childObjects(i), "writeExternal data" :: stack) + if (childStack.nonEmpty) { + return childStack + } + i += 1 + } + return List.empty + } + + private def visitSerializable(o: Object, stack: List[String]): List[String] = { + // An object contains multiple slots in serialization. + // Get the slots and visit fields in all of them. + val (finalObj, desc) = findObjectAndDescriptor(o) + val slotDescs = desc.getSlotDescs + var i = 0 + while (i < slotDescs.length) { + val slotDesc = slotDescs(i) + if (slotDesc.hasWriteObjectMethod) { + // TODO: Handle classes that specify writeObject method. + } else { + val fields: Array[ObjectStreamField] = slotDesc.getFields + val objFieldValues: Array[Object] = new Array[Object](slotDesc.getNumObjFields) + val numPrims = fields.length - objFieldValues.length + desc.getObjFieldValues(finalObj, objFieldValues) + + var j = 0 + while (j < objFieldValues.length) { + val fieldDesc = fields(numPrims + j) + val elem = s"field (class: ${slotDesc.getName}" + + s", name: ${fieldDesc.getName}" + + s", type: ${fieldDesc.getType})" + val childStack = visit(objFieldValues(j), elem :: stack) + if (childStack.nonEmpty) { + return childStack + } + j += 1 + } + + } + i += 1 + } + return List.empty + } + } + + /** + * Find the object to serialize and the associated [[ObjectStreamClass]]. This method handles + * writeReplace in Serializable. It starts with the object itself, and keeps calling the + * writeReplace method until there is no more + */ + @tailrec + private def findObjectAndDescriptor(o: Object): (Object, ObjectStreamClass) = { + val cl = o.getClass + val desc = ObjectStreamClass.lookupAny(cl) + if (!desc.hasWriteReplaceMethod) { + (o, desc) + } else { + // write place + findObjectAndDescriptor(desc.invokeWriteReplace(o)) + } + } + + /** + * A dummy [[ObjectOutput]] that simply saves the list of objects written by a writeExternal + * call, and returns them through `outputArray`. + */ + private class ListObjectOutput extends ObjectOutput { + private val output = new mutable.ArrayBuffer[Any] + def outputArray: Array[Any] = output.toArray + override def writeObject(o: Any): Unit = output += o + override def flush(): Unit = {} + override def write(i: Int): Unit = {} + override def write(bytes: Array[Byte]): Unit = {} + override def write(bytes: Array[Byte], i: Int, i1: Int): Unit = {} + override def close(): Unit = {} + override def writeFloat(v: Float): Unit = {} + override def writeChars(s: String): Unit = {} + override def writeDouble(v: Double): Unit = {} + override def writeUTF(s: String): Unit = {} + override def writeShort(i: Int): Unit = {} + override def writeInt(i: Int): Unit = {} + override def writeBoolean(b: Boolean): Unit = {} + override def writeBytes(s: String): Unit = {} + override def writeChar(i: Int): Unit = {} + override def writeLong(l: Long): Unit = {} + override def writeByte(i: Int): Unit = {} + } + + /** An implicit class that allows us to call private methods of ObjectStreamClass. */ + implicit class ObjectStreamClassMethods(val desc: ObjectStreamClass) extends AnyVal { + def getSlotDescs: Array[ObjectStreamClass] = { + reflect.GetClassDataLayout.invoke(desc).asInstanceOf[Array[Object]].map { + classDataSlot => reflect.DescField.get(classDataSlot).asInstanceOf[ObjectStreamClass] + } + } + + def hasWriteObjectMethod: Boolean = { + reflect.HasWriteObjectMethod.invoke(desc).asInstanceOf[Boolean] + } + + def hasWriteReplaceMethod: Boolean = { + reflect.HasWriteReplaceMethod.invoke(desc).asInstanceOf[Boolean] + } + + def invokeWriteReplace(obj: Object): Object = { + reflect.InvokeWriteReplace.invoke(desc, obj) + } + + def getNumObjFields: Int = { + reflect.GetNumObjFields.invoke(desc).asInstanceOf[Int] + } + + def getObjFieldValues(obj: Object, out: Array[Object]): Unit = { + reflect.GetObjFieldValues.invoke(desc, obj, out) + } + } + + /** + * Object to hold all the reflection objects. If we run on a JVM that we cannot understand, + * this field will be null and this the debug helper should be disabled. + */ + private val reflect: ObjectStreamClassReflection = try { + new ObjectStreamClassReflection + } catch { + case e: Exception => + logWarning("Cannot find private methods using reflection", e) + null + } + + private class ObjectStreamClassReflection { + /** ObjectStreamClass.getClassDataLayout */ + val GetClassDataLayout: Method = { + val f = classOf[ObjectStreamClass].getDeclaredMethod("getClassDataLayout") + f.setAccessible(true) + f + } + + /** ObjectStreamClass.hasWriteObjectMethod */ + val HasWriteObjectMethod: Method = { + val f = classOf[ObjectStreamClass].getDeclaredMethod("hasWriteObjectMethod") + f.setAccessible(true) + f + } + + /** ObjectStreamClass.hasWriteReplaceMethod */ + val HasWriteReplaceMethod: Method = { + val f = classOf[ObjectStreamClass].getDeclaredMethod("hasWriteReplaceMethod") + f.setAccessible(true) + f + } + + /** ObjectStreamClass.invokeWriteReplace */ + val InvokeWriteReplace: Method = { + val f = classOf[ObjectStreamClass].getDeclaredMethod("invokeWriteReplace", classOf[Object]) + f.setAccessible(true) + f + } + + /** ObjectStreamClass.getNumObjFields */ + val GetNumObjFields: Method = { + val f = classOf[ObjectStreamClass].getDeclaredMethod("getNumObjFields") + f.setAccessible(true) + f + } + + /** ObjectStreamClass.getObjFieldValues */ + val GetObjFieldValues: Method = { + val f = classOf[ObjectStreamClass].getDeclaredMethod( + "getObjFieldValues", classOf[Object], classOf[Array[Object]]) + f.setAccessible(true) + f + } + + /** ObjectStreamClass$ClassDataSlot.desc field */ + val DescField: Field = { + val f = Class.forName("java.io.ObjectStreamClass$ClassDataSlot").getDeclaredField("desc") + f.setAccessible(true) + f + } + } +} diff --git a/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala new file mode 100644 index 0000000000000..e62828c4fbac6 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import java.io.{ObjectOutput, ObjectInput} + +import org.scalatest.{BeforeAndAfterEach, FunSuite} + + +class SerializationDebuggerSuite extends FunSuite with BeforeAndAfterEach { + + import SerializationDebugger.find + + override def beforeEach(): Unit = { + SerializationDebugger.enableDebugging = true + } + + test("primitives, strings, and nulls") { + assert(find(1) === List.empty) + assert(find(1L) === List.empty) + assert(find(1.toShort) === List.empty) + assert(find(1.0) === List.empty) + assert(find("1") === List.empty) + assert(find(null) === List.empty) + } + + test("primitive arrays") { + assert(find(Array[Int](1, 2)) === List.empty) + assert(find(Array[Long](1, 2)) === List.empty) + } + + test("non-primitive arrays") { + assert(find(Array("aa", "bb")) === List.empty) + assert(find(Array(new SerializableClass1)) === List.empty) + } + + test("serializable object") { + assert(find(new Foo(1, "b", 'c', 'd', null, null, null)) === List.empty) + } + + test("nested arrays") { + val foo1 = new Foo(1, "b", 'c', 'd', null, null, null) + val foo2 = new Foo(1, "b", 'c', 'd', null, Array(foo1), null) + assert(find(new Foo(1, "b", 'c', 'd', null, Array(foo2), null)) === List.empty) + } + + test("nested objects") { + val foo1 = new Foo(1, "b", 'c', 'd', null, null, null) + val foo2 = new Foo(1, "b", 'c', 'd', null, null, foo1) + assert(find(new Foo(1, "b", 'c', 'd', null, null, foo2)) === List.empty) + } + + test("cycles (should not loop forever)") { + val foo1 = new Foo(1, "b", 'c', 'd', null, null, null) + foo1.g = foo1 + assert(find(new Foo(1, "b", 'c', 'd', null, null, foo1)) === List.empty) + } + + test("root object not serializable") { + val s = find(new NotSerializable) + assert(s.size === 1) + assert(s.head.contains("NotSerializable")) + } + + test("array containing not serializable element") { + val s = find(new SerializableArray(Array(new NotSerializable))) + assert(s.size === 5) + assert(s(0).contains("NotSerializable")) + assert(s(1).contains("element of array")) + assert(s(2).contains("array")) + assert(s(3).contains("arrayField")) + assert(s(4).contains("SerializableArray")) + } + + test("object containing not serializable field") { + val s = find(new SerializableClass2(new NotSerializable)) + assert(s.size === 3) + assert(s(0).contains("NotSerializable")) + assert(s(1).contains("objectField")) + assert(s(2).contains("SerializableClass2")) + } + + test("externalizable class writing out not serializable object") { + val s = find(new ExternalizableClass) + assert(s.size === 5) + assert(s(0).contains("NotSerializable")) + assert(s(1).contains("objectField")) + assert(s(2).contains("SerializableClass2")) + assert(s(3).contains("writeExternal")) + assert(s(4).contains("ExternalizableClass")) + } +} + + +class SerializableClass1 extends Serializable + + +class SerializableClass2(val objectField: Object) extends Serializable + + +class SerializableArray(val arrayField: Array[Object]) extends Serializable + + +class ExternalizableClass extends java.io.Externalizable { + override def writeExternal(out: ObjectOutput): Unit = { + out.writeInt(1) + out.writeObject(new SerializableClass2(new NotSerializable)) + } + + override def readExternal(in: ObjectInput): Unit = {} +} + + +class Foo( + a: Int, + b: String, + c: Char, + d: Byte, + e: Array[Int], + f: Array[Object], + var g: Foo) extends Serializable + + +class NotSerializable