From b1776438dbadc893750bbe45a7cb66044669f8e6 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Sat, 4 Jan 2025 13:18:06 +0000 Subject: [PATCH 1/3] init --- .../server/superlink/linkstate/in_memory_linkstate.py | 6 +++++- src/py/flwr/server/superlink/linkstate/linkstate.py | 8 +++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py index d22072b41621..c4cbf6aaf754 100644 --- a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py @@ -307,7 +307,7 @@ def num_task_res(self) -> int: return len(self.task_res_store) def create_node( - self, ping_interval: float, public_key: Optional[bytes] = None + self, ping_interval: float ) -> int: """Create, store in the link state, and return `node_id`.""" # Sample a random int64 as node_id @@ -366,6 +366,10 @@ def get_nodes(self, run_id: int) -> set[int]: if online_until > current_time } + def set_node_public_key(self, node_id: int, public_key: bytes) -> None: + """Store `public_key` for the specified `node_id`.""" + self.public_key_to_node_id[public_key] = node_id + def get_node_id(self, node_public_key: bytes) -> Optional[int]: """Retrieve stored `node_id` filtered by `node_public_keys`.""" return self.public_key_to_node_id.get(node_public_key) diff --git a/src/py/flwr/server/superlink/linkstate/linkstate.py b/src/py/flwr/server/superlink/linkstate/linkstate.py index 4f3c16a5460a..1458ee4e4a3a 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate.py @@ -154,9 +154,7 @@ def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]: """Get all TaskIns IDs for the given run_id.""" @abc.abstractmethod - def create_node( - self, ping_interval: float, public_key: Optional[bytes] = None - ) -> int: + def create_node(self, ping_interval: float) -> int: """Create, store in the link state, and return `node_id`.""" @abc.abstractmethod @@ -173,6 +171,10 @@ def get_nodes(self, run_id: int) -> set[int]: an empty `Set` MUST be returned. """ + @abc.abstractmethod + def set_node_public_key(self, node_id: int, public_key: bytes) -> None: + """Store `public_key` for the specified `node_id`.""" + @abc.abstractmethod def get_node_id(self, node_public_key: bytes) -> Optional[int]: """Retrieve stored `node_id` filtered by `node_public_keys`.""" From 77352c3762cd90d45e18948610d06cf1d8b824bb Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Sun, 5 Jan 2025 16:26:23 +0000 Subject: [PATCH 2/3] add and implement two new methods --- .../fleet/grpc_rere/server_interceptor.py | 3 +- .../grpc_rere/server_interceptor_test.py | 107 +++--------------- .../linkstate/in_memory_linkstate.py | 48 ++++---- .../server/superlink/linkstate/linkstate.py | 8 +- .../superlink/linkstate/linkstate_test.py | 44 +++---- .../superlink/linkstate/sqlite_linkstate.py | 62 ++++++---- 6 files changed, 106 insertions(+), 166 deletions(-) diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py index c07ee0788493..6cafaaa21459 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py @@ -223,5 +223,6 @@ def _create_authenticated_node( # No `node_id` exists for the provided `public_key` # Handle `CreateNode` here instead of calling the default method handler # Note: the innermost `CreateNode` method will never be called - node_id = state.create_node(request.ping_interval, public_key_bytes) + node_id = state.create_node(request.ping_interval) + state.set_node_public_key(node_id, public_key_bytes) return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False)) diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py index a0ff7a77304a..9984b93f3e84 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py @@ -161,9 +161,7 @@ def test_unsuccessful_create_node_with_metadata(self) -> None: def test_successful_delete_node_with_metadata(self) -> None: """Test server interceptor for deleting node.""" # Prepare - node_id = self.state.create_node( - ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) - ) + node_id = self._create_node_and_set_public_key() request = DeleteNodeRequest(node=Node(node_id=node_id)) shared_secret = generate_shared_key( self._node_private_key, self._server_public_key @@ -191,9 +189,7 @@ def test_successful_delete_node_with_metadata(self) -> None: def test_unsuccessful_delete_node_with_metadata(self) -> None: """Test server interceptor for deleting node unsuccessfully.""" # Prepare - node_id = self.state.create_node( - ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) - ) + node_id = self._create_node_and_set_public_key() request = DeleteNodeRequest(node=Node(node_id=node_id)) node_private_key, _ = generate_key_pairs() shared_secret = generate_shared_key(node_private_key, self._server_public_key) @@ -217,9 +213,7 @@ def test_unsuccessful_delete_node_with_metadata(self) -> None: def test_successful_pull_task_ins_with_metadata(self) -> None: """Test server interceptor for pull task ins.""" # Prepare - node_id = self.state.create_node( - ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) - ) + node_id = self._create_node_and_set_public_key() request = PullTaskInsRequest(node=Node(node_id=node_id)) shared_secret = generate_shared_key( self._node_private_key, self._server_public_key @@ -247,9 +241,7 @@ def test_successful_pull_task_ins_with_metadata(self) -> None: def test_unsuccessful_pull_task_ins_with_metadata(self) -> None: """Test server interceptor for pull task ins unsuccessfully.""" # Prepare - node_id = self.state.create_node( - ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) - ) + node_id = self._create_node_and_set_public_key() request = PullTaskInsRequest(node=Node(node_id=node_id)) node_private_key, _ = generate_key_pairs() shared_secret = generate_shared_key(node_private_key, self._server_public_key) @@ -273,9 +265,7 @@ def test_unsuccessful_pull_task_ins_with_metadata(self) -> None: def test_successful_push_task_res_with_metadata(self) -> None: """Test server interceptor for push task res.""" # Prepare - node_id = self.state.create_node( - ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) - ) + node_id = self._create_node_and_set_public_key() run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) # Transition status to running. PushTaskRes is only allowed in running status. _ = self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")) @@ -311,9 +301,7 @@ def test_successful_push_task_res_with_metadata(self) -> None: def test_unsuccessful_push_task_res_with_metadata(self) -> None: """Test server interceptor for push task res unsuccessfully.""" # Prepare - node_id = self.state.create_node( - ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) - ) + node_id = self._create_node_and_set_public_key() run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) # Transition status to running. PushTaskRes is only allowed in running status. _ = self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")) @@ -344,9 +332,7 @@ def test_unsuccessful_push_task_res_with_metadata(self) -> None: def test_successful_get_run_with_metadata(self) -> None: """Test server interceptor for get run.""" # Prepare - self.state.create_node( - ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) - ) + self._create_node_and_set_public_key() run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) # Transition status to running. GetRun is only allowed in running status. _ = self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")) @@ -378,9 +364,7 @@ def test_successful_get_run_with_metadata(self) -> None: def test_unsuccessful_get_run_with_metadata(self) -> None: """Test server interceptor for get run unsuccessfully.""" # Prepare - self.state.create_node( - ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) - ) + self._create_node_and_set_public_key() run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) request = GetRunRequest(run_id=run_id) node_private_key, _ = generate_key_pairs() @@ -405,9 +389,7 @@ def test_unsuccessful_get_run_with_metadata(self) -> None: def test_successful_ping_with_metadata(self) -> None: """Test server interceptor for ping.""" # Prepare - node_id = self.state.create_node( - ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) - ) + node_id = self._create_node_and_set_public_key() request = PingRequest(node=Node(node_id=node_id)) shared_secret = generate_shared_key( self._node_private_key, self._server_public_key @@ -435,9 +417,7 @@ def test_successful_ping_with_metadata(self) -> None: def test_unsuccessful_ping_with_metadata(self) -> None: """Test server interceptor for ping unsuccessfully.""" # Prepare - node_id = self.state.create_node( - ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) - ) + node_id = self._create_node_and_set_public_key() request = PingRequest(node=Node(node_id=node_id)) node_private_key, _ = generate_key_pairs() shared_secret = generate_shared_key(node_private_key, self._server_public_key) @@ -458,65 +438,8 @@ def test_unsuccessful_ping_with_metadata(self) -> None: ), ) - def test_successful_restore_node(self) -> None: - """Test server interceptor for restoring node.""" - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._node_public_key) - ) - response, call = self._create_node.with_call( - request=CreateNodeRequest(), - metadata=((_PUBLIC_KEY_HEADER, public_key_bytes),), - ) - - expected_metadata = ( - _PUBLIC_KEY_HEADER, - base64.urlsafe_b64encode( - public_key_to_bytes(self._server_public_key) - ).decode(), - ) - - node = response.node - node_node_id = node.node_id - - assert call.initial_metadata()[0] == expected_metadata - assert isinstance(response, CreateNodeResponse) - - request = DeleteNodeRequest(node=node) - shared_secret = generate_shared_key( - self._node_private_key, self._server_public_key - ) - hmac_value = base64.urlsafe_b64encode( - compute_hmac(shared_secret, request.SerializeToString(deterministic=True)) - ) - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._node_public_key) - ) - response, call = self._delete_node.with_call( - request=request, - metadata=( - (_PUBLIC_KEY_HEADER, public_key_bytes), - (_AUTH_TOKEN_HEADER, hmac_value), - ), - ) - - assert isinstance(response, DeleteNodeResponse) - assert grpc.StatusCode.OK == call.code() - - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._node_public_key) - ) - response, call = self._create_node.with_call( - request=CreateNodeRequest(), - metadata=((_PUBLIC_KEY_HEADER, public_key_bytes),), - ) - - expected_metadata = ( - _PUBLIC_KEY_HEADER, - base64.urlsafe_b64encode( - public_key_to_bytes(self._server_public_key) - ).decode(), - ) - - assert call.initial_metadata()[0] == expected_metadata - assert isinstance(response, CreateNodeResponse) - assert response.node.node_id == node_node_id + def _create_node_and_set_public_key(self) -> int: + node_id = self.state.create_node(ping_interval=30) + pk_bytes = public_key_to_bytes(self._node_public_key) + self.state.set_node_public_key(node_id, pk_bytes) + return node_id diff --git a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py index c4cbf6aaf754..ccce6cdd6e05 100644 --- a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py @@ -62,6 +62,7 @@ def __init__(self) -> None: # Map node_id to (online_until, ping_interval) self.node_ids: dict[int, tuple[float, float]] = {} self.public_key_to_node_id: dict[bytes, int] = {} + self.node_id_to_public_key: dict[int, bytes] = {} # Map run_id to RunRecord self.run_ids: dict[int, RunRecord] = {} @@ -306,9 +307,7 @@ def num_task_res(self) -> int: """ return len(self.task_res_store) - def create_node( - self, ping_interval: float - ) -> int: + def create_node(self, ping_interval: float) -> int: """Create, store in the link state, and return `node_id`.""" # Sample a random int64 as node_id node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES) @@ -318,33 +317,18 @@ def create_node( log(ERROR, "Unexpected node registration failure.") return 0 - if public_key is not None: - if ( - public_key in self.public_key_to_node_id - or node_id in self.public_key_to_node_id.values() - ): - log(ERROR, "Unexpected node registration failure.") - return 0 - - self.public_key_to_node_id[public_key] = node_id - self.node_ids[node_id] = (time.time() + ping_interval, ping_interval) return node_id - def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None: + def delete_node(self, node_id: int) -> None: """Delete a node.""" with self.lock: if node_id not in self.node_ids: raise ValueError(f"Node {node_id} not found") - if public_key is not None: - if ( - public_key not in self.public_key_to_node_id - or node_id not in self.public_key_to_node_id.values() - ): - raise ValueError("Public key or node_id not found") - - del self.public_key_to_node_id[public_key] + # Remove node ID <> public key mappings + if pk := self.node_id_to_public_key.pop(node_id, None): + del self.public_key_to_node_id[pk] del self.node_ids[node_id] @@ -367,8 +351,24 @@ def get_nodes(self, run_id: int) -> set[int]: } def set_node_public_key(self, node_id: int, public_key: bytes) -> None: - """Store `public_key` for the specified `node_id`.""" - self.public_key_to_node_id[public_key] = node_id + """Set `public_key` for the specified `node_id`.""" + with self.lock: + if node_id not in self.node_ids: + raise ValueError(f"Node {node_id} not found") + + if public_key in self.public_key_to_node_id: + raise ValueError("Public key already in use") + + self.public_key_to_node_id[public_key] = node_id + self.node_id_to_public_key[node_id] = public_key + + def get_node_public_key(self, node_id: int) -> Optional[bytes]: + """Get `public_key` for the specified `node_id`.""" + with self.lock: + if node_id not in self.node_ids: + raise ValueError(f"Node {node_id} not found") + + return self.node_id_to_public_key.get(node_id) def get_node_id(self, node_public_key: bytes) -> Optional[int]: """Retrieve stored `node_id` filtered by `node_public_keys`.""" diff --git a/src/py/flwr/server/superlink/linkstate/linkstate.py b/src/py/flwr/server/superlink/linkstate/linkstate.py index 1458ee4e4a3a..e1eccf2b8b2f 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate.py @@ -158,7 +158,7 @@ def create_node(self, ping_interval: float) -> int: """Create, store in the link state, and return `node_id`.""" @abc.abstractmethod - def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None: + def delete_node(self, node_id: int) -> None: """Remove `node_id` from the link state.""" @abc.abstractmethod @@ -173,7 +173,11 @@ def get_nodes(self, run_id: int) -> set[int]: @abc.abstractmethod def set_node_public_key(self, node_id: int, public_key: bytes) -> None: - """Store `public_key` for the specified `node_id`.""" + """Set `public_key` for the specified `node_id`.""" + + @abc.abstractmethod + def get_node_public_key(self, node_id: int) -> Optional[bytes]: + """Get `public_key` for the specified `node_id`.""" @abc.abstractmethod def get_node_id(self, node_public_key: bytes) -> Optional[int]: diff --git a/src/py/flwr/server/superlink/linkstate/linkstate_test.py b/src/py/flwr/server/superlink/linkstate/linkstate_test.py index 3edaf72ec20c..d3e391c5b62a 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate_test.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate_test.py @@ -588,7 +588,8 @@ def test_create_node_public_key(self) -> None: run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) # Execute - node_id = state.create_node(ping_interval=10, public_key=public_key) + node_id = state.create_node(ping_interval=10) + state.set_node_public_key(node_id, public_key) retrieved_node_ids = state.get_nodes(run_id) retrieved_node_id = state.get_node_id(public_key) @@ -602,15 +603,21 @@ def test_create_node_public_key_twice(self) -> None: state: LinkState = self.state_factory() public_key = b"mock" run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) - node_id = state.create_node(ping_interval=10, public_key=public_key) + node_id = state.create_node(ping_interval=10) + state.set_node_public_key(node_id, public_key) # Execute - new_node_id = state.create_node(ping_interval=10, public_key=public_key) + new_node_id = state.create_node(ping_interval=10) + try: + state.set_node_public_key(new_node_id, public_key) + except ValueError: + state.delete_node(new_node_id) + else: + raise AssertionError("Should have raised ValueError") retrieved_node_ids = state.get_nodes(run_id) retrieved_node_id = state.get_node_id(public_key) # Assert - assert new_node_id == 0 assert len(retrieved_node_ids) == 1 assert retrieved_node_id == node_id @@ -639,10 +646,11 @@ def test_delete_node_public_key(self) -> None: state: LinkState = self.state_factory() public_key = b"mock" run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) - node_id = state.create_node(ping_interval=10, public_key=public_key) + node_id = state.create_node(ping_interval=10) + state.set_node_public_key(node_id, public_key) # Execute - state.delete_node(node_id, public_key=public_key) + state.delete_node(node_id) retrieved_node_ids = state.get_nodes(run_id) retrieved_node_id = state.get_node_id(public_key) @@ -660,7 +668,7 @@ def test_delete_node_public_key_none(self) -> None: # Execute & Assert with self.assertRaises(ValueError): - state.delete_node(node_id, public_key=public_key) + state.delete_node(node_id) retrieved_node_ids = state.get_nodes(run_id) retrieved_node_id = state.get_node_id(public_key) @@ -668,25 +676,6 @@ def test_delete_node_public_key_none(self) -> None: assert len(retrieved_node_ids) == 0 assert retrieved_node_id is None - def test_delete_node_wrong_public_key(self) -> None: - """Test deleting a client node with wrong public key.""" - # Prepare - state: LinkState = self.state_factory() - public_key = b"mock" - wrong_public_key = b"mock_mock" - run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) - node_id = state.create_node(ping_interval=10, public_key=public_key) - - # Execute & Assert - with self.assertRaises(ValueError): - state.delete_node(node_id, public_key=wrong_public_key) - - retrieved_node_ids = state.get_nodes(run_id) - retrieved_node_id = state.get_node_id(public_key) - - assert len(retrieved_node_ids) == 1 - assert retrieved_node_id == node_id - def test_get_node_id_wrong_public_key(self) -> None: """Test retrieving a client node with wrong public key.""" # Prepare @@ -696,7 +685,8 @@ def test_get_node_id_wrong_public_key(self) -> None: run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) # Execute - state.create_node(ping_interval=10, public_key=public_key) + node_id = state.create_node(ping_interval=10) + state.set_node_public_key(node_id, public_key) retrieved_node_ids = state.get_nodes(run_id) retrieved_node_id = state.get_node_id(wrong_public_key) diff --git a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py index e8311dfaac5e..cc773f7b93de 100644 --- a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py @@ -72,14 +72,14 @@ SQL_CREATE_TABLE_CREDENTIAL = """ CREATE TABLE IF NOT EXISTS credential( - private_key BLOB PRIMARY KEY, - public_key BLOB + private_key BLOB PRIMARY KEY, + public_key BLOB ); """ SQL_CREATE_TABLE_PUBLIC_KEY = """ CREATE TABLE IF NOT EXISTS public_key( - public_key BLOB UNIQUE + public_key BLOB PRIMARY KEY ); """ @@ -635,9 +635,7 @@ def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]: return {UUID(row["task_id"]) for row in rows} - def create_node( - self, ping_interval: float, public_key: Optional[bytes] = None - ) -> int: + def create_node(self, ping_interval: float) -> int: """Create, store in the link state, and return `node_id`.""" # Sample a random uint64 as node_id uint64_node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES) @@ -645,13 +643,6 @@ def create_node( # Convert the uint64 value to sint64 for SQLite sint64_node_id = convert_uint64_to_sint64(uint64_node_id) - query = "SELECT node_id FROM node WHERE public_key = :public_key;" - row = self.query(query, {"public_key": public_key}) - - if len(row) > 0: - log(ERROR, "Unexpected node registration failure.") - return 0 - query = ( "INSERT INTO node " "(node_id, online_until, ping_interval, public_key) " @@ -665,7 +656,7 @@ def create_node( sint64_node_id, time.time() + ping_interval, ping_interval, - public_key, + b"", # Initialize with an empty public key ), ) except sqlite3.IntegrityError: @@ -675,7 +666,7 @@ def create_node( # Note: we need to return the uint64 value of the node_id return uint64_node_id - def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None: + def delete_node(self, node_id: int) -> None: """Delete a node.""" # Convert the uint64 value to sint64 for SQLite sint64_node_id = convert_uint64_to_sint64(node_id) @@ -683,10 +674,6 @@ def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None: query = "DELETE FROM node WHERE node_id = ?" params = (sint64_node_id,) - if public_key is not None: - query += " AND public_key = ?" - params += (public_key,) # type: ignore - if self.conn is None: raise AttributeError("LinkState is not initialized.") @@ -694,7 +681,7 @@ def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None: with self.conn: rows = self.conn.execute(query, params) if rows.rowcount < 1: - raise ValueError("Public key or node_id not found") + raise ValueError(f"Node {node_id} not found") except KeyError as exc: log(ERROR, {"query": query, "data": params, "exception": exc}) @@ -722,6 +709,41 @@ def get_nodes(self, run_id: int) -> set[int]: result: set[int] = {convert_sint64_to_uint64(row["node_id"]) for row in rows} return result + def set_node_public_key(self, node_id: int, public_key: bytes) -> None: + """Set `public_key` for the specified `node_id`.""" + # Convert the uint64 value to sint64 for SQLite + sint64_node_id = convert_uint64_to_sint64(node_id) + + # Check if the node exists in the `node` table + query = "SELECT 1 FROM node WHERE node_id = ?" + if not self.query(query, (sint64_node_id,)): + raise ValueError(f"Node {node_id} not found") + + # Check if the public key is already in use in the `node` table + query = "SELECT 1 FROM node WHERE public_key = ?" + if self.query(query, (public_key,)): + raise ValueError("Public key already in use") + + # Update the `node` table to set the public key for the given node ID + query = "UPDATE node SET public_key = ? WHERE node_id = ?" + self.query(query, (public_key, sint64_node_id)) + + def get_node_public_key(self, node_id: int) -> Optional[bytes]: + """Get `public_key` for the specified `node_id`.""" + # Convert the uint64 value to sint64 for SQLite + sint64_node_id = convert_uint64_to_sint64(node_id) + + # Query the public key for the given node_id + query = "SELECT public_key FROM node WHERE node_id = ?" + rows = self.query(query, (sint64_node_id,)) + + # If no result is found, return None + if not rows: + raise ValueError(f"Node {node_id} not found") + + # Return the public key if it is not empty, otherwise return None + return rows[0]["public_key"] or None + def get_node_id(self, node_public_key: bytes) -> Optional[int]: """Retrieve stored `node_id` filtered by `node_public_keys`.""" query = "SELECT node_id FROM node WHERE public_key = :public_key;" From dc260850f3f358fa3c3aaeafcb7f4e3eccf6b1bf Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Sun, 5 Jan 2025 18:06:40 +0000 Subject: [PATCH 3/3] rm unnecessary test --- .../superlink/linkstate/linkstate_test.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/src/py/flwr/server/superlink/linkstate/linkstate_test.py b/src/py/flwr/server/superlink/linkstate/linkstate_test.py index d3e391c5b62a..fd1051e1cbfc 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate_test.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate_test.py @@ -658,24 +658,6 @@ def test_delete_node_public_key(self) -> None: assert len(retrieved_node_ids) == 0 assert retrieved_node_id is None - def test_delete_node_public_key_none(self) -> None: - """Test deleting a client node with public key.""" - # Prepare - state: LinkState = self.state_factory() - public_key = b"mock" - run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) - node_id = 0 - - # Execute & Assert - with self.assertRaises(ValueError): - state.delete_node(node_id) - - retrieved_node_ids = state.get_nodes(run_id) - retrieved_node_id = state.get_node_id(public_key) - - assert len(retrieved_node_ids) == 0 - assert retrieved_node_id is None - def test_get_node_id_wrong_public_key(self) -> None: """Test retrieving a client node with wrong public key.""" # Prepare