Skip to content

Commit

Permalink
feat(framework) Add set_node_public_key and get_node_public_key m…
Browse files Browse the repository at this point in the history
…ethods to `LinkState` (#4765)
  • Loading branch information
panh99 authored Jan 9, 2025
1 parent 9446ef2 commit c41de46
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 183 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -223,5 +223,6 @@ def _create_authenticated_node(
# No `node_id` exists for the provided `public_key`
# Handle `CreateNode` here instead of calling the default method handler
# Note: the innermost `CreateNode` method will never be called
node_id = state.create_node(request.ping_interval, public_key_bytes)
node_id = state.create_node(request.ping_interval)
state.set_node_public_key(node_id, public_key_bytes)
return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,7 @@ def test_unsuccessful_create_node_with_metadata(self) -> None:
def test_successful_delete_node_with_metadata(self) -> None:
"""Test server interceptor for deleting node."""
# Prepare
node_id = self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._node_public_key)
)
node_id = self._create_node_and_set_public_key()
request = DeleteNodeRequest(node=Node(node_id=node_id))
shared_secret = generate_shared_key(
self._node_private_key, self._server_public_key
Expand Down Expand Up @@ -191,9 +189,7 @@ def test_successful_delete_node_with_metadata(self) -> None:
def test_unsuccessful_delete_node_with_metadata(self) -> None:
"""Test server interceptor for deleting node unsuccessfully."""
# Prepare
node_id = self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._node_public_key)
)
node_id = self._create_node_and_set_public_key()
request = DeleteNodeRequest(node=Node(node_id=node_id))
node_private_key, _ = generate_key_pairs()
shared_secret = generate_shared_key(node_private_key, self._server_public_key)
Expand All @@ -217,9 +213,7 @@ def test_unsuccessful_delete_node_with_metadata(self) -> None:
def test_successful_pull_task_ins_with_metadata(self) -> None:
"""Test server interceptor for pull task ins."""
# Prepare
node_id = self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._node_public_key)
)
node_id = self._create_node_and_set_public_key()
request = PullTaskInsRequest(node=Node(node_id=node_id))
shared_secret = generate_shared_key(
self._node_private_key, self._server_public_key
Expand Down Expand Up @@ -247,9 +241,7 @@ def test_successful_pull_task_ins_with_metadata(self) -> None:
def test_unsuccessful_pull_task_ins_with_metadata(self) -> None:
"""Test server interceptor for pull task ins unsuccessfully."""
# Prepare
node_id = self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._node_public_key)
)
node_id = self._create_node_and_set_public_key()
request = PullTaskInsRequest(node=Node(node_id=node_id))
node_private_key, _ = generate_key_pairs()
shared_secret = generate_shared_key(node_private_key, self._server_public_key)
Expand All @@ -273,9 +265,7 @@ def test_unsuccessful_pull_task_ins_with_metadata(self) -> None:
def test_successful_push_task_res_with_metadata(self) -> None:
"""Test server interceptor for push task res."""
# Prepare
node_id = self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._node_public_key)
)
node_id = self._create_node_and_set_public_key()
run_id = self.state.create_run("", "", "", {}, ConfigsRecord())
# Transition status to running. PushTaskRes is only allowed in running status.
_ = self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", ""))
Expand Down Expand Up @@ -311,9 +301,7 @@ def test_successful_push_task_res_with_metadata(self) -> None:
def test_unsuccessful_push_task_res_with_metadata(self) -> None:
"""Test server interceptor for push task res unsuccessfully."""
# Prepare
node_id = self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._node_public_key)
)
node_id = self._create_node_and_set_public_key()
run_id = self.state.create_run("", "", "", {}, ConfigsRecord())
# Transition status to running. PushTaskRes is only allowed in running status.
_ = self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", ""))
Expand Down Expand Up @@ -344,9 +332,7 @@ def test_unsuccessful_push_task_res_with_metadata(self) -> None:
def test_successful_get_run_with_metadata(self) -> None:
"""Test server interceptor for get run."""
# Prepare
self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._node_public_key)
)
self._create_node_and_set_public_key()
run_id = self.state.create_run("", "", "", {}, ConfigsRecord())
# Transition status to running. GetRun is only allowed in running status.
_ = self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", ""))
Expand Down Expand Up @@ -378,9 +364,7 @@ def test_successful_get_run_with_metadata(self) -> None:
def test_unsuccessful_get_run_with_metadata(self) -> None:
"""Test server interceptor for get run unsuccessfully."""
# Prepare
self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._node_public_key)
)
self._create_node_and_set_public_key()
run_id = self.state.create_run("", "", "", {}, ConfigsRecord())
request = GetRunRequest(run_id=run_id)
node_private_key, _ = generate_key_pairs()
Expand All @@ -405,9 +389,7 @@ def test_unsuccessful_get_run_with_metadata(self) -> None:
def test_successful_ping_with_metadata(self) -> None:
"""Test server interceptor for ping."""
# Prepare
node_id = self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._node_public_key)
)
node_id = self._create_node_and_set_public_key()
request = PingRequest(node=Node(node_id=node_id))
shared_secret = generate_shared_key(
self._node_private_key, self._server_public_key
Expand Down Expand Up @@ -435,9 +417,7 @@ def test_successful_ping_with_metadata(self) -> None:
def test_unsuccessful_ping_with_metadata(self) -> None:
"""Test server interceptor for ping unsuccessfully."""
# Prepare
node_id = self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._node_public_key)
)
node_id = self._create_node_and_set_public_key()
request = PingRequest(node=Node(node_id=node_id))
node_private_key, _ = generate_key_pairs()
shared_secret = generate_shared_key(node_private_key, self._server_public_key)
Expand All @@ -458,65 +438,8 @@ def test_unsuccessful_ping_with_metadata(self) -> None:
),
)

def test_successful_restore_node(self) -> None:
"""Test server interceptor for restoring node."""
public_key_bytes = base64.urlsafe_b64encode(
public_key_to_bytes(self._node_public_key)
)
response, call = self._create_node.with_call(
request=CreateNodeRequest(),
metadata=((_PUBLIC_KEY_HEADER, public_key_bytes),),
)

expected_metadata = (
_PUBLIC_KEY_HEADER,
base64.urlsafe_b64encode(
public_key_to_bytes(self._server_public_key)
).decode(),
)

node = response.node
node_node_id = node.node_id

assert call.initial_metadata()[0] == expected_metadata
assert isinstance(response, CreateNodeResponse)

request = DeleteNodeRequest(node=node)
shared_secret = generate_shared_key(
self._node_private_key, self._server_public_key
)
hmac_value = base64.urlsafe_b64encode(
compute_hmac(shared_secret, request.SerializeToString(deterministic=True))
)
public_key_bytes = base64.urlsafe_b64encode(
public_key_to_bytes(self._node_public_key)
)
response, call = self._delete_node.with_call(
request=request,
metadata=(
(_PUBLIC_KEY_HEADER, public_key_bytes),
(_AUTH_TOKEN_HEADER, hmac_value),
),
)

assert isinstance(response, DeleteNodeResponse)
assert grpc.StatusCode.OK == call.code()

public_key_bytes = base64.urlsafe_b64encode(
public_key_to_bytes(self._node_public_key)
)
response, call = self._create_node.with_call(
request=CreateNodeRequest(),
metadata=((_PUBLIC_KEY_HEADER, public_key_bytes),),
)

expected_metadata = (
_PUBLIC_KEY_HEADER,
base64.urlsafe_b64encode(
public_key_to_bytes(self._server_public_key)
).decode(),
)

assert call.initial_metadata()[0] == expected_metadata
assert isinstance(response, CreateNodeResponse)
assert response.node.node_id == node_node_id
def _create_node_and_set_public_key(self) -> int:
node_id = self.state.create_node(ping_interval=30)
pk_bytes = public_key_to_bytes(self._node_public_key)
self.state.set_node_public_key(node_id, pk_bytes)
return node_id
48 changes: 26 additions & 22 deletions src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(self) -> None:
# Map node_id to (online_until, ping_interval)
self.node_ids: dict[int, tuple[float, float]] = {}
self.public_key_to_node_id: dict[bytes, int] = {}
self.node_id_to_public_key: dict[int, bytes] = {}

# Map run_id to RunRecord
self.run_ids: dict[int, RunRecord] = {}
Expand Down Expand Up @@ -306,9 +307,7 @@ def num_task_res(self) -> int:
"""
return len(self.task_res_store)

def create_node(
self, ping_interval: float, public_key: Optional[bytes] = None
) -> int:
def create_node(self, ping_interval: float) -> int:
"""Create, store in the link state, and return `node_id`."""
# Sample a random int64 as node_id
node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
Expand All @@ -318,33 +317,18 @@ def create_node(
log(ERROR, "Unexpected node registration failure.")
return 0

if public_key is not None:
if (
public_key in self.public_key_to_node_id
or node_id in self.public_key_to_node_id.values()
):
log(ERROR, "Unexpected node registration failure.")
return 0

self.public_key_to_node_id[public_key] = node_id

self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
return node_id

def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
def delete_node(self, node_id: int) -> None:
"""Delete a node."""
with self.lock:
if node_id not in self.node_ids:
raise ValueError(f"Node {node_id} not found")

if public_key is not None:
if (
public_key not in self.public_key_to_node_id
or node_id not in self.public_key_to_node_id.values()
):
raise ValueError("Public key or node_id not found")

del self.public_key_to_node_id[public_key]
# Remove node ID <> public key mappings
if pk := self.node_id_to_public_key.pop(node_id, None):
del self.public_key_to_node_id[pk]

del self.node_ids[node_id]

Expand All @@ -366,6 +350,26 @@ def get_nodes(self, run_id: int) -> set[int]:
if online_until > current_time
}

def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
"""Set `public_key` for the specified `node_id`."""
with self.lock:
if node_id not in self.node_ids:
raise ValueError(f"Node {node_id} not found")

if public_key in self.public_key_to_node_id:
raise ValueError("Public key already in use")

self.public_key_to_node_id[public_key] = node_id
self.node_id_to_public_key[node_id] = public_key

def get_node_public_key(self, node_id: int) -> Optional[bytes]:
"""Get `public_key` for the specified `node_id`."""
with self.lock:
if node_id not in self.node_ids:
raise ValueError(f"Node {node_id} not found")

return self.node_id_to_public_key.get(node_id)

def get_node_id(self, node_public_key: bytes) -> Optional[int]:
"""Retrieve stored `node_id` filtered by `node_public_keys`."""
return self.public_key_to_node_id.get(node_public_key)
Expand Down
14 changes: 10 additions & 4 deletions src/py/flwr/server/superlink/linkstate/linkstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,11 @@ def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
"""Get all TaskIns IDs for the given run_id."""

@abc.abstractmethod
def create_node(
self, ping_interval: float, public_key: Optional[bytes] = None
) -> int:
def create_node(self, ping_interval: float) -> int:
"""Create, store in the link state, and return `node_id`."""

@abc.abstractmethod
def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
def delete_node(self, node_id: int) -> None:
"""Remove `node_id` from the link state."""

@abc.abstractmethod
Expand All @@ -173,6 +171,14 @@ def get_nodes(self, run_id: int) -> set[int]:
an empty `Set` MUST be returned.
"""

@abc.abstractmethod
def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
"""Set `public_key` for the specified `node_id`."""

@abc.abstractmethod
def get_node_public_key(self, node_id: int) -> Optional[bytes]:
"""Get `public_key` for the specified `node_id`."""

@abc.abstractmethod
def get_node_id(self, node_public_key: bytes) -> Optional[int]:
"""Retrieve stored `node_id` filtered by `node_public_keys`."""
Expand Down
Loading

0 comments on commit c41de46

Please sign in to comment.