Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optionally use NumPy to allocate buffers #5750

Merged
merged 8 commits into from
Feb 24, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions distributed/comm/asyncio_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .addressing import parse_host_port, unparse_host_port
from .core import Comm, CommClosedError, Connector, Listener
from .registry import Backend
from .utils import ensure_concrete_host, from_frames, to_frames
from .utils import ensure_concrete_host, from_frames, host_array, to_frames

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -121,7 +121,7 @@ def __init__(self, on_connection=None, min_read_size=128 * 1024):
self._using_default_buffer = True

self._default_len = max(min_read_size, 16) # need at least 16 bytes of buffer
self._default_buffer = memoryview(bytearray(self._default_len))
self._default_buffer = host_array(self._default_len)
# Index in default_buffer pointing to the first unparsed byte
self._default_start = 0
# Index in default_buffer pointing to the last written byte
Expand Down Expand Up @@ -258,7 +258,7 @@ def _parse_frame_lengths(self):
self._default_start += 8 * n_read

if n_read == needed:
self._frames = [memoryview(bytearray(n)) for n in self._frame_lengths]
self._frames = [host_array(n) for n in self._frame_lengths]
self._frame_index = 0
self._frame_nbytes_needed = (
self._frame_lengths[0] if self._frame_lengths else 0
Expand Down
10 changes: 8 additions & 2 deletions distributed/comm/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@
from .addressing import parse_host_port, unparse_host_port
from .core import Comm, CommClosedError, Connector, FatalCommClosedError, Listener
from .registry import Backend
from .utils import ensure_concrete_host, from_frames, get_tcp_server_address, to_frames
from .utils import (
ensure_concrete_host,
from_frames,
get_tcp_server_address,
host_array,
to_frames,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -205,7 +211,7 @@ async def read(self, deserializers=None):
frames_nbytes = await stream.read_bytes(fmt_size)
(frames_nbytes,) = struct.unpack(fmt, frames_nbytes)

frames = memoryview(bytearray(frames_nbytes))
frames = host_array(frames_nbytes)
# Workaround for OpenSSL 1.0.2 (can drop with OpenSSL 1.1.1)
for i, j in sliding_window(
2, range(0, frames_nbytes + C_INT_MAX, C_INT_MAX)
Expand Down
13 changes: 2 additions & 11 deletions distributed/comm/ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .addressing import parse_host_port, unparse_host_port
from .core import Comm, CommClosedError, Connector, Listener
from .registry import Backend, backends
from .utils import ensure_concrete_host, from_frames, to_frames
from .utils import ensure_concrete_host, from_frames, host_array, to_frames

logger = logging.getLogger(__name__)

Expand All @@ -41,7 +41,6 @@
ucx_create_endpoint = None # type: ignore
ucx_create_listener = None # type: ignore

host_array = None
device_array = None
pre_existing_cuda_context = False
cuda_context_created = False
Expand All @@ -57,7 +56,7 @@ def synchronize_stream(stream=0):


def init_once():
global ucp, host_array, device_array
global ucp, device_array
global ucx_create_endpoint, ucx_create_listener
global pre_existing_cuda_context, cuda_context_created

Expand Down Expand Up @@ -115,14 +114,6 @@ def init_once():

ucp.init(options=ucx_config, env_takes_precedence=True)

# Find the function, `host_array()`, to use when allocating new host arrays
try:
import numpy

host_array = lambda n: numpy.empty((n,), dtype="u1")
except ImportError:
host_array = lambda n: bytearray(n)

# Find the function, `cuda_array()`, to use when allocating new CUDA arrays
try:
import rmm
Expand Down
20 changes: 20 additions & 0 deletions distributed/comm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,26 @@
OFFLOAD_THRESHOLD = parse_bytes(OFFLOAD_THRESHOLD)


# Find the function, `host_array()`, to use when allocating new host arrays
try:
# Use NumPy, when available, to avoid memory initialization cost.
# A `bytearray` is zero-initialized using `calloc`, which we don't need.
# `np.empty` both skips the zero-initialization, and
# uses hugepages when available ( https://github.com/numpy/numpy/pull/14216 ).
import numpy
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd been hoping to avoid importing NumPy when it's not needed (#5729). This change feels like a fine reason to me to say "NumPy is a required import of distributed" and give up on that goal, but wanted to note it. I suppose we could defer the import into the host_array function, but that doesn't really gain us anything. cc @crusaderky

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deferring the import to a function would just ensure that line is run every time we create a buffer, which adds a (small) performance hit (though larger on the first read).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly. I think we should leave the import at the top-level, just wanted to point it out.


def numpy_host_array(n: int) -> memoryview:
return memoryview(numpy.empty((n,), dtype="u1")) # type: ignore

host_array = numpy_host_array
except ImportError:

def builtin_host_array(n: int) -> memoryview:
return memoryview(bytearray(n))

host_array = builtin_host_array


async def to_frames(
msg,
allow_offload=True,
Expand Down