Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(add): numpy ndarray--template-less init #255

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions python/cppyy/_cpython_cppyy.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,14 @@ def __call__(self, *args):
# most common cases are covered
if args:
args0 = args[0]

if (
type(args0).__module__ == "numpy"
and type(args0).__name__ == "ndarray"
and hasattr(args0, "dtype")
):
# Handle arrays of arbitrary dimension recursively
return _np_vector(args0)
if args0 and (type(args0) is tuple or type(args0) is list):
t = type(args0[0])
if t is float: t = 'double'
Expand Down Expand Up @@ -209,3 +217,41 @@ def _end_capture_stderr():
pass
return "C++ issued an error message that could not be decoded (%s)" % str(original_error)
return ""

def _np_vector(arr):
CPP_EXPLICIT_TYPES = {"float64": "double", "int64": "long"}

def build_nested_vector_type(ndim, base_type, cache={}):
key = (ndim, base_type)
if key not in cache:
vector_t = gbl.std.vector[base_type]
for _ in range(ndim - 1):
vector_t = gbl.std.vector[vector_t]
cache[key] = vector_t
return cache[key]

def convert(arr):
ndim = arr.ndim
if arr.size > 0:
base_type = CPP_EXPLICIT_TYPES.get(
arr.dtype.type.__name__, type(arr.flat[0].item())
)
else:
base_type = float

if ndim == 1:
vector = build_nested_vector_type(1, base_type)()
vector.reserve(arr.size)
for elem in arr.flat:
vector.push_back(elem.item())
return vector

vector_type = build_nested_vector_type(ndim, base_type)
result = vector_type()
result.reserve(arr.shape[0])
for subarr in arr:
result.push_back(convert(subarr))

return result

return convert(arr)
39 changes: 39 additions & 0 deletions test/test_stltypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,45 @@ def test23_copy_conversion(self):
for f, d in zip(x, v):
assert f == d

def test_ndarray_template_less(self):
import cppyy

try:
import numpy as np
except ImportError:
self.skipTest("numpy is not installed")
dtype_mappings = {
np.int32: "int",
np.int64: "long",
np.float32: "float",
np.float64: "double",
}

shapes = [
(10,), # 1D array
(5, 5), # 2D array
(4, 4, 4), # 3D array
(2, 3, 3, 3), # 4D array
]

for np_dtype, cpp_dtype in dtype_mappings.items():
for shape in shapes:
rng = np.random.default_rng(seed=42)

if np.issubdtype(np_dtype, np.integer):
x = rng.integers(low=0, high=100, size=shape, dtype=np_dtype)
else:
x = rng.random(size=shape).astype(np_dtype)

cpp_vector = cppyy.gbl.std.vector(x)
assert len(cpp_vector) == shape[0]

if len(shape) > 1:
assert len(cpp_vector[0]) == shape[1]
if len(shape) > 2:
assert len(cpp_vector[0][0]) == shape[2]
if len(shape) > 3:
assert len(cpp_vector[0][0][0]) == shape[3]

class TestSTLSTRING:
def setup_class(cls):
Expand Down