diff --git a/include/matx/core/allocator.h b/include/matx/core/allocator.h index 0efd59df..771caa58 100644 --- a/include/matx/core/allocator.h +++ b/include/matx/core/allocator.h @@ -43,6 +43,8 @@ #include "matx/core/error.h" #include "matx/core/nvtx.h" +#include +#include #pragma once @@ -203,7 +205,7 @@ struct MemTracker { [[maybe_unused]] std::unique_lock lck(memory_mtx); matxMemoryStats.currentBytesAllocated += bytes; matxMemoryStats.totalBytesAllocated += bytes; - matxMemoryStats.maxBytesAllocated = std::max( + matxMemoryStats.maxBytesAllocated = cuda::std::max( matxMemoryStats.maxBytesAllocated, matxMemoryStats.currentBytesAllocated); allocationMap[*ptr] = {bytes, space, stream}; } diff --git a/include/matx/core/half_complex.h b/include/matx/core/half_complex.h index cf31a70d..a5f0383a 100644 --- a/include/matx/core/half_complex.h +++ b/include/matx/core/half_complex.h @@ -147,26 +147,6 @@ template struct alignas(sizeof(T) * 2) matxHalfComplex { return {x, y}; } - /** - * @brief std::complex cast operator - * - * @return std::complex value - */ - __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ operator std::complex() - { - return {x, y}; - } - - /** - * @brief std::complex cast operator - * - * @return std::complex value - */ - __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ operator std::complex() - { - return {x, y}; - } - /** * @brief Copy assignment operator * diff --git a/include/matx/core/pybind.h b/include/matx/core/pybind.h index 4e8dc504..0a42d6b2 100644 --- a/include/matx/core/pybind.h +++ b/include/matx/core/pybind.h @@ -437,12 +437,14 @@ class MatXPybind { } template > + typename CT = matx_convert_cuda_complex_type> std::optional> CompareOutput(const TensorType &ten, const std::string fname, double thresh, bool debug = false) { - using ntype = matx_convert_complex_type; + using raw_type = typename TensorType::scalar_type; + using ntype = matx_convert_complex_type; + using ctype = matx_convert_cuda_complex_type; auto resobj = res_dict[fname.c_str()]; auto ften = pybind11::array_t(resobj); constexpr int RANK = TensorType::Rank(); @@ -453,7 +455,7 @@ class MatXPybind { auto file_val = ften.at(); auto ten_val = ConvertComplex(ten()); if (!CompareVals(ten_val, file_val, thresh, fname, debug)) { - return TestFailResult{Index2Str(0), "0", ten_val, file_val, + return TestFailResult{Index2Str(0), "0", ten_val, file_val, thresh}; } } @@ -468,7 +470,7 @@ class MatXPybind { auto file_val = ften.at(s1, s2, s3, s4); auto ten_val = ConvertComplex(ten(s1, s2, s3, s4)); if (!CompareVals(ten_val, file_val, thresh, fname, debug)) { - return TestFailResult{Index2Str(s1, s2, s3, s4), + return TestFailResult{Index2Str(s1, s2, s3, s4), fname, ten_val, file_val, thresh}; } @@ -478,7 +480,7 @@ class MatXPybind { auto file_val = ften.at(s1, s2, s3); auto ten_val = ConvertComplex(ten(s1, s2, s3)); if (!CompareVals(ten_val, file_val, thresh, fname, debug)) { - return TestFailResult{Index2Str(s1, s2, s3), fname, + return TestFailResult{Index2Str(s1, s2, s3), fname, ten_val, file_val, thresh}; } } @@ -488,7 +490,7 @@ class MatXPybind { auto file_val = ften.at(s1, s2); auto ten_val = ConvertComplex(ten(s1, s2)); if (!CompareVals(ten_val, file_val, thresh, fname, debug)) { - return TestFailResult{Index2Str(s1, s2), fname, ten_val, + return TestFailResult{Index2Str(s1, s2), fname, ten_val, file_val, thresh}; } } @@ -498,7 +500,7 @@ class MatXPybind { auto file_val = ften.at(s1); auto ten_val = ConvertComplex(ten(s1)); if (!CompareVals(ten_val, file_val, thresh, fname, debug)) { - return TestFailResult{Index2Str(s1), fname, ten_val, + return TestFailResult{Index2Str(s1), fname, ten_val, file_val, thresh}; } } diff --git a/include/matx/core/tensor_desc.h b/include/matx/core/tensor_desc.h index d588b3b2..d8b28923 100644 --- a/include/matx/core/tensor_desc.h +++ b/include/matx/core/tensor_desc.h @@ -373,7 +373,7 @@ class static_tensor_desc_t { * @param dim Dimension to retrieve * @return Size of dimension */ - static constexpr auto Size(int dim) { return shape_[dim]; } + static constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto Size(int dim) { return shape_[dim]; } /** * @brief Get stride of dimension @@ -381,7 +381,7 @@ class static_tensor_desc_t { * @param dim Dimension to retrieve * @return Stride of dimension */ - static constexpr auto Stride(int dim) { return stride_[dim]; } + static constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto Stride(int dim) { return stride_[dim]; } /** * @brief Return strides contaienr of descriptor diff --git a/include/matx/core/tensor_utils.h b/include/matx/core/tensor_utils.h index 3accf8cc..97e81728 100644 --- a/include/matx/core/tensor_utils.h +++ b/include/matx/core/tensor_utils.h @@ -99,7 +99,7 @@ namespace matx for (int i = 1; i < op.Rank(); i++) { - maxSize = std::max(op.Size(i), maxSize); + maxSize = cuda::std::max(op.Size(i), maxSize); } return maxSize; diff --git a/include/matx/core/type_utils.h b/include/matx/core/type_utils.h index b4cde9f1..78e7e397 100644 --- a/include/matx/core/type_utils.h +++ b/include/matx/core/type_utils.h @@ -800,11 +800,22 @@ struct complex_type_of typename C::value_type>>> { }; +template +struct cuda_complex_type_of + : identity, float, + typename C::value_type>>> { +}; + template using matx_convert_complex_type = typename std::conditional_t, identity, complex_type_of>::type; +template +using matx_convert_cuda_complex_type = + typename std::conditional_t, identity, + cuda_complex_type_of>::type; + template struct value_type { using type = T; diff --git a/include/matx/kernels/channelize_poly.cuh b/include/matx/kernels/channelize_poly.cuh index 055760fa..e195c6db 100644 --- a/include/matx/kernels/channelize_poly.cuh +++ b/include/matx/kernels/channelize_poly.cuh @@ -80,7 +80,7 @@ __global__ void ChannelizePoly1D(OutType output, InType input, FilterType filter constexpr index_t ELEMS_PER_BLOCK = CHANNELIZE_POLY1D_ELEMS_PER_THREAD * THREADS; const index_t first_out_elem = elem_block * CHANNELIZE_POLY1D_ELEMS_PER_THREAD * THREADS; - const index_t last_out_elem = std::min( + const index_t last_out_elem = cuda::std::min( output_len_per_channel - 1, first_out_elem + ELEMS_PER_BLOCK - 1); if (filter_phase_len <= SMEM_MAX_FILTER_TAPS) { @@ -103,7 +103,7 @@ __global__ void ChannelizePoly1D(OutType output, InType input, FilterType filter if (filter_phase_len <= SMEM_MAX_FILTER_TAPS) { for (index_t t = first_out_elem+tid; t <= last_out_elem; t += THREADS) { - const index_t first_ind = std::max(static_cast(0), t - filter_phase_len + 1); + const index_t first_ind = cuda::std::max(static_cast(0), t - filter_phase_len + 1); output_t accum {}; const filter_t *h = smem_filter; // index_t in MatX should be signed (32 or 64 bit), so j-- below will not underflow @@ -134,7 +134,7 @@ __global__ void ChannelizePoly1D(OutType output, InType input, FilterType filter } } else { for (index_t t = first_out_elem+tid; t <= last_out_elem; t += THREADS) { - index_t first_ind = std::max(static_cast(0), t - filter_phase_len + 1); + index_t first_ind = cuda::std::max(static_cast(0), t - filter_phase_len + 1); // If we use the last filter tap for this phase (which is the first index because // the filter is flipped), then it may be a padded zero. If so, increment first_ind // by 1 to avoid using the zero. This prevents a bounds-check in the inner loop. @@ -227,7 +227,7 @@ __global__ void ChannelizePoly1D_Smem(OutType output, InType input, FilterType f const uint32_t smem_input_height = filter_phase_len + by - 1; const index_t start_elem = blockIdx.x * elems_per_channel_per_cta; - const index_t last_elem = std::min(output_len_per_channel-1, (blockIdx.x+1) * elems_per_channel_per_cta - 1); + const index_t last_elem = cuda::std::min(output_len_per_channel-1, (blockIdx.x+1) * elems_per_channel_per_cta - 1); auto indims = BlockToIdx(input, blockIdx.z, 1); auto outdims = BlockToIdx(output, blockIdx.z, 2); outdims[ChannelRank] = chan; @@ -256,7 +256,7 @@ __global__ void ChannelizePoly1D_Smem(OutType output, InType input, FilterType f __syncthreads(); // Load next elems_per_channel_per_cta elements for each channel - const index_t next_last_elem = std::min(next_start_elem + by - 1, last_elem); + const index_t next_last_elem = cuda::std::min(next_start_elem + by - 1, last_elem); const uint32_t out_samples_this_iter = static_cast(next_last_elem - next_start_elem + 1); if (ty < out_samples_this_iter) { indims[InRank-1] = (next_start_elem + ty) * num_channels + chan; @@ -286,7 +286,7 @@ __global__ void ChannelizePoly1D_Smem(OutType output, InType input, FilterType f if (outdims[OutElemRank] <= last_elem) { const filter_t *h = h_start; output_t accum { 0 }; - const int first_end = std::min(cached_input_ind_tail + filter_phase_len - 1, smem_input_height - 1); + const int first_end = cuda::std::min(cached_input_ind_tail + filter_phase_len - 1, smem_input_height - 1); // The footprint of samples involved in the convolution may wrap from the end // to the beginning of smem_input. The prologue below handles the samples from // the current tail to the end of smem_input and the epilogue starts back at the @@ -342,7 +342,7 @@ __global__ void ChannelizePoly1D_FusedChan(OutType output, InType input, FilterT constexpr index_t ELEMS_PER_BLOCK = CHANNELIZE_POLY1D_ELEMS_PER_THREAD * THREADS; const index_t first_out_elem = elem_block * CHANNELIZE_POLY1D_ELEMS_PER_THREAD * THREADS; - const index_t last_out_elem = std::min( + const index_t last_out_elem = cuda::std::min( output_len_per_channel - 1, first_out_elem + ELEMS_PER_BLOCK - 1); // Pre-compute the DFT complex exponentials and store in shared memory @@ -371,7 +371,7 @@ __global__ void ChannelizePoly1D_FusedChan(OutType output, InType input, FilterT for (int i = 0; i < NUM_CHAN; i++) { accum[i] = static_cast(0); } - index_t first_ind = std::max(static_cast(0), t - filter_phase_len + 1); + index_t first_ind = cuda::std::max(static_cast(0), t - filter_phase_len + 1); indims[InRank-1] = t * NUM_CHAN + NUM_CHAN - 1; index_t j_start = t; index_t h_ind { 0 }; diff --git a/include/matx/kernels/filter.cuh b/include/matx/kernels/filter.cuh index 409d83a4..fcd41cca 100644 --- a/include/matx/kernels/filter.cuh +++ b/include/matx/kernels/filter.cuh @@ -23,7 +23,7 @@ #define COMPLEX_TYPE cuComplex -// std::max/min isn't working on template value parameters +// cuda::std::max/min isn't working on template value parameters #define MAX(a, b) ((a) < (b) ? (b) : (a)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) diff --git a/include/matx/kernels/resample_poly.cuh b/include/matx/kernels/resample_poly.cuh index 15b712aa..5c306d84 100644 --- a/include/matx/kernels/resample_poly.cuh +++ b/include/matx/kernels/resample_poly.cuh @@ -189,7 +189,7 @@ __global__ void ResamplePoly1D_PhaseBlock(OutType output, InType input, FilterTy const index_t max_input_ind = input_len - 1; const index_t start_ind = phase_ind + up * (tid + elem_block * elems_per_thread * THREADS); - const index_t last_ind = std::min(output_len - 1, start_ind + elems_per_thread * THREADS * up); + const index_t last_ind = cuda::std::min(output_len - 1, start_ind + elems_per_thread * THREADS * up); for (index_t out_ind = start_ind; out_ind <= last_ind; out_ind += THREADS * up) { // out_ind is the index in the output array and up_ind = out_ind * down is the // corresponding index in the upsampled array @@ -203,9 +203,9 @@ __global__ void ResamplePoly1D_PhaseBlock(OutType output, InType input, FilterTy // of valid samples before input_ind. In the case that the filter is not // long enough to include input_ind, last_filter_ind is left_filter_ind - up // and thus left_h_ind and prologue are both -1. - const index_t prologue = std::min(input_ind, left_h_ind); + const index_t prologue = cuda::std::min(input_ind, left_h_ind); // epilogue is the number of valid samples after input_ind. - const index_t epilogue = std::min(max_input_ind - input_ind, max_h_epilogue); + const index_t epilogue = cuda::std::min(max_input_ind - input_ind, max_h_epilogue); // n is the number of valid samples. If input_ind is not valid because it // precedes the reach of the filter, then prologue = -1 and n is just the // epilogue. @@ -302,12 +302,12 @@ __global__ void ResamplePoly1D_ElemBlock(OutType output, InType input, FilterTyp // whether or not the filter has been loaded to shared memory. const index_t filter_central_tap = (filter_len-1)/2; const index_t start_ind = elem_block * elems_per_thread * THREADS + tid; - const index_t last_ind = std::min(output_len - 1, start_ind + (elems_per_thread-1) * THREADS); + const index_t last_ind = cuda::std::min(output_len - 1, start_ind + (elems_per_thread-1) * THREADS); if (load_filter_to_smem) { for (index_t out_ind = start_ind; out_ind <= last_ind; out_ind += THREADS) { const index_t up_ind = out_ind * down; - const index_t up_start = std::max(static_cast(0), up_ind - filter_len_half); - const index_t up_end = std::min(max_input_ind * up, up_ind + filter_len_half); + const index_t up_start = cuda::std::max(static_cast(0), up_ind - filter_len_half); + const index_t up_end = cuda::std::min(max_input_ind * up, up_ind + filter_len_half); const index_t x_start = (up_start + up - 1) / up; index_t x_end = up_end / up; // Since the filter is in shared memory, we can narrow the index type to 32 bits @@ -333,8 +333,8 @@ __global__ void ResamplePoly1D_ElemBlock(OutType output, InType input, FilterTyp } else { for (index_t out_ind = start_ind; out_ind <= last_ind; out_ind += THREADS) { const index_t up_ind = out_ind * down; - const index_t up_start = std::max(static_cast(0), up_ind - filter_len_half); - const index_t up_end = std::min(max_input_ind * up, up_ind + filter_len_half); + const index_t up_start = cuda::std::max(static_cast(0), up_ind - filter_len_half); + const index_t up_end = cuda::std::min(max_input_ind * up, up_ind + filter_len_half); const index_t x_start = (up_start + up - 1) / up; index_t x_end = up_end / up; index_t h_ind = filter_central_tap + (up_ind - up*x_start); @@ -409,12 +409,12 @@ __global__ void ResamplePoly1D_WarpCentric(OutType output, InType input, FilterT const index_t filter_len_half = filter_len/2; const index_t filter_central_tap = (filter_len-1)/2; const index_t start_ind = elem_block * elems_per_warp * NUM_WARPS; - const index_t last_ind = std::min(output_len - 1, start_ind + elems_per_warp * NUM_WARPS - 1); + const index_t last_ind = cuda::std::min(output_len - 1, start_ind + elems_per_warp * NUM_WARPS - 1); if (load_filter_to_smem) { for (index_t out_ind = start_ind+warp_id; out_ind <= last_ind; out_ind += NUM_WARPS) { const index_t up_ind = out_ind * down; - const index_t up_start = std::max(static_cast(0), up_ind - filter_len_half); - const index_t up_end = std::min(max_input_ind * up, up_ind + filter_len_half); + const index_t up_start = cuda::std::max(static_cast(0), up_ind - filter_len_half); + const index_t up_end = cuda::std::min(max_input_ind * up, up_ind + filter_len_half); const index_t x_start = (up_start + up - 1) / up; index_t x_end = up_end / up; // Since the filter is in shared memory, we can narrow the index type to 32 bits @@ -449,8 +449,8 @@ __global__ void ResamplePoly1D_WarpCentric(OutType output, InType input, FilterT } else { for (index_t out_ind = start_ind+warp_id; out_ind <= last_ind; out_ind += NUM_WARPS) { const index_t up_ind = out_ind * down; - const index_t up_start = std::max(static_cast(0), up_ind - filter_len_half); - const index_t up_end = std::min(max_input_ind * up, up_ind + filter_len_half); + const index_t up_start = cuda::std::max(static_cast(0), up_ind - filter_len_half); + const index_t up_end = cuda::std::min(max_input_ind * up, up_ind + filter_len_half); const index_t x_start = (up_start + up - 1) / up; index_t x_end = up_end / up; index_t h_ind = filter_central_tap + (up_ind - up*x_start); diff --git a/include/matx/operators/conv.h b/include/matx/operators/conv.h index 575e3915..800c69c3 100644 --- a/include/matx/operators/conv.h +++ b/include/matx/operators/conv.h @@ -46,7 +46,7 @@ namespace matx private: using out_t = std::conditional_t, typename OpA::scalar_type, typename OpB::scalar_type>; - constexpr static int max_rank = std::max(OpA::Rank(), OpB::Rank()); + constexpr static int max_rank = cuda::std::max(OpA::Rank(), OpB::Rank()); OpA a_; OpB b_; matxConvCorrMode_t mode_; @@ -82,8 +82,8 @@ namespace matx for (int r = 0; r < Rank(); r++) { const int axis = perm[r]; if (axis == Rank() - 1) { - max_axis = std::max(a_.Size(r), b_.Size(r)); - min_axis = std::min(a_.Size(r), b_.Size(r)); + max_axis = cuda::std::max(a_.Size(r), b_.Size(r)); + min_axis = cuda::std::min(a_.Size(r), b_.Size(r)); if (mode_ == MATX_C_MODE_FULL) { out_dims_[axis] = a_.Size(r) + b_.Size(r) - 1; @@ -112,8 +112,8 @@ namespace matx } } - max_axis = std::max(a_.Size(OpA::Rank()-1), b_.Size(OpB::Rank()-1)); - min_axis = std::min(a_.Size(OpA::Rank()-1), b_.Size(OpB::Rank()-1)); + max_axis = cuda::std::max(a_.Size(OpA::Rank()-1), b_.Size(OpB::Rank()-1)); + min_axis = cuda::std::min(a_.Size(OpA::Rank()-1), b_.Size(OpB::Rank()-1)); if (mode_ == MATX_C_MODE_FULL) { out_dims_[max_rank-1] = max_axis + min_axis - 1; @@ -231,7 +231,7 @@ namespace detail { private: using out_t = std::conditional_t, typename OpA::scalar_type, typename OpB::scalar_type>; - constexpr static int max_rank = std::max(OpA::Rank(), OpB::Rank()); + constexpr static int max_rank = cuda::std::max(OpA::Rank(), OpB::Rank()); OpA a_; OpB b_; matxConvCorrMode_t mode_; @@ -257,8 +257,8 @@ namespace detail { for (int r = 0; r < Rank(); r++) { const int axis = perm[r]; if (axis >= Rank() - 2) { - const auto max_axis = std::max(a_.Size(r), b_.Size(r)); - const auto min_axis = std::min(a_.Size(r), b_.Size(r)); + const auto max_axis = cuda::std::max(a_.Size(r), b_.Size(r)); + const auto min_axis = cuda::std::min(a_.Size(r), b_.Size(r)); if (mode_ == MATX_C_MODE_FULL) { out_dims_[axis] = a_.Size(r) + b_.Size(r) - 1; } @@ -287,8 +287,8 @@ namespace detail { } for (int r = max_rank - 2; r < max_rank; r++) { - const auto max_axis = std::max(a_.Size(r), b_.Size(r)); - const auto min_axis = std::min(a_.Size(r), b_.Size(r)); + const auto max_axis = cuda::std::max(a_.Size(r), b_.Size(r)); + const auto min_axis = cuda::std::min(a_.Size(r), b_.Size(r)); if (mode_ == MATX_C_MODE_FULL) { out_dims_[r] = max_axis + min_axis - 1; } diff --git a/include/matx/operators/corr.h b/include/matx/operators/corr.h index dd1d6c31..36a42038 100644 --- a/include/matx/operators/corr.h +++ b/include/matx/operators/corr.h @@ -46,7 +46,7 @@ namespace matx private: using out_t = std::conditional_t, typename OpA::scalar_type, typename OpB::scalar_type>; - constexpr static int max_rank = std::max(OpA::Rank(), OpB::Rank()); + constexpr static int max_rank = cuda::std::max(OpA::Rank(), OpB::Rank()); OpA a_; OpB b_; matxConvCorrMode_t mode_; @@ -73,8 +73,8 @@ namespace matx for (int r = 0; r < Rank(); r++) { const int axis = perm[r]; if (axis == Rank() - 1) { - const auto max_axis = std::max(a_.Size(r), b_.Size(r)); - const auto min_axis = std::min(a_.Size(r), b_.Size(r)); + const auto max_axis = cuda::std::max(a_.Size(r), b_.Size(r)); + const auto min_axis = cuda::std::min(a_.Size(r), b_.Size(r)); if (mode_ == MATX_C_MODE_FULL) { out_dims_[axis] = a_.Size(r) + b_.Size(r) - 1; @@ -103,8 +103,8 @@ namespace matx } } - const auto max_axis = std::max(a_.Size(OpA::Rank()-1), b_.Size(OpB::Rank()-1)); - const auto min_axis = std::min(a_.Size(OpA::Rank()-1), b_.Size(OpB::Rank()-1)); + const auto max_axis = cuda::std::max(a_.Size(OpA::Rank()-1), b_.Size(OpB::Rank()-1)); + const auto min_axis = cuda::std::min(a_.Size(OpA::Rank()-1), b_.Size(OpB::Rank()-1)); if (mode_ == MATX_C_MODE_FULL) { out_dims_[max_rank-1] = max_axis + min_axis - 1; } diff --git a/include/matx/operators/diag.h b/include/matx/operators/diag.h index 2345a3e7..74b9e53c 100644 --- a/include/matx/operators/diag.h +++ b/include/matx/operators/diag.h @@ -84,7 +84,7 @@ namespace matx return op_.Size(dim); } else { - return std::min(op_.Size(RANK - 1), op_.Size(RANK-2)); + return cuda::std::min(op_.Size(RANK - 1), op_.Size(RANK-2)); } } diff --git a/include/matx/operators/matmul.h b/include/matx/operators/matmul.h index 9a9e36ef..185e9eec 100644 --- a/include/matx/operators/matmul.h +++ b/include/matx/operators/matmul.h @@ -49,7 +49,7 @@ namespace matx float alpha_; float beta_; PermDims perm_; - static constexpr int out_rank = std::max(OpA::Rank(), OpB::Rank()); + static constexpr int out_rank = cuda::std::max(OpA::Rank(), OpB::Rank()); cuda::std::array out_dims_; mutable matx::tensor_t tmp_out_; diff --git a/include/matx/operators/outer.h b/include/matx/operators/outer.h index cc2680a2..9e412275 100644 --- a/include/matx/operators/outer.h +++ b/include/matx/operators/outer.h @@ -48,7 +48,7 @@ namespace matx OpB b_; float alpha_; float beta_; - static constexpr int RANK = std::max(remove_cvref_t::Rank(), remove_cvref_t::Rank()) + 1; + static constexpr int RANK = cuda::std::max(remove_cvref_t::Rank(), remove_cvref_t::Rank()) + 1; cuda::std::array out_dims_; mutable matx::tensor_t tmp_out_; diff --git a/include/matx/operators/scalar_ops.h b/include/matx/operators/scalar_ops.h index d87ab88e..e29b788f 100644 --- a/include/matx/operators/scalar_ops.h +++ b/include/matx/operators/scalar_ops.h @@ -33,6 +33,7 @@ #pragma once #include +#include namespace matx { namespace detail { @@ -207,7 +208,7 @@ MATX_UNARY_OP_GEN(norm, Norm); template static __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto _internal_sin(T v1) { if constexpr (is_matx_type_v) { - return sin(v1); + return matx::sin(v1); } else { return cuda::std::sin(v1); @@ -222,7 +223,7 @@ template using SinOp = UnOp>; template static __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto _internal_cos(T v1) { if constexpr (is_matx_type_v) { - return cos(v1); + return matx::cos(v1); } else { return cuda::std::cos(v1); @@ -555,7 +556,7 @@ template struct MaximumF { static __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto op(T1 v1, T2 v2) { - return std::max(v1, v2); + return cuda::std::max(v1, v2); } }; template using MaximumOp = BinOp>; @@ -565,7 +566,7 @@ template struct MinimumF { static __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto op(T1 v1, T2 v2) { - return std::min(v1, v2); + return cuda::std::min(v1, v2); } }; template using MinimumOp = BinOp>; diff --git a/include/matx/transforms/channelize_poly.h b/include/matx/transforms/channelize_poly.h index deb58ff0..27490cf5 100644 --- a/include/matx/transforms/channelize_poly.h +++ b/include/matx/transforms/channelize_poly.h @@ -97,7 +97,7 @@ inline size_t matxChannelizePoly1DInternal_SmemSizeBytes(const OutType &o, const size_t smem_size = sizeof(filter_t)*(num_channels)*(filter_phase_len) + sizeof(input_t)*(num_channels)*(filter_phase_len + MATX_CHANNELIZE_POLY1D_FULL_SMEM_KERNEL_NOUT_PER_ITER - 1); - const size_t max_sizeof = std::max(sizeof(filter_t), sizeof(input_t)); + const size_t max_sizeof = cuda::std::max(sizeof(filter_t), sizeof(input_t)); if (smem_size % max_sizeof) { smem_size += max_sizeof - (smem_size % max_sizeof); } diff --git a/include/matx/transforms/conv.h b/include/matx/transforms/conv.h index a43a5980..25e4fa76 100644 --- a/include/matx/transforms/conv.h +++ b/include/matx/transforms/conv.h @@ -179,7 +179,7 @@ inline void matxDirectConv1DInternal(OutputType &o, const InType &i, unsigned int num_blocks = (unsigned int)(sig_len + filter.Size(filter.Rank()-1) + work_per_block -1) / work_per_block; // number below was chosen arbitrarily. Cannot be more than 65536. - num_blocks = std::min(num_blocks, 10000U); + num_blocks = cuda::std::min(num_blocks, 10000U); unsigned int grid_size = static_cast(TotalSize(i)/i.Size(i.Rank() - 1)); @@ -238,13 +238,13 @@ inline void conv1d_impl_internal(OutputType &o, const In1Type &i1, const In2Type static_assert(In1Type::Rank() == In2Type::Rank()); if (mode == MATX_C_MODE_SAME) { - MATX_ASSERT_STR(o.Size(OutputType::Rank() - 1) == std::max(i1.Size(i1.Rank()-1), i2.Size(i2.Rank()-1)), matxInvalidSize, + MATX_ASSERT_STR(o.Size(OutputType::Rank() - 1) == cuda::std::max(i1.Size(i1.Rank()-1), i2.Size(i2.Rank()-1)), matxInvalidSize, "Output size for SAME mode convolution must match largest input size"); } if (mode == MATX_C_MODE_VALID) { MATX_ASSERT_STR(o.Size(OutputType::Rank() - 1) == - std::max(i1.Size(i1.Rank()-1), i2.Size(i2.Rank()-1)) - std::min(i1.Size(i1.Rank()-1), i2.Size(i2.Rank()-1)) + 1, matxInvalidSize, + cuda::std::max(i1.Size(i1.Rank()-1), i2.Size(i2.Rank()-1)) - cuda::std::min(i1.Size(i1.Rank()-1), i2.Size(i2.Rank()-1)) + 1, matxInvalidSize, "Output size for VALID mode convolution must be N - L + 1"); } diff --git a/include/matx/transforms/outer.h b/include/matx/transforms/outer.h index a9c88b62..d51e0114 100644 --- a/include/matx/transforms/outer.h +++ b/include/matx/transforms/outer.h @@ -87,7 +87,7 @@ __MATX_INLINE__ void outer_impl(TensorTypeC C, const TensorTypeA A, ac.fill(matxKeepDim); bc.fill(matxKeepDim); - for (int r = 0; r < std::min(A.Rank(), B.Rank()) - 1; r++) { + for (int r = 0; r < cuda::std::min(A.Rank(), B.Rank()) - 1; r++) { MATX_ASSERT_STR(A.Size(r) == B.Size(r), matxInvalidSize, "A and B tensors must match batch sizes"); } diff --git a/include/matx/transforms/qr.h b/include/matx/transforms/qr.h index 5348f722..891610a5 100644 --- a/include/matx/transforms/qr.h +++ b/include/matx/transforms/qr.h @@ -84,7 +84,7 @@ namespace detail { index_t m = A.Size(RANK-2); index_t n = A.Size(RANK-1); - index_t k = std::min(m,n); + index_t k = cuda::std::min(m,n); if(m<=n) k--; // these matrices have one less update since the diagonal ends on the bottom of the matrix auto Qin = cuda::std::get<0>(workspace); diff --git a/include/matx/transforms/reduce.h b/include/matx/transforms/reduce.h index d330c465..c244fe43 100644 --- a/include/matx/transforms/reduce.h +++ b/include/matx/transforms/reduce.h @@ -2117,12 +2117,12 @@ void __MATX_INLINE__ argmax_impl(OutType dest, TensorIndexType &idest, const InT auto ft = [&](auto &&lin, auto &&lout, [[maybe_unused]] auto &&lbegin, [[maybe_unused]] auto &&lend) { if constexpr (OutType::Rank() == 0) { - *lout = std::max_element(lin, lin + TotalSize(in)) - lin; + *lout = cuda::std::max_element(lin, lin + TotalSize(in)) - lin; } else { auto els = lend[0] - lbegin[0]; for (index_t b = 0; b < els; b++) { - lout[b] = std::max_element(lin + lbegin[b], lin + lend[b]) - lin; + lout[b] = cuda::std::max_element(lin + lbegin[b], lin + lend[b]) - lin; } } }; @@ -2264,12 +2264,12 @@ void __MATX_INLINE__ argmin_impl(OutType dest, TensorIndexType &idest, const InT auto ft = [&](auto &&lin, auto &&lout, [[maybe_unused]] auto &&lbegin, [[maybe_unused]] auto &&lend) { if constexpr (OutType::Rank() == 0) { - *lout = std::min_element(lin, lin + TotalSize(in)) - lin; + *lout = cuda::std::min_element(lin, lin + TotalSize(in)) - lin; } else { auto els = lend[1] - lbegin[0]; for (index_t b = 0; b < els; b++) { - lout[b] = std::min_element(lin + lbegin[b], lin + lend[b]) - lin; + lout[b] = cuda::std::min_element(lin + lbegin[b], lin + lend[b]) - lin; } } }; diff --git a/include/matx/transforms/solver.h b/include/matx/transforms/solver.h index c378b4db..ae69df79 100644 --- a/include/matx/transforms/solver.h +++ b/include/matx/transforms/solver.h @@ -1288,7 +1288,7 @@ void det_impl(OutputTensor &out, const InputTensor &a, s[i] = a_new.Size(i); } - s[RANK - 2] = std::min(a_new.Size(RANK - 1), a_new.Size(RANK - 2)); + s[RANK - 2] = cuda::std::min(a_new.Size(RANK - 1), a_new.Size(RANK - 2)); auto piv = make_tensor(s, MATX_ASYNC_DEVICE_MEMORY, stream); auto ac = make_tensor(a_new.Shape(), MATX_ASYNC_DEVICE_MEMORY, stream); diff --git a/include/matx/transforms/svd.h b/include/matx/transforms/svd.h index 20293be0..864f824c 100644 --- a/include/matx/transforms/svd.h +++ b/include/matx/transforms/svd.h @@ -88,7 +88,7 @@ void svdpi_impl(UType &U, SType &S, VTType &VT, AType &A, X0Type &x0, int iterat auto m = A.Size(RANK-2); // rows auto n = A.Size(RANK-1); // cols - auto d = std::min(n,m); // dim for AAT or ATA + auto d = cuda::std::min(n,m); // dim for AAT or ATA // if sentinal found get all singularvalues if( k == -1 ) k = (int) d; @@ -309,7 +309,7 @@ inline auto svdbpi_impl_workspace(const AType &A, cudaStream_t stream) { auto m = A.Size(RANK-2); // rows auto n = A.Size(RANK-1); // cols - auto d = std::min(n,m); // dim for AAT or ATA + auto d = cuda::std::min(n,m); // dim for AAT or ATA auto ATShape = A.Shape(); ATShape[RANK-2] = d; @@ -383,7 +383,7 @@ inline void svdbpi_impl(UType &U, SType &S, VTType &VT, const AType &A, int max_ auto m = A.Size(RANK-2); // rows auto n = A.Size(RANK-1); // cols - auto d = std::min(n,m); // dim for AAT or ATA + auto d = cuda::std::min(n,m); // dim for AAT or ATA // assert batch sizes are the same for(int i = 0 ; i < RANK-2; i++) {