Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add zero-copy interface from MatX to NumPy #653

Merged
merged 1 commit into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 56 additions & 32 deletions include/matx/core/pybind.h
Original file line number Diff line number Diff line change
Expand Up @@ -397,50 +397,74 @@ class MatXPybind {
}

template <typename TensorType>
auto TensorViewToNumpy(const TensorType &ten)
{
auto TensorViewToNumpy(const TensorType &ten) {
using tensor_type = typename TensorType::scalar_type;
using ntype = matx_convert_complex_type<tensor_type>;
constexpr int RANK = TensorType::Rank();
static_assert(RANK <=5, "TensorViewToNumpy only supports max(RANK) = 5 at the moment.");

using ntype = matx_convert_complex_type<typename TensorType::scalar_type>;
auto ften = pybind11::array_t<ntype>(ten.Shape());

for (index_t s1 = 0; s1 < ten.Size(0); s1++) {
if constexpr (RANK > 1) {
for (index_t s2 = 0; s2 < ten.Size(1); s2++) {
if constexpr (RANK > 2) {
for (index_t s3 = 0; s3 < ten.Size(2); s3++) {
if constexpr (RANK > 3) {
for (index_t s4 = 0; s4 < ten.Size(3); s4++) {
if constexpr (RANK > 4) {
for (index_t s5 = 0; s5 < ten.Size(4); s5++) {
ften.mutable_at(s1, s2, s3, s4, s5) =
ConvertComplex(ten(s1, s2, s3, s4, s5));

// If this is a half-precision type pybind/numpy doesn't support it, so we fall back to the
// slow method where we convert everything
if constexpr (is_matx_type<tensor_type>()) {
auto ften = pybind11::array_t<ntype, pybind11::array::c_style | pybind11::array::forcecast>(ten.Shape());

for (index_t s1 = 0; s1 < ten.Size(0); s1++) {
if constexpr (RANK > 1) {
for (index_t s2 = 0; s2 < ten.Size(1); s2++) {
if constexpr (RANK > 2) {
for (index_t s3 = 0; s3 < ten.Size(2); s3++) {
if constexpr (RANK > 3) {
for (index_t s4 = 0; s4 < ten.Size(3); s4++) {
if constexpr (RANK > 4) {
for (index_t s5 = 0; s5 < ten.Size(4); s5++) {
ften.mutable_at(s1, s2, s3, s4, s5) =
ConvertComplex(ten(s1, s2, s3, s4, s5));
}
} else {
ften.mutable_at(s1, s2, s3, s4) =
ConvertComplex(ten(s1, s2, s3, s4));
}
} else {
ften.mutable_at(s1, s2, s3, s4) =
ConvertComplex(ten(s1, s2, s3, s4));
}
}
}
else {
ften.mutable_at(s1, s2, s3) = ConvertComplex(ten(s1, s2, s3));
else {
ften.mutable_at(s1, s2, s3) = ConvertComplex(ten(s1, s2, s3));
}
}
}
}
else {
ften.mutable_at(s1, s2) = ConvertComplex(ten(s1, s2));
else {
ften.mutable_at(s1, s2) = ConvertComplex(ten(s1, s2));
}
}
}
else {
ften.mutable_at(s1) = ConvertComplex(ten(s1));
}
}
else {
ften.mutable_at(s1) = ConvertComplex(ten(s1));
}
}

return ften;
return ften;
}
else {
const auto tshape = ten.Shape();
const auto tstrides = ten.Strides();
std::vector<pybind11::ssize_t> shape{tshape.begin(), tshape.end()};
std::vector<pybind11::ssize_t> strides{tstrides.begin(), tstrides.end()};
std::for_each(strides.begin(), strides.end(), [](pybind11::ssize_t &x) {
x *= sizeof(tensor_type);
});

auto buf = pybind11::buffer_info(
ten.Data(),
sizeof(tensor_type),
pybind11::format_descriptor<ntype>::format(),
RANK,
shape,
strides
);

return pybind11::array_t<ntype, pybind11::array::c_style | pybind11::array::forcecast>(buf);
}
}


template <typename TensorType,
typename CT = matx_convert_cuda_complex_type<typename TensorType::scalar_type>>
std::optional<TestFailResult<CT>>
Expand Down
8 changes: 8 additions & 0 deletions include/matx/core/tensor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,14 @@ class tensor_impl_t {
*/
__MATX_INLINE__ auto Shape() const noexcept { return this->desc_.Shape(); }

/**
* Get the strides the tensor from the underlying data
*
* @return
* A shape of the data with the appropriate strides set
*/
__MATX_INLINE__ auto Strides() const noexcept { return this->desc_.Strides(); }

/**
* Set the size of a dimension
*
Expand Down