From 836102ec9bb7c8593dfd92b16a93da54a867d341 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=85=E5=89=91=E5=AF=92?= Date: Tue, 12 Sep 2023 11:24:15 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90CINN=E3=80=91Delete=20remove=5Fnested?= =?UTF-8?q?=5Fblock=20(#56972)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * delete remove_nested_block * fix simplify blocks * add logic for dealing with scheduleblock * add logic for dealing with scheduleblock * fix logic about scheduleblock * fix scheduleblock logic * add logic for deal with block within ifthenelse --- paddle/cinn/backends/codegen_c.cc | 3 +- paddle/cinn/backends/codegen_cuda_dev.cc | 3 +- paddle/cinn/lang/lower_impl.cc | 4 +- paddle/cinn/lang/lower_impl.h | 1 - paddle/cinn/optim/CMakeLists.txt | 3 - paddle/cinn/optim/ir_simplify.cc | 44 +++++++ paddle/cinn/optim/optimize.cc | 5 +- paddle/cinn/optim/remove_nested_block.cc | 123 ------------------ paddle/cinn/optim/remove_nested_block.h | 33 ----- paddle/cinn/optim/remove_nested_block_test.cc | 58 --------- 10 files changed, 50 insertions(+), 227 deletions(-) delete mode 100644 paddle/cinn/optim/remove_nested_block.cc delete mode 100644 paddle/cinn/optim/remove_nested_block.h delete mode 100644 paddle/cinn/optim/remove_nested_block_test.cc diff --git a/paddle/cinn/backends/codegen_c.cc b/paddle/cinn/backends/codegen_c.cc index cffebdc1a67360..3352a458cecebf 100644 --- a/paddle/cinn/backends/codegen_c.cc +++ b/paddle/cinn/backends/codegen_c.cc @@ -23,7 +23,6 @@ #include "paddle/cinn/ir/op/ir_operators.h" #include "paddle/cinn/ir/utils/ir_verify.h" #include "paddle/cinn/optim/ir_simplify.h" -#include "paddle/cinn/optim/remove_nested_block.h" #include "paddle/cinn/runtime/cpu/thread_backend.h" #include "paddle/cinn/runtime/intrinsic.h" #include "paddle/cinn/utils/string.h" @@ -645,7 +644,7 @@ void CodeGenC::Visit(const ir::_LoweredFunc_ *op) { Expr func_body = ir::Block::Make(new_body); - optim::RemoveNestedBlock(&func_body); + optim::SimplifyBlocks(&func_body); IrPrinter::Visit(func_body); } diff --git a/paddle/cinn/backends/codegen_cuda_dev.cc b/paddle/cinn/backends/codegen_cuda_dev.cc index 018f935482c7f9..e33154f0c0129b 100644 --- a/paddle/cinn/backends/codegen_cuda_dev.cc +++ b/paddle/cinn/backends/codegen_cuda_dev.cc @@ -24,7 +24,6 @@ #include "paddle/cinn/ir/op/ir_operators.h" #include "paddle/cinn/ir/utils/ir_verify.h" #include "paddle/cinn/optim/ir_simplify.h" -#include "paddle/cinn/optim/remove_nested_block.h" namespace cinn { namespace backends { @@ -141,7 +140,7 @@ void CodeGenCUDA_Dev::Visit(const ir::_LoweredFunc_ *op) { Expr func_body = ir::Block::Make(new_body); - optim::RemoveNestedBlock(&func_body); + optim::SimplifyBlocks(&func_body); // Make sure that the function's body is wrapped by a block if (!func_body.As()) { func_body = ir::Block::Make({func_body}); diff --git a/paddle/cinn/lang/lower_impl.cc b/paddle/cinn/lang/lower_impl.cc index f313d52938a93a..03185077285c25 100644 --- a/paddle/cinn/lang/lower_impl.cc +++ b/paddle/cinn/lang/lower_impl.cc @@ -25,7 +25,7 @@ #include "paddle/cinn/ir/ir_base.h" #include "paddle/cinn/ir/tensor.h" #include "paddle/cinn/ir/utils/ir_printer.h" -#include "paddle/cinn/optim/remove_nested_block.h" +#include "paddle/cinn/optim/ir_simplify.h" #include "paddle/cinn/optim/replace_var_with_expr.h" #include "paddle/cinn/optim/transform_polyfor_to_for.h" #include "paddle/cinn/poly/stage.h" @@ -655,7 +655,7 @@ std::vector LowerImpl::operator()() { if (support_ir_schedule_) { optim::TransformPolyForToFor(&func->body); - optim::RemoveNestedBlock(&func->body); + optim::SimplifyBlocks(&func->body); func->body = ir::Block::Make({func->body}); result.push_back(ir::LoweredFunc(func.get())); num_func++; diff --git a/paddle/cinn/lang/lower_impl.h b/paddle/cinn/lang/lower_impl.h index 3e52279b19566a..4876888b926c69 100644 --- a/paddle/cinn/lang/lower_impl.h +++ b/paddle/cinn/lang/lower_impl.h @@ -32,7 +32,6 @@ #include "paddle/cinn/optim/compute_inline_expand.h" #include "paddle/cinn/optim/fold_cinn_call_arguments.h" #include "paddle/cinn/optim/optimize.h" -#include "paddle/cinn/optim/remove_nested_block.h" #include "paddle/cinn/optim/replace_call_with_expr.h" #include "paddle/cinn/optim/tensor_write_tell.h" #include "paddle/cinn/optim/transform_gpu_forloop.h" diff --git a/paddle/cinn/optim/CMakeLists.txt b/paddle/cinn/optim/CMakeLists.txt index 99ae9cf3bd3d64..37210affeb0663 100755 --- a/paddle/cinn/optim/CMakeLists.txt +++ b/paddle/cinn/optim/CMakeLists.txt @@ -3,7 +3,6 @@ core_gather_headers() gather_srcs( cinnapi_src SRCS - remove_nested_block.cc replace_call_with_expr.cc ir_replace.cc replace_var_with_expr.cc @@ -33,8 +32,6 @@ if(WITH_CUDA) gather_srcs(cinnapi_src SRCS transform_gpu_forloop.cc) endif() -cinn_cc_test(test_remove_nested_block SRCS remove_nested_block_test.cc DEPS - cinncore) cinn_cc_test(test_ir_simplify SRCS ir_simplify_test.cc DEPS cinncore) cinn_cc_test(test_replace_call_with_expr SRCS replace_call_with_expr_test.cc DEPS cinncore) diff --git a/paddle/cinn/optim/ir_simplify.cc b/paddle/cinn/optim/ir_simplify.cc index bfed498da521da..6cf3fcf4b7be8e 100644 --- a/paddle/cinn/optim/ir_simplify.cc +++ b/paddle/cinn/optim/ir_simplify.cc @@ -305,6 +305,50 @@ struct SimplifyBlocksMutator : public ir::IRMutator<> { expr->As()->stmts = stmts; } } + + void Visit(const IfThenElse* op, Expr* expr) override { + auto* node = expr->As(); + Visit(&node->condition, &node->condition); + if (node->true_case.As() && + (node->true_case.As()->stmts.size() == 1)) { + node->true_case = node->true_case.As()->stmts[0]; + } + Visit(&node->true_case, &node->true_case); + if (node->false_case.defined()) { + if (node->false_case.As() && + (node->false_case.As()->stmts.size() == 1)) { + node->false_case = node->false_case.As()->stmts[0]; + } + Visit(&node->false_case, &node->false_case); + } + } + + void Visit(const ScheduleBlock* op, Expr* expr) override { + auto* node = expr->As(); + CHECK(node); + for (auto& var : node->iter_vars) { + if (var->lower_bound.defined()) { + Visit(&var->lower_bound, &var->lower_bound); + } + if (var->upper_bound.defined()) { + Visit(&var->upper_bound, &var->upper_bound); + } + } + for (auto& buffer_region : node->read_buffers) { + Visit(&buffer_region, &buffer_region); + } + for (auto& buffer_region : node->write_buffers) { + Visit(&buffer_region, &buffer_region); + } + + if (node->body.As()) { + if (node->body.As()->stmts.size() == 1) { + node->body = node->body.As()->stmts[0]; + } + } + + Visit(&(node->body), &(node->body)); + } }; struct SimplifyForLoopsMutator : public ir::IRMutator<> { diff --git a/paddle/cinn/optim/optimize.cc b/paddle/cinn/optim/optimize.cc index b1e73e3c58a9b6..3764e1bd616e23 100644 --- a/paddle/cinn/optim/optimize.cc +++ b/paddle/cinn/optim/optimize.cc @@ -27,7 +27,6 @@ #include "paddle/cinn/optim/lower_function_call_bind_vars.h" #include "paddle/cinn/optim/lower_intrin.h" #include "paddle/cinn/optim/map_extern_call.h" -#include "paddle/cinn/optim/remove_nested_block.h" #include "paddle/cinn/optim/remove_schedule_block.h" #include "paddle/cinn/optim/replace_const_param_to_integer.h" #include "paddle/cinn/optim/transform_gpu_forloop.h" @@ -65,8 +64,8 @@ Expr Optimize(Expr e, CudaSyncThreadsDropIfThenElse(&copied); #endif - RemoveNestedBlock(&copied); - VLOG(4) << "After Optimize RemoveNestedBlock:" << copied; + SimplifyBlocks(&copied); + VLOG(4) << "After SimplifyBlocks:" << copied; MapExternCall(&copied, target); VLOG(10) << "After Optimize MapExternCall:" << copied; diff --git a/paddle/cinn/optim/remove_nested_block.cc b/paddle/cinn/optim/remove_nested_block.cc deleted file mode 100644 index 06050ec5b123cc..00000000000000 --- a/paddle/cinn/optim/remove_nested_block.cc +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright (c) 2021 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/cinn/optim/remove_nested_block.h" - -#include "paddle/cinn/ir/utils/ir_mutator.h" -#include "paddle/cinn/ir/utils/ir_printer.h" - -namespace cinn { -namespace optim { - -Expr GetExprInsideBlock(Expr op) { - Expr node = op; - while (node.As()) { - auto& stmts = node.As()->stmts; - if (stmts.size() == 1) { - node = stmts.front(); - } else { - break; - } - } - return node; -} - -// This will remove the nested blocks, but it will also remove the block outside -// the forloop's body. -struct NestedBlockSimplifer : public ir::IRMutator { - void operator()(ir::Expr* expr) { Visit(expr); } - - private: - void Visit(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } - - void Visit(const ir::Block* expr, Expr* op) override { - auto* node = op->As(); - if (node->stmts.size() == 1) { - *op = GetExprInsideBlock(*op); - IRMutator::Visit(op, op); - } else { - IRMutator::Visit(expr, op); - } - } -}; - -struct NestedBlockRemover : public ir::IRMutator { - void operator()(ir::Expr* expr) { Visit(expr); } - - private: - void Visit(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } - - void Visit(const ir::Block* expr, Expr* op) override { - auto* node = op->As(); - - std::vector new_exprs; - - bool detect_nested = false; - for (auto it = node->stmts.begin(); it != node->stmts.end(); it++) { - auto* block = it->As(); - if (block) { - detect_nested = true; - new_exprs.insert( - std::end(new_exprs), block->stmts.begin(), block->stmts.end()); - } else { - new_exprs.push_back(*it); - } - } - - node->stmts = new_exprs; - - IRMutator::Visit(expr, op); - } -}; - -// add block outside forloop's body. -struct AddBlockToForloop : public ir::IRMutator<> { - void operator()(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } - - void Visit(const ir::For* expr, Expr* op) override { - auto* node = op->As(); - if (!node->body.As()) { - node->body = ir::Block::Make({node->body}); - } - - ir::IRMutator<>::Visit(expr, op); - } - - void Visit(const ir::PolyFor* expr, Expr* op) override { - auto* node = op->As(); - if (!node->body.As()) { - node->body = ir::Block::Make({node->body}); - } - - ir::IRMutator<>::Visit(expr, op); - } - - void Visit(const ir::_LoweredFunc_* expr, Expr* op) override { - auto* node = op->As(); - if (!node->body.As()) { - node->body = ir::Block::Make({node->body}); - } - - ir::IRMutator<>::Visit(expr, op); - } -}; - -void RemoveNestedBlock(Expr* e) { - NestedBlockRemover()(e); - NestedBlockSimplifer()(e); - AddBlockToForloop()(e); -} - -} // namespace optim -} // namespace cinn diff --git a/paddle/cinn/optim/remove_nested_block.h b/paddle/cinn/optim/remove_nested_block.h deleted file mode 100644 index 41220c18b254a6..00000000000000 --- a/paddle/cinn/optim/remove_nested_block.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) 2021 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/** - * This file implements the strategy to remove the unnecessary nested block. - */ -#pragma once -#include - -#include "paddle/cinn/common/common.h" -#include "paddle/cinn/ir/ir.h" - -namespace cinn { -namespace optim { - -/** - * Remove the unecessary nested block. - */ -void RemoveNestedBlock(Expr* e); - -} // namespace optim -} // namespace cinn diff --git a/paddle/cinn/optim/remove_nested_block_test.cc b/paddle/cinn/optim/remove_nested_block_test.cc deleted file mode 100644 index 27238329dfbd7e..00000000000000 --- a/paddle/cinn/optim/remove_nested_block_test.cc +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2021 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/cinn/optim/remove_nested_block.h" - -#include - -#include -#include - -#include "paddle/cinn/ir/utils/ir_printer.h" -#include "paddle/cinn/utils/string.h" - -namespace cinn { -namespace optim { - -TEST(RemoveNestedBlock, basic) { - auto block0 = ir::Block::Make({Expr(1.f), Expr(1.f)}); - auto block1 = ir::Block::Make({block0}); - auto e = Expr(block1); - - std::string origin = utils::GetStreamCnt(e); - EXPECT_EQ(origin, utils::Trim(R"ROC( -{ - { - 1.00000000f - 1.00000000f - } -} - )ROC")); - - std::cout << "origin:\n" << e << std::endl; - - RemoveNestedBlock(&e); - - std::cout << "e:\n" << e << std::endl; - - EXPECT_EQ(utils::GetStreamCnt(e), utils::Trim(R"ROC( -{ - 1.00000000f - 1.00000000f -} - )ROC")); -} - -} // namespace optim -} // namespace cinn