From 2d5ed47e4a7dc5baf824f4bc7cfdffbe4a8308f7 Mon Sep 17 00:00:00 2001 From: "Meng,Chen" Date: Thu, 25 May 2023 09:29:20 +0000 Subject: [PATCH] fuse vit attention for faster-rcnn on BML --- paddle/fluid/framework/ir/CMakeLists.txt | 7 + .../framework/ir/graph_pattern_detector.cc | 75 +++ .../framework/ir/graph_pattern_detector.h | 27 + .../ir/mkldnn/self_attention_fuse_pass.cc | 150 ++++++ .../ir/mkldnn/self_attention_fuse_pass.h | 41 ++ .../inference/api/paddle_pass_builder.cc | 1 + paddle/fluid/operators/fused/CMakeLists.txt | 9 + .../operators/fused/scaled_dp_attention.h | 466 ++++++++++++++++++ .../operators/fused/self_dp_attention_op.cc | 124 +++++ .../operators/fused/self_dp_attention_op.h | 41 ++ test/mkldnn/test_fused_vit_attention.py | 74 +++ 11 files changed, 1015 insertions(+) create mode 100644 paddle/fluid/framework/ir/mkldnn/self_attention_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/mkldnn/self_attention_fuse_pass.h create mode 100644 paddle/fluid/operators/fused/scaled_dp_attention.h create mode 100644 paddle/fluid/operators/fused/self_dp_attention_op.cc create mode 100644 paddle/fluid/operators/fused/self_dp_attention_op.h create mode 100644 test/mkldnn/test_fused_vit_attention.py diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index f1583f5312f4a..62533102b29ed 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -200,6 +200,13 @@ if(WITH_MKLDNN) pass_library(multi_gru_seq_fuse_pass inference DIR mkldnn) pass_library(quant_dequant_mkldnn_pass inference DIR mkldnn) pass_library(compute_propagate_scales_mkldnn_pass inference DIR mkldnn) + pass_library(self_attention_fuse_pass inference DIR mkldnn) + if(WITH_AVX + AND AVX512F_FOUND + AND AVX512F_FLAG) + set_target_properties(self_attention_fuse_pass + PROPERTIES COMPILE_FLAGS "-mfma ${AVX512F_FLAG}") + endif() endif() if(WITH_IPU) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 03f1d5bc40498..319d3e4f027ae 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2615,6 +2615,81 @@ PDNode *patterns::VitAttention::operator()(PDNode *in) { return reshape2_out; } +PDNode *patterns::SelfAttention::operator()(PDNode *in) { + in->AsInput(); + + std::unordered_set matmul_ops{"matmul", "matmul_v2"}; + auto transpose2_0_op = + pattern->NewNode(transpose2_0_op_repr())->assert_is_op("transpose2"); + auto transpose2_0_out = pattern->NewNode(transpose2_0_out_repr()) + ->assert_is_op_output("transpose2", "Out") + ->assert_is_op_input("slice", "Input") + ->AsIntermediate(); + auto slice_0_op = pattern->NewNode(slice_0_op_repr())->assert_is_op("slice"); + auto slice_0_out = pattern->NewNode(slice_0_out_repr()) + ->assert_is_op_output("slice", "Out") + ->assert_is_ops_input(matmul_ops, "X") + ->AsIntermediate(); + auto slice_1_op = pattern->NewNode(slice_1_op_repr())->assert_is_op("slice"); + auto slice_1_out = pattern->NewNode(slice_1_out_repr()) + ->assert_is_op_output("slice", "Out") + ->assert_is_op_input("transpose2", "X") + ->AsIntermediate(); + auto slice_2_op = pattern->NewNode(slice_2_op_repr())->assert_is_op("slice"); + auto slice_2_out = pattern->NewNode(slice_2_out_repr()) + ->assert_is_op_output("slice", "Out") + ->assert_is_ops_input(matmul_ops, "Y") + ->AsIntermediate(); + auto matmul_0_op = + pattern->NewNode(matmul_0_op_repr())->assert_is_ops(matmul_ops); + auto matmul_0_out = pattern->NewNode(matmul_0_out_repr()) + ->assert_is_ops_output(matmul_ops, "Out") + ->assert_is_op_input("transpose2", "X") + ->AsIntermediate(); + auto matmul_1_op = + pattern->NewNode(matmul_1_op_repr())->assert_is_ops(matmul_ops); + auto matmul_1_out = pattern->NewNode(matmul_1_out_repr()) + ->assert_is_ops_output(matmul_ops, "Out") + ->assert_is_op_input("softmax", "X") + ->AsIntermediate(); + auto transpose2_1_op = + pattern->NewNode(transpose2_1_op_repr())->assert_is_op("transpose2"); + auto transpose2_1_out = pattern->NewNode(transpose2_1_out_repr()) + ->assert_is_op_output("transpose2", "Out") + ->assert_is_ops_input(matmul_ops, "Y") + ->AsIntermediate(); + auto softmax_op = + pattern->NewNode(softmax_op_repr())->assert_is_op("softmax"); + auto softmax_out = pattern->NewNode(softmax_out_repr()) + ->assert_is_op_output("softmax", "Out") + ->assert_is_ops_input(matmul_ops, "X") + ->AsIntermediate(); + auto transpose2_2_op = + pattern->NewNode(transpose2_2_op_repr())->assert_is_op("transpose2"); + auto transpose2_2_out = pattern->NewNode(transpose2_2_out_repr()) + ->assert_is_op_output("transpose2", "Out") + ->AsOutput(); + transpose2_0_op->LinksFrom({in}); + transpose2_0_out->LinksFrom({transpose2_0_op}); + slice_0_op->LinksFrom({transpose2_0_out}); + slice_0_out->LinksFrom({slice_0_op}); + slice_1_op->LinksFrom({transpose2_0_out}); + slice_1_out->LinksFrom({slice_1_op}); + slice_2_op->LinksFrom({transpose2_0_out}); + slice_2_out->LinksFrom({slice_2_op}); + transpose2_1_op->LinksFrom({slice_1_out}); + transpose2_1_out->LinksFrom({transpose2_1_op}); + matmul_1_op->LinksFrom({slice_0_out, transpose2_1_out}); + matmul_1_out->LinksFrom({matmul_1_op}); + softmax_op->LinksFrom({matmul_1_out}); + softmax_out->LinksFrom({softmax_op}); + matmul_0_op->LinksFrom({softmax_out, slice_2_out}); + matmul_0_out->LinksFrom({matmul_0_op}); + transpose2_2_op->LinksFrom({matmul_0_out}); + transpose2_2_out->LinksFrom({transpose2_2_op}); + return transpose2_2_out; +} + PDNode *patterns::ConvElementwiseadd2Act::operator()( PDNode *conv_in, const std::unordered_set &conv_act_set) { auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d"); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 1be8e13e2ec74..7e9a6ed6f4383 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1491,6 +1491,33 @@ struct VitAttention : public PatternBase { PATTERN_DECL_NODE(reshape2_out); }; +// self_attention in vit +struct SelfAttention : public PatternBase { + SelfAttention(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "vit_block") {} + + PDNode* operator()(PDNode* in); + + PATTERN_DECL_NODE(transpose2_0_op); + PATTERN_DECL_NODE(transpose2_0_out); + PATTERN_DECL_NODE(transpose2_1_op); + PATTERN_DECL_NODE(transpose2_1_out); + PATTERN_DECL_NODE(transpose2_2_op); + PATTERN_DECL_NODE(transpose2_2_out); + PATTERN_DECL_NODE(matmul_0_op); + PATTERN_DECL_NODE(matmul_0_out); + PATTERN_DECL_NODE(matmul_1_op); + PATTERN_DECL_NODE(matmul_1_out); + PATTERN_DECL_NODE(slice_0_op); + PATTERN_DECL_NODE(slice_0_out); + PATTERN_DECL_NODE(slice_1_op); + PATTERN_DECL_NODE(slice_1_out); + PATTERN_DECL_NODE(slice_2_op); + PATTERN_DECL_NODE(slice_2_out); + PATTERN_DECL_NODE(softmax_op); + PATTERN_DECL_NODE(softmax_out); +}; + // Conv + ElementwiseAdd + an activation // This pattern can further fuse the conv related ops after the conv+bn fusion. struct ConvElementwiseaddAct : public PatternBase { diff --git a/paddle/fluid/framework/ir/mkldnn/self_attention_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/self_attention_fuse_pass.cc new file mode 100644 index 0000000000000..c9bb1ba8e4d9e --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/self_attention_fuse_pass.cc @@ -0,0 +1,150 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/self_attention_fuse_pass.h" + +#include +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/string/pretty_log.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" + +#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern); +#define GET_NODES \ + GET_IR_NODE(transpose2_0_op); \ + GET_IR_NODE(transpose2_0_out); \ + GET_IR_NODE(slice_0_op); \ + GET_IR_NODE(slice_0_out); \ + GET_IR_NODE(slice_1_op); \ + GET_IR_NODE(slice_1_out); \ + GET_IR_NODE(slice_2_op); \ + GET_IR_NODE(slice_2_out); \ + GET_IR_NODE(matmul_0_op); \ + GET_IR_NODE(matmul_0_out); \ + GET_IR_NODE(matmul_1_op); \ + GET_IR_NODE(matmul_1_out); \ + GET_IR_NODE(transpose2_1_op); \ + GET_IR_NODE(transpose2_1_out); \ + GET_IR_NODE(softmax_op); \ + GET_IR_NODE(softmax_out); \ + GET_IR_NODE(transpose2_2_op); \ + GET_IR_NODE(transpose2_2_out); + +namespace paddle { +namespace framework { +namespace ir { + +using string::PrettyLogDetail; + +void SelfAttentionFusePass::ApplyImpl(ir::Graph* graph) const { +#if !defined(__AVX512F__) || !defined(PADDLE_WITH_MKLML) || \ + !defined(PADDLE_WITH_MKLDNN) + LOG(WARNING) << "No-avx512 or MKL supported!"; + return; +#endif + // do something; + GraphPatternDetector gpd; + const std::string pattern_name = "self_attention_fuse"; + FusePassBase::Init(pattern_name, graph); + + // pattern + PDNode* x = gpd.mutable_pattern() + ->NewNode("x") + ->assert_is_op_input("transpose2", "X") + ->AsInput(); + patterns::SelfAttention pattern(gpd.mutable_pattern(), pattern_name); + pattern(x); + + int fusion_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_NODES; + // do something; + OpDesc desc(transpose2_0_op->Op()->Block()); + desc.SetType("self_dp_attention"); + desc.SetInput("X", {subgraph.at(x)->Name()}); + desc.SetOutput("Out", {transpose2_2_out->Name()}); + + std::vector in_shape = subgraph.at(x)->Var()->GetShape(); + std::vector shape = transpose2_0_out->Var()->GetShape(); + // in shape should be [batch_size, seq_len, 3, num_heads, head_size] + if (in_shape.size() != 5 || in_shape[2] != 3 || shape.size() != 5 || + shape[0] != 3 || shape[2] != in_shape[3]) { + LOG(WARNING) << "Self-attention shape mismatch!"; + return; + } + desc.SetAttr("head_number", static_cast(shape[2])); + float alpha = 1.0; + if (matmul_1_op->Op()->HasAttr("alpha")) + alpha = PADDLE_GET_CONST(float, matmul_1_op->Op()->GetAttr("alpha")); + desc.SetAttr("alpha", alpha); + + // Create a new node for the fused op. + auto self_attention_node = graph->CreateOpNode(&desc); + + // Link inputs and outputs. + PADDLE_ENFORCE_NE(subgraph.count(x), + 0, + platform::errors::NotFound( + "Detector did not find input x of self attention.")); + + IR_NODE_LINK_TO(subgraph.at(x), self_attention_node); // Input + IR_NODE_LINK_TO(self_attention_node, transpose2_2_out); // Output + + // Delete the unneeded nodes. + std::unordered_set marked_nodes({transpose2_0_op, + transpose2_0_out, + slice_0_op, + slice_0_out, + slice_1_op, + slice_1_out, + slice_2_op, + slice_2_out, + matmul_0_op, + matmul_0_out, + matmul_1_op, + matmul_1_out, + transpose2_1_op, + transpose2_1_out, + softmax_op, + softmax_out, + transpose2_2_op}); + + GraphSafeRemoveNodes(graph, marked_nodes); + ++fusion_count; + }; + gpd(graph, handler); + AddStatis(fusion_count); + if (!Has("disable_logs") || !Get("disable_logs")) { + PrettyLogDetail( + "--- fused %d self attention (of scaled_dp_attention) with %s", + fusion_count, + pattern_name); + } +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(self_attention_fuse_pass, + paddle::framework::ir::SelfAttentionFusePass); +REGISTER_PASS_CAPABILITY(self_attention_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("transpose2", 0) + .EQ("slice", 0) + .EQ("scale", 0) + .EQ("softmax", 0) + .EQ("matmul_v2", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/self_attention_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/self_attention_fuse_pass.h new file mode 100644 index 0000000000000..ade48f398e3b6 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/self_attention_fuse_pass.h @@ -0,0 +1,41 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +// Fusing of self-attetion structure + +class Graph; + +class SelfAttentionFusePass : public FusePassBase { + public: + virtual ~SelfAttentionFusePass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 423256c7ec7d3..637256a693f42 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -367,6 +367,7 @@ void CpuPassStrategy::EnableMKLDNN() { "fc_mkldnn_pass", "fc_act_mkldnn_fuse_pass", "fc_elementwise_add_mkldnn_fuse_pass", // + "self_attention_fuse_pass", // "batch_norm_act_fuse_pass", // "softplus_activation_onednn_fuse_pass", // "shuffle_channel_mkldnn_detect_pass", // diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index ebb6e747f16be..bf1ae2edea1f5 100755 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -11,6 +11,7 @@ register_operators( fusion_conv_inception_op fused_fc_elementwise_layernorm_op multihead_matmul_op + self_dp_attention_op skip_layernorm_op yolo_box_head_op yolo_box_post_op @@ -33,6 +34,14 @@ register_operators( # fusion_gru_op does not have CUDA kernel op_library(fusion_gru_op) op_library(fusion_lstm_op) +if(WITH_AVX + AND AVX512F_FOUND + AND AVX512F_FLAG + AND WITH_MKL) + op_library(self_dp_attention_op) + set_target_properties(self_dp_attention_op PROPERTIES COMPILE_FLAGS + "-mfma ${AVX512F_FLAG}") +endif() if(WITH_XPU) op_library(resnet_basic_block_op) diff --git a/paddle/fluid/operators/fused/scaled_dp_attention.h b/paddle/fluid/operators/fused/scaled_dp_attention.h new file mode 100644 index 0000000000000..bc103caf4bb9d --- /dev/null +++ b/paddle/fluid/operators/fused/scaled_dp_attention.h @@ -0,0 +1,466 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef PADDLE_WITH_MKLDNN +#include "dnnl.hpp" //NOLINT +#endif + +namespace paddle { +namespace operators { + +template +void arraycpy(T* dst, const Tt* src, int n) { +#ifdef PADDLE_WITH_MKLML +#pragma omp simd +#endif + for (int i = 0; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +// batchs x tokens x 3 x head x heads -> 3 x batchs x head x tokens x heads (2 +// 0 3 1 4) +template +void transpose_before_bmm1(const T* qkvBuffer, + Tt* qkvTransBuffer, + int batchSize, + int tokenSize, + int headNum, + int headSize) { + int hiddenSize = headNum * headSize; + int blocksize = tokenSize * hiddenSize; // dst buffer stride in each batch + + const T* qBuffer = qkvBuffer; + const T* kBuffer = qkvBuffer + hiddenSize; + const T* vBuffer = qkvBuffer + hiddenSize * 2; + + Tt* q_buffer = qkvTransBuffer; + Tt* k_buffer = qkvTransBuffer + batchSize * blocksize; + Tt* v_buffer = qkvTransBuffer + batchSize * blocksize * 2; + + int bmHead = headNum; + int cols_per_bmHead = hiddenSize / headNum; // 768/12 = 64 + +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for collapse(3) +#endif + for (int i = 0; i < batchSize; i++) { + for (int k = 0; k < bmHead; k++) { + for (int j = 0; j < tokenSize; j++) { + const T* q_src_each_batch = + reinterpret_cast(qBuffer) + blocksize * 3 * i; + const T* k_src_each_batch = + reinterpret_cast(kBuffer) + blocksize * 3 * i; + const T* v_src_each_batch = + reinterpret_cast(vBuffer) + blocksize * 3 * i; + + int dst_offset_each_bmHead = k * tokenSize * cols_per_bmHead; + int src_offset_each_line = k * cols_per_bmHead; + + int dst_offset_each_line = j * cols_per_bmHead; + int src_offset_each_bmHead = j * hiddenSize * 3; + + Tt* q_dst_each_line = q_buffer + i * blocksize + + dst_offset_each_bmHead + dst_offset_each_line; + const T* q_src_each_line = + q_src_each_batch + src_offset_each_bmHead + src_offset_each_line; + + Tt* k_dst_each_line = k_buffer + i * blocksize + + dst_offset_each_bmHead + dst_offset_each_line; + const T* k_src_each_line = + k_src_each_batch + src_offset_each_bmHead + src_offset_each_line; + + Tt* v_dst_each_line = v_buffer + i * blocksize + + dst_offset_each_bmHead + dst_offset_each_line; + const T* v_src_each_line = + v_src_each_batch + src_offset_each_bmHead + src_offset_each_line; + arraycpy(q_dst_each_line, q_src_each_line, cols_per_bmHead); + arraycpy(k_dst_each_line, k_src_each_line, cols_per_bmHead); + arraycpy(v_dst_each_line, v_src_each_line, cols_per_bmHead); + } + } + } +} + +// batchs x head x tokens x heads -> batchs x tokens x head x heads (0 2 1 3) +template +void transpose_after_bmm2(T* Buffer, + Tt* TransBuffer, + int batchSize, + int tokenSize, + int headNum, + int headSize) { + int hiddenSize = headNum * headSize; + int blocksize = tokenSize * hiddenSize; // dst buffer stride in each batch + + int bmHead = headNum; + int cols_per_bmHead = hiddenSize / headNum; // 768/12 = 64 + +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for collapse(2) +#endif + for (int i = 0; i < batchSize; i++) { + for (int k = 0; k < tokenSize; k++) { + int src_offset_each_head = k * cols_per_bmHead; + int dst_offset_each_line = k * hiddenSize; + + for (int j = 0; j < bmHead; j++) { + int src_offset_each_line = j * tokenSize * cols_per_bmHead; + int dst_offset_each_head = j * cols_per_bmHead; + + Tt* q_dst_each_line = TransBuffer + dst_offset_each_head + + dst_offset_each_line + i * blocksize; + const T* q_src_each_line = Buffer + src_offset_each_line + + src_offset_each_head + i * blocksize; + + arraycpy(q_dst_each_line, q_src_each_line, cols_per_bmHead); + } + } + } +} + +// C = A * B +// bTranspose: B need to be transposed or not +void sgemm(const float* A, + const float* B, + float* C, + int m, + int n, + int k, + bool transa, + bool transb) { +#ifdef PADDLE_WITH_MKLDNN + int lda = (transa ? m : k); + int ldb = (transb ? k : n); + int ldc = n; + float alpha = 1; + float beta = 0; + char ta[] = "N"; + char tb[] = "N"; + if (transa) ta[0] = 'T'; + if (transb) tb[0] = 'T'; + + dnnl_sgemm(ta[0], tb[0], m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +#else + LOG(ERROR) << "scaled_dp_atten not supported without WITH_MKL!"; +#endif +} + +#if defined(__AVX512F__) +// exp based-on jit code +static inline __m512 vexp(const __m512& _x) { + __m512 p16f_1 = _mm512_set1_ps(1.0f); + __m512 p16f_half = _mm512_set1_ps(0.5f); + __m512 p16f_127 = _mm512_set1_ps(127.f); + __m512 p16f_exp_hi = _mm512_set1_ps(88.3762626647950f); + __m512 p16f_exp_lo = _mm512_set1_ps(-88.3762626647949f); + + __m512 p16f_cephes_LOG2EF = _mm512_set1_ps(1.44269504088896341f); + + __m512 p16f_cephes_exp_p0 = _mm512_set1_ps(1.9875691500E-4f); + __m512 p16f_cephes_exp_p1 = _mm512_set1_ps(1.3981999507E-3f); + __m512 p16f_cephes_exp_p2 = _mm512_set1_ps(8.3334519073E-3f); + __m512 p16f_cephes_exp_p3 = _mm512_set1_ps(4.1665795894E-2f); + __m512 p16f_cephes_exp_p4 = _mm512_set1_ps(1.6666665459E-1f); + __m512 p16f_cephes_exp_p5 = _mm512_set1_ps(5.0000001201E-1f); + + // Clamp x. + __m512 x = _mm512_max_ps(_mm512_min_ps(_x, p16f_exp_hi), p16f_exp_lo); + + // Express exp(x) as exp(m*ln(2) + r), start by extracting + // m = floor(x/ln(2) + 0.5). + __m512 m = _mm512_floor_ps(_mm512_fmadd_ps(x, p16f_cephes_LOG2EF, p16f_half)); + + // Get r = x - m*ln(2). + __m512 p16f_nln2 = _mm512_set1_ps(-0.6931471805599453f); + __m512 r = _mm512_fmadd_ps(m, p16f_nln2, x); + + __m512 r2 = _mm512_mul_ps(r, r); + + __m512 y = p16f_cephes_exp_p0; + y = _mm512_fmadd_ps(y, r, p16f_cephes_exp_p1); + y = _mm512_fmadd_ps(y, r, p16f_cephes_exp_p2); + y = _mm512_fmadd_ps(y, r, p16f_cephes_exp_p3); + y = _mm512_fmadd_ps(y, r, p16f_cephes_exp_p4); + y = _mm512_fmadd_ps(y, r, p16f_cephes_exp_p5); + y = _mm512_fmadd_ps(y, r2, r); + y = _mm512_add_ps(y, p16f_1); + + // Build emm0 = 2^m. + __m512i emm0 = _mm512_cvttps_epi32(_mm512_add_ps(m, p16f_127)); + emm0 = _mm512_slli_epi32(emm0, 23); + + // Return 2^m * exp(r). + return _mm512_max_ps(_mm512_mul_ps(y, _mm512_castsi512_ps(emm0)), _x); +} + +// need to do for res. +void softmax_sum_max(float* AB, + float* sum, + float* max, + float* pre_sum, + float* pre_max, + float refac, + int m, + int k) { + assert(k % 16 == 0); + float max_val = std::numeric_limits::lowest(); + __m512 vrefac = _mm512_set1_ps(refac); + for (int i = 0; i < m; ++i) { + float* buf = AB + i * k; + // max val for avoiding inf and nan + __m512 vmax = _mm512_set1_ps(max_val); + for (int off = 0; off < k; off += 16) { + int remain = k - off; + __mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1); + __m512 vx = _mm512_maskz_loadu_ps(mask, buf + off); + + vmax = _mm512_mask_max_ps(vmax, mask, vmax, vx); + } + float _max = _mm512_reduce_max_ps(vmax); + + _max *= refac; + _max = _max > max[i] ? _max : max[i]; + __m512 merr = _mm512_set1_ps(max[i] - _max); + merr = vexp(merr); + max[i] = _max; + + // exp and get sum + __m512 vsum = _mm512_set1_ps(0); + vmax = _mm512_set1_ps(_max); + for (int off = 0; off < k; off += 16) { + int remain = k - off; + __mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1); + + __m512 vx = _mm512_maskz_loadu_ps(mask, buf + off); + vx = vexp(vx * vrefac - vmax); + + _mm512_mask_storeu_ps(buf + off, mask, vx); + + vsum = _mm512_mask_add_ps(vsum, mask, vsum, vx); + } + float _sum = _mm512_reduce_add_ps(vsum); + float fac = _mm512_cvtss_f32(merr); + sum[i] = sum[i] * fac + _sum; + _sum = sum[i]; + + // Compute exp/sum(exp) and store + __m512 vrsum = _mm512_set1_ps(1.0f / _sum); + for (int off = 0; off < k; off += 16) { + int remain = k - off; + __mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1); + + __m512 vx = _mm512_maskz_loadu_ps(mask, buf + off); + vx = vx * vrsum; + + _mm512_mask_storeu_ps(buf + off, mask, vx); + } + } +} + +void update_out_blk(float* output, + const float* exp_ABC, + float* pre_sum, + float* sum, + float* pre_max, + float* max, + int m, + int n) { + assert(n % 16 == 0); + for (int i = 0; i < m; ++i) { + const float* buf = exp_ABC + i * n; + float* outbuf = output + i * n; + __m512 merr = _mm512_set1_ps(pre_max[i] - max[i]); + merr = vexp(merr); + __m512 vfac = _mm512_set1_ps(pre_sum[i] / sum[i]); + for (int off = 0; off < n; off += 16) { + __m512 vout = _mm512_loadu_ps(outbuf + off); + __m512 vabc = _mm512_loadu_ps(buf + off); + __m512 vupt = vout * merr * vfac + vabc; + _mm512_storeu_ps(outbuf + off, vupt); + } + pre_sum[i] = sum[i]; + pre_max[i] = max[i]; + } +} +#endif + +// hard code: axis = 1 +// sum += sum(exp(A[i])) +// output = output * pre_sum / sum + (exp(A) / sum) x B +// pre_sum = sum +void incremental_tile_attention(const float* A, + const float* B, + const float* C, + int m, + int n, + int k, + float* pre_sum, + float* sum, + float* pre_max, + float* max, + float refac, + float* AB, + float* exp_ABC, + float* output) { + sgemm(A, B, AB, m, k, n, false, true); + softmax_sum_max(AB, sum, max, pre_sum, pre_max, refac, m, k); + sgemm(AB, C, exp_ABC, m, n, k, false, false); + update_out_blk(output, exp_ABC, pre_sum, sum, pre_max, max, m, n); +} + +// scaled dot-product attention: bmm1 + softmax + bmm2 +void scaled_dp_attention(const float* query, + const float* key, + const float* value, + float scale, + int batch_size, + int itsize, + int otsize, + int num_head, + int head_size, + float* output) { + // output = trans(softmax(query * trans(key)) * value) + int iblk = std::min(512, itsize / 1); + int oblk = std::min(512, otsize / 1); + float refac = scale; + assert(itsize % iblk == 0); + assert(otsize % oblk == 0); + +#ifdef PADDLE_WITH_MKLML + int nth = omp_get_max_threads(); +#else + int nth = 1; +#endif + + float** pre_sum; + float** sum; + float** pre_max; + float** max; + float** qk_arr; + float** exp_qkv_arr; + pre_sum = new float*[nth]; + sum = new float*[nth]; + pre_max = new float*[nth]; + max = new float*[nth]; + qk_arr = new float*[nth]; + exp_qkv_arr = new float*[nth]; + for (int i = 0; i < nth; ++i) { + pre_sum[i] = new float[iblk]; + sum[i] = new float[iblk]; + pre_max[i] = new float[iblk]; + max[i] = new float[iblk]; + qk_arr[i] = new float[iblk * oblk]; + exp_qkv_arr[i] = new float[iblk * head_size]; + } + +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for collapse(3) +#endif + for (int i = 0; i < batch_size; ++i) { + for (int j = 0; j < num_head; ++j) { + for (int m = 0; m < itsize; m += iblk) { +#ifdef PADDLE_WITH_MKLML + int tid = omp_get_thread_num(); +#else + int tid = 0; +#endif + int ooffset = + i * num_head * otsize * head_size + j * otsize * head_size; + const float* k = key + ooffset; + const float* v = value + ooffset; + + int q_rblk = std::min(iblk, itsize - m); + int ioffset = + i * num_head * otsize * head_size + j * otsize * head_size; + const float* q = query + ioffset + m * head_size; + float* out = output + ioffset + m * head_size; + + // reset out + for (int ii = 0; ii < q_rblk; ++ii) { +#ifdef PADDLE_WITH_MKLML +#pragma omp simd +#endif + for (int jj = 0; jj < head_size; ++jj) { + out[ii * head_size + jj] = 0; // reset output + } + } + // reset sum +#ifdef PADDLE_WITH_MKLML +#pragma omp simd +#endif + for (int ii = 0; ii < q_rblk; ++ii) { + pre_sum[tid][ii] = 0; + sum[tid][ii] = 0; + pre_max[tid][ii] = std::numeric_limits::lowest(); + max[tid][ii] = std::numeric_limits::lowest(); + } + // + for (int b = 0; b < otsize; b += oblk) { + int kv_rblk = std::min(oblk, otsize - b); + const float* blk_k = k + b * head_size; + const float* blk_v = v + b * head_size; + + incremental_tile_attention(q, + blk_k, + blk_v, + q_rblk, + head_size, + kv_rblk, + pre_sum[tid], + sum[tid], + pre_max[tid], + max[tid], + refac, + qk_arr[tid], + exp_qkv_arr[tid], + out); + } + } + } + } + + for (int i = 0; i < nth; ++i) { + delete[] pre_sum[i]; + delete[] sum[i]; + delete[] pre_max[i]; + delete[] max[i]; + delete[] qk_arr[i]; + delete[] exp_qkv_arr[i]; + } + delete[] pre_sum; + delete[] sum; + delete[] pre_max; + delete[] max; + delete[] qk_arr; + delete[] exp_qkv_arr; + + return; +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fused/self_dp_attention_op.cc b/paddle/fluid/operators/fused/self_dp_attention_op.cc new file mode 100644 index 0000000000000..04c7424a80dc5 --- /dev/null +++ b/paddle/fluid/operators/fused/self_dp_attention_op.cc @@ -0,0 +1,124 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/fused/self_dp_attention_op.h" +#include "paddle/fluid/operators/fused/scaled_dp_attention.h" + +namespace paddle { +namespace operators { + +void SelfDPAttenOp::InferShape(framework::InferShapeContext* ctx) const { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SelfDPAtten"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SelfDPAtten"); + + auto dim_input = ctx->GetInputDim("X"); + PADDLE_ENFORCE_EQ(dim_input.size(), + 5, + platform::errors::InvalidArgument( + "The size of input X dims should be 5, " + "[batchsize, tokensize, 3, nhead, headsize] " + ", but now Input X dim is:[%s] ", + dim_input)); + PADDLE_ENFORCE_EQ(dim_input[4] % 16, + 0, + platform::errors::InvalidArgument( + "The last dim of input X should be a multiple of 16, " + ", but now the dim is:[%d] " + "Please remove self_attention_fuse_pass from the lists", + dim_input[4])); + framework::DDim out_dims( + {dim_input[0], dim_input[1], dim_input[3], dim_input[4]}); + ctx->SetOutputDim("Out", out_dims); + ctx->ShareLoD("X", /*->*/ "Out"); +} + +phi::KernelKey SelfDPAttenOp::GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); +} + +void SelfDPAttenOpMaker::Make() { + AddInput("X", "(LoDTensor) Input tensors of this operator."); + AddOutput("Out", "(LoDTensor) Output tensor of this operator."); + AddAttr("alpha", "The scale of Out").SetDefault(1.0f); + AddAttr("head_number", "The number of heads of the matrix") + .SetDefault(1); + AddComment(R"DOC( + Multihead Self-scaled-dp-Attention Operator. +)DOC"); +} + +template +class SelfDPAttenKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + using DeviceContext = phi::CPUContext; + auto* in = ctx.Input("X"); + auto* out = ctx.Output("Out"); + auto place = ctx.GetPlace(); + auto* input_d = in->data(); + auto* output_d = out->mutable_data(place); + float scale = static_cast(ctx.Attr("alpha")); + int head_number = ctx.Attr("head_number"); + auto input_dims = in->dims(); + // in shouble be (batch * seq * 3 * head_num * head_size) + // out shouble be (batch * seq * head_num * head_size) + int batch_size = input_dims[0]; + int seq_len = input_dims[1]; + int head_size = input_dims[4]; + + auto& dev_ctx = ctx.template device_context(); + phi::DenseTensor temp1 = + ctx.AllocateTmpTensor(input_dims, dev_ctx); + float* trans_input = temp1.mutable_data(place); + phi::DenseTensor temp2 = + ctx.AllocateTmpTensor(input_dims, dev_ctx); + float* trans_output = temp2.mutable_data(place); + + transpose_before_bmm1( + input_d, trans_input, batch_size, seq_len, head_number, head_size); + float* query = trans_input; + float* key = trans_input + batch_size * head_number * seq_len * head_size; + float* value = + trans_input + batch_size * head_number * seq_len * head_size * 2; + + scaled_dp_attention(query, + key, + value, + scale, + batch_size, + seq_len, + seq_len, + head_number, + head_size, + trans_output); + transpose_after_bmm2( + trans_output, output_d, batch_size, seq_len, head_number, head_size); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(self_dp_attention, + ops::SelfDPAttenOp, + ops::SelfDPAttenOpMaker); + +REGISTER_OP_KERNEL(self_dp_attention, + CPU, + phi::CPUPlace, + ops::SelfDPAttenKernel, + ops::SelfDPAttenKernel); diff --git a/paddle/fluid/operators/fused/self_dp_attention_op.h b/paddle/fluid/operators/fused/self_dp_attention_op.h new file mode 100644 index 0000000000000..f9c81ec7f9c7b --- /dev/null +++ b/paddle/fluid/operators/fused/self_dp_attention_op.h @@ -0,0 +1,41 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using LoDTensor = phi::DenseTensor; +using Tensor = phi::DenseTensor; + +class SelfDPAttenOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override; + + protected: + phi::KernelKey GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override; +}; + +class SelfDPAttenOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override; +}; + +} // namespace operators +} // namespace paddle diff --git a/test/mkldnn/test_fused_vit_attention.py b/test/mkldnn/test_fused_vit_attention.py new file mode 100644 index 0000000000000..c3718886544ef --- /dev/null +++ b/test/mkldnn/test_fused_vit_attention.py @@ -0,0 +1,74 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np + +import paddle +import paddle.incubate +from paddle.fluid import core + +paddle.enable_static() +np.random.seed(0) + + +def test_fuse_resenet_unit(): + place = paddle.CPUPlace() + program = paddle.static.Program() + startup_program = paddle.static.Program() + batch_size = 1 + token_size = 4097 + hidden_size = 768 + num_heads = 12 + dtype = np.float32 + with paddle.static.program_guard(program, startup_program): + x = paddle.static.data( + "x", [batch_size, token_size, hidden_size * 3], dtype=dtype + ) + qkv = x.reshape( + (batch_size, token_size, 3, num_heads, hidden_size // num_heads) + ).transpose((2, 0, 3, 1, 4)) + + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = q.matmul(k.transpose((0, 1, 3, 2))) + + attn = paddle.nn.functional.softmax(attn, axis=-1) + + out = ( + (attn.matmul(v)) + .transpose((0, 2, 1, 3)) + .reshape((-1, token_size, hidden_size)) + ) + + graph = core.Graph(program.desc) + core.get_pass("self_attention_fuse_pass").apply(graph) + after_program = paddle.fluid.framework.IrGraph(graph).to_program() + exe = paddle.static.Executor(place) + exe.run(startup_program) + + feed = { + "x": np.random.randn(batch_size, token_size, hidden_size * 3).astype( + dtype + ) + } + before_out = exe.run(program, feed=feed, fetch_list=[out.name]) + after_out = exe.run(after_program, feed=feed, fetch_list=[out.name]) + np.testing.assert_allclose( + before_out[0], after_out[0], rtol=1e-05, atol=0.005 + ) + + +if __name__ == '__main__': + test_fuse_resenet_unit()