Skip to content

Commit

Permalink
Add zero-copy interface from MatX to NumPy (#653)
Browse files Browse the repository at this point in the history
  • Loading branch information
cliffburdick authored Jun 24, 2024
1 parent 7d1debb commit ed09e1c
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 32 deletions.
88 changes: 56 additions & 32 deletions include/matx/core/pybind.h
Original file line number Diff line number Diff line change
Expand Up @@ -397,50 +397,74 @@ class MatXPybind {
}

template <typename TensorType>
auto TensorViewToNumpy(const TensorType &ten)
{
auto TensorViewToNumpy(const TensorType &ten) {
using tensor_type = typename TensorType::scalar_type;
using ntype = matx_convert_complex_type<tensor_type>;
constexpr int RANK = TensorType::Rank();
static_assert(RANK <=5, "TensorViewToNumpy only supports max(RANK) = 5 at the moment.");

using ntype = matx_convert_complex_type<typename TensorType::scalar_type>;
auto ften = pybind11::array_t<ntype>(ten.Shape());

for (index_t s1 = 0; s1 < ten.Size(0); s1++) {
if constexpr (RANK > 1) {
for (index_t s2 = 0; s2 < ten.Size(1); s2++) {
if constexpr (RANK > 2) {
for (index_t s3 = 0; s3 < ten.Size(2); s3++) {
if constexpr (RANK > 3) {
for (index_t s4 = 0; s4 < ten.Size(3); s4++) {
if constexpr (RANK > 4) {
for (index_t s5 = 0; s5 < ten.Size(4); s5++) {
ften.mutable_at(s1, s2, s3, s4, s5) =
ConvertComplex(ten(s1, s2, s3, s4, s5));

// If this is a half-precision type pybind/numpy doesn't support it, so we fall back to the
// slow method where we convert everything
if constexpr (is_matx_type<tensor_type>()) {
auto ften = pybind11::array_t<ntype, pybind11::array::c_style | pybind11::array::forcecast>(ten.Shape());

for (index_t s1 = 0; s1 < ten.Size(0); s1++) {
if constexpr (RANK > 1) {
for (index_t s2 = 0; s2 < ten.Size(1); s2++) {
if constexpr (RANK > 2) {
for (index_t s3 = 0; s3 < ten.Size(2); s3++) {
if constexpr (RANK > 3) {
for (index_t s4 = 0; s4 < ten.Size(3); s4++) {
if constexpr (RANK > 4) {
for (index_t s5 = 0; s5 < ten.Size(4); s5++) {
ften.mutable_at(s1, s2, s3, s4, s5) =
ConvertComplex(ten(s1, s2, s3, s4, s5));
}
} else {
ften.mutable_at(s1, s2, s3, s4) =
ConvertComplex(ten(s1, s2, s3, s4));
}
} else {
ften.mutable_at(s1, s2, s3, s4) =
ConvertComplex(ten(s1, s2, s3, s4));
}
}
}
else {
ften.mutable_at(s1, s2, s3) = ConvertComplex(ten(s1, s2, s3));
else {
ften.mutable_at(s1, s2, s3) = ConvertComplex(ten(s1, s2, s3));
}
}
}
}
else {
ften.mutable_at(s1, s2) = ConvertComplex(ten(s1, s2));
else {
ften.mutable_at(s1, s2) = ConvertComplex(ten(s1, s2));
}
}
}
else {
ften.mutable_at(s1) = ConvertComplex(ten(s1));
}
}
else {
ften.mutable_at(s1) = ConvertComplex(ten(s1));
}
}

return ften;
return ften;
}
else {
const auto tshape = ten.Shape();
const auto tstrides = ten.Strides();
std::vector<pybind11::ssize_t> shape{tshape.begin(), tshape.end()};
std::vector<pybind11::ssize_t> strides{tstrides.begin(), tstrides.end()};
std::for_each(strides.begin(), strides.end(), [](pybind11::ssize_t &x) {
x *= sizeof(tensor_type);
});

auto buf = pybind11::buffer_info(
ten.Data(),
sizeof(tensor_type),
pybind11::format_descriptor<ntype>::format(),
RANK,
shape,
strides
);

return pybind11::array_t<ntype, pybind11::array::c_style | pybind11::array::forcecast>(buf);
}
}


template <typename TensorType,
typename CT = matx_convert_cuda_complex_type<typename TensorType::scalar_type>>
std::optional<TestFailResult<CT>>
Expand Down
8 changes: 8 additions & 0 deletions include/matx/core/tensor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,14 @@ class tensor_impl_t {
*/
__MATX_INLINE__ auto Shape() const noexcept { return this->desc_.Shape(); }

/**
* Get the strides the tensor from the underlying data
*
* @return
* A shape of the data with the appropriate strides set
*/
__MATX_INLINE__ auto Strides() const noexcept { return this->desc_.Strides(); }

/**
* Set the size of a dimension
*
Expand Down

0 comments on commit ed09e1c

Please sign in to comment.