Skip to content

Commit

Permalink
apacheGH-268: Fix ClientConnectionService.sendHeartBeat()
Browse files Browse the repository at this point in the history
If a reply is requested, but none arrives within the timeout, throw an
exception and terminate the connection.

This rolls back the changes made in ClientConnectionService in commit
de2f8fe. It's quite all right to use the synchronous implementation of
Session.request() because heartbeats are not sent from I/O threads.

However, commit de2f8fe broke the contract specified in interface
Session, which says Session.request() must return "the buffer if the
request was successful, {@code null} otherwise." The implementation
from commit de2f8fe threw an exception instead.

This was wrong, and is corrected now in this commit. Note that this
means that a caller cannot distinguish between the server replying
SSH_MSG_UNIMPLEMENTED or SSH_MSG_REQUEST_FAILURE.

Where such distinction is needed, use an asynchronous global request
instead.

Bug: apache#268
  • Loading branch information
tomaswolf committed Nov 23, 2022
1 parent 45fd3a4 commit 29ed50c
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 48 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

## Bug fixes

* [GH-268](https://github.com/apache/mina-sshd/issues/268) (Regression in 2.9.0) Heartbeat should throw an exception if no reply arrives within the timeout.

## Major code re-factoring

## Potential compatibility issues
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

import java.io.IOException;
import java.time.Duration;
import java.time.Instant;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
Expand All @@ -30,7 +28,6 @@
import org.apache.sshd.common.FactoryManager;
import org.apache.sshd.common.SshConstants;
import org.apache.sshd.common.SshException;
import org.apache.sshd.common.future.GlobalRequestFuture;
import org.apache.sshd.common.io.IoWriteFuture;
import org.apache.sshd.common.session.Session;
import org.apache.sshd.common.session.helpers.AbstractConnectionService;
Expand Down Expand Up @@ -128,33 +125,11 @@ protected boolean sendHeartBeat() {
buf.putBoolean(withReply);

if (withReply) {
Instant start = Instant.now();
CountDownLatch replyReceived = new CountDownLatch(1);
GlobalRequestFuture writeFuture = session.request(buf, heartbeatRequest, (cmd, reply) -> {
replyReceived.countDown();
Buffer reply = session.request(heartbeatRequest, buf, heartbeatReplyMaxWait);
if (reply != null) {
if (log.isTraceEnabled()) {
log.trace("sendHeartBeat({}) received reply={} size={} for request={}", session,
SshConstants.getCommandMessageName(cmd), reply.available(), heartbeatRequest);
}
});
writeFuture.await(heartbeatReplyMaxWait);
Throwable t = writeFuture.getException();
if (t != null) {
// We couldn't even send the request.
throw new IOException(t.getMessage(), t);
}
Duration elapsed = Duration.between(start, Instant.now());
if (elapsed.compareTo(heartbeatReplyMaxWait) < 0) {
long toWait = heartbeatReplyMaxWait.minus(elapsed).toMillis();
if (toWait > 0) {
try {
replyReceived.await(toWait, TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
if (log.isTraceEnabled()) {
log.trace("sendHeartBeat({}) interrupted waiting for reply to request={}", session,
heartbeatRequest);
}
}
log.trace("sendHeartBeat({}) received reply size={} for request={}",
session, reply.available(), heartbeatRequest);
}
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.apache.sshd.common.SshConstants;
import org.apache.sshd.common.SshException;
import org.apache.sshd.common.io.IoWriteFuture;
import org.apache.sshd.common.util.GenericUtils;
import org.apache.sshd.common.util.buffer.Buffer;

/**
Expand Down Expand Up @@ -109,15 +108,6 @@ public void setSequenceNumber(long number) {
sequenceNumber = number;
}

/**
* Fulfills this future, marking it as failed.
*
* @param message An explanation of the failure reason
*/
public void fail(String message) {
setValue(new SshException(GenericUtils.isEmpty(message) ? "Global request failure; unknown reason" : message));
}

/**
* Retrieves the {@link ReplyHandler} of this future, if any.
*
Expand Down Expand Up @@ -163,7 +153,7 @@ public void operationComplete(IoWriteFuture future) {
if (ioe != null) {
setValue(ioe);
} else {
fail("Could not write global request " + getId() + " seqNo=" + getSequenceNumber());
setValue(new SshException("Could not write global request " + getId() + " seqNo=" + getSequenceNumber()));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sshd.common.global;

import org.apache.sshd.common.SshConstants;

/**
* An exception that can be set on a {@link org.apache.sshd.common.future.GlobalRequestFuture} to indicate that the
* server sent back a failure reply.
*/
public class GlobalRequestException extends Exception {

private static final long serialVersionUID = 225802262556424684L;

private final int code;

/**
* Creates a new {@link GlobalRequestException} with the given SSH message code.
*
* @param code SSH message code to set; normally {@link SshConstants#SSH_MSG_UNIMPLEMENTED} or
* {@link SshConstants#SSH_MSG_REQUEST_FAILURE}
*/
public GlobalRequestException(int code) {
super(SshConstants.getCommandMessageName(code));
this.code = code;
}

/**
* Retrieves the SSH message code.
*
* @return the code, normally {@link SshConstants#SSH_MSG_UNIMPLEMENTED} or
* {@link SshConstants#SSH_MSG_REQUEST_FAILURE}
*/
public int getCode() {
return code;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
import org.apache.sshd.common.future.DefaultSshFuture;
import org.apache.sshd.common.future.GlobalRequestFuture;
import org.apache.sshd.common.future.KeyExchangeFuture;
import org.apache.sshd.common.global.GlobalRequestException;
import org.apache.sshd.common.io.IoSession;
import org.apache.sshd.common.io.IoWriteFuture;
import org.apache.sshd.common.kex.KexProposalOption;
Expand Down Expand Up @@ -972,7 +973,7 @@ protected void preClose() {
if (debugEnabled) {
log.debug("preClose({}): Session closing; failing still pending global request {}", this, future.getId());
}
future.fail("Session is closing");
future.setValue(new SshException("Session is closing"));
}

// Fire 'close' event
Expand Down Expand Up @@ -1179,6 +1180,15 @@ public Buffer request(String request, Buffer buffer, long maxWaitMillis) throws
if (!done || result == null) {
throw new SocketTimeoutException("No response received after " + maxWaitMillis + "ms for request=" + request);
}
// The operation is specified to return null if the request could be made, but got an error reply.
// The caller cannot distinguish between SSH_MSG_UNIMPLEMENTED and SSH_MSG_REQUEST_FAILURE.
if (result instanceof GlobalRequestException) {
if (debugEnabled) {
log.debug("request({}) request={}, requestSeqNo={}: received={}", this, request, future.getSequenceNumber(),
SshConstants.getCommandMessageName(((GlobalRequestException) result).getCode()));
}
return null;
}
}

if (result instanceof Throwable) {
Expand Down Expand Up @@ -1290,7 +1300,7 @@ protected boolean doInvokeUnimplementedMessageHandler(int cmd, Buffer buffer) th
Buffer resultBuf = ByteArrayBuffer.getCompactClone(buffer.array(), buffer.rpos(), buffer.available());
handler.accept(cmd, resultBuf);
} else {
future.fail(SshConstants.getCommandMessageName(cmd));
future.setValue(new GlobalRequestException(cmd));
}
return true; // message handled internally
} else if (future != null) {
Expand Down Expand Up @@ -2204,7 +2214,7 @@ protected void requestFailure(Buffer buffer) throws Exception {
Buffer resultBuf = ByteArrayBuffer.getCompactClone(buffer.array(), buffer.rpos(), buffer.available());
handler.accept(SshConstants.SSH_MSG_REQUEST_FAILURE, resultBuf);
} else {
request.fail(SshConstants.getCommandMessageName(SshConstants.SSH_MSG_REQUEST_FAILURE));
request.setValue(new GlobalRequestException(SshConstants.SSH_MSG_REQUEST_FAILURE));
}
}
}
Expand Down
31 changes: 31 additions & 0 deletions sshd-core/src/test/java/org/apache/sshd/KeepAliveTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,37 @@ public Result process(
}
}

@Test // see GH-268
public void testTimeoutOnMissingHeartbeatResponse() throws Exception {
CoreModuleProperties.IDLE_TIMEOUT.set(sshd, Duration.ofSeconds(30));
List<RequestHandler<ConnectionService>> globalHandlers = sshd.getGlobalRequestHandlers();
sshd.setGlobalRequestHandlers(Collections.singletonList(new AbstractConnectionServiceRequestHandler() {
@Override
public Result process(ConnectionService connectionService, String request, boolean wantReply, Buffer buffer)
throws Exception {
// Never reply;
return Result.Replied;
}
}));
CoreModuleProperties.HEARTBEAT_INTERVAL.set(client, HEARTBEAT);
CoreModuleProperties.HEARTBEAT_REPLY_WAIT.set(client, Duration.ofSeconds(1));
try (ClientSession session = client.connect(getCurrentTestName(), TEST_LOCALHOST, port).verify(CONNECT_TIMEOUT)
.getSession()) {
session.addPasswordIdentity(getCurrentTestName());
session.auth().verify(AUTH_TIMEOUT);

try (ClientChannel channel = session.createChannel(Channel.CHANNEL_SHELL)) {
long waitStart = System.currentTimeMillis();
Collection<ClientChannelEvent> result = channel.waitFor(EnumSet.of(ClientChannelEvent.CLOSED), TIMEOUT);
long waitEnd = System.currentTimeMillis();
assertTrue("Wrong channel state after wait of " + (waitEnd - waitStart) + " ms: " + result,
result.contains(ClientChannelEvent.CLOSED));
}
} finally {
sshd.setGlobalRequestHandlers(globalHandlers); // restore original
}
}

public static class TestEchoShellFactory extends EchoShellFactory {
public TestEchoShellFactory() {
super();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.sshd.common.channel.RequestHandler;
import org.apache.sshd.common.config.keys.KeyUtils;
import org.apache.sshd.common.future.GlobalRequestFuture;
import org.apache.sshd.common.global.GlobalRequestException;
import org.apache.sshd.common.session.helpers.AbstractConnectionServiceRequestHandler;
import org.apache.sshd.common.signature.Signature;
import org.apache.sshd.common.util.GenericUtils;
Expand Down Expand Up @@ -197,22 +198,28 @@ public Result process(ConnectionService connectionService, String request, boole
case 0: {
j++;
Buffer reply = request.getBuffer();
assertTrue("Expected success for request " + i, reply != null);
assertNotNull("Expected success for request " + i, reply);
assertEquals("Expected a success", (byte) ('1' + j - 1), reply.getByte());
break;
}
case 1:
j++;
failure = request.getException();
assertNotNull("Expected failure for request " + i, failure);
assertEquals("Unexpected failure reason for request " + i, "SSH_MSG_REQUEST_FAILURE",
failure.getMessage());
assertTrue("Unexpected failure type", failure instanceof GlobalRequestException);
assertEquals("Unexpected failure reason for request " + i, SshConstants.SSH_MSG_REQUEST_FAILURE,
((GlobalRequestException) failure).getCode());
assertTrue("Unexpected failure message for request " + i,
failure.getMessage().contains("SSH_MSG_REQUEST_FAILURE"));
break;
default:
failure = request.getException();
assertNotNull("Expected failure for request " + i, failure);
assertEquals("Unexpected failure reason for request " + i, "SSH_MSG_UNIMPLEMENTED",
failure.getMessage());
assertTrue("Unexpected failure type", failure instanceof GlobalRequestException);
assertEquals("Unexpected failure reason for request " + i, SshConstants.SSH_MSG_UNIMPLEMENTED,
((GlobalRequestException) failure).getCode());
assertTrue("Unexpected failure message for request " + i,
failure.getMessage().contains("SSH_MSG_UNIMPLEMENTED"));
break;
}
}
Expand Down

0 comments on commit 29ed50c

Please sign in to comment.