diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp index 02375d5313..639a5e5988 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp @@ -69,11 +69,16 @@ template struct LogAddExpFunctor const sycl::vec &in2) { sycl::vec res; - auto diff = in1 - in2; + auto diff = in1 - in2; // take advantange of faster vec arithmetic #pragma unroll for (int i = 0; i < vec_sz; ++i) { - res[i] = impl(in1[i], in2[i]); + if (std::isfinite(diff[i])) { + res[i] = in2[i] + impl_finite(diff[i]); + } + else { + res[i] = impl(in[i], in[2]); + } } return res; @@ -82,19 +87,28 @@ template struct LogAddExpFunctor private: template T impl(T const &in1, T const &in2) { - T max = std::max(in1, in2); - if (std::isnan(max)) { - return std::numeric_limits::quiet_NaN(); + if (in1 == in2) { // handle signed infinities + const T log2 = std::log(T(2)); + return in1 + log2; } else { - if (std::isinf(max)) { - // if both args are -inf, and hence max is -inf - // the result is -inf as well - return max; + const T tmp = in1 - in2; + if (tmp > 0) { + return in1 + std::log1p(std::exp(-tmp)); + } + else if (tmp <= 0) { + return in2 + std::log1p(std::exp(tmp)); + } + else { + return std::numeric_limits::quiet_NaN(); } } - T min = std::min(in1, in2); - return max + std::log1p(std::exp(min - max)); + } + + template T impl_finite(T const &in) + { + return (in > 0) ? (in + std::log1p(std::exp(-in))) + : std::log1p(std::exp(in)); } };