Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XPU] support fp16 for mea and group_norm #70633

Merged
merged 2 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion paddle/phi/backends/xpu/xpu3_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1459,6 +1459,8 @@ XPUOpMap& get_kl3_ops() {
phi::DataType::INT32,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})},
{"variable_length_memory_efficient_attention",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"warpctc_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"warpctc", XPUKernelSet({phi::DataType::FLOAT32})},
{"where_index",
Expand Down Expand Up @@ -1495,7 +1497,8 @@ XPUOpMap& get_kl3_ops() {
phi::DataType::INT32,
phi::DataType::INT64})},
{"randint", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})},
{"group_norm", XPUKernelSet({phi::DataType::FLOAT32})},
{"group_norm",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"group_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"meshgrid",
XPUKernelSet({phi::DataType::FLOAT32,
Expand Down
80 changes: 56 additions & 24 deletions paddle/phi/kernels/xpu/group_norm_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,37 +57,69 @@ void GroupNormKernel(const Context& dev_ctx,
std::multiplies<int>()));

dev_ctx.template Alloc<T>(y);
dev_ctx.template Alloc<T>(mean);
dev_ctx.template Alloc<T>(var);
dev_ctx.template Alloc<float>(mean);
dev_ctx.template Alloc<float>(var);

auto* x_data = x.data<T>();
auto* y_data = y->data<T>();
auto* mean_data = mean->data<T>();
auto* var_data = var->data<T>();
auto* mean_data = mean->data<float>();
auto* var_data = var->data<float>();

const T* scale_data = nullptr;
if (scale_ptr) scale_data = scale_ptr->data<T>();
const T* bias_data = nullptr;
if (bias_ptr) bias_data = bias_ptr->data<T>();
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
const float* scale_data = nullptr;
if (scale_ptr) {
if (std::is_same<T, float>::value) {
scale_data = scale_ptr->data<float>();
} else {
float* scale_fp32 = RAII_GUARD.alloc_l3_or_gm<float>(scale_ptr->numel());
int r = xpu::cast<XPUType, float>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(scale_ptr->data<T>()),
scale_fp32,
scale_ptr->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
scale_data = scale_fp32;
}
}

auto r =
xpu::group_norm<XPUType>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x_data),
reinterpret_cast<XPUType*>(y_data),
N,
C,
L,
1,
groups,
static_cast<XPUType>(epsilon),
reinterpret_cast<const XPUType*>(scale_data),
reinterpret_cast<const XPUType*>(bias_data),
reinterpret_cast<XPUType*>(mean_data),
reinterpret_cast<XPUType*>(var_data),
channel_first);
const float* bias_data = nullptr;
if (bias_ptr) {
if (std::is_same<T, float>::value) {
bias_data = bias_ptr->data<float>();
} else {
float* bias_fp32 = RAII_GUARD.alloc_l3_or_gm<float>(bias_ptr->numel());
int r = xpu::cast<XPUType, float>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(bias_ptr->data<T>()),
bias_fp32,
bias_ptr->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
bias_data = bias_fp32;
}
}

int r = xpu::group_norm<XPUType>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x_data),
reinterpret_cast<XPUType*>(y_data),
N,
C,
L,
1,
groups,
epsilon,
scale_data,
bias_data,
mean_data,
var_data,
channel_first);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "group_norm");
}

} // namespace phi

PD_REGISTER_KERNEL(group_norm, XPU, ALL_LAYOUT, phi::GroupNormKernel, float) {}
PD_REGISTER_KERNEL(group_norm,
XPU,
ALL_LAYOUT,
phi::GroupNormKernel,
float,
phi::dtype::float16) {}
276 changes: 276 additions & 0 deletions test/xpu/test_variable_length_memory_efficient_attention_xpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

import paddle
from paddle import base
from paddle.incubate.nn.functional import (
variable_length_memory_efficient_attention,
)

paddle.seed(2023)


def create_attn_mask(
mask_type,
batch_size,
seq_lens,
):
max_seq_len = max(seq_lens)
mask = paddle.zeros(
[batch_size, 1, max_seq_len, max_seq_len], dtype=mask_type
)
for i in range(batch_size):
seq_len = seq_lens[i]
mask[i, 0, :seq_len, :seq_len] = (
paddle.tril(paddle.ones(shape=(seq_len, seq_len), dtype=mask_type))
- 1
) * 1e4
return mask


def naive_attention_impl(query, key, value, mask, scale):
batch = query.shape[0]
heads = query.shape[1]
seq_len = query.shape[2]
head_dim = query.shape[3]
kv_head = key.shape[1]

key = key.reshape([batch, kv_head, 1, seq_len, head_dim])
key = paddle.tile(key, [1, 1, heads // kv_head, 1, 1])
key = key.reshape([batch, heads, seq_len, head_dim])

value = value.reshape([batch, kv_head, 1, seq_len, head_dim])
value = paddle.tile(value, [1, 1, heads // kv_head, 1, 1])
value = value.reshape([batch, heads, seq_len, head_dim])

qk_res = paddle.matmul(query, key, transpose_y=True)
attention = qk_res * scale
attention = attention + mask
softmax_result = paddle.nn.functional.softmax(attention, -1)
result = paddle.matmul(softmax_result, value)
return result


class TestMemEffAttentionVariableAPI(unittest.TestCase):
def setUp(self):
self.name = "MemEffAPIVariable_fp32"
self.place = paddle.XPUPlace(0)
self.batch_size = 1
self.num_head = 8
self.kv_num_head = 8
self.seq_len = 64
self.dim_head = 16
self.seq_lens = paddle.to_tensor(
[
self.seq_len,
]
* self.batch_size,
"int32",
)
self.shape = (
self.batch_size,
self.num_head,
self.seq_len,
self.dim_head,
)
self.shape_kv = (
self.batch_size,
self.kv_num_head,
self.seq_len,
self.dim_head,
)
self.dtype = 'float32'
self.attention_mask = create_attn_mask(
self.dtype,
self.batch_size,
[
self.seq_len,
]
* self.batch_size,
)
self.scale = 1.0 / np.sqrt(self.shape[-1])

def test_all(self):
paddle.disable_static()

query = np.random.random(self.shape)
q = paddle.to_tensor(
query, place=self.place, dtype=self.dtype, stop_gradient=False
)
key = np.random.random(self.shape_kv)
k = paddle.to_tensor(
key, place=self.place, dtype=self.dtype, stop_gradient=False
)
value = np.random.random(self.shape_kv)
v = paddle.to_tensor(
value, place=self.place, dtype=self.dtype, stop_gradient=False
)

out_ = naive_attention_impl(q, k, v, self.attention_mask, self.scale)

out = variable_length_memory_efficient_attention(
q,
k,
v,
self.seq_lens,
self.seq_lens,
self.attention_mask,
self.scale,
)

for i in range(self.batch_size):
np.testing.assert_allclose(
out.numpy()[i, :, : self.seq_lens[i], :],
out_[i, :, : self.seq_lens[i], :],
rtol=5e-03,
houj04 marked this conversation as resolved.
Show resolved Hide resolved
atol=1e-03,
)


class TestMemEffAPIVariableDtypeFP16(TestMemEffAttentionVariableAPI):
def setUp(self):
self.name = "MemEffAPIVariable_fp16"
self.place = paddle.XPUPlace(0)
self.batch_size = 3
self.num_head = 16
self.kv_num_head = 16
self.seq_len = 64
self.dim_head = 32
self.seq_lens = paddle.to_tensor(
[
self.seq_len,
]
* self.batch_size,
"int32",
)
self.shape = (
self.batch_size,
self.num_head,
self.seq_len,
self.dim_head,
)
self.shape_kv = (
self.batch_size,
self.kv_num_head,
self.seq_len,
self.dim_head,
)
self.dtype = 'float16'
self.attention_mask = create_attn_mask(
self.dtype,
self.batch_size,
[
self.seq_len,
]
* self.batch_size,
)
self.scale = 1.0 / np.sqrt(self.shape[-1])


class TestMemEffAPIVariableDtypeFP16Static(unittest.TestCase):
def setUp(self):
self.name = "MemEffAPIVariableStatic_fp16"
self.place = paddle.XPUPlace(0)
self.batch_size = 3
self.num_head = 16
self.kv_num_head = 16
self.seq_len = 64
self.dim_head = 32
self.seq_lens = paddle.to_tensor(
[
self.seq_len,
]
* self.batch_size,
"int32",
).numpy()
self.shape = (
self.batch_size,
self.num_head,
self.seq_len,
self.dim_head,
)
self.shape_kv = (
self.batch_size,
self.kv_num_head,
self.seq_len,
self.dim_head,
)
self.dtype = 'float16'
self.attention_mask = create_attn_mask(
self.dtype,
self.batch_size,
[
self.seq_len,
]
* self.batch_size,
).numpy()
self.q = np.random.random(self.shape).astype(self.dtype)
self.k = np.random.random(self.shape_kv).astype(self.dtype)
self.v = np.random.random(self.shape_kv).astype(self.dtype)
self.scale = 1.0 / np.sqrt(self.shape[-1])

self.ref_out = naive_attention_impl(
paddle.to_tensor(self.q),
paddle.to_tensor(self.k),
paddle.to_tensor(self.v),
paddle.to_tensor(self.attention_mask),
self.scale,
)

def test_all(self):
paddle.enable_static()
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
q = paddle.static.data(
name="query", shape=self.shape, dtype=self.dtype
)
k = paddle.static.data(
name="key", shape=self.shape_kv, dtype=self.dtype
)
v = paddle.static.data(
name="value", shape=self.shape_kv, dtype=self.dtype
)
mask = paddle.static.data(
name="mask",
shape=[self.batch_size, 1, self.seq_len, self.seq_len],
dtype=self.dtype,
)
seq_lens = paddle.static.data(
name="seq_lens", shape=[self.batch_size, 1], dtype="int32"
)
out = variable_length_memory_efficient_attention(
q, k, v, seq_lens, seq_lens, mask, self.scale
)
exe = base.Executor()
res = exe.run(
feed={
"query": self.q,
"key": self.k,
"value": self.v,
"seq_lens": self.seq_lens,
"mask": self.attention_mask,
},
fetch_list=[out],
)
paddle.disable_static()
np.testing.assert_allclose(res[0], self.ref_out, rtol=5e-03, atol=1e-03)


if __name__ == '__main__':
unittest.main()
Loading