Skip to content

Commit

Permalink
Add SendKeyShare::AlwaysDefaultShares mode
Browse files Browse the repository at this point in the history
Summary: Add new `SendKeyShare` enum value `AlwaysDefaultShares` which sends all default keyshares instead of only the group on the PSK.

Reviewed By: knekritz

Differential Revision: D69614697

fbshipit-source-id: 805df78653d56d43b1439e2282e1aa3ace7fc274
  • Loading branch information
Jolene Tan authored and facebook-github-bot committed Feb 19, 2025
1 parent 9f0eecd commit b4fe516
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 2 deletions.
3 changes: 2 additions & 1 deletion fizz/client/ClientProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -830,7 +830,8 @@ EventHandler<ClientTypes, StateEnum::Uninitialized, Event::Connect>::handle(
// If we have a saved PSK, use the group to choose which groups to
// send by default
std::vector<NamedGroup> selectedShares;
if (psk && psk->group &&
if (context->getSendKeyShare() != SendKeyShare::AlwaysDefaultShares && psk &&
psk->group &&
std::find(
context->getSupportedGroups().begin(),
context->getSupportedGroups().end(),
Expand Down
8 changes: 8 additions & 0 deletions fizz/client/FizzClientContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,17 @@
namespace fizz {
namespace client {

/**
* Controls keyshare sending behaviour in the presence of a saved PSK.
* Always: If saved PSK has a group, always send that group and only that group.
* WhenNecessary: Omit sending keyshare iff saved PSK has no group i.e. psk_ke
* AlwaysDefaultShares: Send all default shares regardless of saved PSK and its
* group
*/
enum class SendKeyShare {
Always,
WhenNecessary,
AlwaysDefaultShares,
};

class FizzClientContext {
Expand Down
62 changes: 61 additions & 1 deletion fizz/client/test/ClientProtocolTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2744,7 +2744,67 @@ TEST_F(ClientProtocolTest, TestConnectPskKeAlwaysShares) {
encodedHello->trimStart(4);
auto decodedHello = decode<ClientHello>(std::move(encodedHello));
auto keyShare = getExtension<ClientKeyShare>(decodedHello.extensions);
EXPECT_TRUE(!keyShare->client_shares.empty());
EXPECT_EQ(
keyShare->client_shares.size(), context_->getDefaultShares().size());
EXPECT_TRUE(!state_.keyExchangers()->empty());
}

TEST_F(ClientProtocolTest, TestConnectPskDheKeAlwaysShares) {
Connect connect;
context_->setSendKeyShare(SendKeyShare::Always);
context_->setDefaultShares({NamedGroup::secp256r1, NamedGroup::x25519});
connect.context = context_;
auto psk = getCachedPsk();
psk.group = NamedGroup::x25519;
connect.cachedPsk = psk;
fizz::Param param = std::move(connect);
auto actions = detail::processEvent(state_, param);
expectActions<MutateState, WriteToSocket>(actions);
processStateMutations(actions);
EXPECT_EQ(state_.state(), StateEnum::ExpectingServerHello);

auto& encodedHello = *state_.encodedClientHello();

// Get rid of handshake header (type + version)
encodedHello->trimStart(4);
auto decodedHello = decode<ClientHello>(std::move(encodedHello));
auto keyShare = getExtension<ClientKeyShare>(decodedHello.extensions);
const auto& clientShares = keyShare->client_shares;
EXPECT_EQ(clientShares.size(), 1);
EXPECT_EQ(clientShares[0].group, NamedGroup::x25519);
EXPECT_TRUE(!state_.keyExchangers()->empty());
}

TEST_F(ClientProtocolTest, TestConnectPskDheKeAlwaysDefaultShares) {
Connect connect;
context_->setSendKeyShare(SendKeyShare::AlwaysDefaultShares);
std::vector<NamedGroup> defaultShares = {
NamedGroup::secp256r1, NamedGroup::x25519};
context_->setDefaultShares(defaultShares);
connect.context = context_;
auto psk = getCachedPsk();
psk.group = NamedGroup::x25519;
connect.cachedPsk = psk;
fizz::Param param = std::move(connect);
auto actions = detail::processEvent(state_, param);
expectActions<MutateState, WriteToSocket>(actions);
processStateMutations(actions);
EXPECT_EQ(state_.state(), StateEnum::ExpectingServerHello);

auto& encodedHello = *state_.encodedClientHello();

// Get rid of handshake header (type + version)
encodedHello->trimStart(4);
auto decodedHello = decode<ClientHello>(std::move(encodedHello));
auto keyShare = getExtension<ClientKeyShare>(decodedHello.extensions);
const auto& clientShares = keyShare->client_shares;
EXPECT_EQ(clientShares.size(), 2);
std::vector<NamedGroup> clientSharesNamedGroups;
for (const auto& clientShare : clientShares) {
clientSharesNamedGroups.push_back(clientShare.group);
}
std::sort(clientSharesNamedGroups.begin(), clientSharesNamedGroups.end());
EXPECT_EQ(clientSharesNamedGroups, defaultShares);
EXPECT_TRUE(!state_.keyExchangers()->empty());
}

Expand Down

0 comments on commit b4fe516

Please sign in to comment.