diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index 672ef3eb4cd50f..101e099d2b644c 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -870,6 +870,92 @@ class ExtractStridedMetadataOpReinterpretCastFolder } }; +/// Replace `base, offset, sizes, strides = +/// extract_strided_metadata( +/// cast(src) to dstTy)` +/// With +/// ``` +/// base, ... = extract_strided_metadata(src) +/// offset = !dstTy.srcOffset.isDynamic() +/// ? dstTy.srcOffset +/// : extract_strided_metadata(src).offset +/// sizes = for each srcSize in dstTy.srcSizes: +/// !srcSize.isDynamic() +/// ? srcSize +// : extract_strided_metadata(src).sizes[i] +/// strides = for each srcStride in dstTy.srcStrides: +/// !srcStrides.isDynamic() +/// ? srcStrides +/// : extract_strided_metadata(src).strides[i] +/// ``` +/// +/// In other words, consume the `cast` and apply its effects +/// on the offset, sizes, and strides or compute them directly from `src`. +class ExtractStridedMetadataOpCastFolder + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult + matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp, + PatternRewriter &rewriter) const override { + Value source = extractStridedMetadataOp.getSource(); + auto castOp = source.getDefiningOp(); + if (!castOp) + return failure(); + + Location loc = extractStridedMetadataOp.getLoc(); + // Check if the source is suitable for extract_strided_metadata. + SmallVector inferredReturnTypes; + if (failed(extractStridedMetadataOp.inferReturnTypes( + rewriter.getContext(), loc, {castOp.getSource()}, + /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{}, + inferredReturnTypes))) + return rewriter.notifyMatchFailure(castOp, + "cast source's type is incompatible"); + + auto memrefType = cast(source.getType()); + unsigned rank = memrefType.getRank(); + SmallVector results; + results.resize_for_overwrite(rank * 2 + 2); + + auto newExtractStridedMetadata = + rewriter.create(loc, + castOp.getSource()); + + // Register the base_buffer. + results[0] = newExtractStridedMetadata.getBaseBuffer(); + + auto getConstantOrValue = [&rewriter](int64_t constant, + OpFoldResult ofr) -> OpFoldResult { + return !ShapedType::isDynamic(constant) + ? OpFoldResult(rewriter.getIndexAttr(constant)) + : ofr; + }; + + auto [sourceStrides, sourceOffset] = getStridesAndOffset(memrefType); + assert(sourceStrides.size() == rank && "unexpected number of strides"); + + // Register the new offset. + results[1] = + getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset()); + + const unsigned sizeStartIdx = 2; + const unsigned strideStartIdx = sizeStartIdx + rank; + ArrayRef sourceSizes = memrefType.getShape(); + + SmallVector sizes = newExtractStridedMetadata.getSizes(); + SmallVector strides = newExtractStridedMetadata.getStrides(); + for (unsigned i = 0; i < rank; ++i) { + results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]); + results[strideStartIdx + i] = + getConstantOrValue(sourceStrides[i], strides[i]); + } + rewriter.replaceOp(extractStridedMetadataOp, + getValueOrCreateConstantIndexOp(rewriter, loc, results)); + return success(); + } +}; + /// Replace `base, offset = /// extract_strided_metadata(extract_strided_metadata(src)#0)` /// With @@ -911,6 +997,7 @@ void memref::populateExpandStridedMetadataPatterns( ExtractStridedMetadataOpGetGlobalFolder, RewriteExtractAlignedPointerAsIndexOfViewLikeOp, ExtractStridedMetadataOpReinterpretCastFolder, + ExtractStridedMetadataOpCastFolder, ExtractStridedMetadataOpExtractStridedMetadataFolder>( patterns.getContext()); } @@ -923,6 +1010,7 @@ void memref::populateResolveExtractStridedMetadataPatterns( ExtractStridedMetadataOpSubviewFolder, RewriteExtractAlignedPointerAsIndexOfViewLikeOp, ExtractStridedMetadataOpReinterpretCastFolder, + ExtractStridedMetadataOpCastFolder, ExtractStridedMetadataOpExtractStridedMetadataFolder>( patterns.getContext()); } diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir index a6303aa2d97110..ab0c78a8ba7669 100644 --- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir @@ -1369,3 +1369,128 @@ func.func @extract_strided_metadata_of_get_global_with_offset() return %base, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : memref, index, index, index, index, index } + +// ----- + +// Check that we simplify extract_strided_metadata of cast +// when the source of the cast is compatible with what +// `extract_strided_metadata`s accept. +// +// When we apply the transformation the resulting offset, sizes and strides +// should come straight from the inputs of the cast. +// Additionally the folder on extract_strided_metadata should propagate the +// static information. +// +// CHECK-LABEL: func @extract_strided_metadata_of_cast +// CHECK-SAME: %[[ARG:.*]]: memref<3x?xi32, strided<[4, ?], offset: ?>>) +// +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] +// +// CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[C3]], %[[DYN_SIZES]]#1, %[[C4]], %[[DYN_STRIDES]]#1 +func.func @extract_strided_metadata_of_cast( + %arg : memref<3x?xi32, strided<[4, ?], offset:?>>) + -> (memref, index, + index, index, + index, index) { + + %cast = + memref.cast %arg : + memref<3x?xi32, strided<[4, ?], offset: ?>> to + memref> + + %base, %base_offset, %sizes:2, %strides:2 = + memref.extract_strided_metadata %cast:memref> + -> memref, index, + index, index, + index, index + + return %base, %base_offset, + %sizes#0, %sizes#1, + %strides#0, %strides#1 : + memref, index, + index, index, + index, index +} + +// ----- + +// Check that we simplify extract_strided_metadata of cast +// when the source of the cast is compatible with what +// `extract_strided_metadata`s accept. +// +// Same as extract_strided_metadata_of_cast but with constant sizes and strides +// in the destination type. +// +// CHECK-LABEL: func @extract_strided_metadata_of_cast_w_csts +// CHECK-SAME: %[[ARG:.*]]: memref>) +// +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C18:.*]] = arith.constant 18 : index +// CHECK-DAG: %[[C25:.*]] = arith.constant 25 : index +// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] +// +// CHECK: return %[[BASE]], %[[C25]], %[[C4]], %[[DYN_SIZES]]#1, %[[DYN_STRIDES]]#0, %[[C18]] +func.func @extract_strided_metadata_of_cast_w_csts( + %arg : memref>) + -> (memref, index, + index, index, + index, index) { + + %cast = + memref.cast %arg : + memref> to + memref<4x?xi32, strided<[?, 18], offset: 25>> + + %base, %base_offset, %sizes:2, %strides:2 = + memref.extract_strided_metadata %cast:memref<4x?xi32, strided<[?, 18], offset: 25>> + -> memref, index, + index, index, + index, index + + return %base, %base_offset, + %sizes#0, %sizes#1, + %strides#0, %strides#1 : + memref, index, + index, index, + index, index +} +// ----- + +// Check that we don't simplify extract_strided_metadata of +// cast when the source of the cast is unranked. +// Unranked memrefs cannot feed into extract_strided_metadata operations. +// Note: Technically we could still fold the sizes and strides. +// +// CHECK-LABEL: func @extract_strided_metadata_of_cast_unranked +// CHECK-SAME: %[[ARG:.*]]: memref<*xi32>) +// +// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : +// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[CAST]] +// +// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZES]]#0, %[[SIZES]]#1, %[[STRIDES]]#0, %[[STRIDES]]#1 +func.func @extract_strided_metadata_of_cast_unranked( + %arg : memref<*xi32>) + -> (memref, index, + index, index, + index, index) { + + %cast = + memref.cast %arg : + memref<*xi32> to + memref> + + %base, %base_offset, %sizes:2, %strides:2 = + memref.extract_strided_metadata %cast:memref> + -> memref, index, + index, index, + index, index + + return %base, %base_offset, + %sizes#0, %sizes#1, + %strides#0, %strides#1 : + memref, index, + index, index, + index, index +}