Skip to content

Commit

Permalink
pow_Tensor_Scalar: fixing error result when scalar exp is a bool. (#638)
Browse files Browse the repository at this point in the history
Fixing TestBinaryUfuncsXPU.test_pow_xpu_int

---------

Co-authored-by: Feng Yuan <[email protected]>
  • Loading branch information
yuchengliu1 and fengyuan14 authored Jul 24, 2024
1 parent f716a58 commit d25d3d0
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
8 changes: 7 additions & 1 deletion src/ATen/native/xpu/Pow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,13 @@ Tensor& XPUNativeFunctions::pow_out(
Tensor XPUNativeFunctions::pow(const Tensor& self, const Scalar& exponent) {
Tensor out;
auto iter = pow_tensor_scalar_meta(self, exponent, out);
native::xpu::pow_tensor_scalar_kernel(iter, exponent);
if (exponent.equal(0.0) || exponent.equal(false)) {
iter.output().fill_(1);
} else if (exponent.equal(1.0) || exponent.equal(true)) {
iter.output().copy_(self);
} else {
native::xpu::pow_tensor_scalar_kernel(iter, exponent);
}
return iter.output();
}

Expand Down
7 changes: 4 additions & 3 deletions test/xpu/run_test_with_skip.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,12 +797,13 @@ def launch_test(test_case, skip_list=None, exe_list=None):
skip_list = (
"test_fmod_remainder_by_zero_integral_xpu_int64", # zero division is an undefined behavior: different handles on different backends
"test_div_rounding_numpy_xpu_float16", # Calculation error. XPU implementation uses opmath type.
# RuntimeError: false INTERNAL ASSERT FAILED at "torch-xpu-ops/src/ATen/native/xpu/sycl/PowKernels.cpp":233, please report a bug to PyTorch. invalid combination of type in Pow function, common dtype: Short, exp is integral? 0
# fail in complex_exponents=[-1.0 - 1.5j, 3.3j]
# Mismatched elements: 33 / 100 (33.0%)
# Greatest absolute difference: 0.00038337233127094805 at index (4,) (up to 1e-05 allowed)
# Greatest relative difference: 1.9085073290625587e-06 at index (6,) (up to 1.3e-06 allowed)
"test_pow_xpu_int16",
"test_pow_xpu_int32",
"test_pow_xpu_int64",
"test_pow_xpu_int8",
"test_pow_xpu_uint8",
# AssertionError: Jiterator is only supported on CUDA and ROCm GPUs, none are available.
"_jiterator_",
# Unexpected success
Expand Down

0 comments on commit d25d3d0

Please sign in to comment.