From 4f423a15eba874ff059b6562455428f2d682ce21 Mon Sep 17 00:00:00 2001 From: lj970926 <1783973490@qq.com> Date: Mon, 6 Jan 2025 06:25:11 +0000 Subject: [PATCH 1/2] [XPU] support fp16 for mea and group_norm --- paddle/phi/backends/xpu/xpu3_op_list.cc | 5 +- paddle/phi/kernels/xpu/group_norm_kernel.cc | 80 +++-- ...e_length_memory_efficient_attention_xpu.py | 276 ++++++++++++++++++ 3 files changed, 336 insertions(+), 25 deletions(-) create mode 100644 test/xpu/test_variable_length_memory_efficient_attention_xpu.py diff --git a/paddle/phi/backends/xpu/xpu3_op_list.cc b/paddle/phi/backends/xpu/xpu3_op_list.cc index a25403a8df8db..01f894688e54a 100644 --- a/paddle/phi/backends/xpu/xpu3_op_list.cc +++ b/paddle/phi/backends/xpu/xpu3_op_list.cc @@ -1459,6 +1459,8 @@ XPUOpMap& get_kl3_ops() { phi::DataType::INT32, phi::DataType::FLOAT16, phi::DataType::FLOAT32})}, + {"variable_length_memory_efficient_attention", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"warpctc_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"warpctc", XPUKernelSet({phi::DataType::FLOAT32})}, {"where_index", @@ -1495,7 +1497,8 @@ XPUOpMap& get_kl3_ops() { phi::DataType::INT32, phi::DataType::INT64})}, {"randint", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})}, - {"group_norm", XPUKernelSet({phi::DataType::FLOAT32})}, + {"group_norm", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"group_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"meshgrid", XPUKernelSet({phi::DataType::FLOAT32, diff --git a/paddle/phi/kernels/xpu/group_norm_kernel.cc b/paddle/phi/kernels/xpu/group_norm_kernel.cc index 01435f82b2cef..037bb8c14b76b 100644 --- a/paddle/phi/kernels/xpu/group_norm_kernel.cc +++ b/paddle/phi/kernels/xpu/group_norm_kernel.cc @@ -57,37 +57,69 @@ void GroupNormKernel(const Context& dev_ctx, std::multiplies())); dev_ctx.template Alloc(y); - dev_ctx.template Alloc(mean); - dev_ctx.template Alloc(var); + dev_ctx.template Alloc(mean); + dev_ctx.template Alloc(var); auto* x_data = x.data(); auto* y_data = y->data(); - auto* mean_data = mean->data(); - auto* var_data = var->data(); + auto* mean_data = mean->data(); + auto* var_data = var->data(); - const T* scale_data = nullptr; - if (scale_ptr) scale_data = scale_ptr->data(); - const T* bias_data = nullptr; - if (bias_ptr) bias_data = bias_ptr->data(); + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + const float* scale_data = nullptr; + if (scale_ptr) { + if (std::is_same::value) { + scale_data = scale_ptr->data(); + } else { + float* scale_fp32 = RAII_GUARD.alloc_l3_or_gm(scale_ptr->numel()); + int r = xpu::cast( + dev_ctx.x_context(), + reinterpret_cast(scale_ptr->data()), + scale_fp32, + scale_ptr->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); + scale_data = scale_fp32; + } + } - auto r = - xpu::group_norm(dev_ctx.x_context(), - reinterpret_cast(x_data), - reinterpret_cast(y_data), - N, - C, - L, - 1, - groups, - static_cast(epsilon), - reinterpret_cast(scale_data), - reinterpret_cast(bias_data), - reinterpret_cast(mean_data), - reinterpret_cast(var_data), - channel_first); + const float* bias_data = nullptr; + if (bias_ptr) { + if (std::is_same::value) { + bias_data = bias_ptr->data(); + } else { + float* bias_fp32 = RAII_GUARD.alloc_l3_or_gm(bias_ptr->numel()); + int r = xpu::cast( + dev_ctx.x_context(), + reinterpret_cast(bias_ptr->data()), + bias_fp32, + bias_ptr->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); + bias_data = bias_fp32; + } + } + + int r = xpu::group_norm(dev_ctx.x_context(), + reinterpret_cast(x_data), + reinterpret_cast(y_data), + N, + C, + L, + 1, + groups, + static_cast(epsilon), + scale_data, + bias_data, + mean_data, + var_data, + channel_first); PADDLE_ENFORCE_XDNN_SUCCESS(r, "group_norm"); } } // namespace phi -PD_REGISTER_KERNEL(group_norm, XPU, ALL_LAYOUT, phi::GroupNormKernel, float) {} +PD_REGISTER_KERNEL(group_norm, + XPU, + ALL_LAYOUT, + phi::GroupNormKernel, + float, + phi::dtype::float16) {} diff --git a/test/xpu/test_variable_length_memory_efficient_attention_xpu.py b/test/xpu/test_variable_length_memory_efficient_attention_xpu.py new file mode 100644 index 0000000000000..027a35add46f3 --- /dev/null +++ b/test/xpu/test_variable_length_memory_efficient_attention_xpu.py @@ -0,0 +1,276 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle import base +from paddle.incubate.nn.functional import ( + variable_length_memory_efficient_attention, +) + +paddle.seed(2023) + + +def create_attn_mask( + mask_type, + batch_size, + seq_lens, +): + max_seq_len = max(seq_lens) + mask = paddle.zeros( + [batch_size, 1, max_seq_len, max_seq_len], dtype=mask_type + ) + for i in range(batch_size): + seq_len = seq_lens[i] + mask[i, 0, :seq_len, :seq_len] = ( + paddle.tril(paddle.ones(shape=(seq_len, seq_len), dtype=mask_type)) + - 1 + ) * 1e4 + return mask + + +def naive_attention_impl(query, key, value, mask, scale): + batch = query.shape[0] + heads = query.shape[1] + seq_len = query.shape[2] + head_dim = query.shape[3] + kv_head = key.shape[1] + + key = key.reshape([batch, kv_head, 1, seq_len, head_dim]) + key = paddle.tile(key, [1, 1, heads // kv_head, 1, 1]) + key = key.reshape([batch, heads, seq_len, head_dim]) + + value = value.reshape([batch, kv_head, 1, seq_len, head_dim]) + value = paddle.tile(value, [1, 1, heads // kv_head, 1, 1]) + value = value.reshape([batch, heads, seq_len, head_dim]) + + qk_res = paddle.matmul(query, key, transpose_y=True) + attention = qk_res * scale + attention = attention + mask + softmax_result = paddle.nn.functional.softmax(attention, -1) + result = paddle.matmul(softmax_result, value) + return result + + +class TestMemEffAttentionVariableAPI(unittest.TestCase): + def setUp(self): + self.name = "MemEffAPIVariable_fp32" + self.place = paddle.XPUPlace(0) + self.batch_size = 1 + self.num_head = 8 + self.kv_num_head = 8 + self.seq_len = 64 + self.dim_head = 16 + self.seq_lens = paddle.to_tensor( + [ + self.seq_len, + ] + * self.batch_size, + "int32", + ) + self.shape = ( + self.batch_size, + self.num_head, + self.seq_len, + self.dim_head, + ) + self.shape_kv = ( + self.batch_size, + self.kv_num_head, + self.seq_len, + self.dim_head, + ) + self.dtype = 'float32' + self.attention_mask = create_attn_mask( + self.dtype, + self.batch_size, + [ + self.seq_len, + ] + * self.batch_size, + ) + self.scale = 1.0 / np.sqrt(self.shape[-1]) + + def test_all(self): + paddle.disable_static() + + query = np.random.random(self.shape) + q = paddle.to_tensor( + query, place=self.place, dtype=self.dtype, stop_gradient=False + ) + key = np.random.random(self.shape_kv) + k = paddle.to_tensor( + key, place=self.place, dtype=self.dtype, stop_gradient=False + ) + value = np.random.random(self.shape_kv) + v = paddle.to_tensor( + value, place=self.place, dtype=self.dtype, stop_gradient=False + ) + + out_ = naive_attention_impl(q, k, v, self.attention_mask, self.scale) + + out = variable_length_memory_efficient_attention( + q, + k, + v, + self.seq_lens, + self.seq_lens, + self.attention_mask, + self.scale, + ) + + for i in range(self.batch_size): + np.testing.assert_allclose( + out.numpy()[i, :, : self.seq_lens[i], :], + out_[i, :, : self.seq_lens[i], :], + rtol=5e-03, + atol=1e-03, + ) + + +class TestMemEffAPIVariableDtypeFP16(TestMemEffAttentionVariableAPI): + def setUp(self): + self.name = "MemEffAPIVariable_fp16" + self.place = paddle.XPUPlace(0) + self.batch_size = 3 + self.num_head = 16 + self.kv_num_head = 16 + self.seq_len = 64 + self.dim_head = 32 + self.seq_lens = paddle.to_tensor( + [ + self.seq_len, + ] + * self.batch_size, + "int32", + ) + self.shape = ( + self.batch_size, + self.num_head, + self.seq_len, + self.dim_head, + ) + self.shape_kv = ( + self.batch_size, + self.kv_num_head, + self.seq_len, + self.dim_head, + ) + self.dtype = 'float16' + self.attention_mask = create_attn_mask( + self.dtype, + self.batch_size, + [ + self.seq_len, + ] + * self.batch_size, + ) + self.scale = 1.0 / np.sqrt(self.shape[-1]) + + +class TestMemEffAPIVariableDtypeFP16Static(unittest.TestCase): + def setUp(self): + self.name = "MemEffAPIVariableStatic_fp16" + self.place = paddle.XPUPlace(0) + self.batch_size = 3 + self.num_head = 16 + self.kv_num_head = 16 + self.seq_len = 64 + self.dim_head = 32 + self.seq_lens = paddle.to_tensor( + [ + self.seq_len, + ] + * self.batch_size, + "int32", + ).numpy() + self.shape = ( + self.batch_size, + self.num_head, + self.seq_len, + self.dim_head, + ) + self.shape_kv = ( + self.batch_size, + self.kv_num_head, + self.seq_len, + self.dim_head, + ) + self.dtype = 'float16' + self.attention_mask = create_attn_mask( + self.dtype, + self.batch_size, + [ + self.seq_len, + ] + * self.batch_size, + ).numpy() + self.q = np.random.random(self.shape).astype(self.dtype) + self.k = np.random.random(self.shape_kv).astype(self.dtype) + self.v = np.random.random(self.shape_kv).astype(self.dtype) + self.scale = 1.0 / np.sqrt(self.shape[-1]) + + self.ref_out = naive_attention_impl( + paddle.to_tensor(self.q), + paddle.to_tensor(self.k), + paddle.to_tensor(self.v), + paddle.to_tensor(self.attention_mask), + self.scale, + ) + + def test_all(self): + paddle.enable_static() + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + q = paddle.static.data( + name="query", shape=self.shape, dtype=self.dtype + ) + k = paddle.static.data( + name="key", shape=self.shape_kv, dtype=self.dtype + ) + v = paddle.static.data( + name="value", shape=self.shape_kv, dtype=self.dtype + ) + mask = paddle.static.data( + name="mask", + shape=[self.batch_size, 1, self.seq_len, self.seq_len], + dtype=self.dtype, + ) + seq_lens = paddle.static.data( + name="seq_lens", shape=[self.batch_size, 1], dtype="int32" + ) + out = variable_length_memory_efficient_attention( + q, k, v, seq_lens, seq_lens, mask, self.scale + ) + exe = base.Executor() + res = exe.run( + feed={ + "query": self.q, + "key": self.k, + "value": self.v, + "seq_lens": self.seq_lens, + "mask": self.attention_mask, + }, + fetch_list=[out], + ) + paddle.disable_static() + np.testing.assert_allclose(res[0], self.ref_out, rtol=5e-03, atol=1e-03) + + +if __name__ == '__main__': + unittest.main() From 82a97765ba6ffe0c44208cc16c9bd26e6e7ba78f Mon Sep 17 00:00:00 2001 From: lj970926 <1783973490@qq.com> Date: Mon, 6 Jan 2025 06:44:49 +0000 Subject: [PATCH 2/2] fix typo --- paddle/phi/kernels/xpu/group_norm_kernel.cc | 2 +- test/xpu/test_variable_length_memory_efficient_attention_xpu.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/xpu/group_norm_kernel.cc b/paddle/phi/kernels/xpu/group_norm_kernel.cc index 037bb8c14b76b..d0ae626705027 100644 --- a/paddle/phi/kernels/xpu/group_norm_kernel.cc +++ b/paddle/phi/kernels/xpu/group_norm_kernel.cc @@ -106,7 +106,7 @@ void GroupNormKernel(const Context& dev_ctx, L, 1, groups, - static_cast(epsilon), + epsilon, scale_data, bias_data, mean_data, diff --git a/test/xpu/test_variable_length_memory_efficient_attention_xpu.py b/test/xpu/test_variable_length_memory_efficient_attention_xpu.py index 027a35add46f3..96a9deab53e75 100644 --- a/test/xpu/test_variable_length_memory_efficient_attention_xpu.py +++ b/test/xpu/test_variable_length_memory_efficient_attention_xpu.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License.