diff --git a/src/aioquic/tls.py b/src/aioquic/tls.py index 0766147d..47c0dc44 100644 --- a/src/aioquic/tls.py +++ b/src/aioquic/tls.py @@ -1312,7 +1312,7 @@ def __init__( self._signature_algorithms.append(SignatureAlgorithm.ED25519) if default_backend().ed448_supported(): self._signature_algorithms.append(SignatureAlgorithm.ED448) - self._supported_groups = [Group.SECP256R1] + self._supported_groups = [Group.SECP256R1, Group.SECP384R1] if default_backend().x25519_supported(): self._supported_groups.append(Group.X25519) if default_backend().x448_supported(): @@ -1337,7 +1337,7 @@ def __init__( self._dec_key: Optional[bytes] = None self.__logger = logger - self._ec_private_key: Optional[ec.EllipticCurvePrivateKey] = None + self._ec_private_keys: List[ec.EllipticCurvePrivateKey] = [] self._x25519_private_key: Optional[x25519.X25519PrivateKey] = None self._x448_private_key: Optional[x448.X448PrivateKey] = None @@ -1525,13 +1525,7 @@ def _client_send_hello(self, output_buf: Buffer) -> None: supported_groups: List[int] = [] for group in self._supported_groups: - if group == Group.SECP256R1: - self._ec_private_key = ec.generate_private_key( - GROUP_TO_CURVE[Group.SECP256R1]() - ) - key_share.append(encode_public_key(self._ec_private_key.public_key())) - supported_groups.append(Group.SECP256R1) - elif group == Group.X25519: + if group == Group.X25519: self._x25519_private_key = x25519.X25519PrivateKey.generate() key_share.append( encode_public_key(self._x25519_private_key.public_key()) @@ -1544,6 +1538,11 @@ def _client_send_hello(self, output_buf: Buffer) -> None: elif group == Group.GREASE: key_share.append((Group.GREASE, b"\x00")) supported_groups.append(Group.GREASE) + elif group in GROUP_TO_CURVE: + ec_private_key = ec.generate_private_key(GROUP_TO_CURVE[group]()) + self._ec_private_keys.append(ec_private_key) + key_share.append(encode_public_key(ec_private_key.public_key())) + supported_groups.append(group) assert len(key_share), "no key share entries" @@ -1668,13 +1667,13 @@ def _client_handle_hello(self, input_buf: Buffer, output_buf: Buffer) -> None: and self._x448_private_key is not None ): shared_key = self._x448_private_key.exchange(peer_public_key) - elif ( - isinstance(peer_public_key, ec.EllipticCurvePublicKey) - and self._ec_private_key is not None - and self._ec_private_key.public_key().curve.__class__ - == peer_public_key.curve.__class__ - ): - shared_key = self._ec_private_key.exchange(ec.ECDH(), peer_public_key) + elif isinstance(peer_public_key, ec.EllipticCurvePublicKey): + for ec_private_key in self._ec_private_keys: + if ( + ec_private_key.public_key().curve.__class__ + == peer_public_key.curve.__class__ + ): + shared_key = ec_private_key.exchange(ec.ECDH(), peer_public_key) assert shared_key is not None self.key_schedule.update_hash(input_buf.data) @@ -1989,11 +1988,10 @@ def _server_handle_hello( shared_key = self._x448_private_key.exchange(peer_public_key) break elif isinstance(peer_public_key, ec.EllipticCurvePublicKey): - self._ec_private_key = ec.generate_private_key( - GROUP_TO_CURVE[key_share[0]]() - ) - public_key = self._ec_private_key.public_key() - shared_key = self._ec_private_key.exchange(ec.ECDH(), peer_public_key) + ec_private_key = ec.generate_private_key(GROUP_TO_CURVE[key_share[0]]()) + self._ec_private_keys.append(ec_private_key) + public_key = ec_private_key.public_key() + shared_key = ec_private_key.exchange(ec.ECDH(), peer_public_key) break assert shared_key is not None diff --git a/tests/test_tls.py b/tests/test_tls.py index ded2b8c1..f9d2831d 100644 --- a/tests/test_tls.py +++ b/tests/test_tls.py @@ -116,6 +116,11 @@ def reset_buffers(buffers): class ContextTest(TestCase): + def assertClientHello(self, data: bytes): + self.assertEqual(data[0], tls.HandshakeType.CLIENT_HELLO) + self.assertGreaterEqual(len(data), 191) + self.assertLessEqual(len(data), 564) + def create_client( self, alpn_protocols=None, cadata=None, cafile=SERVER_CACERTFILE, **kwargs ): @@ -379,8 +384,7 @@ def _handshake(self, client, server): client.handle_message(b"", client_buf) self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) server_input = merge_buffers(client_buf) - self.assertGreaterEqual(len(server_input), 181) - self.assertLessEqual(len(server_input), 358) + self.assertClientHello(server_input) reset_buffers(client_buf) # Handle client hello. @@ -445,8 +449,7 @@ def test_handshake_with_certificate_request_no_certificate(self): client.handle_message(b"", client_buf) self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) server_input = merge_buffers(client_buf) - self.assertGreaterEqual(len(server_input), 181) - self.assertLessEqual(len(server_input), 358) + self.assertClientHello(server_input) reset_buffers(client_buf) # Handle client hello. @@ -504,8 +507,7 @@ def test_handshake_with_certificate_request_with_certificate(self): client.handle_message(b"", client_buf) self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) server_input = merge_buffers(client_buf) - self.assertGreaterEqual(len(server_input), 181) - self.assertLessEqual(len(server_input), 358) + self.assertClientHello(server_input) reset_buffers(client_buf) # Handle client hello. @@ -660,6 +662,20 @@ def test_handshake_with_grease_group(self): self._handshake(client, server) + def test_handshake_with_secp256r1_group(self): + client = self.create_client() + client._supported_groups = [tls.Group.SECP256R1] + server = self.create_server() + + self._handshake(client, server) + + def test_handshake_with_secp384r1_group(self): + client = self.create_client() + client._supported_groups = [tls.Group.SECP384R1] + server = self.create_server() + + self._handshake(client, server) + def test_handshake_with_x25519(self): client = self.create_client() client._supported_groups = [tls.Group.X25519] @@ -729,8 +745,7 @@ def second_handshake(): client.handle_message(b"", client_buf) self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) server_input = merge_buffers(client_buf) - self.assertGreaterEqual(len(server_input), 383) - self.assertLessEqual(len(server_input), 483) + self.assertClientHello(server_input) reset_buffers(client_buf) # Handle client hello. @@ -782,8 +797,7 @@ def second_handshake_bad_binder(): client.handle_message(b"", client_buf) self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) server_input = merge_buffers(client_buf) - self.assertGreaterEqual(len(server_input), 383) - self.assertLessEqual(len(server_input), 483) + self.assertClientHello(server_input) reset_buffers(client_buf) # tamper with binder @@ -808,8 +822,7 @@ def second_handshake_bad_pre_shared_key(): client.handle_message(b"", client_buf) self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) server_input = merge_buffers(client_buf) - self.assertGreaterEqual(len(server_input), 383) - self.assertLessEqual(len(server_input), 483) + self.assertClientHello(server_input) reset_buffers(client_buf) # handle client hello