Skip to content

Commit

Permalink
Add QuicPskCache to FizzClientQuicHandshakeContext
Browse files Browse the repository at this point in the history
  • Loading branch information
deadalnix committed Feb 26, 2020
1 parent 37dfc0f commit ce57445
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 15 deletions.
25 changes: 22 additions & 3 deletions quic/client/handshake/FizzClientQuicHandshakeContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,34 @@ namespace quic {

FizzClientQuicHandshakeContext::FizzClientQuicHandshakeContext(
std::shared_ptr<const fizz::client::FizzClientContext> context,
std::shared_ptr<const fizz::CertificateVerifier> verifier)
: context_(std::move(context)), verifier_(std::move(verifier)) {}
std::shared_ptr<const fizz::CertificateVerifier> verifier,
std::shared_ptr<QuicPskCache> pskCache)
: context_(std::move(context)),
verifier_(std::move(verifier)),
pskCache_(std::move(pskCache)) {}

std::unique_ptr<ClientHandshake>
FizzClientQuicHandshakeContext::makeClientHandshake(
QuicClientConnectionState* conn) {
return std::make_unique<FizzClientHandshake>(conn, shared_from_this());
}

folly::Optional<QuicCachedPsk> FizzClientQuicHandshakeContext::getPsk(
const folly::Optional<std::string>& hostname) {
if (!hostname || !pskCache_) {
return folly::none;
}

return pskCache_->getPsk(*hostname);
}

void FizzClientQuicHandshakeContext::removePsk(
const folly::Optional<std::string>& hostname) {
if (hostname && pskCache_) {
pskCache_->removePsk(*hostname);
}
}

std::shared_ptr<FizzClientQuicHandshakeContext>
FizzClientQuicHandshakeContext::Builder::build() {
if (!context_) {
Expand All @@ -35,7 +54,7 @@ FizzClientQuicHandshakeContext::Builder::build() {

return std::shared_ptr<FizzClientQuicHandshakeContext>(
new FizzClientQuicHandshakeContext(
std::move(context_), std::move(verifier_)));
std::move(context_), std::move(verifier_), std::move(pskCache_)));
}

} // namespace quic
16 changes: 15 additions & 1 deletion quic/client/handshake/FizzClientQuicHandshakeContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#include <quic/client/handshake/ClientHandshakeFactory.h>

#include <quic/client/handshake/QuicPskCache.h>

#include <fizz/client/FizzClientContext.h>
#include <fizz/protocol/DefaultCertificateVerifier.h>

Expand All @@ -34,6 +36,10 @@ class FizzClientQuicHandshakeContext
return verifier_;
}

folly::Optional<QuicCachedPsk> getPsk(
const folly::Optional<std::string>& hostname);
void removePsk(const folly::Optional<std::string>& hostname);

private:
/**
* We make the constructor private so that users have to use the Builder
Expand All @@ -45,10 +51,12 @@ class FizzClientQuicHandshakeContext
*/
FizzClientQuicHandshakeContext(
std::shared_ptr<const fizz::client::FizzClientContext> context,
std::shared_ptr<const fizz::CertificateVerifier> verifier);
std::shared_ptr<const fizz::CertificateVerifier> verifier,
std::shared_ptr<QuicPskCache> pskCache);

std::shared_ptr<const fizz::client::FizzClientContext> context_;
std::shared_ptr<const fizz::CertificateVerifier> verifier_;
std::shared_ptr<QuicPskCache> pskCache_;

public:
class Builder {
Expand All @@ -65,11 +73,17 @@ class FizzClientQuicHandshakeContext
return *this;
}

Builder& setPskCache(std::shared_ptr<QuicPskCache> pskCache) {
pskCache_ = std::move(pskCache);
return *this;
}

std::shared_ptr<FizzClientQuicHandshakeContext> build();

private:
std::shared_ptr<const fizz::client::FizzClientContext> context_;
std::shared_ptr<const fizz::CertificateVerifier> verifier_;
std::shared_ptr<QuicPskCache> pskCache_;
};
};

Expand Down
11 changes: 11 additions & 0 deletions quic/client/handshake/test/ClientHandshakeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class ClientHandshakeTest : public Test, public boost::static_visitor<> {
auto handshakeFactory = FizzClientQuicHandshakeContext::Builder()
.setFizzClientContext(clientCtx)
.setCertificateVerifier(verifier)
.setPskCache(getPskCache())
.build();
conn.reset(new QuicClientConnectionState(handshakeFactory));
cryptoState = conn->cryptoState.get();
Expand Down Expand Up @@ -117,6 +118,10 @@ class ClientHandshakeTest : public Test, public boost::static_visitor<> {
fizzServer->accept(&evb, serverCtx, serverTransportParameters);
}

virtual std::shared_ptr<QuicPskCache> getPskCache() {
return nullptr;
}

void clientServerRound() {
auto writableBytes = getHandshakeWriteBytes();
serverReadBuf.append(std::move(writableBytes));
Expand Down Expand Up @@ -460,6 +465,12 @@ class ClientHandshakeZeroRttTest : public ClientHandshakeTest {
setupZeroRttServer();
}

std::shared_ptr<QuicPskCache> getPskCache() override {
auto pskCache = std::make_shared<BasicQuicPskCache>();
pskCache->putPsk(hostname, psk);
return pskCache;
}

void connect() override {
handshake->connect(
hostname,
Expand Down
50 changes: 39 additions & 11 deletions quic/client/test/QuicClientTransportTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ class QuicClientTransportIntegrationTest : public TestWithParam<TestingParams> {
auto fizzClientContext = FizzClientQuicHandshakeContext::Builder()
.setFizzClientContext(clientCtx)
.setCertificateVerifier(verifier)
.setPskCache(pskCache_)
.build();
client = std::make_shared<TestingQuicClientTransport>(
&eventbase_,
Expand Down Expand Up @@ -1114,7 +1115,9 @@ INSTANTIATE_TEST_CASE_P(
// from the server.
class FakeOneRttHandshakeLayer : public ClientHandshake {
public:
explicit FakeOneRttHandshakeLayer(QuicClientConnectionState* conn)
explicit FakeOneRttHandshakeLayer(
QuicClientConnectionState* conn,
std::shared_ptr<FizzClientQuicHandshakeContext> fizzContext)
: ClientHandshake(conn) {}

void connect(
Expand Down Expand Up @@ -1303,17 +1306,29 @@ class QuicClientTransportTest : public Test {
QuicClientTransportTest()
: eventbase_(std::make_unique<folly::EventBase>()) {}

std::shared_ptr<FizzClientQuicHandshakeContext> getFizzClientContext() {
if (!fizzClientContext) {
fizzClientContext =
FizzClientQuicHandshakeContext::Builder()
.setCertificateVerifier(createTestCertificateVerifier())
.setPskCache(getPskCache())
.build();
}

return fizzClientContext;
}

virtual std::shared_ptr<QuicPskCache> getPskCache() {
return nullptr;
}

void SetUp() override final {
auto socket = std::make_unique<NiceMock<folly::test::MockAsyncUDPSocket>>(
eventbase_.get());
sock = socket.get();

auto fizzClientContext =
FizzClientQuicHandshakeContext::Builder()
.setCertificateVerifier(createTestCertificateVerifier())
.build();
client = TestingQuicClientTransport::newClient<TestingQuicClientTransport>(
eventbase_.get(), std::move(socket), std::move(fizzClientContext));
eventbase_.get(), std::move(socket), getFizzClientContext());
destructionCallback = std::make_shared<DestructionCallback>();
client->setDestructionCallback(destructionCallback);
client->setSupportedVersions(
Expand Down Expand Up @@ -1359,8 +1374,8 @@ class QuicClientTransportTest : public Test {

virtual void setupCryptoLayer() {
// Fake that the handshake has already occured and fix the keys.
mockClientHandshake =
new FakeOneRttHandshakeLayer(&client->getNonConstConn());
mockClientHandshake = new FakeOneRttHandshakeLayer(
&client->getNonConstConn(), getFizzClientContext());
client->getNonConstConn().clientHandshakeLayer = mockClientHandshake;
client->getNonConstConn().handshakeLayer.reset(mockClientHandshake);
setFakeHandshakeCiphers();
Expand Down Expand Up @@ -1698,6 +1713,7 @@ class QuicClientTransportTest : public Test {
SocketAddress serverAddr{"127.0.0.1", 443};
AsyncUDPSocket::ReadCallback* networkReadCallback{nullptr};
FakeOneRttHandshakeLayer* mockClientHandshake;
std::shared_ptr<FizzClientQuicHandshakeContext> fizzClientContext;
std::shared_ptr<TestingQuicClientTransport> client;
PacketNum initialPacketNum{0}, handshakePacketNum{0}, appDataPacketNum{0};
std::unique_ptr<ConnectionIdAlgo> connIdAlgo_;
Expand Down Expand Up @@ -4965,11 +4981,15 @@ class QuicClientTransportPskCacheTest
: public QuicClientTransportAfterStartTestBase {
public:
void SetUpChild() override {
mockPskCache_ = std::make_shared<NiceMock<MockQuicPskCache>>();
client->setPskCache(mockPskCache_);
QuicClientTransportAfterStartTestBase::SetUpChild();
}

std::shared_ptr<QuicPskCache> getPskCache() override {
mockPskCache_ = std::make_shared<NiceMock<MockQuicPskCache>>();
return mockPskCache_;
}

protected:
std::shared_ptr<MockQuicPskCache> mockPskCache_;
};
Expand Down Expand Up @@ -5058,12 +5078,16 @@ class QuicZeroRttClientTest : public QuicClientTransportAfterStartTestBase {
test::createNoOpHeaderCipher());
}

std::shared_ptr<QuicPskCache> getPskCache() override {
mockQuicPskCache_ = std::make_shared<MockQuicPskCache>();
return mockQuicPskCache_;
}

void start() override {
TransportSettings clientSettings;
// Ignore path mtu to test negotiation.
clientSettings.canIgnorePathMTU = true;
client->setTransportSettings(clientSettings);
mockQuicPskCache_ = std::make_shared<NiceMock<MockQuicPskCache>>();
client->setPskCache(mockQuicPskCache_);
}

Expand Down Expand Up @@ -5331,7 +5355,6 @@ class QuicZeroRttHappyEyeballsClientTransportTest
public:
void SetUpChild() override {
client->setHostname(hostname_);
mockQuicPskCache_ = std::make_shared<NiceMock<MockQuicPskCache>>();
client->setPskCache(mockQuicPskCache_);

auto secondSocket =
Expand All @@ -5351,6 +5374,11 @@ class QuicZeroRttHappyEyeballsClientTransportTest
setupCryptoLayer();
}

std::shared_ptr<QuicPskCache> getPskCache() override {
mockQuicPskCache_ = std::make_shared<MockQuicPskCache>();
return mockQuicPskCache_;
}

protected:
folly::test::MockAsyncUDPSocket* secondSock;
SocketAddress firstAddress{"::1", 443};
Expand Down

0 comments on commit ce57445

Please sign in to comment.