Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][ArmSME] Propagate pad and mask in vector.transfer_read lowering #70814

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 35 additions & 22 deletions mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,30 @@ getSMETileAndCastToVector(PatternRewriter &rewriter, Location loc,

namespace {

/// Conversion pattern for vector.transfer_read op with transpose permutation
/// map to vertical arm_sme.tile_load (in-flight transpose).
/// Conversion pattern for vector.transfer_read.
///
/// ---
///
/// Example 1: op with identity permutation map to horizontal
/// arm_sme.tile_load:
///
/// vector.transfer_read ... permutation_map: (d0, d1) -> (d0, d1)
///
/// is converted to:
///
/// arm_sme.tile_load ...
///
/// ---
///
/// Example 2: op with transpose permutation map to vertical arm_sme.tile_load
/// (in-flight transpose):
///
/// vector.transfer_read ... permutation_map: (d0, d1) -> (d1, d0)
///
/// is converted to:
///
/// arm_sme.tile_load ... layout<vertical>
struct TransferReadPermutationToArmSMELowering
struct TransferReadToArmSMELowering
: public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;

Expand All @@ -79,15 +94,6 @@ struct TransferReadPermutationToArmSMELowering
return rewriter.notifyMatchFailure(transferReadOp,
"not a 2 result permutation map");

AffineMap map = transferReadOp.getPermutationMap();

// Permutation map doesn't perform permutation, can be lowered to
// vector.load by TransferReadToVectorLoadLowering and then
// arm_sme.tile_load by VectorLoadToArmSMELowering.
if (map.isIdentity())
return rewriter.notifyMatchFailure(
transferReadOp, "map is an identity, apply another pattern");

auto vectorType = transferReadOp.getVectorType();
if (!arm_sme::isValidSMETileVectorType(vectorType))
return rewriter.notifyMatchFailure(transferReadOp,
Expand All @@ -96,26 +102,33 @@ struct TransferReadPermutationToArmSMELowering
if (!llvm::isa<MemRefType>(transferReadOp.getSource().getType()))
return rewriter.notifyMatchFailure(transferReadOp, "not a memref source");

if (transferReadOp.getMask())
// TODO: support masking.
return rewriter.notifyMatchFailure(transferReadOp,
"masking not yet supported");

// Out-of-bounds dims are not supported.
if (transferReadOp.hasOutOfBoundsDim())
return rewriter.notifyMatchFailure(transferReadOp,
"not inbounds transfer read");

arm_sme::TileSliceLayout layout;

AffineExpr d0, d1;
bindDims(transferReadOp.getContext(), d0, d1);
if (map != AffineMap::get(map.getNumDims(), 0, {d1, d0},
transferReadOp.getContext()))
AffineMap map = transferReadOp.getPermutationMap();
if (map.isIdentity())
layout = arm_sme::TileSliceLayout::Horizontal;
else if (map == AffineMap::get(map.getNumDims(), 0, {d1, d0},
transferReadOp.getContext()))
layout = arm_sme::TileSliceLayout::Vertical;
else
return rewriter.notifyMatchFailure(transferReadOp,
"not true 2-D matrix transpose");
"unsupported permutation map");

// Padding isn't optional for transfer_read, but is only used in the case
// of out-of-bounds accesses (not supported here) and/or masking. Mask is
// optional, if it's not present don't pass padding.
auto mask = transferReadOp.getMask();
auto padding = mask ? transferReadOp.getPadding() : nullptr;
rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
transferReadOp, vectorType, transferReadOp.getSource(),
transferReadOp.getIndices(), arm_sme::TileSliceLayout::Vertical);
transferReadOp.getIndices(), padding, mask, layout);

return success();
}
Expand Down Expand Up @@ -531,7 +544,7 @@ struct VectorOuterProductToArmSMELowering
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
patterns.add<BroadcastOpToArmSMELowering, ConstantOpToArmSMELowering,
SplatOpToArmSMELowering, TransferReadPermutationToArmSMELowering,
SplatOpToArmSMELowering, TransferReadToArmSMELowering,
TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
VectorOuterProductToArmSMELowering>(&ctx);
Expand Down
140 changes: 74 additions & 66 deletions mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
Original file line number Diff line number Diff line change
@@ -1,181 +1,189 @@
// RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file -allow-unregistered-dialect | FileCheck %s

//===----------------------------------------------------------------------===//
// vector.transfer_read (with in-flight transpose)
// vector.transfer_read
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @transfer_read_2d_transpose_i8
// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
func.func @transfer_read_2d_transpose_i8(%src : memref<?x?xi8>) {
// CHECK-LABEL: @transfer_read_2d_i8
// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>
func.func @transfer_read_2d_i8(%src : memref<?x?xi8>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0 : i8
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi8>, vector<[16]x[16]xi8>
%0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi8>, vector<[16]x[16]xi8>
"prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
return
}

// -----

// CHECK-LABEL: @transfer_read_2d_transpose_i16
// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
func.func @transfer_read_2d_transpose_i16(%src : memref<?x?xi16>) {
// CHECK-LABEL: @transfer_read_2d_i16
// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi16>, vector<[8]x[8]xi16>
func.func @transfer_read_2d_i16(%src : memref<?x?xi16>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0 : i16
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16>
%0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16>
"prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> ()
return
}

// -----

// CHECK-LABEL: @transfer_read_2d_transpose_i32
// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
func.func @transfer_read_2d_transpose_i32(%src : memref<?x?xi32>) {
// CHECK-LABEL: @transfer_read_2d_i32
// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi32>, vector<[4]x[4]xi32>
func.func @transfer_read_2d_i32(%src : memref<?x?xi32>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0 : i32
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
%0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
"prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> ()
return
}

// -----

// CHECK-LABEL: @transfer_read_2d_transpose_i64
// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
func.func @transfer_read_2d_transpose_i64(%src : memref<?x?xi64>) {
// CHECK-LABEL: @transfer_read_2d_i64
// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi64>, vector<[2]x[2]xi64>
func.func @transfer_read_2d_i64(%src : memref<?x?xi64>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0 : i64
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi64>, vector<[2]x[2]xi64>
%0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi64>, vector<[2]x[2]xi64>
"prevent.dce"(%0) : (vector<[2]x[2]xi64>) -> ()
return
}

// -----

// CHECK-LABEL: @transfer_read_2d_transpose_i128
// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
func.func @transfer_read_2d_transpose_i128(%src : memref<?x?xi128>) {
// CHECK-LABEL: @transfer_read_2d_i128
// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi128>, vector<[1]x[1]xi128>
func.func @transfer_read_2d_i128(%src : memref<?x?xi128>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0 : i128
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi128>, vector<[1]x[1]xi128>
%0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi128>, vector<[1]x[1]xi128>
"prevent.dce"(%0) : (vector<[1]x[1]xi128>) -> ()
return
}

// -----

// CHECK-LABEL: @transfer_read_2d_transpose_f16
// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
func.func @transfer_read_2d_transpose_f16(%src : memref<?x?xf16>) {
// CHECK-LABEL: @transfer_read_2d_f16
// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xf16>, vector<[8]x[8]xf16>
func.func @transfer_read_2d_f16(%src : memref<?x?xf16>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f16
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf16>, vector<[8]x[8]xf16>
%0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf16>, vector<[8]x[8]xf16>
"prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
return
}

// -----

// CHECK-LABEL: @transfer_read_2d_transpose_bf16
// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
func.func @transfer_read_2d_transpose_bf16(%src : memref<?x?xbf16>) {
// CHECK-LABEL: @transfer_read_2d_bf16
// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
func.func @transfer_read_2d_bf16(%src : memref<?x?xbf16>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : bf16
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
%0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
"prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> ()
return
}

// -----

// CHECK-LABEL: @transfer_read_2d_transpose_f32
// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
func.func @transfer_read_2d_transpose_f32(%src : memref<?x?xf32>) {
// CHECK-LABEL: @transfer_read_2d_f32
// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xf32>, vector<[4]x[4]xf32>
func.func @transfer_read_2d_f32(%src : memref<?x?xf32>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f32
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
%0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
return
}

// -----

// CHECK-LABEL: @transfer_read_2d_transpose_f64
// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
func.func @transfer_read_2d_transpose_f64(%src : memref<?x?xf64>) {
// CHECK-LABEL: @transfer_read_2d_f64
// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xf64>, vector<[2]x[2]xf64>
func.func @transfer_read_2d_f64(%src : memref<?x?xf64>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f64
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
%0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
return
}

// -----

// CHECK-LABEL: @transfer_read_2d__bad_type
// CHECK-NOT: arm_sme.tile_load
// CHECK: vector.transfer_read
func.func @transfer_read_2d__bad_type(%src : memref<?x?xf64>) {
// CHECK-LABEL: @transfer_read_2d_with_mask_i16
// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}], {{.*}}, {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
func.func @transfer_read_2d_with_mask_i16(%src : memref<?x?xi16>, %mask : vector<[8]x[8]xi1>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f64
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref<?x?xf64>, vector<[4]x[4]xf64>
"prevent.dce"(%0) : (vector<[4]x[4]xf64>) -> ()
%pad = arith.constant 0 : i16
%0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16>
"prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> ()
return
}

// -----

// CHECK-LABEL: @transfer_read_2d__non_memref_type
// CHECK-NOT: arm_sme.tile_load
// CHECK: vector.transfer_read
func.func @transfer_read_2d__non_memref_type(%src : tensor<?x?xf64>) {
/// in-flight transpose

// CHECK-LABEL: @transfer_read_2d_transpose_i8
// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
func.func @transfer_read_2d_transpose_i8(%src : memref<?x?xi8>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f64
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : tensor<?x?xf64>, vector<[2]x[2]xf64>
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
%pad = arith.constant 0 : i8
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi8>, vector<[16]x[16]xi8>
"prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
return
}

// -----

// CHECK-LABEL: @transfer_read_2d__bad_transfer_rank
// CHECK-LABEL: @transfer_read_2d_transpose_with_mask_f32
// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
func.func @transfer_read_2d_transpose_with_mask_f32(%src : memref<?x?xf32>, %mask : vector<[4]x[4]xi1>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f32
%0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
return
}

// -----

// CHECK-LABEL: @transfer_read_2d__bad_type
// CHECK-NOT: arm_sme.tile_load
// CHECK: vector.transfer_read
func.func @transfer_read_2d__bad_transfer_rank(%src : memref<?x?xf64>) {
func.func @transfer_read_2d__bad_type(%src : memref<?x?xf64>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f64
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} : memref<?x?xf64>, vector<[2]xf64>
"prevent.dce"(%0) : (vector<[2]xf64>) -> ()
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref<?x?xf64>, vector<[4]x[4]xf64>
"prevent.dce"(%0) : (vector<[4]x[4]xf64>) -> ()
return
}

// -----

// CHECK-LABEL: @transfer_read_2d__unsupported_mask
// CHECK-LABEL: @transfer_read_2d__non_memref_type
// CHECK-NOT: arm_sme.tile_load
// CHECK: vector.transfer_read
func.func @transfer_read_2d__unsupported_mask(%src : memref<?x?xf64>, %mask : vector<[2]x[2]xi1>) {
func.func @transfer_read_2d__non_memref_type(%src : tensor<?x?xf64>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f64
%0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : tensor<?x?xf64>, vector<[2]x[2]xf64>
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
return
}

// -----

/// transfer_read with identity map should be lowered to vector.load by
/// TransferReadToVectorLoadLowering and then arm_sme.tile_load by
/// VectorLoadToArmSMELowering.

// CHECK-LABEL: @transfer_read_2d__non_permuting_map
// CHECK-LABEL: @transfer_read_2d__bad_transfer_rank
// CHECK-NOT: arm_sme.tile_load
// CHECK: vector.transfer_read
func.func @transfer_read_2d__non_permuting_map(%src : memref<?x?xf64>) {
func.func @transfer_read_2d__bad_transfer_rank(%src : memref<?x?xf64>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f64
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0, d1)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} : memref<?x?xf64>, vector<[2]xf64>
"prevent.dce"(%0) : (vector<[2]xf64>) -> ()
return
}

Expand Down