diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index f368209980f93..4f6f5e235811d 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -20,11 +20,14 @@ package org.apache.spark.network.nio import java.net._ import java.nio._ import java.nio.channels._ +import java.util.concurrent.ConcurrentLinkedQueue import java.util.LinkedList import org.apache.spark._ +import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.util.control.NonFatal private[nio] abstract class Connection(val channel: SocketChannel, val selector: Selector, @@ -51,7 +54,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, @volatile private var closed = false var onCloseCallback: Connection => Unit = null - var onExceptionCallback: (Connection, Exception) => Unit = null + val onExceptionCallbacks = new ConcurrentLinkedQueue[(Connection, Throwable) => Unit] var onKeyInterestChangeCallback: (Connection, Int) => Unit = null val remoteAddress = getRemoteAddress() @@ -130,20 +133,24 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, onCloseCallback = callback } - def onException(callback: (Connection, Exception) => Unit) { - onExceptionCallback = callback + def onException(callback: (Connection, Throwable) => Unit) { + onExceptionCallbacks.add(callback) } def onKeyInterestChange(callback: (Connection, Int) => Unit) { onKeyInterestChangeCallback = callback } - def callOnExceptionCallback(e: Exception) { - if (onExceptionCallback != null) { - onExceptionCallback(this, e) - } else { - logError("Error in connection to " + getRemoteConnectionManagerId() + - " and OnExceptionCallback not registered", e) + def callOnExceptionCallbacks(e: Throwable) { + onExceptionCallbacks foreach { + callback => + try { + callback(this, e) + } catch { + case NonFatal(e) => { + logWarning("Ignored error in onExceptionCallback", e) + } + } } } @@ -323,7 +330,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, } catch { case e: Exception => { logError("Error connecting to " + address, e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) } } } @@ -348,7 +355,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, } catch { case e: Exception => { logWarning("Error finishing connection to " + address, e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) } } true @@ -393,7 +400,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, } catch { case e: Exception => { logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) close() return false } @@ -420,7 +427,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, case e: Exception => logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) close() } @@ -577,7 +584,7 @@ private[spark] class ReceivingConnection( } catch { case e: Exception => { logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) close() return false } diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 01cd27a907eea..6b00190c5eccc 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -34,6 +34,8 @@ import scala.language.postfixOps import org.apache.spark._ import org.apache.spark.util.Utils +import scala.util.Try +import scala.util.control.NonFatal private[nio] class ConnectionManager( port: Int, @@ -51,14 +53,23 @@ private[nio] class ConnectionManager( class MessageStatus( val message: Message, val connectionManagerId: ConnectionManagerId, - completionHandler: MessageStatus => Unit) { + completionHandler: Try[Message] => Unit) { - /** This is non-None if message has been ack'd */ - var ackMessage: Option[Message] = None + def success(ackMessage: Message) { + if (ackMessage == null) { + failure(new NullPointerException) + } + else { + completionHandler(scala.util.Success(ackMessage)) + } + } - def markDone(ackMessage: Option[Message]) { - this.ackMessage = ackMessage - completionHandler(this) + def failWithoutAck() { + completionHandler(scala.util.Failure(new IOException("Failed without being ACK'd"))) + } + + def failure(e: Throwable) { + completionHandler(scala.util.Failure(e)) } } @@ -72,14 +83,32 @@ private[nio] class ConnectionManager( conf.getInt("spark.core.connection.handler.threads.max", 60), conf.getInt("spark.core.connection.handler.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable](), - Utils.namedThreadFactory("handle-message-executor")) + Utils.namedThreadFactory("handle-message-executor")) { + + override def afterExecute(r: Runnable, t: Throwable): Unit = { + super.afterExecute(r, t) + if (t != null && NonFatal(t)) { + logError("Error in handleMessageExecutor is not handled properly", t) + } + } + + } private val handleReadWriteExecutor = new ThreadPoolExecutor( conf.getInt("spark.core.connection.io.threads.min", 4), conf.getInt("spark.core.connection.io.threads.max", 32), conf.getInt("spark.core.connection.io.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable](), - Utils.namedThreadFactory("handle-read-write-executor")) + Utils.namedThreadFactory("handle-read-write-executor")) { + + override def afterExecute(r: Runnable, t: Throwable): Unit = { + super.afterExecute(r, t) + if (t != null && NonFatal(t)) { + logError("Error in handleReadWriteExecutor is not handled properly", t) + } + } + + } // Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : // which should be executed asap @@ -153,17 +182,24 @@ private[nio] class ConnectionManager( } handleReadWriteExecutor.execute(new Runnable { override def run() { - var register: Boolean = false try { - register = conn.write() - } finally { - writeRunnableStarted.synchronized { - writeRunnableStarted -= key - val needReregister = register || conn.resetForceReregister() - if (needReregister && conn.changeInterestForWrite()) { - conn.registerInterest() + var register: Boolean = false + try { + register = conn.write() + } finally { + writeRunnableStarted.synchronized { + writeRunnableStarted -= key + val needReregister = register || conn.resetForceReregister() + if (needReregister && conn.changeInterestForWrite()) { + conn.registerInterest() + } } } + } catch { + case NonFatal(e) => { + logError("Error when writing to " + conn.getRemoteConnectionManagerId(), e) + conn.callOnExceptionCallbacks(e) + } } } } ) @@ -187,16 +223,23 @@ private[nio] class ConnectionManager( } handleReadWriteExecutor.execute(new Runnable { override def run() { - var register: Boolean = false try { - register = conn.read() - } finally { - readRunnableStarted.synchronized { - readRunnableStarted -= key - if (register && conn.changeInterestForRead()) { - conn.registerInterest() + var register: Boolean = false + try { + register = conn.read() + } finally { + readRunnableStarted.synchronized { + readRunnableStarted -= key + if (register && conn.changeInterestForRead()) { + conn.registerInterest() + } } } + } catch { + case NonFatal(e) => { + logError("Error when reading from " + conn.getRemoteConnectionManagerId(), e) + conn.callOnExceptionCallbacks(e) + } } } } ) @@ -213,19 +256,25 @@ private[nio] class ConnectionManager( handleConnectExecutor.execute(new Runnable { override def run() { + try { + var tries: Int = 10 + while (tries >= 0) { + if (conn.finishConnect(false)) return + // Sleep ? + Thread.sleep(1) + tries -= 1 + } - var tries: Int = 10 - while (tries >= 0) { - if (conn.finishConnect(false)) return - // Sleep ? - Thread.sleep(1) - tries -= 1 + // fallback to previous behavior : we should not really come here since this method was + // triggered since channel became connectable : but at times, the first finishConnect need + // not succeed : hence the loop to retry a few 'times'. + conn.finishConnect(true) + } catch { + case NonFatal(e) => { + logError("Error when finishConnect for " + conn.getRemoteConnectionManagerId(), e) + conn.callOnExceptionCallbacks(e) + } } - - // fallback to previous behavior : we should not really come here since this method was - // triggered since channel became connectable : but at times, the first finishConnect need - // not succeed : hence the loop to retry a few 'times'. - conn.finishConnect(true) } } ) } @@ -246,16 +295,16 @@ private[nio] class ConnectionManager( handleConnectExecutor.execute(new Runnable { override def run() { try { - conn.callOnExceptionCallback(e) + conn.callOnExceptionCallbacks(e) } catch { // ignore exceptions - case e: Exception => logDebug("Ignoring exception", e) + case NonFatal(e) => logDebug("Ignoring exception", e) } try { conn.close() } catch { // ignore exceptions - case e: Exception => logDebug("Ignoring exception", e) + case NonFatal(e) => logDebug("Ignoring exception", e) } } }) @@ -448,7 +497,7 @@ private[nio] class ConnectionManager( messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId) .foreach(status => { logInfo("Notifying " + status) - status.markDone(None) + status.failWithoutAck() }) messageStatuses.retain((i, status) => { @@ -477,7 +526,7 @@ private[nio] class ConnectionManager( for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) { logInfo("Notifying " + s) - s.markDone(None) + s.failWithoutAck() } messageStatuses.retain((i, status) => { @@ -492,7 +541,7 @@ private[nio] class ConnectionManager( } } - def handleConnectionError(connection: Connection, e: Exception) { + def handleConnectionError(connection: Connection, e: Throwable) { logInfo("Handling connection error on connection to " + connection.getRemoteConnectionManagerId()) removeConnection(connection) @@ -510,9 +559,17 @@ private[nio] class ConnectionManager( val runnable = new Runnable() { val creationTime = System.currentTimeMillis def run() { - logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms") - handleMessage(connectionManagerId, message, connection) - logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms") + try { + logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms") + handleMessage(connectionManagerId, message, connection) + logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms") + } catch { + case NonFatal(e) => { + logError("Error when handling messages from " + + connection.getRemoteConnectionManagerId(), e) + connection.callOnExceptionCallbacks(e) + } + } } } handleMessageExecutor.execute(runnable) @@ -651,7 +708,7 @@ private[nio] class ConnectionManager( messageStatuses.get(bufferMessage.ackId) match { case Some(status) => { messageStatuses -= bufferMessage.ackId - status.markDone(Some(message)) + status.success(message) } case None => { /** @@ -770,6 +827,12 @@ private[nio] class ConnectionManager( val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId, newConnectionId, securityManager) + newConnection.onException { + case (conn, e) => { + logError("Exception while sending message.", e) + reportSendingMessageFailure(message.id, e) + } + } logTrace("creating new sending connection: " + newConnectionId) registerRequests.enqueue(newConnection) @@ -782,13 +845,36 @@ private[nio] class ConnectionManager( "connectionid: " + connection.connectionId) if (authEnabled) { - checkSendAuthFirst(connectionManagerId, connection) + try { + checkSendAuthFirst(connectionManagerId, connection) + } catch { + case NonFatal(e) => { + reportSendingMessageFailure(message.id, e) + } + } } logDebug("Sending [" + message + "] to [" + connectionManagerId + "]") connection.send(message) wakeupSelector() } + private def reportSendingMessageFailure(messageId: Int, e: Throwable): Unit = { + // need to tell sender it failed + messageStatuses.synchronized { + val s = messageStatuses.get(messageId) + s match { + case Some(msgStatus) => { + messageStatuses -= messageId + logInfo("Notifying " + msgStatus.connectionManagerId) + msgStatus.failure(e) + } + case None => { + logError("no messageStatus for failed message id: " + messageId) + } + } + } + } + private def wakeupSelector() { selector.wakeup() } @@ -807,9 +893,11 @@ private[nio] class ConnectionManager( override def run(): Unit = { messageStatuses.synchronized { messageStatuses.remove(message.id).foreach ( s => { - promise.failure( - new IOException("sendMessageReliably failed because ack " + - s"was not received within $ackTimeout sec")) + val e = new IOException("sendMessageReliably failed because ack " + + s"was not received within $ackTimeout sec") + if (!promise.tryFailure(e)) { + logWarning("Ignore error because promise is completed", e) + } }) } } @@ -817,15 +905,23 @@ private[nio] class ConnectionManager( val status = new MessageStatus(message, connectionManagerId, s => { timeoutTask.cancel() - s.ackMessage match { - case None => // Indicates a failure where we either never sent or never got ACK'd - promise.failure(new IOException("sendMessageReliably failed without being ACK'd")) - case Some(ackMessage) => + s match { + case scala.util.Failure(e) => + // Indicates a failure where we either never sent or never got ACK'd + if (!promise.tryFailure(e)) { + logWarning("Ignore error because promise is completed", e) + } + case scala.util.Success(ackMessage) => if (ackMessage.hasError) { - promise.failure( - new IOException("sendMessageReliably failed with ACK that signalled a remote error")) + val e = new IOException( + "sendMessageReliably failed with ACK that signalled a remote error") + if (!promise.tryFailure(e)) { + logWarning("Ignore error because promise is completed", e) + } } else { - promise.success(ackMessage) + if (!promise.trySuccess(ackMessage)) { + logWarning("Drop ackMessage because promise is completed") + } } } })