From 1a8786cf051cd7b70711e22ee1ca1e5c0d1b11e4 Mon Sep 17 00:00:00 2001 From: Li Min <11663212+limin2021@users.noreply.github.com> Date: Tue, 23 Nov 2021 20:19:16 +0800 Subject: [PATCH 1/4] Add support bias is none for fused_attention op. (#37411) Add support for bias is none for fused_attention op. --- .../operators/fused/fused_attention_op.cc | 77 ++++++++----- .../operators/fused/fused_attention_op.cu | 107 +++++++++++++----- .../unittests/test_fused_attention_op.py | 57 ++++++++-- .../nn/functional/fused_transformer.py | 3 + 4 files changed, 173 insertions(+), 71 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc index 11601a5ce40d5a..39c7e52cc465c7 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -28,12 +28,8 @@ class FusedAttentionOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext *ctx) const override { OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias", - "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW", "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias", - "FusedAttentionOp"); if (ctx->Attrs().Get("pre_layer_norm") == true) { OP_INOUT_CHECK(ctx->HasOutput("LnMean"), "Output", "LnMean", @@ -54,8 +50,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel { // qkv_out: [batch_size, seq_len, 3, num_head, dim_head] OP_INOUT_CHECK(ctx->HasOutput("QKVOut"), "Output", "QKVOut", "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasOutput("QKVBiasOut"), "Output", "QKVBiasOut", - "FusedAttentionOp"); + if (ctx->HasInput("QKVBias")) { + OP_INOUT_CHECK(ctx->HasOutput("QKVBiasOut"), "Output", "QKVBiasOut", + "FusedAttentionOp"); + } OP_INOUT_CHECK(ctx->HasOutput("TransposeOut2"), "Output", "TransposeOut2", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("QKOut"), "Output", "QKOut", @@ -107,6 +105,13 @@ class FusedAttentionOp : public framework::OperatorWithKernel { "input qkv_weight = [%s]", x_dim, y_dim)); + PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2], y_dim[3], + platform::errors::InvalidArgument( + "The dimensions of qkv_weight must be 4" + "(3, num_head, dim_head, dim_embed)," + "and must satisfy the limitations: " + "(num_head * dim_head == dim_embed)")); + if (ctx->Attrs().Get("pre_layer_norm") == true) { ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]}); ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]}); @@ -119,8 +124,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel { // [batch_size, seq_len, 3, num_head, head_size] ctx->SetOutputDim("QKVOut", {x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]}); - ctx->SetOutputDim("QKVBiasOut", - {x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]}); + + if (ctx->HasInput("QKVBias")) { + ctx->SetOutputDim("QKVBiasOut", + {x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]}); + } // [3, batch_size, num_head, seq_len, head_size] ctx->SetOutputDim("TransposeOut2", {y_dim[0], x_dim[0], y_dim[1], x_dim[1], y_dim[2]}); @@ -173,11 +181,11 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { "H. Here, H represents the last dimension of its input tensor.") .AsDispensable(); AddInput("QKVW", "The qkv weight tensor."); - AddInput("QKVBias", "The qkv bias tensor."); + AddInput("QKVBias", "The qkv bias tensor.").AsDispensable(); AddInput("SrcMask", "(optional) The attention mask tensor in fmha.") .AsDispensable(); AddInput("OutLinearW", "The out_linear weight tensor."); - AddInput("OutLinearBias", "The out_linear bias tensor."); + AddInput("OutLinearBias", "The out_linear bias tensor.").AsDispensable(); AddInput("Ln2Scale", "(optional) Scale is a 1-dimensional tensor of size " "H. Here, H represents the last dimension of its input tensor.") @@ -379,12 +387,8 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionGrad"); OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", "FusedAttentionGrad"); - OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias", - "FusedAttentionGrad"); OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW", "FusedAttentionGrad"); - OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias", - "FusedAttentionGrad"); if (ctx->Attrs().Get("pre_layer_norm") == true) { if (ctx->HasOutput(framework::GradVarName("LnScale"))) { @@ -399,14 +403,17 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { if (ctx->HasOutput(framework::GradVarName("X"))) { ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } - - ctx->SetOutputDim(framework::GradVarName("OutLinearBias"), - ctx->GetInputDim("OutLinearBias")); + if (ctx->HasOutput(framework::GradVarName("OutLinearBias"))) { + ctx->SetOutputDim(framework::GradVarName("OutLinearBias"), + ctx->GetInputDim("OutLinearBias")); + } ctx->SetOutputDim(framework::GradVarName("OutLinearW"), ctx->GetInputDim("OutLinearW")); ctx->SetOutputDim(framework::GradVarName("QKVW"), ctx->GetInputDim("QKVW")); - ctx->SetOutputDim(framework::GradVarName("QKVBias"), - ctx->GetInputDim("QKVBias")); + if (ctx->HasOutput(framework::GradVarName("QKVBias"))) { + ctx->SetOutputDim(framework::GradVarName("QKVBias"), + ctx->GetInputDim("QKVBias")); + } if (ctx->Attrs().Get("pre_layer_norm") == true) { ctx->SetOutputDim(framework::GradVarName("LnOut"), @@ -434,8 +441,10 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { } ctx->SetOutputDim(framework::GradVarName("QKVOut"), ctx->GetInputDim("QKVOut")); - ctx->SetOutputDim(framework::GradVarName("QKVBiasOut"), - ctx->GetInputDim("QKVBiasOut")); + if (ctx->HasOutput(framework::GradVarName("QKVBiasOut"))) { + ctx->SetOutputDim(framework::GradVarName("QKVBiasOut"), + ctx->GetInputDim("QKVBiasOut")); + } ctx->SetOutputDim(framework::GradVarName("OutLinearOut"), ctx->GetInputDim("OutLinearOut")); } @@ -462,7 +471,15 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { // inputs x, parameters and their grad. op->SetInput("X", this->Input("X")); op->SetInput("QKVW", this->Input("QKVW")); - op->SetInput("QKVBias", this->Input("QKVBias")); + + if (this->HasInput("QKVBias")) { + op->SetInput("QKVBias", this->Input("QKVBias")); + op->SetOutput(framework::GradVarName("QKVBias"), + this->InputGrad("QKVBias")); + op->SetInput("QKVBiasOut", this->Output("QKVBiasOut")); + op->SetOutput(framework::GradVarName("QKVBiasOut"), + this->OutputGrad("QKVBiasOut")); + } if (this->HasInput("SrcMask")) { op->SetInput("SrcMask", this->Input("SrcMask")); @@ -472,7 +489,11 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { } op->SetInput("OutLinearW", this->Input("OutLinearW")); - op->SetInput("OutLinearBias", this->Input("OutLinearBias")); + if (this->HasInput("OutLinearBias")) { + op->SetInput("OutLinearBias", this->Input("OutLinearBias")); + op->SetOutput(framework::GradVarName("OutLinearBias"), + this->InputGrad("OutLinearBias")); + } op->SetAttrMap(this->Attrs()); bool is_pre_layer_norm = @@ -503,10 +524,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("QKVW"), this->InputGrad("QKVW")); - op->SetOutput(framework::GradVarName("QKVBias"), - this->InputGrad("QKVBias")); - op->SetOutput(framework::GradVarName("OutLinearBias"), - this->InputGrad("OutLinearBias")); + op->SetOutput(framework::GradVarName("OutLinearW"), this->InputGrad("OutLinearW")); @@ -528,7 +546,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { this->Output("BiasDropoutResidualOut")); } op->SetInput("QKVOut", this->Output("QKVOut")); - op->SetInput("QKVBiasOut", this->Output("QKVBiasOut")); + op->SetInput("TransposeOut2", this->Output("TransposeOut2")); op->SetInput("QKOut", this->Output("QKOut")); op->SetInput("QKTVOut", this->Output("QKTVOut")); @@ -553,8 +571,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { } op->SetOutput(framework::GradVarName("QKVOut"), this->OutputGrad("QKVOut")); - op->SetOutput(framework::GradVarName("QKVBiasOut"), - this->OutputGrad("QKVBiasOut")); + op->SetOutput(framework::GradVarName("QKTVOut"), this->OutputGrad("QKTVOut")); op->SetOutput(framework::GradVarName("TransposeOut2"), diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu index 5bcf1285608369..9f6d6e2270673d 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -96,9 +96,11 @@ class FusedAttentionOpKernel : public framework::OpKernel { auto *x_data = input_x->data(); auto *qkv_weight_data = qkv_weight->data(); - auto *qkv_bias_data = qkv_bias->data(); + auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data(); auto *qkv_out_data = qkv_out->mutable_data(ctx.GetPlace()); - auto *qkv_bias_out_data = qkv_bias_out->mutable_data(ctx.GetPlace()); + auto *qkv_bias_out_data = + (qkv_bias == nullptr) ? nullptr + : qkv_bias_out->mutable_data(ctx.GetPlace()); // get data ptr for FMHA. auto *transpose_out_2_data = @@ -117,7 +119,8 @@ class FusedAttentionOpKernel : public framework::OpKernel { // get data ptr for out_linear. auto *out_linear_weight_data = out_linear_weight->data(); - auto *out_linear_bias_data = out_linear_bias->data(); + auto *out_linear_bias_data = + (out_linear_bias == nullptr) ? nullptr : out_linear_bias->data(); auto *out_linear_out_data = out_linear_out->mutable_data(ctx.GetPlace()); // get data ptr for bias+dropout+residual+layernorm @@ -139,9 +142,15 @@ class FusedAttentionOpKernel : public framework::OpKernel { auto layer_norm_compute = AttnLayerNorm(ctx.cuda_device_context(), epsilon, bsz_seq, dim_embed); + + bool compute_bias = true; + if (qkv_bias == nullptr) { + compute_bias = false; + } // (transA, transB, compute_bias) = (false, true, true) - auto qkv_compute = AttnMatMul(ctx.cuda_device_context(), false, true, - bsz_seq, output_size, input_size, true); + auto qkv_compute = + AttnMatMul(ctx.cuda_device_context(), false, true, bsz_seq, + output_size, input_size, compute_bias); AttnDropoutParam attn_dropout_param( is_test_1, dropout_implementation_1, attn_dropout_rate, @@ -176,10 +185,17 @@ class FusedAttentionOpKernel : public framework::OpKernel { qkv_compute.ComputeForward(qkv_weight, input_x, qkv_bias, qkv_out, qkv_bias_out); } - fmha_ref_compute.ComputeForward(*qkv_bias_out, src_mask, transpose_out_2, - qk_out, src_mask_out, softmax_out, - attn_dropout_mask_out, attn_dropout_out, - qktv_out, fmha_out); + if (qkv_bias == nullptr) { + fmha_ref_compute.ComputeForward(*qkv_out, src_mask, transpose_out_2, + qk_out, src_mask_out, softmax_out, + attn_dropout_mask_out, attn_dropout_out, + qktv_out, fmha_out); + } else { + fmha_ref_compute.ComputeForward(*qkv_bias_out, src_mask, transpose_out_2, + qk_out, src_mask_out, softmax_out, + attn_dropout_mask_out, attn_dropout_out, + qktv_out, fmha_out); + } // fmha_out: [batch_size, seq_len, num_head, head_dim] // weight: [embed_dim, embed_dim] @@ -249,9 +265,10 @@ class FusedAttentionGradKernel : public framework::OpKernel { auto *out_linear_bias = ctx.Input("OutLinearBias"); auto *src_mask_data = (src_mask == nullptr ? nullptr : src_mask->data()); auto *qkv_weight_data = qkv_weight->data(); - auto *qkv_bias_data = qkv_bias->data(); + auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data(); auto *out_linear_weight_data = out_linear_weight->data(); - auto *out_linear_bias_data = out_linear_bias->data(); + auto *out_linear_bias_data = + (out_linear_bias == nullptr) ? nullptr : out_linear_bias->data(); // fw output auto *fmha_out = ctx.Input("FMHAOut"); @@ -299,8 +316,15 @@ class FusedAttentionGradKernel : public framework::OpKernel { auto *d_bias_dropout_residual_out = ctx.Output(framework::GradVarName("BiasDropoutResidualOut")); auto *d_x_data = d_x->mutable_data(ctx.GetPlace()); - auto *d_qkv_out_data = d_qkv_out->mutable_data(ctx.GetPlace()); - auto *d_qkv_bias_out_data = d_qkv_bias_out->mutable_data(ctx.GetPlace()); + // when qkv_bias is not nullptr, d_qkv_out is equals to d_qkv_bias_out, the + // space can be reused. + auto *d_qkv_out_data = (d_qkv_bias_out != nullptr) + ? nullptr + : d_qkv_out->mutable_data(ctx.GetPlace()); + auto *d_qkv_bias_out_data = + (d_qkv_bias_out == nullptr) + ? nullptr + : d_qkv_bias_out->mutable_data(ctx.GetPlace()); auto *d_qktv_out_data = d_qktv_out->mutable_data(ctx.GetPlace()); auto *d_transpose_out_2_data = d_transpose_out_2->mutable_data(ctx.GetPlace()); @@ -326,11 +350,15 @@ class FusedAttentionGradKernel : public framework::OpKernel { auto *d_ln_2_bias = ctx.Output(framework::GradVarName("Ln2Bias")); auto *d_qkv_weight_data = d_qkv_weight->mutable_data(ctx.GetPlace()); - auto *d_qkv_bias_data = d_qkv_bias->mutable_data(ctx.GetPlace()); + auto *d_qkv_bias_data = (d_qkv_bias == nullptr) + ? nullptr + : d_qkv_bias->mutable_data(ctx.GetPlace()); auto *d_out_linear_weight_data = d_out_linear_weight->mutable_data(ctx.GetPlace()); auto *d_out_linear_bias_data = - d_out_linear_bias->mutable_data(ctx.GetPlace()); + (d_out_linear_bias == nullptr) + ? nullptr + : d_out_linear_bias->mutable_data(ctx.GetPlace()); const auto input_x_dims = input_x->dims(); const auto qkv_w_dims = qkv_weight->dims(); @@ -352,12 +380,15 @@ class FusedAttentionGradKernel : public framework::OpKernel { bool transA = false; bool transB = true; - bool compute_bias = true; + bool compute_qkv_bias = true; + if (qkv_bias == nullptr) { + compute_qkv_bias = false; + } auto layer_norm_compute = AttnLayerNorm(ctx.cuda_device_context(), epsilon, bsz_seq, dim_embed); auto qkv_compute = AttnMatMul(ctx.cuda_device_context(), transA, transB, bsz_seq, - output_size, input_size, compute_bias); + output_size, input_size, compute_qkv_bias); AttnDropoutParam attn_dropout_param( is_test_1, dropout_implementation_1, attn_dropout_prob, is_upscale_in_train_1, is_fix_seed_1, seed_val_1, seed_1); @@ -367,7 +398,7 @@ class FusedAttentionGradKernel : public framework::OpKernel { output_size = hidden_size; transA = false; transB = false; - compute_bias = false; + bool compute_bias = false; auto out_linear_compute = AttnMatMul(ctx.cuda_device_context(), transA, transB, bsz_seq, output_size, input_size, compute_bias); @@ -405,14 +436,19 @@ class FusedAttentionGradKernel : public framework::OpKernel { d_out_linear_out, d_fmha_out, d_out_linear_weight, nullptr); - fmha_ref_compute.ComputeBackward( - *transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out, - *attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out, - d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out, - d_transpose_out_2, nullptr, d_qkv_bias_out); - cudaMemcpyAsync(d_qkv_out_data, d_qkv_bias_out_data, - bsz_seq * 3 * num_head * dim_head * sizeof(T), - cudaMemcpyDeviceToDevice); + if (qkv_bias != nullptr) { + fmha_ref_compute.ComputeBackward( + *transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out, + *attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out, + d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out, + d_transpose_out_2, nullptr, d_qkv_bias_out); + } else { + fmha_ref_compute.ComputeBackward( + *transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out, + *attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out, + d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out, + d_transpose_out_2, nullptr, d_qkv_out); + } if (pre_layer_norm) { auto *ln_mean = ctx.Input("LnMean"); @@ -432,15 +468,24 @@ class FusedAttentionGradKernel : public framework::OpKernel { auto *d_ln_bias_data = (d_ln_bias == nullptr ? nullptr : d_ln_bias->mutable_data(ctx.GetPlace())); - - qkv_compute.ComputeBackward(ln_out, qkv_weight, d_qkv_bias_out, d_ln_out, - d_qkv_weight, d_qkv_bias); + if (qkv_bias != nullptr) { + qkv_compute.ComputeBackward(ln_out, qkv_weight, d_qkv_bias_out, + d_ln_out, d_qkv_weight, d_qkv_bias); + } else { + qkv_compute.ComputeBackward(ln_out, qkv_weight, d_qkv_out, d_ln_out, + d_qkv_weight, d_qkv_bias); + } layer_norm_compute.ComputeBackward(x_data, d_ln_out_data, ln_scale_data, ln_mean_data, ln_var_data, d_x_data, d_ln_scale_data, d_ln_bias_data); } else { - qkv_compute.ComputeBackward(input_x, qkv_weight, d_qkv_bias_out, d_x, - d_qkv_weight, d_qkv_bias); + if (qkv_bias != nullptr) { + qkv_compute.ComputeBackward(input_x, qkv_weight, d_qkv_bias_out, d_x, + d_qkv_weight, d_qkv_bias); + } else { + qkv_compute.ComputeBackward(input_x, qkv_weight, d_qkv_out, d_x, + d_qkv_weight, d_qkv_bias); + } } // gradient accumulation std::vector ins; diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py index b2b5cac2bff965..443703aa937d8a 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py @@ -168,17 +168,29 @@ def GetFusedAttentionOut(self): paddle.disable_static(place=paddle.CUDAPlace(0)) q_proj_weight = paddle.to_tensor( self.q_proj.weight, stop_gradient=False) - q_proj_bias = paddle.to_tensor(self.q_proj.bias, stop_gradient=False) k_proj_weight = paddle.to_tensor( self.k_proj.weight, stop_gradient=False) - k_proj_bias = paddle.to_tensor(self.k_proj.bias, stop_gradient=False) v_proj_weight = paddle.to_tensor( self.v_proj.weight, stop_gradient=False) - v_proj_bias = paddle.to_tensor(self.v_proj.bias, stop_gradient=False) out_linear_weight = paddle.to_tensor( self.out_proj.weight, stop_gradient=False) - out_linear_bias = paddle.to_tensor( - self.out_proj.bias, stop_gradient=False) + + if self.bias_attr is False: + qkv_bias_tensor = None + out_linear_bias = None + else: + q_proj_bias = paddle.to_tensor( + self.q_proj.bias, stop_gradient=False) + k_proj_bias = paddle.to_tensor( + self.k_proj.bias, stop_gradient=False) + v_proj_bias = paddle.to_tensor( + self.v_proj.bias, stop_gradient=False) + qkv_bias = np.concatenate( + (q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy())) + qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim)) + qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False) + out_linear_bias = paddle.to_tensor( + self.out_proj.bias, stop_gradient=False) ln1_scale = paddle.to_tensor(self.norm1.weight, stop_gradient=False) ln1_bias = paddle.to_tensor(self.norm1.bias, stop_gradient=False) @@ -193,17 +205,12 @@ def GetFusedAttentionOut(self): qkv_weight = qkv_weight.reshape( (3, self.num_heads, self.head_dim, self.embed_dim)) - qkv_bias = np.concatenate( - (q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy())) - qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim)) - x = paddle.to_tensor(self.query, stop_gradient=False) if self.has_attn_mask: attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) else: attn_mask = None qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False) - qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False) epsilon = 1e-05 ln2_epsilon = 1e-05 @@ -227,6 +234,36 @@ def test_fused_attention_op(self): x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-4) +class TestFusedAttentionOpBiasIsNone(TestFusedAttentionOp): + def config(self): + self.x_type = np.float32 + self.attn_mask_type = np.float64 + self.pre_layer_norm = False + self.has_attn_mask = True + self.training = True + + self.batch_size = 8 + self.query_length = 128 + self.head_dim = 64 + self.num_heads = 16 + self.embed_dim = self.head_dim * self.num_heads + + self.dropout_prob = 0.0 + self.attn_dropout_prob = 0.0 + self.weight_attr = None + self.bias_attr = False + self.kdim, self.vdim = self.embed_dim, self.embed_dim + self.key_length, self.value_length = self.query_length, self.query_length + + def test_fused_attention_op(self): + final_out_ref, x_grad_ref = self.GetBaselineOut() + final_out, x_grad = self.GetFusedAttentionOut() + np.testing.assert_allclose( + final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-4) + np.testing.assert_allclose( + x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-4) + + class TestFusedAttentionOpPreLn(TestFusedAttentionOp): def config(self): self.x_type = np.float32 diff --git a/python/paddle/incubate/nn/functional/fused_transformer.py b/python/paddle/incubate/nn/functional/fused_transformer.py index df9cc68a02d8dc..eafefd98298f54 100644 --- a/python/paddle/incubate/nn/functional/fused_transformer.py +++ b/python/paddle/incubate/nn/functional/fused_transformer.py @@ -356,6 +356,9 @@ def fused_multi_head_attention(x, 0] == 3, "The shape of qkv_weight should be [3, num_head, head_dim, embed_dim]." assert qkv_weight.shape[3] == x.shape[ 2], "The 3rd dim of qkv_weight and 2nd dim of x should be the same, i.e., embed_dim." + assert qkv_weight.shape[1] * qkv_weight.shape[2] == qkv_weight.shape[ + 3], "embed_dim must be divisible by num_heads." + _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, final_out = _C_ops.fused_attention( x, pre_ln_scale, pre_ln_bias, qkv_weight, qkv_bias, attn_mask, linear_weight, linear_bias, ln_scale, ln_bias, 'pre_layer_norm', From 79800978d923643249fde02840708edc13a0f2a6 Mon Sep 17 00:00:00 2001 From: Qi Li Date: Tue, 23 Nov 2021 20:39:32 +0800 Subject: [PATCH 2/4] [XPU] Reorganize xpu device codes in platform, test=develop (#37428) * [XPU] Reorganize xpu device codes in platform, test=develop * fix xpu_header.h, test=develop --- paddle/fluid/framework/details/all_reduce_op_handle.h | 2 +- paddle/fluid/framework/details/bkcl_op_handle.h | 2 +- paddle/fluid/framework/details/broadcast_op_handle.h | 2 +- paddle/fluid/framework/details/build_strategy.h | 2 +- .../framework/details/fused_all_reduce_op_handle.h | 2 +- paddle/fluid/framework/details/reduce_op_handle.h | 2 +- paddle/fluid/framework/operator.cc | 4 ++-- paddle/fluid/framework/var_type_traits.cc | 2 +- paddle/fluid/framework/var_type_traits_test.cc | 2 +- paddle/fluid/imperative/bkcl_context.cc | 2 +- paddle/fluid/imperative/prepared_operator.cc | 2 +- paddle/fluid/memory/allocation/allocator_facade.cc | 2 +- .../memory/allocation/naive_best_fit_allocator.cc | 2 +- paddle/fluid/memory/memcpy.cc | 2 +- paddle/fluid/operators/activation_op_xpu.cc | 2 +- paddle/fluid/operators/collective/broadcast_op_xpu.cc | 2 +- paddle/fluid/operators/collective/c_allreduce_op.h | 2 +- paddle/fluid/operators/collective/c_reduce_op.h | 2 +- paddle/fluid/operators/collective/gen_bkcl_id_op.cc | 2 +- paddle/fluid/operators/concat_op_xpu.cc | 2 +- paddle/fluid/operators/deformable_conv_op_xpu.cc | 2 +- paddle/fluid/operators/dropout_op_xpu.cc | 2 +- paddle/fluid/operators/metrics/accuracy_op_xpu.cc | 2 +- paddle/fluid/operators/reduce_ops/logsumexp_op_xpu.cc | 2 +- .../fluid/operators/reduce_ops/reduce_max_op_xpu.cc | 2 +- paddle/fluid/operators/reduce_ops/reduce_op_xpu.h | 2 +- .../fluid/operators/reduce_ops/reduce_sum_op_xpu.cc | 2 +- paddle/fluid/operators/rnn_op_xpu.cc | 2 +- paddle/fluid/operators/scale_op_xpu.cc | 2 +- paddle/fluid/operators/sign_op_xpu.cc | 2 +- paddle/fluid/operators/sum_op_xpu.cc | 2 +- paddle/fluid/operators/transpose_op_xpu.cc | 2 +- paddle/fluid/platform/CMakeLists.txt | 11 +---------- paddle/fluid/platform/device/CMakeLists.txt | 4 ++++ paddle/fluid/platform/device/xpu/CMakeLists.txt | 8 ++++++++ paddle/fluid/platform/{ => device/xpu}/bkcl_helper.h | 0 paddle/fluid/platform/{ => device}/xpu/xpu1_op_list.h | 0 paddle/fluid/platform/{ => device}/xpu/xpu2_op_list.h | 0 paddle/fluid/platform/{ => device}/xpu/xpu_header.h | 1 + paddle/fluid/platform/{ => device}/xpu/xpu_info.cc | 4 ++-- paddle/fluid/platform/{ => device}/xpu/xpu_info.h | 0 paddle/fluid/platform/{ => device}/xpu/xpu_op_list.cc | 8 ++++---- paddle/fluid/platform/{ => device}/xpu/xpu_op_list.h | 0 paddle/fluid/platform/device_context.h | 4 ++-- paddle/fluid/platform/init.cc | 4 ++-- paddle/fluid/pybind/pybind.cc | 2 +- 46 files changed, 58 insertions(+), 54 deletions(-) create mode 100644 paddle/fluid/platform/device/CMakeLists.txt create mode 100644 paddle/fluid/platform/device/xpu/CMakeLists.txt rename paddle/fluid/platform/{ => device/xpu}/bkcl_helper.h (100%) rename paddle/fluid/platform/{ => device}/xpu/xpu1_op_list.h (100%) rename paddle/fluid/platform/{ => device}/xpu/xpu2_op_list.h (100%) rename paddle/fluid/platform/{ => device}/xpu/xpu_header.h (98%) rename paddle/fluid/platform/{ => device}/xpu/xpu_info.cc (97%) rename paddle/fluid/platform/{ => device}/xpu/xpu_info.h (100%) rename paddle/fluid/platform/{ => device}/xpu/xpu_op_list.cc (91%) rename paddle/fluid/platform/{ => device}/xpu/xpu_op_list.h (100%) diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.h b/paddle/fluid/framework/details/all_reduce_op_handle.h index 39b923be9df84f..033d9396e9bf23 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.h +++ b/paddle/fluid/framework/details/all_reduce_op_handle.h @@ -36,7 +36,7 @@ class NCCLCommunicator; #include "paddle/fluid/platform/nccl_helper.h" #elif defined(PADDLE_WITH_XPU_BKCL) #include "paddle/fluid/framework/details/bkcl_op_handle.h" -#include "paddle/fluid/platform/bkcl_helper.h" +#include "paddle/fluid/platform/device/xpu/bkcl_helper.h" #endif namespace paddle { diff --git a/paddle/fluid/framework/details/bkcl_op_handle.h b/paddle/fluid/framework/details/bkcl_op_handle.h index fe63153a309205..f863cb123a8afb 100644 --- a/paddle/fluid/framework/details/bkcl_op_handle.h +++ b/paddle/fluid/framework/details/bkcl_op_handle.h @@ -23,7 +23,7 @@ #include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/platform/bkcl_helper.h" +#include "paddle/fluid/platform/device/xpu/bkcl_helper.h" DECLARE_bool(sync_bkcl_allreduce); diff --git a/paddle/fluid/framework/details/broadcast_op_handle.h b/paddle/fluid/framework/details/broadcast_op_handle.h index 8ca20da97416c6..0b062b1a3f49a4 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.h +++ b/paddle/fluid/framework/details/broadcast_op_handle.h @@ -46,7 +46,7 @@ struct BKCLContextMap; #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/nccl_helper.h" #elif defined(PADDLE_WITH_XPU_BKCL) -#include "paddle/fluid/platform/bkcl_helper.h" +#include "paddle/fluid/platform/device/xpu/bkcl_helper.h" #endif namespace paddle { diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 25110fe24f5871..68c5daaac5d780 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -42,7 +42,7 @@ class NCCLCommunicator; #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/nccl_helper.h" #elif defined(PADDLE_WITH_XPU) && defined(PADDLE_WITH_XPU_BKCL) -#include "paddle/fluid/platform/bkcl_helper.h" +#include "paddle/fluid/platform/device/xpu/bkcl_helper.h" #endif namespace paddle { diff --git a/paddle/fluid/framework/details/fused_all_reduce_op_handle.h b/paddle/fluid/framework/details/fused_all_reduce_op_handle.h index 8473700867ce32..31336b92c4dfb6 100644 --- a/paddle/fluid/framework/details/fused_all_reduce_op_handle.h +++ b/paddle/fluid/framework/details/fused_all_reduce_op_handle.h @@ -37,7 +37,7 @@ class NCCLCommunicator; #include "paddle/fluid/framework/details/nccl_op_handle.h" #include "paddle/fluid/platform/nccl_helper.h" #elif defined(PADDLE_WITH_XPU_BKCL) -#include "paddle/fluid/platform/bkcl_helper.h" +#include "paddle/fluid/platform/device/xpu/bkcl_helper.h" #endif namespace paddle { diff --git a/paddle/fluid/framework/details/reduce_op_handle.h b/paddle/fluid/framework/details/reduce_op_handle.h index 569699c19ccf50..d56b6b3663003c 100644 --- a/paddle/fluid/framework/details/reduce_op_handle.h +++ b/paddle/fluid/framework/details/reduce_op_handle.h @@ -43,7 +43,7 @@ struct NCCLContextMap; #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/nccl_helper.h" #elif defined(PADDLE_WITH_XPU_BKCL) -#include "paddle/fluid/platform/bkcl_helper.h" +#include "paddle/fluid/platform/device/xpu/bkcl_helper.h" #endif namespace paddle { diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index f2615694cfbc83..cde5d1353d018c 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -37,8 +37,8 @@ class LoDTensor; } // namespace framework } // namespace paddle #ifdef PADDLE_WITH_XPU -#include "paddle/fluid/platform/xpu/xpu_info.h" -#include "paddle/fluid/platform/xpu/xpu_op_list.h" +#include "paddle/fluid/platform/device/xpu/xpu_info.h" +#include "paddle/fluid/platform/device/xpu/xpu_op_list.h" #endif #ifdef PADDLE_WITH_MKLDNN diff --git a/paddle/fluid/framework/var_type_traits.cc b/paddle/fluid/framework/var_type_traits.cc index 886d00e562bff1..1d5e638729361d 100644 --- a/paddle/fluid/framework/var_type_traits.cc +++ b/paddle/fluid/framework/var_type_traits.cc @@ -38,7 +38,7 @@ #endif #if defined(PADDLE_WITH_XPU_BKCL) -#include "paddle/fluid/platform/bkcl_helper.h" +#include "paddle/fluid/platform/device/xpu/bkcl_helper.h" #endif namespace paddle { diff --git a/paddle/fluid/framework/var_type_traits_test.cc b/paddle/fluid/framework/var_type_traits_test.cc index 2a6635c4b6050b..ae7ae85207d849 100644 --- a/paddle/fluid/framework/var_type_traits_test.cc +++ b/paddle/fluid/framework/var_type_traits_test.cc @@ -37,7 +37,7 @@ #include "paddle/fluid/operators/miopen_rnn_cache.h" #endif #if defined(PADDLE_WITH_XPU_BKCL) -#include "paddle/fluid/platform/bkcl_helper.h" +#include "paddle/fluid/platform/device/xpu/bkcl_helper.h" #endif namespace paddle { diff --git a/paddle/fluid/imperative/bkcl_context.cc b/paddle/fluid/imperative/bkcl_context.cc index ba9b70aea7b96c..8c6b840f60a591 100644 --- a/paddle/fluid/imperative/bkcl_context.cc +++ b/paddle/fluid/imperative/bkcl_context.cc @@ -20,8 +20,8 @@ #include #include "paddle/fluid/framework/variable.h" -#include "paddle/fluid/platform/bkcl_helper.h" #include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/device/xpu/bkcl_helper.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/gen_comm_id_helper.h" #include "paddle/fluid/platform/place.h" diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 521f85d9429ba7..167afdfb1966d7 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -21,7 +21,7 @@ #include "paddle/pten/common/scalar.h" #include "paddle/utils/small_vector.h" #ifdef PADDLE_WITH_XPU -#include "paddle/fluid/platform/xpu/xpu_op_list.h" +#include "paddle/fluid/platform/device/xpu/xpu_op_list.h" #endif DECLARE_bool(check_nan_inf); DECLARE_bool(run_pten_kernel); diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index 9da735636fc00f..ca7f5655f08c37 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -42,7 +42,7 @@ #include "paddle/fluid/platform/cuda_graph.h" #endif #ifdef PADDLE_WITH_XPU -#include "paddle/fluid/platform/xpu/xpu_info.h" +#include "paddle/fluid/platform/device/xpu/xpu_info.h" #endif #include "paddle/fluid/platform/npu_info.h" diff --git a/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc b/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc index 2c00b34dd1353b..a9ef83db1704e7 100644 --- a/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc +++ b/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc @@ -31,7 +31,7 @@ #include "paddle/fluid/platform/cuda_device_guard.h" #endif #ifdef PADDLE_WITH_XPU -#include "paddle/fluid/platform/xpu/xpu_header.h" +#include "paddle/fluid/platform/device/xpu/xpu_header.h" #endif PADDLE_DEFINE_EXPORTED_bool( diff --git a/paddle/fluid/memory/memcpy.cc b/paddle/fluid/memory/memcpy.cc index 3b3be9776c4c54..574b1520543993 100644 --- a/paddle/fluid/memory/memcpy.cc +++ b/paddle/fluid/memory/memcpy.cc @@ -19,7 +19,7 @@ limitations under the License. */ #include "paddle/fluid/platform/profiler.h" #ifdef PADDLE_WITH_XPU -#include "paddle/fluid/platform/xpu/xpu_header.h" +#include "paddle/fluid/platform/device/xpu/xpu_header.h" #endif namespace paddle { diff --git a/paddle/fluid/operators/activation_op_xpu.cc b/paddle/fluid/operators/activation_op_xpu.cc index 2c3d9697366cad..fe85eb26705d1f 100644 --- a/paddle/fluid/operators/activation_op_xpu.cc +++ b/paddle/fluid/operators/activation_op_xpu.cc @@ -16,7 +16,7 @@ limitations under the License. */ #include "paddle/fluid/operators/activation_op.h" #include -#include "paddle/fluid/platform/xpu/xpu_header.h" +#include "paddle/fluid/platform/device/xpu/xpu_header.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/collective/broadcast_op_xpu.cc b/paddle/fluid/operators/collective/broadcast_op_xpu.cc index 2bfd77b8c2a090..9cd5c5fd22cc38 100644 --- a/paddle/fluid/operators/collective/broadcast_op_xpu.cc +++ b/paddle/fluid/operators/collective/broadcast_op_xpu.cc @@ -21,8 +21,8 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #if defined(PADDLE_WITH_XPU_BKCL) -#include "paddle/fluid/platform/bkcl_helper.h" #include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/device/xpu/bkcl_helper.h" #endif namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/collective/c_allreduce_op.h b/paddle/fluid/operators/collective/c_allreduce_op.h index a19d603ada8257..938ded9d017632 100644 --- a/paddle/fluid/operators/collective/c_allreduce_op.h +++ b/paddle/fluid/operators/collective/c_allreduce_op.h @@ -33,7 +33,7 @@ limitations under the License. */ #endif #if defined(PADDLE_WITH_XPU_BKCL) -#include "paddle/fluid/platform/bkcl_helper.h" +#include "paddle/fluid/platform/device/xpu/bkcl_helper.h" #endif #if defined(PADDLE_WITH_GLOO) diff --git a/paddle/fluid/operators/collective/c_reduce_op.h b/paddle/fluid/operators/collective/c_reduce_op.h index fa9fd079d8e48b..06023dc8ca4618 100644 --- a/paddle/fluid/operators/collective/c_reduce_op.h +++ b/paddle/fluid/operators/collective/c_reduce_op.h @@ -34,7 +34,7 @@ limitations under the License. */ #endif #if defined(PADDLE_WITH_XPU_BKCL) -#include "paddle/fluid/platform/bkcl_helper.h" +#include "paddle/fluid/platform/device/xpu/bkcl_helper.h" #endif #if defined(PADDLE_WITH_GLOO) diff --git a/paddle/fluid/operators/collective/gen_bkcl_id_op.cc b/paddle/fluid/operators/collective/gen_bkcl_id_op.cc index 7067bfb314485e..1ce89383568959 100644 --- a/paddle/fluid/operators/collective/gen_bkcl_id_op.cc +++ b/paddle/fluid/operators/collective/gen_bkcl_id_op.cc @@ -21,7 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/var_type_traits.h" -#include "paddle/fluid/platform/bkcl_helper.h" +#include "paddle/fluid/platform/device/xpu/bkcl_helper.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/place.h" diff --git a/paddle/fluid/operators/concat_op_xpu.cc b/paddle/fluid/operators/concat_op_xpu.cc index dc9359ecf5c3d1..0ff11e11165f06 100644 --- a/paddle/fluid/operators/concat_op_xpu.cc +++ b/paddle/fluid/operators/concat_op_xpu.cc @@ -16,7 +16,7 @@ limitations under the License. */ #include #include #include -#include "paddle/fluid/platform/xpu/xpu_header.h" +#include "paddle/fluid/platform/device/xpu/xpu_header.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/deformable_conv_op_xpu.cc b/paddle/fluid/operators/deformable_conv_op_xpu.cc index 457616756215c2..ebdaf0acced13d 100644 --- a/paddle/fluid/operators/deformable_conv_op_xpu.cc +++ b/paddle/fluid/operators/deformable_conv_op_xpu.cc @@ -16,7 +16,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/xpu/xpu_header.h" +#include "paddle/fluid/platform/device/xpu/xpu_header.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/dropout_op_xpu.cc b/paddle/fluid/operators/dropout_op_xpu.cc index 0b0b7095bd5d16..3335c0de429e4b 100644 --- a/paddle/fluid/operators/dropout_op_xpu.cc +++ b/paddle/fluid/operators/dropout_op_xpu.cc @@ -11,7 +11,7 @@ limitations under the License. */ #include "paddle/fluid/operators/dropout_op.h" #include #include -#include "paddle/fluid/platform/xpu/xpu_header.h" +#include "paddle/fluid/platform/device/xpu/xpu_header.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/metrics/accuracy_op_xpu.cc b/paddle/fluid/operators/metrics/accuracy_op_xpu.cc index cb75616221bc4d..7031b96a50b9e4 100644 --- a/paddle/fluid/operators/metrics/accuracy_op_xpu.cc +++ b/paddle/fluid/operators/metrics/accuracy_op_xpu.cc @@ -15,7 +15,7 @@ limitations under the License. */ #ifdef PADDLE_WITH_XPU #include "paddle/fluid/operators/metrics/accuracy_op.h" -#include "paddle/fluid/platform/xpu/xpu_header.h" +#include "paddle/fluid/platform/device/xpu/xpu_header.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/reduce_ops/logsumexp_op_xpu.cc b/paddle/fluid/operators/reduce_ops/logsumexp_op_xpu.cc index 4f98dde210f7a7..dcb849de0991bc 100644 --- a/paddle/fluid/operators/reduce_ops/logsumexp_op_xpu.cc +++ b/paddle/fluid/operators/reduce_ops/logsumexp_op_xpu.cc @@ -15,8 +15,8 @@ #ifdef PADDLE_WITH_XPU #include "paddle/fluid/operators/reduce_ops/logsumexp_op.h" +#include "paddle/fluid/platform/device/xpu/xpu_header.h" #include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/platform/xpu/xpu_header.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/reduce_ops/reduce_max_op_xpu.cc b/paddle/fluid/operators/reduce_ops/reduce_max_op_xpu.cc index ae27a5d7df4734..15d672da04bec5 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_max_op_xpu.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_max_op_xpu.cc @@ -16,7 +16,7 @@ #include #include #include "paddle/fluid/operators/reduce_ops/reduce_op_xpu.h" -#include "paddle/fluid/platform/xpu/xpu_header.h" +#include "paddle/fluid/platform/device/xpu/xpu_header.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/reduce_ops/reduce_op_xpu.h b/paddle/fluid/operators/reduce_ops/reduce_op_xpu.h index 5ae60713bc912b..324fd369e82b59 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op_xpu.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op_xpu.h @@ -21,7 +21,7 @@ #include #include #include "paddle/fluid/operators/reduce_ops/reduce_op.h" -#include "paddle/fluid/platform/xpu/xpu_header.h" +#include "paddle/fluid/platform/device/xpu/xpu_header.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op_xpu.cc b/paddle/fluid/operators/reduce_ops/reduce_sum_op_xpu.cc index f759b104d01d18..7a5c86c35c6a2a 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op_xpu.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op_xpu.cc @@ -16,7 +16,7 @@ #include #include #include "paddle/fluid/operators/reduce_ops/reduce_op_xpu.h" -#include "paddle/fluid/platform/xpu/xpu_header.h" +#include "paddle/fluid/platform/device/xpu/xpu_header.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/rnn_op_xpu.cc b/paddle/fluid/operators/rnn_op_xpu.cc index 9d637e1cee1176..183f83dbae7c3e 100644 --- a/paddle/fluid/operators/rnn_op_xpu.cc +++ b/paddle/fluid/operators/rnn_op_xpu.cc @@ -13,8 +13,8 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/utils.h" +#include "paddle/fluid/platform/device/xpu/xpu_header.h" #include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/platform/xpu/xpu_header.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/scale_op_xpu.cc b/paddle/fluid/operators/scale_op_xpu.cc index d3943e09b6d0b1..4960f720ee39aa 100644 --- a/paddle/fluid/operators/scale_op_xpu.cc +++ b/paddle/fluid/operators/scale_op_xpu.cc @@ -16,7 +16,7 @@ limitations under the License. */ #include "paddle/fluid/operators/scale_op.h" #include -#include "paddle/fluid/platform/xpu/xpu_header.h" +#include "paddle/fluid/platform/device/xpu/xpu_header.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/sign_op_xpu.cc b/paddle/fluid/operators/sign_op_xpu.cc index a164a9b056677a..8b3beb2fb397b0 100644 --- a/paddle/fluid/operators/sign_op_xpu.cc +++ b/paddle/fluid/operators/sign_op_xpu.cc @@ -15,7 +15,7 @@ limitations under the License. */ #ifdef PADDLE_WITH_XPU #include "paddle/fluid/operators/sign_op.h" -#include "paddle/fluid/platform/xpu/xpu_header.h" +#include "paddle/fluid/platform/device/xpu/xpu_header.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/sum_op_xpu.cc b/paddle/fluid/operators/sum_op_xpu.cc index 89a48a2144ef7b..5899591549eacc 100644 --- a/paddle/fluid/operators/sum_op_xpu.cc +++ b/paddle/fluid/operators/sum_op_xpu.cc @@ -13,7 +13,7 @@ limitations under the License. */ #include "paddle/fluid/operators/sum_op.h" #include -#include "paddle/fluid/platform/xpu/xpu_header.h" +#include "paddle/fluid/platform/device/xpu/xpu_header.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/transpose_op_xpu.cc b/paddle/fluid/operators/transpose_op_xpu.cc index 0e25a69f87c4a5..00a43c74d87366 100644 --- a/paddle/fluid/operators/transpose_op_xpu.cc +++ b/paddle/fluid/operators/transpose_op_xpu.cc @@ -17,7 +17,7 @@ limitations under the License. */ #include #include #include -#include "paddle/fluid/platform/xpu/xpu_header.h" +#include "paddle/fluid/platform/device/xpu/xpu_header.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 54e73c5c1d9fa2..b6a2de06962f8d 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -4,12 +4,6 @@ if(WITH_GPU) proto_library(external_error_proto SRCS external_error.proto) endif(WITH_GPU) -if(WITH_XPU) - set(XPU_CTX_DEPS xpulib ssl crypto rt z resolv dl) -ELSE() - set(XPU_CTX_DEPS) -endif(WITH_XPU) - if(WITH_ASCEND) set(ASCEND_DEPS xpulib) ELSE() @@ -74,10 +68,7 @@ ENDIF() cc_library(place SRCS place.cc DEPS enforce boost) cc_test(place_test SRCS place_test.cc DEPS place glog gflags) -if(WITH_XPU) -cc_library(xpu_info SRCS xpu/xpu_info.cc DEPS gflags glog enforce xpulib) -cc_library(xpu_op_list SRCS xpu/xpu_op_list.cc DEPS gflags glog enforce xpulib device_context) -endif() +add_subdirectory(device) if(WITH_ASCEND) cc_library(ascend_npu_info SRCS ascend_npu_info.cc DEPS gflags glog enforce atlas_acl) diff --git a/paddle/fluid/platform/device/CMakeLists.txt b/paddle/fluid/platform/device/CMakeLists.txt new file mode 100644 index 00000000000000..0d8e5a7784ce4a --- /dev/null +++ b/paddle/fluid/platform/device/CMakeLists.txt @@ -0,0 +1,4 @@ +# XPU +IF(WITH_XPU) + add_subdirectory(xpu) +ENDIF() diff --git a/paddle/fluid/platform/device/xpu/CMakeLists.txt b/paddle/fluid/platform/device/xpu/CMakeLists.txt new file mode 100644 index 00000000000000..17f492f93e5346 --- /dev/null +++ b/paddle/fluid/platform/device/xpu/CMakeLists.txt @@ -0,0 +1,8 @@ +if(NOT WITH_XPU) + return() +endif() + +set(XPU_CTX_DEPS xpulib ssl crypto rt z resolv dl) + +cc_library(xpu_info SRCS xpu_info.cc DEPS gflags glog enforce xpulib) +cc_library(xpu_op_list SRCS xpu_op_list.cc DEPS gflags glog enforce xpulib device_context) diff --git a/paddle/fluid/platform/bkcl_helper.h b/paddle/fluid/platform/device/xpu/bkcl_helper.h similarity index 100% rename from paddle/fluid/platform/bkcl_helper.h rename to paddle/fluid/platform/device/xpu/bkcl_helper.h diff --git a/paddle/fluid/platform/xpu/xpu1_op_list.h b/paddle/fluid/platform/device/xpu/xpu1_op_list.h similarity index 100% rename from paddle/fluid/platform/xpu/xpu1_op_list.h rename to paddle/fluid/platform/device/xpu/xpu1_op_list.h diff --git a/paddle/fluid/platform/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h similarity index 100% rename from paddle/fluid/platform/xpu/xpu2_op_list.h rename to paddle/fluid/platform/device/xpu/xpu2_op_list.h diff --git a/paddle/fluid/platform/xpu/xpu_header.h b/paddle/fluid/platform/device/xpu/xpu_header.h similarity index 98% rename from paddle/fluid/platform/xpu/xpu_header.h rename to paddle/fluid/platform/device/xpu/xpu_header.h index a72fbd65e24622..fe75290c252dfd 100644 --- a/paddle/fluid/platform/xpu/xpu_header.h +++ b/paddle/fluid/platform/device/xpu/xpu_header.h @@ -20,6 +20,7 @@ #include #include "paddle/fluid/platform/bfloat16.h" +#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/float16.h" #include "xpu/runtime.h" diff --git a/paddle/fluid/platform/xpu/xpu_info.cc b/paddle/fluid/platform/device/xpu/xpu_info.cc similarity index 97% rename from paddle/fluid/platform/xpu/xpu_info.cc rename to paddle/fluid/platform/device/xpu/xpu_info.cc index 3f45286d8f2020..adc8bcc22da98b 100644 --- a/paddle/fluid/platform/xpu/xpu_info.cc +++ b/paddle/fluid/platform/device/xpu/xpu_info.cc @@ -8,14 +8,14 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/platform/xpu/xpu_info.h" +#include "paddle/fluid/platform/device/xpu/xpu_info.h" #include #include #include #include "gflags/gflags.h" +#include "paddle/fluid/platform/device/xpu/xpu_header.h" #include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/platform/xpu/xpu_header.h" #include "paddle/fluid/string/split.h" PADDLE_DEFINE_EXPORTED_string( diff --git a/paddle/fluid/platform/xpu/xpu_info.h b/paddle/fluid/platform/device/xpu/xpu_info.h similarity index 100% rename from paddle/fluid/platform/xpu/xpu_info.h rename to paddle/fluid/platform/device/xpu/xpu_info.h diff --git a/paddle/fluid/platform/xpu/xpu_op_list.cc b/paddle/fluid/platform/device/xpu/xpu_op_list.cc similarity index 91% rename from paddle/fluid/platform/xpu/xpu_op_list.cc rename to paddle/fluid/platform/device/xpu/xpu_op_list.cc index 0c10436f397898..7561830fc76c11 100644 --- a/paddle/fluid/platform/xpu/xpu_op_list.cc +++ b/paddle/fluid/platform/device/xpu/xpu_op_list.cc @@ -13,10 +13,10 @@ limitations under the License. */ #include #include -#include "paddle/fluid/platform/xpu/xpu1_op_list.h" -#include "paddle/fluid/platform/xpu/xpu2_op_list.h" -#include "paddle/fluid/platform/xpu/xpu_info.h" -#include "paddle/fluid/platform/xpu/xpu_op_list.h" +#include "paddle/fluid/platform/device/xpu/xpu1_op_list.h" +#include "paddle/fluid/platform/device/xpu/xpu2_op_list.h" +#include "paddle/fluid/platform/device/xpu/xpu_info.h" +#include "paddle/fluid/platform/device/xpu/xpu_op_list.h" namespace paddle { namespace platform { diff --git a/paddle/fluid/platform/xpu/xpu_op_list.h b/paddle/fluid/platform/device/xpu/xpu_op_list.h similarity index 100% rename from paddle/fluid/platform/xpu/xpu_op_list.h rename to paddle/fluid/platform/device/xpu/xpu_op_list.h diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 13a1040dd19df2..6ffc3bef7431bf 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -69,8 +69,8 @@ struct GpuDevice; } // namespace Eigen #ifdef PADDLE_WITH_XPU -#include "paddle/fluid/platform/xpu/xpu_header.h" -#include "paddle/fluid/platform/xpu/xpu_info.h" +#include "paddle/fluid/platform/device/xpu/xpu_header.h" +#include "paddle/fluid/platform/device/xpu/xpu_info.h" #endif #ifdef PADDLE_WITH_ASCEND_CL diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc index 1109ecd52824a9..2516b4a3b1d4a0 100644 --- a/paddle/fluid/platform/init.cc +++ b/paddle/fluid/platform/init.cc @@ -30,8 +30,8 @@ limitations under the License. */ #include "paddle/fluid/platform/place.h" #ifdef PADDLE_WITH_XPU -#include "paddle/fluid/platform/xpu/xpu_header.h" -#include "paddle/fluid/platform/xpu/xpu_info.h" +#include "paddle/fluid/platform/device/xpu/xpu_header.h" +#include "paddle/fluid/platform/device/xpu/xpu_info.h" #endif #ifdef WITH_WIN_DUMP_DBG diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index ec3a4ba78c7cca..b85ebdeab55429 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -125,7 +125,7 @@ limitations under the License. */ #endif #ifdef PADDLE_WITH_XPU -#include "paddle/fluid/platform/xpu/xpu_info.h" +#include "paddle/fluid/platform/device/xpu/xpu_info.h" #endif #include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" From ee1e16429b58792379f57116cd00550b502df018 Mon Sep 17 00:00:00 2001 From: pangyoki Date: Tue, 23 Nov 2021 21:34:20 +0800 Subject: [PATCH 3/4] fix inplace bug when the first grad_var(loss_grad) is inplace var (#37420) * fix inplace bug * fix custom grad input error * add unittest * fix inplace bug --- paddle/fluid/imperative/basic_engine.cc | 15 ++++++----- .../fluid/tests/unittests/test_inplace.py | 25 +++++++++++++++++++ 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/imperative/basic_engine.cc b/paddle/fluid/imperative/basic_engine.cc index a31fd436c71644..014fa8a2ee3dff 100644 --- a/paddle/fluid/imperative/basic_engine.cc +++ b/paddle/fluid/imperative/basic_engine.cc @@ -53,6 +53,10 @@ void BasicEngine::Init( platform::errors::AlreadyExists( "Accumulators are not empty before preparing it for " "backward network execution.")); + PADDLE_ENFORCE_EQ(accumulators_with_grad_node_.empty(), true, + platform::errors::AlreadyExists( + "Accumulators with grad_node as the key are not empty " + "before preparing it for backward network execution.")); for (size_t i = 0; i < tensors.size(); ++i) { auto var = tensors[i]; @@ -73,7 +77,6 @@ void BasicEngine::Init( VLOG(5) << "Clear the auto-grad graph from grad var " << var->Name() << " because of retain_graph=False when calling backward"; var->GradVarBase()->SetGraphIsFreed(true); - var->GradVarBase()->ClearGradNode(); } if (init_node == nullptr || var->OverridedStopGradient()) { @@ -108,7 +111,9 @@ void BasicEngine::Init( } VariableWrapper* init_grad_var = var->GradVarBase()->SharedVar().get(); - auto& accumulator = accumulators_[init_grad_var]; + auto& accumulator = + accumulators_with_grad_node_[init_grad_var->GetGradNode()] + [init_grad_var]; if (!accumulator) { if (FLAGS_sort_sum_gradient) { accumulator.reset(new SortedGradientAccumulator(init_grad_var)); @@ -116,6 +121,8 @@ void BasicEngine::Init( accumulator.reset(new EagerGradientAccumulator(init_grad_var)); } } + accumulator->IncreaseRefCnt(); + accumulator->IncreaseCurCnt(); init_nodes_.push_back(init_node); } @@ -253,10 +260,6 @@ void BasicEngine::PrepareDeps() { node_deps_.empty(), true, platform::errors::AlreadyExists("Op deps are not empty before preparing " "it for backward network execution.")); - PADDLE_ENFORCE_EQ(accumulators_with_grad_node_.empty(), true, - platform::errors::AlreadyExists( - "Accumulators with grad_node as the key are not empty " - "before preparing it for backward network execution.")); std::queue q; std::unordered_set visited; diff --git a/python/paddle/fluid/tests/unittests/test_inplace.py b/python/paddle/fluid/tests/unittests/test_inplace.py index 3d158763527e71..98e2d2367fd5ed 100644 --- a/python/paddle/fluid/tests/unittests/test_inplace.py +++ b/python/paddle/fluid/tests/unittests/test_inplace.py @@ -409,5 +409,30 @@ def inplace_api_processing(self, var): return var.subtract_(self.input_var_2) +class TestLossIsInplaceVar(unittest.TestCase): + def test_loss_is_inplace_var(self): + with paddle.fluid.dygraph.guard(): + var_a = paddle.ones((2, 2)) + var_a.stop_gradient = False + + var_b = var_a * 2 + loss = var_b.tanh_() + + loss.backward() + inplace_grad_var_a = var_a.grad.numpy() + + with paddle.fluid.dygraph.guard(): + var_a = paddle.ones((2, 2)) + var_a.stop_gradient = False + + var_b = var_a * 2 + loss = var_b.tanh() + + loss.backward() + grad_var_a = var_a.grad.numpy() + + self.assertTrue(np.array_equal(inplace_grad_var_a, grad_var_a)) + + if __name__ == '__main__': unittest.main() From 1799c0322795d76dfae6d6cf6d62996d92974dbd Mon Sep 17 00:00:00 2001 From: Jiabin Yang Date: Wed, 24 Nov 2021 10:38:23 +0800 Subject: [PATCH 4/4] Refactor dygraph to eager -- TensorWrapper, EagerUtils, GlobalUtils (#37466) * Add EagerTensor and tests * remove useless enforce * remove comment in cmake * support autograd meta * support grad node info test * support grad_node_info * add more edge test * remove Python.h * add tensor wrapper with tests * support compute require grad and stop gradient * support sync methods and global utils * support pure cpu test * refine error msg * refine error msg * refine error info * fix npu error --- paddle/fluid/eager/CMakeLists.txt | 2 + paddle/fluid/eager/api/CMakeLists.txt | 3 + paddle/fluid/eager/api/all.cc | 18 ++ paddle/fluid/eager/api/all.h | 17 ++ paddle/fluid/eager/api/utils/CMakeLists.txt | 1 + paddle/fluid/eager/api/utils/global_utils.cc | 22 ++ paddle/fluid/eager/api/utils/global_utils.h | 62 ++++++ paddle/fluid/eager/eager_tensor.h | 1 - paddle/fluid/eager/tensor_wrapper.h | 91 ++++++++ paddle/fluid/eager/tests/CMakeLists.txt | 1 + .../tests/data_structure_tests/CMakeLists.txt | 1 + .../data_structure_tests/grad_node_test.h | 3 + .../tensor_wrapper_test.cc | 80 +++++++ .../eager/tests/task_tests/CMakeLists.txt | 1 + .../tests/task_tests/eager_utils_test.cc | 202 ++++++++++++++++++ paddle/fluid/eager/tests/test_utils.h | 175 +++++++++++++++ paddle/fluid/eager/utils.cc | 120 +++++++++++ paddle/fluid/eager/utils.h | 126 +++++++++++ 18 files changed, 925 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/eager/api/CMakeLists.txt create mode 100644 paddle/fluid/eager/api/all.cc create mode 100644 paddle/fluid/eager/api/all.h create mode 100644 paddle/fluid/eager/api/utils/CMakeLists.txt create mode 100644 paddle/fluid/eager/api/utils/global_utils.cc create mode 100644 paddle/fluid/eager/api/utils/global_utils.h create mode 100644 paddle/fluid/eager/tensor_wrapper.h create mode 100644 paddle/fluid/eager/tests/data_structure_tests/tensor_wrapper_test.cc create mode 100644 paddle/fluid/eager/tests/task_tests/CMakeLists.txt create mode 100644 paddle/fluid/eager/tests/task_tests/eager_utils_test.cc create mode 100644 paddle/fluid/eager/tests/test_utils.h create mode 100644 paddle/fluid/eager/utils.cc create mode 100644 paddle/fluid/eager/utils.h diff --git a/paddle/fluid/eager/CMakeLists.txt b/paddle/fluid/eager/CMakeLists.txt index a79b451b54431b..fe3fe9760858ed 100644 --- a/paddle/fluid/eager/CMakeLists.txt +++ b/paddle/fluid/eager/CMakeLists.txt @@ -1,3 +1,5 @@ add_subdirectory(tests) +add_subdirectory(api) cc_library(grad_node_info SRCS grad_node_info.cc DEPS pten pten_api) cc_library(autograd_meta SRCS autograd_meta.cc DEPS pten pten_api) +cc_library(utils SRCS utils.cc DEPS pten pten_api autograd_meta eager_api) diff --git a/paddle/fluid/eager/api/CMakeLists.txt b/paddle/fluid/eager/api/CMakeLists.txt new file mode 100644 index 00000000000000..92c1c81bb8cd56 --- /dev/null +++ b/paddle/fluid/eager/api/CMakeLists.txt @@ -0,0 +1,3 @@ +add_subdirectory(utils) + +cc_library(eager_api SRCS all.cc DEPS global_utils) diff --git a/paddle/fluid/eager/api/all.cc b/paddle/fluid/eager/api/all.cc new file mode 100644 index 00000000000000..2308e9341772c0 --- /dev/null +++ b/paddle/fluid/eager/api/all.cc @@ -0,0 +1,18 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "paddle/fluid/eager/api/all.h" + +namespace egr {} // namespace egr diff --git a/paddle/fluid/eager/api/all.h b/paddle/fluid/eager/api/all.h new file mode 100644 index 00000000000000..4d873ad95a4225 --- /dev/null +++ b/paddle/fluid/eager/api/all.h @@ -0,0 +1,17 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +#pragma once + +#include "paddle/fluid/eager/api/utils/global_utils.h" diff --git a/paddle/fluid/eager/api/utils/CMakeLists.txt b/paddle/fluid/eager/api/utils/CMakeLists.txt new file mode 100644 index 00000000000000..5168f1fc02489c --- /dev/null +++ b/paddle/fluid/eager/api/utils/CMakeLists.txt @@ -0,0 +1 @@ +cc_library(global_utils SRCS global_utils.cc DEPS enforce) diff --git a/paddle/fluid/eager/api/utils/global_utils.cc b/paddle/fluid/eager/api/utils/global_utils.cc new file mode 100644 index 00000000000000..3a6a05eb1bfc8b --- /dev/null +++ b/paddle/fluid/eager/api/utils/global_utils.cc @@ -0,0 +1,22 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "paddle/fluid/eager/api/utils/global_utils.h" + +namespace egr { + +Controller* Controller::controller_ = new Controller(); + +} // namespace egr diff --git a/paddle/fluid/eager/api/utils/global_utils.h b/paddle/fluid/eager/api/utils/global_utils.h new file mode 100644 index 00000000000000..16e7ef8a58e666 --- /dev/null +++ b/paddle/fluid/eager/api/utils/global_utils.h @@ -0,0 +1,62 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#pragma once + +#include "paddle/fluid/eager/eager_tensor.h" +#include "paddle/fluid/platform/enforce.h" + +namespace egr { + +class UniqueNameGenerator { + public: + explicit UniqueNameGenerator(std::string prefix = "") : prefix_(prefix) {} + std::string Generate(std::string key = "eager_tmp") { + return prefix_ + key + "_" + std::to_string(id_++); + } + + private: + std::atomic id_{0}; + std::string prefix_; +}; + +// Global +class Controller { + public: + static Controller& Instance() { return *controller_; } + const paddle::platform::Place& GetExpectedPlace() const { + return *expected_place_.get(); + } + void SetExpectedPlace(const paddle::platform::Place& place) { + expected_place_ = std::make_shared(place); + } + void SetAMPLevel(int level) { amp_level_ = level; } + int GetAMPLevel() const { return amp_level_; } + bool HasGrad() const { return has_grad_; } + std::string GenerateUniqueName(std::string key = "eager_tmp") { + return generator_->Generate(key); + } + + private: + Controller() = default; + static Controller* controller_; + std::shared_ptr expected_place_ = nullptr; + int amp_level_ = 0; + bool has_grad_ = true; + std::unique_ptr generator_{new UniqueNameGenerator()}; + DISABLE_COPY_AND_ASSIGN(Controller); +}; + +} // namespace egr diff --git a/paddle/fluid/eager/eager_tensor.h b/paddle/fluid/eager/eager_tensor.h index d871a84dc224f6..753040a2623f9b 100644 --- a/paddle/fluid/eager/eager_tensor.h +++ b/paddle/fluid/eager/eager_tensor.h @@ -14,7 +14,6 @@ #pragma once // framework deps -#include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/variable.h" diff --git a/paddle/fluid/eager/tensor_wrapper.h b/paddle/fluid/eager/tensor_wrapper.h new file mode 100644 index 00000000000000..d760a76ec68229 --- /dev/null +++ b/paddle/fluid/eager/tensor_wrapper.h @@ -0,0 +1,91 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/** + * We now still need TensorWrapper and it is designed to Copy + * tensor in autograd mode. + * + * Since in autograd usage, we need to pass autograd_meta to + * backward computation however in tensor interface add to much + * autograd_related method is not a good choice. + * + * In TensorWrapper we will keep autograd info to backward, only + * for input var, but for output var it will only copy autograd + * with no grad **/ + +#pragma once +#include "paddle/fluid/eager/autograd_meta.h" +#include "paddle/fluid/eager/grad_node_info.h" +#include "paddle/fluid/eager/utils.h" + +namespace egr { +class TensorWrapper { + public: + TensorWrapper() = default; + explicit TensorWrapper(const egr::EagerTensor& tensor, + bool full_reserved = false) { + /** + * Normally, we should fully reserved all non-output or non-leaf fwd tensor + * here. And for fwd output tensor, we should not reserve its autogradmeta, + * to avoid recursive depends on GradNodeBase + * **/ + full_reserved_ = full_reserved; + if (full_reserved_) { + VLOG(6) << "Fully reserved tensor: " << tensor.name(); + intermidiate_tensor_ = tensor; + return; + } + + // shallow copy tensor_impl here + intermidiate_tensor_.set_impl(tensor.impl()); + intermidiate_tensor_.ResetVar(tensor.Var()); + intermidiate_tensor_.set_name(tensor.name() + "@Saved"); + PADDLE_ENFORCE_NOT_NULL( + EagerUtils::unsafe_autograd_meta(tensor), + paddle::platform::errors::Fatal( + "Full reserved Tensor should not have null autograd meta, since " + "tensor_wrapper is used to build backward info. There is no way " + "for us to build it with null autograd_meta.")); + // copy output_rank + out_rank_info_ = EagerUtils::OutRankInfo(tensor); + } + + egr::EagerTensor recover(const std::shared_ptr& grad_node) { + VLOG(6) << "Recover tensor for wrapper"; + if ((!intermidiate_tensor_.defined()) && + (!intermidiate_tensor_.Var().IsInitialized())) { + VLOG(6) << "Return NULL tensor Here. "; + return egr::EagerTensor(); + } + + // if it's full_reserved just return the full copy of tensor + if (full_reserved_) { + return intermidiate_tensor_; + } else { + std::shared_ptr new_grad_node = grad_node; + auto p_ab_autograd_meta = + std::make_shared(Edge(new_grad_node, out_rank_info_)); + intermidiate_tensor_.set_autograd_meta( + std::static_pointer_cast( + p_ab_autograd_meta)); + return intermidiate_tensor_; + } + } + + private: + bool full_reserved_ = false; + std::pair out_rank_info_; + egr::EagerTensor intermidiate_tensor_; +}; +} // namespace egr diff --git a/paddle/fluid/eager/tests/CMakeLists.txt b/paddle/fluid/eager/tests/CMakeLists.txt index 572740e03c66c1..bdf542f20e07d1 100644 --- a/paddle/fluid/eager/tests/CMakeLists.txt +++ b/paddle/fluid/eager/tests/CMakeLists.txt @@ -1,2 +1,3 @@ set(eager_deps pten pten_api) add_subdirectory(data_structure_tests) +add_subdirectory(task_tests) diff --git a/paddle/fluid/eager/tests/data_structure_tests/CMakeLists.txt b/paddle/fluid/eager/tests/data_structure_tests/CMakeLists.txt index 21e63b6480c733..2989330efa8aac 100644 --- a/paddle/fluid/eager/tests/data_structure_tests/CMakeLists.txt +++ b/paddle/fluid/eager/tests/data_structure_tests/CMakeLists.txt @@ -1,3 +1,4 @@ cc_test(test_egr_ds_eager_tensor SRCS eager_tensor_test.cc DEPS ${eager_deps} ) cc_test(test_egr_ds_auotgrad_meta SRCS autograd_meta_test.cc DEPS ${eager_deps} grad_node_info) cc_test(test_egr_ds_grad_node_info SRCS grad_node_info_test.cc DEPS ${eager_deps} grad_node_info) +cc_test(test_egr_ds_tensor_wrapper SRCS tensor_wrapper_test.cc DEPS ${eager_deps} grad_node_info utils) diff --git a/paddle/fluid/eager/tests/data_structure_tests/grad_node_test.h b/paddle/fluid/eager/tests/data_structure_tests/grad_node_test.h index ddea70da791a05..2870bfa8b0c943 100644 --- a/paddle/fluid/eager/tests/data_structure_tests/grad_node_test.h +++ b/paddle/fluid/eager/tests/data_structure_tests/grad_node_test.h @@ -19,6 +19,9 @@ #include "paddle/fluid/eager/eager_tensor.h" #include "paddle/fluid/eager/grad_node_info.h" #include "paddle/pten/api/lib/utils/allocator.h" +namespace egr { +class TensorWrapper; +} namespace eager_test { class GradTestNode : public egr::GradNodeBase { diff --git a/paddle/fluid/eager/tests/data_structure_tests/tensor_wrapper_test.cc b/paddle/fluid/eager/tests/data_structure_tests/tensor_wrapper_test.cc new file mode 100644 index 00000000000000..6d78cf42d0c48a --- /dev/null +++ b/paddle/fluid/eager/tests/data_structure_tests/tensor_wrapper_test.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "glog/logging.h" +#include "gtest/gtest.h" + +#include "paddle/fluid/eager/tensor_wrapper.h" +#include "paddle/fluid/eager/tests/data_structure_tests/grad_node_test.h" +#include "paddle/fluid/eager/utils.h" + +TEST(TensorWrapper, Basic) { + VLOG(6) << "Test Full reserved"; + egr::EagerTensor et1; + pten::DenseTensorMeta meta = pten::DenseTensorMeta( + pten::DataType::FLOAT32, paddle::framework::make_ddim({1, 2})); + std::shared_ptr dt = std::make_shared( + std::make_shared( + paddle::platform::CPUPlace()), + meta); + auto* dt_ptr = dt->mutable_data(); + dt_ptr[0] = 5.0f; + dt_ptr[1] = 10.0f; + et1.set_impl(dt); + // Create grad node; + auto grad_test_node0 = std::make_shared( + /* val */ 5.0, /* in_num */ 2, /* out_num */ 2); + egr::Edge edge0(grad_test_node0, 1, 2); + auto auto_grad0 = std::make_shared(edge0); + et1.set_autograd_meta(auto_grad0); + et1.set_name("et1"); + auto tw0 = egr::TensorWrapper(et1, true); + auto recover_et1 = tw0.recover(std::make_shared()); + CHECK_EQ(recover_et1.name(), std::string("et1")); + CHECK_EQ(egr::EagerUtils::OutRankInfo(recover_et1).first, + egr::EagerUtils::OutRankInfo(et1).first); + CHECK_EQ(egr::EagerUtils::OutRankInfo(recover_et1).second, + egr::EagerUtils::OutRankInfo(et1).second); + VLOG(6) << "Test reconstruct"; + egr::EagerTensor et2; + pten::DenseTensorMeta meta2 = pten::DenseTensorMeta( + pten::DataType::FLOAT32, paddle::framework::make_ddim({1, 2})); + std::shared_ptr dt2 = std::make_shared( + std::make_shared( + paddle::platform::CPUPlace()), + meta2); + auto* dt_ptr2 = dt->mutable_data(); + dt_ptr2[0] = 6.0f; + dt_ptr2[1] = 11.0f; + et2.set_impl(dt2); + et2.set_name("et2"); + auto grad_test_node1 = + std::make_shared(/* val */ 5.0, 2, 2); + egr::Edge edge1(grad_test_node1, 1, 2); + auto auto_grad1 = std::make_shared(edge1); + et2.set_autograd_meta(auto_grad1); + auto tw1 = egr::TensorWrapper(et2, false); + auto recover_et2 = tw1.recover(grad_test_node1); + CHECK_EQ(recover_et2.name(), std::string("et2@Saved")); + CHECK_EQ(egr::EagerUtils::OutRankInfo(recover_et2).first, + egr::EagerUtils::OutRankInfo(et2).first); + CHECK_EQ(egr::EagerUtils::OutRankInfo(recover_et2).second, + egr::EagerUtils::OutRankInfo(et2).second); + // Test Raw recover + egr::EagerTensor et3; + auto tw2 = egr::TensorWrapper(et3, true); + CHECK( + tw2.recover(std::make_shared()).initialized() == + false); +} diff --git a/paddle/fluid/eager/tests/task_tests/CMakeLists.txt b/paddle/fluid/eager/tests/task_tests/CMakeLists.txt new file mode 100644 index 00000000000000..61f89d945251a5 --- /dev/null +++ b/paddle/fluid/eager/tests/task_tests/CMakeLists.txt @@ -0,0 +1 @@ +cc_test(test_egr_task_eager_utils SRCS eager_utils_test.cc DEPS ${eager_deps} grad_node_info autograd_meta utils) diff --git a/paddle/fluid/eager/tests/task_tests/eager_utils_test.cc b/paddle/fluid/eager/tests/task_tests/eager_utils_test.cc new file mode 100644 index 00000000000000..3df0a77aed0752 --- /dev/null +++ b/paddle/fluid/eager/tests/task_tests/eager_utils_test.cc @@ -0,0 +1,202 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Eager Dygraph + +#include "gtest/gtest.h" + +#include "paddle/fluid/eager/autograd_meta.h" +#include "paddle/fluid/eager/grad_node_info.h" +#include "paddle/fluid/eager/tests/data_structure_tests/grad_node_test.h" +#include "paddle/fluid/eager/tests/test_utils.h" +#include "paddle/fluid/eager/utils.h" + +namespace eager_test { +template +egr::EagerTensor CreateTestCPUTensor(T val, + const paddle::framework::DDim& ddim) { + pten::DenseTensorMeta meta = + pten::DenseTensorMeta(pten::DataType::FLOAT32, ddim); + egr::EagerTensor tensor; + std::shared_ptr dt = std::make_shared( + std::make_shared( + paddle::platform::CPUPlace()), + meta); + auto* dt_ptr = dt->mutable_data(); + for (int64_t i = 0; i < dt->numel(); i++) { + dt_ptr[i] = val; + } + tensor.set_impl(dt); + return tensor; +} +} // namespace eager_test +TEST(EagerUtils, ComputeRequireGrad) { + auto auto_grad0 = std::make_shared(); + auto auto_grad1 = std::make_shared(); + auto auto_grad2 = std::make_shared(); + auto auto_grad3 = std::make_shared(); + CHECK_EQ(auto_grad0->NumericStopGradient(), -1); + VLOG(6) << "Single Test ComputeRequireGrad"; + auto_grad0->SetStopGradient(true); + CHECK(egr::EagerUtils::ComputeRequireGrad(true, auto_grad0.get()) == false); + CHECK(egr::EagerUtils::ComputeRequireGrad(false, auto_grad0.get()) == false); + auto_grad0->SetStopGradient(false); + CHECK(egr::EagerUtils::ComputeRequireGrad(false, auto_grad0.get()) == false); + CHECK(egr::EagerUtils::ComputeRequireGrad(true, auto_grad0.get()) == true); + + VLOG(6) << "Multi Test ComputeRequireGrad"; + auto_grad0->SetStopGradient(false); + auto_grad1->SetStopGradient(true); + CHECK(egr::EagerUtils::ComputeRequireGrad(true, auto_grad0.get(), + auto_grad1.get()) == true); + CHECK(egr::EagerUtils::ComputeRequireGrad(false, auto_grad0.get(), + auto_grad1.get()) == false); + auto_grad0->SetStopGradient(true); + CHECK(egr::EagerUtils::ComputeRequireGrad(true, auto_grad0.get(), + auto_grad1.get()) == false); + CHECK(egr::EagerUtils::ComputeRequireGrad(false, auto_grad0.get(), + auto_grad1.get()) == false); +} + +TEST(EagerUtils, PassStopGradient) { + auto auto_grad0 = std::make_shared(); + auto auto_grad1 = std::make_shared(); + auto auto_grad2 = std::make_shared(); + auto auto_grad3 = std::make_shared(); + CHECK_EQ(auto_grad0->NumericStopGradient(), -1); + VLOG(6) << "Test PassStopGradient"; + egr::EagerUtils::PassStopGradient(false, auto_grad0.get()); + CHECK(auto_grad0->StopGradient() == false); + egr::EagerUtils::PassStopGradient(true, auto_grad0.get(), auto_grad1.get(), + auto_grad2.get(), auto_grad3.get()); + CHECK(auto_grad0->StopGradient() == true); + CHECK(auto_grad1->StopGradient() == true); + CHECK(auto_grad2->StopGradient() == true); + CHECK(auto_grad3->StopGradient() == true); +} + +TEST(EagerUtils, SyncToVarsSingle) { + paddle::framework::DDim ddim = paddle::framework::make_ddim({2, 4, 4, 4}); + auto tensor = eager_test::CreateTestCPUTensor(5.0f, ddim); + std::vector> var_bases = + egr::EagerUtils::SyncToVars(tensor); + + paddle::framework::Variable* var = var_bases[0]->MutableVar(); + const auto& framework_tensor = var->Get(); + + const float* ptr = framework_tensor.data(); + VLOG(6) << "Check Value for SyncToVarsSingle"; + CHECK_EQ(framework_tensor.numel(), tensor.numel()); + + for (int i = 0; i < framework_tensor.numel(); i++) { + CHECK_EQ(ptr[i], 5.0f); + } +} + +TEST(EagerUtils, SyncToVarsMultiple) { + paddle::framework::DDim ddim = paddle::framework::make_ddim({2, 4, 4, 4}); + std::vector tensors = { + eager_test::CreateTestCPUTensor(1.0f, ddim), + eager_test::CreateTestCPUTensor(2.0f, ddim)}; + + std::vector> var_bases = + egr::EagerUtils::SyncToVars(tensors); + + { + paddle::framework::Variable* var = var_bases[0]->MutableVar(); + const auto& framework_tensor = var->Get(); + + const float* ptr = framework_tensor.data(); + CHECK_EQ(framework_tensor.numel(), tensors[0].numel()); + + for (int i = 0; i < framework_tensor.numel(); i++) { + CHECK_EQ(ptr[i], 1.0); + } + } + + { + paddle::framework::Variable* var = var_bases[1]->MutableVar(); + const auto& framework_tensor = var->Get(); + + const float* ptr = framework_tensor.data(); + VLOG(6) << "Check Value for SyncToVarsMultiple"; + CHECK_EQ(framework_tensor.numel(), tensors[0].numel()); + + for (int i = 0; i < framework_tensor.numel(); i++) { + CHECK_EQ(ptr[i], 2.0); + } + } +} + +TEST(EagerUtils, SyncToTensorSingle) { + std::shared_ptr X(new egr::EagerTensor()); + std::vector src_data(128, 5.0); + std::vector dims = {2, 4, 4, 4}; + paddle::platform::CPUPlace place; + + auto* x_tensor = X->MutableVar()->GetMutable(); + x_tensor->Resize(paddle::framework::make_ddim(dims)); + auto* mutable_x = x_tensor->mutable_data(place); + paddle::memory::Copy(place, mutable_x, place, src_data.data(), + sizeof(float) * src_data.size()); + auto X_ = egr::EagerUtils::SyncToTensors(*(X.get())); + egr::EagerTensor tensor = egr::EagerUtils::GetOutput(X_[0]); + VLOG(6) << "Check Value for SyncToTensorSingle"; + CHECK(eager_test::CompareTensorWithValue(tensor, 5.0)); +} + +TEST(EagerUtils, SyncToTensorMultiple) { + eager_test::InitEnv(paddle::platform::CPUPlace()); + std::vector dims = {2, 4, 4, 4}; + paddle::platform::CPUPlace place; + + std::vector egr_tensors; + { + auto egr_tensor = egr::EagerTensor(); + std::vector src_data(128, 1.0); + auto* x_tensor = + egr_tensor.MutableVar()->GetMutable(); + x_tensor->Resize(paddle::framework::make_ddim(dims)); + auto* mutable_x = x_tensor->mutable_data(place); + paddle::memory::Copy(place, mutable_x, place, src_data.data(), + sizeof(float) * src_data.size()); + egr_tensors.emplace_back(egr_tensor); + } + { + auto egr_tensor = egr::EagerTensor(); + std::vector src_data(128, 2.0); + auto* x_tensor = + egr_tensor.MutableVar()->GetMutable(); + x_tensor->Resize(paddle::framework::make_ddim(dims)); + auto* mutable_x = x_tensor->mutable_data(place); + paddle::memory::Copy(place, mutable_x, place, src_data.data(), + sizeof(float) * src_data.size()); + egr_tensors.emplace_back(std::move(egr_tensor)); + } + std::vector tensors = + egr::EagerUtils::GetOutputs(egr::EagerUtils::SyncToTensors(egr_tensors)); + + VLOG(6) << "Check Value for SyncToTensorMultiple"; + CHECK(eager_test::CompareTensorWithValue(tensors[0], 1.0) == true); + CHECK(eager_test::CompareTensorWithValue(tensors[1], 2.0) == true); +} + +TEST(EagerUtils, ConstructDuplicableOutput) { + VLOG(6) << "Check ConstructDuplicableOutput"; + std::vector> outs = + egr::EagerUtils::ConstructDuplicableOutput(2); + CHECK_EQ(outs.size(), size_t(2)); + CHECK(outs[0]->defined() == false); + CHECK(outs[0]->initialized() == false); +} diff --git a/paddle/fluid/eager/tests/test_utils.h b/paddle/fluid/eager/tests/test_utils.h new file mode 100644 index 00000000000000..b98ff72f0f0568 --- /dev/null +++ b/paddle/fluid/eager/tests/test_utils.h @@ -0,0 +1,175 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/eager/api/utils/global_utils.h" +#include "paddle/fluid/eager/autograd_meta.h" +#include "paddle/fluid/eager/eager_tensor.h" +#include "paddle/fluid/eager/utils.h" + +#include "paddle/pten/api/all.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/tensor_meta.h" + +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/init.h" + +namespace eager_test { + +template +bool CompareGradTensorWithValue(const egr::EagerTensor& target, T value) { + egr::AutogradMeta* meta = egr::EagerUtils::unsafe_autograd_meta(target); + auto grad_dense = + std::dynamic_pointer_cast(meta->Grad().impl()); + T* ptr = grad_dense->mutable_data(); + + std::vector host_data(grad_dense->numel()); + if (paddle::platform::is_gpu_place(grad_dense->place())) { +#ifdef PADDLE_WITH_CUDA + paddle::platform::DeviceContextPool& pool = + paddle::platform::DeviceContextPool::Instance(); + auto* dev_ctx = dynamic_cast( + pool.Get(paddle::platform::CUDAPlace())); + auto stream = dev_ctx->stream(); + + paddle::memory::Copy(paddle::platform::CPUPlace(), host_data.data(), + paddle::platform::CUDAPlace(), ptr, + sizeof(T) * grad_dense->numel(), stream); + ptr = host_data.data(); +#endif + } + VLOG(6) << "CompareGradTensorWithValue"; + for (int i = 0; i < grad_dense->numel(); i++) { + PADDLE_ENFORCE(value == ptr[i], + paddle::platform::errors::PreconditionNotMet( + "Numerical Error in Compare Grad Variable With Value of " + "%d, we expected got value: %f, but got: %f instead. " + "Please check it later.", + i, value, ptr[i])); + } + return true; +} + +template +bool CompareTensorWithValue(const egr::EagerTensor& target, T value) { + // TODO(jiabin): Support Selected Rows later + auto dense_t = std::dynamic_pointer_cast(target.impl()); + T* ptr = dense_t->mutable_data(); + + std::vector host_data(dense_t->numel()); + if (paddle::platform::is_gpu_place(dense_t->place())) { +#ifdef PADDLE_WITH_CUDA + paddle::platform::DeviceContextPool& pool = + paddle::platform::DeviceContextPool::Instance(); + auto* dev_ctx = dynamic_cast( + pool.Get(paddle::platform::CUDAPlace())); + auto stream = dev_ctx->stream(); + + paddle::memory::Copy(paddle::platform::CPUPlace(), host_data.data(), + paddle::platform::CUDAPlace(), ptr, + sizeof(T) * dense_t->numel(), stream); + ptr = host_data.data(); +#endif + } + + VLOG(6) << "CompareTensorWithValue"; + for (int i = 0; i < dense_t->numel(); i++) { + PADDLE_ENFORCE(value == ptr[i], + paddle::platform::errors::PreconditionNotMet( + "Numerical Error in Compare Grad Variable With Value of " + "%d, we expected got value: %f, but got: %f instead. " + "Please check it later.", + i, value, ptr[i])); + } + return true; +} + +template +bool CompareVariableWithValue(const egr::EagerTensor& target, T value) { + // TODO(jiabin): Support Selected Rows later + auto lod_tensor = target.Var().Get(); + T* ptr = lod_tensor.data(); + + std::vector host_data(lod_tensor.numel()); + if (paddle::platform::is_gpu_place(lod_tensor.place())) { +#ifdef PADDLE_WITH_CUDA + paddle::platform::DeviceContextPool& pool = + paddle::platform::DeviceContextPool::Instance(); + auto* dev_ctx = dynamic_cast( + pool.Get(paddle::platform::CUDAPlace())); + auto stream = dev_ctx->stream(); + + paddle::memory::Copy(paddle::platform::CPUPlace(), host_data.data(), + paddle::platform::CUDAPlace(), ptr, + sizeof(T) * lod_tensor.numel(), stream); + ptr = host_data.data(); +#endif + } + VLOG(6) << "CompareVariableWithValue"; + for (int i = 0; i < lod_tensor.numel(); i++) { + PADDLE_ENFORCE(value == ptr[i], + paddle::platform::errors::PreconditionNotMet( + "Numerical Error in Compare Grad Variable With Value of " + "%d, we expected got value: %f, but got: %f instead. " + "Please check it later.", + i, value, ptr[i])); + } + return true; +} + +template +bool CompareGradVariableWithValue(const egr::EagerTensor& target, T value) { + // TODO(jiabin): Support Selected Rows later + egr::AutogradMeta* meta = egr::EagerUtils::unsafe_autograd_meta(target); + auto lod_tensor = meta->Grad().Var().Get(); + T* ptr = lod_tensor.data(); + + std::vector host_data(lod_tensor.numel()); + if (paddle::platform::is_gpu_place(lod_tensor.place())) { +#ifdef PADDLE_WITH_CUDA + paddle::platform::DeviceContextPool& pool = + paddle::platform::DeviceContextPool::Instance(); + auto* dev_ctx = dynamic_cast( + pool.Get(paddle::platform::CUDAPlace())); + auto stream = dev_ctx->stream(); + + paddle::memory::Copy(paddle::platform::CPUPlace(), host_data.data(), + paddle::platform::CUDAPlace(), ptr, + sizeof(T) * lod_tensor.numel(), stream); + ptr = host_data.data(); +#endif + } + VLOG(6) << "CompareGradVariableWithValue"; + for (int i = 0; i < lod_tensor.numel(); i++) { + PADDLE_ENFORCE(value == ptr[i], + paddle::platform::errors::PreconditionNotMet( + "Numerical Error in Compare Grad Variable With Value of " + "%d, we expected got value: %f, but got: %f instead. " + "Please check it later.", + i, value, ptr[i])); + } + return true; +} + +inline void InitEnv(paddle::platform::Place place) { + // Prepare Device Contexts + // Init DeviceContextPool + paddle::framework::InitDevices(); + + // Init Tracer Place + egr::Controller::Instance().SetExpectedPlace(place); +} +} // namespace eager_test diff --git a/paddle/fluid/eager/utils.cc b/paddle/fluid/eager/utils.cc new file mode 100644 index 00000000000000..c5bda181d405b5 --- /dev/null +++ b/paddle/fluid/eager/utils.cc @@ -0,0 +1,120 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/eager/utils.h" +#include "paddle/fluid/eager/api/all.h" + +namespace egr { +/* ---- Tensor -> Var ---- */ +std::vector> EagerUtils::SyncToVars( + const egr::EagerTensor& tensor) { + // TODO(jiabin): No const cast here. We should call SyncToVar in Python_C + // wrapper + const_cast(&tensor)->SyncToVar( + paddle::framework::proto::VarType_Type_LOD_TENSOR); + return {std::make_shared(tensor)}; +} + +std::vector> EagerUtils::SyncToVars( + const std::vector& tensors) { + // TODO(jiabin): No const cast here. We should call SyncToVar in Python_C + // wrapper + std::vector> res; + size_t num = tensors.size(); + res.reserve(num); + for (size_t i = 0; i < num; i++) { + const_cast(&(tensors[i])) + ->SyncToVar(paddle::framework::proto::VarType_Type_LOD_TENSOR); + res.emplace_back(new EagerTensor(tensors[i])); + } + return res; +} + +/* ---- VarBase -> Tensor ---- */ +std::vector> EagerUtils::SyncToTensors( + const egr::EagerTensor& tensor) { + // TODO(jiabin): No const cast here. We should call SyncToTensor in Python_C + // wrapper + const_cast(&tensor)->SyncToTensor(); + return {std::make_shared(tensor)}; +} + +std::vector> EagerUtils::SyncToTensors( + const std::vector& tensors) { + // TODO(jiabin): No const cast here. We should call SyncToTensor in Python_C + // wrapper + std::vector> res; + size_t num = tensors.size(); + res.reserve(num); + for (size_t i = 0; i < num; i++) { + const_cast(&(tensors[i]))->SyncToTensor(); + res.emplace_back(new EagerTensor(tensors[i])); + } + return res; +} + +std::vector> EagerUtils::ConstructDuplicableOutput( + const size_t num) { + std::vector> res; + res.reserve(num); + for (size_t i = 0; i < num; i++) { + res.emplace_back( + new EagerTensor(egr::Controller::Instance().GenerateUniqueName())); + } + return res; +} + +std::vector EagerUtils::GetOutputs( + const std::vector>& outs) { + std::vector res; + res.reserve(outs.size()); + for (const auto& out : outs) { + PADDLE_ENFORCE_NOT_NULL( + out.get(), paddle::platform::errors::Fatal( + "Eager Tensor %s is null and cannot be copied. " + "We are tring to Get Output tensor from its " + "shared_ptr, this error may indicate some outputs " + "are nullptr", + out->name())); + res.emplace_back((*(out.get()))); + } + return res; +} + +egr::EagerTensor EagerUtils::GetOutput( + const std::shared_ptr& out) { + PADDLE_ENFORCE_NOT_NULL( + out.get(), paddle::platform::errors::Fatal( + "Eager Tensor %s is null and cannot be copied. We " + "are tring to Get Output tensor from its shared_ptr, " + "this error may indicate output is nullptr", + out->name())); + return EagerTensor((*(out.get()))); +} + +AutogradMeta* EagerUtils::unsafe_autograd_meta(const egr::EagerTensor& target) { + auto* p_autograd_meta = target.get_autograd_meta(); + PADDLE_ENFORCE(p_autograd_meta, + paddle::platform::errors::Fatal( + "Null autograd_meta gotten from unsafe_autograd_meta(), " + "if you are using unsafe_autograd_meta, please make sure " + "your tensor's autograd_meta is set")); + return static_cast(p_autograd_meta); +} + +std::pair EagerUtils::OutRankInfo( + const egr::EagerTensor& target) { + return unsafe_autograd_meta(target)->OutRankInfo(); +} +} // namespace egr diff --git a/paddle/fluid/eager/utils.h b/paddle/fluid/eager/utils.h new file mode 100644 index 00000000000000..4e8461e6600b94 --- /dev/null +++ b/paddle/fluid/eager/utils.h @@ -0,0 +1,126 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/eager/autograd_meta.h" +#include "paddle/fluid/eager/eager_tensor.h" +#include "paddle/fluid/eager/grad_node_info.h" + +#include "paddle/pten/api/all.h" + +namespace egr { + +/** + * EagerUtils is utils used to do some static conversion or autograd + * members access, this class is desinged to be a full static functional + * utils class + * **/ + +template +class IterHelper { + virtual void visit(ElementType element) = 0; + + void visit(std::vector* elements) { + for (auto element : *elements) visit(element); + } + + template + void apply() {} + + public: + template + void apply(T&& arg, Args&&... args) { + visit(std::forward(arg)); + return apply(std::forward(args)...); + } + virtual ~IterHelper() = default; +}; + +class ComputeRequireGradIter : public IterHelper { + public: + bool RequireGrad() { return require_grad_; } + + private: + void visit(AutogradMeta* element) override { + bool stop_gradient = element->StopGradient(); + if (!stop_gradient) require_grad_ = true; + } + + bool require_grad_ = false; +}; + +class PassStopGradientIter : public IterHelper { + public: + void SetStopGradient(bool stop_gradient) { stop_gradient_ = stop_gradient; } + + private: + void visit(AutogradMeta* element) override { + if (!element) { + // TODO(jiabin): Add Tensor name here when we supported. + VLOG(2) << "Tensor is NULL"; + return; + } + element->SetStopGradient(stop_gradient_); + } + + bool stop_gradient_ = true; +}; + +class EagerUtils { + public: + /** + * We have to use autograd_meta and multi_autograd_meta to initialize + * autograd_meta for tensor, since we can't init it in + * egr::EagerTensor's + * constructor (it's abstract class there) + * + * **/ + template + static bool ComputeRequireGrad(T trace_backward, Args&&... args) { + if (!trace_backward) return false; + + auto iter = ComputeRequireGradIter(); + iter.apply(std::forward(args)...); + + return iter.RequireGrad(); + } + + template + static void PassStopGradient(T stop_gradient, Args&&... args) { + auto iter = PassStopGradientIter(); + iter.SetStopGradient(stop_gradient); + iter.apply(std::forward(args)...); + } + static std::pair OutRankInfo(const egr::EagerTensor& target); + // This method will return an AutogradMeta pointer unsafely. + static AutogradMeta* unsafe_autograd_meta(const egr::EagerTensor& target); + + // Intermidate needed remove this once we don't need legacy + static std::vector> SyncToVars( + const egr::EagerTensor& tensor); + static std::vector> SyncToVars( + const std::vector& tensors); + static std::vector> SyncToTensors( + const egr::EagerTensor& tensor); + static std::vector> SyncToTensors( + const std::vector& tensors); + static std::vector> ConstructDuplicableOutput( + const size_t num); + static std::vector GetOutputs( + const std::vector>& outs); + static egr::EagerTensor GetOutput(const std::shared_ptr& outs); +}; + +} // namespace egr