Skip to content

Commit

Permalink
Fixed log-add-exp per review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
oleksandr-pavlyk committed Aug 8, 2023
1 parent 8343edc commit 739475d
Showing 1 changed file with 25 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,16 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
const sycl::vec<argT2, vec_sz> &in2)
{
sycl::vec<resT, vec_sz> res;
auto diff = in1 - in2;
auto diff = in1 - in2; // take advantange of faster vec arithmetic

#pragma unroll
for (int i = 0; i < vec_sz; ++i) {
res[i] = impl<resT>(in1[i], in2[i]);
if (std::isfinite(diff[i])) {
res[i] = in2[i] + impl_finite<resT>(diff[i]);
}
else {
res[i] = impl<resT>(in[i], in[2]);
}
}

return res;
Expand All @@ -82,19 +87,28 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
private:
template <typename T> T impl(T const &in1, T const &in2)
{
T max = std::max<T>(in1, in2);
if (std::isnan(max)) {
return std::numeric_limits<T>::quiet_NaN();
if (in1 == in2) { // handle signed infinities
const T log2 = std::log(T(2));
return in1 + log2;
}
else {
if (std::isinf(max)) {
// if both args are -inf, and hence max is -inf
// the result is -inf as well
return max;
const T tmp = in1 - in2;
if (tmp > 0) {
return in1 + std::log1p(std::exp(-tmp));
}
else if (tmp <= 0) {
return in2 + std::log1p(std::exp(tmp));
}
else {
return std::numeric_limits<T>::quiet_NaN();
}
}
T min = std::min<T>(in1, in2);
return max + std::log1p(std::exp(min - max));
}

template <typename T> T impl_finite(T const &in)
{
return (in > 0) ? (in + std::log1p(std::exp(-in)))
: std::log1p(std::exp(in));
}
};

Expand Down

0 comments on commit 739475d

Please sign in to comment.