Skip to content

Commit

Permalink
Fixing sort allocation (#764)
Browse files Browse the repository at this point in the history
  • Loading branch information
cliffburdick authored Oct 14, 2024
1 parent 690d4ab commit ab9e7a1
Showing 1 changed file with 17 additions and 10 deletions.
27 changes: 17 additions & 10 deletions include/matx/transforms/cub.h
Original file line number Diff line number Diff line change
Expand Up @@ -1418,20 +1418,27 @@ void sort_impl(OutputTensor &a_out, const InputOperator &a,

cudaStream_t stream = exec.getStream();

using a_type = typename InputOperator::value_type;
a_type *out_ptr = nullptr;
detail::tensor_impl_t<a_type, InputOperator::Rank()> tmp_in;

// sorting currently requires a contiguous tensor view, so allocate a temporary
// tensor to copy the input if necessary.
auto a_contig = [&a, &stream]() -> auto {
if constexpr (is_tensor_view_v<InputOperator>) {
if (a.IsContiguous()) {
return a;
}
bool done = false;
if constexpr (is_tensor_view_v<InputOperator>) {
if (a.IsContiguous()) {
make_tensor(tmp_in, a.Data(), a.Shape());
done = true;
}
auto a_tmp = make_tensor<typename InputOperator::value_type>(a.Shape(), MATX_ASYNC_DEVICE_MEMORY, stream);
(a_tmp = a).run(stream);
return a_tmp;
}();
}

if (!done) {
matxAlloc((void**)&out_ptr, TotalSize(a) * sizeof(a_type), MATX_ASYNC_DEVICE_MEMORY, exec.getStream());
make_tensor(tmp_in, out_ptr, a.Shape());
(tmp_in = a).run(exec);
}

detail::sort_impl_inner(a_out, a_contig, dir, exec);
detail::sort_impl_inner(a_out, tmp_in, dir, exec);
#endif
}

Expand Down

0 comments on commit ab9e7a1

Please sign in to comment.