Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: use nplike.asarray #2134

Merged
merged 12 commits into from
Jan 18, 2023
8 changes: 4 additions & 4 deletions src/awkward/_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import awkward_cpp

import awkward as ak
from awkward._kernels import CupyKernel, JaxKernel, NumpyKernel
from awkward._kernels import CupyKernel, JaxKernel, NumpyKernel, TypeTracerKernel
jpivarski marked this conversation as resolved.
Show resolved Hide resolved
from awkward._nplikes import Cupy, Jax, Numpy, NumpyLike, NumpyMetadata, nplike_of
from awkward._singleton import Singleton
from awkward._typetracer import NoKernel, TypeTracer
from awkward._typetracer import TypeTracer
from awkward.typing import Callable, Final, Tuple, TypeAlias, TypeVar, Unpack

np = NumpyMetadata.instance()
Expand Down Expand Up @@ -160,8 +160,8 @@ def index_nplike(self) -> TypeTracer:
def __init__(self):
self._typetracer = TypeTracer.instance()

def __getitem__(self, index: KernelKeyType) -> NoKernel:
return NoKernel(index)
def __getitem__(self, index: KernelKeyType) -> TypeTracerKernel:
return TypeTracerKernel(index)


def _backend_for_nplike(nplike: ak._nplikes.NumpyLike) -> Backend:
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_broadcasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def continuation():

combos = backend.index_nplike.stack(tagslist, axis=-1)

all_combos = backend.index_nplike.array(
all_combos = backend.index_nplike.asarray(
list(itertools.product(*[range(x) for x in numtags])),
dtype=[(str(i), combos.dtype) for i in range(len(tagslist))],
)
Expand Down
10 changes: 7 additions & 3 deletions src/awkward/_connect/jax/reducers.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def apply(cls, array, parents, outlength):

if array.dtype.kind == "m":
return ak.contents.NumpyArray(
array.backend.nplike.asarray(result, array.dtype)
array.backend.nplike.asarray(result, dtype=array.dtype)
)
elif array.dtype.type in (np.complex128, np.complex64):
return ak.contents.NumpyArray(result.view(array.dtype))
Expand Down Expand Up @@ -184,7 +184,9 @@ def apply(cls, array, parents, outlength):

if array.dtype.type in (np.complex128, np.complex64):
return ak.contents.NumpyArray(
array.backend.nplike.array(result.view(array.dtype), array.dtype),
array.backend.nplike.asarray(
result.view(array.dtype), dtype=array.dtype
),
backend=array.backend,
)
else:
Expand Down Expand Up @@ -230,7 +232,9 @@ def apply(cls, array, parents, outlength):
result = jax.numpy.maximum(result, cls._max_initial(cls.initial, array.dtype))
if array.dtype.type in (np.complex128, np.complex64):
return ak.contents.NumpyArray(
array.backend.nplike.array(result.view(array.dtype), array.dtype),
array.backend.nplike.asarray(
result.view(array.dtype), dtype=array.dtype
),
backend=array.backend,
)
else:
Expand Down
7 changes: 4 additions & 3 deletions src/awkward/_connect/numba/arrayview.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numba
import numba.core.typing
import numba.core.typing.ctypes_utils
import numpy

import awkward as ak

Expand Down Expand Up @@ -865,7 +866,7 @@ def array_supported(dtype):
) or isinstance(dtype, (numba.types.NPDatetime, numba.types.NPTimedelta))


@numba.extending.overload(ak._nplikes.numpy.array)
@numba.extending.overload(numpy.array)
def overload_np_array(array, dtype=None):
if isinstance(array, ArrayViewType):
ndim = array.type.ndim
Expand Down Expand Up @@ -934,7 +935,7 @@ def array_impl(array, dtype=None):
)


@numba.extending.type_callable(ak._nplikes.numpy.asarray)
@numba.extending.type_callable(numpy.asarray)
def type_asarray(context):
def typer(arrayview):
if (
Expand All @@ -948,7 +949,7 @@ def typer(arrayview):
return typer


@numba.extending.lower_builtin(ak._nplikes.numpy.asarray, ArrayViewType)
@numba.extending.lower_builtin(numpy.asarray, ArrayViewType)
def lower_asarray(context, builder, sig, args):
rettype, (viewtype,) = sig.return_type, sig.args
(viewval,) = args
Expand Down
4 changes: 1 addition & 3 deletions src/awkward/_connect/numba/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
import numba
import numba.core.typing
import numba.core.typing.ctypes_utils
import numpy
from awkward_cpp import libawkward

import awkward as ak

numpy = ak._nplikes.Numpy.instance()


dynamic_addrs = {}


Expand Down
3 changes: 0 additions & 3 deletions src/awkward/_connect/numba/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@

import awkward as ak

np = ak._nplikes.NumpyMetadata.instance()
numpy = ak._nplikes.Numpy.instance()


@numba.extending.typeof_impl.register(ak.contents.Content)
@numba.extending.typeof_impl.register(ak.index.Index)
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def arrayptr(x):
self.nplike = layout.backend.nplike
self.generator = generator
self.positions = positions
self.arrayptrs = self.nplike.array(
self.arrayptrs = self.nplike.asarray(
[arrayptr(x) for x in positions], dtype=np.intp
)

Expand Down
27 changes: 20 additions & 7 deletions src/awkward/_nplikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,26 @@ class NumpyLike(Singleton):

############################ array creation

def array(self, *args, **kwargs):
# data[, dtype=[, copy=]]
return self._module.array(*args, **kwargs)

def asarray(self, *args, **kwargs):
# array[, dtype=][, order=]
return self._module.asarray(*args, **kwargs)
def asarray(
self,
obj,
*,
dtype: numpy.dtype | None = None,
copy: bool | None = None,
):
if copy:
return self._module.array(obj, dtype=dtype, copy=True)
elif copy is None:
return self._module.asarray(obj, dtype=dtype)
else:
if getattr(obj, "dtype", dtype) != dtype:
raise ak._errors.wrap_error(
ValueError(
"asarray was called with copy=False for an array of a different dtype"
)
)
else:
return self._module.asarray(obj, dtype=dtype)

def ascontiguousarray(self, *args, **kwargs):
# array[, dtype=]
Expand Down
18 changes: 11 additions & 7 deletions src/awkward/_reducers.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def apply(self, array, parents, outlength):

if array.dtype.kind == "m":
return ak.contents.NumpyArray(
array.backend.nplike.asarray(result, array.dtype)
array.backend.nplike.asarray(result, dtype=array.dtype)
)
elif array.dtype.type in (np.complex128, np.complex64):
return ak.contents.NumpyArray(result.view(array.dtype))
Expand All @@ -348,7 +348,7 @@ def identity_for(self, dtype: np.dtype | None):
if dtype in {np.timedelta64, np.datetime64}:
return np.timedelta64(0)
else:
return numpy.array(0, dtype=dtype)[()]
return numpy.asarray(0, dtype=dtype)[()]


class Prod(Reducer):
Expand Down Expand Up @@ -435,7 +435,7 @@ def identity_for(self, dtype: np.dtype | None):
if dtype in {np.timedelta64, np.datetime64}:
return np.timedelta64(0)
else:
return numpy.array(1, dtype=dtype)[()]
return numpy.asarray(1, dtype=dtype)[()]


class Any(Reducer):
Expand Down Expand Up @@ -633,11 +633,13 @@ def apply(self, array, parents, outlength):
)
if array.dtype.type in (np.complex128, np.complex64):
return ak.contents.NumpyArray(
array.backend.nplike.array(result.view(array.dtype), array.dtype)
array.backend.nplike.asarray(
result.view(array.dtype), dtype=array.dtype
)
)
else:
return ak.contents.NumpyArray(
array.backend.nplike.array(result, array.dtype)
array.backend.nplike.asarray(result, dtype=array.dtype)
)


Expand Down Expand Up @@ -734,9 +736,11 @@ def apply(self, array, parents, outlength):
)
if array.dtype.type in (np.complex128, np.complex64):
return ak.contents.NumpyArray(
array.backend.nplike.array(result.view(array.dtype), array.dtype)
array.backend.nplike.asarray(
result.view(array.dtype), dtype=array.dtype
)
)
else:
return ak.contents.NumpyArray(
array.backend.nplike.array(result, array.dtype)
array.backend.nplike.asarray(result, dtype=array.dtype)
)
61 changes: 23 additions & 38 deletions src/awkward/_typetracer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE
from __future__ import annotations

import numbers

Expand All @@ -12,32 +13,6 @@
np = _nplikes.NumpyMetadata.instance()


class NoError:
def __init__(self):
self.str = None
self.filename = None
self.pass_through = False
self.attempt = ak._util.kSliceNone
self.id = ak._util.kSliceNone


class NoKernel:
def __init__(self, index):
self._name_and_types = index

def __call__(self, *args):
for x in args:
try_touch_data(x)
return NoError()

def __repr__(self):
return "<{} {}{}>".format(
type(self).__name__,
self._name_and_types[0],
"".join(", " + str(numpy.dtype(x)) for x in self._name_and_types[1:]),
)


class UnknownLengthType:
def __repr__(self):
return "UnknownLength"
Expand Down Expand Up @@ -712,9 +687,6 @@ def to_rectilinear(self, array, *args, **kwargs):
try_touch_shape(array)
raise ak._errors.wrap_error(NotImplementedError)

def __getitem__(self, name_and_types):
return NoKernel(name_and_types)

@property
def ma(self):
raise ak._errors.wrap_error(NotImplementedError)
Expand Down Expand Up @@ -747,15 +719,28 @@ def raw(self, array, nplike):

############################ array creation

def array(self, data, dtype=None, **kwargs):
# data[, dtype=[, copy=]]
try_touch_data(data)
return TypeTracerArray.from_array(data, dtype=dtype)

def asarray(self, array, dtype=None, **kwargs):
# array[, dtype=][, order=]
try_touch_data(array)
return TypeTracerArray.from_array(array, dtype=dtype)
def asarray(
self,
obj,
*,
dtype: numpy.dtype | None = None,
copy: bool | None = None,
):
try_touch_data(obj)
result = TypeTracerArray.from_array(obj, dtype=dtype)
# If we want a copy, by the dtypes don't match
if (
not (copy is None or copy)
and dtype is not None
and getattr(obj, "dtype", dtype) != dtype
):
raise ak._errors.wrap_error(
ValueError(
"asarray was called with copy=False for an array of a different dtype"
)
)
else:
return result

def ascontiguousarray(self, array, dtype=None, **kwargs):
# array[, dtype=]
Expand Down
8 changes: 4 additions & 4 deletions src/awkward/contents/indexedoptionarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,7 +1416,7 @@ def _pad_none(self, target, axis, depth, clip):
)

def _to_arrow(self, pyarrow, mask_node, validbytes, length, options):
index = numpy.array(self._index, copy=True)
index = numpy.asarray(self._index, copy=True)
this_validbytes = self.mask_as_bool(valid_when=True)
index[~this_validbytes] = 0

Expand Down Expand Up @@ -1468,9 +1468,9 @@ def _to_backend_array(self, allow_missing, backend):
elif issubclass(content.dtype.type, np.integer):
data[mask0] = np.iinfo(content.dtype).max
elif issubclass(content.dtype.type, (np.datetime64, np.timedelta64)):
data[mask0] = nplike.array([np.iinfo(np.int64).max], content.dtype)[
0
]
data[mask0] = nplike.asarray(
[np.iinfo(np.int64).max], dtype=content.dtype
)[0]
else:
raise ak._errors.wrap_error(
AssertionError(f"unrecognized dtype: {content.dtype}")
Expand Down
8 changes: 4 additions & 4 deletions src/awkward/contents/listoffsetarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def _getitem_range(self, where):
offsets = self._offsets[start : stop + 1]
if offsets.length == 0:
offsets = Index(
self._backend.index_nplike.array([0], dtype=self._offsets.dtype),
self._backend.index_nplike.asarray([0], dtype=self._offsets.dtype),
nplike=self._backend.index_nplike,
)
return ListOffsetArray(offsets, self._content, parameters=self._parameters)
Expand Down Expand Up @@ -1862,8 +1862,8 @@ def _to_arrow(self, pyarrow, mask_node, validbytes, length, options):
nonzeros = npoffsets[1:] != npoffsets[:-1]
maskedbytes = validbytes == 0
if numpy.any(maskedbytes & nonzeros): # null and count > 0
new_starts = numpy.array(npoffsets[:-1], copy=True)
new_stops = numpy.array(npoffsets[1:], copy=True)
new_starts = numpy.asarray(npoffsets[:-1], copy=True)
new_stops = numpy.asarray(npoffsets[1:], copy=True)
new_starts[maskedbytes] = 0
new_stops[maskedbytes] = 0
next = ak.contents.ListArray(
Expand Down Expand Up @@ -1952,7 +1952,7 @@ def _to_arrow(self, pyarrow, mask_node, validbytes, length, options):
def _to_backend_array(self, allow_missing, backend):
array_param = self.parameter("__array__")
if array_param in {"bytestring", "string"}:
return backend.nplike.array(self.to_list())
return backend.nplike.asarray(self.to_list())

return self.to_RegularArray()._to_backend_array(allow_missing, backend)

Expand Down
8 changes: 4 additions & 4 deletions src/awkward/contents/numpyarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ def to_RegularArray(self):
def maybe_to_NumpyArray(self) -> Self:
return self

def __array__(self, *args, **kwargs):
return self._backend.nplike.asarray(self._data, *args, **kwargs)
def __array__(self, dtype=None):
return self._backend.nplike.asarray(self._data, dtype=dtype)

def __iter__(self):
return iter(self._data)
Expand Down Expand Up @@ -685,7 +685,7 @@ def _unique(self, negaxis, starts, parents, outlength):
)

return ak.contents.NumpyArray(
self._backend.nplike.asarray(out[: nextlength[0]], self.dtype),
self._backend.nplike.asarray(out[: nextlength[0]], dtype=self.dtype),
parameters=None,
backend=self._backend,
)
Expand Down Expand Up @@ -997,7 +997,7 @@ def _sort_next(self, negaxis, starts, parents, outlength, ascending, stable):
)
)
return ak.contents.NumpyArray(
self._backend.nplike.asarray(out, self.dtype),
self._backend.nplike.asarray(out, dtype=self.dtype),
parameters=None,
backend=self._backend,
)
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/contents/regulararray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,7 +1141,7 @@ def _pad_none(self, target, axis, depth, clip):
def _to_backend_array(self, allow_missing, backend):
array_param = self.parameter("__array__")
if array_param in {"bytestring", "string"}:
return backend.nplike.array(self.to_list())
return backend.nplike.asarray(self.to_list())

out = self._content._to_backend_array(allow_missing, backend)
shape = (self._length, self._size) + out.shape[1:]
Expand Down
Loading