Skip to content

Commit

Permalink
feat: from_map like optimization for dask arrays (#679)
Browse files Browse the repository at this point in the history
* Token change to get PR number

* Revert "Token change to get PR number"

This reverts commit ce59fe1.

* from_map optimization: Tests pass

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks"

This reverts commit 6bc1c39.

* Solve issues with merge and add 2 noqa's

* Changes as suggested in #680

* Use right label for dask arrays

* Import dask.blockwise in extras

* Corrected issue with error message during merge

* Remove specia keyword used in tests for assert_eq

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
kkothari2001 and pre-commit-ci[bot] authored Sep 1, 2022
1 parent 92e91f3 commit 49d40bd
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 24 deletions.
201 changes: 180 additions & 21 deletions src/uproot/_dask.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections.abc import Callable, Iterable, Mapping

import numpy

import uproot
Expand Down Expand Up @@ -168,6 +170,163 @@ def dask(
raise NotImplementedError()


class _PackedArgCallable:
"""Wrap a callable such that packed arguments can be unrolled.
Inspired by dask.dataframe.io.io._PackedArgCallable.
"""

def __init__(
self,
func: Callable,
args=None,
kwargs=None,
packed: bool = False,
):
self.func = func
self.args = args
self.kwargs = kwargs
self.packed = packed

def __call__(self, packed_arg):
if not self.packed:
packed_arg = (packed_arg,)
return self.func(
*packed_arg,
*(self.args or []),
**(self.kwargs or {}),
)


class LazyInputsDict(Mapping):
"""Dictionary with lazy key value pairs
Parameters
----------
inputs : list[Any]
The list of dicionary values.
"""

def __init__(self, inputs, **kwargs) -> None:
self.inputs = inputs
self.kwargs = kwargs

def __len__(self):
return len(self.inputs)

def __iter__(self):
return (self[k] for k in self.keys())

def __getitem__(self, i):
return self.inputs[i[0]]

def __contains__(self, k):
if isinstance(k, tuple):
if isinstance(k[0], int):
return k[0] >= 0 and k[0] < len(self)
return False

def keys(self):
return ((i,) for i in range(len(self.inputs)))


def _dask_array_from_map(
func,
*iterables,
chunks,
dtype,
args=None,
label=None,
token=None,
**kwargs,
):
dask = uproot.extras.dask()
da = uproot.extras.dask_array()
if not callable(func):
raise ValueError("`func` argument must be `callable`")
lengths = set()
iters = list(iterables)
for i, iterable in enumerate(iters):
if not isinstance(iterable, Iterable):
raise ValueError(
f"All elements of `iterables` must be Iterable, got {type(iterable)}"
)
try:
lengths.add(len(iterable)) # type: ignore
except (AttributeError, TypeError):
iters[i] = list(iterable)
lengths.add(len(iters[i])) # type: ignore
if len(lengths) == 0:
raise ValueError("`from_map` requires at least one Iterable input")
elif len(lengths) > 1:
raise ValueError("All `iterables` must have the same length")
if lengths == {0}:
raise ValueError("All `iterables` must have a non-zero length")

# Check for `produces_tasks` and `creation_info`
produces_tasks = kwargs.pop("produces_tasks", False)
# creation_info = kwargs.pop("creation_info", None)

if produces_tasks or len(iters) == 1:
if len(iters) > 1:
# Tasks are not detected correctly when they are "packed"
# within an outer list/tuple
raise ValueError(
"Multiple iterables not supported when produces_tasks=True"
)
inputs = list(iters[0])
packed = False
else:
# Structure inputs such that the tuple of arguments pair each 0th,
# 1st, 2nd, ... elements together; for example:
# from_map(f, [1, 2, 3], [4, 5, 6]) --> [f(1, 4), f(2, 5), f(3, 6)]
inputs = list(zip(*iters))
packed = True

# Define collection name
label = label or dask.utils.funcname(func)
token = token or dask.base.tokenize(func, iters, **kwargs)
name = f"{label}-{token}"

# Define io_func
if packed or args or kwargs:
io_func = _PackedArgCallable(
func,
args=args,
kwargs=kwargs,
packed=packed,
)
else:
io_func = func

io_arg_map = dask.blockwise.BlockwiseDepDict(
mapping=LazyInputsDict(inputs), # type: ignore
produces_tasks=produces_tasks,
)

dsk = dask.blockwise.Blockwise(
output=name,
output_indices="i",
dsk={name: (io_func, dask.blockwise.blockwise_token(0))},
indices=[(io_arg_map, "i")],
numblocks={},
annotations=None,
)

hlg = dask.highlevelgraph.HighLevelGraph.from_collections(name, dsk)
return da.core.Array(hlg, name, chunks, dtype=dtype)


class _UprootReadNumpy:
def __init__(self, hasbranches, key) -> None:
self.hasbranches = hasbranches
self.key = key

def __call__(self, i_start_stop):
i, start, stop = i_start_stop
return self.hasbranches[i][self.key].array(
entry_start=start, entry_stop=stop, library="np"
)


def _get_dask_array(
files,
filter_name=no_filter,
Expand All @@ -180,7 +339,6 @@ def _get_dask_array(
allow_missing=False,
real_options=None,
):
dask, da = uproot.extras.dask()
hasbranches = []
common_keys = None
is_self = []
Expand Down Expand Up @@ -254,13 +412,16 @@ def real_filter_branch(branch):

dask_dict = {}

@dask.delayed
def delayed_get_array(ttree, key, start, stop):
return ttree[key].array(library="np", entry_start=start, entry_stop=stop)

for key in common_keys:
dask_arrays = []
for ttree in hasbranches:
dt = hasbranches[0][key].interpretation.numpy_dtype
if dt.subdtype is None:
inner_shape = ()
else:
dt, inner_shape = dt.subdtype

chunks = []
chunk_args = []
for i, ttree in enumerate(hasbranches):
entry_start, entry_stop = _regularize_entries_start_stop(
ttree.tree.num_entries, None, None
)
Expand All @@ -270,26 +431,23 @@ def delayed_get_array(ttree, key, start, stop):
else:
entry_step = ttree.num_entries_for(step_size, expressions=f"{key}")

dt = ttree[key].interpretation.numpy_dtype
if dt.subdtype is None:
inner_shape = ()
else:
dt, inner_shape = dt.subdtype

def foreach(start):
stop = min(start + entry_step, entry_stop) # noqa: B023
length = stop - start

delayed_array = delayed_get_array(ttree, key, start, stop) # noqa: B023
shape = (length,) + inner_shape # noqa: B023
dask_arrays.append( # noqa: B023
da.from_delayed(delayed_array, shape=shape, dtype=dt) # noqa: B023
)
chunks.append(length) # noqa: B023
chunk_args.append((i, start, stop)) # noqa: B023

for start in range(entry_start, entry_stop, entry_step):
foreach(start)

dask_dict[key] = da.concatenate(dask_arrays)
dask_dict[key] = _dask_array_from_map(
_UprootReadNumpy(hasbranches, key),
chunk_args,
chunks=(tuple(chunks),),
dtype=dt,
label=f"{key}-from-uproot",
)

return dask_dict


Expand All @@ -304,7 +462,8 @@ def _get_dask_array_delay_open(
allow_missing=False,
real_options=None,
):
dask, da = uproot.extras.dask()
dask = uproot.extras.dask()
da = uproot.extras.dask_array()
ffile_path, fobject_path = files[0]
obj = uproot._util.regularize_object_path(
ffile_path, fobject_path, custom_classes, allow_missing, real_options
Expand Down
19 changes: 18 additions & 1 deletion src/uproot/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,23 @@ def dask():
"""
try:
import dask
import dask.blockwise
except ModuleNotFoundError as err:
raise ModuleNotFoundError(
"""for uproot.dask with 'library="np"', install the complete 'dask' package with:
pip install "dask[complete]"
or
conda install dask"""
) from err
else:
return dask


def dask_array():
"""
Imports and returns ``dask.array``.
"""
try:
import dask.array as da
except ModuleNotFoundError as err:
raise ModuleNotFoundError(
Expand All @@ -272,7 +289,7 @@ def dask():
conda install dask"""
) from err
else:
return dask, da
return da


def dask_awkward():
Expand Down
4 changes: 2 additions & 2 deletions tests/test_0652_dask-for-awkward.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_dask_concatenation():
[test_path1, test_path2, test_path3, test_path4], library="ak"
)

assert_eq(dak_array, ak_array, check_unconcat_form=False)
assert_eq(dak_array, ak_array)


def test_multidim_array():
Expand Down Expand Up @@ -87,4 +87,4 @@ def test_delay_open():
[test_path1, test_path2, test_path3, test_path4], open_files=False, library="ak"
)

assert_eq(dak_array, ak_array, check_unconcat_form=False)
assert_eq(dak_array, ak_array)

0 comments on commit 49d40bd

Please sign in to comment.