Skip to content

Commit

Permalink
[mlir][MemRef] Add a pattern to simplify `extract_strided_metadata(ca… (
Browse files Browse the repository at this point in the history
#68291)

…st)`

`expand-strided-metadata` was missing a pattern to get rid of
`memref.cast`.
The pattern is straight foward:
Produce a new `extract_strided_metadata` with the source of the cast and
fold the static information (sizes, strides, offset) along the way.
  • Loading branch information
qcolombet authored Oct 5, 2023
1 parent 253ee85 commit 932dc9d
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 0 deletions.
88 changes: 88 additions & 0 deletions mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,92 @@ class ExtractStridedMetadataOpReinterpretCastFolder
}
};

/// Replace `base, offset, sizes, strides =
/// extract_strided_metadata(
/// cast(src) to dstTy)`
/// With
/// ```
/// base, ... = extract_strided_metadata(src)
/// offset = !dstTy.srcOffset.isDynamic()
/// ? dstTy.srcOffset
/// : extract_strided_metadata(src).offset
/// sizes = for each srcSize in dstTy.srcSizes:
/// !srcSize.isDynamic()
/// ? srcSize
// : extract_strided_metadata(src).sizes[i]
/// strides = for each srcStride in dstTy.srcStrides:
/// !srcStrides.isDynamic()
/// ? srcStrides
/// : extract_strided_metadata(src).strides[i]
/// ```
///
/// In other words, consume the `cast` and apply its effects
/// on the offset, sizes, and strides or compute them directly from `src`.
class ExtractStridedMetadataOpCastFolder
: public OpRewritePattern<memref::ExtractStridedMetadataOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult
matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
PatternRewriter &rewriter) const override {
Value source = extractStridedMetadataOp.getSource();
auto castOp = source.getDefiningOp<memref::CastOp>();
if (!castOp)
return failure();

Location loc = extractStridedMetadataOp.getLoc();
// Check if the source is suitable for extract_strided_metadata.
SmallVector<Type> inferredReturnTypes;
if (failed(extractStridedMetadataOp.inferReturnTypes(
rewriter.getContext(), loc, {castOp.getSource()},
/*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
inferredReturnTypes)))
return rewriter.notifyMatchFailure(castOp,
"cast source's type is incompatible");

auto memrefType = cast<MemRefType>(source.getType());
unsigned rank = memrefType.getRank();
SmallVector<OpFoldResult> results;
results.resize_for_overwrite(rank * 2 + 2);

auto newExtractStridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc,
castOp.getSource());

// Register the base_buffer.
results[0] = newExtractStridedMetadata.getBaseBuffer();

auto getConstantOrValue = [&rewriter](int64_t constant,
OpFoldResult ofr) -> OpFoldResult {
return !ShapedType::isDynamic(constant)
? OpFoldResult(rewriter.getIndexAttr(constant))
: ofr;
};

auto [sourceStrides, sourceOffset] = getStridesAndOffset(memrefType);
assert(sourceStrides.size() == rank && "unexpected number of strides");

// Register the new offset.
results[1] =
getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset());

const unsigned sizeStartIdx = 2;
const unsigned strideStartIdx = sizeStartIdx + rank;
ArrayRef<int64_t> sourceSizes = memrefType.getShape();

SmallVector<OpFoldResult> sizes = newExtractStridedMetadata.getSizes();
SmallVector<OpFoldResult> strides = newExtractStridedMetadata.getStrides();
for (unsigned i = 0; i < rank; ++i) {
results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]);
results[strideStartIdx + i] =
getConstantOrValue(sourceStrides[i], strides[i]);
}
rewriter.replaceOp(extractStridedMetadataOp,
getValueOrCreateConstantIndexOp(rewriter, loc, results));
return success();
}
};

/// Replace `base, offset =
/// extract_strided_metadata(extract_strided_metadata(src)#0)`
/// With
Expand Down Expand Up @@ -911,6 +997,7 @@ void memref::populateExpandStridedMetadataPatterns(
ExtractStridedMetadataOpGetGlobalFolder,
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
ExtractStridedMetadataOpReinterpretCastFolder,
ExtractStridedMetadataOpCastFolder,
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
patterns.getContext());
}
Expand All @@ -923,6 +1010,7 @@ void memref::populateResolveExtractStridedMetadataPatterns(
ExtractStridedMetadataOpSubviewFolder,
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
ExtractStridedMetadataOpReinterpretCastFolder,
ExtractStridedMetadataOpCastFolder,
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
patterns.getContext());
}
Expand Down
125 changes: 125 additions & 0 deletions mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1369,3 +1369,128 @@ func.func @extract_strided_metadata_of_get_global_with_offset()
return %base, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :
memref<i32>, index, index, index, index, index
}

// -----

// Check that we simplify extract_strided_metadata of cast
// when the source of the cast is compatible with what
// `extract_strided_metadata`s accept.
//
// When we apply the transformation the resulting offset, sizes and strides
// should come straight from the inputs of the cast.
// Additionally the folder on extract_strided_metadata should propagate the
// static information.
//
// CHECK-LABEL: func @extract_strided_metadata_of_cast
// CHECK-SAME: %[[ARG:.*]]: memref<3x?xi32, strided<[4, ?], offset: ?>>)
//
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
//
// CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[C3]], %[[DYN_SIZES]]#1, %[[C4]], %[[DYN_STRIDES]]#1
func.func @extract_strided_metadata_of_cast(
%arg : memref<3x?xi32, strided<[4, ?], offset:?>>)
-> (memref<i32>, index,
index, index,
index, index) {

%cast =
memref.cast %arg :
memref<3x?xi32, strided<[4, ?], offset: ?>> to
memref<?x?xi32, strided<[?, ?], offset: ?>>

%base, %base_offset, %sizes:2, %strides:2 =
memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
-> memref<i32>, index,
index, index,
index, index

return %base, %base_offset,
%sizes#0, %sizes#1,
%strides#0, %strides#1 :
memref<i32>, index,
index, index,
index, index
}

// -----

// Check that we simplify extract_strided_metadata of cast
// when the source of the cast is compatible with what
// `extract_strided_metadata`s accept.
//
// Same as extract_strided_metadata_of_cast but with constant sizes and strides
// in the destination type.
//
// CHECK-LABEL: func @extract_strided_metadata_of_cast_w_csts
// CHECK-SAME: %[[ARG:.*]]: memref<?x?xi32, strided<[?, ?], offset: ?>>)
//
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[C18:.*]] = arith.constant 18 : index
// CHECK-DAG: %[[C25:.*]] = arith.constant 25 : index
// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
//
// CHECK: return %[[BASE]], %[[C25]], %[[C4]], %[[DYN_SIZES]]#1, %[[DYN_STRIDES]]#0, %[[C18]]
func.func @extract_strided_metadata_of_cast_w_csts(
%arg : memref<?x?xi32, strided<[?, ?], offset:?>>)
-> (memref<i32>, index,
index, index,
index, index) {

%cast =
memref.cast %arg :
memref<?x?xi32, strided<[?, ?], offset: ?>> to
memref<4x?xi32, strided<[?, 18], offset: 25>>

%base, %base_offset, %sizes:2, %strides:2 =
memref.extract_strided_metadata %cast:memref<4x?xi32, strided<[?, 18], offset: 25>>
-> memref<i32>, index,
index, index,
index, index

return %base, %base_offset,
%sizes#0, %sizes#1,
%strides#0, %strides#1 :
memref<i32>, index,
index, index,
index, index
}
// -----

// Check that we don't simplify extract_strided_metadata of
// cast when the source of the cast is unranked.
// Unranked memrefs cannot feed into extract_strided_metadata operations.
// Note: Technically we could still fold the sizes and strides.
//
// CHECK-LABEL: func @extract_strided_metadata_of_cast_unranked
// CHECK-SAME: %[[ARG:.*]]: memref<*xi32>)
//
// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] :
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[CAST]]
//
// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZES]]#0, %[[SIZES]]#1, %[[STRIDES]]#0, %[[STRIDES]]#1
func.func @extract_strided_metadata_of_cast_unranked(
%arg : memref<*xi32>)
-> (memref<i32>, index,
index, index,
index, index) {

%cast =
memref.cast %arg :
memref<*xi32> to
memref<?x?xi32, strided<[?, ?], offset: ?>>

%base, %base_offset, %sizes:2, %strides:2 =
memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
-> memref<i32>, index,
index, index,
index, index

return %base, %base_offset,
%sizes#0, %sizes#1,
%strides#0, %strides#1 :
memref<i32>, index,
index, index,
index, index
}

0 comments on commit 932dc9d

Please sign in to comment.