Skip to content

Commit

Permalink
【CINN】Delete remove_nested_block (#56972)
Browse files Browse the repository at this point in the history
* delete remove_nested_block

* fix simplify blocks

* add logic for dealing with scheduleblock

* add logic for dealing with scheduleblock

* fix logic about scheduleblock

* fix scheduleblock logic

* add logic for deal with block within ifthenelse
  • Loading branch information
Courtesy-Xs authored Sep 12, 2023
1 parent 2d17239 commit 836102e
Show file tree
Hide file tree
Showing 10 changed files with 50 additions and 227 deletions.
3 changes: 1 addition & 2 deletions paddle/cinn/backends/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_verify.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/remove_nested_block.h"
#include "paddle/cinn/runtime/cpu/thread_backend.h"
#include "paddle/cinn/runtime/intrinsic.h"
#include "paddle/cinn/utils/string.h"
Expand Down Expand Up @@ -645,7 +644,7 @@ void CodeGenC::Visit(const ir::_LoweredFunc_ *op) {

Expr func_body = ir::Block::Make(new_body);

optim::RemoveNestedBlock(&func_body);
optim::SimplifyBlocks(&func_body);

IrPrinter::Visit(func_body);
}
Expand Down
3 changes: 1 addition & 2 deletions paddle/cinn/backends/codegen_cuda_dev.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_verify.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/remove_nested_block.h"

namespace cinn {
namespace backends {
Expand Down Expand Up @@ -141,7 +140,7 @@ void CodeGenCUDA_Dev::Visit(const ir::_LoweredFunc_ *op) {

Expr func_body = ir::Block::Make(new_body);

optim::RemoveNestedBlock(&func_body);
optim::SimplifyBlocks(&func_body);
// Make sure that the function's body is wrapped by a block
if (!func_body.As<ir::Block>()) {
func_body = ir::Block::Make({func_body});
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/lang/lower_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/optim/remove_nested_block.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/replace_var_with_expr.h"
#include "paddle/cinn/optim/transform_polyfor_to_for.h"
#include "paddle/cinn/poly/stage.h"
Expand Down Expand Up @@ -655,7 +655,7 @@ std::vector<ir::LoweredFunc> LowerImpl::operator()() {

if (support_ir_schedule_) {
optim::TransformPolyForToFor(&func->body);
optim::RemoveNestedBlock(&func->body);
optim::SimplifyBlocks(&func->body);
func->body = ir::Block::Make({func->body});
result.push_back(ir::LoweredFunc(func.get()));
num_func++;
Expand Down
1 change: 0 additions & 1 deletion paddle/cinn/lang/lower_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
#include "paddle/cinn/optim/compute_inline_expand.h"
#include "paddle/cinn/optim/fold_cinn_call_arguments.h"
#include "paddle/cinn/optim/optimize.h"
#include "paddle/cinn/optim/remove_nested_block.h"
#include "paddle/cinn/optim/replace_call_with_expr.h"
#include "paddle/cinn/optim/tensor_write_tell.h"
#include "paddle/cinn/optim/transform_gpu_forloop.h"
Expand Down
3 changes: 0 additions & 3 deletions paddle/cinn/optim/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ core_gather_headers()
gather_srcs(
cinnapi_src
SRCS
remove_nested_block.cc
replace_call_with_expr.cc
ir_replace.cc
replace_var_with_expr.cc
Expand Down Expand Up @@ -33,8 +32,6 @@ if(WITH_CUDA)
gather_srcs(cinnapi_src SRCS transform_gpu_forloop.cc)
endif()

cinn_cc_test(test_remove_nested_block SRCS remove_nested_block_test.cc DEPS
cinncore)
cinn_cc_test(test_ir_simplify SRCS ir_simplify_test.cc DEPS cinncore)
cinn_cc_test(test_replace_call_with_expr SRCS replace_call_with_expr_test.cc
DEPS cinncore)
Expand Down
44 changes: 44 additions & 0 deletions paddle/cinn/optim/ir_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,50 @@ struct SimplifyBlocksMutator : public ir::IRMutator<> {
expr->As<ir::Block>()->stmts = stmts;
}
}

void Visit(const IfThenElse* op, Expr* expr) override {
auto* node = expr->As<IfThenElse>();
Visit(&node->condition, &node->condition);
if (node->true_case.As<Block>() &&
(node->true_case.As<Block>()->stmts.size() == 1)) {
node->true_case = node->true_case.As<Block>()->stmts[0];
}
Visit(&node->true_case, &node->true_case);
if (node->false_case.defined()) {
if (node->false_case.As<Block>() &&
(node->false_case.As<Block>()->stmts.size() == 1)) {
node->false_case = node->false_case.As<Block>()->stmts[0];
}
Visit(&node->false_case, &node->false_case);
}
}

void Visit(const ScheduleBlock* op, Expr* expr) override {
auto* node = expr->As<ScheduleBlock>();
CHECK(node);
for (auto& var : node->iter_vars) {
if (var->lower_bound.defined()) {
Visit(&var->lower_bound, &var->lower_bound);
}
if (var->upper_bound.defined()) {
Visit(&var->upper_bound, &var->upper_bound);
}
}
for (auto& buffer_region : node->read_buffers) {
Visit(&buffer_region, &buffer_region);
}
for (auto& buffer_region : node->write_buffers) {
Visit(&buffer_region, &buffer_region);
}

if (node->body.As<Block>()) {
if (node->body.As<Block>()->stmts.size() == 1) {
node->body = node->body.As<Block>()->stmts[0];
}
}

Visit(&(node->body), &(node->body));
}
};

struct SimplifyForLoopsMutator : public ir::IRMutator<> {
Expand Down
5 changes: 2 additions & 3 deletions paddle/cinn/optim/optimize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#include "paddle/cinn/optim/lower_function_call_bind_vars.h"
#include "paddle/cinn/optim/lower_intrin.h"
#include "paddle/cinn/optim/map_extern_call.h"
#include "paddle/cinn/optim/remove_nested_block.h"
#include "paddle/cinn/optim/remove_schedule_block.h"
#include "paddle/cinn/optim/replace_const_param_to_integer.h"
#include "paddle/cinn/optim/transform_gpu_forloop.h"
Expand Down Expand Up @@ -65,8 +64,8 @@ Expr Optimize(Expr e,
CudaSyncThreadsDropIfThenElse(&copied);
#endif

RemoveNestedBlock(&copied);
VLOG(4) << "After Optimize RemoveNestedBlock:" << copied;
SimplifyBlocks(&copied);
VLOG(4) << "After SimplifyBlocks:" << copied;

MapExternCall(&copied, target);
VLOG(10) << "After Optimize MapExternCall:" << copied;
Expand Down
123 changes: 0 additions & 123 deletions paddle/cinn/optim/remove_nested_block.cc

This file was deleted.

33 changes: 0 additions & 33 deletions paddle/cinn/optim/remove_nested_block.h

This file was deleted.

58 changes: 0 additions & 58 deletions paddle/cinn/optim/remove_nested_block_test.cc

This file was deleted.

0 comments on commit 836102e

Please sign in to comment.