Skip to content

Commit

Permalink
address reviewer's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
vtavana committed Aug 9, 2023
1 parent b989d36 commit 0c9fbeb
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 25 deletions.
8 changes: 7 additions & 1 deletion dpctl/tensor/_elementwise_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,7 +1196,7 @@
Default: "K".
Returns:
usm_narray:
An array containing the element-wise products. The data type of
An array containing the element-wise result. The data type of
the returned array is determined by the Type Promotion Rules.
"""
maximum = BinaryElementwiseFunc(
Expand Down Expand Up @@ -1429,6 +1429,12 @@
First input array, expected to have a real-valued data type.
x2 (usm_ndarray):
Second input array, also expected to have a real-valued data type.
out ({None, usm_ndarray}, optional):
Output array to populate.
Array have the correct shape and the expected data type.
order ("C","F","A","K", optional):
Memory layout of the newly output array, if parameter `out` is `None`.
Default: "K".
Returns:
usm_ndarray:
an array containing the element-wise remainders. The data type of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,16 @@ template <typename argT1, typename argT2, typename resT> struct MaximumFunctor
realT imag1 = std::imag(in1);
realT imag2 = std::imag(in2);

if (std::isnan(real1) || std::isnan(imag1))
return in1;
else if (std::isnan(real2) || std::isnan(imag2))
return in2;
else if (real1 == real2)
return imag1 > imag2 ? in1 : in2;
else
return real1 > real2 ? in1 : in2;
}
else {
return (in1 != in1 || in1 > in2) ? in1 : in2;
bool gt = (real1 == real2) ? (imag1 > imag2)
: (real1 > real2 && !std::isnan(imag1) &&
!std::isnan(imag2));
return (std::isnan(real1) || std::isnan(imag1) || gt) ? in1 : in2;
}
else if constexpr (std::is_floating_point_v<argT1> ||
std::is_same_v<argT1, sycl::half>)
return (std::isnan(in1) || in1 > in2) ? in1 : in2;
else
return (in1 > in2) ? in1 : in2;
}

template <int vec_sz>
Expand All @@ -92,7 +90,11 @@ template <typename argT1, typename argT2, typename resT> struct MaximumFunctor
sycl::vec<resT, vec_sz> res;
#pragma unroll
for (int i = 0; i < vec_sz; ++i) {
res[i] = (in1[i] != in1[i] || in1[i] > in2[i]) ? in1[i] : in2[i];
if constexpr (std::is_floating_point_v<argT1>)
res[i] =
(sycl::isnan(in1[i]) || in1[i] > in2[i]) ? in1[i] : in2[i];
else
res[i] = (in1[i] > in2[i]) ? in1[i] : in2[i];
}
return res;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,16 @@ template <typename argT1, typename argT2, typename resT> struct MinimumFunctor
realT imag1 = std::imag(in1);
realT imag2 = std::imag(in2);

if (std::isnan(real1) || std::isnan(imag1))
return in1;
else if (std::isnan(real2) || std::isnan(imag2))
return in2;
else if (real1 == real2)
return imag1 < imag2 ? in1 : in2;
else
return real1 < real2 ? in1 : in2;
}
else {
return (in1 != in1 || in1 < in2) ? in1 : in2;
bool lt = (real1 == real2) ? (imag1 < imag2)
: (real1 < real2 && !std::isnan(imag1) &&
!std::isnan(imag2));
return (std::isnan(real1) || std::isnan(imag1) || lt) ? in1 : in2;
}
else if constexpr (std::is_floating_point_v<argT1> ||
std::is_same_v<argT1, sycl::half>)
return (std::isnan(in1) || in1 < in2) ? in1 : in2;
else
return (in1 < in2) ? in1 : in2;
}

template <int vec_sz>
Expand All @@ -92,7 +90,11 @@ template <typename argT1, typename argT2, typename resT> struct MinimumFunctor
sycl::vec<resT, vec_sz> res;
#pragma unroll
for (int i = 0; i < vec_sz; ++i) {
res[i] = (in1[i] != in1[i] || in1[i] < in2[i]) ? in1[i] : in2[i];
if constexpr (std::is_floating_point_v<argT1>)
res[i] =
(sycl::isnan(in1[i]) || in1[i] < in2[i]) ? in1[i] : in2[i];
else
res[i] = (in1[i] < in2[i]) ? in1[i] : in2[i];
}
return res;
}
Expand Down

0 comments on commit 0c9fbeb

Please sign in to comment.