Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(framework) Add set_node_public_key and get_node_public_key methods to LinkState #4765

Merged
merged 8 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -223,5 +223,6 @@ def _create_authenticated_node(
# No `node_id` exists for the provided `public_key`
# Handle `CreateNode` here instead of calling the default method handler
# Note: the innermost `CreateNode` method will never be called
panh99 marked this conversation as resolved.
Show resolved Hide resolved
node_id = state.create_node(request.ping_interval, public_key_bytes)
node_id = state.create_node(request.ping_interval)
state.set_node_public_key(node_id, public_key_bytes)
panh99 marked this conversation as resolved.
Show resolved Hide resolved
danielnugraha marked this conversation as resolved.
Show resolved Hide resolved
return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,7 @@ def test_unsuccessful_create_node_with_metadata(self) -> None:
def test_successful_delete_node_with_metadata(self) -> None:
"""Test server interceptor for deleting node."""
# Prepare
node_id = self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._node_public_key)
)
node_id = self._create_node_and_set_public_key()
request = DeleteNodeRequest(node=Node(node_id=node_id))
shared_secret = generate_shared_key(
self._node_private_key, self._server_public_key
Expand Down Expand Up @@ -191,9 +189,7 @@ def test_successful_delete_node_with_metadata(self) -> None:
def test_unsuccessful_delete_node_with_metadata(self) -> None:
"""Test server interceptor for deleting node unsuccessfully."""
# Prepare
node_id = self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._node_public_key)
)
node_id = self._create_node_and_set_public_key()
request = DeleteNodeRequest(node=Node(node_id=node_id))
node_private_key, _ = generate_key_pairs()
shared_secret = generate_shared_key(node_private_key, self._server_public_key)
Expand All @@ -217,9 +213,7 @@ def test_unsuccessful_delete_node_with_metadata(self) -> None:
def test_successful_pull_task_ins_with_metadata(self) -> None:
"""Test server interceptor for pull task ins."""
# Prepare
node_id = self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._node_public_key)
)
node_id = self._create_node_and_set_public_key()
request = PullTaskInsRequest(node=Node(node_id=node_id))
shared_secret = generate_shared_key(
self._node_private_key, self._server_public_key
Expand Down Expand Up @@ -247,9 +241,7 @@ def test_successful_pull_task_ins_with_metadata(self) -> None:
def test_unsuccessful_pull_task_ins_with_metadata(self) -> None:
"""Test server interceptor for pull task ins unsuccessfully."""
# Prepare
node_id = self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._node_public_key)
)
node_id = self._create_node_and_set_public_key()
request = PullTaskInsRequest(node=Node(node_id=node_id))
node_private_key, _ = generate_key_pairs()
shared_secret = generate_shared_key(node_private_key, self._server_public_key)
Expand All @@ -273,9 +265,7 @@ def test_unsuccessful_pull_task_ins_with_metadata(self) -> None:
def test_successful_push_task_res_with_metadata(self) -> None:
"""Test server interceptor for push task res."""
# Prepare
node_id = self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._node_public_key)
)
node_id = self._create_node_and_set_public_key()
run_id = self.state.create_run("", "", "", {}, ConfigsRecord())
# Transition status to running. PushTaskRes is only allowed in running status.
_ = self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", ""))
Expand Down Expand Up @@ -311,9 +301,7 @@ def test_successful_push_task_res_with_metadata(self) -> None:
def test_unsuccessful_push_task_res_with_metadata(self) -> None:
"""Test server interceptor for push task res unsuccessfully."""
# Prepare
node_id = self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._node_public_key)
)
node_id = self._create_node_and_set_public_key()
run_id = self.state.create_run("", "", "", {}, ConfigsRecord())
# Transition status to running. PushTaskRes is only allowed in running status.
_ = self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", ""))
Expand Down Expand Up @@ -344,9 +332,7 @@ def test_unsuccessful_push_task_res_with_metadata(self) -> None:
def test_successful_get_run_with_metadata(self) -> None:
"""Test server interceptor for get run."""
# Prepare
self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._node_public_key)
)
self._create_node_and_set_public_key()
run_id = self.state.create_run("", "", "", {}, ConfigsRecord())
# Transition status to running. GetRun is only allowed in running status.
_ = self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", ""))
Expand Down Expand Up @@ -378,9 +364,7 @@ def test_successful_get_run_with_metadata(self) -> None:
def test_unsuccessful_get_run_with_metadata(self) -> None:
"""Test server interceptor for get run unsuccessfully."""
# Prepare
self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._node_public_key)
)
self._create_node_and_set_public_key()
run_id = self.state.create_run("", "", "", {}, ConfigsRecord())
request = GetRunRequest(run_id=run_id)
node_private_key, _ = generate_key_pairs()
Expand All @@ -405,9 +389,7 @@ def test_unsuccessful_get_run_with_metadata(self) -> None:
def test_successful_ping_with_metadata(self) -> None:
"""Test server interceptor for ping."""
# Prepare
node_id = self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._node_public_key)
)
node_id = self._create_node_and_set_public_key()
request = PingRequest(node=Node(node_id=node_id))
shared_secret = generate_shared_key(
self._node_private_key, self._server_public_key
Expand Down Expand Up @@ -435,9 +417,7 @@ def test_successful_ping_with_metadata(self) -> None:
def test_unsuccessful_ping_with_metadata(self) -> None:
"""Test server interceptor for ping unsuccessfully."""
# Prepare
node_id = self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._node_public_key)
)
node_id = self._create_node_and_set_public_key()
request = PingRequest(node=Node(node_id=node_id))
node_private_key, _ = generate_key_pairs()
shared_secret = generate_shared_key(node_private_key, self._server_public_key)
Expand All @@ -458,65 +438,8 @@ def test_unsuccessful_ping_with_metadata(self) -> None:
),
)

def test_successful_restore_node(self) -> None:
panh99 marked this conversation as resolved.
Show resolved Hide resolved
"""Test server interceptor for restoring node."""
public_key_bytes = base64.urlsafe_b64encode(
public_key_to_bytes(self._node_public_key)
)
response, call = self._create_node.with_call(
request=CreateNodeRequest(),
metadata=((_PUBLIC_KEY_HEADER, public_key_bytes),),
)

expected_metadata = (
_PUBLIC_KEY_HEADER,
base64.urlsafe_b64encode(
public_key_to_bytes(self._server_public_key)
).decode(),
)

node = response.node
node_node_id = node.node_id

assert call.initial_metadata()[0] == expected_metadata
assert isinstance(response, CreateNodeResponse)

request = DeleteNodeRequest(node=node)
shared_secret = generate_shared_key(
self._node_private_key, self._server_public_key
)
hmac_value = base64.urlsafe_b64encode(
compute_hmac(shared_secret, request.SerializeToString(deterministic=True))
)
public_key_bytes = base64.urlsafe_b64encode(
public_key_to_bytes(self._node_public_key)
)
response, call = self._delete_node.with_call(
request=request,
metadata=(
(_PUBLIC_KEY_HEADER, public_key_bytes),
(_AUTH_TOKEN_HEADER, hmac_value),
),
)

assert isinstance(response, DeleteNodeResponse)
assert grpc.StatusCode.OK == call.code()

public_key_bytes = base64.urlsafe_b64encode(
public_key_to_bytes(self._node_public_key)
)
response, call = self._create_node.with_call(
request=CreateNodeRequest(),
metadata=((_PUBLIC_KEY_HEADER, public_key_bytes),),
)

expected_metadata = (
_PUBLIC_KEY_HEADER,
base64.urlsafe_b64encode(
public_key_to_bytes(self._server_public_key)
).decode(),
)

assert call.initial_metadata()[0] == expected_metadata
assert isinstance(response, CreateNodeResponse)
assert response.node.node_id == node_node_id
def _create_node_and_set_public_key(self) -> int:
node_id = self.state.create_node(ping_interval=30)
pk_bytes = public_key_to_bytes(self._node_public_key)
self.state.set_node_public_key(node_id, pk_bytes)
return node_id
48 changes: 26 additions & 22 deletions src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(self) -> None:
# Map node_id to (online_until, ping_interval)
self.node_ids: dict[int, tuple[float, float]] = {}
self.public_key_to_node_id: dict[bytes, int] = {}
danielnugraha marked this conversation as resolved.
Show resolved Hide resolved
self.node_id_to_public_key: dict[int, bytes] = {}

# Map run_id to RunRecord
self.run_ids: dict[int, RunRecord] = {}
Expand Down Expand Up @@ -306,9 +307,7 @@ def num_task_res(self) -> int:
"""
return len(self.task_res_store)

def create_node(
self, ping_interval: float, public_key: Optional[bytes] = None
) -> int:
def create_node(self, ping_interval: float) -> int:
"""Create, store in the link state, and return `node_id`."""
# Sample a random int64 as node_id
node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
Expand All @@ -318,33 +317,18 @@ def create_node(
log(ERROR, "Unexpected node registration failure.")
return 0

if public_key is not None:
if (
public_key in self.public_key_to_node_id
or node_id in self.public_key_to_node_id.values()
):
log(ERROR, "Unexpected node registration failure.")
return 0

self.public_key_to_node_id[public_key] = node_id

self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
return node_id

def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
def delete_node(self, node_id: int) -> None:
"""Delete a node."""
with self.lock:
if node_id not in self.node_ids:
raise ValueError(f"Node {node_id} not found")

if public_key is not None:
if (
public_key not in self.public_key_to_node_id
or node_id not in self.public_key_to_node_id.values()
):
raise ValueError("Public key or node_id not found")

del self.public_key_to_node_id[public_key]
# Remove node ID <> public key mappings
if pk := self.node_id_to_public_key.pop(node_id, None):
del self.public_key_to_node_id[pk]

del self.node_ids[node_id]

Expand All @@ -366,6 +350,26 @@ def get_nodes(self, run_id: int) -> set[int]:
if online_until > current_time
}

def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
"""Set `public_key` for the specified `node_id`."""
with self.lock:
if node_id not in self.node_ids:
raise ValueError(f"Node {node_id} not found")

if public_key in self.public_key_to_node_id:
raise ValueError("Public key already in use")

self.public_key_to_node_id[public_key] = node_id
self.node_id_to_public_key[node_id] = public_key

def get_node_public_key(self, node_id: int) -> Optional[bytes]:
"""Get `public_key` for the specified `node_id`."""
with self.lock:
if node_id not in self.node_ids:
raise ValueError(f"Node {node_id} not found")

return self.node_id_to_public_key.get(node_id)

def get_node_id(self, node_public_key: bytes) -> Optional[int]:
"""Retrieve stored `node_id` filtered by `node_public_keys`."""
return self.public_key_to_node_id.get(node_public_key)
Expand Down
14 changes: 10 additions & 4 deletions src/py/flwr/server/superlink/linkstate/linkstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,11 @@ def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
"""Get all TaskIns IDs for the given run_id."""

@abc.abstractmethod
def create_node(
self, ping_interval: float, public_key: Optional[bytes] = None
) -> int:
def create_node(self, ping_interval: float) -> int:
"""Create, store in the link state, and return `node_id`."""

@abc.abstractmethod
def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
def delete_node(self, node_id: int) -> None:
"""Remove `node_id` from the link state."""

@abc.abstractmethod
Expand All @@ -173,6 +171,14 @@ def get_nodes(self, run_id: int) -> set[int]:
an empty `Set` MUST be returned.
"""

@abc.abstractmethod
def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
"""Set `public_key` for the specified `node_id`."""

@abc.abstractmethod
def get_node_public_key(self, node_id: int) -> Optional[bytes]:
"""Get `public_key` for the specified `node_id`."""

@abc.abstractmethod
def get_node_id(self, node_public_key: bytes) -> Optional[int]:
"""Retrieve stored `node_id` filtered by `node_public_keys`."""
Expand Down
44 changes: 17 additions & 27 deletions src/py/flwr/server/superlink/linkstate/linkstate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,8 @@ def test_create_node_public_key(self) -> None:
run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord())

# Execute
node_id = state.create_node(ping_interval=10, public_key=public_key)
node_id = state.create_node(ping_interval=10)
state.set_node_public_key(node_id, public_key)
retrieved_node_ids = state.get_nodes(run_id)
retrieved_node_id = state.get_node_id(public_key)

Expand All @@ -602,15 +603,21 @@ def test_create_node_public_key_twice(self) -> None:
state: LinkState = self.state_factory()
public_key = b"mock"
run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord())
node_id = state.create_node(ping_interval=10, public_key=public_key)
node_id = state.create_node(ping_interval=10)
state.set_node_public_key(node_id, public_key)

# Execute
new_node_id = state.create_node(ping_interval=10, public_key=public_key)
new_node_id = state.create_node(ping_interval=10)
try:
state.set_node_public_key(new_node_id, public_key)
except ValueError:
state.delete_node(new_node_id)
else:
raise AssertionError("Should have raised ValueError")
retrieved_node_ids = state.get_nodes(run_id)
retrieved_node_id = state.get_node_id(public_key)

# Assert
assert new_node_id == 0
assert len(retrieved_node_ids) == 1
assert retrieved_node_id == node_id

Expand Down Expand Up @@ -639,10 +646,11 @@ def test_delete_node_public_key(self) -> None:
state: LinkState = self.state_factory()
public_key = b"mock"
run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord())
node_id = state.create_node(ping_interval=10, public_key=public_key)
node_id = state.create_node(ping_interval=10)
state.set_node_public_key(node_id, public_key)

# Execute
state.delete_node(node_id, public_key=public_key)
state.delete_node(node_id)
retrieved_node_ids = state.get_nodes(run_id)
retrieved_node_id = state.get_node_id(public_key)

Expand All @@ -660,33 +668,14 @@ def test_delete_node_public_key_none(self) -> None:

# Execute & Assert
with self.assertRaises(ValueError):
state.delete_node(node_id, public_key=public_key)
state.delete_node(node_id)

retrieved_node_ids = state.get_nodes(run_id)
retrieved_node_id = state.get_node_id(public_key)

assert len(retrieved_node_ids) == 0
assert retrieved_node_id is None

def test_delete_node_wrong_public_key(self) -> None:
"""Test deleting a client node with wrong public key."""
# Prepare
state: LinkState = self.state_factory()
public_key = b"mock"
wrong_public_key = b"mock_mock"
run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord())
node_id = state.create_node(ping_interval=10, public_key=public_key)

# Execute & Assert
with self.assertRaises(ValueError):
state.delete_node(node_id, public_key=wrong_public_key)

retrieved_node_ids = state.get_nodes(run_id)
retrieved_node_id = state.get_node_id(public_key)

assert len(retrieved_node_ids) == 1
assert retrieved_node_id == node_id

def test_get_node_id_wrong_public_key(self) -> None:
"""Test retrieving a client node with wrong public key."""
# Prepare
Expand All @@ -696,7 +685,8 @@ def test_get_node_id_wrong_public_key(self) -> None:
run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord())

# Execute
state.create_node(ping_interval=10, public_key=public_key)
node_id = state.create_node(ping_interval=10)
state.set_node_public_key(node_id, public_key)
retrieved_node_ids = state.get_nodes(run_id)
retrieved_node_id = state.get_node_id(wrong_public_key)

Expand Down
Loading