Skip to content

Commit

Permalink
[mlir][bufferization] Allow cyclic function graphs without tensors (#…
Browse files Browse the repository at this point in the history
…68632)

Cyclic function call graphs are generally not supported by One-Shot
Bufferize. However, they can be allowed when a function does not have
tensor arguments or results. This is because it is then no longer
necessary that the callee will be bufferized before the caller.
  • Loading branch information
matthias-springer authored Oct 10, 2023
1 parent c8b5f4c commit 3d0ca2c
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,13 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
});
}

/// Return "true" if the given function signature has tensor semantics.
static bool hasTensorSignature(func::FuncOp funcOp) {
auto isaTensor = [](Type t) { return isa<TensorType>(t); };
return llvm::any_of(funcOp.getFunctionType().getInputs(), isaTensor) ||
llvm::any_of(funcOp.getFunctionType().getResults(), isaTensor);
}

/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
/// callee-caller order (i.e. callees without callers first).
/// Store the map of FuncOp to all its callers in `callerMap`.
Expand All @@ -297,10 +304,16 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
"without a unique ReturnOp";
}

// Collect function calls and populate the caller map.
numberCallOpsContainedInFuncOp[funcOp] = 0;
return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
func::FuncOp calledFunction = getCalledFunction(callOp);
assert(calledFunction && "could not retrieved called func::FuncOp");
// If the called function does not have any tensors in its signature, then
// it is not necessary to bufferize the callee before the caller.
if (!hasTensorSignature(calledFunction))
return WalkResult::skip();

callerMap[calledFunction].insert(callOp);
if (calledBy[calledFunction].insert(funcOp).second) {
numberCallOpsContainedInFuncOp[funcOp]++;
Expand All @@ -310,7 +323,7 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
});
if (res.wasInterrupted())
return failure();
// Iteratively remove function operation that do not call any of the
// Iteratively remove function operations that do not call any of the
// functions remaining in the callCounter map and add them to the worklist.
while (!numberCallOpsContainedInFuncOp.empty()) {
auto it = llvm::find_if(numberCallOpsContainedInFuncOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ func.func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>

// expected-error @-3 {{expected callgraph to be free of circular dependencies}}

func.func @foo() {
call @bar() : () -> ()
return
func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> {
%0 = call @bar(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
return %0 : tensor<5xf32>
}

func.func @bar() {
call @foo() : () -> ()
return
func.func @bar(%t: tensor<5xf32>) -> tensor<5xf32>{
%0 = call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
return %0 : tensor<5xf32>
}

// -----
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -662,3 +662,24 @@ func.func @br_in_func(%t: tensor<5xf32>) -> tensor<5xf32> {
^bb1(%arg1 : tensor<5xf32>):
func.return %arg1 : tensor<5xf32>
}

// -----

// Cyclic call graphs with tensors are not supported by One-Shot Bufferize.
// However, if a function signature does not have any tensor arguments or
// results, calls to that function are not seen as an "edge" in the fuction
// call graph.

// CHECK-LABEL: func.func @foo(%{{.*}}: memref<5xf32>) -> memref<5xf32>
func.func @foo(%m: memref<5xf32>) -> memref<5xf32> {
%0 = tensor.empty() : tensor<5xf32>
%1 = func.call @bar(%0, %m)
: (tensor<5xf32>, memref<5xf32>) -> (memref<5xf32>)
return %1 : memref<5xf32>
}

// CHECK: func.func @bar(%{{.*}}: memref<5xf32, strided<[?], offset: ?>>, %arg1: memref<5xf32>) -> memref<5xf32>
func.func @bar(%t: tensor<5xf32>, %m: memref<5xf32>) -> memref<5xf32> {
%0 = func.call @foo(%m) : (memref<5xf32>) -> (memref<5xf32>)
return %0 : memref<5xf32>
}

0 comments on commit 3d0ca2c

Please sign in to comment.