-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathucx.py
538 lines (462 loc) · 19.5 KB
/
ucx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import concurrent.futures as futures
import functools
import logging
import os
import weakref
from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Type
import cloudpickle
import numpy as np
from ...nvutils import get_cuda_context, get_index_and_uuid
from ...serialization import deserialize
from ...serialization.aio import BUFFER_SIZES_NAME, AioSerializer, get_header_length
from ...utils import classproperty, implements, is_cuda_buffer, is_v6_ip, lazy_import
from ..message import _MessageBase
from .base import Channel, ChannelType, Client, Server
from .core import register_client, register_server
from .errors import ChannelClosed
ucp = lazy_import("ucxx")
numba_cuda = lazy_import("numba.cuda")
rmm = lazy_import("rmm")
_warning_suffix = (
"This is often the result of a CUDA-enabled library calling a CUDA runtime function before "
"spawning worker processes. Please make sure any such function calls don't happen "
"at import time or in the global scope of a program."
)
logger = logging.getLogger(__name__)
def synchronize_stream(stream: int = 0):
ctx = numba_cuda.current_context()
cu_stream = numba_cuda.driver.drvapi.cu_stream(stream)
stream = numba_cuda.driver.Stream(ctx, cu_stream, None)
stream.synchronize() # type: ignore
class UCXInitializer:
_inited = False
@staticmethod
def _get_options(ucx_config: dict) -> Tuple[dict, dict]:
"""
Get options and envs from ucx options in oscar config
"""
options = dict()
envs = dict()
# if any of the flags are set, as long as they are not Null/None,
# we assume we should configure basic TLS settings for UCX, otherwise we
# leave UCX to its default configuration
if any(ucx_config.get(name) for name in ["tcp", "nvlink", "infiniband"]):
if ucx_config.get("rdmacm"): # pragma: no cover
tls = "tcp"
tls_priority = "rdmacm"
else:
tls = "tcp"
tls_priority = "tcp"
# CUDA COPY can optionally be used with ucx -- we rely on the user
# to define when messages will include CUDA objects. Note:
# defining only the Infiniband flag will not enable cuda_copy
if any(
ucx_config.get(name) for name in ["nvlink", "cuda-copy"]
): # pragma: no cover
tls += ",cuda_copy"
if ucx_config.get("infiniband"): # pragma: no cover
tls = "ib," + tls
if ucx_config.get("nvlink"): # pragma: no cover
tls += ",cuda_ipc"
options["TLS"] = tls
options["SOCKADDR_TLS_PRIORITY"] = tls_priority
elif "UCX_TLS" in os.environ: # pragma: no cover
options["TLS"] = os.environ["UCX_TLS"]
for k, v in ucx_config.get("environment", dict()).items(): # pragma: no cover
# {"some-name": value} is translated to {"UCX_SOME_NAME": value}
key = f'UCX_{"_".join(s.upper() for s in k.split("-"))}'
opt_key = key[4:]
if opt_key in options:
logger.warning(
f"Ignoring {k}={v} (key={key}) in ucx.environment, "
f"preferring {opt_key}={options[opt_key]} "
"from high level options"
)
elif key in os.environ:
# This is only info because setting UCX configuration via
# environment variables is a reasonably common approach
logger.info(
f"Ignoring {k}={v} (key={key}) in ucx.environment, "
f"preferring {key}={os.environ[key]} from external environment"
)
else:
envs[key] = v
return options, envs
@staticmethod
def init(ucx_config: dict):
if UCXInitializer._inited:
return
options, envs = UCXInitializer._get_options(ucx_config)
# We ensure the CUDA context is created before initializing UCX. This can't
# be safely handled externally because communications start before
# preload scripts run.
# Precedence:
# 1. external environment
# 2. ucx_config (high level settings passed to ucp.init)
# 3. ucx_environment (low level settings equivalent to environment variables)
ucx_tls = os.environ.get("UCX_TLS", options.get("TLS", envs.get("UCX_TLS", "")))
if (
ucx_config.get("create-cuda-contex") is True
# This is not foolproof, if UCX_TLS=all we might require CUDA
# depending on configuration of UCX, but this is better than
# nothing
or ("cuda" in ucx_tls and "^cuda" not in ucx_tls)
):
if numba_cuda is None: # pragma: no cover
raise ImportError(
"CUDA support with UCX requires Numba for context management"
)
pre_existing_cuda_context = get_cuda_context()
if pre_existing_cuda_context.has_context:
dev = pre_existing_cuda_context.device_info
assert dev is not None
logger.warning(
f"A CUDA context for device {dev.device_index} ({str(dev.uuid)}) "
f"already exists on process ID {os.getpid()}. {_warning_suffix}"
)
numba_cuda.current_context()
cuda_context_created = get_cuda_context()
cuda_visible_device = get_index_and_uuid(
os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0]
)
if (
cuda_context_created.has_context
and cuda_context_created.device_info.uuid != cuda_visible_device.uuid # type: ignore
): # pragma: no cover
cuda_context_created_dev = cuda_context_created.device_info
assert cuda_context_created_dev is not None
logger.warning(
f"Worker with process ID {os.getpid()} should have a CUDA context assigned to device "
f"{cuda_visible_device.device_index} ({str(cuda_visible_device.uuid)}), " # type: ignore
f"but instead the CUDA context is on device {cuda_context_created_dev.device_index} "
f"({str(cuda_context_created_dev.uuid)}). {_warning_suffix}"
)
original_environ = os.environ
new_environ = os.environ.copy()
new_environ.update(envs)
os.environ = new_environ # type: ignore
try:
# let UCX determine the appropriate transports
ucp.init()
finally:
os.environ = original_environ
UCXInitializer._inited = True
@staticmethod
def reset():
ucp.reset()
UCXInitializer._inited = False
class UCXChannel(Channel):
__slots__ = (
"ucp_endpoint",
"_closed",
"_has_close_callback",
"_send_lock",
"_recv_lock",
"__weakref__",
)
name = "ucx"
def __init__(
self,
ucp_endpoint: "ucp.Endpoint", # type: ignore
local_address: str | None = None,
dest_address: str | None = None,
compression: str | None = None,
):
super().__init__(
local_address=local_address,
dest_address=dest_address,
compression=compression,
)
self.ucp_endpoint = ucp_endpoint
self._send_lock = asyncio.Lock()
self._recv_lock = asyncio.Lock()
# When the UCX endpoint closes or errors the registered callback
# is called.
if hasattr(self.ucp_endpoint, "set_close_callback"):
ref = weakref.ref(self)
self.ucp_endpoint.set_close_callback(
functools.partial(UCXChannel._close_channel, ref)
)
self._closed = False
self._has_close_callback = True
else: # pragma: no cover
self._has_close_callback = False
@staticmethod
def _close_channel(channel_ref: weakref.ReferenceType):
channel = channel_ref()
if channel is not None:
channel._closed = True
async def _serialize(self, message: Any) -> List[bytes]:
compress = self.compression or 0
serializer = AioSerializer(message, compress=compress)
return await serializer.run()
@property
@implements(Channel.type)
def type(self) -> int:
return ChannelType.remote
@implements(Channel.send)
async def send(self, message: Any):
if self.closed:
raise ChannelClosed("UCX Endpoint is closed, unable to send message")
buffers = await self._serialize(message)
return await self.send_buffers(buffers)
@implements(Channel.recv)
async def recv(self):
async with self._recv_lock:
try:
info_buffer = np.empty(11, dtype="u1").data
await self.ucp_endpoint.recv(info_buffer)
head_length = get_header_length(info_buffer)
header_buffer = np.empty(head_length, dtype="u1").data
await self.ucp_endpoint.recv(header_buffer)
header = cloudpickle.loads(header_buffer)
is_cuda_buffers = header[0].get("is_cuda_buffers")
buffer_sizes = header[0].pop(BUFFER_SIZES_NAME)
buffers = []
for is_cuda_buffer, buf_size in zip(is_cuda_buffers, buffer_sizes):
if buf_size == 0: # pragma: no cover
buffers.append(bytes())
elif is_cuda_buffer:
cuda_buffer = rmm.DeviceBuffer(size=buf_size)
await self.ucp_endpoint.recv(cuda_buffer)
buffers.append(cuda_buffer)
else:
buffer = np.empty(buf_size, dtype="u1").data
await self.ucp_endpoint.recv(buffer)
buffers.append(buffer)
except BaseException as e:
if not self._closed:
# In addition to UCX exceptions, may be CancelledError or another
# "low-level" exception. The only safe thing to do is to abort.
self.abort()
raise ChannelClosed(
f"Connection closed by writer.\nInner exception: {e!r}"
) from e
else:
raise EOFError("Server closed already")
return deserialize(header, buffers)
async def send_buffers(self, buffers: list, meta: Optional[_MessageBase] = None):
try:
# It is necessary to first synchronize the default stream before start
# sending We synchronize the default stream because UCX is not
# stream-ordered and syncing the default stream will wait for other
# non-blocking CUDA streams. Note this is only sufficient if the memory
# being sent is not currently in use on non-blocking CUDA streams.
if any(is_cuda_buffer(buf) for buf in buffers):
# has GPU buffer
synchronize_stream(0)
meta_buffers = None
if meta:
meta_buffers = await self._serialize(meta)
async with self._send_lock:
if meta_buffers:
for buf in meta_buffers:
await self.ucp_endpoint.send(buf)
for buffer in buffers:
await self.ucp_endpoint.send(buffer)
except ucp.exceptions.UCXError: # pragma: no cover
self.abort()
raise ChannelClosed("While writing, the connection was closed")
async def recv_buffers(self, buffers: list):
async with self._recv_lock:
try:
for buffer in buffers:
await self.ucp_endpoint.recv(buffer)
except BaseException as e: # pragma: no cover
if not self._closed:
# In addition to UCX exceptions, may be CancelledError or another
# "low-level" exception. The only safe thing to do is to abort.
self.abort()
raise ChannelClosed(
f"Connection closed by writer.\nInner exception: {e!r}"
) from e
else:
raise EOFError("Server closed already")
def abort(self):
self._closed = True
if self.ucp_endpoint is not None:
self.ucp_endpoint.abort()
self.ucp_endpoint = None
@implements(Channel.close)
async def close(self):
self._closed = True
if self.ucp_endpoint is not None:
await self.ucp_endpoint.close()
# abort
self.ucp_endpoint.abort()
self.ucp_endpoint = None
@property
@implements(Channel.closed)
def closed(self):
if self._has_close_callback is None: # pragma: no cover
# The self._closed flag is separate from the endpoint's lifetime, even when
# the endpoint has closed or errored, there may be messages on its buffer
# still to be received, even though sending is not possible anymore.
return self._closed
else:
return self.ucp_endpoint is None
@register_server
class UCXServer(Server):
__slots__ = "host", "port", "_ucp_listener", "_channels", "_closed"
scheme = "ucx"
_ucp_listener: "ucp.Listener" # type: ignore
_channels: set[UCXChannel]
def __init__(
self,
host: str,
port: int,
ucp_listener: "ucp.Listener", # type: ignore
channel_handler: Callable[[Channel], Coroutine] | None = None,
):
super().__init__(f"{UCXServer.scheme}://{host}:{port}", channel_handler)
self.host = host
self.port = port
self._ucp_listener = ucp_listener
self._channels = set()
self._closed = asyncio.Event()
@classproperty
@implements(Server.client_type)
def client_type(self) -> Type["Client"]:
return UCXClient
@property
@implements(Server.channel_type)
def channel_type(self) -> int:
return ChannelType.remote
@staticmethod
async def create(config: Dict) -> "Server":
config = config.copy()
if "address" in config:
address = config.pop("address")
prefix = f"{UCXServer.scheme}://"
if address.startswith(prefix):
address = address[len(prefix) :]
host, port = address.rsplit(":", 1)
port = int(port)
else:
host = config.pop("host")
port = int(config.pop("port"))
_host = host
if config.pop("listen_elastic_ip", False):
# The Actor.address will be announce to client, and is not on our host,
# cannot actually listen on it,
# so we have to keep SocketServer.host untouched to make sure Actor.address not changed
if is_v6_ip(host):
_host = "::"
else:
_host = "0.0.0.0"
handle_channel = config.pop("handle_channel")
# init
UCXInitializer.init(config.get("ucx", dict()))
async def serve_forever(client_ucp_endpoint: "ucp.Endpoint"): # type: ignore
try:
await server.on_connected(
client_ucp_endpoint, local_address="%s:%d" % (_host, port)
)
except ChannelClosed: # pragma: no cover
logger.exception("Connection closed before handshake completed")
return
ucp_listener = ucp.create_listener(serve_forever, port=port)
# get port of the ucp listener if not specified
if not port:
port = ucp_listener.port
server = UCXServer(host, port, ucp_listener, channel_handler=handle_channel)
return server
@classmethod
def parse_config(cls, config: dict) -> dict:
return config
@implements(Server.start)
async def start(self):
pass
@implements(Server.join)
async def join(self, timeout=None):
wait_coro = self._closed.wait()
try:
await asyncio.wait_for(wait_coro, timeout=timeout)
except (futures.TimeoutError, asyncio.TimeoutError):
pass
@implements(Server.on_connected)
async def on_connected(self, *args, **kwargs):
(ucp_endpoint,) = args
local_address = kwargs.pop("local_address", None)
dest_address = kwargs.pop("dest_address", None)
if kwargs: # pragma: no cover
raise TypeError(
f"{type(self).__name__} got unexpected "
f'arguments: {",".join(kwargs)}'
)
channel = UCXChannel(
ucp_endpoint, local_address=local_address, dest_address=dest_address
)
self._channels.add(channel)
# handle over channel to some handlers
try:
await self.channel_handler(channel)
finally:
if not channel.closed:
await channel.close()
# Remove channel if channel exit
self._channels.discard(channel)
logger.debug("Channel exit: %s", channel.info)
@implements(Server.stop)
async def stop(self):
self._ucp_listener.close()
# close all channels
await asyncio.gather(
*(channel.close() for channel in self._channels if not channel.closed)
)
self._channels.clear()
self._ucp_listener = None
self._closed.set()
@property
@implements(Server.stopped)
def stopped(self) -> bool:
return self._ucp_listener is None
@register_client
class UCXClient(Client):
__slots__ = ()
scheme = UCXServer.scheme
channel: UCXChannel
@classmethod
def parse_config(cls, config: dict) -> dict:
return config
@staticmethod
@implements(Client.connect)
async def connect(
dest_address: str, local_address: str | None = None, **kwargs
) -> "Client":
prefix = f"{UCXClient.scheme}://"
if dest_address.startswith(prefix):
dest_address = dest_address[len(prefix) :]
host, port_str = dest_address.rsplit(":", 1)
port = int(port_str)
kwargs = kwargs.copy()
ucx_config = kwargs.pop("config", dict()).get("ucx", dict())
UCXInitializer.init(ucx_config)
try:
ucp_endpoint = await ucp.create_endpoint(host, port)
except ucp.exceptions.UCXError as e: # pragma: no cover
raise ChannelClosed(
f"Connection closed before handshake completed, "
f"local address: {local_address}, dest address: {dest_address}"
) from e
channel = UCXChannel(
ucp_endpoint, local_address=local_address, dest_address=dest_address
)
return UCXClient(local_address, dest_address, channel)
async def send_buffers(self, buffers: list, meta: _MessageBase):
return await self.channel.send_buffers(buffers, meta)