Skip to content

Commit

Permalink
[mlir][affine] Allow memref.cast in isDimOpValidSymbol
Browse files Browse the repository at this point in the history
`isDimOpValidSymbol` is used during the verification of `affine.for`
ops. It is used to check if LB/UB values are valid symbols. This change
adds support for `memref.cast`, which can be skipped over if it is a
ranked -> ranked cast.

This change fixes `mlir/test/Transforms.mlir`, which used to fail when
verifying the IR after each pattern application (llvm#74270). In this test
case, a pattern that folds dynamic offsets/sizes/strides to static ones
is applied. This pattern inserts a trivial `memref.cast` that can be
folded away. This folding happens after the pattern application, so the
IR fails to verify after applying the offsets/sizes/strides
canonicalization pattern.

Note: The verifier of `affine.for` violates MLIR guidelines. Only local
properties of an op should be verified. The verifier should not inspect
the defining ops of operands. (This would mean that constraints such as
"operand is a valid affine symbol" cannot be verified.)
  • Loading branch information
matthias-springer committed Dec 5, 2023
1 parent 192439d commit e39e09f
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
@@ -354,8 +354,19 @@ static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
if (!index.has_value())
return false;

// Skip over all memref.cast ops (if any).
Operation *op = dimOp.getShapedValue().getDefiningOp();
while (auto castOp = dyn_cast<memref::CastOp>(op)) {
// Bail on unranked memrefs.
if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
return false;
op = castOp.getSource().getDefiningOp();
if (!op)
return false;
}

int64_t i = index.value();
return TypeSwitch<Operation *, bool>(dimOp.getShapedValue().getDefiningOp())
return TypeSwitch<Operation *, bool>(op)
.Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
[&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
.Default([](Operation *) { return false; });

0 comments on commit e39e09f

Please sign in to comment.