Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XPU] add pos_weight for sigmoid_cross_entropy_with_logits. #55001

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@ void SigmoidCrossEntropyWithLogitsGradKernel(
int* hit = RAII_GUARD.alloc_l3_or_gm<int>(x.numel());
PADDLE_ENFORCE_NOT_NULL(
hit, errors::External("XPU alloc_l3_or_gm returns nullptr"));

auto pos_weight_data =
(pos_weight.get_ptr() == nullptr ? nullptr
: pos_weight.get_ptr()->data<T>());
// int sigmoid_cross_entropy_with_logits_grad(Context* ctx, const T* x, const
// T* label, const T* dy, T* dx, int64_t m, int64_t n, TH* hit = nullptr,
// int64_t ignore_index = -100, const T* pos_weight = nullptr);
int r = xpu::sigmoid_cross_entropy_with_logits_grad(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
Expand All @@ -56,7 +61,8 @@ void SigmoidCrossEntropyWithLogitsGradKernel(
1,
x.numel(),
hit,
ignore_index);
ignore_index,
reinterpret_cast<const XPUType*>(pos_weight_data));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "sigmoid_cross_entropy_with_logits");
if (normalize) {
int* non_zero = RAII_GUARD.alloc_l3_or_gm<int>(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,12 @@ void SigmoidCrossEntropyWithLogitsKernel(
int* hit = RAII_GUARD.alloc_l3_or_gm<int>(x.numel());
PADDLE_ENFORCE_NOT_NULL(
hit, errors::External("XPU alloc_l3_or_gm returns nullptr"));

auto pos_weight_data =
(pos_weight.get_ptr() == nullptr ? nullptr
: pos_weight.get_ptr()->data<T>());
// int sigmoid_cross_entropy_with_logits(Context* ctx, const T* x, const T*
// label, T* y, int64_t m, int64_t n, TH* hit = nullptr, int64_t ignore_index
// = -100, const T* pos_weight = nullptr);
int r = xpu::sigmoid_cross_entropy_with_logits(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
Expand All @@ -52,7 +57,8 @@ void SigmoidCrossEntropyWithLogitsKernel(
1,
x.numel(),
hit,
ignore_index);
ignore_index,
reinterpret_cast<const XPUType*>(pos_weight_data));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "sigmoid_cross_entropy_with_logits");
if (normalize) {
int* non_zero = RAII_GUARD.alloc_l3_or_gm<int>(1);
Expand Down
35 changes: 35 additions & 0 deletions test/xpu/test_sigmoid_cross_entropy_with_logits_op_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,41 @@ def set_output(self):
term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X)
self.outputs = {'Out': -term1 - term2}

class TestSigmoidCrossEntropyWithLogitsOp7(
TestSigmoidCrossEntropyWithLogitsOp
):
"""Test sigmoid_cross_entropy_with_logit_op with binary label"""

def set_inputs(self):
batch_size = [10, 10]
num_classes = 20
self.inputs = {
'X': logit(
np.random.uniform(
0, 1, tuple(batch_size + [num_classes])
).astype(self.dtype)
),
'Label': np.random.randint(
0, 2, tuple(batch_size + [num_classes])
).astype(self.dtype),
'pos_weight': np.random.uniform(
0, 1, tuple(batch_size + [num_classes])
).astype(self.dtype),
}
self.attrs = {'num_classes': num_classes, 'batch_size': batch_size}

def set_output(self):
# Fw Pass is implemented as elementwise sigmoid followed by
# elementwise logistic loss
# Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X))
term1 = np.maximum(self.inputs['X'], 0)
term2 = self.inputs['X'] * self.inputs['Label']
term3 = (
np.log(1 + np.exp(-1 * np.abs(self.inputs['X'])))
* self.inputs['pos_weight']
)
self.outputs = {'Out': term1 - term2 + term3}

class TestSigmoidCrossEntropyWithLogitsNorm(
TestSigmoidCrossEntropyWithLogitsOp
):
Expand Down