From 29ed50c855459cf9b4cacb610f0a6cb4aac53ccb Mon Sep 17 00:00:00 2001 From: Thomas Wolf Date: Fri, 11 Nov 2022 21:07:38 +0100 Subject: [PATCH] GH-268: Fix ClientConnectionService.sendHeartBeat() If a reply is requested, but none arrives within the timeout, throw an exception and terminate the connection. This rolls back the changes made in ClientConnectionService in commit de2f8fef. It's quite all right to use the synchronous implementation of Session.request() because heartbeats are not sent from I/O threads. However, commit de2f8fef broke the contract specified in interface Session, which says Session.request() must return "the buffer if the request was successful, {@code null} otherwise." The implementation from commit de2f8fef threw an exception instead. This was wrong, and is corrected now in this commit. Note that this means that a caller cannot distinguish between the server replying SSH_MSG_UNIMPLEMENTED or SSH_MSG_REQUEST_FAILURE. Where such distinction is needed, use an asynchronous global request instead. Bug: https://github.com/apache/mina-sshd/issues/268 --- CHANGES.md | 2 + .../session/ClientConnectionService.java | 33 ++---------- .../common/future/GlobalRequestFuture.java | 12 +---- .../common/global/GlobalRequestException.java | 53 +++++++++++++++++++ .../session/helpers/AbstractSession.java | 16 ++++-- .../java/org/apache/sshd/KeepAliveTest.java | 31 +++++++++++ .../common/session/GlobalRequestTest.java | 17 ++++-- 7 files changed, 116 insertions(+), 48 deletions(-) create mode 100644 sshd-core/src/main/java/org/apache/sshd/common/global/GlobalRequestException.java diff --git a/CHANGES.md b/CHANGES.md index 58dd8edb9..d4a3dc172 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -24,6 +24,8 @@ ## Bug fixes +* [GH-268](https://github.com/apache/mina-sshd/issues/268) (Regression in 2.9.0) Heartbeat should throw an exception if no reply arrives within the timeout. + ## Major code re-factoring ## Potential compatibility issues diff --git a/sshd-core/src/main/java/org/apache/sshd/client/session/ClientConnectionService.java b/sshd-core/src/main/java/org/apache/sshd/client/session/ClientConnectionService.java index 25e79baa4..6017d3feb 100644 --- a/sshd-core/src/main/java/org/apache/sshd/client/session/ClientConnectionService.java +++ b/sshd-core/src/main/java/org/apache/sshd/client/session/ClientConnectionService.java @@ -20,8 +20,6 @@ import java.io.IOException; import java.time.Duration; -import java.time.Instant; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; @@ -30,7 +28,6 @@ import org.apache.sshd.common.FactoryManager; import org.apache.sshd.common.SshConstants; import org.apache.sshd.common.SshException; -import org.apache.sshd.common.future.GlobalRequestFuture; import org.apache.sshd.common.io.IoWriteFuture; import org.apache.sshd.common.session.Session; import org.apache.sshd.common.session.helpers.AbstractConnectionService; @@ -128,33 +125,11 @@ protected boolean sendHeartBeat() { buf.putBoolean(withReply); if (withReply) { - Instant start = Instant.now(); - CountDownLatch replyReceived = new CountDownLatch(1); - GlobalRequestFuture writeFuture = session.request(buf, heartbeatRequest, (cmd, reply) -> { - replyReceived.countDown(); + Buffer reply = session.request(heartbeatRequest, buf, heartbeatReplyMaxWait); + if (reply != null) { if (log.isTraceEnabled()) { - log.trace("sendHeartBeat({}) received reply={} size={} for request={}", session, - SshConstants.getCommandMessageName(cmd), reply.available(), heartbeatRequest); - } - }); - writeFuture.await(heartbeatReplyMaxWait); - Throwable t = writeFuture.getException(); - if (t != null) { - // We couldn't even send the request. - throw new IOException(t.getMessage(), t); - } - Duration elapsed = Duration.between(start, Instant.now()); - if (elapsed.compareTo(heartbeatReplyMaxWait) < 0) { - long toWait = heartbeatReplyMaxWait.minus(elapsed).toMillis(); - if (toWait > 0) { - try { - replyReceived.await(toWait, TimeUnit.MILLISECONDS); - } catch (InterruptedException e) { - if (log.isTraceEnabled()) { - log.trace("sendHeartBeat({}) interrupted waiting for reply to request={}", session, - heartbeatRequest); - } - } + log.trace("sendHeartBeat({}) received reply size={} for request={}", + session, reply.available(), heartbeatRequest); } } } else { diff --git a/sshd-core/src/main/java/org/apache/sshd/common/future/GlobalRequestFuture.java b/sshd-core/src/main/java/org/apache/sshd/common/future/GlobalRequestFuture.java index f3a646c7d..837224e8f 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/future/GlobalRequestFuture.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/future/GlobalRequestFuture.java @@ -21,7 +21,6 @@ import org.apache.sshd.common.SshConstants; import org.apache.sshd.common.SshException; import org.apache.sshd.common.io.IoWriteFuture; -import org.apache.sshd.common.util.GenericUtils; import org.apache.sshd.common.util.buffer.Buffer; /** @@ -109,15 +108,6 @@ public void setSequenceNumber(long number) { sequenceNumber = number; } - /** - * Fulfills this future, marking it as failed. - * - * @param message An explanation of the failure reason - */ - public void fail(String message) { - setValue(new SshException(GenericUtils.isEmpty(message) ? "Global request failure; unknown reason" : message)); - } - /** * Retrieves the {@link ReplyHandler} of this future, if any. * @@ -163,7 +153,7 @@ public void operationComplete(IoWriteFuture future) { if (ioe != null) { setValue(ioe); } else { - fail("Could not write global request " + getId() + " seqNo=" + getSequenceNumber()); + setValue(new SshException("Could not write global request " + getId() + " seqNo=" + getSequenceNumber())); } } } diff --git a/sshd-core/src/main/java/org/apache/sshd/common/global/GlobalRequestException.java b/sshd-core/src/main/java/org/apache/sshd/common/global/GlobalRequestException.java new file mode 100644 index 000000000..1d1342469 --- /dev/null +++ b/sshd-core/src/main/java/org/apache/sshd/common/global/GlobalRequestException.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sshd.common.global; + +import org.apache.sshd.common.SshConstants; + +/** + * An exception that can be set on a {@link org.apache.sshd.common.future.GlobalRequestFuture} to indicate that the + * server sent back a failure reply. + */ +public class GlobalRequestException extends Exception { + + private static final long serialVersionUID = 225802262556424684L; + + private final int code; + + /** + * Creates a new {@link GlobalRequestException} with the given SSH message code. + * + * @param code SSH message code to set; normally {@link SshConstants#SSH_MSG_UNIMPLEMENTED} or + * {@link SshConstants#SSH_MSG_REQUEST_FAILURE} + */ + public GlobalRequestException(int code) { + super(SshConstants.getCommandMessageName(code)); + this.code = code; + } + + /** + * Retrieves the SSH message code. + * + * @return the code, normally {@link SshConstants#SSH_MSG_UNIMPLEMENTED} or + * {@link SshConstants#SSH_MSG_REQUEST_FAILURE} + */ + public int getCode() { + return code; + } +} diff --git a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java index 8b51d8eba..908041698 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java @@ -66,6 +66,7 @@ import org.apache.sshd.common.future.DefaultSshFuture; import org.apache.sshd.common.future.GlobalRequestFuture; import org.apache.sshd.common.future.KeyExchangeFuture; +import org.apache.sshd.common.global.GlobalRequestException; import org.apache.sshd.common.io.IoSession; import org.apache.sshd.common.io.IoWriteFuture; import org.apache.sshd.common.kex.KexProposalOption; @@ -972,7 +973,7 @@ protected void preClose() { if (debugEnabled) { log.debug("preClose({}): Session closing; failing still pending global request {}", this, future.getId()); } - future.fail("Session is closing"); + future.setValue(new SshException("Session is closing")); } // Fire 'close' event @@ -1179,6 +1180,15 @@ public Buffer request(String request, Buffer buffer, long maxWaitMillis) throws if (!done || result == null) { throw new SocketTimeoutException("No response received after " + maxWaitMillis + "ms for request=" + request); } + // The operation is specified to return null if the request could be made, but got an error reply. + // The caller cannot distinguish between SSH_MSG_UNIMPLEMENTED and SSH_MSG_REQUEST_FAILURE. + if (result instanceof GlobalRequestException) { + if (debugEnabled) { + log.debug("request({}) request={}, requestSeqNo={}: received={}", this, request, future.getSequenceNumber(), + SshConstants.getCommandMessageName(((GlobalRequestException) result).getCode())); + } + return null; + } } if (result instanceof Throwable) { @@ -1290,7 +1300,7 @@ protected boolean doInvokeUnimplementedMessageHandler(int cmd, Buffer buffer) th Buffer resultBuf = ByteArrayBuffer.getCompactClone(buffer.array(), buffer.rpos(), buffer.available()); handler.accept(cmd, resultBuf); } else { - future.fail(SshConstants.getCommandMessageName(cmd)); + future.setValue(new GlobalRequestException(cmd)); } return true; // message handled internally } else if (future != null) { @@ -2204,7 +2214,7 @@ protected void requestFailure(Buffer buffer) throws Exception { Buffer resultBuf = ByteArrayBuffer.getCompactClone(buffer.array(), buffer.rpos(), buffer.available()); handler.accept(SshConstants.SSH_MSG_REQUEST_FAILURE, resultBuf); } else { - request.fail(SshConstants.getCommandMessageName(SshConstants.SSH_MSG_REQUEST_FAILURE)); + request.setValue(new GlobalRequestException(SshConstants.SSH_MSG_REQUEST_FAILURE)); } } } diff --git a/sshd-core/src/test/java/org/apache/sshd/KeepAliveTest.java b/sshd-core/src/test/java/org/apache/sshd/KeepAliveTest.java index 3d5e16fe2..42bdeb015 100644 --- a/sshd-core/src/test/java/org/apache/sshd/KeepAliveTest.java +++ b/sshd-core/src/test/java/org/apache/sshd/KeepAliveTest.java @@ -217,6 +217,37 @@ public Result process( } } + @Test // see GH-268 + public void testTimeoutOnMissingHeartbeatResponse() throws Exception { + CoreModuleProperties.IDLE_TIMEOUT.set(sshd, Duration.ofSeconds(30)); + List> globalHandlers = sshd.getGlobalRequestHandlers(); + sshd.setGlobalRequestHandlers(Collections.singletonList(new AbstractConnectionServiceRequestHandler() { + @Override + public Result process(ConnectionService connectionService, String request, boolean wantReply, Buffer buffer) + throws Exception { + // Never reply; + return Result.Replied; + } + })); + CoreModuleProperties.HEARTBEAT_INTERVAL.set(client, HEARTBEAT); + CoreModuleProperties.HEARTBEAT_REPLY_WAIT.set(client, Duration.ofSeconds(1)); + try (ClientSession session = client.connect(getCurrentTestName(), TEST_LOCALHOST, port).verify(CONNECT_TIMEOUT) + .getSession()) { + session.addPasswordIdentity(getCurrentTestName()); + session.auth().verify(AUTH_TIMEOUT); + + try (ClientChannel channel = session.createChannel(Channel.CHANNEL_SHELL)) { + long waitStart = System.currentTimeMillis(); + Collection result = channel.waitFor(EnumSet.of(ClientChannelEvent.CLOSED), TIMEOUT); + long waitEnd = System.currentTimeMillis(); + assertTrue("Wrong channel state after wait of " + (waitEnd - waitStart) + " ms: " + result, + result.contains(ClientChannelEvent.CLOSED)); + } + } finally { + sshd.setGlobalRequestHandlers(globalHandlers); // restore original + } + } + public static class TestEchoShellFactory extends EchoShellFactory { public TestEchoShellFactory() { super(); diff --git a/sshd-core/src/test/java/org/apache/sshd/common/session/GlobalRequestTest.java b/sshd-core/src/test/java/org/apache/sshd/common/session/GlobalRequestTest.java index 1e99203e1..3765031e9 100644 --- a/sshd-core/src/test/java/org/apache/sshd/common/session/GlobalRequestTest.java +++ b/sshd-core/src/test/java/org/apache/sshd/common/session/GlobalRequestTest.java @@ -35,6 +35,7 @@ import org.apache.sshd.common.channel.RequestHandler; import org.apache.sshd.common.config.keys.KeyUtils; import org.apache.sshd.common.future.GlobalRequestFuture; +import org.apache.sshd.common.global.GlobalRequestException; import org.apache.sshd.common.session.helpers.AbstractConnectionServiceRequestHandler; import org.apache.sshd.common.signature.Signature; import org.apache.sshd.common.util.GenericUtils; @@ -197,7 +198,7 @@ public Result process(ConnectionService connectionService, String request, boole case 0: { j++; Buffer reply = request.getBuffer(); - assertTrue("Expected success for request " + i, reply != null); + assertNotNull("Expected success for request " + i, reply); assertEquals("Expected a success", (byte) ('1' + j - 1), reply.getByte()); break; } @@ -205,14 +206,20 @@ public Result process(ConnectionService connectionService, String request, boole j++; failure = request.getException(); assertNotNull("Expected failure for request " + i, failure); - assertEquals("Unexpected failure reason for request " + i, "SSH_MSG_REQUEST_FAILURE", - failure.getMessage()); + assertTrue("Unexpected failure type", failure instanceof GlobalRequestException); + assertEquals("Unexpected failure reason for request " + i, SshConstants.SSH_MSG_REQUEST_FAILURE, + ((GlobalRequestException) failure).getCode()); + assertTrue("Unexpected failure message for request " + i, + failure.getMessage().contains("SSH_MSG_REQUEST_FAILURE")); break; default: failure = request.getException(); assertNotNull("Expected failure for request " + i, failure); - assertEquals("Unexpected failure reason for request " + i, "SSH_MSG_UNIMPLEMENTED", - failure.getMessage()); + assertTrue("Unexpected failure type", failure instanceof GlobalRequestException); + assertEquals("Unexpected failure reason for request " + i, SshConstants.SSH_MSG_UNIMPLEMENTED, + ((GlobalRequestException) failure).getCode()); + assertTrue("Unexpected failure message for request " + i, + failure.getMessage().contains("SSH_MSG_UNIMPLEMENTED")); break; } }