-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathworkaround.py
107 lines (92 loc) · 3.51 KB
/
workaround.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import flwr
import grpc
from flwr.client.grpc_client.connection import \
(contextmanager,
GRPC_MAX_MESSAGE_LENGTH,
Iterator,
Tuple,
Callable,
ServerMessage,
Queue,
ClientMessage,
Optional,
log,
INFO,
on_channel_state_change,
FlowerServiceStub,
DEBUG,
)
@contextmanager
def grpc_connection(
server_address: str,
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
root_certificates: Optional[bytes] = None,
) -> Iterator[Tuple[Callable[[], ServerMessage], Callable[[ClientMessage], None]]]:
"""Establish an insecure gRPC connection to a gRPC server.
Parameters
----------
server_address : str
The IPv6 address of the server. If the Flower server runs on the same machine
on port 8080, then `server_address` would be `"[::]:8080"`.
max_message_length : int
The maximum length of gRPC messages that can be exchanged with the Flower
server. The default should be sufficient for most models. Users who train
very large models might need to increase this value. Note that the Flower
server needs to be started with the same value
(see `flwr.server.start_server`), otherwise it will not know about the
increased limit and block larger messages.
(default: 536_870_912, this equals 512MB)
root_certificates : Optional[bytes] (default: None)
The PEM-encoded root certificates as a byte string. If provided, a secure
connection using the certificates will be established to a SSL-enabled
Flower server
Returns
-------
receive, send : Callable, Callable
Examples
--------
Establishing a SSL-enabled connection to the server:
>>> from pathlib import Path
>>> with grpc_connection(
>>> server_address,
>>> max_message_length=max_message_length,
>>> root_certificates=Path("/crts/root.pem").read_bytes(),
>>> ) as conn:
>>> receive, send = conn
>>> server_message = receive()
>>> # do something here
>>> send(client_message)
"""
# Possible options:
# https://github.com/grpc/grpc/blob/v1.43.x/include/grpc/impl/codegen/grpc_types.h
channel_options = [
("grpc.max_send_message_length", max_message_length),
("grpc.max_receive_message_length", max_message_length),
]
if root_certificates is not None:
ssl_channel_credentials = grpc.ssl_channel_credentials(root_certificates)
channel = grpc.secure_channel(
server_address, ssl_channel_credentials, options=channel_options
)
log(INFO, "Opened secure gRPC connection using certificates")
else:
channel = grpc.insecure_channel(server_address, options=channel_options)
log(INFO, "Opened insecure gRPC connection (no certificates were passed)")
channel.subscribe(on_channel_state_change)
queue: Queue[ClientMessage] = Queue( # pylint: disable=unsubscriptable-object
maxsize=1
)
stub = FlowerServiceStub(channel)
server_message_iterator: Iterator[ServerMessage] = stub.Join(
iter(queue.get, None), wait_for_ready=True,
)
receive: Callable[[], ServerMessage] = lambda: next(server_message_iterator)
send: Callable[[ClientMessage], None] = lambda msg: queue.put(msg, block=False)
try:
yield (receive, send)
finally:
# Make sure to have a final
channel.close()
log(DEBUG, "gRPC channel closed")
# Patch it in
flwr.client.app.grpc_connection = grpc_connection