From f61649949fac17a734c0e4dda40c19da0c9834f0 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Mon, 7 Nov 2022 12:30:56 +0100 Subject: [PATCH 01/31] Porting spillabe buffer and manager from #11553 --- ci/gpu/build.sh | 4 + .../source/developer_guide/library_design.md | 9 +- python/cudf/cudf/_lib/binaryop.pyx | 3 + python/cudf/cudf/_lib/column.pxd | 6 +- python/cudf/cudf/_lib/column.pyx | 104 +++- python/cudf/cudf/_lib/copying.pyx | 21 +- python/cudf/cudf/_lib/groupby.pyx | 6 +- python/cudf/cudf/_lib/transform.pyx | 6 +- python/cudf/cudf/_lib/transpose.pyx | 6 +- python/cudf/cudf/_lib/unary.pyx | 7 + python/cudf/cudf/core/buffer/__init__.py | 3 +- python/cudf/cudf/core/buffer/buffer.py | 22 + python/cudf/cudf/core/buffer/spill_manager.py | 306 +++++++++++ .../cudf/cudf/core/buffer/spillable_buffer.py | 473 ++++++++++++++++++ python/cudf/cudf/core/buffer/utils.py | 65 ++- python/cudf/cudf/core/column/column.py | 4 +- python/cudf/cudf/core/column/decimal.py | 4 +- python/cudf/cudf/core/df_protocol.py | 22 +- python/cudf/cudf/core/groupby/groupby.py | 4 + python/cudf/cudf/options.py | 71 +++ python/cudf/cudf/tests/conftest.py | 14 + python/cudf/cudf/tests/test_buffer.py | 12 +- python/cudf/cudf/tests/test_groupby.py | 6 +- python/cudf/cudf/tests/test_spilling.py | 464 +++++++++++++++++ python/cudf/cudf/utils/utils.py | 2 +- .../strings_udf/_lib/cudf_jit_udf.pyx | 2 +- 26 files changed, 1602 insertions(+), 44 deletions(-) create mode 100644 python/cudf/cudf/core/buffer/spill_manager.py create mode 100644 python/cudf/cudf/core/buffer/spillable_buffer.py create mode 100644 python/cudf/cudf/tests/test_spilling.py diff --git a/ci/gpu/build.sh b/ci/gpu/build.sh index 500c3bdbcc5..516d369f5d9 100755 --- a/ci/gpu/build.sh +++ b/ci/gpu/build.sh @@ -282,6 +282,10 @@ conda list gpuci_logger "Python py.test for cuDF" py.test -n 8 --cache-clear --basetemp="$WORKSPACE/cudf-cuda-tmp" --ignore="$WORKSPACE/python/cudf/cudf/benchmarks" --junitxml="$WORKSPACE/junit-cudf.xml" -v --cov-config="$WORKSPACE/python/cudf/.coveragerc" --cov=cudf --cov-report=xml:"$WORKSPACE/python/cudf/cudf-coverage.xml" --cov-report term --dist=loadscope tests +gpuci_logger "Python py.tests for cuDF with spilling (CUDF_SPILL_DEVICE_LIMIT=1)" +# Due to time concerns, we only run a limited set of tests +CUDF_SPILL=on CUDF_SPILL_DEVICE_LIMIT=1 py.test -n 8 --cache-clear --basetemp="$WORKSPACE/cudf-cuda-tmp" --ignore="$WORKSPACE/python/cudf/cudf/benchmarks" -v --cov-config="$WORKSPACE/python/cudf/.coveragerc" --cov-append --cov=cudf --cov-report=xml:"$WORKSPACE/python/cudf/cudf-coverage.xml" --cov-report term --dist=loadscope tests/test_binops.py tests/test_dataframe.py tests/test_buffer.py tests/test_onehot.py tests/test_reshape.py + cd "$WORKSPACE/python/dask_cudf" gpuci_logger "Python py.test for dask-cudf" py.test -n 8 --cache-clear --basetemp="$WORKSPACE/dask-cudf-cuda-tmp" --junitxml="$WORKSPACE/junit-dask-cudf.xml" -v --cov-config=.coveragerc --cov=dask_cudf --cov-report=xml:"$WORKSPACE/python/dask_cudf/dask-cudf-coverage.xml" --cov-report term dask_cudf diff --git a/docs/cudf/source/developer_guide/library_design.md b/docs/cudf/source/developer_guide/library_design.md index 2f0fb5d86fc..be233edf200 100644 --- a/docs/cudf/source/developer_guide/library_design.md +++ b/docs/cudf/source/developer_guide/library_design.md @@ -203,7 +203,6 @@ For instance, all numerical types (floats and ints of different widths) are all ### Buffer - `Column`s are in turn composed of one or more `Buffer`s. A `Buffer` represents a single, contiguous, device memory allocation owned by another object. A `Buffer` constructed from a preexisting device memory allocation (such as a CuPy array) will view that memory. @@ -212,6 +211,14 @@ Conversely, when constructed from a host object, The data is then copied from the host object into the newly allocated device memory. You can read more about [device memory allocation with RMM here](https://github.com/rapidsai/rmm). + +### Spilling to host memory + +Setting the environment variable `CUDF_SPILL=on` enables automatic spilling (and "unspilling") of buffers from +device to host to enable out-of-memory computation, i.e., computing on objects that occupy more memory than is +available on the GPU. + + ## The Cython layer The lowest level of cuDF is its interaction with `libcudf` via Cython. diff --git a/python/cudf/cudf/_lib/binaryop.pyx b/python/cudf/cudf/_lib/binaryop.pyx index 995fdc7e315..9455565a74f 100644 --- a/python/cudf/cudf/_lib/binaryop.pyx +++ b/python/cudf/cudf/_lib/binaryop.pyx @@ -22,6 +22,7 @@ from cudf._lib.cpp.types cimport data_type, type_id from cudf._lib.types cimport dtype_to_data_type, underlying_type_t_type_id from cudf.api.types import is_scalar, is_string_dtype +from cudf.core.buffer import with_spill_lock cimport cudf._lib.cpp.binaryop as cpp_binaryop from cudf._lib.cpp.binaryop cimport binary_operator @@ -156,6 +157,7 @@ cdef binaryop_s_v(DeviceScalar lhs, Column rhs, return Column.from_unique_ptr(move(c_result)) +@with_spill_lock() def binaryop(lhs, rhs, op, dtype): """ Dispatches a binary op call to the appropriate libcudf function: @@ -203,6 +205,7 @@ def binaryop(lhs, rhs, op, dtype): return result +@with_spill_lock() def binaryop_udf(Column lhs, Column rhs, udf_ptx, dtype): """ Apply a user-defined binary operator (a UDF) defined in `udf_ptx` on diff --git a/python/cudf/cudf/_lib/column.pxd b/python/cudf/cudf/_lib/column.pxd index 2df958466c6..f8f851bfe0f 100644 --- a/python/cudf/cudf/_lib/column.pxd +++ b/python/cudf/cudf/_lib/column.pxd @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. from libcpp cimport bool from libcpp.memory cimport unique_ptr @@ -28,7 +28,9 @@ cdef class Column: cdef mutable_column_view mutable_view(self) except * @staticmethod - cdef Column from_unique_ptr(unique_ptr[column] c_col) + cdef Column from_unique_ptr( + unique_ptr[column] c_col, bint data_ptr_exposed=* + ) @staticmethod cdef Column from_column_view(column_view, object) diff --git a/python/cudf/cudf/_lib/column.pyx b/python/cudf/cudf/_lib/column.pyx index 918d786fb83..9e5b62ab404 100644 --- a/python/cudf/cudf/_lib/column.pyx +++ b/python/cudf/cudf/_lib/column.pyx @@ -8,7 +8,14 @@ import rmm import cudf import cudf._lib as libcudf from cudf.api.types import is_categorical_dtype -from cudf.core.buffer import Buffer, as_buffer +from cudf.core.buffer import ( + Buffer, + SpillableBuffer, + SpillLock, + as_buffer, + get_spill_lock, + with_spill_lock, +) from cpython.buffer cimport PyObject_CheckBuffer from libc.stdint cimport uintptr_t @@ -95,7 +102,11 @@ cdef class Column: if self._data is None: start = self.offset * self.dtype.itemsize end = start + self.size * self.dtype.itemsize - self._data = self.base_data[start:end] + if start == 0 and end == self.base_data.size: + # `data` spans all of `base_data` + self._data = self.base_data + else: + self._data = self.base_data[start:end] return self._data @property @@ -249,7 +260,8 @@ cdef class Column: @property def null_count(self): if self._null_count is None: - self._null_count = self.compute_null_count() + with with_spill_lock(): + self._null_count = self.compute_null_count() return self._null_count @property @@ -381,7 +393,14 @@ cdef class Column: cdef vector[column_view] children cdef void* data - data = (col.base_data_ptr) + if col.base_data is None: + data = NULL + elif isinstance(col.base_data, SpillableBuffer): + data = (col.base_data).get_ptr( + spill_lock=get_spill_lock() + ) + else: + data = (col.base_data.ptr) cdef Column child_column if col.base_children: @@ -406,7 +425,16 @@ cdef class Column: children) @staticmethod - cdef Column from_unique_ptr(unique_ptr[column] c_col): + cdef Column from_unique_ptr( + unique_ptr[column] c_col, bint data_ptr_exposed=False + ): + """Create a Column from a column + + Typically, this is called on the result of a libcudf operation. + If the data of the libcudf result has been exposed, set + `data_ptr_exposed=True` to expose the memory of the returned Column + as well. + """ cdef column_view view = c_col.get()[0].view() cdef libcudf_types.type_id tid = view.type().id() cdef libcudf_types.data_type c_dtype @@ -431,20 +459,30 @@ cdef class Column: # After call to release(), c_col is unusable cdef column_contents contents = move(c_col.get()[0].release()) - data = DeviceBuffer.c_from_unique_ptr(move(contents.data)) - data = as_buffer(data) + data = as_buffer( + DeviceBuffer.c_from_unique_ptr(move(contents.data)), + exposed=data_ptr_exposed + ) if null_count > 0: - mask = DeviceBuffer.c_from_unique_ptr(move(contents.null_mask)) - mask = as_buffer(mask) + mask = as_buffer( + DeviceBuffer.c_from_unique_ptr(move(contents.null_mask)), + exposed=data_ptr_exposed + ) else: mask = None cdef vector[unique_ptr[column]] c_children = move(contents.children) - children = () + children = [] if c_children.size() != 0: - children = tuple(Column.from_unique_ptr(move(c_children[i])) - for i in range(c_children.size())) + # Because of a bug in Cython, we cannot set the optional + # `data_ptr_exposed` argument within a comprehension. + for i in range(c_children.size()): + child = Column.from_unique_ptr( + move(c_children[i]), + data_ptr_exposed=data_ptr_exposed + ) + children.append(child) return cudf.core.column.build_column( data, @@ -452,7 +490,7 @@ cdef class Column: mask=mask, size=size, null_count=null_count, - children=children + children=tuple(children) ) @staticmethod @@ -474,6 +512,7 @@ cdef class Column: size = cv.size() offset = cv.offset() dtype = dtype_from_column_view(cv) + dtype_itemsize = dtype.itemsize if hasattr(dtype, "itemsize") else 1 data_ptr = (cv.head[void]()) data = None @@ -484,19 +523,45 @@ cdef class Column: data_owner = owner.base_data mask_owner = mask_owner.base_mask base_size = owner.base_size - + base_nbytes = base_size * dtype_itemsize if data_ptr: if data_owner is None: data = as_buffer( rmm.DeviceBuffer(ptr=data_ptr, - size=(size+offset) * dtype.itemsize) + size=(size+offset) * dtype_itemsize) ) + elif ( + # This is an optimization to avoid creating a new + # SpillableBuffer that represent the same memory + # as the owner. + column_owner and + isinstance(data_owner, SpillableBuffer) and + # We have to make sure that `data_owner` is already spill + # locked and that its pointer is the same as `data_ptr` + # _without_ exposing the buffer permanently. + not data_owner.spillable and + data_owner.get_ptr(spill_lock=SpillLock()) == data_ptr and + data_owner.size == base_nbytes + ): + data = data_owner else: + # At this point we don't know the relationship between data_ptr + # and data_owner thus we mark both of them exposed. + # TODO: try to discover their relationship and create a + # SpillableBufferSlice instead. data = as_buffer( - data=data_ptr, - size=(base_size) * dtype.itemsize, - owner=data_owner + data_ptr, + size=base_nbytes, + owner=data_owner, + exposed=True, ) + if isinstance(data_owner, SpillableBuffer): + if data_owner.is_spilled: + raise ValueError( + f"{data_owner} is spilled, which invalidates " + f"the exposed data_ptr ({hex(data_ptr)})" + ) + data_owner.ptr # accessing the pointer marks it exposed. else: data = as_buffer( rmm.DeviceBuffer(ptr=data_ptr, size=0) @@ -538,7 +603,8 @@ cdef class Column: mask = as_buffer( data=mask_ptr, size=bitmask_allocation_size_bytes(base_size), - owner=mask_owner + owner=mask_owner, + exposed=True ) if cv.has_nulls(): diff --git a/python/cudf/cudf/_lib/copying.pyx b/python/cudf/cudf/_lib/copying.pyx index d9a7a5b8754..7cd811caa26 100644 --- a/python/cudf/cudf/_lib/copying.pyx +++ b/python/cudf/cudf/_lib/copying.pyx @@ -12,7 +12,7 @@ from libcpp.vector cimport vector from rmm._lib.device_buffer cimport DeviceBuffer import cudf -from cudf.core.buffer import Buffer, as_buffer +from cudf.core.buffer import Buffer, as_buffer, with_spill_lock from cudf._lib.column cimport Column @@ -64,6 +64,7 @@ def _gather_map_is_valid( return gm_min >= -nrows and gm_max < nrows +@with_spill_lock() def copy_column(Column input_column): """ Deep copies a column @@ -132,6 +133,7 @@ def _copy_range(Column input_column, return Column.from_unique_ptr(move(c_result)) +@with_spill_lock() def copy_range(Column input_column, Column target_column, size_type input_begin, @@ -164,6 +166,7 @@ def copy_range(Column input_column, input_begin, input_end, target_begin) +@with_spill_lock() def gather( list columns, Column gather_map, @@ -231,6 +234,7 @@ cdef scatter_column(list source_columns, return columns_from_unique_ptr(move(c_result)) +@with_spill_lock() def scatter(list sources, Column scatter_map, list target_columns, bool bounds_check=True): """ @@ -271,6 +275,7 @@ def scatter(list sources, Column scatter_map, list target_columns, ) +@with_spill_lock() def column_empty_like(Column input_column): cdef column_view input_column_view = input_column.view() @@ -282,6 +287,7 @@ def column_empty_like(Column input_column): return Column.from_unique_ptr(move(c_result)) +@with_spill_lock() def column_allocate_like(Column input_column, size=None): cdef size_type c_size = 0 @@ -306,6 +312,7 @@ def column_allocate_like(Column input_column, size=None): return Column.from_unique_ptr(move(c_result)) +@with_spill_lock() def columns_empty_like(list input_columns): cdef table_view input_table_view = table_view_from_columns(input_columns) cdef unique_ptr[table] c_result @@ -316,6 +323,7 @@ def columns_empty_like(list input_columns): return columns_from_unique_ptr(move(c_result)) +@with_spill_lock() def column_slice(Column input_column, object indices): cdef column_view input_column_view = input_column.view() @@ -345,6 +353,7 @@ def column_slice(Column input_column, object indices): return result +@with_spill_lock() def columns_slice(list input_columns, list indices): """ Given a list of input columns, return columns sliced by ``indices``. @@ -371,6 +380,7 @@ def columns_slice(list input_columns, list indices): ] +@with_spill_lock() def column_split(Column input_column, object splits): cdef column_view input_column_view = input_column.view() @@ -402,6 +412,7 @@ def column_split(Column input_column, object splits): return result +@with_spill_lock() def columns_split(list input_columns, object splits): cdef table_view input_table_view = table_view_from_columns(input_columns) @@ -508,6 +519,7 @@ def _copy_if_else_scalar_scalar(DeviceScalar lhs, return Column.from_unique_ptr(move(c_result)) +@with_spill_lock() def copy_if_else(object lhs, object rhs, Column boolean_mask): if isinstance(lhs, Column): @@ -575,6 +587,7 @@ def _boolean_mask_scatter_scalar(list input_scalars, list target_columns, return columns_from_unique_ptr(move(c_result)) +@with_spill_lock() def boolean_mask_scatter(list input_, list target_columns, Column boolean_mask): """Copy the target columns, replacing masked rows with input data. @@ -607,6 +620,7 @@ def boolean_mask_scatter(list input_, list target_columns, ) +@with_spill_lock() def shift(Column input, int offset, object fill_value=None): cdef DeviceScalar fill @@ -643,6 +657,7 @@ def shift(Column input, int offset, object fill_value=None): return Column.from_unique_ptr(move(c_output)) +@with_spill_lock() def get_element(Column input_column, size_type index): cdef column_view col_view = input_column.view() @@ -657,6 +672,7 @@ def get_element(Column input_column, size_type index): ) +@with_spill_lock() def segmented_gather(Column source_column, Column gather_map): cdef shared_ptr[lists_column_view] source_LCV = ( make_shared[lists_column_view](source_column.view()) @@ -724,7 +740,8 @@ cdef class _CPackedColumns: gpu_data = as_buffer( data=self.gpu_data_ptr, size=self.gpu_data_size, - owner=self + owner=self, + exposed=True ) data_header, data_frames = gpu_data.serialize() header["data"] = data_header diff --git a/python/cudf/cudf/_lib/groupby.pyx b/python/cudf/cudf/_lib/groupby.pyx index e6fbefaeee9..bea39c06387 100644 --- a/python/cudf/cudf/_lib/groupby.pyx +++ b/python/cudf/cudf/_lib/groupby.pyx @@ -10,6 +10,7 @@ from cudf.api.types import ( is_string_dtype, is_struct_dtype, ) +from cudf.core.buffer import with_spill_lock from libcpp cimport bool from libcpp.memory cimport unique_ptr @@ -86,13 +87,16 @@ cdef class GroupBy: def __cinit__(self, list keys, bool dropna=True, *args, **kwargs): cdef libcudf_types.null_policy c_null_handling + cdef table_view keys_view if dropna: c_null_handling = libcudf_types.null_policy.EXCLUDE else: c_null_handling = libcudf_types.null_policy.INCLUDE - cdef table_view keys_view = table_view_from_columns(keys) + with with_spill_lock() as spill_lock: + keys_view = table_view_from_columns(keys) + self._spill_lock = spill_lock with nogil: self.c_obj.reset( diff --git a/python/cudf/cudf/_lib/transform.pyx b/python/cudf/cudf/_lib/transform.pyx index b95bce0db58..1fa68282c3d 100644 --- a/python/cudf/cudf/_lib/transform.pyx +++ b/python/cudf/cudf/_lib/transform.pyx @@ -146,7 +146,11 @@ def one_hot_encode(Column input_column, Column categories): libcudf_transform.one_hot_encode(c_view_input, c_view_categories) ) - owner = Column.from_unique_ptr(move(c_result.first)) + # Notice, the data pointer of `owner` has been exposed + # through `c_result.second` at this point. + owner = Column.from_unique_ptr( + move(c_result.first), data_ptr_exposed=True + ) pylist_categories = categories.to_arrow().to_pylist() encodings, _ = data_from_table_view( diff --git a/python/cudf/cudf/_lib/transpose.pyx b/python/cudf/cudf/_lib/transpose.pyx index b9eea6169bd..51e49b1f27a 100644 --- a/python/cudf/cudf/_lib/transpose.pyx +++ b/python/cudf/cudf/_lib/transpose.pyx @@ -20,7 +20,11 @@ def transpose(list source_columns): with nogil: c_result = move(cpp_transpose(c_input)) - result_owner = Column.from_unique_ptr(move(c_result.first)) + # Notice, the data pointer of `result_owner` has been exposed + # through `c_result.second` at this point. + result_owner = Column.from_unique_ptr( + move(c_result.first), data_ptr_exposed=True + ) return columns_from_table_view( c_result.second, owners=[result_owner] * c_result.second.num_columns() diff --git a/python/cudf/cudf/_lib/unary.pyx b/python/cudf/cudf/_lib/unary.pyx index 52f0a804b2a..b1f5e3bd101 100644 --- a/python/cudf/cudf/_lib/unary.pyx +++ b/python/cudf/cudf/_lib/unary.pyx @@ -3,6 +3,7 @@ from enum import IntEnum from cudf.api.types import is_decimal_dtype +from cudf.core.buffer import with_spill_lock from libcpp.memory cimport unique_ptr from libcpp.utility cimport move @@ -43,6 +44,7 @@ class UnaryOp(IntEnum): NOT = unary_operator.NOT +@with_spill_lock() def unary_operation(Column input, object op): cdef column_view c_input = input.view() cdef unary_operator c_op = ( @@ -60,6 +62,7 @@ def unary_operation(Column input, object op): return Column.from_unique_ptr(move(c_result)) +@with_spill_lock() def is_null(Column input): cdef column_view c_input = input.view() cdef unique_ptr[column] c_result @@ -70,6 +73,7 @@ def is_null(Column input): return Column.from_unique_ptr(move(c_result)) +@with_spill_lock() def is_valid(Column input): cdef column_view c_input = input.view() cdef unique_ptr[column] c_result @@ -80,6 +84,7 @@ def is_valid(Column input): return Column.from_unique_ptr(move(c_result)) +@with_spill_lock() def cast(Column input, object dtype=np.float64): cdef column_view c_input = input.view() cdef data_type c_dtype = dtype_to_data_type(dtype) @@ -95,6 +100,7 @@ def cast(Column input, object dtype=np.float64): return result +@with_spill_lock() def is_nan(Column input): cdef column_view c_input = input.view() cdef unique_ptr[column] c_result @@ -105,6 +111,7 @@ def is_nan(Column input): return Column.from_unique_ptr(move(c_result)) +@with_spill_lock() def is_non_nan(Column input): cdef column_view c_input = input.view() cdef unique_ptr[column] c_result diff --git a/python/cudf/cudf/core/buffer/__init__.py b/python/cudf/cudf/core/buffer/__init__.py index a73bc69ffb5..044f2fa0478 100644 --- a/python/cudf/cudf/core/buffer/__init__.py +++ b/python/cudf/cudf/core/buffer/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) 2022, NVIDIA CORPORATION. from cudf.core.buffer.buffer import Buffer, cuda_array_interface_wrapper -from cudf.core.buffer.utils import as_buffer +from cudf.core.buffer.spillable_buffer import SpillableBuffer, SpillLock +from cudf.core.buffer.utils import as_buffer, get_spill_lock, with_spill_lock diff --git a/python/cudf/cudf/core/buffer/buffer.py b/python/cudf/cudf/core/buffer/buffer.py index 73e589ebb8e..29534ab5529 100644 --- a/python/cudf/cudf/core/buffer/buffer.py +++ b/python/cudf/cudf/core/buffer/buffer.py @@ -148,6 +148,28 @@ def _from_host_memory(cls: Type[T], data: Any) -> T: # Create from device memory return cls._from_device_memory(buf) + @classmethod + def _from_any_memory(cls: Type[T], data: Any) -> T: + """Create a Buffer from device or host memory + + If data exposes `__cuda_array_interface__`, we deligate to the + `_from_device_memory` constructor otherwise `_from_host_memory`. + + Parameters + ---------- + data : Any + An object that represens device or host memory. + + Returns + ------- + Buffer + Buffer representing `data`. + """ + + if hasattr(data, "__cuda_array_interface__"): + return cls._from_device_memory(data) + return cls._from_host_memory(data) + def _getitem(self, offset: int, size: int) -> Buffer: """ Sub-classes can overwrite this to implement __getitem__ diff --git a/python/cudf/cudf/core/buffer/spill_manager.py b/python/cudf/cudf/core/buffer/spill_manager.py new file mode 100644 index 00000000000..821cae54128 --- /dev/null +++ b/python/cudf/cudf/core/buffer/spill_manager.py @@ -0,0 +1,306 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +from __future__ import annotations + +import gc +import io +import threading +import traceback +import warnings +import weakref +from typing import List, Optional, Tuple + +import rmm.mr + +from cudf.core.buffer.spillable_buffer import SpillableBuffer +from cudf.options import get_option +from cudf.utils.string import format_bytes + + +def get_traceback() -> str: + """Pretty print current traceback to a string""" + with io.StringIO() as f: + traceback.print_stack(file=f) + f.seek(0) + return f.read() + + +def get_rmm_memory_resource_stack( + mr: rmm.mr.DeviceMemoryResource, +) -> List[rmm.mr.DeviceMemoryResource]: + """Get the RMM resource stack + + Parameters + ---------- + mr : rmm.mr.DeviceMemoryResource + Top of the resource stack + + Return + ------ + list + List of RMM resources + """ + + if hasattr(mr, "upstream_mr"): + return [mr] + get_rmm_memory_resource_stack(mr.upstream_mr) + return [mr] + + +class SpillManager: + """Manager of spillable buffers. + + This class implements tracking of all known spillable buffers, on-demand + spilling of said buffers, and (optionally) maintains a memory usage limit. + + When `spill_on_demand=True`, the manager registers an RMM out-of-memory + error handler, which will spill spillable buffers in order to free up + memory. + + When `device_memory_limit=True`, the manager will try keep the device + memory usage below the specified limit by spilling of spillable buffers + continuously, which will introduce a modest overhead. + + Parameters + ---------- + spill_on_demand : bool + Enable spill on demand. The global manager sets this to the value of + `CUDF_SPILL_ON_DEMAND` or False. + device_memory_limit: int, optional + If not None, this is the device memory limit in bytes that triggers + device to host spilling. The global manager sets this to the value + of `CUDF_SPILL_DEVICE_LIMIT` or None. + """ + + _base_buffers: weakref.WeakValueDictionary[int, SpillableBuffer] + + def __init__( + self, + *, + spill_on_demand: bool = False, + device_memory_limit: int = None, + ) -> None: + self._lock = threading.Lock() + self._base_buffers = weakref.WeakValueDictionary() + self._id_counter = 0 + self._spill_on_demand = spill_on_demand + self._device_memory_limit = device_memory_limit + + if self._spill_on_demand: + # Set the RMM out-of-memory handle if not already set + mr = rmm.mr.get_current_device_resource() + if all( + not isinstance(m, rmm.mr.FailureCallbackResourceAdaptor) + for m in get_rmm_memory_resource_stack(mr) + ): + rmm.mr.set_current_device_resource( + rmm.mr.FailureCallbackResourceAdaptor( + mr, self._out_of_memory_handle + ) + ) + + def _out_of_memory_handle(self, nbytes: int, *, retry_once=True) -> bool: + """Try to handle an out-of-memory error by spilling + + This can by used as the callback function to RMM's + `FailureCallbackResourceAdaptor` + + Parameters + ---------- + nbytes : int + Number of bytes to try to spill. + retry_once : bool, optional + If True, call `gc.collect()` and retry once. + + Return + ------ + bool + True if any buffers were freed otherwise False. + + Warning + ------- + In order to avoid deadlock, this function should not lock + already locked buffers. + """ + + # Keep spilling until `nbytes` been spilled + total_spilled = 0 + while total_spilled < nbytes: + spilled = self.spill_device_memory() + if spilled == 0: + break # No more to spill! + total_spilled += spilled + + if total_spilled > 0: + return True # Ask RMM to retry the allocation + + if retry_once: + # Let's collect garbage and try one more time + gc.collect() + return self._out_of_memory_handle(nbytes, retry_once=False) + + # TODO: write to log instead of stdout + print( + f"[WARNING] RMM allocation of {format_bytes(nbytes)} bytes " + "failed, spill-on-demand couldn't find any device memory to " + f"spill:\n{repr(self)}\ntraceback:\n{get_traceback()}" + ) + return False # Since we didn't find anything to spill, we give up + + def add(self, buffer: SpillableBuffer) -> None: + """Add buffer to the set of managed buffers + + The manager keeps a weak reference to the buffer + + Parameters + ---------- + buffer : SpillableBuffer + The buffer to manage + """ + if buffer.size > 0 and not buffer.exposed: + with self._lock: + self._base_buffers[self._id_counter] = buffer + self._id_counter += 1 + self.spill_to_device_limit() + + def base_buffers( + self, order_by_access_time: bool = False + ) -> Tuple[SpillableBuffer, ...]: + """Get all managed buffers + + Parameters + ---------- + order_by_access_time : bool, optional + Order the buffer by access time (ascending order) + + Return + ------ + tuple + Tuple of buffers + """ + with self._lock: + ret = tuple(self._base_buffers.values()) + if order_by_access_time: + ret = tuple(sorted(ret, key=lambda b: b.last_accessed)) + return ret + + def spill_device_memory(self) -> int: + """Try to spill device memory + + This function is safe to call doing spill-on-demand + since it does not lock buffers already locked. + + Return + ------ + int + Number of bytes spilled. + """ + for buf in self.base_buffers(order_by_access_time=True): + if buf.lock.acquire(blocking=False): + try: + if not buf.is_spilled and buf.spillable: + buf.__spill__(target="cpu") + return buf.size + finally: + buf.lock.release() + return 0 + + def spill_to_device_limit(self, device_limit: int = None) -> int: + """Spill until device limit + + Notice, by default this is a no-op. + + Parameters + ---------- + device_limit : int, optional + Limit in bytes. If None, the value of the environment variable + `CUDF_SPILL_DEVICE_LIMIT` is used. If this is not set, the method + does nothing and returns 0. + + Return + ------ + int + The number of bytes spilled. + """ + limit = ( + self._device_memory_limit if device_limit is None else device_limit + ) + if limit is None: + return 0 + ret = 0 + while True: + unspilled = sum( + buf.size for buf in self.base_buffers() if not buf.is_spilled + ) + if unspilled < limit: + break + nbytes = self.spill_device_memory() + if nbytes == 0: + break # No more to spill + ret += nbytes + return ret + + def lookup_address_range( # TODO: remove, only for debugging + self, ptr: int, size: int + ) -> List[SpillableBuffer]: + ret = [] + for buf in self.base_buffers(): + if buf.is_overlapping(ptr, size): + ret.append(buf) + return ret + + def __repr__(self) -> str: + spilled = sum( + buf.size for buf in self.base_buffers() if buf.is_spilled + ) + unspilled = sum( + buf.size for buf in self.base_buffers() if not buf.is_spilled + ) + unspillable = 0 + for buf in self.base_buffers(): + if not (buf.is_spilled or buf.spillable): + unspillable += buf.size + unspillable_ratio = unspillable / unspilled if unspilled else 0 + + return ( + f"" + ) + + +# The global manager has three states: +# - Uninitialized +# - Initialized to None (spilling disabled) +# - Initialized to a SpillManager instance (spilling enabled) +_global_manager_uninitialized: bool = True +_global_manager: Optional[SpillManager] = None + + +def set_global_manager(manager: Optional[SpillManager]) -> None: + """Set the global manager, which if None disables spilling""" + + global _global_manager, _global_manager_uninitialized + if _global_manager is not None: + gc.collect() + base_buffers = _global_manager.base_buffers() + if len(base_buffers) > 0: + warnings.warn(f"overwriting non-empty manager: {base_buffers}") + + _global_manager = manager + _global_manager_uninitialized = False + + +def get_global_manager() -> Optional[SpillManager]: + """Get the global manager or None if spilling is disabled""" + global _global_manager_uninitialized + if _global_manager_uninitialized: + manager = None + if get_option("spill"): + manager = SpillManager( + spill_on_demand=get_option("spill_on_demand"), + device_memory_limit=get_option("spill_device_limit"), + ) + set_global_manager(manager) + return _global_manager diff --git a/python/cudf/cudf/core/buffer/spillable_buffer.py b/python/cudf/cudf/core/buffer/spillable_buffer.py new file mode 100644 index 00000000000..50ac6b4e653 --- /dev/null +++ b/python/cudf/cudf/core/buffer/spillable_buffer.py @@ -0,0 +1,473 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +from __future__ import annotations + +import collections.abc +import pickle +import time +import weakref +from threading import RLock +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar + +import numpy + +import rmm + +from cudf.core.buffer.buffer import Buffer, cuda_array_interface_wrapper +from cudf.utils.string import format_bytes + +if TYPE_CHECKING: + from cudf.core.buffer.spill_manager import SpillManager + + +T = TypeVar("T", bound="SpillableBuffer") + + +class SpillLock: + pass + + +class DelayedPointerTuple(collections.abc.Sequence): + """ + A delayed version of the "data" field in __cuda_array_interface__. + + The idea is to delay the access to `Buffer.ptr` until the user + actually accesses the data pointer. + + For instance, in many cases __cuda_array_interface__ is accessed + only to determine whether an object is a CUDA object or not. + + TODO: this doesn't support libraries such as PyTorch that declare + the tuple of __cuda_array_interface__["data"] in Cython. In such + cases, Cython will raise an error because DelayedPointerTuple + isn't a "real" tuple. + """ + + def __init__(self, buffer) -> None: + self._buf = buffer + + def __len__(self): + return 2 + + def __getitem__(self, i): + if i == 0: + return self._buf.ptr + elif i == 1: + return False + raise IndexError("tuple index out of range") + + +class SpillableBuffer(Buffer): + """A spillable buffer that implements DeviceBufferLike. + + This buffer supports spilling the represented data to host memory. + Spilling can be done manually by calling `.__spill__(target="cpu")` but + usually the associated spilling manager triggers spilling based on current + device memory usage see `cudf.core.buffer.spill_manager.SpillManager`. + Unspill is triggered automatically when accessing the data of the buffer. + + The buffer might not be spillable, which is based on the "expose" status + of the buffer. We say that the buffer has been exposed if the device + pointer (integer or void*) has been accessed outside of SpillableBuffer. + In this case, we cannot invalidate the device pointer by moving the data + to host. + + A buffer can be exposed permanently at creation or by accessing the `.ptr` + property. To avoid this, one can use `.get_ptr()` instead, which support + exposing the buffer temporarily. + + Use the factory function `as_buffer` to create a SpillableBuffer instance. + """ + + _lock: RLock + _spill_locks: weakref.WeakSet + _last_accessed: float + _ptr_desc: Dict[str, Any] + _exposed: bool + _manager: SpillManager + + def _finalize_init(self, ptr_desc: Dict[str, Any], exposed: bool) -> None: + from cudf.core.buffer.spill_manager import get_global_manager + + self._lock = RLock() + self._spill_locks = weakref.WeakSet() + self._last_accessed = time.monotonic() + self._ptr_desc = ptr_desc + self._exposed = exposed + manager = get_global_manager() + if manager is None: + raise ValueError( + f"cannot create {self.__class__} with a global spill manager" + ) + + if self._ptr: + # TODO: run the following asserts in "debug mode" or not at all. + # Assert that any buffers `data` may refer to has been exposed + # already. If this is not the case, it means that somewhere we + # are accessing a buffer's device pointer without marking it as + # exposed, which would be a bug. + bases = manager.lookup_address_range(self._ptr, self._size) + assert all(b.exposed for b in bases) + # Assert that if `data` refers to any existing base buffers, it + # must itself be exposed. + assert len(bases) == 0 or exposed + + self._manager = manager + self._manager.add(self) + + @classmethod + def _from_device_memory( + cls: Type[T], data: Any, *, exposed: bool = False + ) -> T: + """Create a spillabe buffer from device memory. + + No data is being copied. + + Parameters + ---------- + data : device-buffer-like + An object implementing the CUDA Array Interface. + exposed : bool, optional + Mark the buffer as permanently exposed (unspillable). + + Returns + ------- + SpillableBuffer + Buffer representing the same device memory as `data` + """ + ret = super(SpillableBuffer, cls)._from_device_memory(data) + ret._finalize_init(ptr_desc={"type": "gpu"}, exposed=exposed) + return ret + + @classmethod + def _from_host_memory(cls: Type[T], data: Any) -> T: + """Create a spillabe buffer from host memory. + + Data must implement `__array_interface__`, the buffer protocol, and/or + be convertible to a buffer object using `numpy.array()` + + The new buffer is marked as spilled to host memory already. + + Raises ValueError if array isn't C-contiguous. + + Parameters + ---------- + data : Any + An object that represens host memory. + + Returns + ------- + SpillableBuffer + Buffer representing a copy of `data`. + """ + + # Convert to a memoryview using numpy array, this will not copy data + # in most cases. + data = memoryview(numpy.array(data, copy=False, subok=True)) + if not data.c_contiguous: + raise ValueError("Buffer data must be C-contiguous") + + # Create an already spilled buffer + ret = cls.__new__(cls) + ret._owner = None + ret._ptr = 0 + ret._size = data.nbytes + ret._finalize_init( + ptr_desc={"type": "cpu", "memoryview": data}, exposed=False + ) + return ret + + @property + def lock(self) -> RLock: + return self._lock + + @property + def is_spilled(self) -> bool: + return self._ptr_desc["type"] != "gpu" + + def __spill__(self, target: str = "cpu") -> None: + """Spill or un-spill this buffer in-place + + Parameters + ---------- + target : str + The target of the spilling. + """ + + with self._lock: + ptr_type = self._ptr_desc["type"] + if ptr_type == target: + return + + if not self.spillable: + raise ValueError( + f"Cannot in-place move an unspillable buffer: {self}" + ) + + if (ptr_type, target) == ("gpu", "cpu"): + host_mem = memoryview(bytearray(self.size)) + rmm._lib.device_buffer.copy_ptr_to_host(self._ptr, host_mem) + self._ptr_desc["memoryview"] = host_mem + self._ptr = 0 + self._owner = None + elif (ptr_type, target) == ("cpu", "gpu"): + # Notice, this operation is prone to deadlock because the RMM + # allocation might trigger spilling-on-demand which in turn + # trigger a new call to this buffer's `__spill__()`. + # Therefore, it is important that spilling-on-demand doesn't + # tries to unspill an already locked buffer! + dev_mem = rmm.DeviceBuffer.to_device( + self._ptr_desc.pop("memoryview") + ) + self._ptr = dev_mem.ptr + self._owner = dev_mem + assert self._size == dev_mem.size + else: + # TODO: support moving to disk + raise ValueError(f"Unknown target: {target}") + self._ptr_desc["type"] = target + + @property + def ptr(self) -> int: + """Access the memory directly + + Notice, this will mark the buffer as "exposed" and make + it unspillable permanently. + + Consider using `.get_ptr()` instead. + """ + + self._manager.spill_to_device_limit() + with self._lock: + self.__spill__(target="gpu") + self._exposed = True + self._last_accessed = time.monotonic() + return self._ptr + + def spill_lock(self, spill_lock: SpillLock = None) -> SpillLock: + if spill_lock is None: + spill_lock = SpillLock() + with self._lock: + self.__spill__(target="gpu") + self._spill_locks.add(spill_lock) + return spill_lock + + def get_ptr(self, spill_lock: SpillLock = None) -> int: + """Get a device pointer to the memory of the buffer. + + If spill_lock is not None, a reference to this buffer is added + to spill_lock, which disable spilling of this buffer while + spill_lock is alive. + + Parameters + ---------- + spill_lock : SpillLock, optional + Adding a reference of this buffer to the spill lock. + + Return + ------ + int + The device pointer as an integer + """ + + if spill_lock is None: + return self.ptr # expose the buffer permanently + + self.spill_lock(spill_lock) + self._last_accessed = time.monotonic() + return self._ptr + + @property + def owner(self) -> Any: + return self._owner + + @property + def exposed(self) -> bool: + return self._exposed + + @property + def spillable(self) -> bool: + return not self._exposed and len(self._spill_locks) == 0 + + @property + def size(self) -> int: + return self._size + + @property + def nbytes(self) -> int: + return self._size + + @property + def last_accessed(self) -> float: + return self._last_accessed + + @property + def __cuda_array_interface__(self) -> dict: + return { + "data": DelayedPointerTuple(self), + "shape": (self.size,), + "strides": None, + "typestr": "|u1", + "version": 0, + } + + def memoryview(self, *, offset: int = 0, size: int = None) -> memoryview: + size = self._size if size is None else size + with self._lock: + if self.spillable: + self.__spill__(target="cpu") + return self._ptr_desc["memoryview"][offset : offset + size] + else: + assert self._ptr_desc["type"] == "gpu" + ret = memoryview(bytearray(size)) + rmm._lib.device_buffer.copy_ptr_to_host( + self._ptr + offset, ret + ) + return ret + + def _getitem(self, offset: int, size: int) -> Buffer: + return SpillableBufferSlice(base=self, offset=offset, size=size) + + def serialize(self) -> Tuple[dict, list]: + """Serialize the Buffer + + Normally, we would use `[self]` as the frames. This would work but + also mean that `self` becomes exposed permanently if the frames are + later accessed through `__cuda_array_interface__`, which is exactly + what libraries like Dask+UCX would do when communicating! + + The sound solution is to modify Dask et al. so that they access the + frames through `.get_ptr()` and holds on to the `spill_lock` until + the frame has been transferred. However, until this adaptation we + use a hack where the frame is a `Buffer` with a `spill_lock` as the + owner, which makes `self` unspillable while the frame is alive but + doesn't expose `self` when `__cuda_array_interface__` is accessed. + + Warning, this hack means that the returned frame must be copied before + given to `.deserialize()`, otherwise we would have a `Buffer` pointing + to memory already owned by an existing `SpillableBuffer`. + """ + header: Dict[Any, Any] + frames: List[Buffer | memoryview] + with self._lock: + header = {} + header["type-serialized"] = pickle.dumps(self.__class__) + header["frame_count"] = 1 + if self.is_spilled: + frames = [self.memoryview()] + else: + # TODO: Use `frames=[self]` instead of this hack, see doc above + spill_lock = SpillLock() + ptr = self.get_ptr(spill_lock=spill_lock) + frames = [ + Buffer._from_device_memory( + cuda_array_interface_wrapper( + ptr=ptr, + size=self.size, + owner=(self._owner, spill_lock), + ) + ) + ] + return header, frames + + def is_overlapping(self, ptr: int, size: int): + with self._lock: + return ( + not self.is_spilled + and (ptr + size) > self._ptr + and (self._ptr + self._size) > ptr + ) + + def __repr__(self) -> str: + if self._ptr_desc["type"] != "gpu": + ptr_info = str(self._ptr_desc) + else: + ptr_info = str(hex(self._ptr)) + return ( + f"" + ) + + +class SpillableBufferSlice(SpillableBuffer): + """A slice of a spillable buffer + + This buffer applies the slicing and then delegates all + operations to its base buffer. + + Parameters + ---------- + base : SpillableBuffer + The base of the view + offset : int + Memory offset into the base buffer + size : int + Size of the view (in bytes) + """ + + def __init__(self, base: SpillableBuffer, offset: int, size: int) -> None: + if size < 0: + raise ValueError("size cannot be negative") + if offset < 0: + raise ValueError("offset cannot be negative") + if offset + size > base.size: + raise ValueError( + "offset+size cannot be greater than the size of base" + ) + self._base = base + self._offset = offset + self._size = size + self._owner = base + self._lock = base.lock + + @property + def ptr(self) -> int: + return self._base.ptr + self._offset + + def get_ptr(self, spill_lock: SpillLock = None) -> int: + return self._base.get_ptr(spill_lock=spill_lock) + self._offset + + def _getitem(self, offset: int, size: int) -> Buffer: + return SpillableBufferSlice( + base=self._base, offset=offset + self._offset, size=size + ) + + @classmethod + def deserialize(cls, header: dict, frames: list): + # TODO: because of the hack in `SpillableBuffer.serialize()` where + # frames are of type `Buffer`, we always deserialize as if they are + # `SpillableBufferbuffer`. In the future, we should be able to + # deserialize into `SpillableBufferSlice` when the frames hasn't been + # copied. + return SpillableBuffer.deserialize(header, frames) + + def memoryview(self, *, offset: int = 0, size: int = None) -> memoryview: + size = self._size if size is None else size + return self._base.memoryview(offset=self._offset + offset, size=size) + + def __repr__(self) -> str: + return ( + f" None: + return self._base.__spill__(target=target) + + @property + def is_spilled(self) -> bool: + return self._base.is_spilled + + @property + def exposed(self) -> bool: + return self._base.exposed + + @property + def spillable(self) -> bool: + return self._base.spillable + + def spill_lock(self, spill_lock: SpillLock = None) -> SpillLock: + return self._base.spill_lock(spill_lock=spill_lock) diff --git a/python/cudf/cudf/core/buffer/utils.py b/python/cudf/cudf/core/buffer/utils.py index 5e017c4bc92..3da1d610ca1 100644 --- a/python/cudf/cudf/core/buffer/utils.py +++ b/python/cudf/cudf/core/buffer/utils.py @@ -2,9 +2,13 @@ from __future__ import annotations -from typing import Any, Union +import threading +from contextlib import ContextDecorator +from typing import Any, Dict, Optional, Tuple, Union from cudf.core.buffer.buffer import Buffer, cuda_array_interface_wrapper +from cudf.core.buffer.spill_manager import get_global_manager +from cudf.core.buffer.spillable_buffer import SpillableBuffer, SpillLock def as_buffer( @@ -12,6 +16,7 @@ def as_buffer( *, size: int = None, owner: object = None, + exposed: bool = False, ) -> Buffer: """Factory function to wrap `data` in a Buffer object. @@ -37,6 +42,10 @@ def as_buffer( owner : object, optional Python object to which the lifetime of the memory allocation is tied. A reference to this object is kept in the returned Buffer. + exposed : bool, optional + Mark the buffer as permanently exposed (unspillable). This is ignored + unless spilling is enabled and the data represents device memory, see + SpillableBuffer. Return ------ @@ -62,6 +71,60 @@ def as_buffer( "`data` is a buffer-like or array-like object" ) + if get_global_manager() is not None: + if hasattr(data, "__cuda_array_interface__"): + return SpillableBuffer._from_device_memory(data, exposed=exposed) + if exposed: + raise ValueError("cannot created exposed host memory") + return SpillableBuffer._from_host_memory(data) + if hasattr(data, "__cuda_array_interface__"): return Buffer._from_device_memory(data) return Buffer._from_host_memory(data) + + +_thread_spill_locks: Dict[int, Tuple[Optional[SpillLock], int]] = {} + + +def _push_thread_spill_lock() -> None: + _id = threading.get_ident() + spill_lock, count = _thread_spill_locks.get(_id, (None, 0)) + if spill_lock is None: + spill_lock = SpillLock() + _thread_spill_locks[_id] = (spill_lock, count + 1) + + +def _pop_thread_spill_lock() -> None: + _id = threading.get_ident() + spill_lock, count = _thread_spill_locks[_id] + if count == 1: + spill_lock = None + _thread_spill_locks[_id] = (spill_lock, count - 1) + + +class with_spill_lock(ContextDecorator): + """Decorator and context to set spill lock automatically. + + All calls to `get_spill_lock()` within the decorated function or context + will return a spill lock with a lifetime bound to the function or context. + """ + + def __enter__(self) -> Optional[SpillLock]: + _push_thread_spill_lock() + return get_spill_lock() + + def __exit__(self, *exc): + _pop_thread_spill_lock() + + +def get_spill_lock() -> Union[SpillLock, None]: + """Return a spill lock within the context of `with_spill_lock` or None + + Returns None, if spilling is disabled. + """ + + if get_global_manager() is None: + return None + _id = threading.get_ident() + spill_lock, _ = _thread_spill_locks.get(_id, (None, 0)) + return spill_lock diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index 6c17b492f8a..d16df7ea1c0 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -1764,7 +1764,7 @@ def as_column( ): arbitrary = cupy.ascontiguousarray(arbitrary) - data = as_buffer(arbitrary) + data = as_buffer(arbitrary, exposed=True) col = build_column(data, dtype=current_dtype, mask=mask) if dtype is not None: @@ -2221,7 +2221,7 @@ def _mask_from_cuda_array_interface_desc(obj) -> Union[Buffer, None]: typecode = typestr[1] if typecode == "t": mask_size = bitmask_allocation_size_bytes(nelem) - mask = as_buffer(data=ptr, size=mask_size, owner=obj) + mask = as_buffer(data=ptr, size=mask_size, owner=obj, exposed=True) elif typecode == "b": col = as_column(mask) mask = bools_to_mask(col) diff --git a/python/cudf/cudf/core/column/decimal.py b/python/cudf/cudf/core/column/decimal.py index 5ee9024a0d8..77ca3f9688b 100644 --- a/python/cudf/cudf/core/column/decimal.py +++ b/python/cudf/cudf/core/column/decimal.py @@ -203,7 +203,7 @@ def from_arrow(cls, data: pa.Array): data_128 = cp.array(np.frombuffer(data.buffers()[1]).view("int32")) data_32 = data_128[::4].copy() return cls( - data=as_buffer(data_32.view("uint8")), + data=as_buffer(data_32.view("uint8"), exposed=True), size=len(data), dtype=dtype, offset=data.offset, @@ -290,7 +290,7 @@ def from_arrow(cls, data: pa.Array): data_128 = cp.array(np.frombuffer(data.buffers()[1]).view("int64")) data_64 = data_128[::2].copy() return cls( - data=as_buffer(data_64.view("uint8")), + data=as_buffer(data_64.view("uint8"), exposed=True), size=len(data), dtype=dtype, offset=data.offset, diff --git a/python/cudf/cudf/core/df_protocol.py b/python/cudf/cudf/core/df_protocol.py index b29fc41e5b4..b38d3048ed7 100644 --- a/python/cudf/cudf/core/df_protocol.py +++ b/python/cudf/cudf/core/df_protocol.py @@ -721,7 +721,9 @@ def _protocol_to_cudf_column_numeric( _dbuffer, _ddtype = buffers["data"] _check_buffer_is_on_gpu(_dbuffer) cudfcol_num = build_column( - as_buffer(data=_dbuffer.ptr, size=_dbuffer.bufsize, owner=None), + as_buffer( + data=_dbuffer.ptr, size=_dbuffer.bufsize, owner=None, exposed=True + ), protocol_dtype_to_cupy_dtype(_ddtype), ) return _set_missing_values(col, cudfcol_num), buffers @@ -751,7 +753,11 @@ def _set_missing_values( valid_mask = protocol_col.get_buffers()["validity"] if valid_mask is not None: bitmask = cp.asarray( - as_buffer(data=valid_mask[0].ptr, size=valid_mask[0].bufsize), + as_buffer( + data=valid_mask[0].ptr, + size=valid_mask[0].bufsize, + exposed=True, + ), cp.bool8, ) cudf_col[~bitmask] = None @@ -790,7 +796,9 @@ def _protocol_to_cudf_column_categorical( _check_buffer_is_on_gpu(codes_buffer) cdtype = protocol_dtype_to_cupy_dtype(codes_dtype) codes = build_column( - as_buffer(data=codes_buffer.ptr, size=codes_buffer.bufsize), + as_buffer( + data=codes_buffer.ptr, size=codes_buffer.bufsize, exposed=True + ), cdtype, ) @@ -822,7 +830,9 @@ def _protocol_to_cudf_column_string( data_buffer, data_dtype = buffers["data"] _check_buffer_is_on_gpu(data_buffer) encoded_string = build_column( - as_buffer(data=data_buffer.ptr, size=data_buffer.bufsize), + as_buffer( + data=data_buffer.ptr, size=data_buffer.bufsize, exposed=True + ), protocol_dtype_to_cupy_dtype(data_dtype), ) @@ -832,7 +842,9 @@ def _protocol_to_cudf_column_string( offset_buffer, offset_dtype = buffers["offsets"] _check_buffer_is_on_gpu(offset_buffer) offsets = build_column( - as_buffer(data=offset_buffer.ptr, size=offset_buffer.bufsize), + as_buffer( + data=offset_buffer.ptr, size=offset_buffer.bufsize, exposed=True + ), protocol_dtype_to_cupy_dtype(offset_dtype), ) diff --git a/python/cudf/cudf/core/groupby/groupby.py b/python/cudf/cudf/core/groupby/groupby.py index e4ea59c1f15..371c0566166 100644 --- a/python/cudf/cudf/core/groupby/groupby.py +++ b/python/cudf/cudf/core/groupby/groupby.py @@ -259,6 +259,10 @@ def __init__( else: self.grouping = _Grouping(obj, by, level) + self._groupby = libgroupby.GroupBy( + [*self.grouping.keys._columns], dropna=self._dropna + ) + def __iter__(self): group_names, offsets, _, grouped_values = self._grouped() if isinstance(group_names, cudf.BaseIndex): diff --git a/python/cudf/cudf/options.py b/python/cudf/cudf/options.py index 7f6a6f10e25..4a0a0437e00 100644 --- a/python/cudf/cudf/options.py +++ b/python/cudf/cudf/options.py @@ -1,5 +1,6 @@ # Copyright (c) 2022, NVIDIA CORPORATION. +import os import textwrap from collections.abc import Container from dataclasses import dataclass @@ -17,6 +18,26 @@ class Option: _OPTIONS: Dict[str, Option] = {} +def _env_get_int(name, default): + try: + return int(os.getenv(name, default)) + except (ValueError, TypeError): + return default + + +def _env_get_bool(name, default): + env = os.getenv(name) + if env is None: + return default + as_a_int = _env_get_int(name, None) + env = env.lower().strip() + if env == "true" or env == "on" or as_a_int: + return True + if env == "false" or env == "off" or as_a_int == 0: + return False + return default + + def _register_option( name: str, default_value: Any, description: str, validator: Callable ): @@ -129,6 +150,16 @@ def _validator(val): return _validator +def _integer_and_none_validator(val): + try: + if val is None or int(val): + return + except ValueError: + raise ValueError( + f"{val} is not a valid option. " f"Must be an integer or None." + ) + + _register_option( "default_integer_bitwidth", None, @@ -163,3 +194,43 @@ def _validator(val): ), _make_contains_validator([None, 32, 64]), ) + + +_register_option( + "spill", + _env_get_bool("CUDF_SPILL", False), + textwrap.dedent( + """ + Enables spilling. + \tValid values are True or False. Default is False. + """ + ), + _make_contains_validator([False, True]), +) + +_register_option( + "spill_on_demand", + _env_get_bool("CUDF_SPILL_ON_DEMAND", True), + textwrap.dedent( + """ + Enables spilling on demand using an RMM out-of-memory error handler. + This has no effect if spilling is disabled, see the "spill" option. + \tValid values are True or False. Default is True. + """ + ), + _make_contains_validator([False, True]), +) + +_register_option( + "spill_device_limit", + _env_get_int("CUDF_SPILL_DEVICE_LIMIT", None), + textwrap.dedent( + """ + Enforce a device memory limit in bytes. + This has no effect if spilling is disabled, see the "spill" option. + \tValid values are any positive integer or None (disabled). + \tDefault is None. + """ + ), + _integer_and_none_validator, +) diff --git a/python/cudf/cudf/tests/conftest.py b/python/cudf/cudf/tests/conftest.py index 258b628305d..bf565ed9a47 100644 --- a/python/cudf/cudf/tests/conftest.py +++ b/python/cudf/cudf/tests/conftest.py @@ -158,3 +158,17 @@ def default_float_bitwidth(request): cudf.set_option("default_float_bitwidth", request.param) yield request.param cudf.set_option("default_float_bitwidth", old_default) + + +@pytest.hookimpl(tryfirst=True, hookwrapper=True) +def pytest_runtest_makereport(item, call): + """Hook to make result information available in fixtures + + See + """ + outcome = yield + rep = outcome.get_result() + + # Set a report attribute for each phase of a call, which can + # be "setup", "call", "teardown" + setattr(item, "report", {rep.when: rep}) diff --git a/python/cudf/cudf/tests/test_buffer.py b/python/cudf/cudf/tests/test_buffer.py index 5ed5750f29b..6ff715db761 100644 --- a/python/cudf/cudf/tests/test_buffer.py +++ b/python/cudf/cudf/tests/test_buffer.py @@ -48,15 +48,21 @@ def test_buffer_from_cuda_iface_dtype(data, dtype): def test_buffer_creation_from_any(): ary = cp.arange(arr_len) - b = as_buffer(ary) + b = as_buffer(ary, exposed=True) assert isinstance(b, Buffer) - assert ary.__cuda_array_interface__["data"][0] == b.ptr + assert ary.data.ptr == b.ptr assert ary.nbytes == b.size with pytest.raises( ValueError, match="size must be specified when `data` is an integer" ): - as_buffer(42) + as_buffer(ary.data.ptr) + + b = as_buffer(ary.data.ptr, size=ary.nbytes, owner=ary, exposed=True) + assert isinstance(b, Buffer) + assert ary.data.ptr == b.ptr + assert ary.nbytes == b.size + assert b.owner.owner is ary @pytest.mark.parametrize( diff --git a/python/cudf/cudf/tests/test_groupby.py b/python/cudf/cudf/tests/test_groupby.py index b00e31115c9..3898db1c9fa 100644 --- a/python/cudf/cudf/tests/test_groupby.py +++ b/python/cudf/cudf/tests/test_groupby.py @@ -1456,7 +1456,11 @@ def test_groupby_attribute_error(): class TestGroupBy(cudf.core.groupby.GroupBy): @property def _groupby(self): - raise AttributeError("Test error message") + raise AttributeError(err_msg) + + @_groupby.setter + def _groupby(self, _): + pass a = cudf.DataFrame({"a": [1, 2], "b": [2, 3]}) gb = TestGroupBy(a, a["a"]) diff --git a/python/cudf/cudf/tests/test_spilling.py b/python/cudf/cudf/tests/test_spilling.py new file mode 100644 index 00000000000..b38d90f3178 --- /dev/null +++ b/python/cudf/cudf/tests/test_spilling.py @@ -0,0 +1,464 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +import importlib +import random +import time +import warnings +from concurrent.futures import ThreadPoolExecutor +from typing import Tuple + +import cupy +import numpy as np +import pandas +import pandas.testing +import pytest + +import rmm + +import cudf +import cudf.core.buffer.spill_manager +import cudf.options +from cudf.core.abc import Serializable +from cudf.core.buffer import Buffer, as_buffer, get_spill_lock, with_spill_lock +from cudf.core.buffer.spill_manager import ( + SpillManager, + get_global_manager, + get_rmm_memory_resource_stack, + set_global_manager, +) +from cudf.core.buffer.spillable_buffer import SpillableBuffer, SpillLock +from cudf.testing._utils import assert_eq + + +def gen_df(target="gpu") -> cudf.DataFrame: + ret = cudf.DataFrame({"a": [1, 2, 3]}) + if target != "gpu": + gen_df.buffer(ret).__spill__(target=target) + return ret + + +gen_df.buffer = lambda df: df._data._data["a"].data +gen_df.is_spilled = lambda df: gen_df.buffer(df).is_spilled +gen_df.is_spillable = lambda df: gen_df.buffer(df).spillable +gen_df.buffer_size = gen_df.buffer(gen_df()).size + + +def spilled_and_unspilled(manager: SpillManager) -> Tuple[int, int]: + """Get bytes spilled and unspilled known by the manager""" + spilled = sum(buf.size for buf in manager.base_buffers() if buf.is_spilled) + unspilled = sum( + buf.size for buf in manager.base_buffers() if not buf.is_spilled + ) + return spilled, unspilled + + +@pytest.fixture +def manager(request): + """Fixture to enable and make a spilling manager availabe""" + kwargs = dict(getattr(request, "param", {})) + with warnings.catch_warnings(): + warnings.simplefilter("error") + set_global_manager(manager=SpillManager(**kwargs)) + yield get_global_manager() + # Retrieving the test result using the `pytest_runtest_makereport` + # hook from conftest.py + if request.node.report["call"].failed: + # Ignore `overwriting non-empty manager` errors when + # test is failing. + warnings.simplefilter("ignore") + set_global_manager(manager=None) + + +def test_spillable_buffer(manager: SpillManager): + buf = as_buffer(data=rmm.DeviceBuffer(size=10), exposed=False) + assert isinstance(buf, SpillableBuffer) + assert buf.spillable + buf.ptr # Expose pointer + assert buf.exposed + assert not buf.spillable + buf = as_buffer(data=rmm.DeviceBuffer(size=10), exposed=False) + # Notice, accessing `__cuda_array_interface__` itself doesn't + # expose the pointer, only accessing the "data" field exposes + # the pointer. + iface = buf.__cuda_array_interface__ + assert not buf.exposed + assert buf.spillable + iface["data"][0] # Expose pointer + assert buf.exposed + assert not buf.spillable + + +@pytest.mark.parametrize( + "attribute", + [ + "ptr", + "get_ptr", + "memoryview", + "is_spilled", + "exposed", + "spillable", + "spill_lock", + "__spill__", + ], +) +def test_spillable_buffer_view_attributes(manager: SpillManager, attribute): + base = as_buffer(data=rmm.DeviceBuffer(size=10), exposed=False) + view = base[:] + attr_base = getattr(base, attribute) + attr_view = getattr(view, attribute) + if callable(attr_view): + pass + else: + assert attr_base == attr_view + + +def test_from_pandas(manager: SpillManager): + pdf1 = pandas.DataFrame({"x": [1, 2, 3]}) + df = cudf.from_pandas(pdf1) + assert df._data._data["x"].data.spillable + pdf2 = df.to_pandas() + pandas.testing.assert_frame_equal(pdf1, pdf2) + + +def test_creations(manager: SpillManager): + df = cudf.datasets.timeseries() + assert isinstance(df._data._data["x"].data, SpillableBuffer) + assert df._data._data["x"].data.spillable + df = cudf.DataFrame({"x": [1, 2, 3]}) + assert df._data._data["x"].data.spillable + df = cudf.datasets.randomdata(10) + assert df._data._data["x"].data.spillable + + +def test_spillable_df_groupby(manager: SpillManager): + df = cudf.DataFrame({"x": [1, 1, 1]}) + gb = df.groupby("x") + # `gb` holds a reference to the device memory, which makes + # the buffer unspillable + assert len(df._data._data["x"].data._spill_locks) == 1 + assert not df._data._data["x"].data.spillable + del gb + assert df._data._data["x"].data.spillable + + +def test_spilling_buffer(manager: SpillManager): + buf = as_buffer(rmm.DeviceBuffer(size=10), exposed=False) + buf.__spill__(target="cpu") + assert buf.is_spilled + buf.ptr # Expose pointer and trigger unspill + assert not buf.is_spilled + with pytest.raises(ValueError, match="unspillable buffer"): + buf.__spill__(target="cpu") + + +def test_environment_variables(monkeypatch): + def reload_options(): + # In order to enabling monkey patching of the environment variables + # mark the global manager as uninitialized. + set_global_manager(None) + cudf.core.buffer.spill_manager._global_manager_uninitialized = True + importlib.reload(cudf.options) + + monkeypatch.setenv("CUDF_SPILL_ON_DEMAND", "off") + monkeypatch.setenv("CUDF_SPILL", "off") + reload_options() + assert get_global_manager() is None + + monkeypatch.setenv("CUDF_SPILL", "on") + reload_options() + manager = get_global_manager() + assert isinstance(manager, SpillManager) + assert manager._spill_on_demand is False + assert manager._device_memory_limit is None + + monkeypatch.setenv("CUDF_SPILL_DEVICE_LIMIT", "1000") + reload_options() + manager = get_global_manager() + assert isinstance(manager, SpillManager) + assert manager._device_memory_limit == 1000 + + +def test_spill_device_memory(manager: SpillManager): + df = gen_df() + assert spilled_and_unspilled(manager) == (0, gen_df.buffer_size) + manager.spill_device_memory() + assert spilled_and_unspilled(manager) == (gen_df.buffer_size, 0) + del df + assert spilled_and_unspilled(manager) == (0, 0) + df1 = gen_df() + df2 = gen_df() + manager.spill_device_memory() + assert gen_df.is_spilled(df1) + assert not gen_df.is_spilled(df2) + manager.spill_device_memory() + assert gen_df.is_spilled(df1) + assert gen_df.is_spilled(df2) + df3 = df1 + df2 + assert not gen_df.is_spilled(df1) + assert not gen_df.is_spilled(df2) + assert not gen_df.is_spilled(df3) + manager.spill_device_memory() + assert gen_df.is_spilled(df1) + assert not gen_df.is_spilled(df2) + assert not gen_df.is_spilled(df3) + df2.abs() # Should change the access time + manager.spill_device_memory() + assert gen_df.is_spilled(df1) + assert not gen_df.is_spilled(df2) + assert gen_df.is_spilled(df3) + + +def test_spill_to_device_limit(manager: SpillManager): + df1 = gen_df() + df2 = gen_df() + assert spilled_and_unspilled(manager) == (0, gen_df.buffer_size * 2) + manager.spill_to_device_limit(device_limit=0) + assert spilled_and_unspilled(manager) == (gen_df.buffer_size * 2, 0) + df3 = df1 + df2 + manager.spill_to_device_limit(device_limit=0) + assert spilled_and_unspilled(manager) == (gen_df.buffer_size * 3, 0) + assert gen_df.is_spilled(df1) + assert gen_df.is_spilled(df2) + assert gen_df.is_spilled(df3) + + +@pytest.mark.parametrize( + "manager", [{"device_memory_limit": 0}], indirect=True +) +def test_zero_device_limit(manager: SpillManager): + assert manager._device_memory_limit == 0 + df1 = gen_df() + df2 = gen_df() + assert spilled_and_unspilled(manager) == (gen_df.buffer_size * 2, 0) + df1 + df2 + # Notice, while performing the addintion both df1 and df2 are unspillable + assert spilled_and_unspilled(manager) == (0, gen_df.buffer_size * 2) + manager.spill_to_device_limit() + assert spilled_and_unspilled(manager) == (gen_df.buffer_size * 2, 0) + + +def test_lookup_address_range(manager: SpillManager): + df = gen_df() + buf = gen_df.buffer(df) + buffers = manager.base_buffers() + assert len(buffers) == 1 + (buf,) = buffers + assert gen_df.buffer(df) is buf + assert manager.lookup_address_range(buf.ptr, buf.size)[0] is buf + assert manager.lookup_address_range(buf.ptr + 1, buf.size - 1)[0] is buf + assert manager.lookup_address_range(buf.ptr + 1, buf.size + 1)[0] is buf + assert manager.lookup_address_range(buf.ptr - 1, buf.size - 1)[0] is buf + assert manager.lookup_address_range(buf.ptr - 1, buf.size + 1)[0] is buf + assert not manager.lookup_address_range(buf.ptr + buf.size, buf.size) + assert not manager.lookup_address_range(buf.ptr - buf.size, buf.size) + + +def test_external_memory_never_spills(manager): + """ + Test that external data, i.e., data not managed by RMM, + is never spilled + """ + + cupy.cuda.set_allocator() # uses default allocator + + a = cupy.asarray([1, 2, 3]) + s = cudf.Series(a) + assert len(manager.base_buffers()) == 0 + assert not s._data[None].data.spillable + + +def test_spilling_df_views(manager): + df = gen_df(target="cpu") + assert gen_df.is_spilled(df) + df_view = df.loc[1:] + assert gen_df.is_spillable(df_view) + assert gen_df.is_spillable(df) + + +def test_modify_spilled_views(manager): + df = gen_df() + df_view = df.iloc[1:] + buf = gen_df.buffer(df) + buf.__spill__(target="cpu") + + # modify the spilled df and check that the changes are reflected + # in the view + df.iloc[1:] = 0 + assert_eq(df_view, df.iloc[1:]) + + # now, modify the view and check that the changes are reflected in + # the df + df_view.iloc[:] = -1 + assert_eq(df_view, df.iloc[1:]) + + +def test_ptr_restricted(manager: SpillManager): + buf = as_buffer(data=rmm.DeviceBuffer(size=10), exposed=False) + assert buf.spillable + assert len(buf._spill_locks) == 0 + slock1 = SpillLock() + buf.get_ptr(spill_lock=slock1) + assert not buf.spillable + assert len(buf._spill_locks) == 1 + slock2 = buf.spill_lock() + buf.get_ptr(spill_lock=slock2) + assert not buf.spillable + assert len(buf._spill_locks) == 2 + del slock1 + assert len(buf._spill_locks) == 1 + del slock2 + assert len(buf._spill_locks) == 0 + assert buf.spillable + + +def test_get_spill_lock(manager: SpillManager): + @with_spill_lock() + def f(sleep=False, nest=0): + if sleep: + time.sleep(random.random() / 100) + if nest: + return f(nest=nest - 1) + return get_spill_lock() + + assert get_spill_lock() is None + slock = f() + assert isinstance(slock, SpillLock) + assert get_spill_lock() is None + slock = f(nest=2) + assert isinstance(slock, SpillLock) + assert get_spill_lock() is None + + with ThreadPoolExecutor(max_workers=2) as executor: + futures_with_spill_lock = [] + futures_without_spill_lock = [] + for _ in range(100): + futures_with_spill_lock.append( + executor.submit(f, sleep=True, nest=1) + ) + futures_without_spill_lock.append( + executor.submit(f, sleep=True, nest=1) + ) + all(isinstance(f.result(), SpillLock) for f in futures_with_spill_lock) + all(f is None for f in futures_without_spill_lock) + + +def test_get_spill_lock_no_manager(): + """When spilling is disabled, get_spill_lock() should return None always""" + + @with_spill_lock() + def f(): + return get_spill_lock() + + assert get_spill_lock() is None + assert f() is None + + +@pytest.mark.parametrize("target", ["gpu", "cpu"]) +@pytest.mark.parametrize("view", [None, slice(0, 2), slice(1, 3)]) +def test_serialize_device(manager, target, view): + df1 = gen_df() + if view is not None: + df1 = df1.iloc[view] + gen_df.buffer(df1).__spill__(target=target) + + header, frames = df1.device_serialize() + assert len(frames) == 1 + if target == "gpu": + assert isinstance(frames[0], Buffer) + assert not gen_df.is_spilled(df1) + assert not gen_df.is_spillable(df1) + frames[0] = cupy.array(frames[0], copy=True) + else: + assert isinstance(frames[0], memoryview) + assert gen_df.is_spilled(df1) + assert gen_df.is_spillable(df1) + + df2 = Serializable.device_deserialize(header, frames) + assert_eq(df1, df2) + + +@pytest.mark.parametrize("target", ["gpu", "cpu"]) +@pytest.mark.parametrize("view", [None, slice(0, 2), slice(1, 3)]) +def test_serialize_host(manager, target, view): + df1 = gen_df() + if view is not None: + df1 = df1.iloc[view] + gen_df.buffer(df1).__spill__(target=target) + + # Unspilled df becomes spilled after host serialization + header, frames = df1.host_serialize() + assert all(isinstance(f, memoryview) for f in frames) + df2 = Serializable.host_deserialize(header, frames) + assert gen_df.is_spilled(df2) + assert_eq(df1, df2) + + +def test_serialize_dask_dataframe(manager: SpillManager): + protocol = pytest.importorskip("distributed.protocol") + + df1 = gen_df(target="gpu") + header, frames = protocol.serialize( + df1, serializers=("dask",), on_error="raise" + ) + buf: SpillableBuffer = gen_df.buffer(df1) + assert len(frames) == 1 + assert isinstance(frames[0], memoryview) + # Check that the memoryview and frames is the same memory + assert ( + np.array(buf.memoryview()).__array_interface__["data"] + == np.array(frames[0]).__array_interface__["data"] + ) + + df2 = protocol.deserialize(header, frames) + assert gen_df.is_spilled(df2) + assert_eq(df1, df2) + + +def test_serialize_cuda_dataframe(manager: SpillManager): + protocol = pytest.importorskip("distributed.protocol") + + df1 = gen_df(target="gpu") + header, frames = protocol.serialize( + df1, serializers=("cuda",), on_error="raise" + ) + buf: SpillableBuffer = gen_df.buffer(df1) + assert len(buf._spill_locks) == 1 + assert len(frames) == 1 + assert isinstance(frames[0], Buffer) + assert frames[0].ptr == buf.ptr + + frames[0] = cupy.array(frames[0], copy=True) + df2 = protocol.deserialize(header, frames) + assert_eq(df1, df2) + + +def test_get_rmm_memory_resource_stack(): + mr1 = rmm.mr.get_current_device_resource() + assert all( + not isinstance(m, rmm.mr.FailureCallbackResourceAdaptor) + for m in get_rmm_memory_resource_stack(mr1) + ) + + mr2 = rmm.mr.FailureCallbackResourceAdaptor(mr1, lambda x: False) + assert get_rmm_memory_resource_stack(mr2)[0] is mr2 + assert get_rmm_memory_resource_stack(mr2)[1] is mr1 + + mr3 = rmm.mr.FixedSizeMemoryResource(mr2) + assert get_rmm_memory_resource_stack(mr3)[0] is mr3 + assert get_rmm_memory_resource_stack(mr3)[1] is mr2 + assert get_rmm_memory_resource_stack(mr3)[2] is mr1 + + mr4 = rmm.mr.FailureCallbackResourceAdaptor(mr3, lambda x: False) + assert get_rmm_memory_resource_stack(mr4)[0] is mr4 + assert get_rmm_memory_resource_stack(mr4)[1] is mr3 + assert get_rmm_memory_resource_stack(mr4)[2] is mr2 + assert get_rmm_memory_resource_stack(mr4)[3] is mr1 + + +def test_df_transpose(manager: SpillManager): + df1 = cudf.DataFrame({"x": [1, 2]}) + df2 = df1.transpose() + # For now, all buffers are marked as exposed + assert df1._data._data["x"].data.exposed + assert df2._data._data[0].data.exposed + assert df2._data._data[1].data.exposed diff --git a/python/cudf/cudf/utils/utils.py b/python/cudf/cudf/utils/utils.py index c5f4629483a..65a86484207 100644 --- a/python/cudf/cudf/utils/utils.py +++ b/python/cudf/cudf/utils/utils.py @@ -294,7 +294,7 @@ def pa_mask_buffer_to_mask(mask_buf, size): dbuf = rmm.DeviceBuffer(size=mask_size) dbuf.copy_from_host(np.asarray(mask_buf).view("u1")) return as_buffer(dbuf) - return as_buffer(mask_buf) + return as_buffer(mask_buf, exposed=True) def _isnat(val): diff --git a/python/strings_udf/strings_udf/_lib/cudf_jit_udf.pyx b/python/strings_udf/strings_udf/_lib/cudf_jit_udf.pyx index 4fc9e473fa3..bf459f22c16 100644 --- a/python/strings_udf/strings_udf/_lib/cudf_jit_udf.pyx +++ b/python/strings_udf/strings_udf/_lib/cudf_jit_udf.pyx @@ -24,7 +24,7 @@ def column_to_string_view_array(Column strings_col): c_buffer = move(cpp_to_string_view_array(input_view)) device_buffer = DeviceBuffer.c_from_unique_ptr(move(c_buffer)) - return as_buffer(device_buffer) + return as_buffer(device_buffer, exposed=True) def column_from_udf_string_array(DeviceBuffer d_buffer): From b34ac14b9c33bcdd1909830d7ef4ea1ce617df09 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Wed, 9 Nov 2022 18:44:14 +0100 Subject: [PATCH 02/31] remove debug code --- python/cudf/cudf/core/buffer/spill_manager.py | 9 --------- .../cudf/cudf/core/buffer/spillable_buffer.py | 20 ------------------- python/cudf/cudf/tests/test_spilling.py | 16 --------------- 3 files changed, 45 deletions(-) diff --git a/python/cudf/cudf/core/buffer/spill_manager.py b/python/cudf/cudf/core/buffer/spill_manager.py index 821cae54128..020167c97fa 100644 --- a/python/cudf/cudf/core/buffer/spill_manager.py +++ b/python/cudf/cudf/core/buffer/spill_manager.py @@ -239,15 +239,6 @@ def spill_to_device_limit(self, device_limit: int = None) -> int: ret += nbytes return ret - def lookup_address_range( # TODO: remove, only for debugging - self, ptr: int, size: int - ) -> List[SpillableBuffer]: - ret = [] - for buf in self.base_buffers(): - if buf.is_overlapping(ptr, size): - ret.append(buf) - return ret - def __repr__(self) -> str: spilled = sum( buf.size for buf in self.base_buffers() if buf.is_spilled diff --git a/python/cudf/cudf/core/buffer/spillable_buffer.py b/python/cudf/cudf/core/buffer/spillable_buffer.py index 50ac6b4e653..ae77ad7375b 100644 --- a/python/cudf/cudf/core/buffer/spillable_buffer.py +++ b/python/cudf/cudf/core/buffer/spillable_buffer.py @@ -100,18 +100,6 @@ def _finalize_init(self, ptr_desc: Dict[str, Any], exposed: bool) -> None: f"cannot create {self.__class__} with a global spill manager" ) - if self._ptr: - # TODO: run the following asserts in "debug mode" or not at all. - # Assert that any buffers `data` may refer to has been exposed - # already. If this is not the case, it means that somewhere we - # are accessing a buffer's device pointer without marking it as - # exposed, which would be a bug. - bases = manager.lookup_address_range(self._ptr, self._size) - assert all(b.exposed for b in bases) - # Assert that if `data` refers to any existing base buffers, it - # must itself be exposed. - assert len(bases) == 0 or exposed - self._manager = manager self._manager.add(self) @@ -370,14 +358,6 @@ def serialize(self) -> Tuple[dict, list]: ] return header, frames - def is_overlapping(self, ptr: int, size: int): - with self._lock: - return ( - not self.is_spilled - and (ptr + size) > self._ptr - and (self._ptr + self._size) > ptr - ) - def __repr__(self) -> str: if self._ptr_desc["type"] != "gpu": ptr_info = str(self._ptr_desc) diff --git a/python/cudf/cudf/tests/test_spilling.py b/python/cudf/cudf/tests/test_spilling.py index b38d90f3178..591560f1983 100644 --- a/python/cudf/cudf/tests/test_spilling.py +++ b/python/cudf/cudf/tests/test_spilling.py @@ -237,22 +237,6 @@ def test_zero_device_limit(manager: SpillManager): assert spilled_and_unspilled(manager) == (gen_df.buffer_size * 2, 0) -def test_lookup_address_range(manager: SpillManager): - df = gen_df() - buf = gen_df.buffer(df) - buffers = manager.base_buffers() - assert len(buffers) == 1 - (buf,) = buffers - assert gen_df.buffer(df) is buf - assert manager.lookup_address_range(buf.ptr, buf.size)[0] is buf - assert manager.lookup_address_range(buf.ptr + 1, buf.size - 1)[0] is buf - assert manager.lookup_address_range(buf.ptr + 1, buf.size + 1)[0] is buf - assert manager.lookup_address_range(buf.ptr - 1, buf.size - 1)[0] is buf - assert manager.lookup_address_range(buf.ptr - 1, buf.size + 1)[0] is buf - assert not manager.lookup_address_range(buf.ptr + buf.size, buf.size) - assert not manager.lookup_address_range(buf.ptr - buf.size, buf.size) - - def test_external_memory_never_spills(manager): """ Test that external data, i.e., data not managed by RMM, From d23649fe12b74180cc8864e09ddf4f4b0c86c23f Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 11 Nov 2022 08:40:44 +0100 Subject: [PATCH 03/31] removed _from_any_memory --- python/cudf/cudf/_lib/column.pyx | 6 +++--- python/cudf/cudf/core/buffer/buffer.py | 22 ---------------------- 2 files changed, 3 insertions(+), 25 deletions(-) diff --git a/python/cudf/cudf/_lib/column.pyx b/python/cudf/cudf/_lib/column.pyx index 9e5b62ab404..9b415498ef3 100644 --- a/python/cudf/cudf/_lib/column.pyx +++ b/python/cudf/cudf/_lib/column.pyx @@ -531,9 +531,9 @@ cdef class Column: size=(size+offset) * dtype_itemsize) ) elif ( - # This is an optimization to avoid creating a new - # SpillableBuffer that represent the same memory - # as the owner. + # This is an optimization of the most common case where + # from_column_view creates a "view" that is identical to + # the owner. column_owner and isinstance(data_owner, SpillableBuffer) and # We have to make sure that `data_owner` is already spill diff --git a/python/cudf/cudf/core/buffer/buffer.py b/python/cudf/cudf/core/buffer/buffer.py index 29534ab5529..73e589ebb8e 100644 --- a/python/cudf/cudf/core/buffer/buffer.py +++ b/python/cudf/cudf/core/buffer/buffer.py @@ -148,28 +148,6 @@ def _from_host_memory(cls: Type[T], data: Any) -> T: # Create from device memory return cls._from_device_memory(buf) - @classmethod - def _from_any_memory(cls: Type[T], data: Any) -> T: - """Create a Buffer from device or host memory - - If data exposes `__cuda_array_interface__`, we deligate to the - `_from_device_memory` constructor otherwise `_from_host_memory`. - - Parameters - ---------- - data : Any - An object that represens device or host memory. - - Returns - ------- - Buffer - Buffer representing `data`. - """ - - if hasattr(data, "__cuda_array_interface__"): - return cls._from_device_memory(data) - return cls._from_host_memory(data) - def _getitem(self, offset: int, size: int) -> Buffer: """ Sub-classes can overwrite this to implement __getitem__ From 05f103b5c95c4dde1932ddeef88e490a36f526d9 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 11 Nov 2022 08:54:44 +0100 Subject: [PATCH 04/31] doc Co-authored-by: Vyas Ramasubramani --- python/cudf/cudf/core/buffer/spill_manager.py | 3 +-- python/cudf/cudf/core/buffer/spillable_buffer.py | 6 +++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/python/cudf/cudf/core/buffer/spill_manager.py b/python/cudf/cudf/core/buffer/spill_manager.py index 020167c97fa..b4e3f5f1ac7 100644 --- a/python/cudf/cudf/core/buffer/spill_manager.py +++ b/python/cudf/cudf/core/buffer/spill_manager.py @@ -63,8 +63,7 @@ class SpillManager: Parameters ---------- spill_on_demand : bool - Enable spill on demand. The global manager sets this to the value of - `CUDF_SPILL_ON_DEMAND` or False. + Enable spill on demand. device_memory_limit: int, optional If not None, this is the device memory limit in bytes that triggers device to host spilling. The global manager sets this to the value diff --git a/python/cudf/cudf/core/buffer/spillable_buffer.py b/python/cudf/cudf/core/buffer/spillable_buffer.py index ae77ad7375b..accb5902e11 100644 --- a/python/cudf/cudf/core/buffer/spillable_buffer.py +++ b/python/cudf/cudf/core/buffer/spillable_buffer.py @@ -97,7 +97,7 @@ def _finalize_init(self, ptr_desc: Dict[str, Any], exposed: bool) -> None: manager = get_global_manager() if manager is None: raise ValueError( - f"cannot create {self.__class__} with a global spill manager" + f"cannot create {self.__class__} without a global spill manager" ) self._manager = manager @@ -203,7 +203,7 @@ def __spill__(self, target: str = "cpu") -> None: # allocation might trigger spilling-on-demand which in turn # trigger a new call to this buffer's `__spill__()`. # Therefore, it is important that spilling-on-demand doesn't - # tries to unspill an already locked buffer! + # try to unspill an already locked buffer! dev_mem = rmm.DeviceBuffer.to_device( self._ptr_desc.pop("memoryview") ) @@ -418,7 +418,7 @@ def _getitem(self, offset: int, size: int) -> Buffer: def deserialize(cls, header: dict, frames: list): # TODO: because of the hack in `SpillableBuffer.serialize()` where # frames are of type `Buffer`, we always deserialize as if they are - # `SpillableBufferbuffer`. In the future, we should be able to + # `SpillableBuffer`. In the future, we should be able to # deserialize into `SpillableBufferSlice` when the frames hasn't been # copied. return SpillableBuffer.deserialize(header, frames) From 249bc7d6acc215b0b0158a273ccffa7bab76409e Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 11 Nov 2022 09:04:52 +0100 Subject: [PATCH 05/31] doc --- python/cudf/cudf/core/buffer/spillable_buffer.py | 3 ++- python/cudf/cudf/tests/conftest.py | 6 +++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/python/cudf/cudf/core/buffer/spillable_buffer.py b/python/cudf/cudf/core/buffer/spillable_buffer.py index accb5902e11..cf4acbc150c 100644 --- a/python/cudf/cudf/core/buffer/spillable_buffer.py +++ b/python/cudf/cudf/core/buffer/spillable_buffer.py @@ -97,7 +97,8 @@ def _finalize_init(self, ptr_desc: Dict[str, Any], exposed: bool) -> None: manager = get_global_manager() if manager is None: raise ValueError( - f"cannot create {self.__class__} without a global spill manager" + f"cannot create {self.__class__} without " + "a global spill manager" ) self._manager = manager diff --git a/python/cudf/cudf/tests/conftest.py b/python/cudf/cudf/tests/conftest.py index bf565ed9a47..30d8f1c8422 100644 --- a/python/cudf/cudf/tests/conftest.py +++ b/python/cudf/cudf/tests/conftest.py @@ -164,7 +164,11 @@ def default_float_bitwidth(request): def pytest_runtest_makereport(item, call): """Hook to make result information available in fixtures - See + This makes it possible for a pytest.fixture to access the current test + state through `request.node.report`. + See the `manager` fixture in `test_spilling.py` for an example. + + Pytest doc: """ outcome = yield rep = outcome.get_result() From f3f73282d4a77c8ded94716f13e2d31cfa1a9709 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 11 Nov 2022 09:36:26 +0100 Subject: [PATCH 06/31] rename base_buffers to buffers --- python/cudf/cudf/core/buffer/spill_manager.py | 28 +++++++++---------- python/cudf/cudf/tests/test_spilling.py | 6 ++-- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/python/cudf/cudf/core/buffer/spill_manager.py b/python/cudf/cudf/core/buffer/spill_manager.py index b4e3f5f1ac7..64c43a8f411 100644 --- a/python/cudf/cudf/core/buffer/spill_manager.py +++ b/python/cudf/cudf/core/buffer/spill_manager.py @@ -70,7 +70,7 @@ class SpillManager: of `CUDF_SPILL_DEVICE_LIMIT` or None. """ - _base_buffers: weakref.WeakValueDictionary[int, SpillableBuffer] + _buffers: weakref.WeakValueDictionary[int, SpillableBuffer] def __init__( self, @@ -79,7 +79,7 @@ def __init__( device_memory_limit: int = None, ) -> None: self._lock = threading.Lock() - self._base_buffers = weakref.WeakValueDictionary() + self._buffers = weakref.WeakValueDictionary() self._id_counter = 0 self._spill_on_demand = spill_on_demand self._device_memory_limit = device_memory_limit @@ -157,11 +157,11 @@ def add(self, buffer: SpillableBuffer) -> None: """ if buffer.size > 0 and not buffer.exposed: with self._lock: - self._base_buffers[self._id_counter] = buffer + self._buffers[self._id_counter] = buffer self._id_counter += 1 self.spill_to_device_limit() - def base_buffers( + def buffers( self, order_by_access_time: bool = False ) -> Tuple[SpillableBuffer, ...]: """Get all managed buffers @@ -177,7 +177,7 @@ def base_buffers( Tuple of buffers """ with self._lock: - ret = tuple(self._base_buffers.values()) + ret = tuple(self._buffers.values()) if order_by_access_time: ret = tuple(sorted(ret, key=lambda b: b.last_accessed)) return ret @@ -193,7 +193,7 @@ def spill_device_memory(self) -> int: int Number of bytes spilled. """ - for buf in self.base_buffers(order_by_access_time=True): + for buf in self.buffers(order_by_access_time=True): if buf.lock.acquire(blocking=False): try: if not buf.is_spilled and buf.spillable: @@ -228,7 +228,7 @@ def spill_to_device_limit(self, device_limit: int = None) -> int: ret = 0 while True: unspilled = sum( - buf.size for buf in self.base_buffers() if not buf.is_spilled + buf.size for buf in self.buffers() if not buf.is_spilled ) if unspilled < limit: break @@ -239,14 +239,12 @@ def spill_to_device_limit(self, device_limit: int = None) -> int: return ret def __repr__(self) -> str: - spilled = sum( - buf.size for buf in self.base_buffers() if buf.is_spilled - ) + spilled = sum(buf.size for buf in self.buffers() if buf.is_spilled) unspilled = sum( - buf.size for buf in self.base_buffers() if not buf.is_spilled + buf.size for buf in self.buffers() if not buf.is_spilled ) unspillable = 0 - for buf in self.base_buffers(): + for buf in self.buffers(): if not (buf.is_spilled or buf.spillable): unspillable += buf.size unspillable_ratio = unspillable / unspilled if unspilled else 0 @@ -274,9 +272,9 @@ def set_global_manager(manager: Optional[SpillManager]) -> None: global _global_manager, _global_manager_uninitialized if _global_manager is not None: gc.collect() - base_buffers = _global_manager.base_buffers() - if len(base_buffers) > 0: - warnings.warn(f"overwriting non-empty manager: {base_buffers}") + buffers = _global_manager.buffers() + if len(buffers) > 0: + warnings.warn(f"overwriting non-empty manager: {buffers}") _global_manager = manager _global_manager_uninitialized = False diff --git a/python/cudf/cudf/tests/test_spilling.py b/python/cudf/cudf/tests/test_spilling.py index 591560f1983..d849e787bb2 100644 --- a/python/cudf/cudf/tests/test_spilling.py +++ b/python/cudf/cudf/tests/test_spilling.py @@ -45,9 +45,9 @@ def gen_df(target="gpu") -> cudf.DataFrame: def spilled_and_unspilled(manager: SpillManager) -> Tuple[int, int]: """Get bytes spilled and unspilled known by the manager""" - spilled = sum(buf.size for buf in manager.base_buffers() if buf.is_spilled) + spilled = sum(buf.size for buf in manager.buffers() if buf.is_spilled) unspilled = sum( - buf.size for buf in manager.base_buffers() if not buf.is_spilled + buf.size for buf in manager.buffers() if not buf.is_spilled ) return spilled, unspilled @@ -247,7 +247,7 @@ def test_external_memory_never_spills(manager): a = cupy.asarray([1, 2, 3]) s = cudf.Series(a) - assert len(manager.base_buffers()) == 0 + assert len(manager.buffers()) == 0 assert not s._data[None].data.spillable From eaf8350d4457870837710c010d84c68cd12c7e0a Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 11 Nov 2022 10:13:19 +0100 Subject: [PATCH 07/31] Now, spill_device_memory() can spill multiple buffers --- python/cudf/cudf/core/buffer/spill_manager.py | 31 ++++++++++--------- python/cudf/cudf/tests/test_spilling.py | 10 +++--- 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/python/cudf/cudf/core/buffer/spill_manager.py b/python/cudf/cudf/core/buffer/spill_manager.py index 64c43a8f411..1b72bea5648 100644 --- a/python/cudf/cudf/core/buffer/spill_manager.py +++ b/python/cudf/cudf/core/buffer/spill_manager.py @@ -121,15 +121,10 @@ def _out_of_memory_handle(self, nbytes: int, *, retry_once=True) -> bool: already locked buffers. """ - # Keep spilling until `nbytes` been spilled - total_spilled = 0 - while total_spilled < nbytes: - spilled = self.spill_device_memory() - if spilled == 0: - break # No more to spill! - total_spilled += spilled - - if total_spilled > 0: + # Let's try to spill device memory + spilled = self.spill_device_memory(nbytes=nbytes) + + if spilled > 0: return True # Ask RMM to retry the allocation if retry_once: @@ -182,26 +177,34 @@ def buffers( ret = tuple(sorted(ret, key=lambda b: b.last_accessed)) return ret - def spill_device_memory(self) -> int: + def spill_device_memory(self, nbytes: int) -> int: """Try to spill device memory This function is safe to call doing spill-on-demand since it does not lock buffers already locked. + Parameters + ---------- + nbytes : int + Number of bytes to try to spill + Return ------ int - Number of bytes spilled. + Number of actually bytes spilled. """ + spilled = 0 for buf in self.buffers(order_by_access_time=True): if buf.lock.acquire(blocking=False): try: if not buf.is_spilled and buf.spillable: buf.__spill__(target="cpu") - return buf.size + spilled += buf.size + if spilled >= nbytes: + break finally: buf.lock.release() - return 0 + return spilled def spill_to_device_limit(self, device_limit: int = None) -> int: """Spill until device limit @@ -232,7 +235,7 @@ def spill_to_device_limit(self, device_limit: int = None) -> int: ) if unspilled < limit: break - nbytes = self.spill_device_memory() + nbytes = self.spill_device_memory(nbytes=limit - unspilled) if nbytes == 0: break # No more to spill ret += nbytes diff --git a/python/cudf/cudf/tests/test_spilling.py b/python/cudf/cudf/tests/test_spilling.py index d849e787bb2..0ad9ff57123 100644 --- a/python/cudf/cudf/tests/test_spilling.py +++ b/python/cudf/cudf/tests/test_spilling.py @@ -181,28 +181,28 @@ def reload_options(): def test_spill_device_memory(manager: SpillManager): df = gen_df() assert spilled_and_unspilled(manager) == (0, gen_df.buffer_size) - manager.spill_device_memory() + manager.spill_device_memory(nbytes=1) assert spilled_and_unspilled(manager) == (gen_df.buffer_size, 0) del df assert spilled_and_unspilled(manager) == (0, 0) df1 = gen_df() df2 = gen_df() - manager.spill_device_memory() + manager.spill_device_memory(nbytes=1) assert gen_df.is_spilled(df1) assert not gen_df.is_spilled(df2) - manager.spill_device_memory() + manager.spill_device_memory(nbytes=1) assert gen_df.is_spilled(df1) assert gen_df.is_spilled(df2) df3 = df1 + df2 assert not gen_df.is_spilled(df1) assert not gen_df.is_spilled(df2) assert not gen_df.is_spilled(df3) - manager.spill_device_memory() + manager.spill_device_memory(nbytes=1) assert gen_df.is_spilled(df1) assert not gen_df.is_spilled(df2) assert not gen_df.is_spilled(df3) df2.abs() # Should change the access time - manager.spill_device_memory() + manager.spill_device_memory(nbytes=1) assert gen_df.is_spilled(df1) assert not gen_df.is_spilled(df2) assert gen_df.is_spilled(df3) From 3e7bb39c24a8e58b4aa6eedc73eedaf43a574b6e Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 11 Nov 2022 10:23:51 +0100 Subject: [PATCH 08/31] Clean up --- python/cudf/cudf/_lib/column.pyx | 6 +----- python/cudf/cudf/tests/test_spilling.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/python/cudf/cudf/_lib/column.pyx b/python/cudf/cudf/_lib/column.pyx index 9b415498ef3..2f8c9c9ec7d 100644 --- a/python/cudf/cudf/_lib/column.pyx +++ b/python/cudf/cudf/_lib/column.pyx @@ -102,11 +102,7 @@ cdef class Column: if self._data is None: start = self.offset * self.dtype.itemsize end = start + self.size * self.dtype.itemsize - if start == 0 and end == self.base_data.size: - # `data` spans all of `base_data` - self._data = self.base_data - else: - self._data = self.base_data[start:end] + self._data = self.base_data[start:end] return self._data @property diff --git a/python/cudf/cudf/tests/test_spilling.py b/python/cudf/cudf/tests/test_spilling.py index 0ad9ff57123..33e775784c6 100644 --- a/python/cudf/cudf/tests/test_spilling.py +++ b/python/cudf/cudf/tests/test_spilling.py @@ -26,7 +26,11 @@ get_rmm_memory_resource_stack, set_global_manager, ) -from cudf.core.buffer.spillable_buffer import SpillableBuffer, SpillLock +from cudf.core.buffer.spillable_buffer import ( + SpillableBuffer, + SpillableBufferSlice, + SpillLock, +) from cudf.testing._utils import assert_eq @@ -135,7 +139,7 @@ def test_spillable_df_groupby(manager: SpillManager): gb = df.groupby("x") # `gb` holds a reference to the device memory, which makes # the buffer unspillable - assert len(df._data._data["x"].data._spill_locks) == 1 + assert len(df._data._data["x"].base_data._spill_locks) == 1 assert not df._data._data["x"].data.spillable del gb assert df._data._data["x"].data.spillable @@ -405,8 +409,8 @@ def test_serialize_cuda_dataframe(manager: SpillManager): header, frames = protocol.serialize( df1, serializers=("cuda",), on_error="raise" ) - buf: SpillableBuffer = gen_df.buffer(df1) - assert len(buf._spill_locks) == 1 + buf: SpillableBufferSlice = gen_df.buffer(df1) + assert len(buf._base._spill_locks) == 1 assert len(frames) == 1 assert isinstance(frames[0], Buffer) assert frames[0].ptr == buf.ptr From f4e82be4807895c8f477eb1b446c3adadb8160a4 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 11 Nov 2022 10:46:20 +0100 Subject: [PATCH 09/31] mutable_view(): use .get_ptr() --- python/cudf/cudf/_lib/column.pyx | 9 ++++++++- python/cudf/cudf/_lib/copying.pyx | 1 + python/cudf/cudf/_lib/filling.pyx | 3 +++ python/cudf/cudf/_lib/replace.pyx | 2 ++ 4 files changed, 14 insertions(+), 1 deletion(-) diff --git a/python/cudf/cudf/_lib/column.pyx b/python/cudf/cudf/_lib/column.pyx index 2f8c9c9ec7d..7d9ef9f0d30 100644 --- a/python/cudf/cudf/_lib/column.pyx +++ b/python/cudf/cudf/_lib/column.pyx @@ -336,7 +336,14 @@ cdef class Column: cdef vector[mutable_column_view] children cdef void* data - data = (col.base_data_ptr) + if col.base_data is None: + data = NULL + elif isinstance(col.base_data, SpillableBuffer): + data = (col.base_data).get_ptr( + spill_lock=get_spill_lock() + ) + else: + data = (col.base_data.ptr) cdef Column child_column if col.base_children: diff --git a/python/cudf/cudf/_lib/copying.pyx b/python/cudf/cudf/_lib/copying.pyx index 7cd811caa26..d73278a991c 100644 --- a/python/cudf/cudf/_lib/copying.pyx +++ b/python/cudf/cudf/_lib/copying.pyx @@ -86,6 +86,7 @@ def copy_column(Column input_column): return Column.from_unique_ptr(move(c_result)) +@with_spill_lock() def _copy_range_in_place(Column input_column, Column target_column, size_type input_begin, diff --git a/python/cudf/cudf/_lib/filling.pyx b/python/cudf/cudf/_lib/filling.pyx index 891da82821c..be92abdb0e8 100644 --- a/python/cudf/cudf/_lib/filling.pyx +++ b/python/cudf/cudf/_lib/filling.pyx @@ -1,5 +1,7 @@ # Copyright (c) 2020-2022, NVIDIA CORPORATION. +from cudf.core.buffer import with_spill_lock + from libcpp.memory cimport unique_ptr from libcpp.utility cimport move @@ -15,6 +17,7 @@ from cudf._lib.scalar cimport DeviceScalar from cudf._lib.utils cimport columns_from_unique_ptr, table_view_from_columns +@with_spill_lock() def fill_in_place(Column destination, int begin, int end, DeviceScalar value): cdef mutable_column_view c_destination = destination.mutable_view() cdef size_type c_begin = begin diff --git a/python/cudf/cudf/_lib/replace.pyx b/python/cudf/cudf/_lib/replace.pyx index e4311b356ec..700f6637a44 100644 --- a/python/cudf/cudf/_lib/replace.pyx +++ b/python/cudf/cudf/_lib/replace.pyx @@ -4,6 +4,7 @@ from libcpp.memory cimport unique_ptr from libcpp.utility cimport move from cudf.api.types import is_scalar +from cudf.core.buffer import with_spill_lock from cudf._lib.column cimport Column @@ -210,6 +211,7 @@ def clip(Column input_col, object lo, object hi): return clamp(input_col, lo_scalar, hi_scalar) +@with_spill_lock() def normalize_nans_and_zeros_inplace(Column input_col): """ Inplace normalizing From 8a06512c256ba06369e69cbb1c9d9bd6472f2d2f Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 11 Nov 2022 10:59:33 +0100 Subject: [PATCH 10/31] Column: removed the base_data_ptr property --- python/cudf/cudf/_lib/column.pyi | 2 -- python/cudf/cudf/_lib/column.pyx | 7 ------- python/cudf/cudf/_lib/replace.pyx | 2 +- python/cudf/cudf/testing/_utils.py | 8 ++++++-- 4 files changed, 7 insertions(+), 12 deletions(-) diff --git a/python/cudf/cudf/_lib/column.pyi b/python/cudf/cudf/_lib/column.pyi index c38c560b982..612f3cdf95a 100644 --- a/python/cudf/cudf/_lib/column.pyi +++ b/python/cudf/cudf/_lib/column.pyi @@ -42,8 +42,6 @@ class Column: @property def base_data(self) -> Optional[Buffer]: ... @property - def base_data_ptr(self) -> int: ... - @property def data(self) -> Optional[Buffer]: ... @property def data_ptr(self) -> int: ... diff --git a/python/cudf/cudf/_lib/column.pyx b/python/cudf/cudf/_lib/column.pyx index 7d9ef9f0d30..4ce0ea70e33 100644 --- a/python/cudf/cudf/_lib/column.pyx +++ b/python/cudf/cudf/_lib/column.pyx @@ -88,13 +88,6 @@ cdef class Column: def base_data(self): return self._base_data - @property - def base_data_ptr(self): - if self.base_data is None: - return 0 - else: - return self.base_data.ptr - @property def data(self): if self.base_data is None: diff --git a/python/cudf/cudf/_lib/replace.pyx b/python/cudf/cudf/_lib/replace.pyx index 700f6637a44..44654f895d2 100644 --- a/python/cudf/cudf/_lib/replace.pyx +++ b/python/cudf/cudf/_lib/replace.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2021, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. from libcpp.memory cimport unique_ptr from libcpp.utility cimport move diff --git a/python/cudf/cudf/testing/_utils.py b/python/cudf/cudf/testing/_utils.py index 259257c257f..5465462d7c2 100644 --- a/python/cudf/cudf/testing/_utils.py +++ b/python/cudf/cudf/testing/_utils.py @@ -352,8 +352,12 @@ def assert_column_memory_eq( children to the same constraints. Also fails check if the number of children mismatches at any level. """ - assert lhs.base_data_ptr == rhs.base_data_ptr - assert lhs.base_mask_ptr == rhs.base_mask_ptr + + def get_ptr(x) -> int: + return x.ptr if x else 0 + + assert get_ptr(lhs.base_data) == get_ptr(rhs.base_data) + assert get_ptr(lhs.base_mask) == get_ptr(rhs.base_mask) assert lhs.base_size == rhs.base_size assert lhs.offset == rhs.offset assert lhs.size == rhs.size From 9b71670bcbd1ff8d75fc0979888449560e018b17 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 11 Nov 2022 11:07:57 +0100 Subject: [PATCH 11/31] renamed __spill__ to spill --- python/cudf/cudf/core/buffer/spill_manager.py | 2 +- python/cudf/cudf/core/buffer/spillable_buffer.py | 16 ++++++++-------- python/cudf/cudf/tests/test_spilling.py | 14 +++++++------- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/python/cudf/cudf/core/buffer/spill_manager.py b/python/cudf/cudf/core/buffer/spill_manager.py index 1b72bea5648..5ea1b90928b 100644 --- a/python/cudf/cudf/core/buffer/spill_manager.py +++ b/python/cudf/cudf/core/buffer/spill_manager.py @@ -198,7 +198,7 @@ def spill_device_memory(self, nbytes: int) -> int: if buf.lock.acquire(blocking=False): try: if not buf.is_spilled and buf.spillable: - buf.__spill__(target="cpu") + buf.spill(target="cpu") spilled += buf.size if spilled >= nbytes: break diff --git a/python/cudf/cudf/core/buffer/spillable_buffer.py b/python/cudf/cudf/core/buffer/spillable_buffer.py index cf4acbc150c..e68427e0813 100644 --- a/python/cudf/cudf/core/buffer/spillable_buffer.py +++ b/python/cudf/cudf/core/buffer/spillable_buffer.py @@ -61,7 +61,7 @@ class SpillableBuffer(Buffer): """A spillable buffer that implements DeviceBufferLike. This buffer supports spilling the represented data to host memory. - Spilling can be done manually by calling `.__spill__(target="cpu")` but + Spilling can be done manually by calling `.spill(target="cpu")` but usually the associated spilling manager triggers spilling based on current device memory usage see `cudf.core.buffer.spill_manager.SpillManager`. Unspill is triggered automatically when accessing the data of the buffer. @@ -174,7 +174,7 @@ def lock(self) -> RLock: def is_spilled(self) -> bool: return self._ptr_desc["type"] != "gpu" - def __spill__(self, target: str = "cpu") -> None: + def spill(self, target: str = "cpu") -> None: """Spill or un-spill this buffer in-place Parameters @@ -202,7 +202,7 @@ def __spill__(self, target: str = "cpu") -> None: elif (ptr_type, target) == ("cpu", "gpu"): # Notice, this operation is prone to deadlock because the RMM # allocation might trigger spilling-on-demand which in turn - # trigger a new call to this buffer's `__spill__()`. + # trigger a new call to this buffer's `spill()`. # Therefore, it is important that spilling-on-demand doesn't # try to unspill an already locked buffer! dev_mem = rmm.DeviceBuffer.to_device( @@ -228,7 +228,7 @@ def ptr(self) -> int: self._manager.spill_to_device_limit() with self._lock: - self.__spill__(target="gpu") + self.spill(target="gpu") self._exposed = True self._last_accessed = time.monotonic() return self._ptr @@ -237,7 +237,7 @@ def spill_lock(self, spill_lock: SpillLock = None) -> SpillLock: if spill_lock is None: spill_lock = SpillLock() with self._lock: - self.__spill__(target="gpu") + self.spill(target="gpu") self._spill_locks.add(spill_lock) return spill_lock @@ -304,7 +304,7 @@ def memoryview(self, *, offset: int = 0, size: int = None) -> memoryview: size = self._size if size is None else size with self._lock: if self.spillable: - self.__spill__(target="cpu") + self.spill(target="cpu") return self._ptr_desc["memoryview"][offset : offset + size] else: assert self._ptr_desc["type"] == "gpu" @@ -435,8 +435,8 @@ def __repr__(self) -> str: ) # The rest of the methods delegate to the base buffer. - def __spill__(self, target: str = "cpu") -> None: - return self._base.__spill__(target=target) + def spill(self, target: str = "cpu") -> None: + return self._base.spill(target=target) @property def is_spilled(self) -> bool: diff --git a/python/cudf/cudf/tests/test_spilling.py b/python/cudf/cudf/tests/test_spilling.py index 33e775784c6..03a920e67a2 100644 --- a/python/cudf/cudf/tests/test_spilling.py +++ b/python/cudf/cudf/tests/test_spilling.py @@ -37,7 +37,7 @@ def gen_df(target="gpu") -> cudf.DataFrame: ret = cudf.DataFrame({"a": [1, 2, 3]}) if target != "gpu": - gen_df.buffer(ret).__spill__(target=target) + gen_df.buffer(ret).spill(target=target) return ret @@ -102,7 +102,7 @@ def test_spillable_buffer(manager: SpillManager): "exposed", "spillable", "spill_lock", - "__spill__", + "spill", ], ) def test_spillable_buffer_view_attributes(manager: SpillManager, attribute): @@ -147,12 +147,12 @@ def test_spillable_df_groupby(manager: SpillManager): def test_spilling_buffer(manager: SpillManager): buf = as_buffer(rmm.DeviceBuffer(size=10), exposed=False) - buf.__spill__(target="cpu") + buf.spill(target="cpu") assert buf.is_spilled buf.ptr # Expose pointer and trigger unspill assert not buf.is_spilled with pytest.raises(ValueError, match="unspillable buffer"): - buf.__spill__(target="cpu") + buf.spill(target="cpu") def test_environment_variables(monkeypatch): @@ -267,7 +267,7 @@ def test_modify_spilled_views(manager): df = gen_df() df_view = df.iloc[1:] buf = gen_df.buffer(df) - buf.__spill__(target="cpu") + buf.spill(target="cpu") # modify the spilled df and check that the changes are reflected # in the view @@ -347,7 +347,7 @@ def test_serialize_device(manager, target, view): df1 = gen_df() if view is not None: df1 = df1.iloc[view] - gen_df.buffer(df1).__spill__(target=target) + gen_df.buffer(df1).spill(target=target) header, frames = df1.device_serialize() assert len(frames) == 1 @@ -371,7 +371,7 @@ def test_serialize_host(manager, target, view): df1 = gen_df() if view is not None: df1 = df1.iloc[view] - gen_df.buffer(df1).__spill__(target=target) + gen_df.buffer(df1).spill(target=target) # Unspilled df becomes spilled after host serialization header, frames = df1.host_serialize() From f0dab57e8a409ec72e3f57b40fe45383467528b9 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 11 Nov 2022 11:19:23 +0100 Subject: [PATCH 12/31] SpillableBuffer: .lock is now a regular attribute --- .../cudf/cudf/core/buffer/spillable_buffer.py | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/python/cudf/cudf/core/buffer/spillable_buffer.py b/python/cudf/cudf/core/buffer/spillable_buffer.py index e68427e0813..ef59057e305 100644 --- a/python/cudf/cudf/core/buffer/spillable_buffer.py +++ b/python/cudf/cudf/core/buffer/spillable_buffer.py @@ -79,7 +79,7 @@ class SpillableBuffer(Buffer): Use the factory function `as_buffer` to create a SpillableBuffer instance. """ - _lock: RLock + lock: RLock _spill_locks: weakref.WeakSet _last_accessed: float _ptr_desc: Dict[str, Any] @@ -89,7 +89,7 @@ class SpillableBuffer(Buffer): def _finalize_init(self, ptr_desc: Dict[str, Any], exposed: bool) -> None: from cudf.core.buffer.spill_manager import get_global_manager - self._lock = RLock() + self.lock = RLock() self._spill_locks = weakref.WeakSet() self._last_accessed = time.monotonic() self._ptr_desc = ptr_desc @@ -166,10 +166,6 @@ def _from_host_memory(cls: Type[T], data: Any) -> T: ) return ret - @property - def lock(self) -> RLock: - return self._lock - @property def is_spilled(self) -> bool: return self._ptr_desc["type"] != "gpu" @@ -183,7 +179,7 @@ def spill(self, target: str = "cpu") -> None: The target of the spilling. """ - with self._lock: + with self.lock: ptr_type = self._ptr_desc["type"] if ptr_type == target: return @@ -227,7 +223,7 @@ def ptr(self) -> int: """ self._manager.spill_to_device_limit() - with self._lock: + with self.lock: self.spill(target="gpu") self._exposed = True self._last_accessed = time.monotonic() @@ -236,7 +232,7 @@ def ptr(self) -> int: def spill_lock(self, spill_lock: SpillLock = None) -> SpillLock: if spill_lock is None: spill_lock = SpillLock() - with self._lock: + with self.lock: self.spill(target="gpu") self._spill_locks.add(spill_lock) return spill_lock @@ -302,7 +298,7 @@ def __cuda_array_interface__(self) -> dict: def memoryview(self, *, offset: int = 0, size: int = None) -> memoryview: size = self._size if size is None else size - with self._lock: + with self.lock: if self.spillable: self.spill(target="cpu") return self._ptr_desc["memoryview"][offset : offset + size] @@ -338,7 +334,7 @@ def serialize(self) -> Tuple[dict, list]: """ header: Dict[Any, Any] frames: List[Buffer | memoryview] - with self._lock: + with self.lock: header = {} header["type-serialized"] = pickle.dumps(self.__class__) header["frame_count"] = 1 @@ -401,7 +397,7 @@ def __init__(self, base: SpillableBuffer, offset: int, size: int) -> None: self._offset = offset self._size = size self._owner = base - self._lock = base.lock + self.lock = base.lock @property def ptr(self) -> int: From e3da837c26f136e42c7fff3c51ab72912372fe8f Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 11 Nov 2022 13:27:57 +0100 Subject: [PATCH 13/31] clean up spill_lock --- .../cudf/cudf/core/buffer/spillable_buffer.py | 18 ++++++++++++++---- python/cudf/cudf/tests/test_spilling.py | 3 ++- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/python/cudf/cudf/core/buffer/spillable_buffer.py b/python/cudf/cudf/core/buffer/spillable_buffer.py index ef59057e305..3c3d8ef06b2 100644 --- a/python/cudf/cudf/core/buffer/spillable_buffer.py +++ b/python/cudf/cudf/core/buffer/spillable_buffer.py @@ -229,13 +229,23 @@ def ptr(self) -> int: self._last_accessed = time.monotonic() return self._ptr - def spill_lock(self, spill_lock: SpillLock = None) -> SpillLock: + def spill_lock(self, spill_lock: SpillLock) -> None: + """Spill lock the buffer + + Mark the buffer as unspillable while `spill_lock` is alive, + which is tracked by monitoring a weakref to `spill_lock`. + + Parameters + ---------- + spill_lock : SpillLock + The object that defines the scope of the lock. + """ + if spill_lock is None: spill_lock = SpillLock() with self.lock: self.spill(target="gpu") self._spill_locks.add(spill_lock) - return spill_lock def get_ptr(self, spill_lock: SpillLock = None) -> int: """Get a device pointer to the memory of the buffer. @@ -446,5 +456,5 @@ def exposed(self) -> bool: def spillable(self) -> bool: return self._base.spillable - def spill_lock(self, spill_lock: SpillLock = None) -> SpillLock: - return self._base.spill_lock(spill_lock=spill_lock) + def spill_lock(self, spill_lock: SpillLock) -> None: + self._base.spill_lock(spill_lock=spill_lock) diff --git a/python/cudf/cudf/tests/test_spilling.py b/python/cudf/cudf/tests/test_spilling.py index 03a920e67a2..348d8a82c67 100644 --- a/python/cudf/cudf/tests/test_spilling.py +++ b/python/cudf/cudf/tests/test_spilling.py @@ -288,7 +288,8 @@ def test_ptr_restricted(manager: SpillManager): buf.get_ptr(spill_lock=slock1) assert not buf.spillable assert len(buf._spill_locks) == 1 - slock2 = buf.spill_lock() + slock2 = SpillLock() + buf.spill_lock(spill_lock=slock2) buf.get_ptr(spill_lock=slock2) assert not buf.spillable assert len(buf._spill_locks) == 2 From e4b5c900acf97b3dae4892a83821a723c059eb9d Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 11 Nov 2022 13:49:50 +0100 Subject: [PATCH 14/31] undone the _groupby change --- python/cudf/cudf/_lib/groupby.pyx | 1 + python/cudf/cudf/core/groupby/groupby.py | 4 ---- python/cudf/cudf/tests/test_groupby.py | 4 ---- python/cudf/cudf/tests/test_spilling.py | 5 +++-- 4 files changed, 4 insertions(+), 10 deletions(-) diff --git a/python/cudf/cudf/_lib/groupby.pyx b/python/cudf/cudf/_lib/groupby.pyx index bea39c06387..b1012a7619a 100644 --- a/python/cudf/cudf/_lib/groupby.pyx +++ b/python/cudf/cudf/_lib/groupby.pyx @@ -96,6 +96,7 @@ cdef class GroupBy: with with_spill_lock() as spill_lock: keys_view = table_view_from_columns(keys) + # We spill lock the columns while this GroupBy instance is alive. self._spill_lock = spill_lock with nogil: diff --git a/python/cudf/cudf/core/groupby/groupby.py b/python/cudf/cudf/core/groupby/groupby.py index 371c0566166..e4ea59c1f15 100644 --- a/python/cudf/cudf/core/groupby/groupby.py +++ b/python/cudf/cudf/core/groupby/groupby.py @@ -259,10 +259,6 @@ def __init__( else: self.grouping = _Grouping(obj, by, level) - self._groupby = libgroupby.GroupBy( - [*self.grouping.keys._columns], dropna=self._dropna - ) - def __iter__(self): group_names, offsets, _, grouped_values = self._grouped() if isinstance(group_names, cudf.BaseIndex): diff --git a/python/cudf/cudf/tests/test_groupby.py b/python/cudf/cudf/tests/test_groupby.py index 3898db1c9fa..dd1f726c783 100644 --- a/python/cudf/cudf/tests/test_groupby.py +++ b/python/cudf/cudf/tests/test_groupby.py @@ -1458,10 +1458,6 @@ class TestGroupBy(cudf.core.groupby.GroupBy): def _groupby(self): raise AttributeError(err_msg) - @_groupby.setter - def _groupby(self, _): - pass - a = cudf.DataFrame({"a": [1, 2], "b": [2, 3]}) gb = TestGroupBy(a, a["a"]) diff --git a/python/cudf/cudf/tests/test_spilling.py b/python/cudf/cudf/tests/test_spilling.py index 348d8a82c67..4a5e317c6e4 100644 --- a/python/cudf/cudf/tests/test_spilling.py +++ b/python/cudf/cudf/tests/test_spilling.py @@ -137,8 +137,9 @@ def test_creations(manager: SpillManager): def test_spillable_df_groupby(manager: SpillManager): df = cudf.DataFrame({"x": [1, 1, 1]}) gb = df.groupby("x") - # `gb` holds a reference to the device memory, which makes - # the buffer unspillable + assert len(df._data._data["x"].base_data._spill_locks) == 0 + gb._groupby + # `gb._groupby`, which is cached on `gb`, holds a spill lock assert len(df._data._data["x"].base_data._spill_locks) == 1 assert not df._data._data["x"].data.spillable del gb From d25b9525a87aa5c305c71937eac3f79943c0f91c Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 11 Nov 2022 14:38:22 +0100 Subject: [PATCH 15/31] doc --- python/cudf/cudf/core/buffer/spillable_buffer.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/python/cudf/cudf/core/buffer/spillable_buffer.py b/python/cudf/cudf/core/buffer/spillable_buffer.py index 3c3d8ef06b2..87275c55e61 100644 --- a/python/cudf/cudf/core/buffer/spillable_buffer.py +++ b/python/cudf/cudf/core/buffer/spillable_buffer.py @@ -87,6 +87,19 @@ class SpillableBuffer(Buffer): _manager: SpillManager def _finalize_init(self, ptr_desc: Dict[str, Any], exposed: bool) -> None: + """Finish initialization of the spillable buffer + + This implements the common initialization that `_from_device_memory` + and `_from_host_memory` are missing. + + Parameters + ---------- + ptr_desc : dict + Description of the memory. + exposed : bool, optional + Mark the buffer as permanently exposed (unspillable). + """ + from cudf.core.buffer.spill_manager import get_global_manager self.lock = RLock() From 561c135ce67ec99ae4a667eb940928e861cd32c0 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 11 Nov 2022 14:43:39 +0100 Subject: [PATCH 16/31] use super()._from_device_memory(data) --- python/cudf/cudf/core/buffer/spillable_buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/cudf/cudf/core/buffer/spillable_buffer.py b/python/cudf/cudf/core/buffer/spillable_buffer.py index 87275c55e61..768de3df55a 100644 --- a/python/cudf/cudf/core/buffer/spillable_buffer.py +++ b/python/cudf/cudf/core/buffer/spillable_buffer.py @@ -137,7 +137,7 @@ def _from_device_memory( SpillableBuffer Buffer representing the same device memory as `data` """ - ret = super(SpillableBuffer, cls)._from_device_memory(data) + ret = super()._from_device_memory(data) ret._finalize_init(ptr_desc={"type": "gpu"}, exposed=exposed) return ret From eaf19dd92c32114b1485eddfa6a2aedfb2eb5f07 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 11 Nov 2022 16:00:20 +0100 Subject: [PATCH 17/31] doc --- .../source/developer_guide/library_design.md | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/docs/cudf/source/developer_guide/library_design.md b/docs/cudf/source/developer_guide/library_design.md index be233edf200..8d5705e9859 100644 --- a/docs/cudf/source/developer_guide/library_design.md +++ b/docs/cudf/source/developer_guide/library_design.md @@ -219,6 +219,30 @@ device to host to enable out-of-memory computation, i.e., computing on objects t available on the GPU. +Spilling can be enabled in two ways (it is disabled by default): + - setting the environment variable `CUDF_SPILL=on`, or + - setting the `spill` option in `cudf` by doing `cudf.set_option("spill", True)`. + +Additionally, parameters are: + - `CUDF_SPILL_ON_DEMAND=ON` / `cudf.set_option("spill_on_demand", True)`, which registers an RMM out-of-memory error handler that spills buffers in order to free up memory. + - `CUDF_SPILL_DEVICE_LIMIT=...` / `cudf.set_option("spill_device_limit", ...)`, which sets a device memory limit in bytes. + + +#### Design + +Spilling consists of two components: + - A new buffer sub-class, `SpillableBuffer`, that implements moving of its data from host to device memory in-place. + - A spill manager that tracks all instances of `SpillableBuffer` and spills them on demand. +A global spill manager is used throughout cudf when spilling is enabled, which makes `as_buffer()` return `SpillableBuffer` instead of the default `Buffer` instances. + +Accessing `Buffer.ptr`, we get the device memory pointer of the buffer. This is unproblematic in the case of `Buffer` but what happens when accessing `SpillableBuffer.ptr`, which might have spilled its device memory? In this case, `SpillableBuffer` needs to unspill the memory before returning its device memory pointer. Furthermore, while this device memory pointer is being used (or could be used), `SpillableBuffer` cannot spill its memory back to host memory because doing so would invalidate the device pointer. + +To address this, we mark the `SpillableBuffer` as unspillable, we say that the buffer has been _exposed_. This can either be permanent if the device pointer is exposed to external projects or temporary while `libcudf` accesses the device memory. + +The `SpillableBuffer.get_ptr()` returns the device pointer of the buffer memory just like `.ptr` but if given an instance of `SpillLock`, the buffer is only unspillable as long as the instance of `SpillLock` is alive. + +For convenience, one can use the decorator/context `with_spill_lock` to associate a `SpillLock` with a lifetime bound to the context automatically. + ## The Cython layer The lowest level of cuDF is its interaction with `libcudf` via Cython. From 2707191ec4dcff7407542d87f7957521f770d5d1 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Mon, 14 Nov 2022 08:53:17 +0100 Subject: [PATCH 18/31] doc Co-authored-by: Vyas Ramasubramani --- docs/cudf/source/developer_guide/library_design.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/cudf/source/developer_guide/library_design.md b/docs/cudf/source/developer_guide/library_design.md index 8d5705e9859..ac3ace20ba3 100644 --- a/docs/cudf/source/developer_guide/library_design.md +++ b/docs/cudf/source/developer_guide/library_design.md @@ -235,7 +235,7 @@ Spilling consists of two components: - A spill manager that tracks all instances of `SpillableBuffer` and spills them on demand. A global spill manager is used throughout cudf when spilling is enabled, which makes `as_buffer()` return `SpillableBuffer` instead of the default `Buffer` instances. -Accessing `Buffer.ptr`, we get the device memory pointer of the buffer. This is unproblematic in the case of `Buffer` but what happens when accessing `SpillableBuffer.ptr`, which might have spilled its device memory? In this case, `SpillableBuffer` needs to unspill the memory before returning its device memory pointer. Furthermore, while this device memory pointer is being used (or could be used), `SpillableBuffer` cannot spill its memory back to host memory because doing so would invalidate the device pointer. +Accessing `Buffer.ptr`, we get the device memory pointer of the buffer. This is unproblematic in the case of `Buffer` but what happens when accessing `SpillableBuffer.ptr`, which might have spilled its device memory. In this case, `SpillableBuffer` needs to unspill the memory before returning its device memory pointer. Furthermore, while this device memory pointer is being used (or could be used), `SpillableBuffer` cannot spill its memory back to host memory because doing so would invalidate the device pointer. To address this, we mark the `SpillableBuffer` as unspillable, we say that the buffer has been _exposed_. This can either be permanent if the device pointer is exposed to external projects or temporary while `libcudf` accesses the device memory. From 90f2ec6a7a64897c39378f821ca79cdc16cfdff1 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Mon, 14 Nov 2022 08:53:44 +0100 Subject: [PATCH 19/31] doc --- python/cudf/cudf/_lib/column.pyx | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/cudf/cudf/_lib/column.pyx b/python/cudf/cudf/_lib/column.pyx index 4ce0ea70e33..2bad7ca2d86 100644 --- a/python/cudf/cudf/_lib/column.pyx +++ b/python/cudf/cudf/_lib/column.pyx @@ -532,9 +532,10 @@ cdef class Column: # the owner. column_owner and isinstance(data_owner, SpillableBuffer) and - # We have to make sure that `data_owner` is already spill - # locked and that its pointer is the same as `data_ptr` - # _without_ exposing the buffer permanently. + # We check that `data_owner` is spill locked (not spillable) + # and that its pointer is the same as `data_ptr` _without_ + # exposing the buffer permanently (calling get_ptr with a + # dummy SpillLock). not data_owner.spillable and data_owner.get_ptr(spill_lock=SpillLock()) == data_ptr and data_owner.size == base_nbytes From 4b372e68617f7c7f9e1953abb29bb8070e4c0b1f Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Mon, 14 Nov 2022 09:25:23 +0100 Subject: [PATCH 20/31] removing some exposed=True --- python/cudf/cudf/core/column/decimal.py | 4 ++-- python/cudf/cudf/utils/utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/cudf/cudf/core/column/decimal.py b/python/cudf/cudf/core/column/decimal.py index 77ca3f9688b..5ee9024a0d8 100644 --- a/python/cudf/cudf/core/column/decimal.py +++ b/python/cudf/cudf/core/column/decimal.py @@ -203,7 +203,7 @@ def from_arrow(cls, data: pa.Array): data_128 = cp.array(np.frombuffer(data.buffers()[1]).view("int32")) data_32 = data_128[::4].copy() return cls( - data=as_buffer(data_32.view("uint8"), exposed=True), + data=as_buffer(data_32.view("uint8")), size=len(data), dtype=dtype, offset=data.offset, @@ -290,7 +290,7 @@ def from_arrow(cls, data: pa.Array): data_128 = cp.array(np.frombuffer(data.buffers()[1]).view("int64")) data_64 = data_128[::2].copy() return cls( - data=as_buffer(data_64.view("uint8"), exposed=True), + data=as_buffer(data_64.view("uint8")), size=len(data), dtype=dtype, offset=data.offset, diff --git a/python/cudf/cudf/utils/utils.py b/python/cudf/cudf/utils/utils.py index 65a86484207..c5f4629483a 100644 --- a/python/cudf/cudf/utils/utils.py +++ b/python/cudf/cudf/utils/utils.py @@ -294,7 +294,7 @@ def pa_mask_buffer_to_mask(mask_buf, size): dbuf = rmm.DeviceBuffer(size=mask_size) dbuf.copy_from_host(np.asarray(mask_buf).view("u1")) return as_buffer(dbuf) - return as_buffer(mask_buf, exposed=True) + return as_buffer(mask_buf) def _isnat(val): From 1ce9d19bc34e4ff2364c52149961708d30990526 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Mon, 14 Nov 2022 10:19:17 +0100 Subject: [PATCH 21/31] fix memoryview() of sliced spillable buffer --- python/cudf/cudf/core/buffer/spillable_buffer.py | 1 + python/cudf/cudf/tests/test_spilling.py | 12 ++++++++++++ 2 files changed, 13 insertions(+) diff --git a/python/cudf/cudf/core/buffer/spillable_buffer.py b/python/cudf/cudf/core/buffer/spillable_buffer.py index 768de3df55a..c42216be279 100644 --- a/python/cudf/cudf/core/buffer/spillable_buffer.py +++ b/python/cudf/cudf/core/buffer/spillable_buffer.py @@ -168,6 +168,7 @@ def _from_host_memory(cls: Type[T], data: Any) -> T: data = memoryview(numpy.array(data, copy=False, subok=True)) if not data.c_contiguous: raise ValueError("Buffer data must be C-contiguous") + data = data.cast("B") # Make sure itemsize==1 # Create an already spilled buffer ret = cls.__new__(cls) diff --git a/python/cudf/cudf/tests/test_spilling.py b/python/cudf/cudf/tests/test_spilling.py index 4a5e317c6e4..3617cd2f6ff 100644 --- a/python/cudf/cudf/tests/test_spilling.py +++ b/python/cudf/cudf/tests/test_spilling.py @@ -452,3 +452,15 @@ def test_df_transpose(manager: SpillManager): assert df1._data._data["x"].data.exposed assert df2._data._data[0].data.exposed assert df2._data._data[1].data.exposed + + +@pytest.mark.parametrize("dtype", ["uint8", "uint64"]) +def test_memoryview_slice(manager: SpillManager, dtype): + """Check .memoryview() of a sliced spillable buffer""" + + data = np.arange(10, dtype=dtype) + # memoryview of a sliced spillable buffer + m1 = as_buffer(data=data)[1:-1].memoryview() + # sliced memoryview of data as bytes + m2 = memoryview(data).cast("B")[1:-1] + assert m1 == m2 From 96747127b667e6915f5b059c7c961059191e56b4 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Mon, 14 Nov 2022 10:46:54 +0100 Subject: [PATCH 22/31] abc: fix __reduce_ex__ by converting frames to numpy arrays --- ci/gpu/build.sh | 2 +- python/cudf/cudf/core/abc.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/ci/gpu/build.sh b/ci/gpu/build.sh index 516d369f5d9..825b0bd565a 100755 --- a/ci/gpu/build.sh +++ b/ci/gpu/build.sh @@ -284,7 +284,7 @@ py.test -n 8 --cache-clear --basetemp="$WORKSPACE/cudf-cuda-tmp" --ignore="$WORK gpuci_logger "Python py.tests for cuDF with spilling (CUDF_SPILL_DEVICE_LIMIT=1)" # Due to time concerns, we only run a limited set of tests -CUDF_SPILL=on CUDF_SPILL_DEVICE_LIMIT=1 py.test -n 8 --cache-clear --basetemp="$WORKSPACE/cudf-cuda-tmp" --ignore="$WORKSPACE/python/cudf/cudf/benchmarks" -v --cov-config="$WORKSPACE/python/cudf/.coveragerc" --cov-append --cov=cudf --cov-report=xml:"$WORKSPACE/python/cudf/cudf-coverage.xml" --cov-report term --dist=loadscope tests/test_binops.py tests/test_dataframe.py tests/test_buffer.py tests/test_onehot.py tests/test_reshape.py +CUDF_SPILL=on CUDF_SPILL_DEVICE_LIMIT=1 py.test -n 8 --cache-clear --basetemp="$WORKSPACE/cudf-cuda-tmp" --ignore="$WORKSPACE/python/cudf/cudf/benchmarks" -v --cov-config="$WORKSPACE/python/cudf/.coveragerc" --cov-append --cov=cudf --cov-report=xml:"$WORKSPACE/python/cudf/cudf-coverage.xml" --cov-report term --dist=loadscope tests/test_binops.py tests/test_dataframe.py tests/test_buffer.py tests/test_onehot.py tests/test_reshape.py tests/test_pickling.py cd "$WORKSPACE/python/dask_cudf" gpuci_logger "Python py.test for dask-cudf" diff --git a/python/cudf/cudf/core/abc.py b/python/cudf/cudf/core/abc.py index 1c8874a2abd..adf9fe39e4f 100644 --- a/python/cudf/cudf/core/abc.py +++ b/python/cudf/cudf/core/abc.py @@ -3,6 +3,8 @@ import pickle +import numpy + import cudf @@ -176,5 +178,9 @@ def host_deserialize(cls, header, frames): def __reduce_ex__(self, protocol): header, frames = self.host_serialize() - frames = [f.obj for f in frames] + + # Since memoryviews are not pickable, we convert them to numpy + # arrays (zero-copy). This works seamlessly because host_deserialize + # converts the frames back into memoryviews. + frames = [numpy.asarray(f) for f in frames] return self.host_deserialize, (header, frames) From d2d2ca11930cdcec5fe54eb01fd23e478b0892f8 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Mon, 14 Nov 2022 19:28:14 +0100 Subject: [PATCH 23/31] added some with_spill_lock --- python/cudf/cudf/_lib/datetime.pyx | 12 ++++++++++++ python/cudf/cudf/_lib/filling.pyx | 3 +++ python/cudf/cudf/_lib/labeling.pyx | 3 +++ python/cudf/cudf/_lib/null_mask.pyx | 9 +++++---- python/cudf/cudf/_lib/partitioning.pyx | 3 +++ python/cudf/cudf/_lib/quantiles.pyx | 3 +++ python/cudf/cudf/_lib/reduce.pyx | 4 ++++ 7 files changed, 33 insertions(+), 4 deletions(-) diff --git a/python/cudf/cudf/_lib/datetime.pyx b/python/cudf/cudf/_lib/datetime.pyx index cb0a245b915..3402e950763 100644 --- a/python/cudf/cudf/_lib/datetime.pyx +++ b/python/cudf/cudf/_lib/datetime.pyx @@ -1,5 +1,7 @@ # Copyright (c) 2020-2022, NVIDIA CORPORATION. +from cudf.core.buffer import with_spill_lock + from libcpp.memory cimport unique_ptr from libcpp.utility cimport move @@ -12,6 +14,7 @@ from cudf._lib.cpp.types cimport size_type from cudf._lib.scalar cimport DeviceScalar +@with_spill_lock() def add_months(Column col, Column months): # months must be int16 dtype cdef unique_ptr[column] c_result @@ -29,6 +32,7 @@ def add_months(Column col, Column months): return Column.from_unique_ptr(move(c_result)) +@with_spill_lock() def extract_datetime_component(Column col, object field): cdef unique_ptr[column] c_result @@ -99,6 +103,7 @@ cdef libcudf_datetime.rounding_frequency _get_rounding_frequency(object freq): return freq_val +@with_spill_lock() def ceil_datetime(Column col, object freq): cdef unique_ptr[column] c_result cdef column_view col_view = col.view() @@ -112,6 +117,7 @@ def ceil_datetime(Column col, object freq): return result +@with_spill_lock() def floor_datetime(Column col, object freq): cdef unique_ptr[column] c_result cdef column_view col_view = col.view() @@ -125,6 +131,7 @@ def floor_datetime(Column col, object freq): return result +@with_spill_lock() def round_datetime(Column col, object freq): cdef unique_ptr[column] c_result cdef column_view col_view = col.view() @@ -138,6 +145,7 @@ def round_datetime(Column col, object freq): return result +@with_spill_lock() def is_leap_year(Column col): """Returns a boolean indicator whether the year of the date is a leap year """ @@ -150,6 +158,7 @@ def is_leap_year(Column col): return Column.from_unique_ptr(move(c_result)) +@with_spill_lock() def date_range(DeviceScalar start, size_type n, offset): cdef unique_ptr[column] c_result cdef size_type months = ( @@ -166,6 +175,7 @@ def date_range(DeviceScalar start, size_type n, offset): return Column.from_unique_ptr(move(c_result)) +@with_spill_lock() def extract_quarter(Column col): """ Returns a column which contains the corresponding quarter of the year @@ -180,6 +190,7 @@ def extract_quarter(Column col): return Column.from_unique_ptr(move(c_result)) +@with_spill_lock() def days_in_month(Column col): """Extracts the number of days in the month of the date """ @@ -192,6 +203,7 @@ def days_in_month(Column col): return Column.from_unique_ptr(move(c_result)) +@with_spill_lock() def last_day_of_month(Column col): cdef unique_ptr[column] c_result cdef column_view col_view = col.view() diff --git a/python/cudf/cudf/_lib/filling.pyx b/python/cudf/cudf/_lib/filling.pyx index be92abdb0e8..e7727f12bd1 100644 --- a/python/cudf/cudf/_lib/filling.pyx +++ b/python/cudf/cudf/_lib/filling.pyx @@ -32,6 +32,7 @@ def fill_in_place(Column destination, int begin, int end, DeviceScalar value): ) +@with_spill_lock() def fill(Column destination, int begin, int end, DeviceScalar value): cdef column_view c_destination = destination.view() cdef size_type c_begin = begin @@ -50,6 +51,7 @@ def fill(Column destination, int begin, int end, DeviceScalar value): return Column.from_unique_ptr(move(c_result)) +@with_spill_lock() def repeat(list inp, object count): if isinstance(count, Column): return _repeat_via_column(inp, count) @@ -84,6 +86,7 @@ def _repeat_via_size_type(list inp, size_type count): return columns_from_unique_ptr(move(c_result)) +@with_spill_lock() def sequence(int size, DeviceScalar init, DeviceScalar step): cdef size_type c_size = size cdef const scalar* c_init = init.get_raw_ptr() diff --git a/python/cudf/cudf/_lib/labeling.pyx b/python/cudf/cudf/_lib/labeling.pyx index ed5033c08a5..6362072b26e 100644 --- a/python/cudf/cudf/_lib/labeling.pyx +++ b/python/cudf/cudf/_lib/labeling.pyx @@ -1,5 +1,7 @@ # Copyright (c) 2021-2022, NVIDIA CORPORATION. +from cudf.core.buffer import with_spill_lock + from libcpp cimport bool as cbool from libcpp.memory cimport unique_ptr from libcpp.utility cimport move @@ -13,6 +15,7 @@ from cudf._lib.cpp.labeling cimport inclusive, label_bins as cpp_label_bins # Note that the parameter input shadows a Python built-in in the local scope, # but I'm not too concerned about that since there's no use-case for actual # input in this context. +@with_spill_lock() def label_bins(Column input, Column left_edges, cbool left_inclusive, Column right_edges, cbool right_inclusive): cdef inclusive c_left_inclusive = \ diff --git a/python/cudf/cudf/_lib/null_mask.pyx b/python/cudf/cudf/_lib/null_mask.pyx index 61988019c70..a4154db7b6e 100644 --- a/python/cudf/cudf/_lib/null_mask.pyx +++ b/python/cudf/cudf/_lib/null_mask.pyx @@ -2,12 +2,14 @@ from enum import Enum +from rmm._lib.device_buffer cimport DeviceBuffer, device_buffer + +from cudf.core.buffer import as_buffer, with_spill_lock + from libcpp.memory cimport make_unique, unique_ptr from libcpp.pair cimport pair from libcpp.utility cimport move -from rmm._lib.device_buffer cimport DeviceBuffer, device_buffer - from cudf._lib.column cimport Column from cudf._lib.cpp.column.column_view cimport column_view from cudf._lib.cpp.null_mask cimport ( @@ -22,8 +24,6 @@ from cudf._lib.cpp.table.table_view cimport table_view from cudf._lib.cpp.types cimport mask_state, size_type from cudf._lib.utils cimport table_view_from_columns -from cudf.core.buffer import as_buffer - class MaskState(Enum): """ @@ -35,6 +35,7 @@ class MaskState(Enum): ALL_NULL = mask_state.ALL_NULL +@with_spill_lock() def copy_bitmask(Column col): """ Copies column's validity mask buffer into a new buffer, shifting by the diff --git a/python/cudf/cudf/_lib/partitioning.pyx b/python/cudf/cudf/_lib/partitioning.pyx index 233551c5134..f707c6a426b 100644 --- a/python/cudf/cudf/_lib/partitioning.pyx +++ b/python/cudf/cudf/_lib/partitioning.pyx @@ -1,5 +1,7 @@ # Copyright (c) 2020-2022, NVIDIA CORPORATION. +from cudf.core.buffer import with_spill_lock + from libcpp.memory cimport unique_ptr from libcpp.pair cimport pair from libcpp.utility cimport move @@ -16,6 +18,7 @@ from cudf._lib.stream_compaction import distinct_count as cpp_distinct_count cimport cudf._lib.cpp.types as libcudf_types +@with_spill_lock() def partition(list source_columns, Column partition_map, object num_partitions): diff --git a/python/cudf/cudf/_lib/quantiles.pyx b/python/cudf/cudf/_lib/quantiles.pyx index 62706367c4f..4862d1d9c34 100644 --- a/python/cudf/cudf/_lib/quantiles.pyx +++ b/python/cudf/cudf/_lib/quantiles.pyx @@ -1,5 +1,7 @@ # Copyright (c) 2020-2022, NVIDIA CORPORATION. +from cudf.core.buffer import with_spill_lock + from libcpp cimport bool from libcpp.memory cimport unique_ptr from libcpp.utility cimport move @@ -27,6 +29,7 @@ from cudf._lib.cpp.types cimport interpolation, null_order, order, sorted from cudf._lib.utils cimport columns_from_unique_ptr, table_view_from_columns +@with_spill_lock() def quantile( Column input, object q, diff --git a/python/cudf/cudf/_lib/reduce.pyx b/python/cudf/cudf/_lib/reduce.pyx index e46d724ed9d..2f708ce42dd 100644 --- a/python/cudf/cudf/_lib/reduce.pyx +++ b/python/cudf/cudf/_lib/reduce.pyx @@ -3,6 +3,7 @@ from cython.operator import dereference import cudf +from cudf.core.buffer import with_spill_lock from libcpp.memory cimport unique_ptr from libcpp.utility cimport move, pair @@ -23,6 +24,7 @@ from cudf._lib.scalar cimport DeviceScalar from cudf._lib.types cimport dtype_to_data_type, is_decimal_type_id +@with_spill_lock() def reduce(reduction_op, Column incol, dtype=None, **kwargs): """ Top level Cython reduce function wrapping libcudf reductions. @@ -79,6 +81,7 @@ def reduce(reduction_op, Column incol, dtype=None, **kwargs): return py_result.value +@with_spill_lock() def scan(scan_op, Column incol, inclusive, **kwargs): """ Top level Cython scan function wrapping libcudf scans. @@ -110,6 +113,7 @@ def scan(scan_op, Column incol, inclusive, **kwargs): return py_result +@with_spill_lock() def minmax(Column incol): """ Top level Cython minmax function wrapping libcudf minmax. From 2d84e141ac4ffa9c24081bf1a8ea602798d49d96 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Mon, 14 Nov 2022 19:47:30 +0100 Subject: [PATCH 24/31] Clean up Co-authored-by: GALI PREM SAGAR --- python/cudf/cudf/_lib/column.pyx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/cudf/cudf/_lib/column.pyx b/python/cudf/cudf/_lib/column.pyx index 2bad7ca2d86..27cb362c285 100644 --- a/python/cudf/cudf/_lib/column.pyx +++ b/python/cudf/cudf/_lib/column.pyx @@ -508,7 +508,7 @@ cdef class Column: size = cv.size() offset = cv.offset() dtype = dtype_from_column_view(cv) - dtype_itemsize = dtype.itemsize if hasattr(dtype, "itemsize") else 1 + dtype_itemsize = getattr(dtype, "itemsize", 1) data_ptr = (cv.head[void]()) data = None @@ -547,7 +547,7 @@ cdef class Column: # TODO: try to discover their relationship and create a # SpillableBufferSlice instead. data = as_buffer( - data_ptr, + data=data_ptr, size=base_nbytes, owner=data_owner, exposed=True, From 989f2606046ef6e879952f6e9188ace764305dd6 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Mon, 14 Nov 2022 19:56:23 +0100 Subject: [PATCH 25/31] moved spill lock to compute_null_count() --- python/cudf/cudf/_lib/column.pyx | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/cudf/cudf/_lib/column.pyx b/python/cudf/cudf/_lib/column.pyx index 27cb362c285..1846042ae86 100644 --- a/python/cudf/cudf/_lib/column.pyx +++ b/python/cudf/cudf/_lib/column.pyx @@ -249,8 +249,7 @@ cdef class Column: @property def null_count(self): if self._null_count is None: - with with_spill_lock(): - self._null_count = self.compute_null_count() + self._null_count = self.compute_null_count() return self._null_count @property @@ -315,7 +314,8 @@ cdef class Column: return other_col cdef libcudf_types.size_type compute_null_count(self) except? 0: - return self._view(libcudf_types.UNKNOWN_NULL_COUNT).null_count() + with with_spill_lock(): + return self._view(libcudf_types.UNKNOWN_NULL_COUNT).null_count() cdef mutable_column_view mutable_view(self) except *: if is_categorical_dtype(self.dtype): From 9c9ec922b62ad67986eba622139fde6ba26408ac Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 15 Nov 2022 08:49:05 +0100 Subject: [PATCH 26/31] rename with_spill_lock => acquire_spill_lock --- python/cudf/cudf/_lib/binaryop.pyx | 6 ++-- python/cudf/cudf/_lib/column.pyx | 4 +-- python/cudf/cudf/_lib/copying.pyx | 36 ++++++++++++------------ python/cudf/cudf/_lib/datetime.pyx | 22 +++++++-------- python/cudf/cudf/_lib/filling.pyx | 10 +++---- python/cudf/cudf/_lib/groupby.pyx | 4 +-- python/cudf/cudf/_lib/labeling.pyx | 4 +-- python/cudf/cudf/_lib/null_mask.pyx | 4 +-- python/cudf/cudf/_lib/partitioning.pyx | 4 +-- python/cudf/cudf/_lib/quantiles.pyx | 4 +-- python/cudf/cudf/_lib/reduce.pyx | 8 +++--- python/cudf/cudf/_lib/replace.pyx | 4 +-- python/cudf/cudf/_lib/unary.pyx | 14 ++++----- python/cudf/cudf/core/buffer/__init__.py | 6 +++- python/cudf/cudf/core/buffer/utils.py | 4 +-- python/cudf/cudf/tests/test_spilling.py | 11 ++++++-- 16 files changed, 77 insertions(+), 68 deletions(-) diff --git a/python/cudf/cudf/_lib/binaryop.pyx b/python/cudf/cudf/_lib/binaryop.pyx index 9455565a74f..ecc73059ba3 100644 --- a/python/cudf/cudf/_lib/binaryop.pyx +++ b/python/cudf/cudf/_lib/binaryop.pyx @@ -22,7 +22,7 @@ from cudf._lib.cpp.types cimport data_type, type_id from cudf._lib.types cimport dtype_to_data_type, underlying_type_t_type_id from cudf.api.types import is_scalar, is_string_dtype -from cudf.core.buffer import with_spill_lock +from cudf.core.buffer import acquire_spill_lock cimport cudf._lib.cpp.binaryop as cpp_binaryop from cudf._lib.cpp.binaryop cimport binary_operator @@ -157,7 +157,7 @@ cdef binaryop_s_v(DeviceScalar lhs, Column rhs, return Column.from_unique_ptr(move(c_result)) -@with_spill_lock() +@acquire_spill_lock() def binaryop(lhs, rhs, op, dtype): """ Dispatches a binary op call to the appropriate libcudf function: @@ -205,7 +205,7 @@ def binaryop(lhs, rhs, op, dtype): return result -@with_spill_lock() +@acquire_spill_lock() def binaryop_udf(Column lhs, Column rhs, udf_ptx, dtype): """ Apply a user-defined binary operator (a UDF) defined in `udf_ptx` on diff --git a/python/cudf/cudf/_lib/column.pyx b/python/cudf/cudf/_lib/column.pyx index 1846042ae86..ec7d2570708 100644 --- a/python/cudf/cudf/_lib/column.pyx +++ b/python/cudf/cudf/_lib/column.pyx @@ -12,9 +12,9 @@ from cudf.core.buffer import ( Buffer, SpillableBuffer, SpillLock, + acquire_spill_lock, as_buffer, get_spill_lock, - with_spill_lock, ) from cpython.buffer cimport PyObject_CheckBuffer @@ -314,7 +314,7 @@ cdef class Column: return other_col cdef libcudf_types.size_type compute_null_count(self) except? 0: - with with_spill_lock(): + with acquire_spill_lock(): return self._view(libcudf_types.UNKNOWN_NULL_COUNT).null_count() cdef mutable_column_view mutable_view(self) except *: diff --git a/python/cudf/cudf/_lib/copying.pyx b/python/cudf/cudf/_lib/copying.pyx index 36b3fdd19fe..9f0b294b10c 100644 --- a/python/cudf/cudf/_lib/copying.pyx +++ b/python/cudf/cudf/_lib/copying.pyx @@ -12,7 +12,7 @@ from libcpp.vector cimport vector from rmm._lib.device_buffer cimport DeviceBuffer import cudf -from cudf.core.buffer import Buffer, as_buffer, with_spill_lock +from cudf.core.buffer import Buffer, acquire_spill_lock, as_buffer from cudf._lib.column cimport Column @@ -64,7 +64,7 @@ def _gather_map_is_valid( return gm_min >= -nrows and gm_max < nrows -@with_spill_lock() +@acquire_spill_lock() def copy_column(Column input_column): """ Deep copies a column @@ -86,7 +86,7 @@ def copy_column(Column input_column): return Column.from_unique_ptr(move(c_result)) -@with_spill_lock() +@acquire_spill_lock() def _copy_range_in_place(Column input_column, Column target_column, size_type input_begin, @@ -134,7 +134,7 @@ def _copy_range(Column input_column, return Column.from_unique_ptr(move(c_result)) -@with_spill_lock() +@acquire_spill_lock() def copy_range(Column source_column, Column target_column, size_type source_begin, @@ -168,7 +168,7 @@ def copy_range(Column source_column, source_begin, source_end, target_begin) -@with_spill_lock() +@acquire_spill_lock() def gather( list columns, Column gather_map, @@ -236,7 +236,7 @@ cdef scatter_column(list source_columns, return columns_from_unique_ptr(move(c_result)) -@with_spill_lock() +@acquire_spill_lock() def scatter(list sources, Column scatter_map, list target_columns, bool bounds_check=True): """ @@ -277,7 +277,7 @@ def scatter(list sources, Column scatter_map, list target_columns, ) -@with_spill_lock() +@acquire_spill_lock() def column_empty_like(Column input_column): cdef column_view input_column_view = input_column.view() @@ -289,7 +289,7 @@ def column_empty_like(Column input_column): return Column.from_unique_ptr(move(c_result)) -@with_spill_lock() +@acquire_spill_lock() def column_allocate_like(Column input_column, size=None): cdef size_type c_size = 0 @@ -314,7 +314,7 @@ def column_allocate_like(Column input_column, size=None): return Column.from_unique_ptr(move(c_result)) -@with_spill_lock() +@acquire_spill_lock() def columns_empty_like(list input_columns): cdef table_view input_table_view = table_view_from_columns(input_columns) cdef unique_ptr[table] c_result @@ -325,7 +325,7 @@ def columns_empty_like(list input_columns): return columns_from_unique_ptr(move(c_result)) -@with_spill_lock() +@acquire_spill_lock() def column_slice(Column input_column, object indices): cdef column_view input_column_view = input_column.view() @@ -355,7 +355,7 @@ def column_slice(Column input_column, object indices): return result -@with_spill_lock() +@acquire_spill_lock() def columns_slice(list input_columns, list indices): """ Given a list of input columns, return columns sliced by ``indices``. @@ -382,7 +382,7 @@ def columns_slice(list input_columns, list indices): ] -@with_spill_lock() +@acquire_spill_lock() def column_split(Column input_column, object splits): cdef column_view input_column_view = input_column.view() @@ -414,7 +414,7 @@ def column_split(Column input_column, object splits): return result -@with_spill_lock() +@acquire_spill_lock() def columns_split(list input_columns, object splits): cdef table_view input_table_view = table_view_from_columns(input_columns) @@ -521,7 +521,7 @@ def _copy_if_else_scalar_scalar(DeviceScalar lhs, return Column.from_unique_ptr(move(c_result)) -@with_spill_lock() +@acquire_spill_lock() def copy_if_else(object lhs, object rhs, Column boolean_mask): if isinstance(lhs, Column): @@ -589,7 +589,7 @@ def _boolean_mask_scatter_scalar(list input_scalars, list target_columns, return columns_from_unique_ptr(move(c_result)) -@with_spill_lock() +@acquire_spill_lock() def boolean_mask_scatter(list input_, list target_columns, Column boolean_mask): """Copy the target columns, replacing masked rows with input data. @@ -622,7 +622,7 @@ def boolean_mask_scatter(list input_, list target_columns, ) -@with_spill_lock() +@acquire_spill_lock() def shift(Column input, int offset, object fill_value=None): cdef DeviceScalar fill @@ -659,7 +659,7 @@ def shift(Column input, int offset, object fill_value=None): return Column.from_unique_ptr(move(c_output)) -@with_spill_lock() +@acquire_spill_lock() def get_element(Column input_column, size_type index): cdef column_view col_view = input_column.view() @@ -674,7 +674,7 @@ def get_element(Column input_column, size_type index): ) -@with_spill_lock() +@acquire_spill_lock() def segmented_gather(Column source_column, Column gather_map): cdef shared_ptr[lists_column_view] source_LCV = ( make_shared[lists_column_view](source_column.view()) diff --git a/python/cudf/cudf/_lib/datetime.pyx b/python/cudf/cudf/_lib/datetime.pyx index 3402e950763..81949dbaa20 100644 --- a/python/cudf/cudf/_lib/datetime.pyx +++ b/python/cudf/cudf/_lib/datetime.pyx @@ -1,6 +1,6 @@ # Copyright (c) 2020-2022, NVIDIA CORPORATION. -from cudf.core.buffer import with_spill_lock +from cudf.core.buffer import acquire_spill_lock from libcpp.memory cimport unique_ptr from libcpp.utility cimport move @@ -14,7 +14,7 @@ from cudf._lib.cpp.types cimport size_type from cudf._lib.scalar cimport DeviceScalar -@with_spill_lock() +@acquire_spill_lock() def add_months(Column col, Column months): # months must be int16 dtype cdef unique_ptr[column] c_result @@ -32,7 +32,7 @@ def add_months(Column col, Column months): return Column.from_unique_ptr(move(c_result)) -@with_spill_lock() +@acquire_spill_lock() def extract_datetime_component(Column col, object field): cdef unique_ptr[column] c_result @@ -103,7 +103,7 @@ cdef libcudf_datetime.rounding_frequency _get_rounding_frequency(object freq): return freq_val -@with_spill_lock() +@acquire_spill_lock() def ceil_datetime(Column col, object freq): cdef unique_ptr[column] c_result cdef column_view col_view = col.view() @@ -117,7 +117,7 @@ def ceil_datetime(Column col, object freq): return result -@with_spill_lock() +@acquire_spill_lock() def floor_datetime(Column col, object freq): cdef unique_ptr[column] c_result cdef column_view col_view = col.view() @@ -131,7 +131,7 @@ def floor_datetime(Column col, object freq): return result -@with_spill_lock() +@acquire_spill_lock() def round_datetime(Column col, object freq): cdef unique_ptr[column] c_result cdef column_view col_view = col.view() @@ -145,7 +145,7 @@ def round_datetime(Column col, object freq): return result -@with_spill_lock() +@acquire_spill_lock() def is_leap_year(Column col): """Returns a boolean indicator whether the year of the date is a leap year """ @@ -158,7 +158,7 @@ def is_leap_year(Column col): return Column.from_unique_ptr(move(c_result)) -@with_spill_lock() +@acquire_spill_lock() def date_range(DeviceScalar start, size_type n, offset): cdef unique_ptr[column] c_result cdef size_type months = ( @@ -175,7 +175,7 @@ def date_range(DeviceScalar start, size_type n, offset): return Column.from_unique_ptr(move(c_result)) -@with_spill_lock() +@acquire_spill_lock() def extract_quarter(Column col): """ Returns a column which contains the corresponding quarter of the year @@ -190,7 +190,7 @@ def extract_quarter(Column col): return Column.from_unique_ptr(move(c_result)) -@with_spill_lock() +@acquire_spill_lock() def days_in_month(Column col): """Extracts the number of days in the month of the date """ @@ -203,7 +203,7 @@ def days_in_month(Column col): return Column.from_unique_ptr(move(c_result)) -@with_spill_lock() +@acquire_spill_lock() def last_day_of_month(Column col): cdef unique_ptr[column] c_result cdef column_view col_view = col.view() diff --git a/python/cudf/cudf/_lib/filling.pyx b/python/cudf/cudf/_lib/filling.pyx index e7727f12bd1..63549f08cbd 100644 --- a/python/cudf/cudf/_lib/filling.pyx +++ b/python/cudf/cudf/_lib/filling.pyx @@ -1,6 +1,6 @@ # Copyright (c) 2020-2022, NVIDIA CORPORATION. -from cudf.core.buffer import with_spill_lock +from cudf.core.buffer import acquire_spill_lock from libcpp.memory cimport unique_ptr from libcpp.utility cimport move @@ -17,7 +17,7 @@ from cudf._lib.scalar cimport DeviceScalar from cudf._lib.utils cimport columns_from_unique_ptr, table_view_from_columns -@with_spill_lock() +@acquire_spill_lock() def fill_in_place(Column destination, int begin, int end, DeviceScalar value): cdef mutable_column_view c_destination = destination.mutable_view() cdef size_type c_begin = begin @@ -32,7 +32,7 @@ def fill_in_place(Column destination, int begin, int end, DeviceScalar value): ) -@with_spill_lock() +@acquire_spill_lock() def fill(Column destination, int begin, int end, DeviceScalar value): cdef column_view c_destination = destination.view() cdef size_type c_begin = begin @@ -51,7 +51,7 @@ def fill(Column destination, int begin, int end, DeviceScalar value): return Column.from_unique_ptr(move(c_result)) -@with_spill_lock() +@acquire_spill_lock() def repeat(list inp, object count): if isinstance(count, Column): return _repeat_via_column(inp, count) @@ -86,7 +86,7 @@ def _repeat_via_size_type(list inp, size_type count): return columns_from_unique_ptr(move(c_result)) -@with_spill_lock() +@acquire_spill_lock() def sequence(int size, DeviceScalar init, DeviceScalar step): cdef size_type c_size = size cdef const scalar* c_init = init.get_raw_ptr() diff --git a/python/cudf/cudf/_lib/groupby.pyx b/python/cudf/cudf/_lib/groupby.pyx index b1012a7619a..a8b7fef6a57 100644 --- a/python/cudf/cudf/_lib/groupby.pyx +++ b/python/cudf/cudf/_lib/groupby.pyx @@ -10,7 +10,7 @@ from cudf.api.types import ( is_string_dtype, is_struct_dtype, ) -from cudf.core.buffer import with_spill_lock +from cudf.core.buffer import acquire_spill_lock from libcpp cimport bool from libcpp.memory cimport unique_ptr @@ -94,7 +94,7 @@ cdef class GroupBy: else: c_null_handling = libcudf_types.null_policy.INCLUDE - with with_spill_lock() as spill_lock: + with acquire_spill_lock() as spill_lock: keys_view = table_view_from_columns(keys) # We spill lock the columns while this GroupBy instance is alive. self._spill_lock = spill_lock diff --git a/python/cudf/cudf/_lib/labeling.pyx b/python/cudf/cudf/_lib/labeling.pyx index 6362072b26e..2c2538ab0af 100644 --- a/python/cudf/cudf/_lib/labeling.pyx +++ b/python/cudf/cudf/_lib/labeling.pyx @@ -1,6 +1,6 @@ # Copyright (c) 2021-2022, NVIDIA CORPORATION. -from cudf.core.buffer import with_spill_lock +from cudf.core.buffer import acquire_spill_lock from libcpp cimport bool as cbool from libcpp.memory cimport unique_ptr @@ -15,7 +15,7 @@ from cudf._lib.cpp.labeling cimport inclusive, label_bins as cpp_label_bins # Note that the parameter input shadows a Python built-in in the local scope, # but I'm not too concerned about that since there's no use-case for actual # input in this context. -@with_spill_lock() +@acquire_spill_lock() def label_bins(Column input, Column left_edges, cbool left_inclusive, Column right_edges, cbool right_inclusive): cdef inclusive c_left_inclusive = \ diff --git a/python/cudf/cudf/_lib/null_mask.pyx b/python/cudf/cudf/_lib/null_mask.pyx index a4154db7b6e..140fa75f569 100644 --- a/python/cudf/cudf/_lib/null_mask.pyx +++ b/python/cudf/cudf/_lib/null_mask.pyx @@ -4,7 +4,7 @@ from enum import Enum from rmm._lib.device_buffer cimport DeviceBuffer, device_buffer -from cudf.core.buffer import as_buffer, with_spill_lock +from cudf.core.buffer import acquire_spill_lock, as_buffer from libcpp.memory cimport make_unique, unique_ptr from libcpp.pair cimport pair @@ -35,7 +35,7 @@ class MaskState(Enum): ALL_NULL = mask_state.ALL_NULL -@with_spill_lock() +@acquire_spill_lock() def copy_bitmask(Column col): """ Copies column's validity mask buffer into a new buffer, shifting by the diff --git a/python/cudf/cudf/_lib/partitioning.pyx b/python/cudf/cudf/_lib/partitioning.pyx index f707c6a426b..083407954b3 100644 --- a/python/cudf/cudf/_lib/partitioning.pyx +++ b/python/cudf/cudf/_lib/partitioning.pyx @@ -1,6 +1,6 @@ # Copyright (c) 2020-2022, NVIDIA CORPORATION. -from cudf.core.buffer import with_spill_lock +from cudf.core.buffer import acquire_spill_lock from libcpp.memory cimport unique_ptr from libcpp.pair cimport pair @@ -18,7 +18,7 @@ from cudf._lib.stream_compaction import distinct_count as cpp_distinct_count cimport cudf._lib.cpp.types as libcudf_types -@with_spill_lock() +@acquire_spill_lock() def partition(list source_columns, Column partition_map, object num_partitions): diff --git a/python/cudf/cudf/_lib/quantiles.pyx b/python/cudf/cudf/_lib/quantiles.pyx index 4862d1d9c34..d3a02fa7cbf 100644 --- a/python/cudf/cudf/_lib/quantiles.pyx +++ b/python/cudf/cudf/_lib/quantiles.pyx @@ -1,6 +1,6 @@ # Copyright (c) 2020-2022, NVIDIA CORPORATION. -from cudf.core.buffer import with_spill_lock +from cudf.core.buffer import acquire_spill_lock from libcpp cimport bool from libcpp.memory cimport unique_ptr @@ -29,7 +29,7 @@ from cudf._lib.cpp.types cimport interpolation, null_order, order, sorted from cudf._lib.utils cimport columns_from_unique_ptr, table_view_from_columns -@with_spill_lock() +@acquire_spill_lock() def quantile( Column input, object q, diff --git a/python/cudf/cudf/_lib/reduce.pyx b/python/cudf/cudf/_lib/reduce.pyx index 2f708ce42dd..f11bacd5d1e 100644 --- a/python/cudf/cudf/_lib/reduce.pyx +++ b/python/cudf/cudf/_lib/reduce.pyx @@ -3,7 +3,7 @@ from cython.operator import dereference import cudf -from cudf.core.buffer import with_spill_lock +from cudf.core.buffer import acquire_spill_lock from libcpp.memory cimport unique_ptr from libcpp.utility cimport move, pair @@ -24,7 +24,7 @@ from cudf._lib.scalar cimport DeviceScalar from cudf._lib.types cimport dtype_to_data_type, is_decimal_type_id -@with_spill_lock() +@acquire_spill_lock() def reduce(reduction_op, Column incol, dtype=None, **kwargs): """ Top level Cython reduce function wrapping libcudf reductions. @@ -81,7 +81,7 @@ def reduce(reduction_op, Column incol, dtype=None, **kwargs): return py_result.value -@with_spill_lock() +@acquire_spill_lock() def scan(scan_op, Column incol, inclusive, **kwargs): """ Top level Cython scan function wrapping libcudf scans. @@ -113,7 +113,7 @@ def scan(scan_op, Column incol, inclusive, **kwargs): return py_result -@with_spill_lock() +@acquire_spill_lock() def minmax(Column incol): """ Top level Cython minmax function wrapping libcudf minmax. diff --git a/python/cudf/cudf/_lib/replace.pyx b/python/cudf/cudf/_lib/replace.pyx index 44654f895d2..06e94934ef5 100644 --- a/python/cudf/cudf/_lib/replace.pyx +++ b/python/cudf/cudf/_lib/replace.pyx @@ -4,7 +4,7 @@ from libcpp.memory cimport unique_ptr from libcpp.utility cimport move from cudf.api.types import is_scalar -from cudf.core.buffer import with_spill_lock +from cudf.core.buffer import acquire_spill_lock from cudf._lib.column cimport Column @@ -211,7 +211,7 @@ def clip(Column input_col, object lo, object hi): return clamp(input_col, lo_scalar, hi_scalar) -@with_spill_lock() +@acquire_spill_lock() def normalize_nans_and_zeros_inplace(Column input_col): """ Inplace normalizing diff --git a/python/cudf/cudf/_lib/unary.pyx b/python/cudf/cudf/_lib/unary.pyx index b1f5e3bd101..7ef4d00b9ff 100644 --- a/python/cudf/cudf/_lib/unary.pyx +++ b/python/cudf/cudf/_lib/unary.pyx @@ -3,7 +3,7 @@ from enum import IntEnum from cudf.api.types import is_decimal_dtype -from cudf.core.buffer import with_spill_lock +from cudf.core.buffer import acquire_spill_lock from libcpp.memory cimport unique_ptr from libcpp.utility cimport move @@ -44,7 +44,7 @@ class UnaryOp(IntEnum): NOT = unary_operator.NOT -@with_spill_lock() +@acquire_spill_lock() def unary_operation(Column input, object op): cdef column_view c_input = input.view() cdef unary_operator c_op = ( @@ -62,7 +62,7 @@ def unary_operation(Column input, object op): return Column.from_unique_ptr(move(c_result)) -@with_spill_lock() +@acquire_spill_lock() def is_null(Column input): cdef column_view c_input = input.view() cdef unique_ptr[column] c_result @@ -73,7 +73,7 @@ def is_null(Column input): return Column.from_unique_ptr(move(c_result)) -@with_spill_lock() +@acquire_spill_lock() def is_valid(Column input): cdef column_view c_input = input.view() cdef unique_ptr[column] c_result @@ -84,7 +84,7 @@ def is_valid(Column input): return Column.from_unique_ptr(move(c_result)) -@with_spill_lock() +@acquire_spill_lock() def cast(Column input, object dtype=np.float64): cdef column_view c_input = input.view() cdef data_type c_dtype = dtype_to_data_type(dtype) @@ -100,7 +100,7 @@ def cast(Column input, object dtype=np.float64): return result -@with_spill_lock() +@acquire_spill_lock() def is_nan(Column input): cdef column_view c_input = input.view() cdef unique_ptr[column] c_result @@ -111,7 +111,7 @@ def is_nan(Column input): return Column.from_unique_ptr(move(c_result)) -@with_spill_lock() +@acquire_spill_lock() def is_non_nan(Column input): cdef column_view c_input = input.view() cdef unique_ptr[column] c_result diff --git a/python/cudf/cudf/core/buffer/__init__.py b/python/cudf/cudf/core/buffer/__init__.py index 044f2fa0478..49f2c57b17f 100644 --- a/python/cudf/cudf/core/buffer/__init__.py +++ b/python/cudf/cudf/core/buffer/__init__.py @@ -2,4 +2,8 @@ from cudf.core.buffer.buffer import Buffer, cuda_array_interface_wrapper from cudf.core.buffer.spillable_buffer import SpillableBuffer, SpillLock -from cudf.core.buffer.utils import as_buffer, get_spill_lock, with_spill_lock +from cudf.core.buffer.utils import ( + acquire_spill_lock, + as_buffer, + get_spill_lock, +) diff --git a/python/cudf/cudf/core/buffer/utils.py b/python/cudf/cudf/core/buffer/utils.py index 3da1d610ca1..cc11a9b9678 100644 --- a/python/cudf/cudf/core/buffer/utils.py +++ b/python/cudf/cudf/core/buffer/utils.py @@ -102,7 +102,7 @@ def _pop_thread_spill_lock() -> None: _thread_spill_locks[_id] = (spill_lock, count - 1) -class with_spill_lock(ContextDecorator): +class acquire_spill_lock(ContextDecorator): """Decorator and context to set spill lock automatically. All calls to `get_spill_lock()` within the decorated function or context @@ -118,7 +118,7 @@ def __exit__(self, *exc): def get_spill_lock() -> Union[SpillLock, None]: - """Return a spill lock within the context of `with_spill_lock` or None + """Return a spill lock within the context of `acquire_spill_lock` or None Returns None, if spilling is disabled. """ diff --git a/python/cudf/cudf/tests/test_spilling.py b/python/cudf/cudf/tests/test_spilling.py index 3617cd2f6ff..81097599807 100644 --- a/python/cudf/cudf/tests/test_spilling.py +++ b/python/cudf/cudf/tests/test_spilling.py @@ -19,7 +19,12 @@ import cudf.core.buffer.spill_manager import cudf.options from cudf.core.abc import Serializable -from cudf.core.buffer import Buffer, as_buffer, get_spill_lock, with_spill_lock +from cudf.core.buffer import ( + Buffer, + acquire_spill_lock, + as_buffer, + get_spill_lock, +) from cudf.core.buffer.spill_manager import ( SpillManager, get_global_manager, @@ -302,7 +307,7 @@ def test_ptr_restricted(manager: SpillManager): def test_get_spill_lock(manager: SpillManager): - @with_spill_lock() + @acquire_spill_lock() def f(sleep=False, nest=0): if sleep: time.sleep(random.random() / 100) @@ -335,7 +340,7 @@ def f(sleep=False, nest=0): def test_get_spill_lock_no_manager(): """When spilling is disabled, get_spill_lock() should return None always""" - @with_spill_lock() + @acquire_spill_lock() def f(): return get_spill_lock() From 824a7c36e48dfb35b187b95f32a699ead73e9ec1 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 15 Nov 2022 09:01:09 +0100 Subject: [PATCH 27/31] skip tests when CUDF_SPILL=on --- python/cudf/cudf/tests/test_cuda_array_interface.py | 8 ++++++++ python/cudf/cudf/tests/test_spilling.py | 6 ++++++ 2 files changed, 14 insertions(+) diff --git a/python/cudf/cudf/tests/test_cuda_array_interface.py b/python/cudf/cudf/tests/test_cuda_array_interface.py index 9b9709b52c3..e81f4ec795a 100644 --- a/python/cudf/cudf/tests/test_cuda_array_interface.py +++ b/python/cudf/cudf/tests/test_cuda_array_interface.py @@ -10,6 +10,7 @@ from numba import cuda import cudf +from cudf.core.buffer.spill_manager import get_global_manager from cudf.testing._utils import DATETIME_TYPES, NUMERIC_TYPES, assert_eq @@ -169,6 +170,13 @@ def test_column_from_ephemeral_cupy_try_lose_reference(): assert_eq(pd.Series([1, 2, 3]), a.to_pandas()) +@pytest.mark.xfail( + get_global_manager() is not None, + reason=( + "spilling doesn't support PyTorch, see " + "`cudf.core.buffer.spillable_buffer.DelayedPointerTuple`" + ), +) def test_cuda_array_interface_pytorch(): torch = pytest.importorskip("torch", minversion="1.6.0") if not torch.cuda.is_available(): diff --git a/python/cudf/cudf/tests/test_spilling.py b/python/cudf/cudf/tests/test_spilling.py index 81097599807..6f790600d92 100644 --- a/python/cudf/cudf/tests/test_spilling.py +++ b/python/cudf/cudf/tests/test_spilling.py @@ -38,6 +38,12 @@ ) from cudf.testing._utils import assert_eq +if get_global_manager() is not None: + pytest.skip( + "cannot test spilling when enabled globally, set `CUDF_SPILL=off`", + allow_module_level=True, + ) + def gen_df(target="gpu") -> cudf.DataFrame: ret = cudf.DataFrame({"a": [1, 2, 3]}) From f19d8fd349aaabb9c92108c784f3a4057f39e8da Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 15 Nov 2022 10:29:55 +0100 Subject: [PATCH 28/31] pytest mark: spilling mark benchmark a good candidate to run with `CUDF_SPILL=ON` --- ci/gpu/build.sh | 4 ++-- python/cudf/cudf/tests/pytest.ini | 5 +++++ python/cudf/cudf/tests/test_binops.py | 2 ++ python/cudf/cudf/tests/test_buffer.py | 2 ++ python/cudf/cudf/tests/test_dataframe.py | 2 ++ python/cudf/cudf/tests/test_onehot.py | 2 ++ python/cudf/cudf/tests/test_pickling.py | 10 ++-------- python/cudf/cudf/tests/test_reshape.py | 2 ++ 8 files changed, 19 insertions(+), 10 deletions(-) create mode 100644 python/cudf/cudf/tests/pytest.ini diff --git a/ci/gpu/build.sh b/ci/gpu/build.sh index 825b0bd565a..9e0dd884060 100755 --- a/ci/gpu/build.sh +++ b/ci/gpu/build.sh @@ -283,8 +283,8 @@ gpuci_logger "Python py.test for cuDF" py.test -n 8 --cache-clear --basetemp="$WORKSPACE/cudf-cuda-tmp" --ignore="$WORKSPACE/python/cudf/cudf/benchmarks" --junitxml="$WORKSPACE/junit-cudf.xml" -v --cov-config="$WORKSPACE/python/cudf/.coveragerc" --cov=cudf --cov-report=xml:"$WORKSPACE/python/cudf/cudf-coverage.xml" --cov-report term --dist=loadscope tests gpuci_logger "Python py.tests for cuDF with spilling (CUDF_SPILL_DEVICE_LIMIT=1)" -# Due to time concerns, we only run a limited set of tests -CUDF_SPILL=on CUDF_SPILL_DEVICE_LIMIT=1 py.test -n 8 --cache-clear --basetemp="$WORKSPACE/cudf-cuda-tmp" --ignore="$WORKSPACE/python/cudf/cudf/benchmarks" -v --cov-config="$WORKSPACE/python/cudf/.coveragerc" --cov-append --cov=cudf --cov-report=xml:"$WORKSPACE/python/cudf/cudf-coverage.xml" --cov-report term --dist=loadscope tests/test_binops.py tests/test_dataframe.py tests/test_buffer.py tests/test_onehot.py tests/test_reshape.py tests/test_pickling.py +# Due to time concerns, we only run tests marked "spilling" +CUDF_SPILL=on CUDF_SPILL_DEVICE_LIMIT=1 py.test -n 8 --cache-clear --basetemp="$WORKSPACE/cudf-cuda-tmp" --ignore="$WORKSPACE/python/cudf/cudf/benchmarks" -v --cov-config="$WORKSPACE/python/cudf/.coveragerc" --cov-append --cov=cudf --cov-report=xml:"$WORKSPACE/python/cudf/cudf-coverage.xml" --cov-report term --dist=loadscope -m spilling tests cd "$WORKSPACE/python/dask_cudf" gpuci_logger "Python py.test for dask-cudf" diff --git a/python/cudf/cudf/tests/pytest.ini b/python/cudf/cudf/tests/pytest.ini new file mode 100644 index 00000000000..7adbdb72d72 --- /dev/null +++ b/python/cudf/cudf/tests/pytest.ini @@ -0,0 +1,5 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +[pytest] +markers = + spilling: mark benchmark a good candidate to run with `CUDF_SPILL=ON` diff --git a/python/cudf/cudf/tests/test_binops.py b/python/cudf/cudf/tests/test_binops.py index 2229bcc1938..589755ce980 100644 --- a/python/cudf/cudf/tests/test_binops.py +++ b/python/cudf/cudf/tests/test_binops.py @@ -47,6 +47,8 @@ operator.ge, ] +pytestmark = pytest.mark.spilling + @pytest.mark.parametrize("obj_class", ["Series", "Index"]) @pytest.mark.parametrize("binop", _binops) diff --git a/python/cudf/cudf/tests/test_buffer.py b/python/cudf/cudf/tests/test_buffer.py index 6ff715db761..df7152d53a6 100644 --- a/python/cudf/cudf/tests/test_buffer.py +++ b/python/cudf/cudf/tests/test_buffer.py @@ -5,6 +5,8 @@ from cudf.core.buffer import Buffer, as_buffer +pytestmark = pytest.mark.spilling + arr_len = 10 diff --git a/python/cudf/cudf/tests/test_dataframe.py b/python/cudf/cudf/tests/test_dataframe.py index 58bee95326f..d0c88e3a1e7 100644 --- a/python/cudf/cudf/tests/test_dataframe.py +++ b/python/cudf/cudf/tests/test_dataframe.py @@ -39,6 +39,8 @@ gen_rand, ) +pytestmark = pytest.mark.spilling + def test_init_via_list_of_tuples(): data = [ diff --git a/python/cudf/cudf/tests/test_onehot.py b/python/cudf/cudf/tests/test_onehot.py index 3c067975566..d42b0e85d28 100644 --- a/python/cudf/cudf/tests/test_onehot.py +++ b/python/cudf/cudf/tests/test_onehot.py @@ -10,6 +10,8 @@ from cudf import DataFrame from cudf.testing import _utils as utils +pytestmark = pytest.mark.spilling + @pytest.mark.parametrize( "data, index", diff --git a/python/cudf/cudf/tests/test_pickling.py b/python/cudf/cudf/tests/test_pickling.py index 21343f19d79..8ce818e7a3d 100644 --- a/python/cudf/cudf/tests/test_pickling.py +++ b/python/cudf/cudf/tests/test_pickling.py @@ -1,6 +1,6 @@ # Copyright (c) 2018-2022, NVIDIA CORPORATION. -import sys +import pickle import numpy as np import pandas as pd @@ -10,13 +10,7 @@ from cudf.core.buffer import as_buffer from cudf.testing._utils import assert_eq -if sys.version_info < (3, 8): - try: - import pickle5 as pickle - except ImportError: - import pickle -else: - import pickle +pytestmark = pytest.mark.spilling def check_serialization(df): diff --git a/python/cudf/cudf/tests/test_reshape.py b/python/cudf/cudf/tests/test_reshape.py index 181bff8512a..280b619c305 100644 --- a/python/cudf/cudf/tests/test_reshape.py +++ b/python/cudf/cudf/tests/test_reshape.py @@ -16,6 +16,8 @@ assert_eq, ) +pytestmark = pytest.mark.spilling + @pytest.mark.parametrize("num_id_vars", [0, 1, 2]) @pytest.mark.parametrize("num_value_vars", [0, 1, 2]) From 0c814defd0d325e63d224537bc290f2b1c6ea7c6 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 15 Nov 2022 10:42:07 +0100 Subject: [PATCH 29/31] doc --- python/cudf/cudf/core/buffer/utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/cudf/cudf/core/buffer/utils.py b/python/cudf/cudf/core/buffer/utils.py index cc11a9b9678..71d8ce9853c 100644 --- a/python/cudf/cudf/core/buffer/utils.py +++ b/python/cudf/cudf/core/buffer/utils.py @@ -107,6 +107,12 @@ class acquire_spill_lock(ContextDecorator): All calls to `get_spill_lock()` within the decorated function or context will return a spill lock with a lifetime bound to the function or context. + + Developer Notes + --------------- + We use the global variable `_thread_spill_locks` to track the global spill + lock state. To support concurrency, each thread tracks its own state by + pushing and poping from `_thread_spill_locks` using its thread ID. """ def __enter__(self) -> Optional[SpillLock]: From 040b11d34dad33196bb03231d3211a8a83ce6b87 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 15 Nov 2022 11:51:49 +0100 Subject: [PATCH 30/31] more acquire_spill_lock() --- python/cudf/cudf/_lib/hash.pyx | 4 ++++ python/cudf/cudf/_lib/join.pyx | 9 +++++++-- python/cudf/cudf/_lib/lists.pyx | 13 +++++++++++++ python/cudf/cudf/_lib/null_mask.pyx | 2 ++ python/cudf/cudf/_lib/nvtext/edit_distance.pyx | 4 ++++ python/cudf/cudf/_lib/nvtext/generate_ngrams.pyx | 4 ++++ python/cudf/cudf/_lib/nvtext/ngrams_tokenize.pyx | 3 +++ python/cudf/cudf/_lib/nvtext/normalize.pyx | 4 ++++ python/cudf/cudf/_lib/nvtext/replace.pyx | 4 ++++ python/cudf/cudf/_lib/nvtext/stemmer.pyx | 5 +++++ python/cudf/cudf/_lib/nvtext/subword_tokenize.pyx | 4 ++++ python/cudf/cudf/_lib/nvtext/tokenize.pyx | 8 ++++++++ python/cudf/cudf/_lib/replace.pyx | 8 ++++++++ python/cudf/cudf/_lib/reshape.pyx | 4 ++++ python/cudf/cudf/_lib/rolling.pyx | 3 +++ python/cudf/cudf/_lib/round.pyx | 3 +++ python/cudf/cudf/_lib/search.pyx | 4 ++++ python/cudf/cudf/_lib/sort.pyx | 6 ++++++ python/cudf/cudf/_lib/stream_compaction.pyx | 6 ++++++ python/cudf/cudf/_lib/strings/attributes.pyx | 5 +++++ python/cudf/cudf/_lib/strings/capitalize.pyx | 5 +++++ python/cudf/cudf/_lib/strings/case.pyx | 5 +++++ python/cudf/cudf/_lib/strings/char_types.pyx | 12 ++++++++++++ python/cudf/cudf/_lib/strings/combine.pyx | 6 ++++++ python/cudf/cudf/_lib/strings/contains.pyx | 7 +++++++ .../_lib/strings/convert/convert_fixed_point.pyx | 5 +++++ .../cudf/_lib/strings/convert/convert_floats.pyx | 3 +++ .../cudf/_lib/strings/convert/convert_integers.pyx | 3 +++ .../cudf/_lib/strings/convert/convert_lists.pyx | 3 +++ .../cudf/cudf/_lib/strings/convert/convert_urls.pyx | 4 ++++ python/cudf/cudf/_lib/strings/extract.pyx | 3 +++ python/cudf/cudf/_lib/strings/find.pyx | 10 ++++++++++ python/cudf/cudf/_lib/strings/find_multiple.pyx | 3 +++ python/cudf/cudf/_lib/strings/findall.pyx | 3 +++ python/cudf/cudf/_lib/strings/json.pyx | 3 +++ python/cudf/cudf/_lib/strings/padding.pyx | 10 ++++++++-- python/cudf/cudf/_lib/strings/repeat.pyx | 4 ++++ python/cudf/cudf/_lib/strings/replace.pyx | 6 ++++++ python/cudf/cudf/_lib/strings/replace_re.pyx | 5 +++++ python/cudf/cudf/_lib/strings/split/partition.pyx | 4 ++++ python/cudf/cudf/_lib/strings/split/split.pyx | 10 ++++++++++ python/cudf/cudf/_lib/strings/strip.pyx | 5 +++++ python/cudf/cudf/_lib/strings/substring.pyx | 12 ++++++++---- python/cudf/cudf/_lib/strings/translate.pyx | 4 ++++ python/cudf/cudf/_lib/strings/wrap.pyx | 3 +++ python/cudf/cudf/_lib/transform.pyx | 12 ++++++++---- 46 files changed, 241 insertions(+), 12 deletions(-) diff --git a/python/cudf/cudf/_lib/hash.pyx b/python/cudf/cudf/_lib/hash.pyx index 03033cd1a7e..1264a9b2126 100644 --- a/python/cudf/cudf/_lib/hash.pyx +++ b/python/cudf/cudf/_lib/hash.pyx @@ -1,5 +1,7 @@ # Copyright (c) 2020-2022, NVIDIA CORPORATION. +from cudf.core.buffer import acquire_spill_lock + from libcpp.memory cimport unique_ptr from libcpp.pair cimport pair from libcpp.utility cimport move @@ -15,6 +17,7 @@ from cudf._lib.cpp.table.table_view cimport table_view from cudf._lib.utils cimport columns_from_unique_ptr, table_view_from_columns +@acquire_spill_lock() def hash_partition(list source_columns, object columns_to_hash, int num_partitions): cdef vector[libcudf_types.size_type] c_columns_to_hash = columns_to_hash @@ -37,6 +40,7 @@ def hash_partition(list source_columns, object columns_to_hash, ) +@acquire_spill_lock() def hash(list source_columns, str method, int seed=0): cdef table_view c_source_view = table_view_from_columns(source_columns) cdef unique_ptr[column] c_result diff --git a/python/cudf/cudf/_lib/join.pyx b/python/cudf/cudf/_lib/join.pyx index ff5f6e1afcc..da03e8dcdd1 100644 --- a/python/cudf/cudf/_lib/join.pyx +++ b/python/cudf/cudf/_lib/join.pyx @@ -1,5 +1,7 @@ # Copyright (c) 2020-2022, NVIDIA CORPORATION. +from cudf.core.buffer import acquire_spill_lock + from libcpp.memory cimport make_unique, unique_ptr from libcpp.pair cimport pair from libcpp.utility cimport move @@ -14,7 +16,9 @@ from cudf._lib.utils cimport table_view_from_columns # The functions below return the *gathermaps* that represent # the join result when joining on the keys `lhs` and `rhs`. -cpdef join(list lhs, list rhs, how=None): + +@acquire_spill_lock() +def join(list lhs, list rhs, how=None): cdef pair[cpp_join.gather_map_type, cpp_join.gather_map_type] c_result cdef table_view c_lhs = table_view_from_columns(lhs) cdef table_view c_rhs = table_view_from_columns(rhs) @@ -36,7 +40,8 @@ cpdef join(list lhs, list rhs, how=None): return left_rows, right_rows -cpdef semi_join(list lhs, list rhs, how=None): +@acquire_spill_lock() +def semi_join(list lhs, list rhs, how=None): # left-semi and left-anti joins cdef cpp_join.gather_map_type c_result cdef table_view c_lhs = table_view_from_columns(lhs) diff --git a/python/cudf/cudf/_lib/lists.pyx b/python/cudf/cudf/_lib/lists.pyx index 8a7b4be3be9..47e9dccc8e6 100644 --- a/python/cudf/cudf/_lib/lists.pyx +++ b/python/cudf/cudf/_lib/lists.pyx @@ -1,5 +1,7 @@ # Copyright (c) 2021-2022, NVIDIA CORPORATION. +from cudf.core.buffer import acquire_spill_lock + from libcpp cimport bool from libcpp.memory cimport make_shared, shared_ptr, unique_ptr from libcpp.utility cimport move @@ -35,6 +37,7 @@ from cudf._lib.scalar cimport DeviceScalar from cudf._lib.utils cimport columns_from_unique_ptr, table_view_from_columns +@acquire_spill_lock() def count_elements(Column col): # shared_ptr required because lists_column_view has no default @@ -51,6 +54,7 @@ def count_elements(Column col): return result +@acquire_spill_lock() def explode_outer( list source_columns, int explode_column_idx ): @@ -65,6 +69,7 @@ def explode_outer( return columns_from_unique_ptr(move(c_result)) +@acquire_spill_lock() def distinct(Column col, bool nulls_equal, bool nans_all_equal): """ nulls_equal == True indicates that libcudf should treat any two nulls as @@ -93,6 +98,7 @@ def distinct(Column col, bool nulls_equal, bool nans_all_equal): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def sort_lists(Column col, bool ascending, str na_position): cdef shared_ptr[lists_column_view] list_view = ( make_shared[lists_column_view](col.view()) @@ -114,6 +120,7 @@ def sort_lists(Column col, bool ascending, str na_position): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def extract_element_scalar(Column col, size_type index): # shared_ptr required because lists_column_view has no default # ctor @@ -130,6 +137,7 @@ def extract_element_scalar(Column col, size_type index): return result +@acquire_spill_lock() def extract_element_column(Column col, Column index): cdef shared_ptr[lists_column_view] list_view = ( make_shared[lists_column_view](col.view()) @@ -146,6 +154,7 @@ def extract_element_column(Column col, Column index): return result +@acquire_spill_lock() def contains_scalar(Column col, object py_search_key): cdef DeviceScalar search_key = py_search_key.device_value @@ -166,6 +175,7 @@ def contains_scalar(Column col, object py_search_key): return result +@acquire_spill_lock() def index_of_scalar(Column col, object py_search_key): cdef DeviceScalar search_key = py_search_key.device_value @@ -185,6 +195,7 @@ def index_of_scalar(Column col, object py_search_key): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def index_of_column(Column col, Column search_keys): cdef column_view keys_view = search_keys.view() @@ -203,6 +214,7 @@ def index_of_column(Column col, Column search_keys): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def concatenate_rows(list source_columns): cdef unique_ptr[column] c_result @@ -216,6 +228,7 @@ def concatenate_rows(list source_columns): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def concatenate_list_elements(Column input_column, dropna=False): cdef concatenate_null_policy policy = ( concatenate_null_policy.IGNORE if dropna diff --git a/python/cudf/cudf/_lib/null_mask.pyx b/python/cudf/cudf/_lib/null_mask.pyx index 140fa75f569..c41ae98b9bd 100644 --- a/python/cudf/cudf/_lib/null_mask.pyx +++ b/python/cudf/cudf/_lib/null_mask.pyx @@ -103,6 +103,7 @@ def create_null_mask(size_type size, state=MaskState.UNINITIALIZED): return buf +@acquire_spill_lock() def bitmask_and(columns: list): cdef table_view c_view = table_view_from_columns(columns) cdef pair[device_buffer, size_type] c_result @@ -115,6 +116,7 @@ def bitmask_and(columns: list): return buf, c_result.second +@acquire_spill_lock() def bitmask_or(columns: list): cdef table_view c_view = table_view_from_columns(columns) cdef pair[device_buffer, size_type] c_result diff --git a/python/cudf/cudf/_lib/nvtext/edit_distance.pyx b/python/cudf/cudf/_lib/nvtext/edit_distance.pyx index c8dc6edd6e2..984c8e84d7c 100644 --- a/python/cudf/cudf/_lib/nvtext/edit_distance.pyx +++ b/python/cudf/cudf/_lib/nvtext/edit_distance.pyx @@ -1,5 +1,7 @@ # Copyright (c) 2020-2022, NVIDIA CORPORATION. +from cudf.core.buffer import acquire_spill_lock + from libcpp.memory cimport unique_ptr from libcpp.utility cimport move @@ -12,6 +14,7 @@ from cudf._lib.cpp.nvtext.edit_distance cimport ( ) +@acquire_spill_lock() def edit_distance(Column strings, Column targets): cdef column_view c_strings = strings.view() cdef column_view c_targets = targets.view() @@ -23,6 +26,7 @@ def edit_distance(Column strings, Column targets): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def edit_distance_matrix(Column strings): cdef column_view c_strings = strings.view() cdef unique_ptr[column] c_result diff --git a/python/cudf/cudf/_lib/nvtext/generate_ngrams.pyx b/python/cudf/cudf/_lib/nvtext/generate_ngrams.pyx index 5fcec570dcb..c2c32314f49 100644 --- a/python/cudf/cudf/_lib/nvtext/generate_ngrams.pyx +++ b/python/cudf/cudf/_lib/nvtext/generate_ngrams.pyx @@ -1,5 +1,7 @@ # Copyright (c) 2018-2020, NVIDIA CORPORATION. +from cudf.core.buffer import acquire_spill_lock + from libcpp.memory cimport unique_ptr from libcpp.utility cimport move @@ -15,6 +17,7 @@ from cudf._lib.cpp.types cimport size_type from cudf._lib.scalar cimport DeviceScalar +@acquire_spill_lock() def generate_ngrams(Column strings, int ngrams, object py_separator): cdef DeviceScalar separator = py_separator.device_value @@ -37,6 +40,7 @@ def generate_ngrams(Column strings, int ngrams, object py_separator): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def generate_character_ngrams(Column strings, int ngrams): cdef column_view c_strings = strings.view() cdef size_type c_ngrams = ngrams diff --git a/python/cudf/cudf/_lib/nvtext/ngrams_tokenize.pyx b/python/cudf/cudf/_lib/nvtext/ngrams_tokenize.pyx index 1e9e0e39ff1..104741f2ee8 100644 --- a/python/cudf/cudf/_lib/nvtext/ngrams_tokenize.pyx +++ b/python/cudf/cudf/_lib/nvtext/ngrams_tokenize.pyx @@ -1,5 +1,7 @@ # Copyright (c) 2018-2020, NVIDIA CORPORATION. +from cudf.core.buffer import acquire_spill_lock + from libcpp.memory cimport unique_ptr from libcpp.utility cimport move @@ -14,6 +16,7 @@ from cudf._lib.cpp.types cimport size_type from cudf._lib.scalar cimport DeviceScalar +@acquire_spill_lock() def ngrams_tokenize( Column strings, int ngrams, diff --git a/python/cudf/cudf/_lib/nvtext/normalize.pyx b/python/cudf/cudf/_lib/nvtext/normalize.pyx index e475f0cd996..fa86e580aca 100644 --- a/python/cudf/cudf/_lib/nvtext/normalize.pyx +++ b/python/cudf/cudf/_lib/nvtext/normalize.pyx @@ -1,5 +1,7 @@ # Copyright (c) 2018-2020, NVIDIA CORPORATION. +from cudf.core.buffer import acquire_spill_lock + from libcpp cimport bool from libcpp.memory cimport unique_ptr from libcpp.utility cimport move @@ -13,6 +15,7 @@ from cudf._lib.cpp.nvtext.normalize cimport ( ) +@acquire_spill_lock() def normalize_spaces(Column strings): cdef column_view c_strings = strings.view() cdef unique_ptr[column] c_result @@ -23,6 +26,7 @@ def normalize_spaces(Column strings): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def normalize_characters(Column strings, bool do_lower=True): cdef column_view c_strings = strings.view() cdef unique_ptr[column] c_result diff --git a/python/cudf/cudf/_lib/nvtext/replace.pyx b/python/cudf/cudf/_lib/nvtext/replace.pyx index b4f37ac3ec7..535816a6066 100644 --- a/python/cudf/cudf/_lib/nvtext/replace.pyx +++ b/python/cudf/cudf/_lib/nvtext/replace.pyx @@ -1,5 +1,7 @@ # Copyright (c) 2020, NVIDIA CORPORATION. +from cudf.core.buffer import acquire_spill_lock + from libcpp.memory cimport unique_ptr from libcpp.utility cimport move @@ -15,6 +17,7 @@ from cudf._lib.cpp.types cimport size_type from cudf._lib.scalar cimport DeviceScalar +@acquire_spill_lock() def replace_tokens(Column strings, Column targets, Column replacements, @@ -49,6 +52,7 @@ def replace_tokens(Column strings, return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def filter_tokens(Column strings, size_type min_token_length, object py_replacement, diff --git a/python/cudf/cudf/_lib/nvtext/stemmer.pyx b/python/cudf/cudf/_lib/nvtext/stemmer.pyx index 89d4b07b7ad..c8a93f8e67d 100644 --- a/python/cudf/cudf/_lib/nvtext/stemmer.pyx +++ b/python/cudf/cudf/_lib/nvtext/stemmer.pyx @@ -1,5 +1,7 @@ # Copyright (c) 2020, NVIDIA CORPORATION. +from cudf.core.buffer import acquire_spill_lock + from libcpp.memory cimport unique_ptr from libcpp.utility cimport move @@ -22,6 +24,7 @@ class LetterType(IntEnum): VOWEL = letter_type.VOWEL +@acquire_spill_lock() def porter_stemmer_measure(Column strings): cdef column_view c_strings = strings.view() cdef unique_ptr[column] c_result @@ -32,6 +35,7 @@ def porter_stemmer_measure(Column strings): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def is_letter(Column strings, object ltype, size_type index): @@ -47,6 +51,7 @@ def is_letter(Column strings, return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def is_letter_multi(Column strings, object ltype, Column indices): diff --git a/python/cudf/cudf/_lib/nvtext/subword_tokenize.pyx b/python/cudf/cudf/_lib/nvtext/subword_tokenize.pyx index dd8bbd6d7b6..dbd23d91cc5 100644 --- a/python/cudf/cudf/_lib/nvtext/subword_tokenize.pyx +++ b/python/cudf/cudf/_lib/nvtext/subword_tokenize.pyx @@ -1,6 +1,9 @@ # Copyright (c) 2020-2022, NVIDIA CORPORATION. from libc.stdint cimport uint32_t + +from cudf.core.buffer import acquire_spill_lock + from libcpp cimport bool from libcpp.memory cimport unique_ptr from libcpp.string cimport string @@ -26,6 +29,7 @@ cdef class Hashed_Vocabulary: self.c_obj = move(cpp_load_vocabulary_file(c_hash_file)) +@acquire_spill_lock() def subword_tokenize_inmem_hash( Column strings, Hashed_Vocabulary hashed_vocabulary, diff --git a/python/cudf/cudf/_lib/nvtext/tokenize.pyx b/python/cudf/cudf/_lib/nvtext/tokenize.pyx index 00f63b9cf7c..2bb4fa8e108 100644 --- a/python/cudf/cudf/_lib/nvtext/tokenize.pyx +++ b/python/cudf/cudf/_lib/nvtext/tokenize.pyx @@ -1,5 +1,7 @@ # Copyright (c) 2018-2022, NVIDIA CORPORATION. +from cudf.core.buffer import acquire_spill_lock + from libcpp.memory cimport unique_ptr from libcpp.utility cimport move @@ -16,6 +18,7 @@ from cudf._lib.cpp.scalar.scalar cimport string_scalar from cudf._lib.scalar cimport DeviceScalar +@acquire_spill_lock() def _tokenize_scalar(Column strings, object py_delimiter): cdef DeviceScalar delimiter = py_delimiter.device_value @@ -36,6 +39,7 @@ def _tokenize_scalar(Column strings, object py_delimiter): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def _tokenize_column(Column strings, Column delimiters): cdef column_view c_strings = strings.view() cdef column_view c_delimiters = delimiters.view() @@ -52,6 +56,7 @@ def _tokenize_column(Column strings, Column delimiters): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def _count_tokens_scalar(Column strings, object py_delimiter): cdef DeviceScalar delimiter = py_delimiter.device_value @@ -72,6 +77,7 @@ def _count_tokens_scalar(Column strings, object py_delimiter): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def _count_tokens_column(Column strings, Column delimiters): cdef column_view c_strings = strings.view() cdef column_view c_delimiters = delimiters.view() @@ -88,6 +94,7 @@ def _count_tokens_column(Column strings, Column delimiters): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def character_tokenize(Column strings): cdef column_view c_strings = strings.view() cdef unique_ptr[column] c_result @@ -99,6 +106,7 @@ def character_tokenize(Column strings): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def detokenize(Column strings, Column indices, object py_separator): cdef DeviceScalar separator = py_separator.device_value diff --git a/python/cudf/cudf/_lib/replace.pyx b/python/cudf/cudf/_lib/replace.pyx index 06e94934ef5..c763a86d6e5 100644 --- a/python/cudf/cudf/_lib/replace.pyx +++ b/python/cudf/cudf/_lib/replace.pyx @@ -23,6 +23,7 @@ from cudf._lib.cpp.scalar.scalar cimport scalar from cudf._lib.scalar cimport DeviceScalar +@acquire_spill_lock() def replace(Column input_col, Column values_to_replace, Column replacement_values): """ @@ -49,6 +50,7 @@ def replace(Column input_col, Column values_to_replace, return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def replace_nulls_column(Column input_col, Column replacement_values): """ Replaces null values in input_col with corresponding values from @@ -71,6 +73,7 @@ def replace_nulls_column(Column input_col, Column replacement_values): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def replace_nulls_scalar(Column input_col, DeviceScalar replacement_value): """ Replaces null values in input_col with replacement_value @@ -93,6 +96,7 @@ def replace_nulls_scalar(Column input_col, DeviceScalar replacement_value): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def replace_nulls_fill(Column input_col, object method): """ Replaces null values in input_col with replacement_value @@ -146,6 +150,7 @@ def replace_nulls( return replace_nulls_column(input_col, replacement) +@acquire_spill_lock() def clamp(Column input_col, DeviceScalar lo, DeviceScalar lo_replace, DeviceScalar hi, DeviceScalar hi_replace): """ @@ -176,6 +181,7 @@ def clamp(Column input_col, DeviceScalar lo, DeviceScalar lo_replace, return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def clamp(Column input_col, DeviceScalar lo, DeviceScalar hi): """ Clip the input_col such that values < lo will be replaced by lo @@ -199,6 +205,7 @@ def clamp(Column input_col, DeviceScalar lo, DeviceScalar hi): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def clip(Column input_col, object lo, object hi): """ Clip the input_col such that values < lo will be replaced by lo @@ -222,6 +229,7 @@ def normalize_nans_and_zeros_inplace(Column input_col): cpp_normalize_nans_and_zeros(input_col_view) +@acquire_spill_lock() def normalize_nans_and_zeros_column(Column input_col): """ Returns a new normalized Column diff --git a/python/cudf/cudf/_lib/reshape.pyx b/python/cudf/cudf/_lib/reshape.pyx index 84bad039199..c237b7b1389 100644 --- a/python/cudf/cudf/_lib/reshape.pyx +++ b/python/cudf/cudf/_lib/reshape.pyx @@ -1,5 +1,7 @@ # Copyright (c) 2019-2022, NVIDIA CORPORATION. +from cudf.core.buffer import acquire_spill_lock + from libcpp.memory cimport unique_ptr from libcpp.utility cimport move @@ -15,6 +17,7 @@ from cudf._lib.cpp.types cimport size_type from cudf._lib.utils cimport columns_from_unique_ptr, table_view_from_columns +@acquire_spill_lock() def interleave_columns(list source_columns): cdef table_view c_view = table_view_from_columns(source_columns) cdef unique_ptr[column] c_result @@ -25,6 +28,7 @@ def interleave_columns(list source_columns): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def tile(list source_columns, size_type count): cdef size_type c_count = count cdef table_view c_view = table_view_from_columns(source_columns) diff --git a/python/cudf/cudf/_lib/rolling.pyx b/python/cudf/cudf/_lib/rolling.pyx index 7b0da6957a0..8c4751e3084 100644 --- a/python/cudf/cudf/_lib/rolling.pyx +++ b/python/cudf/cudf/_lib/rolling.pyx @@ -1,5 +1,7 @@ # Copyright (c) 2020-2022, NVIDIA CORPORATION. +from cudf.core.buffer import acquire_spill_lock + from libcpp.memory cimport unique_ptr from libcpp.utility cimport move @@ -11,6 +13,7 @@ from cudf._lib.cpp.rolling cimport rolling_window as cpp_rolling_window from cudf._lib.cpp.types cimport size_type +@acquire_spill_lock() def rolling(Column source_column, Column pre_column_window, Column fwd_column_window, diff --git a/python/cudf/cudf/_lib/round.pyx b/python/cudf/cudf/_lib/round.pyx index c5c565561a9..b62b5a4bb34 100644 --- a/python/cudf/cudf/_lib/round.pyx +++ b/python/cudf/cudf/_lib/round.pyx @@ -1,5 +1,7 @@ # Copyright (c) 2021, NVIDIA CORPORATION. +from cudf.core.buffer import acquire_spill_lock + from libcpp.memory cimport unique_ptr from libcpp.utility cimport move @@ -12,6 +14,7 @@ from cudf._lib.cpp.round cimport ( ) +@acquire_spill_lock() def round(Column input_col, int decimal_places=0, how="half_even"): """ Round column values to the given number of decimal places diff --git a/python/cudf/cudf/_lib/search.pyx b/python/cudf/cudf/_lib/search.pyx index b8abe3d0dab..fef3a08c6d7 100644 --- a/python/cudf/cudf/_lib/search.pyx +++ b/python/cudf/cudf/_lib/search.pyx @@ -1,5 +1,7 @@ # Copyright (c) 2020-2022, NVIDIA CORPORATION. +from cudf.core.buffer import acquire_spill_lock + from libcpp.memory cimport unique_ptr from libcpp.utility cimport move from libcpp.vector cimport vector @@ -13,6 +15,7 @@ from cudf._lib.cpp.table.table_view cimport table_view from cudf._lib.utils cimport table_view_from_columns +@acquire_spill_lock() def search_sorted( list source, list values, side, ascending=True, na_position="last" ): @@ -73,6 +76,7 @@ def search_sorted( return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def contains(Column haystack, Column needles): """Check whether column contains multiple values diff --git a/python/cudf/cudf/_lib/sort.pyx b/python/cudf/cudf/_lib/sort.pyx index eb3aed80700..3b96cc618dd 100644 --- a/python/cudf/cudf/_lib/sort.pyx +++ b/python/cudf/cudf/_lib/sort.pyx @@ -1,5 +1,7 @@ # Copyright (c) 2020-2022, NVIDIA CORPORATION. +from cudf.core.buffer import acquire_spill_lock + from libcpp cimport bool from libcpp.memory cimport unique_ptr from libcpp.utility cimport move @@ -23,6 +25,7 @@ from cudf._lib.cpp.types cimport null_order, null_policy, order from cudf._lib.utils cimport table_view_from_columns +@acquire_spill_lock() def is_sorted( list source_columns, object ascending=None, object null_position=None ): @@ -98,6 +101,7 @@ def is_sorted( return c_result +@acquire_spill_lock() def order_by(list columns_from_table, object ascending, str na_position): """ Get index to sort the table in ascending/descending order. @@ -139,6 +143,7 @@ def order_by(list columns_from_table, object ascending, str na_position): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def digitize(list source_columns, list bins, bool right=False): """ Return the indices of the bins to which each value in source_table belongs. @@ -189,6 +194,7 @@ def digitize(list source_columns, list bins, bool right=False): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def rank_columns(list source_columns, object method, str na_option, bool ascending, bool pct ): diff --git a/python/cudf/cudf/_lib/stream_compaction.pyx b/python/cudf/cudf/_lib/stream_compaction.pyx index 38cead87e76..143999e52ef 100644 --- a/python/cudf/cudf/_lib/stream_compaction.pyx +++ b/python/cudf/cudf/_lib/stream_compaction.pyx @@ -1,5 +1,7 @@ # Copyright (c) 2020-2022, NVIDIA CORPORATION. +from cudf.core.buffer import acquire_spill_lock + from libcpp cimport bool from libcpp.memory cimport unique_ptr from libcpp.utility cimport move @@ -28,6 +30,7 @@ from cudf._lib.cpp.types cimport ( from cudf._lib.utils cimport columns_from_unique_ptr, table_view_from_columns +@acquire_spill_lock() def drop_nulls(list columns, how="any", keys=None, thresh=None): """ Drops null rows from cols depending on key columns. @@ -71,6 +74,7 @@ def drop_nulls(list columns, how="any", keys=None, thresh=None): return columns_from_unique_ptr(move(c_result)) +@acquire_spill_lock() def apply_boolean_mask(list columns, Column boolean_mask): """ Drops the rows which correspond to False in boolean_mask. @@ -100,6 +104,7 @@ def apply_boolean_mask(list columns, Column boolean_mask): return columns_from_unique_ptr(move(c_result)) +@acquire_spill_lock() def drop_duplicates(list columns, object keys=None, object keep='first', @@ -184,6 +189,7 @@ def drop_duplicates(list columns, return columns_from_unique_ptr(move(c_result)) +@acquire_spill_lock() def distinct_count(Column source_column, ignore_nulls=True, nan_as_null=False): """ Finds number of unique rows in `source_column` diff --git a/python/cudf/cudf/_lib/strings/attributes.pyx b/python/cudf/cudf/_lib/strings/attributes.pyx index 8720fad7455..4add4aa8e8c 100644 --- a/python/cudf/cudf/_lib/strings/attributes.pyx +++ b/python/cudf/cudf/_lib/strings/attributes.pyx @@ -1,5 +1,7 @@ # Copyright (c) 2018-2020, NVIDIA CORPORATION. +from cudf.core.buffer import acquire_spill_lock + from libcpp.memory cimport unique_ptr from libcpp.utility cimport move @@ -13,6 +15,7 @@ from cudf._lib.cpp.strings.attributes cimport ( ) +@acquire_spill_lock() def count_characters(Column source_strings): """ Returns an integer numeric column containing the @@ -27,6 +30,7 @@ def count_characters(Column source_strings): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def count_bytes(Column source_strings): """ Returns an integer numeric column containing the @@ -41,6 +45,7 @@ def count_bytes(Column source_strings): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def code_points(Column source_strings): """ Creates a numeric column with code point values (integers) diff --git a/python/cudf/cudf/_lib/strings/capitalize.pyx b/python/cudf/cudf/_lib/strings/capitalize.pyx index 0bbdfa462e2..cfb0feee26c 100644 --- a/python/cudf/cudf/_lib/strings/capitalize.pyx +++ b/python/cudf/cudf/_lib/strings/capitalize.pyx @@ -1,5 +1,7 @@ # Copyright (c) 2018-2021, NVIDIA CORPORATION. +from cudf.core.buffer import acquire_spill_lock + from libcpp.memory cimport unique_ptr from libcpp.utility cimport move @@ -13,6 +15,7 @@ from cudf._lib.cpp.strings.capitalize cimport ( ) +@acquire_spill_lock() def capitalize(Column source_strings): cdef unique_ptr[column] c_result cdef column_view source_view = source_strings.view() @@ -23,6 +26,7 @@ def capitalize(Column source_strings): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def title(Column source_strings): cdef unique_ptr[column] c_result cdef column_view source_view = source_strings.view() @@ -33,6 +37,7 @@ def title(Column source_strings): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def is_title(Column source_strings): cdef unique_ptr[column] c_result cdef column_view source_view = source_strings.view() diff --git a/python/cudf/cudf/_lib/strings/case.pyx b/python/cudf/cudf/_lib/strings/case.pyx index 13679f3fb02..fbf328f9f9f 100644 --- a/python/cudf/cudf/_lib/strings/case.pyx +++ b/python/cudf/cudf/_lib/strings/case.pyx @@ -1,5 +1,7 @@ # Copyright (c) 2018-2020, NVIDIA CORPORATION. +from cudf.core.buffer import acquire_spill_lock + from libcpp.memory cimport unique_ptr from libcpp.utility cimport move @@ -13,6 +15,7 @@ from cudf._lib.cpp.strings.case cimport ( ) +@acquire_spill_lock() def to_upper(Column source_strings): cdef unique_ptr[column] c_result cdef column_view source_view = source_strings.view() @@ -23,6 +26,7 @@ def to_upper(Column source_strings): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def to_lower(Column source_strings): cdef unique_ptr[column] c_result cdef column_view source_view = source_strings.view() @@ -33,6 +37,7 @@ def to_lower(Column source_strings): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def swapcase(Column source_strings): cdef unique_ptr[column] c_result cdef column_view source_view = source_strings.view() diff --git a/python/cudf/cudf/_lib/strings/char_types.pyx b/python/cudf/cudf/_lib/strings/char_types.pyx index 3ef9db2345d..25294d0d626 100644 --- a/python/cudf/cudf/_lib/strings/char_types.pyx +++ b/python/cudf/cudf/_lib/strings/char_types.pyx @@ -1,9 +1,12 @@ # Copyright (c) 2021, NVIDIA CORPORATION. + from libcpp cimport bool from libcpp.memory cimport unique_ptr from libcpp.utility cimport move +from cudf.core.buffer import acquire_spill_lock + from cudf._lib.column cimport Column from cudf._lib.cpp.column.column cimport column from cudf._lib.cpp.column.column_view cimport column_view @@ -16,6 +19,7 @@ from cudf._lib.cpp.strings.char_types cimport ( from cudf._lib.scalar cimport DeviceScalar +@acquire_spill_lock() def filter_alphanum(Column source_strings, object py_repl, bool keep=True): """ Returns a Column of strings keeping only alphanumeric character types. @@ -42,6 +46,7 @@ def filter_alphanum(Column source_strings, object py_repl, bool keep=True): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def is_decimal(Column source_strings): """ Returns a Column of boolean values with True for `source_strings` @@ -61,6 +66,7 @@ def is_decimal(Column source_strings): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def is_alnum(Column source_strings): """ Returns a Column of boolean values with True for `source_strings` @@ -81,6 +87,7 @@ def is_alnum(Column source_strings): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def is_alpha(Column source_strings): """ Returns a Column of boolean values with True for `source_strings` @@ -99,6 +106,7 @@ def is_alpha(Column source_strings): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def is_digit(Column source_strings): """ Returns a Column of boolean values with True for `source_strings` @@ -117,6 +125,7 @@ def is_digit(Column source_strings): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def is_numeric(Column source_strings): """ Returns a Column of boolean values with True for `source_strings` @@ -136,6 +145,7 @@ def is_numeric(Column source_strings): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def is_upper(Column source_strings): """ Returns a Column of boolean values with True for `source_strings` @@ -154,6 +164,7 @@ def is_upper(Column source_strings): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def is_lower(Column source_strings): """ Returns a Column of boolean values with True for `source_strings` @@ -172,6 +183,7 @@ def is_lower(Column source_strings): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def is_space(Column source_strings): """ Returns a Column of boolean values with True for `source_strings` diff --git a/python/cudf/cudf/_lib/strings/combine.pyx b/python/cudf/cudf/_lib/strings/combine.pyx index 141732b4c75..f38f4c5f847 100644 --- a/python/cudf/cudf/_lib/strings/combine.pyx +++ b/python/cudf/cudf/_lib/strings/combine.pyx @@ -1,5 +1,7 @@ # Copyright (c) 2020-2022, NVIDIA CORPORATION. +from cudf.core.buffer import acquire_spill_lock + from libcpp.memory cimport unique_ptr from libcpp.utility cimport move @@ -19,6 +21,7 @@ from cudf._lib.scalar cimport DeviceScalar from cudf._lib.utils cimport table_view_from_columns +@acquire_spill_lock() def concatenate(list source_strings, object sep, object na_rep): @@ -49,6 +52,7 @@ def concatenate(list source_strings, return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def join(Column source_strings, object sep, object na_rep): @@ -80,6 +84,7 @@ def join(Column source_strings, return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def join_lists_with_scalar( Column source_strings, object py_separator, @@ -115,6 +120,7 @@ def join_lists_with_scalar( return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def join_lists_with_column( Column source_strings, Column separator_strings, diff --git a/python/cudf/cudf/_lib/strings/contains.pyx b/python/cudf/cudf/_lib/strings/contains.pyx index 41c4b54d8b1..7ca93b83921 100644 --- a/python/cudf/cudf/_lib/strings/contains.pyx +++ b/python/cudf/cudf/_lib/strings/contains.pyx @@ -1,6 +1,9 @@ # Copyright (c) 2020-2022, NVIDIA CORPORATION. from libc.stdint cimport uint32_t + +from cudf.core.buffer import acquire_spill_lock + from libcpp.memory cimport unique_ptr from libcpp.string cimport string from libcpp.utility cimport move @@ -19,6 +22,7 @@ from cudf._lib.cpp.strings.regex_flags cimport regex_flags from cudf._lib.scalar cimport DeviceScalar +@acquire_spill_lock() def contains_re(Column source_strings, object reg_ex, uint32_t flags): """ Returns a Column of boolean values with True for `source_strings` @@ -40,6 +44,7 @@ def contains_re(Column source_strings, object reg_ex, uint32_t flags): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def count_re(Column source_strings, object reg_ex, uint32_t flags): """ Returns a Column with count of occurrences of `reg_ex` in @@ -61,6 +66,7 @@ def count_re(Column source_strings, object reg_ex, uint32_t flags): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def match_re(Column source_strings, object reg_ex, uint32_t flags): """ Returns a Column with each value True if the string matches `reg_ex` @@ -82,6 +88,7 @@ def match_re(Column source_strings, object reg_ex, uint32_t flags): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def like(Column source_strings, object py_pattern, object py_escape): """ Returns a Column with each value True if the string matches the diff --git a/python/cudf/cudf/_lib/strings/convert/convert_fixed_point.pyx b/python/cudf/cudf/_lib/strings/convert/convert_fixed_point.pyx index fc07cf6462a..177cbffddb0 100644 --- a/python/cudf/cudf/_lib/strings/convert/convert_fixed_point.pyx +++ b/python/cudf/cudf/_lib/strings/convert/convert_fixed_point.pyx @@ -5,6 +5,8 @@ import cudf from libcpp.memory cimport unique_ptr from libcpp.utility cimport move +from cudf.core.buffer import acquire_spill_lock + from cudf._lib.column cimport Column from cudf._lib.cpp.column.column cimport column from cudf._lib.cpp.column.column_view cimport column_view @@ -16,6 +18,7 @@ from cudf._lib.cpp.strings.convert.convert_fixed_point cimport ( from cudf._lib.cpp.types cimport DECIMAL32, DECIMAL64, DECIMAL128, data_type +@acquire_spill_lock() def from_decimal(Column input_col): """ Converts a `Decimal64Column` to a `StringColumn`. @@ -38,6 +41,7 @@ def from_decimal(Column input_col): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def to_decimal(Column input_col, object out_type): """ Returns a `Decimal64Column` from the provided `StringColumn` @@ -75,6 +79,7 @@ def to_decimal(Column input_col, object out_type): return result +@acquire_spill_lock() def is_fixed_point(Column input_col, object dtype): """ Returns a Column of boolean values with True for `input_col` diff --git a/python/cudf/cudf/_lib/strings/convert/convert_floats.pyx b/python/cudf/cudf/_lib/strings/convert/convert_floats.pyx index f9d028c5eb5..d1617d85593 100644 --- a/python/cudf/cudf/_lib/strings/convert/convert_floats.pyx +++ b/python/cudf/cudf/_lib/strings/convert/convert_floats.pyx @@ -3,6 +3,8 @@ from libcpp.memory cimport unique_ptr from libcpp.utility cimport move +from cudf.core.buffer import acquire_spill_lock + from cudf._lib.column cimport Column from cudf._lib.cpp.column.column cimport column from cudf._lib.cpp.column.column_view cimport column_view @@ -11,6 +13,7 @@ from cudf._lib.cpp.strings.convert.convert_floats cimport ( ) +@acquire_spill_lock() def is_float(Column source_strings): """ Returns a Column of boolean values with True for `source_strings` diff --git a/python/cudf/cudf/_lib/strings/convert/convert_integers.pyx b/python/cudf/cudf/_lib/strings/convert/convert_integers.pyx index 220cbd0f760..dc560c42182 100644 --- a/python/cudf/cudf/_lib/strings/convert/convert_integers.pyx +++ b/python/cudf/cudf/_lib/strings/convert/convert_integers.pyx @@ -3,6 +3,8 @@ from libcpp.memory cimport unique_ptr from libcpp.utility cimport move +from cudf.core.buffer import acquire_spill_lock + from cudf._lib.column cimport Column from cudf._lib.cpp.column.column cimport column from cudf._lib.cpp.column.column_view cimport column_view @@ -11,6 +13,7 @@ from cudf._lib.cpp.strings.convert.convert_integers cimport ( ) +@acquire_spill_lock() def is_integer(Column source_strings): """ Returns a Column of boolean values with True for `source_strings` diff --git a/python/cudf/cudf/_lib/strings/convert/convert_lists.pyx b/python/cudf/cudf/_lib/strings/convert/convert_lists.pyx index 7ffa69cd680..61869411183 100644 --- a/python/cudf/cudf/_lib/strings/convert/convert_lists.pyx +++ b/python/cudf/cudf/_lib/strings/convert/convert_lists.pyx @@ -3,6 +3,8 @@ from libcpp.memory cimport unique_ptr from libcpp.utility cimport move +from cudf.core.buffer import acquire_spill_lock + from cudf._lib.column cimport Column from cudf._lib.cpp.column.column cimport column from cudf._lib.cpp.column.column_view cimport column_view @@ -16,6 +18,7 @@ from cudf._lib.scalar import as_device_scalar from cudf._lib.scalar cimport DeviceScalar +@acquire_spill_lock() def format_list_column(Column source_list, Column separators): """ Format a list column of strings into a strings column. diff --git a/python/cudf/cudf/_lib/strings/convert/convert_urls.pyx b/python/cudf/cudf/_lib/strings/convert/convert_urls.pyx index 8d673de12b8..bc8123281f0 100644 --- a/python/cudf/cudf/_lib/strings/convert/convert_urls.pyx +++ b/python/cudf/cudf/_lib/strings/convert/convert_urls.pyx @@ -3,6 +3,8 @@ from libcpp.memory cimport unique_ptr from libcpp.utility cimport move +from cudf.core.buffer import acquire_spill_lock + from cudf._lib.column cimport Column from cudf._lib.cpp.column.column cimport column from cudf._lib.cpp.column.column_view cimport column_view @@ -12,6 +14,7 @@ from cudf._lib.cpp.strings.convert.convert_urls cimport ( ) +@acquire_spill_lock() def url_decode(Column source_strings): """ Decode each string in column. No format checking is performed. @@ -37,6 +40,7 @@ def url_decode(Column source_strings): ) +@acquire_spill_lock() def url_encode(Column source_strings): """ Encode each string in column. No format checking is performed. diff --git a/python/cudf/cudf/_lib/strings/extract.pyx b/python/cudf/cudf/_lib/strings/extract.pyx index 439c1546381..7d16e3e839d 100644 --- a/python/cudf/cudf/_lib/strings/extract.pyx +++ b/python/cudf/cudf/_lib/strings/extract.pyx @@ -5,6 +5,8 @@ from libcpp.memory cimport unique_ptr from libcpp.string cimport string from libcpp.utility cimport move +from cudf.core.buffer import acquire_spill_lock + from cudf._lib.column cimport Column from cudf._lib.cpp.column.column_view cimport column_view from cudf._lib.cpp.strings.extract cimport extract as cpp_extract @@ -13,6 +15,7 @@ from cudf._lib.cpp.table.table cimport table from cudf._lib.utils cimport data_from_unique_ptr +@acquire_spill_lock() def extract(Column source_strings, object pattern, uint32_t flags): """ Returns data which contains extracted capture groups provided in diff --git a/python/cudf/cudf/_lib/strings/find.pyx b/python/cudf/cudf/_lib/strings/find.pyx index 788c0a2524a..fb7acb54293 100644 --- a/python/cudf/cudf/_lib/strings/find.pyx +++ b/python/cudf/cudf/_lib/strings/find.pyx @@ -3,6 +3,8 @@ from libcpp.memory cimport unique_ptr from libcpp.utility cimport move +from cudf.core.buffer import acquire_spill_lock + from cudf._lib.column cimport Column from cudf._lib.cpp.column.column cimport column from cudf._lib.cpp.column.column_view cimport column_view @@ -18,6 +20,7 @@ from cudf._lib.cpp.types cimport size_type from cudf._lib.scalar cimport DeviceScalar +@acquire_spill_lock() def contains(Column source_strings, object py_target): """ Returns a Column of boolean values with True for `source_strings` @@ -41,6 +44,7 @@ def contains(Column source_strings, object py_target): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def contains_multiple(Column source_strings, Column target_strings): """ Returns a Column of boolean values with True for `source_strings` @@ -59,6 +63,7 @@ def contains_multiple(Column source_strings, Column target_strings): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def endswith(Column source_strings, object py_target): """ Returns a Column of boolean values with True for `source_strings` @@ -83,6 +88,7 @@ def endswith(Column source_strings, object py_target): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def endswith_multiple(Column source_strings, Column target_strings): """ Returns a Column of boolean values with True for `source_strings` @@ -102,6 +108,7 @@ def endswith_multiple(Column source_strings, Column target_strings): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def startswith(Column source_strings, object py_target): """ Returns a Column of boolean values with True for `source_strings` @@ -126,6 +133,7 @@ def startswith(Column source_strings, object py_target): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def startswith_multiple(Column source_strings, Column target_strings): """ Returns a Column of boolean values with True for `source_strings` @@ -145,6 +153,7 @@ def startswith_multiple(Column source_strings, Column target_strings): return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def find(Column source_strings, object py_target, size_type start, @@ -176,6 +185,7 @@ def find(Column source_strings, return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def rfind(Column source_strings, object py_target, size_type start, diff --git a/python/cudf/cudf/_lib/strings/find_multiple.pyx b/python/cudf/cudf/_lib/strings/find_multiple.pyx index 4ac86ce4ef5..9a0c0576a4b 100644 --- a/python/cudf/cudf/_lib/strings/find_multiple.pyx +++ b/python/cudf/cudf/_lib/strings/find_multiple.pyx @@ -3,6 +3,8 @@ from libcpp.memory cimport unique_ptr from libcpp.utility cimport move +from cudf.core.buffer import acquire_spill_lock + from cudf._lib.column cimport Column from cudf._lib.cpp.column.column cimport column from cudf._lib.cpp.column.column_view cimport column_view @@ -11,6 +13,7 @@ from cudf._lib.cpp.strings.find_multiple cimport ( ) +@acquire_spill_lock() def find_multiple(Column source_strings, Column target_strings): """ Returns a column with character position values where each diff --git a/python/cudf/cudf/_lib/strings/findall.pyx b/python/cudf/cudf/_lib/strings/findall.pyx index be34ce1fb18..4080d346142 100644 --- a/python/cudf/cudf/_lib/strings/findall.pyx +++ b/python/cudf/cudf/_lib/strings/findall.pyx @@ -5,6 +5,8 @@ from libcpp.memory cimport unique_ptr from libcpp.string cimport string from libcpp.utility cimport move +from cudf.core.buffer import acquire_spill_lock + from cudf._lib.column cimport Column from cudf._lib.cpp.column.column cimport column from cudf._lib.cpp.column.column_view cimport column_view @@ -12,6 +14,7 @@ from cudf._lib.cpp.strings.findall cimport findall as cpp_findall from cudf._lib.cpp.strings.regex_flags cimport regex_flags +@acquire_spill_lock() def findall(Column source_strings, object pattern, uint32_t flags): """ Returns data with all non-overlapping matches of `pattern` diff --git a/python/cudf/cudf/_lib/strings/json.pyx b/python/cudf/cudf/_lib/strings/json.pyx index 9dbc932d842..861e0daa6e3 100644 --- a/python/cudf/cudf/_lib/strings/json.pyx +++ b/python/cudf/cudf/_lib/strings/json.pyx @@ -3,6 +3,8 @@ from libcpp.memory cimport unique_ptr from libcpp.utility cimport move +from cudf.core.buffer import acquire_spill_lock + from cudf._lib.column cimport Column from cudf._lib.cpp.column.column cimport column from cudf._lib.cpp.column.column_view cimport column_view @@ -14,6 +16,7 @@ from cudf._lib.cpp.strings.json cimport ( from cudf._lib.scalar cimport DeviceScalar +@acquire_spill_lock() def get_json_object( Column col, object py_json_path, GetJsonObjectOptions options): """ diff --git a/python/cudf/cudf/_lib/strings/padding.pyx b/python/cudf/cudf/_lib/strings/padding.pyx index f53feab7936..340d7eb52d8 100644 --- a/python/cudf/cudf/_lib/strings/padding.pyx +++ b/python/cudf/cudf/_lib/strings/padding.pyx @@ -1,16 +1,17 @@ # Copyright (c) 2020-2022, NVIDIA CORPORATION. from libcpp.memory cimport unique_ptr +from libcpp.string cimport string from libcpp.utility cimport move +from cudf.core.buffer import acquire_spill_lock + from cudf._lib.column cimport Column from cudf._lib.cpp.column.column_view cimport column_view from cudf._lib.cpp.types cimport size_type from enum import IntEnum -from libcpp.string cimport string - from cudf._lib.cpp.column.column cimport column from cudf._lib.cpp.strings.padding cimport pad as cpp_pad, zfill as cpp_zfill from cudf._lib.cpp.strings.side_type cimport ( @@ -25,6 +26,7 @@ class SideType(IntEnum): BOTH = side_type.BOTH +@acquire_spill_lock() def pad(Column source_strings, size_type width, fill_char, @@ -55,6 +57,7 @@ def pad(Column source_strings, return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def zfill(Column source_strings, size_type width): """ @@ -73,6 +76,7 @@ def zfill(Column source_strings, return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def center(Column source_strings, size_type width, fill_char): @@ -97,6 +101,7 @@ def center(Column source_strings, return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def ljust(Column source_strings, size_type width, fill_char): @@ -120,6 +125,7 @@ def ljust(Column source_strings, return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def rjust(Column source_strings, size_type width, fill_char): diff --git a/python/cudf/cudf/_lib/strings/repeat.pyx b/python/cudf/cudf/_lib/strings/repeat.pyx index 49a46f418b1..608e207f9c3 100644 --- a/python/cudf/cudf/_lib/strings/repeat.pyx +++ b/python/cudf/cudf/_lib/strings/repeat.pyx @@ -3,6 +3,8 @@ from libcpp.memory cimport unique_ptr from libcpp.utility cimport move +from cudf.core.buffer import acquire_spill_lock + from cudf._lib.column cimport Column from cudf._lib.cpp.column.column cimport column from cudf._lib.cpp.column.column_view cimport column_view @@ -10,6 +12,7 @@ from cudf._lib.cpp.strings cimport repeat as cpp_repeat from cudf._lib.cpp.types cimport size_type +@acquire_spill_lock() def repeat_scalar(Column source_strings, size_type repeats): """ @@ -29,6 +32,7 @@ def repeat_scalar(Column source_strings, return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def repeat_sequence(Column source_strings, Column repeats): """ diff --git a/python/cudf/cudf/_lib/strings/replace.pyx b/python/cudf/cudf/_lib/strings/replace.pyx index 72d66d9a8e3..80c9ba95fd8 100644 --- a/python/cudf/cudf/_lib/strings/replace.pyx +++ b/python/cudf/cudf/_lib/strings/replace.pyx @@ -4,6 +4,8 @@ from libc.stdint cimport int32_t from libcpp.memory cimport unique_ptr from libcpp.utility cimport move +from cudf.core.buffer import acquire_spill_lock + from cudf._lib.column cimport Column from cudf._lib.cpp.column.column cimport column from cudf._lib.cpp.column.column_view cimport column_view @@ -16,6 +18,7 @@ from cudf._lib.cpp.types cimport size_type from cudf._lib.scalar cimport DeviceScalar +@acquire_spill_lock() def slice_replace(Column source_strings, size_type start, size_type stop, @@ -46,6 +49,7 @@ def slice_replace(Column source_strings, return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def insert(Column source_strings, size_type start, object py_repl): @@ -74,6 +78,7 @@ def insert(Column source_strings, return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def replace(Column source_strings, object py_target, object py_repl, @@ -107,6 +112,7 @@ def replace(Column source_strings, return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def replace_multi(Column source_strings, Column target_strings, Column repl_strings): diff --git a/python/cudf/cudf/_lib/strings/replace_re.pyx b/python/cudf/cudf/_lib/strings/replace_re.pyx index 20fb903c60c..c37b13c8aa0 100644 --- a/python/cudf/cudf/_lib/strings/replace_re.pyx +++ b/python/cudf/cudf/_lib/strings/replace_re.pyx @@ -5,6 +5,8 @@ from libcpp.string cimport string from libcpp.utility cimport move from libcpp.vector cimport vector +from cudf.core.buffer import acquire_spill_lock + from cudf._lib.column cimport Column from cudf._lib.cpp.column.column cimport column from cudf._lib.cpp.column.column_view cimport column_view @@ -17,6 +19,7 @@ from cudf._lib.cpp.types cimport size_type from cudf._lib.scalar cimport DeviceScalar +@acquire_spill_lock() def replace_re(Column source_strings, object pattern, object py_repl, @@ -48,6 +51,7 @@ def replace_re(Column source_strings, return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def replace_with_backrefs( Column source_strings, object pattern, @@ -73,6 +77,7 @@ def replace_with_backrefs( return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def replace_multi_re(Column source_strings, object patterns, Column repl_strings): diff --git a/python/cudf/cudf/_lib/strings/split/partition.pyx b/python/cudf/cudf/_lib/strings/split/partition.pyx index b17ea4e608d..281d131372a 100644 --- a/python/cudf/cudf/_lib/strings/split/partition.pyx +++ b/python/cudf/cudf/_lib/strings/split/partition.pyx @@ -3,6 +3,8 @@ from libcpp.memory cimport unique_ptr from libcpp.utility cimport move +from cudf.core.buffer import acquire_spill_lock + from cudf._lib.column cimport Column from cudf._lib.cpp.column.column_view cimport column_view from cudf._lib.cpp.scalar.scalar cimport string_scalar @@ -15,6 +17,7 @@ from cudf._lib.scalar cimport DeviceScalar from cudf._lib.utils cimport data_from_unique_ptr +@acquire_spill_lock() def partition(Column source_strings, object py_delimiter): """ @@ -42,6 +45,7 @@ def partition(Column source_strings, ) +@acquire_spill_lock() def rpartition(Column source_strings, object py_delimiter): """ diff --git a/python/cudf/cudf/_lib/strings/split/split.pyx b/python/cudf/cudf/_lib/strings/split/split.pyx index e96c911e83a..7a84cf75e37 100644 --- a/python/cudf/cudf/_lib/strings/split/split.pyx +++ b/python/cudf/cudf/_lib/strings/split/split.pyx @@ -4,6 +4,8 @@ from libcpp.memory cimport unique_ptr from libcpp.string cimport string from libcpp.utility cimport move +from cudf.core.buffer import acquire_spill_lock + from cudf._lib.column cimport Column from cudf._lib.cpp.column.column cimport column from cudf._lib.cpp.column.column_view cimport column_view @@ -24,6 +26,7 @@ from cudf._lib.scalar cimport DeviceScalar from cudf._lib.utils cimport data_from_unique_ptr +@acquire_spill_lock() def split(Column source_strings, object py_delimiter, size_type maxsplit): @@ -54,6 +57,7 @@ def split(Column source_strings, ) +@acquire_spill_lock() def split_record(Column source_strings, object py_delimiter, size_type maxsplit): @@ -83,6 +87,7 @@ def split_record(Column source_strings, ) +@acquire_spill_lock() def rsplit(Column source_strings, object py_delimiter, size_type maxsplit): @@ -113,6 +118,7 @@ def rsplit(Column source_strings, ) +@acquire_spill_lock() def rsplit_record(Column source_strings, object py_delimiter, size_type maxsplit): @@ -142,6 +148,7 @@ def rsplit_record(Column source_strings, ) +@acquire_spill_lock() def split_re(Column source_strings, object pattern, size_type maxsplit): @@ -166,6 +173,7 @@ def split_re(Column source_strings, ) +@acquire_spill_lock() def rsplit_re(Column source_strings, object pattern, size_type maxsplit): @@ -191,6 +199,7 @@ def rsplit_re(Column source_strings, ) +@acquire_spill_lock() def split_record_re(Column source_strings, object pattern, size_type maxsplit): @@ -214,6 +223,7 @@ def split_record_re(Column source_strings, ) +@acquire_spill_lock() def rsplit_record_re(Column source_strings, object pattern, size_type maxsplit): diff --git a/python/cudf/cudf/_lib/strings/strip.pyx b/python/cudf/cudf/_lib/strings/strip.pyx index da3efe33786..2c53782d6ba 100644 --- a/python/cudf/cudf/_lib/strings/strip.pyx +++ b/python/cudf/cudf/_lib/strings/strip.pyx @@ -3,6 +3,8 @@ from libcpp.memory cimport unique_ptr from libcpp.utility cimport move +from cudf.core.buffer import acquire_spill_lock + from cudf._lib.column cimport Column from cudf._lib.cpp.column.column cimport column from cudf._lib.cpp.column.column_view cimport column_view @@ -12,6 +14,7 @@ from cudf._lib.cpp.strings.strip cimport strip as cpp_strip from cudf._lib.scalar cimport DeviceScalar +@acquire_spill_lock() def strip(Column source_strings, object py_repl): """ @@ -39,6 +42,7 @@ def strip(Column source_strings, return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def lstrip(Column source_strings, object py_repl): """ @@ -66,6 +70,7 @@ def lstrip(Column source_strings, return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def rstrip(Column source_strings, object py_repl): """ diff --git a/python/cudf/cudf/_lib/strings/substring.pyx b/python/cudf/cudf/_lib/strings/substring.pyx index 761e9503aba..20446f267e0 100644 --- a/python/cudf/cudf/_lib/strings/substring.pyx +++ b/python/cudf/cudf/_lib/strings/substring.pyx @@ -1,16 +1,17 @@ # Copyright (c) 2020, NVIDIA CORPORATION. +import numpy as np + from libcpp.memory cimport unique_ptr from libcpp.utility cimport move +from cudf.core.buffer import acquire_spill_lock + from cudf._lib.column cimport Column from cudf._lib.cpp.column.column cimport column from cudf._lib.cpp.column.column_view cimport column_view -from cudf._lib.cpp.types cimport size_type - -import numpy as np - from cudf._lib.cpp.strings.substring cimport slice_strings as cpp_slice_strings +from cudf._lib.cpp.types cimport size_type from cudf._lib.scalar import as_device_scalar @@ -18,6 +19,7 @@ from cudf._lib.cpp.scalar.scalar cimport numeric_scalar from cudf._lib.scalar cimport DeviceScalar +@acquire_spill_lock() def slice_strings(Column source_strings, object start, object end, @@ -54,6 +56,7 @@ def slice_strings(Column source_strings, return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def slice_from(Column source_strings, Column starts, Column stops): @@ -77,6 +80,7 @@ def slice_from(Column source_strings, return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def get(Column source_strings, object index): """ diff --git a/python/cudf/cudf/_lib/strings/translate.pyx b/python/cudf/cudf/_lib/strings/translate.pyx index 7a5cf502ba3..3f4bd3d44fb 100644 --- a/python/cudf/cudf/_lib/strings/translate.pyx +++ b/python/cudf/cudf/_lib/strings/translate.pyx @@ -6,6 +6,8 @@ from libcpp.pair cimport pair from libcpp.utility cimport move from libcpp.vector cimport vector +from cudf.core.buffer import acquire_spill_lock + from cudf._lib.column cimport Column from cudf._lib.cpp.column.column cimport column from cudf._lib.cpp.column.column_view cimport column_view @@ -19,6 +21,7 @@ from cudf._lib.cpp.types cimport char_utf8 from cudf._lib.scalar cimport DeviceScalar +@acquire_spill_lock() def translate(Column source_strings, object mapping_table): """ @@ -51,6 +54,7 @@ def translate(Column source_strings, return Column.from_unique_ptr(move(c_result)) +@acquire_spill_lock() def filter_characters(Column source_strings, object mapping_table, bool keep, diff --git a/python/cudf/cudf/_lib/strings/wrap.pyx b/python/cudf/cudf/_lib/strings/wrap.pyx index 5ebc33f77ef..0d840f59c5e 100644 --- a/python/cudf/cudf/_lib/strings/wrap.pyx +++ b/python/cudf/cudf/_lib/strings/wrap.pyx @@ -3,6 +3,8 @@ from libcpp.memory cimport unique_ptr from libcpp.utility cimport move +from cudf.core.buffer import acquire_spill_lock + from cudf._lib.column cimport Column from cudf._lib.cpp.column.column cimport column from cudf._lib.cpp.column.column_view cimport column_view @@ -10,6 +12,7 @@ from cudf._lib.cpp.strings.wrap cimport wrap as cpp_wrap from cudf._lib.cpp.types cimport size_type +@acquire_spill_lock() def wrap(Column source_strings, size_type width): """ diff --git a/python/cudf/cudf/_lib/transform.pyx b/python/cudf/cudf/_lib/transform.pyx index 1fa68282c3d..3787f1405b7 100644 --- a/python/cudf/cudf/_lib/transform.pyx +++ b/python/cudf/cudf/_lib/transform.pyx @@ -5,7 +5,7 @@ from numba.np import numpy_support import cudf from cudf._lib.types import SUPPORTED_NUMPY_TO_LIBCUDF_TYPES from cudf.core._internals.expressions import parse_expression -from cudf.core.buffer import as_buffer +from cudf.core.buffer import acquire_spill_lock, as_buffer from cudf.utils import cudautils from cython.operator cimport dereference @@ -34,6 +34,7 @@ from cudf._lib.utils cimport ( ) +@acquire_spill_lock() def bools_to_mask(Column col): """ Given an int8 (boolean) column, compress the data from booleans to bits and @@ -88,6 +89,7 @@ def nans_to_nulls(Column input): return buffer +@acquire_spill_lock() def transform(Column input, op): cdef column_view c_input = input.view() cdef string c_str @@ -132,8 +134,10 @@ def table_encode(list source_columns): with nogil: c_result = move(libcudf_transform.encode(c_input)) - return columns_from_unique_ptr( - move(c_result.first)), Column.from_unique_ptr(move(c_result.second)) + return ( + columns_from_unique_ptr(move(c_result.first)), + Column.from_unique_ptr(move(c_result.second)) + ) def one_hot_encode(Column input_column, Column categories): @@ -160,10 +164,10 @@ def one_hot_encode(Column input_column, Column categories): x if x is not None else 'null' for x in pylist_categories ] ) - return encodings +@acquire_spill_lock() def compute_column(list columns, tuple column_names, expr: str): """Compute a new column by evaluating an expression on a set of columns. From 92548c50ba45c57bf054602e32ffa4dccc1bbfea Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 15 Nov 2022 14:54:42 +0100 Subject: [PATCH 31/31] copyrights --- python/cudf/cudf/_lib/nvtext/generate_ngrams.pyx | 2 +- python/cudf/cudf/_lib/nvtext/ngrams_tokenize.pyx | 2 +- python/cudf/cudf/_lib/nvtext/normalize.pyx | 2 +- python/cudf/cudf/_lib/nvtext/replace.pyx | 2 +- python/cudf/cudf/_lib/nvtext/stemmer.pyx | 2 +- python/cudf/cudf/_lib/round.pyx | 2 +- python/cudf/cudf/_lib/strings/attributes.pyx | 2 +- python/cudf/cudf/_lib/strings/capitalize.pyx | 2 +- python/cudf/cudf/_lib/strings/case.pyx | 2 +- python/cudf/cudf/_lib/strings/char_types.pyx | 2 +- python/cudf/cudf/_lib/strings/convert/convert_lists.pyx | 2 +- python/cudf/cudf/_lib/strings/find.pyx | 2 +- python/cudf/cudf/_lib/strings/find_multiple.pyx | 2 +- python/cudf/cudf/_lib/strings/repeat.pyx | 2 +- python/cudf/cudf/_lib/strings/replace_re.pyx | 2 +- python/cudf/cudf/_lib/strings/substring.pyx | 2 +- python/cudf/cudf/_lib/strings/translate.pyx | 2 +- python/cudf/cudf/_lib/strings/wrap.pyx | 2 +- 18 files changed, 18 insertions(+), 18 deletions(-) diff --git a/python/cudf/cudf/_lib/nvtext/generate_ngrams.pyx b/python/cudf/cudf/_lib/nvtext/generate_ngrams.pyx index c2c32314f49..7be3b0f7c03 100644 --- a/python/cudf/cudf/_lib/nvtext/generate_ngrams.pyx +++ b/python/cudf/cudf/_lib/nvtext/generate_ngrams.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2020, NVIDIA CORPORATION. +# Copyright (c) 2018-2022, NVIDIA CORPORATION. from cudf.core.buffer import acquire_spill_lock diff --git a/python/cudf/cudf/_lib/nvtext/ngrams_tokenize.pyx b/python/cudf/cudf/_lib/nvtext/ngrams_tokenize.pyx index 104741f2ee8..3e7911c8ae8 100644 --- a/python/cudf/cudf/_lib/nvtext/ngrams_tokenize.pyx +++ b/python/cudf/cudf/_lib/nvtext/ngrams_tokenize.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2020, NVIDIA CORPORATION. +# Copyright (c) 2018-2022, NVIDIA CORPORATION. from cudf.core.buffer import acquire_spill_lock diff --git a/python/cudf/cudf/_lib/nvtext/normalize.pyx b/python/cudf/cudf/_lib/nvtext/normalize.pyx index fa86e580aca..80c6ef792ab 100644 --- a/python/cudf/cudf/_lib/nvtext/normalize.pyx +++ b/python/cudf/cudf/_lib/nvtext/normalize.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2020, NVIDIA CORPORATION. +# Copyright (c) 2018-2022, NVIDIA CORPORATION. from cudf.core.buffer import acquire_spill_lock diff --git a/python/cudf/cudf/_lib/nvtext/replace.pyx b/python/cudf/cudf/_lib/nvtext/replace.pyx index 535816a6066..289e5611010 100644 --- a/python/cudf/cudf/_lib/nvtext/replace.pyx +++ b/python/cudf/cudf/_lib/nvtext/replace.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. from cudf.core.buffer import acquire_spill_lock diff --git a/python/cudf/cudf/_lib/nvtext/stemmer.pyx b/python/cudf/cudf/_lib/nvtext/stemmer.pyx index c8a93f8e67d..7a76052ffe4 100644 --- a/python/cudf/cudf/_lib/nvtext/stemmer.pyx +++ b/python/cudf/cudf/_lib/nvtext/stemmer.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. from cudf.core.buffer import acquire_spill_lock diff --git a/python/cudf/cudf/_lib/round.pyx b/python/cudf/cudf/_lib/round.pyx index b62b5a4bb34..7eddb1b8cbd 100644 --- a/python/cudf/cudf/_lib/round.pyx +++ b/python/cudf/cudf/_lib/round.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. +# Copyright (c) 2021-2022, NVIDIA CORPORATION. from cudf.core.buffer import acquire_spill_lock diff --git a/python/cudf/cudf/_lib/strings/attributes.pyx b/python/cudf/cudf/_lib/strings/attributes.pyx index 4add4aa8e8c..c1b69dda353 100644 --- a/python/cudf/cudf/_lib/strings/attributes.pyx +++ b/python/cudf/cudf/_lib/strings/attributes.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2020, NVIDIA CORPORATION. +# Copyright (c) 2018-2022, NVIDIA CORPORATION. from cudf.core.buffer import acquire_spill_lock diff --git a/python/cudf/cudf/_lib/strings/capitalize.pyx b/python/cudf/cudf/_lib/strings/capitalize.pyx index cfb0feee26c..f6a80ac8fbe 100644 --- a/python/cudf/cudf/_lib/strings/capitalize.pyx +++ b/python/cudf/cudf/_lib/strings/capitalize.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2021, NVIDIA CORPORATION. +# Copyright (c) 2018-2022, NVIDIA CORPORATION. from cudf.core.buffer import acquire_spill_lock diff --git a/python/cudf/cudf/_lib/strings/case.pyx b/python/cudf/cudf/_lib/strings/case.pyx index fbf328f9f9f..09af1178946 100644 --- a/python/cudf/cudf/_lib/strings/case.pyx +++ b/python/cudf/cudf/_lib/strings/case.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2020, NVIDIA CORPORATION. +# Copyright (c) 2018-2022, NVIDIA CORPORATION. from cudf.core.buffer import acquire_spill_lock diff --git a/python/cudf/cudf/_lib/strings/char_types.pyx b/python/cudf/cudf/_lib/strings/char_types.pyx index 25294d0d626..eb03d7c2192 100644 --- a/python/cudf/cudf/_lib/strings/char_types.pyx +++ b/python/cudf/cudf/_lib/strings/char_types.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. +# Copyright (c) 2021-2022, NVIDIA CORPORATION. from libcpp cimport bool diff --git a/python/cudf/cudf/_lib/strings/convert/convert_lists.pyx b/python/cudf/cudf/_lib/strings/convert/convert_lists.pyx index 61869411183..33f6d4a4af7 100644 --- a/python/cudf/cudf/_lib/strings/convert/convert_lists.pyx +++ b/python/cudf/cudf/_lib/strings/convert/convert_lists.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. +# Copyright (c) 2021-2022, NVIDIA CORPORATION. from libcpp.memory cimport unique_ptr from libcpp.utility cimport move diff --git a/python/cudf/cudf/_lib/strings/find.pyx b/python/cudf/cudf/_lib/strings/find.pyx index fb7acb54293..f6dd3b80de9 100644 --- a/python/cudf/cudf/_lib/strings/find.pyx +++ b/python/cudf/cudf/_lib/strings/find.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. from libcpp.memory cimport unique_ptr from libcpp.utility cimport move diff --git a/python/cudf/cudf/_lib/strings/find_multiple.pyx b/python/cudf/cudf/_lib/strings/find_multiple.pyx index 9a0c0576a4b..c2a97a4fd7c 100644 --- a/python/cudf/cudf/_lib/strings/find_multiple.pyx +++ b/python/cudf/cudf/_lib/strings/find_multiple.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. from libcpp.memory cimport unique_ptr from libcpp.utility cimport move diff --git a/python/cudf/cudf/_lib/strings/repeat.pyx b/python/cudf/cudf/_lib/strings/repeat.pyx index 608e207f9c3..4896fb74f41 100644 --- a/python/cudf/cudf/_lib/strings/repeat.pyx +++ b/python/cudf/cudf/_lib/strings/repeat.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. +# Copyright (c) 2021-2022, NVIDIA CORPORATION. from libcpp.memory cimport unique_ptr from libcpp.utility cimport move diff --git a/python/cudf/cudf/_lib/strings/replace_re.pyx b/python/cudf/cudf/_lib/strings/replace_re.pyx index c37b13c8aa0..73911538db2 100644 --- a/python/cudf/cudf/_lib/strings/replace_re.pyx +++ b/python/cudf/cudf/_lib/strings/replace_re.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. from libcpp.memory cimport unique_ptr from libcpp.string cimport string diff --git a/python/cudf/cudf/_lib/strings/substring.pyx b/python/cudf/cudf/_lib/strings/substring.pyx index 20446f267e0..57bca09ee0e 100644 --- a/python/cudf/cudf/_lib/strings/substring.pyx +++ b/python/cudf/cudf/_lib/strings/substring.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. import numpy as np diff --git a/python/cudf/cudf/_lib/strings/translate.pyx b/python/cudf/cudf/_lib/strings/translate.pyx index 3f4bd3d44fb..262d479d914 100644 --- a/python/cudf/cudf/_lib/strings/translate.pyx +++ b/python/cudf/cudf/_lib/strings/translate.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2020, NVIDIA CORPORATION. +# Copyright (c) 2018-2022, NVIDIA CORPORATION. from libcpp cimport bool from libcpp.memory cimport unique_ptr diff --git a/python/cudf/cudf/_lib/strings/wrap.pyx b/python/cudf/cudf/_lib/strings/wrap.pyx index 0d840f59c5e..8b0c367e791 100644 --- a/python/cudf/cudf/_lib/strings/wrap.pyx +++ b/python/cudf/cudf/_lib/strings/wrap.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. from libcpp.memory cimport unique_ptr from libcpp.utility cimport move