Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Interfaces][NFC] Better documentation for RegionBranchOpInterface #66920

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 94 additions & 46 deletions mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -117,27 +117,58 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {

def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
let description = [{
This interface provides information for region operations that contain
branching behavior between held regions, i.e. this interface allows for
This interface provides information for region operations that exhibit
branching behavior between held regions. I.e., this interface allows for
expressing control flow information for region holding operations.

This interface is meant to model well-defined cases of control-flow of
This interface is meant to model well-defined cases of control-flow and
value propagation, where what occurs along control-flow edges is assumed to
be side-effect free. For example, corresponding successor operands and
successor block arguments may have different types. In such cases,
`areTypesCompatible` can be implemented to compare types along control-flow
edges. By default, type equality is used.
be side-effect free.

A "region branch point" indicates a point from which a branch originates. It
can indicate either a region of this op or `RegionBranchPoint::parent()`. In
the latter case, the branch originates from outside of the op, i.e., when
first executing this op.

A "region successor" indicates the target of a branch. It can indicate
either a region of this op or this op. In the former case, the region
successor is a region pointer and a range of block arguments to which the
"successor operands" are forwarded to. In the latter case, the control flow
leaves this op and the region successor is a range of results of this op to
which the successor operands are forwarded to.
Comment on lines +133 to +138
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this section maybe be on the constructors of mlir::RegionSuccessor as well? That class has been annoyingly underdocumented imo and I am super happy to see its behaviour spelled out.

Personally speaking, the documentation of classes within C++ headers is also more accessible (at least in my workflow with the IDEs that I use) than the TableGen files, so having the documentation there is IMO preferrable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TableGen documentation is meant so that it is published on the website (and IIRC we have a TODO to publish the interfaces....)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something we could do is add all the doc from TableGen inside the generated .inc files as /// class comments. My IDE jumps to the .inc file all the time.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is value in having both! I love that the TableGen generated code for Interfaces does this.
In this particular case though, the class is fully written in C++ in the ControlFlowInterfaces.h header rather than generated by TableGen (talking about mlir::RegionSuccessor here)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The RegionSuccessor and RegionBranchPoint C++ classes already have some documentation. I described them in the interface description here, so that the interface can be understood standalone (without looking at other files).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(The documentation is not on the class constructor, but the class directly.)


By default, successor operands and successor block arguments/successor
results must have the same type. `areTypesCompatible` can be implemented to
allow non-equal types.

Example:

```
%r = scf.for %iv = %lb to %ub step %step iter_args(%a = %b)
-> tensor<5xf32> {
...
scf.yield %c : tensor<5xf32>
}
```

`scf.for` has one region. The region has two region successors: the region
itself and the `scf.for` op. %b is an entry successor operand. %c is a
successor operand. %a is a successor block argument. %r is a successor
result.
}];
let cppNamespace = "::mlir";

let methods = [
InterfaceMethod<[{
Returns the operands of this operation used as the entry arguments when
branching from `point`, which was specified as a successor of
this operation by `getEntrySuccessorRegions`, or the operands forwarded
to the operation's results when it branches back to itself. These operands
should correspond 1-1 with the successor inputs specified in
`getEntrySuccessorRegions`.
Returns the operands of this operation that are forwarded to the region
successor's block arguments or this operation's results when branching
to `point`. `point` is guaranteed to be among the successors that are
returned by `getEntrySuccessorRegions`/`getSuccessorRegions(parent())`.

Example: In the above example, this method returns the operand %b of the
`scf.for` op, regardless of the value of `point`. I.e., this op always
forwards the same operands, regardless of whether the loop has 0 or more
iterations.
}],
"::mlir::OperandRange", "getEntrySuccessorOperands",
(ins "::mlir::RegionBranchPoint":$point), [{}],
Expand All @@ -147,32 +178,47 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
}]
>,
InterfaceMethod<[{
Returns the viable region successors that are branched to when first
executing the op.
Returns the potential region successors when first executing the op.

Unlike `getSuccessorRegions`, this method also passes along the
constant operands of this op. Based on these, different region
successors can be determined.
`operands` contains an entry for every operand of the implementing
op with a null attribute if the operand has no constant value or
the corresponding attribute if it is a constant.
constant operands of this op. Based on these, the implementation may
filter out certain successors. By default, simply dispatches to
`getSuccessorRegions`. `operands` contains an entry for every
operand of this op, with a null attribute if the operand has no constant
value.

Note: The control flow does not necessarily have to enter any region of
this op.

By default, simply dispatches to `getSuccessorRegions`.
Example: In the above example, this method may return two region
region successors: the single region of the `scf.for` op and the
`scf.for` operation (that implements this interface). If %lb, %ub, %step
are constants and it can be determined the loop does not have any
iterations, this method may choose to return only this operation.
Similarly, if it can be determined that the loop has at least one
iteration, this method may choose to return only the region of the loop.
}],
"void", "getEntrySuccessorRegions",
(ins "::llvm::ArrayRef<::mlir::Attribute>":$operands,
"::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions),
[{}], [{
"::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions), [{}],
/*defaultImplementation=*/[{
$_op.getSuccessorRegions(mlir::RegionBranchPoint::parent(), regions);
}]
>,
InterfaceMethod<[{
Returns the viable successors of `point`. These are the regions that may
be selected during the flow of control. The parent operation, may
specify itself as successor, which indicates that the control flow may
not enter any region at all. This method allows for describing which
regions may be executed when entering an operation, and which regions
are executed after having executed another region of the parent op. The
successor region must be non-empty.
Returns the potential region successors when branching from `point`.
These are the regions that may be selected during the flow of control.

When `point = RegionBranchPoint::parent()`, this method returns the
region successors when entering the operation. Otherwise, this method
returns the successor regions when branching from the region indicated
by `point`.

Example: In the above example, this method returns the region of the
`scf.for` and this operation for either region branch point (`parent`
and the region of the `scf.for`). An implementation may choose to filter
out region successors when it is statically known (e.g., by examining
the operands of this op) that those successors are not branched to.
}],
"void", "getSuccessorRegions",
(ins "::mlir::RegionBranchPoint":$point,
Expand All @@ -183,12 +229,12 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
times this operation will invoke the attached regions (assuming the
regions yield normally, i.e. do not abort or invoke an infinite loop).
The minimum number of invocations is at least 0. If the maximum number
of invocations cannot be statically determined, then it will not have a
value (i.e., it is set to `std::nullopt`).
of invocations cannot be statically determined, then it will be set to
`InvocationBounds::getUnknown()`.

`operands` is a set of optional attributes that either correspond to
constant values for each operand of this operation or null if that
operand is not a constant.
This method also passes along the constant operands of this op.
`operands` contains an entry for every operand of this op, with a null
attribute if the operand has no constant value.

This method may be called speculatively on operations where the provided
operands are not necessarily the same as the operation's current
Expand All @@ -199,16 +245,18 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
(ins "::llvm::ArrayRef<::mlir::Attribute>":$operands,
"::llvm::SmallVectorImpl<::mlir::InvocationBounds> &"
:$invocationBounds), [{}],
[{ invocationBounds.append($_op->getNumRegions(),
::mlir::InvocationBounds::getUnknown()); }]
/*defaultImplementation=*/[{
invocationBounds.append($_op->getNumRegions(),
::mlir::InvocationBounds::getUnknown());
}]
>,
InterfaceMethod<[{
This method is called to compare types along control-flow edges. By
default, the types are checked as equal.
}],
"bool", "areTypesCompatible",
(ins "::mlir::Type":$lhs, "::mlir::Type":$rhs), [{}],
[{ return lhs == rhs; }]
/*defaultImplementation=*/[{ return lhs == rhs; }]
>,
];

Expand All @@ -235,34 +283,34 @@ def RegionBranchTerminatorOpInterface :
OpInterface<"RegionBranchTerminatorOpInterface"> {
let description = [{
This interface provides information for branching terminator operations
in the presence of a parent RegionBranchOpInterface implementation. It
in the presence of a parent `RegionBranchOpInterface` implementation. It
specifies which operands are passed to which successor region.
}];
let cppNamespace = "::mlir";

let methods = [
InterfaceMethod<[{
Returns a mutable range of operands that are semantically "returned" by
passing them to the region successor given by `point`.
passing them to the region successor indicated by `point`.
}],
"::mlir::MutableOperandRange", "getMutableSuccessorOperands",
(ins "::mlir::RegionBranchPoint":$point)
>,
InterfaceMethod<[{
Returns the viable region successors that are branched to after this
Returns the potential region successors that are branched to after this
terminator based on the given constant operands.

`operands` contains an entry for every operand of the
implementing op with a null attribute if the operand has no constant
value or the corresponding attribute if it is a constant.
This method also passes along the constant operands of this op.
`operands` contains an entry for every operand of this op, with a null
attribute if the operand has no constant value.

Default implementation simply dispatches to the parent
The default implementation simply dispatches to the parent
`RegionBranchOpInterface`'s `getSuccessorRegions` implementation.
}],
"void", "getSuccessorRegions",
(ins "::llvm::ArrayRef<::mlir::Attribute>":$operands,
"::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions), [{}],
[{
/*defaultImplementation=*/[{
::mlir::Operation *op = $_op;
::llvm::cast<::mlir::RegionBranchOpInterface>(op->getParentOp())
.getSuccessorRegions(op->getParentRegion(), regions);
Expand Down
9 changes: 0 additions & 9 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2375,10 +2375,6 @@ void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<AffineForEmptyLoopFolder>(context);
}

/// Return operands used when entering the region at 'index'. These operands
/// correspond to the loop iterator operands, i.e., those excluding the
/// induction variable. AffineForOp only has one region, so zero is the only
/// valid value for `index`.
OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
assert((point.isParent() || point == getRegion()) && "invalid region point");

Expand All @@ -2387,11 +2383,6 @@ OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
return getInits();
}

/// Given the region at `index`, or the parent operation if `index` is None,
/// return the successor regions. These are the regions that may be selected
/// during the flow of control. `operands` is a set of optional attributes that
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void AffineForOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
assert((point.isParent() || point == getRegion()) && "expected loop region");
Expand Down
31 changes: 6 additions & 25 deletions mlir/lib/Dialect/SCF/IR/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,6 @@ void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
}

/// Given the region at `index`, or the parent operation if `index` is None,
/// return the successor regions. These are the regions that may be selected
/// during the flow of control. `operands` is a set of optional attributes that
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void ExecuteRegionOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
// If the predecessor is the ExecuteRegionOp, branch into the body.
Expand Down Expand Up @@ -543,18 +538,10 @@ ForOp mlir::scf::getForInductionVarOwner(Value val) {
return dyn_cast_or_null<ForOp>(containingOp);
}

/// Return operands used when entering the region at 'index'. These operands
/// correspond to the loop iterator operands, i.e., those excluding the
/// induction variable.
OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
return getInitArgs();
}

/// Given the region at `index`, or the parent operation if `index` is None,
/// return the successor regions. These are the regions that may be selected
/// during the flow of control. `operands` is a set of optional attributes that
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void ForOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
// Both the operation itself and the region may be branching into the body or
Expand Down Expand Up @@ -1999,11 +1986,6 @@ void IfOp::print(OpAsmPrinter &p) {
p.printOptionalAttrDict((*this)->getAttrs());
}

/// Given the region at `index`, or the parent operation if `index` is None,
/// return the successor regions. These are the regions that may be selected
/// during the flow of control. `operands` is a set of optional attributes that
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void IfOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
// The `then` and the `else` region branch back to the parent operation.
Expand Down Expand Up @@ -3162,13 +3144,6 @@ void WhileOp::build(::mlir::OpBuilder &odsBuilder,
afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
}

OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
assert(point == getBefore() &&
"WhileOp is expected to branch only to the first region");

return getInits();
}

ConditionOp WhileOp::getConditionOp() {
return cast<ConditionOp>(getBeforeBody()->getTerminator());
}
Expand All @@ -3189,6 +3164,12 @@ Block::BlockArgListType WhileOp::getRegionIterArgs() {
return getBeforeArguments();
}

OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
assert(point == getBefore() &&
"WhileOp is expected to branch only to the first region");
return getInits();
}

void WhileOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
// The parent op always branches to the condition region.
Expand Down
18 changes: 7 additions & 11 deletions mlir/lib/Interfaces/ControlFlowInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,8 @@ static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag,
}

/// Verify that types match along all region control flow edges originating from
/// `sourceNo` (region # if source is a region, std::nullopt if source is parent
/// op). `getInputsTypesForRegion` is a function that returns the types of the
/// inputs that flow from `sourceIndex' to the given region, or std::nullopt if
/// the exact type match verification is not necessary (e.g., if the Op verifies
/// the match itself).
/// `sourcePoint`. `getInputsTypesForRegion` is a function that returns the
/// types of the inputs that flow to a successor region.
static LogicalResult
verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
function_ref<FailureOr<TypeRange>(RegionBranchPoint)>
Expand Down Expand Up @@ -150,8 +147,8 @@ verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
auto regionInterface = cast<RegionBranchOpInterface>(op);

auto inputTypesFromParent = [&](RegionBranchPoint regionNo) -> TypeRange {
return regionInterface.getEntrySuccessorOperands(regionNo).getTypes();
auto inputTypesFromParent = [&](RegionBranchPoint point) -> TypeRange {
return regionInterface.getEntrySuccessorOperands(point).getTypes();
};

// Verify types along control flow edges originating from the parent.
Expand Down Expand Up @@ -190,11 +187,10 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
continue;

auto inputTypesForRegion =
[&](RegionBranchPoint succRegionNo) -> FailureOr<TypeRange> {
[&](RegionBranchPoint point) -> FailureOr<TypeRange> {
std::optional<OperandRange> regionReturnOperands;
for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
auto terminatorOperands =
regionReturnOp.getSuccessorOperands(succRegionNo);
auto terminatorOperands = regionReturnOp.getSuccessorOperands(point);

if (!regionReturnOperands) {
regionReturnOperands = terminatorOperands;
Expand All @@ -206,7 +202,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
if (!areTypesCompatible(regionReturnOperands->getTypes(),
terminatorOperands.getTypes())) {
InFlightDiagnostic diag = op->emitOpError(" along control flow edge");
return printRegionEdgeName(diag, region, succRegionNo)
return printRegionEdgeName(diag, region, point)
<< " operands mismatch between return-like terminators";
}
}
Expand Down