Skip to content

Commit

Permalink
[XPU] support fp16 for mea and group_norm (#70633)
Browse files Browse the repository at this point in the history
* [XPU] support fp16 for mea and group_norm

* fix typo
  • Loading branch information
lj970926 authored Jan 7, 2025
1 parent c0284b5 commit 2ba3933
Show file tree
Hide file tree
Showing 3 changed files with 336 additions and 25 deletions.
5 changes: 4 additions & 1 deletion paddle/phi/backends/xpu/xpu3_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1459,6 +1459,8 @@ XPUOpMap& get_kl3_ops() {
phi::DataType::INT32,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})},
{"variable_length_memory_efficient_attention",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"warpctc_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"warpctc", XPUKernelSet({phi::DataType::FLOAT32})},
{"where_index",
Expand Down Expand Up @@ -1495,7 +1497,8 @@ XPUOpMap& get_kl3_ops() {
phi::DataType::INT32,
phi::DataType::INT64})},
{"randint", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})},
{"group_norm", XPUKernelSet({phi::DataType::FLOAT32})},
{"group_norm",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"group_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"meshgrid",
XPUKernelSet({phi::DataType::FLOAT32,
Expand Down
80 changes: 56 additions & 24 deletions paddle/phi/kernels/xpu/group_norm_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,37 +57,69 @@ void GroupNormKernel(const Context& dev_ctx,
std::multiplies<int>()));

dev_ctx.template Alloc<T>(y);
dev_ctx.template Alloc<T>(mean);
dev_ctx.template Alloc<T>(var);
dev_ctx.template Alloc<float>(mean);
dev_ctx.template Alloc<float>(var);

auto* x_data = x.data<T>();
auto* y_data = y->data<T>();
auto* mean_data = mean->data<T>();
auto* var_data = var->data<T>();
auto* mean_data = mean->data<float>();
auto* var_data = var->data<float>();

const T* scale_data = nullptr;
if (scale_ptr) scale_data = scale_ptr->data<T>();
const T* bias_data = nullptr;
if (bias_ptr) bias_data = bias_ptr->data<T>();
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
const float* scale_data = nullptr;
if (scale_ptr) {
if (std::is_same<T, float>::value) {
scale_data = scale_ptr->data<float>();
} else {
float* scale_fp32 = RAII_GUARD.alloc_l3_or_gm<float>(scale_ptr->numel());
int r = xpu::cast<XPUType, float>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(scale_ptr->data<T>()),
scale_fp32,
scale_ptr->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
scale_data = scale_fp32;
}
}

auto r =
xpu::group_norm<XPUType>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x_data),
reinterpret_cast<XPUType*>(y_data),
N,
C,
L,
1,
groups,
static_cast<XPUType>(epsilon),
reinterpret_cast<const XPUType*>(scale_data),
reinterpret_cast<const XPUType*>(bias_data),
reinterpret_cast<XPUType*>(mean_data),
reinterpret_cast<XPUType*>(var_data),
channel_first);
const float* bias_data = nullptr;
if (bias_ptr) {
if (std::is_same<T, float>::value) {
bias_data = bias_ptr->data<float>();
} else {
float* bias_fp32 = RAII_GUARD.alloc_l3_or_gm<float>(bias_ptr->numel());
int r = xpu::cast<XPUType, float>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(bias_ptr->data<T>()),
bias_fp32,
bias_ptr->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
bias_data = bias_fp32;
}
}

int r = xpu::group_norm<XPUType>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x_data),
reinterpret_cast<XPUType*>(y_data),
N,
C,
L,
1,
groups,
epsilon,
scale_data,
bias_data,
mean_data,
var_data,
channel_first);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "group_norm");
}

} // namespace phi

PD_REGISTER_KERNEL(group_norm, XPU, ALL_LAYOUT, phi::GroupNormKernel, float) {}
PD_REGISTER_KERNEL(group_norm,
XPU,
ALL_LAYOUT,
phi::GroupNormKernel,
float,
phi::dtype::float16) {}
276 changes: 276 additions & 0 deletions test/xpu/test_variable_length_memory_efficient_attention_xpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

import paddle
from paddle import base
from paddle.incubate.nn.functional import (
variable_length_memory_efficient_attention,
)

paddle.seed(2023)


def create_attn_mask(
mask_type,
batch_size,
seq_lens,
):
max_seq_len = max(seq_lens)
mask = paddle.zeros(
[batch_size, 1, max_seq_len, max_seq_len], dtype=mask_type
)
for i in range(batch_size):
seq_len = seq_lens[i]
mask[i, 0, :seq_len, :seq_len] = (
paddle.tril(paddle.ones(shape=(seq_len, seq_len), dtype=mask_type))
- 1
) * 1e4
return mask


def naive_attention_impl(query, key, value, mask, scale):
batch = query.shape[0]
heads = query.shape[1]
seq_len = query.shape[2]
head_dim = query.shape[3]
kv_head = key.shape[1]

key = key.reshape([batch, kv_head, 1, seq_len, head_dim])
key = paddle.tile(key, [1, 1, heads // kv_head, 1, 1])
key = key.reshape([batch, heads, seq_len, head_dim])

value = value.reshape([batch, kv_head, 1, seq_len, head_dim])
value = paddle.tile(value, [1, 1, heads // kv_head, 1, 1])
value = value.reshape([batch, heads, seq_len, head_dim])

qk_res = paddle.matmul(query, key, transpose_y=True)
attention = qk_res * scale
attention = attention + mask
softmax_result = paddle.nn.functional.softmax(attention, -1)
result = paddle.matmul(softmax_result, value)
return result


class TestMemEffAttentionVariableAPI(unittest.TestCase):
def setUp(self):
self.name = "MemEffAPIVariable_fp32"
self.place = paddle.XPUPlace(0)
self.batch_size = 1
self.num_head = 8
self.kv_num_head = 8
self.seq_len = 64
self.dim_head = 16
self.seq_lens = paddle.to_tensor(
[
self.seq_len,
]
* self.batch_size,
"int32",
)
self.shape = (
self.batch_size,
self.num_head,
self.seq_len,
self.dim_head,
)
self.shape_kv = (
self.batch_size,
self.kv_num_head,
self.seq_len,
self.dim_head,
)
self.dtype = 'float32'
self.attention_mask = create_attn_mask(
self.dtype,
self.batch_size,
[
self.seq_len,
]
* self.batch_size,
)
self.scale = 1.0 / np.sqrt(self.shape[-1])

def test_all(self):
paddle.disable_static()

query = np.random.random(self.shape)
q = paddle.to_tensor(
query, place=self.place, dtype=self.dtype, stop_gradient=False
)
key = np.random.random(self.shape_kv)
k = paddle.to_tensor(
key, place=self.place, dtype=self.dtype, stop_gradient=False
)
value = np.random.random(self.shape_kv)
v = paddle.to_tensor(
value, place=self.place, dtype=self.dtype, stop_gradient=False
)

out_ = naive_attention_impl(q, k, v, self.attention_mask, self.scale)

out = variable_length_memory_efficient_attention(
q,
k,
v,
self.seq_lens,
self.seq_lens,
self.attention_mask,
self.scale,
)

for i in range(self.batch_size):
np.testing.assert_allclose(
out.numpy()[i, :, : self.seq_lens[i], :],
out_[i, :, : self.seq_lens[i], :],
rtol=5e-03,
atol=1e-03,
)


class TestMemEffAPIVariableDtypeFP16(TestMemEffAttentionVariableAPI):
def setUp(self):
self.name = "MemEffAPIVariable_fp16"
self.place = paddle.XPUPlace(0)
self.batch_size = 3
self.num_head = 16
self.kv_num_head = 16
self.seq_len = 64
self.dim_head = 32
self.seq_lens = paddle.to_tensor(
[
self.seq_len,
]
* self.batch_size,
"int32",
)
self.shape = (
self.batch_size,
self.num_head,
self.seq_len,
self.dim_head,
)
self.shape_kv = (
self.batch_size,
self.kv_num_head,
self.seq_len,
self.dim_head,
)
self.dtype = 'float16'
self.attention_mask = create_attn_mask(
self.dtype,
self.batch_size,
[
self.seq_len,
]
* self.batch_size,
)
self.scale = 1.0 / np.sqrt(self.shape[-1])


class TestMemEffAPIVariableDtypeFP16Static(unittest.TestCase):
def setUp(self):
self.name = "MemEffAPIVariableStatic_fp16"
self.place = paddle.XPUPlace(0)
self.batch_size = 3
self.num_head = 16
self.kv_num_head = 16
self.seq_len = 64
self.dim_head = 32
self.seq_lens = paddle.to_tensor(
[
self.seq_len,
]
* self.batch_size,
"int32",
).numpy()
self.shape = (
self.batch_size,
self.num_head,
self.seq_len,
self.dim_head,
)
self.shape_kv = (
self.batch_size,
self.kv_num_head,
self.seq_len,
self.dim_head,
)
self.dtype = 'float16'
self.attention_mask = create_attn_mask(
self.dtype,
self.batch_size,
[
self.seq_len,
]
* self.batch_size,
).numpy()
self.q = np.random.random(self.shape).astype(self.dtype)
self.k = np.random.random(self.shape_kv).astype(self.dtype)
self.v = np.random.random(self.shape_kv).astype(self.dtype)
self.scale = 1.0 / np.sqrt(self.shape[-1])

self.ref_out = naive_attention_impl(
paddle.to_tensor(self.q),
paddle.to_tensor(self.k),
paddle.to_tensor(self.v),
paddle.to_tensor(self.attention_mask),
self.scale,
)

def test_all(self):
paddle.enable_static()
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
q = paddle.static.data(
name="query", shape=self.shape, dtype=self.dtype
)
k = paddle.static.data(
name="key", shape=self.shape_kv, dtype=self.dtype
)
v = paddle.static.data(
name="value", shape=self.shape_kv, dtype=self.dtype
)
mask = paddle.static.data(
name="mask",
shape=[self.batch_size, 1, self.seq_len, self.seq_len],
dtype=self.dtype,
)
seq_lens = paddle.static.data(
name="seq_lens", shape=[self.batch_size, 1], dtype="int32"
)
out = variable_length_memory_efficient_attention(
q, k, v, seq_lens, seq_lens, mask, self.scale
)
exe = base.Executor()
res = exe.run(
feed={
"query": self.q,
"key": self.k,
"value": self.v,
"seq_lens": self.seq_lens,
"mask": self.attention_mask,
},
fetch_list=[out],
)
paddle.disable_static()
np.testing.assert_allclose(res[0], self.ref_out, rtol=5e-03, atol=1e-03)


if __name__ == '__main__':
unittest.main()

0 comments on commit 2ba3933

Please sign in to comment.