diff --git a/oneflow/core/job_rewriter/cudnn_fused_normalization_add_relu_pass.cpp b/oneflow/core/job_rewriter/cudnn_fused_normalization_add_relu_pass.cpp index 0a07f3a43c2..545dfb15414 100644 --- a/oneflow/core/job_rewriter/cudnn_fused_normalization_add_relu_pass.cpp +++ b/oneflow/core/job_rewriter/cudnn_fused_normalization_add_relu_pass.cpp @@ -84,7 +84,11 @@ Maybe CudnnFusedNormalizationAddReluPass::Apply(Job* job, JobPassCtx* ctx) OperatorConf new_op_conf = op_conf; auto mute_attrs = new_op_conf.mutable_user_conf()->mutable_attr(); auto training_it = mute_attrs->find("training"); - if (training_it != mute_attrs->end()) { mute_attrs->erase(training_it); } + if (training_it != mute_attrs->end()) { + const bool training = user_op_conf.attr("training"); + if (!training) { return; } + mute_attrs->erase(training_it); + } new_op_conf.mutable_user_conf()->set_op_type_name("cudnn_fused_" + op_type_name); job_builder.MutOpsOnlyOnce({new_op_conf}); });