Skip to content

Commit

Permalink
fix bce loss decomp bug
Browse files Browse the repository at this point in the history
  • Loading branch information
phlrain committed Jan 8, 2025
1 parent 0f72fab commit deea1f4
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions paddle/fluid/primitive/decomp_rule/decomp_rule/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,19 @@ Tensor reciprocal_decomp(const Tensor& x) {

template <typename T>
Tensor bce_loss_decomp(const Tensor& x, const Tensor& label) {
auto one = full_scalar<T>(1, x.dtype(), x.place());
auto ans = full_scalar<T>(-1, x.dtype(), x.place()) *
(label * log<T>(x) + (one - label) * log<T>(one - x));
auto org_dtype = x.dtype();
auto x_mt = ConvertToMT<T>(x);

auto neg_100 = full_scalar<T>(-100, x_mt.dtype(), x.place());
auto one = full_scalar<T>(1, x_mt.dtype(), x.place());

auto log_x = maximum<T>(log<T>(x_mt), neg_100);
auto log_1_x = maximum<T>(log<T>(one - x_mt), neg_100);

auto ans = full_scalar<T>(-1, x_mt.dtype(), x.place()) *
(label * log_x + (one - label) * log_1_x);
ans = ConvertToOrig<T>(ans, org_dtype);

return ans;
}

Expand Down

0 comments on commit deea1f4

Please sign in to comment.