diff --git a/src/Pythonize.cxx b/src/Pythonize.cxx index 9dc78d2..9a06e98 100644 --- a/src/Pythonize.cxx +++ b/src/Pythonize.cxx @@ -52,6 +52,24 @@ bool HasAttrDirect(PyObject* pyclass, PyObject* pyname, bool mustBeCPyCppyy = fa return false; } +template +T Get_IndexValue(Py_buffer *view, int i) +{ + if (!view || !view->buf) + { + // Handle the error, e.g., throw an exception or return a default value + PyErr_SetString(PyExc_RuntimeError, "Buffer is not valid"); + } + + if (i < 0 || i >= view->len / view->itemsize) + { + // Handle the error, e.g., throw an exception or return a default value + PyErr_SetString(PyExc_IndexError, "Index out of range"); + } + // get the value at index i in the view + return *(T *)((char *)view->buf + i * view->strides[0]); +} + PyObject* GetAttrDirect(PyObject* pyclass, PyObject* pyname) { // get an attribute without causing getattr lookups PyObject* dct = PyObject_GetAttr(pyclass, PyStrings::gDict); @@ -459,7 +477,96 @@ PyObject* VectorIAdd(PyObject* self, PyObject* args, PyObject* /* kwds */) PyErr_SetString(PyExc_TypeError, "argument is not iterable"); return nullptr; // error already set } +PyObject *recursive_vector_init(PyObject *self, Py_buffer *view, PyObject *result, int ndim) +{ + // PyObject *tmp_result = PyList_New(0); + if (ndim == 1) + { + if (!result) + return nullptr; + + Py_ssize_t fillsz = view->len / view->itemsize; + PyObject *pb_call = PyObject_GetAttrString(self, "push_back"); + for (Py_ssize_t i = 0; i < fillsz; i++) + { + int val = Get_IndexValue(view, i); + PyObject *item = PyLong_FromLong(val); + + if (!item) + { + Py_DECREF(result); + Py_XDECREF(pb_call); + return nullptr; + } + + PyObject *pbres = PyObject_CallFunctionObjArgs(pb_call, item, nullptr); + Py_DECREF(item); + + if (!pbres) + { + Py_DECREF(result); + Py_XDECREF(pb_call); + return nullptr; + } + + Py_DECREF(pbres); + } + + Py_XDECREF(pb_call); + return result; + } + + if (!result) + return nullptr; + + Py_ssize_t *shape = (Py_ssize_t *)view->shape; + Py_ssize_t *strides = (Py_ssize_t *)view->strides; + Py_ssize_t *subshape = shape + 1; + Py_ssize_t *substrides = strides + 1; + + for (Py_ssize_t i = 0; i < view->ndim; i++) + { + Py_buffer subview; + subview.buf = (void *)((char *)view->buf + i * strides[0]); + subview.obj = NULL; + subview.len = subshape[0] * substrides[0]; + subview.readonly = view->readonly; + subview.itemsize = view->itemsize; + subview.format = view->format; + subview.ndim = ndim - 1; + subview.shape = subshape; + subview.strides = substrides; + subview.suboffsets = view->suboffsets; + subview.internal = view->internal; + + PyObject *subresult = recursive_vector_init(self, &subview, result, ndim - 1); + if (!subresult) + { + Py_DECREF(result); + return nullptr; + } + PyObject *pb_call = PyObject_GetAttrString(self, "push_back"); + + + PyObject *pbres = PyObject_CallFunctionObjArgs(pb_call, subresult, nullptr); + + if (!pbres) + { + Py_DECREF(result); + Py_XDECREF(pb_call); + return nullptr; + } + + + Py_DECREF(pbres); + Py_DECREF(pb_call); + + Py_DECREF(subresult); + } + + return result; +} PyObject* VectorInit(PyObject* self, PyObject* args, PyObject* /* kwds */) { @@ -490,10 +597,41 @@ PyObject* VectorInit(PyObject* self, PyObject* args, PyObject* /* kwds */) return result; } -// The given argument wasn't iterable: simply forward to regular constructor - PyObject* realInit = PyObject_GetAttr(self, PyStrings::gRealInit); - if (realInit) { - PyObject* result = PyObject_Call(realInit, args, nullptr); + // get the first argument + PyObject* fi = PyTuple_GET_ITEM(args, 0); + if (fi == Py_None) + { + // empty vector + return PyObject_CallMethodNoArgs(self, PyStrings::gRealInit); + } + + // check if numpy is passed + if (PyObject_CheckBuffer(fi)){ + // create a memoryview + PyObject* memoryview = PyMemoryView_FromObject(fi); + Py_buffer* view = PyMemoryView_GET_BUFFER(memoryview); + + // check if memoryview is valid + if (view->buf == NULL) return nullptr; + + PyObject *result = PyObject_CallMethodNoArgs(self, PyStrings::gRealInit); + + result = recursive_vector_init(self, view, result, view->ndim); + + // dereference the memoryview buffer + PyBuffer_Release(view); + Py_DECREF(memoryview); + Py_DECREF(fi); + + return result; + + } + + // The given argument wasn't iterable or a numpy array: simply forward to regular constructor + PyObject *realInit = PyObject_GetAttr(self, PyStrings::gRealInit); + if (realInit) + { + PyObject *result = PyObject_Call(realInit, args, nullptr); Py_DECREF(realInit); return result; }