From f65e2257bfb54fee57043d0512aa320783059a26 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Tue, 8 Aug 2023 17:08:58 -0500 Subject: [PATCH] address reviewer's comments --- dpctl/tensor/_elementwise_funcs.py | 8 +++++- .../kernels/elementwise_functions/maximum.hpp | 26 ++++++++++--------- .../kernels/elementwise_functions/minimum.hpp | 26 ++++++++++--------- 3 files changed, 35 insertions(+), 25 deletions(-) diff --git a/dpctl/tensor/_elementwise_funcs.py b/dpctl/tensor/_elementwise_funcs.py index d6cb497236..fe85a183ba 100644 --- a/dpctl/tensor/_elementwise_funcs.py +++ b/dpctl/tensor/_elementwise_funcs.py @@ -1196,7 +1196,7 @@ Default: "K". Returns: usm_narray: - An array containing the element-wise products. The data type of + An array containing the element-wise maxima. The data type of the returned array is determined by the Type Promotion Rules. """ maximum = BinaryElementwiseFunc( @@ -1429,6 +1429,12 @@ First input array, expected to have a real-valued data type. x2 (usm_ndarray): Second input array, also expected to have a real-valued data type. + out ({None, usm_ndarray}, optional): + Output array to populate. + Array have the correct shape and the expected data type. + order ("C","F","A","K", optional): + Memory layout of the newly output array, if parameter `out` is `None`. + Default: "K". Returns: usm_ndarray: an array containing the element-wise remainders. The data type of diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/maximum.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/maximum.hpp index 8beda909d3..6d12477f66 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/maximum.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/maximum.hpp @@ -71,18 +71,16 @@ template struct MaximumFunctor realT imag1 = std::imag(in1); realT imag2 = std::imag(in2); - if (std::isnan(real1) || std::isnan(imag1)) - return in1; - else if (std::isnan(real2) || std::isnan(imag2)) - return in2; - else if (real1 == real2) - return imag1 > imag2 ? in1 : in2; - else - return real1 > real2 ? in1 : in2; - } - else { - return (in1 != in1 || in1 > in2) ? in1 : in2; + bool gt = (real1 == real2) ? (imag1 > imag2) + : (real1 > real2 && !std::isnan(imag1) && + !std::isnan(imag2)); + return (std::isnan(real1) || std::isnan(imag1) || gt) ? in1 : in2; } + else if constexpr (std::is_floating_point_v || + std::is_same_v) + return (std::isnan(in1) || in1 > in2) ? in1 : in2; + else + return (in1 > in2) ? in1 : in2; } template @@ -92,7 +90,11 @@ template struct MaximumFunctor sycl::vec res; #pragma unroll for (int i = 0; i < vec_sz; ++i) { - res[i] = (in1[i] != in1[i] || in1[i] > in2[i]) ? in1[i] : in2[i]; + if constexpr (std::is_floating_point_v) + res[i] = + (sycl::isnan(in1[i]) || in1[i] > in2[i]) ? in1[i] : in2[i]; + else + res[i] = (in1[i] > in2[i]) ? in1[i] : in2[i]; } return res; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/minimum.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/minimum.hpp index 652f3da392..baddbe388d 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/minimum.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/minimum.hpp @@ -71,18 +71,16 @@ template struct MinimumFunctor realT imag1 = std::imag(in1); realT imag2 = std::imag(in2); - if (std::isnan(real1) || std::isnan(imag1)) - return in1; - else if (std::isnan(real2) || std::isnan(imag2)) - return in2; - else if (real1 == real2) - return imag1 < imag2 ? in1 : in2; - else - return real1 < real2 ? in1 : in2; - } - else { - return (in1 != in1 || in1 < in2) ? in1 : in2; + bool lt = (real1 == real2) ? (imag1 < imag2) + : (real1 < real2 && !std::isnan(imag1) && + !std::isnan(imag2)); + return (std::isnan(real1) || std::isnan(imag1) || lt) ? in1 : in2; } + else if constexpr (std::is_floating_point_v || + std::is_same_v) + return (std::isnan(in1) || in1 < in2) ? in1 : in2; + else + return (in1 < in2) ? in1 : in2; } template @@ -92,7 +90,11 @@ template struct MinimumFunctor sycl::vec res; #pragma unroll for (int i = 0; i < vec_sz; ++i) { - res[i] = (in1[i] != in1[i] || in1[i] < in2[i]) ? in1[i] : in2[i]; + if constexpr (std::is_floating_point_v) + res[i] = + (sycl::isnan(in1[i]) || in1[i] < in2[i]) ? in1[i] : in2[i]; + else + res[i] = (in1[i] < in2[i]) ? in1[i] : in2[i]; } return res; }