diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index 417f457c8910ca..786ebb23b457d5 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -274,6 +274,13 @@ static void equivalenceAnalysis(func::FuncOp funcOp, }); } +/// Return "true" if the given function signature has tensor semantics. +static bool hasTensorSignature(func::FuncOp funcOp) { + auto isaTensor = [](Type t) { return isa(t); }; + return llvm::any_of(funcOp.getFunctionType().getInputs(), isaTensor) || + llvm::any_of(funcOp.getFunctionType().getResults(), isaTensor); +} + /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by /// callee-caller order (i.e. callees without callers first). /// Store the map of FuncOp to all its callers in `callerMap`. @@ -297,10 +304,16 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp, "without a unique ReturnOp"; } + // Collect function calls and populate the caller map. numberCallOpsContainedInFuncOp[funcOp] = 0; return funcOp.walk([&](func::CallOp callOp) -> WalkResult { func::FuncOp calledFunction = getCalledFunction(callOp); assert(calledFunction && "could not retrieved called func::FuncOp"); + // If the called function does not have any tensors in its signature, then + // it is not necessary to bufferize the callee before the caller. + if (!hasTensorSignature(calledFunction)) + return WalkResult::skip(); + callerMap[calledFunction].insert(callOp); if (calledBy[calledFunction].insert(funcOp).second) { numberCallOpsContainedInFuncOp[funcOp]++; @@ -310,7 +323,7 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp, }); if (res.wasInterrupted()) return failure(); - // Iteratively remove function operation that do not call any of the + // Iteratively remove function operations that do not call any of the // functions remaining in the callCounter map and add them to the worklist. while (!numberCallOpsContainedInFuncOp.empty()) { auto it = llvm::find_if(numberCallOpsContainedInFuncOp, diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir index fd74ae0b60dbbb..ee0f71f668dc74 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir @@ -27,14 +27,14 @@ func.func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor, %t2 : tensor // expected-error @-3 {{expected callgraph to be free of circular dependencies}} -func.func @foo() { - call @bar() : () -> () - return +func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> { + %0 = call @bar(%t) : (tensor<5xf32>) -> (tensor<5xf32>) + return %0 : tensor<5xf32> } -func.func @bar() { - call @foo() : () -> () - return +func.func @bar(%t: tensor<5xf32>) -> tensor<5xf32>{ + %0 = call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>) + return %0 : tensor<5xf32> } // ----- diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir index b9de4ba34e0e6d..39f4835b28ffeb 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir @@ -662,3 +662,24 @@ func.func @br_in_func(%t: tensor<5xf32>) -> tensor<5xf32> { ^bb1(%arg1 : tensor<5xf32>): func.return %arg1 : tensor<5xf32> } + +// ----- + +// Cyclic call graphs with tensors are not supported by One-Shot Bufferize. +// However, if a function signature does not have any tensor arguments or +// results, calls to that function are not seen as an "edge" in the fuction +// call graph. + +// CHECK-LABEL: func.func @foo(%{{.*}}: memref<5xf32>) -> memref<5xf32> +func.func @foo(%m: memref<5xf32>) -> memref<5xf32> { + %0 = tensor.empty() : tensor<5xf32> + %1 = func.call @bar(%0, %m) + : (tensor<5xf32>, memref<5xf32>) -> (memref<5xf32>) + return %1 : memref<5xf32> +} + +// CHECK: func.func @bar(%{{.*}}: memref<5xf32, strided<[?], offset: ?>>, %arg1: memref<5xf32>) -> memref<5xf32> +func.func @bar(%t: tensor<5xf32>, %m: memref<5xf32>) -> memref<5xf32> { + %0 = func.call @foo(%m) : (memref<5xf32>) -> (memref<5xf32>) + return %0 : memref<5xf32> +}