Skip to content

Commit

Permalink
Change AsyncSSLSocket constructor used on fallback
Browse files Browse the repository at this point in the history
Summary: Use the `AsyncSocket*` constructor introduced in D21614835 when downgrading from `AsyncFizzServer` to `AsyncSSLSocket`. This will ensure that `move` will be called for any `AsyncSocket::LifecycleObserver` installed on the old socket.

Reviewed By: mingtaoy

Differential Revision: D21614839

fbshipit-source-id: b517d5802a64eb6e5036a0257238d47124b2f08e
  • Loading branch information
bschlinker authored and facebook-github-bot committed Jul 16, 2020
1 parent 1eda2c3 commit ba01a0a
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 13 deletions.
21 changes: 10 additions & 11 deletions wangle/acceptor/FizzAcceptorHandshakeHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,23 +111,22 @@ void FizzAcceptorHandshakeHelper::fizzHandshakeError(

folly::AsyncSSLSocket::UniquePtr FizzAcceptorHandshakeHelper::createSSLSocket(
const std::shared_ptr<folly::SSLContext>& context,
folly::EventBase* evb,
int fd) {
return folly::AsyncSSLSocket::UniquePtr(new folly::AsyncSSLSocket(
context, evb, folly::NetworkSocket::fromFd(fd)));
folly::AsyncTransport::UniquePtr transport) {
auto socket = transport->getUnderlyingTransport<folly::AsyncSocket>();
auto sslSocket = folly::AsyncSSLSocket::UniquePtr(
new folly::AsyncSSLSocket(context, CHECK_NOTNULL(socket)));
transport.reset();
return sslSocket;
}

void FizzAcceptorHandshakeHelper::fizzHandshakeAttemptFallback(
std::unique_ptr<folly::IOBuf> clientHello) {
VLOG(3) << "Fallback to OpenSSL";
if (loggingCallback_) {
loggingCallback_->logFizzHandshakeFallback(*transport_, &tinfo_);
}
sslSocket_ = createSSLSocket(sslContext_, std::move(transport_));

auto evb = transport_->getEventBase();
auto fd = transport_->getUnderlyingTransport<folly::AsyncSocket>()
->detachNetworkSocket()
.toFd();
transport_.reset();

sslSocket_ = createSSLSocket(sslContext_, evb, fd);
sslSocket_->setPreReceivedData(std::move(clientHello));
sslSocket_->enableClientHelloParsing();
sslSocket_->forceCacheAddrOnFailure(true);
Expand Down
6 changes: 4 additions & 2 deletions wangle/acceptor/FizzAcceptorHandshakeHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ class FizzAcceptorHandshakeHelper
virtual void logFizzHandshakeSuccess(
const fizz::server::AsyncFizzServer&,
const wangle::TransportInfo* tinfo) = 0;
virtual void logFizzHandshakeFallback(
const fizz::server::AsyncFizzServer&,
const wangle::TransportInfo* tinfo) = 0;
virtual void logFizzHandshakeError(
const fizz::server::AsyncFizzServer&,
const folly::exception_wrapper&) = 0;
Expand Down Expand Up @@ -100,8 +103,7 @@ class FizzAcceptorHandshakeHelper

virtual folly::AsyncSSLSocket::UniquePtr createSSLSocket(
const std::shared_ptr<folly::SSLContext>& sslContext,
folly::EventBase* evb,
int fd);
folly::AsyncTransport::UniquePtr transport);

// AsyncFizzServer::HandshakeCallback API
void fizzHandshakeSuccess(
Expand Down
203 changes: 203 additions & 0 deletions wangle/acceptor/test/AcceptorTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ class TestAcceptor : public Acceptor {
addConnection(new TestConnection);
getEventBase()->terminateLoopSoon();
}

DefaultToFizzPeekingCallback* getFizzPeeker() override {
return Acceptor::getFizzPeeker();
}
};

enum class TestSSLConfig { NO_SSL, SSL };
Expand Down Expand Up @@ -149,6 +153,33 @@ class MockAcceptObserver : public AcceptObserver {
GMOCK_METHOD1_(, noexcept, , observerDetach, void(Acceptor* const));
};

class MockAsyncSocketLifecycleObserver : public AsyncSocket::LifecycleObserver {
public:
GMOCK_METHOD1_(, noexcept, , observerAttach, void(AsyncTransport*));
GMOCK_METHOD1_(, noexcept, , observerDetach, void(AsyncTransport*));
GMOCK_METHOD1_(, noexcept, , destroy, void(AsyncTransport*));
GMOCK_METHOD1_(, noexcept, , close, void(AsyncTransport*));
GMOCK_METHOD1_(, noexcept, , connect, void(AsyncTransport*));
GMOCK_METHOD1_(, noexcept, , fdDetach, void(AsyncSocket*));
GMOCK_METHOD2_(, noexcept, , move, void(AsyncSocket*, AsyncSocket*));
};

class MockFizzLoggingCallback
: public FizzAcceptorHandshakeHelper::LoggingCallback {
public:
MOCK_METHOD2(
logFizzHandshakeSuccess,
void(const fizz::server::AsyncFizzServer&, const wangle::TransportInfo*));
MOCK_METHOD2(
logFizzHandshakeFallback,
void(const fizz::server::AsyncFizzServer&, const wangle::TransportInfo*));
MOCK_METHOD2(
logFizzHandshakeError,
void(
const fizz::server::AsyncFizzServer&,
const folly::exception_wrapper&));
};

TEST_P(AcceptorTest, AcceptObserver) {
auto [acceptor, serverSocket] = initTestAcceptorAndSocket();
SocketAddress serverAddress;
Expand Down Expand Up @@ -332,3 +363,175 @@ TEST_P(AcceptorTest, AcceptObserverStopAcceptorThenRemoveCallback) {
EXPECT_TRUE(acceptor->removeAcceptObserver(cb.get()));
Mock::VerifyAndClearExpectations(cb.get());
}

/**
* Test if AsyncSocket::LifecycleObserver can track socket during SSL accept.
*
* With Fizz support, the accept process involves transforming the AsyncSocket
* to an AsyncFizzServer. Then, if Fizz falls back to OpenSSL, the
* AsyncFizzServer will be transformed into an AsyncSSLSocket.
*
* During each transformation, the LifecycleObserver::move callback must be
* triggered so that the observer can unsubscribe from events on the old socket
* and subscribe to events on the new socket. This requires Wangle and Fizz to
* use the AsyncSocket(AsyncSocket* oldSocket) constructor when performing the
* transformation.
*
* This test ensures that even in the worst case, where two transformations
* occur, that the observer will be able to track the socket through to the
* completion of the accept, when ready() is triggered.
*/
TEST_P(
AcceptorTest,
AcceptObserverInstallSocketObserverThenFizzThenFallbackToSSL) {
auto [acceptor, serverSocket] = initTestAcceptorAndSocket();
auto onAcceptCb = std::make_unique<StrictMock<MockAcceptObserver>>();
auto lifecycleCb =
std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
auto fizzLoggingCb = std::make_unique<StrictMock<MockFizzLoggingCallback>>();
acceptor->getFizzPeeker()->setLoggingCallback(fizzLoggingCb.get());

EXPECT_CALL(*onAcceptCb, observerAttach(acceptor.get()));
acceptor->addAcceptObserver(onAcceptCb.get());
Mock::VerifyAndClearExpectations(onAcceptCb.get());

// add connection, expect callbacks
SocketAddress serverAddress;
serverSocket->getAddress(&serverAddress);
auto clientSocket = connectClientSocket(serverAddress);
folly::AsyncTransport* remoteSocket = nullptr;

// we have to EXPECT_EQ below instead of matchers because the remoteSocket
// will change; ByRef() does not work here, due to const AsyncSocket*
Sequence s1;
Sequence s2;
EXPECT_CALL(*onAcceptCb, accept(_))
.InSequence(s1)
.WillOnce(Invoke([&lifecycleCb, &remoteSocket](auto socket) {
remoteSocket = socket;
EXPECT_CALL(*lifecycleCb, observerAttach(socket));
socket->addLifecycleObserver(lifecycleCb.get());
}));

if (GetParam() == TestSSLConfig::SSL) {
// AsyncSocket -> AsyncFizzServer
EXPECT_CALL(*lifecycleCb, fdDetach(_))
.InSequence(s1)
.WillOnce(Invoke([&remoteSocket](folly::AsyncTransport* socket) {
EXPECT_EQ(remoteSocket, socket);
}));
EXPECT_CALL(*lifecycleCb, move(_, _))
.InSequence(s1)
.WillOnce(Invoke([&lifecycleCb, &remoteSocket, &s2](
folly::AsyncSocket* oldSocket,
folly::AsyncSocket* newSocket) {
EXPECT_EQ(remoteSocket, oldSocket);
EXPECT_NE(remoteSocket, newSocket);

// remove LifecycleCallback from old socket
EXPECT_CALL(*lifecycleCb, observerDetach(_))
.InSequence(s2)
.WillOnce(Invoke([&remoteSocket](folly::AsyncTransport* socket) {
EXPECT_EQ(remoteSocket, socket);
}));
EXPECT_TRUE(oldSocket->removeLifecycleObserver(lifecycleCb.get()));
EXPECT_THAT(oldSocket->getLifecycleObservers(), IsEmpty());

// add LifecycleCallback to new socket
EXPECT_THAT(newSocket->getLifecycleObservers(), IsEmpty());
EXPECT_CALL(*lifecycleCb, observerAttach(_));
newSocket->addLifecycleObserver(lifecycleCb.get());
EXPECT_THAT(
newSocket->getLifecycleObservers(),
UnorderedElementsAre(lifecycleCb.get()));

// update remoteSocket
remoteSocket = newSocket;
}));

// AsyncFizzServer -> AsyncSSLSocket
// use logFizzHandshakeFallback to verify that fallback occurred
EXPECT_CALL(*fizzLoggingCb, logFizzHandshakeFallback(_, _))
.InSequence(s1)
.WillOnce(Invoke([&remoteSocket](
const fizz::server::AsyncFizzServer& transport,
const wangle::TransportInfo* /* tinfo */) {
EXPECT_EQ(
remoteSocket,
transport.getUnderlyingTransport<folly::AsyncSocket>());
}));
EXPECT_CALL(*lifecycleCb, fdDetach(_))
.InSequence(s1)
.WillOnce(Invoke([&remoteSocket](folly::AsyncSocket* socket) {
EXPECT_EQ(remoteSocket, socket);
}));
EXPECT_CALL(*lifecycleCb, move(_, _))
.InSequence(s1)
.WillOnce(Invoke([&lifecycleCb, &remoteSocket, &s2](
folly::AsyncSocket* oldSocket,
folly::AsyncSocket* newSocket) {
EXPECT_EQ(remoteSocket, oldSocket);
EXPECT_NE(remoteSocket, newSocket);

// remove LifecycleCallback from old socket
EXPECT_THAT(
oldSocket->getLifecycleObservers(),
UnorderedElementsAre(lifecycleCb.get()));
EXPECT_CALL(*lifecycleCb, observerDetach(_))
.InSequence(s2)
.WillOnce(Invoke([&remoteSocket](folly::AsyncTransport* socket) {
EXPECT_EQ(remoteSocket, socket);
}));
EXPECT_TRUE(oldSocket->removeLifecycleObserver(lifecycleCb.get()));
EXPECT_THAT(oldSocket->getLifecycleObservers(), IsEmpty());

// add LifecycleCallback to new socket
EXPECT_THAT(newSocket->getLifecycleObservers(), IsEmpty());
EXPECT_CALL(*lifecycleCb, observerAttach(newSocket));
newSocket->addLifecycleObserver(lifecycleCb.get());
EXPECT_THAT(
newSocket->getLifecycleObservers(),
UnorderedElementsAre(lifecycleCb.get()));

// update remoteSocket
remoteSocket = newSocket;
}));
}

// the socket will be ready, and then immediately close
EXPECT_CALL(*onAcceptCb, ready(_))
.InSequence(s1)
.WillOnce(
Invoke([&lifecycleCb, &remoteSocket](const auto* const& socket) {
EXPECT_EQ(remoteSocket, socket);
EXPECT_THAT(
socket->getLifecycleObservers(),
UnorderedElementsAre(lifecycleCb.get()));
}));
EXPECT_CALL(*lifecycleCb, close(_))
.InSequence(s1)
.WillOnce(Invoke([&remoteSocket](const auto* const& socket) {
EXPECT_EQ(remoteSocket, socket);
}));
EXPECT_CALL(*lifecycleCb, destroy(_))
.InSequence(s1)
.WillOnce(Invoke([&remoteSocket](const auto* const& socket) {
EXPECT_EQ(remoteSocket, socket);
}));

evb_.loopForever();
Mock::VerifyAndClearExpectations(onAcceptCb.get());
Mock::VerifyAndClearExpectations(fizzLoggingCb.get());
Mock::VerifyAndClearExpectations(lifecycleCb.get());
CHECK_EQ(acceptor->getNumConnections(), 1);
CHECK(acceptor->getState() == Acceptor::State::kRunning);

acceptor->forceStop();
serverSocket->stopAccepting();
evb_.loop();

EXPECT_CALL(*onAcceptCb, observerDetach(acceptor.get()));
EXPECT_TRUE(acceptor->removeAcceptObserver(onAcceptCb.get()));
acceptor = nullptr;
Mock::VerifyAndClearExpectations(onAcceptCb.get());
}

0 comments on commit ba01a0a

Please sign in to comment.