Skip to content

Commit

Permalink
[XPU] add some bf16 ops and update xdnn (#59653)
Browse files Browse the repository at this point in the history
  • Loading branch information
houj04 authored Dec 5, 2023
1 parent ab8a16a commit e3a535a
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 12 deletions.
2 changes: 1 addition & 1 deletion cmake/external/xpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ set(XPU_XFT_LIB_NAME "libxft.so")
set(XPU_XPTI_LIB_NAME "libxpti.so")

if(NOT DEFINED XPU_BASE_DATE)
set(XPU_BASE_DATE "20231128")
set(XPU_BASE_DATE "20231203")
endif()
set(XPU_XCCL_BASE_VERSION "1.1.6.1")
if(NOT DEFINED XPU_XFT_BASE_VERSION)
Expand Down
13 changes: 10 additions & 3 deletions paddle/phi/backends/xpu/xpu3_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -814,9 +814,13 @@ XPUOpMap& get_kl3_ops() {
phi::DataType::INT32,
phi::DataType::INT64})},
{"softmax",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"softmax_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"fused_softmax_mask_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"softmax_with_cross_entropy_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
Expand Down Expand Up @@ -1094,11 +1098,14 @@ XPUOpMap& get_kl3_ops() {
{"where_grad",
XPUKernelSet({phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::FLOAT64,
phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})},
phi::DataType::BFLOAT16})},
{"where",
XPUKernelSet({phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::FLOAT64,
phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/xpu/softmax_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,5 @@ PD_REGISTER_KERNEL(softmax_grad,
ALL_LAYOUT,
phi::SoftmaxGradKernel,
float,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
9 changes: 7 additions & 2 deletions paddle/phi/kernels/xpu/softmax_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,5 +77,10 @@ void SoftmaxKernel(const Context& dev_ctx,

} // namespace phi

PD_REGISTER_KERNEL(
softmax, XPU, ALL_LAYOUT, phi::SoftmaxKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(softmax,
XPU,
ALL_LAYOUT,
phi::SoftmaxKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16) {}
6 changes: 4 additions & 2 deletions paddle/phi/kernels/xpu/where_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ PD_REGISTER_KERNEL(where_grad,
ALL_LAYOUT,
phi::WhereGradKernel,
float,
phi::dtype::float16,
double,
int,
int64_t) {}
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
1 change: 1 addition & 0 deletions paddle/phi/kernels/xpu/where_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ PD_REGISTER_KERNEL(where,
ALL_LAYOUT,
phi::WhereKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
Expand Down
6 changes: 3 additions & 3 deletions test/xpu/test_softmax_op_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self):
def dynamic_create_class(self):
base_class = self.TestSoftmaxOp
classes = []
shapes = [[2, 3, 4, 5], [7, 1], [63, 18], [2, 38512], [3, 4095]]
shapes = [[2, 3, 4, 5], [63, 18], [2, 38512], [3, 4095]]
axis = [-1, 0, 1]
for shape in shapes:
for axi in axis:
Expand All @@ -67,9 +67,9 @@ class TestSoftmaxOp(XPUOpTest):
def setUp(self):
self.op_type = "softmax"
if not hasattr(self, 'shape'):
self.shape = [1, 7]
self.shape = [2, 3, 4, 5]
self.axis = -1
self.dtype = np.float32
self.dtype = self.in_type

x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
out = np.apply_along_axis(stable_softmax, self.axis, x)
Expand Down

0 comments on commit e3a535a

Please sign in to comment.