Skip to content

Commit

Permalink
[mlir][linalg] Fix invalid IR in Linalg op fusion (#74425)
Browse files Browse the repository at this point in the history
Linalg op fusion (`Linalg/Transforms/Fusion.cpp`) used to generate
invalid fused producer ops:
```
error: 'linalg.conv_2d_nhwc_hwcf' op expected type of operand #2 ('tensor<1x8x16x4xf32>') to match type of corresponding result ('tensor<?x?x?x?xf32>')
note: see current operation:
%24 = "linalg.conv_2d_nhwc_hwcf"(%21, %22, %23) <{dilations = dense<1> : tensor<2xi64>, operandSegmentSizes = array<i32: 2, 1>, strides = dense<2> : tensor<2xi64>}> ({
^bb0(%arg9: f32, %arg10: f32, %arg11: f32):
  %28 = "arith.mulf"(%arg9, %arg10) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
  %29 = "arith.addf"(%arg11, %28) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
  "linalg.yield"(%29) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 * 2 + d4, d2 * 2 + d5, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>]} : (tensor<1x?x?x3xf32>, tensor<3x3x3x4xf32>, tensor<1x8x16x4xf32>) -> tensor<?x?x?x?xf32>
```

This is a problem because the input IR to greedy pattern rewriter during
`-test-linalg-greedy-fusion` is invalid. This commit fixes tests such as
`mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir` when verifying the
IR after each pattern application (#74270).
  • Loading branch information
matthias-springer authored Dec 19, 2023
1 parent 6a7bbf7 commit 3a087c1
Showing 1 changed file with 7 additions and 17 deletions.
24 changes: 7 additions & 17 deletions mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,27 +144,17 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
b, loc, producer, getTiledOperands(producer), ivs, tileSizes, sizeBounds,
/**omitPartialTileCheck=*/false));

// Iterate over the results in order.
// Extract the subtensor type from the linearized range.
// Since we do not enforce any canonicalizations on the fly, this is always
// fully dynamic at construction time.
// Take result types from the tiled init operands.
MutableOperandRange producerDpsInits = producer.getDpsInitsMutable();
SmallVector<Type, 4> resultTypes;
resultTypes.reserve(producer->getNumResults());
for (Value operand : producer.getDpsInits()) {
auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
if (!tensorType)
continue;
unsigned rank = tensorType.getRank();
SmallVector<int64_t, 4> staticOffsetsVector(
rank, ShapedType::kDynamic);
SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamic);
SmallVector<int64_t, 4> staticStridesVector(
rank, ShapedType::kDynamic);
resultTypes.push_back(tensor::ExtractSliceOp::inferResultType(
tensorType, staticOffsetsVector, staticSizesVector,
staticStridesVector));
int64_t firstInitOperandIdx =
static_cast<OperandRange>(producerDpsInits).getBeginOperandIndex();
for (int64_t i = 0, e = producer->getNumResults(); i < e; ++i) {
resultTypes.push_back(clonedShapes[firstInitOperandIdx + i].getType());
}

// Clone the producer with new operands and result types.
LinalgOp clonedOp = clone(b, producer, resultTypes, clonedShapes);

// Shift all IndexOp results by the tile offset.
Expand Down

0 comments on commit 3a087c1

Please sign in to comment.