-
Notifications
You must be signed in to change notification settings - Fork 12.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][ArmSVE] Add convert_to/from_svbool ops #68586
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly makes sense, though I'd like to take another look at the new tablegen defs when I am more awake :)
// Negative values index in reverse. | ||
class ShapedTypeWithNthDimOfSize<int n, list<int> allowedSizes> : Type< | ||
IsNthDimSizeIsOneOfPred<n, allowedSizes>, | ||
" with dim " # n # " having a size of {" # !interleave(allowedSizes, ", ") # "}", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this subtract 1 from n
if n > 0
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so, the dims are indexed from 1 (so both reverse and forward indexing is symmetrical).
This adds slightly higher-level ops for converting masks between svbool and SVE predicate types. The main reason to use these over the intrinsics is these ops support vectors of masks (via unrolling). E.g. ``` // Convert a svbool mask to a mask of SVE predicates: %svbool = vector.load %memref[%c0, %c0] : memref<2x?xi1>, vector<2x[16]xi1> %mask = arm_sve.convert_from_svbool %svbool : vector<2x[8]xi1> // => Results in vector<2x[8]xi1> ``` Or: ``` // Convert a mask of SVE predicates to a svbool mask: %mask = vector.create_mask %c2, %dim_size : vector<2x[2]xi1> %svbool = arm_sve.convert_to_svbool %mask : vector<2x[2]xi1> // => Results in vector<2x[16]xi1> ```
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, have gone through a few more definitions. Will do more later.
@@ -236,6 +243,66 @@ def UmmlaOp : ArmSVE_Op<"ummla", | |||
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; | |||
} | |||
|
|||
|
|||
class SvboolTypeContraint<string lhsArg, string rhsArg> : TypesMatchWith< | |||
"expected corresponding svbool type widened to [16]xi1", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be possible to test for this diagnostic, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is already implicitly tested. This is used to infer the the svbool type (from the result or argument), so there's no textual representation of the op that fails this constraint.
let description = [{ | ||
Converts svbool types (`vector<[16]xi1>` or vectors of that type, e.g. | ||
`vector<2x3x[16]xi1>`) to SVE predicate types. Note: Only the trailing | ||
dimension can be scalable. | ||
|
||
Example 1: Convert a 1-D svbool mask to a SVE predicate. | ||
```mlir | ||
%svbool = vector.load %memref[%c0] : memref<?xi1>, vector<[16]xi1> | ||
%mask = arm_sve.convert_from_svbool %svbool : vector<[4]xi1> | ||
``` | ||
|
||
Example 2: Convert a 2-D svbool mask to a mask of SVE predicates. | ||
```mlir | ||
%svbool = vector.load %memref[%c0, %c0] : memref<2x?xi1>, vector<2x[16]xi1> | ||
%mask = arm_sve.convert_from_svbool %svbool : vector<2x[8]xi1> | ||
``` | ||
}]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How would somebody reading this file know what svbool
and "SVE predicate" types are? It would be good to define both. Note that also input/output names are a bit inconsistent:
let arguments = (ins SVBoolMask:$source);
let results = (outs SVEMask:$result);
Given how subtle this is, I'd suggest either expanding this definition, adding links or adding some definition at the top of this file. Or perhaps these types are defined somewhere? (as in, other than as "TableGen" definitions).
Also, "a SVE" -> "an SVE"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, "a SVE" -> "an SVE"?
No idea! I think my brain hardcodes this for a few words, then does rand()
for the rest 😛
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't tell if "a" or "an" is correct here... Grammarly says "a" though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added a little explainer at the bottom of each op, the online docs should also include auto-generated descriptions of the vector types.
@@ -37,6 +37,12 @@ def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) && | |||
def IsScalableVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) && | |||
::llvm::cast<VectorType>($_self).isScalable()}]>; | |||
|
|||
// Whether a type is a scalable VectorType, with a single trailing scalable dimension. | |||
def IsTrailingScalableVectorTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this inherit from IsScalableVectorTypePred
? Just looking at the comments:
// Whether a type is a scalable VectorType.
vs
// Whether a type is a scalable VectorType, with a single trailing scalable dimension.
So it sounds like IsTrailingScalableVectorTypePred
is a refined version of IsScalableVectorTypePred
? In fact, given how complex this is - wouldn't it make sense to simply have a dedicated predicate to check that only the trailing dim is scalable? I am also thinking about the name of this predicate and IMHO it doesn't reveal what the predicate represents. I would try something like:
IsOnlyTrailingDImScalablePred
That would be shorter, but would capture the key aspect of this pred. Also, we know that only vectors can be "scalable", so can skip "scalable" from name. I would, perhaps, also add an example of what would be valid and what would not under this predicate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't actually get anything from IsScalableVectorTypePred
. That just checks if you have a scalable dim anywhere, so you still have to check that the trailing dim is scalable, and that all the other dims are not.
It's a refined version of IsVectorTypePred
, but as that'll remove getRank() > 0
at some point in the future, that can't be depended on.
ee738b9
to
9c656ca
Compare
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-sve Author: Benjamin Maxwell (MacDue) ChangesThis adds slightly higher-level ops for converting masks between svbool E.g.
Or:
Depends on #68418 Patch is 26.88 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/68586.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index d4294b4dd9fd4e8..cae87b764fc67dd 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -28,6 +28,8 @@ def ArmSVE_Dialect : Dialect {
This dialect contains the definitions necessary to target specific Arm SVE
scalable vector operations.
}];
+
+ let dependentDialects = ["vector::VectorDialect"];
}
//===----------------------------------------------------------------------===//
@@ -40,6 +42,11 @@ def SVBool : ScalableVectorOfRankAndLengthAndType<
def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
[1], [16, 8, 4, 2, 1], [I1]>;
+// Generalizations of SVBool and SVEPredicate to ranks >= 1.
+// These are masks with a single trailing scalable dimension.
+def SVBoolMask : TrailingScalableVectorOfSizeAndType<[16], [I1]>;
+def SVEPredicateMask : TrailingScalableVectorOfSizeAndType<[16, 8, 4, 2, 1], [I1]>;
+
//===----------------------------------------------------------------------===//
// ArmSVE op definitions
//===----------------------------------------------------------------------===//
@@ -236,6 +243,81 @@ def UmmlaOp : ArmSVE_Op<"ummla",
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}
+class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
+ "expected corresponding svbool type widened to [16]xi1",
+ lhsArg, rhsArg,
+ "VectorType(VectorType::Builder(::llvm::cast<VectorType>($_self)).setDim(::llvm::cast<VectorType>($_self).getRank() - 1, 16))">;
+
+def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool",
+ [Pure, SvboolTypeConstraint<"result", "source">]>
+{
+ let summary = "Convert a svbool type to a SVE predicate type";
+ let description = [{
+ Converts svbool types (`vector<[16]xi1>` or vectors of that type, e.g.
+ `vector<2x3x[16]xi1>`) to SVE predicate types. Note: Only the trailing
+ dimension can be scalable.
+
+ Example 1: Convert a 1-D svbool mask to a SVE predicate.
+ ```mlir
+ %source = vector.load %memref[%c0] : memref<?xi1>, vector<[16]xi1>
+ %result = arm_sve.convert_from_svbool %source : vector<[4]xi1>
+ ```
+
+ Example 2: Convert a 2-D svbool mask to a mask of SVE predicates.
+ ```mlir
+ %source = vector.load %memref[%c0, %c0] : memref<2x?xi1>, vector<2x[16]xi1>
+ %result = arm_sve.convert_from_svbool %source : vector<2x[8]xi1>
+ ```
+
+ ---
+
+ A `svbool` is the smallest SVE predicate type that has a in-memory
+ representation (and maps to a full predicate register). In MLIR `svbool` is
+ represented as `vector<[16]xi1>`. Smaller SVE predicate types
+ (`vector<[1|2|4|8]xi1>`) must be stored as `svbool` then converted back to
+ a predicate after loading.
+ }];
+ let arguments = (ins SVBoolMask:$source);
+ let results = (outs SVEPredicateMask:$result);
+ let assemblyFormat = "$source attr-dict `:` type($result)";
+}
+
+def ConvertToSvboolOp : ArmSVE_Op<"convert_to_svbool",
+ [Pure, SvboolTypeConstraint<"source", "result">]>
+{
+ let summary = "Convert a SVE predicate type to a svbool type";
+ let description = [{
+ Converts SVE predicate types (or vectors of predicate types, e.g.
+ `vector<4x[4]xi1>`) to svbool types. Note: Only the trailing dimension can
+ be scalable.
+
+ Example 1: Convert a 1-D SVE predicate to a svbool mask.
+ ```mlir
+ %source = vector.create_mask %dim_size : vector<[4]xi1>
+ %result = arm_sve.convert_to_svbool %source : vector<[4]xi1>
+ // => Results in vector<[16]xi1>
+ ```
+
+ Example 2: Convert a 2-D mask of SVE predicates to a svbool mask.
+ ```mlir
+ %source = vector.create_mask %c2, %dim_size : vector<2x[2]xi1>
+ %result = arm_sve.convert_to_svbool %source : vector<2x[2]xi1>
+ // => Results in vector<2x[16]xi1>
+ ```
+
+ ---
+
+ A `svbool` is the smallest SVE predicate type that has a in-memory
+ representation (and maps to a full predicate register). In MLIR `svbool` is
+ represented as `vector<[16]xi1>`. Smaller SVE predicate types
+ (`vector<[1|2|4|8]xi1>`) must be converted to a `svbool` before they can be
+ stored.
+ }];
+ let arguments = (ins SVEPredicateMask:$source);
+ let results = (outs SVBoolMask:$result);
+ let assemblyFormat = "$source attr-dict `:` type($source)";
+}
+
def ScalableMaskedAddIOp : ScalableMaskedIOp<"masked.addi", "addition",
[Commutative]>;
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 4fc14e30b8a10d0..a7970e59de8c27e 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -37,6 +37,19 @@ def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
def IsScalableVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
::llvm::cast<VectorType>($_self).isScalable()}]>;
+// Whether a type is a scalable VectorType, with a single trailing scalable dimension.
+// Examples:
+// Valid:
+// - vector<[4]xf32>, vector<2x3x[2]xi64>, vector<32x[8]xi32>
+// Invalid
+// - vector<[4]x8xi32>, vector<[2]x[2]xf64>, vector<2x[8]x4xi32>
+def IsOnlyTrailingDimScalablePred : And<[
+ CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
+ CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
+ CPred<"::llvm::cast<::mlir::VectorType>($_self).getScalableDims().back()">,
+ CPred<"!llvm::is_contained(::llvm::cast<::mlir::VectorType>($_self).getScalableDims().drop_back(), true)">
+]>;
+
// Whether a type is a VectorType and all dimensions are scalable.
def allDimsScalableVectorTypePred : And<[
IsVectorTypePred,
@@ -404,6 +417,12 @@ class ScalableVectorOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsScalableVectorTypePred,
"scalable vector", "::mlir::VectorType">;
+// Any vector with a single trailing scalable dimension, with an element type in
+// the `allowedTypes` list.
+class TrailingScalableVectorOf<list<Type> allowedTypes> :
+ ShapedContainerType<allowedTypes, IsOnlyTrailingDimScalablePred,
+ "trailing scalable vector", "::mlir::VectorType">;
+
// Whether the number of elements of a vector is from the given
// `allowedRanks` list
class IsVectorOfRankPred<list<int> allowedRanks> :
@@ -481,10 +500,32 @@ class IsScalableVectorOfLengthPred<list<int> allowedLengths> :
== }]
# allowedlength>)>]>;
+class abs<int value> {
+ int ret = !if(!lt(value, 0), !sub(0, value), value);
+}
+
+// Whether the n-th (starting from 1) dim of the shape matches the given `size`.
+// Negative values index in reverse.
+class IsNthDimSizeIsOneOfPred<int n, list<int> allowedSizes>
+ : And<[CPred<"::llvm::cast<::mlir::ShapedType>($_self).getRank() >= " # abs<n>.ret>,
+ CPred<"::llvm::is_contained(ArrayRef<int64_t>({" # !interleave(allowedSizes, ", ") # "}), "
+ # "::llvm::cast<::mlir::ShapedType>($_self).getDimSize("
+ # !if(!lt(n, 0),
+ "::llvm::cast<::mlir::ShapedType>($_self).getRank() + " # n,
+ "" # !sub(n, 1))
+ # "))">]>;
+
// Whether the shape of a vector matches the given `shape` list.
class IsVectorOfShape<list<int> shape>
: CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(shape, ", ") # "})">;
+// Any ShapedType where the size of the n-th dim is contained in `sizes`.
+// Negative values index in reverse.
+class ShapedTypeWithNthDimOfSize<int n, list<int> allowedSizes> : Type<
+ IsNthDimSizeIsOneOfPred<n, allowedSizes>,
+ " with dim " # n # " having a size of {" # !interleave(allowedSizes, ", ") # "}",
+ "::mlir::ShapedType">;
+
// Any vector where the number of elements is from the given
// `allowedLengths` list
class VectorOfLength<list<int> allowedLengths> : Type<
@@ -546,6 +587,17 @@ class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
ScalableVectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;
+// Any scalable vector with a single trailing scalable dimensions, where the
+// size of the trailing dimension is in `allowedTrailingSizes` list, and the
+// type is in the `allowedTypes` list.
+class TrailingScalableVectorOfSizeAndType<list<int> allowedTrailingSizes,
+ list<Type> allowedTypes> : AllOfType<
+ [TrailingScalableVectorOf<allowedTypes>,
+ ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>],
+ TrailingScalableVectorOf<allowedTypes>.summary #
+ ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>.summary,
+ "::mlir::VectorType">;
+
def AnyVector : VectorOf<[AnyType]>;
// Temporary vector type clone that allows gradual transition to 0-D vectors.
def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
index b7f1020deba1e40..594c9b4c270f218 100644
--- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
+++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
diff --git a/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt
index fffc77245d12c93..9ef7384fc54925a 100644
--- a/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt
@@ -10,5 +10,6 @@ add_mlir_dialect_library(MLIRArmSVEDialect
LINK_LIBS PUBLIC
MLIRIR
MLIRLLVMDialect
+ MLIRVectorDialect
MLIRSideEffectInterfaces
)
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
index 7031ab4f799c4d2..2f1c43fae240d51 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRArmSVETransforms
LINK_LIBS PUBLIC
MLIRArmSVEDialect
MLIRFuncDialect
+ MLIRVectorDialect
MLIRIR
MLIRLLVMCommonConversion
MLIRLLVMDialect
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index abbb978304068e2..ca9e280f510858c 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -12,6 +12,8 @@
#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
@@ -66,6 +68,77 @@ using ScalableMaskedDivFOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
ScalableMaskedDivFIntrOp>;
+namespace {
+
+/// Unrolls a conversion to/from equivalent vector types, to allow using a
+/// conversion intrinsic that only supports 1-D vector types.
+///
+/// Example:
+/// ```
+/// %result = arm_sve.convert_to_svbool %source : vector<2x[4]xi1>
+/// ```
+/// is rewritten into:
+/// ```
+/// %cst = arith.constant dense<false> : vector<2x[16]xi1>
+/// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1>
+/// %2 = "arm_sve.intr.convert.to.svbool"(%1)
+/// : (vector<[4]xi1>) -> vector<[16]xi1>
+/// %3 = vector.insert %2, %cst [0] : vector<[16]xi1> into vector<2x[16]xi1>
+/// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1>
+/// %5 = "arm_sve.intr.convert.to.svbool"(%4)
+/// : (vector<[4]xi1>) -> vector<[16]xi1>
+/// %result = vector.insert %5, %3 [1] : vector<[16]xi1> into vector<2x[16]xi1>
+/// ```
+template <typename Op, typename IntrOp>
+struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> {
+ using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(Op convertOp, typename Op::Adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = convertOp.getLoc();
+
+ auto source = convertOp.getSource();
+ VectorType sourceType = source.getType();
+ VectorType resultType = convertOp.getResult().getType();
+
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, resultType, rewriter.getZeroAttr(resultType));
+
+ // We want to iterate over the input vector in steps of the trailing
+ // dimension. So this creates tile shape where all leading dimensions are 1,
+ // and the trailing dimension step is the size of the dimension.
+ SmallVector<int64_t> tileShape(sourceType.getRank(), 1);
+ tileShape.back() = sourceType.getShape().back();
+
+ // Iterate over all scalable mask/predicate slices of the source vector.
+ for (SmallVector<int64_t> index :
+ StaticTileOffsetRange(sourceType.getShape(), tileShape)) {
+ auto extractOrInsertPosition = ArrayRef(index).drop_back();
+ auto sourceVector = rewriter.create<vector::ExtractOp>(
+ loc, source, extractOrInsertPosition);
+ auto convertedType =
+ VectorType::Builder(llvm::cast<VectorType>(sourceVector.getType()))
+ .setDim(0, resultType.getShape().back());
+ auto convertedVector =
+ rewriter.create<IntrOp>(loc, TypeRange{convertedType}, sourceVector);
+ result = rewriter.create<vector::InsertOp>(loc, convertedVector, result,
+ extractOrInsertPosition);
+ }
+
+ rewriter.replaceOp(convertOp, result);
+ return success();
+ }
+};
+
+using ConvertToSvboolOpLowering =
+ SvboolConversionOpLowering<ConvertToSvboolOp, ConvertToSvboolIntrOp>;
+
+using ConvertFromSvboolOpLowering =
+ SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>;
+
+} // namespace
+
/// Populate the given list with patterns that convert from ArmSVE to LLVM.
void mlir::populateArmSVELegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
@@ -88,7 +161,9 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
ScalableMaskedMulFOpLowering,
ScalableMaskedSDivIOpLowering,
ScalableMaskedUDivIOpLowering,
- ScalableMaskedDivFOpLowering>(converter);
+ ScalableMaskedDivFOpLowering,
+ ConvertToSvboolOpLowering,
+ ConvertFromSvboolOpLowering>(converter);
// clang-format on
}
@@ -107,7 +182,9 @@ void mlir::configureArmSVELegalizeForExportTarget(
ScalableMaskedMulFIntrOp,
ScalableMaskedSDivIIntrOp,
ScalableMaskedUDivIIntrOp,
- ScalableMaskedDivFIntrOp>();
+ ScalableMaskedDivFIntrOp,
+ ConvertToSvboolIntrOp,
+ ConvertFromSvboolIntrOp>();
target.addIllegalOp<SdotOp,
SmmlaOp,
UdotOp,
@@ -120,6 +197,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
ScalableMaskedMulFOp,
ScalableMaskedSDivIOp,
ScalableMaskedUDivIOp,
- ScalableMaskedDivFOp>();
+ ScalableMaskedDivFOp,
+ ConvertToSvboolOp,
+ ConvertFromSvboolOp>();
// clang-format on
}
diff --git a/mlir/test/Dialect/ArmSVE/invalid.mlir b/mlir/test/Dialect/ArmSVE/invalid.mlir
new file mode 100644
index 000000000000000..a1fa0d0292b7b76
--- /dev/null
+++ b/mlir/test/Dialect/ArmSVE/invalid.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+// -----
+
+func.func @arm_sve_convert_from_svbool__bad_mask_type(%bool: vector<2x[16]xi1>) -> vector<2x[8]xi2> {
+ // expected-error@+1 {{'result' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<2x[8]xi2>'}}
+ %mask = arm_sve.convert_from_svbool %bool : vector<2x[8]xi2>
+ return %mask : vector<2x[8]xi2>
+}
+
+// -----
+
+func.func @arm_sve_convert_from_svbool__bad_mask_shape(%bool : vector<[16]xi1>) -> vector<[7]xi1> {
+ // expected-error@+1 {{'result' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<[7]xi1>'}}
+ %mask = arm_sve.convert_from_svbool %bool : vector<[7]xi1>
+ return %mask : vector<[7]xi1>
+}
+
+// -----
+
+func.func @arm_sve_convert_from_svbool__bad_mask_scalability(%bool : vector<[4]x[16]xi1>) -> vector<[4]x[8]xi1> {
+ // expected-error@+1 {{'result' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<[4]x[8]xi1>'}}
+ %mask = arm_sve.convert_from_svbool %bool : vector<[4]x[8]xi1>
+ return %mask : vector<[4]x[8]xi1>
+}
+
+// -----
+
+func.func @arm_sve_convert_to_svbool__bad_mask_type(%mask: vector<2x[8]xi2>) -> vector<2x[16]xi1> {
+ // expected-error@+1 {{'source' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<2x[8]xi2>'}}
+ %bool = arm_sve.convert_to_svbool %mask : vector<2x[8]xi2>
+ return %bool : vector<2x[16]xi1>
+}
+
+// -----
+
+func.func @arm_sve_convert_to_svbool__bad_mask_shape(%mask : vector<[7]xi1>) -> vector<[16]xi1> {
+ // expected-error@+1 {{'source' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<[7]xi1>'}}
+ %bool = arm_sve.convert_to_svbool %mask : vector<[7]xi1>
+ return
+}
+
+// -----
+
+func.func @arm_sve_convert_to_svbool__bad_mask_scalability(%mask : vector<[4]x[8]xi1>) -> vector<[4]x[16]xi1> {
+ // expected-error@+1 {{'source' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<[4]x[8]xi1>'}}
+ %bool = arm_sve.convert_to_svbool %mask : vector<[4]x[8]xi1>
+ return
+}
+
+
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 2d980db981034dd..8e76fb7119b844e 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sve" -convert-func-to-llvm -reconcile-unrealized-casts | mlir-opt | FileCheck %s
+// RUN: mlir-opt -convert-vector-to-llvm="enable-arm-sve" -convert-func-to-llvm -reconcile-unrealized-casts -split-input-file %s | FileCheck %s
func.func @arm_sve_sdot(%a: vector<[16]xi8>,
%b: vector<[16]xi8>,
@@ -10,6 +10,8 @@ func.func @arm_sve_sdot(%a: vector<[16]xi8>,
return %0 : vector<[4]xi32>
}
+// -----
+
func.func @arm_sve_smmla(%a: vector<[16]xi8>,
%b: vector<[16]xi8>,
%c: vector<[4]xi32>)
@@ -20,6 +22,8 @@ func.func @arm_sve_smmla(%a: vector<[16]xi8>,
return %0 : vector<[4]xi32>
}
+// -----
+
func.func @arm_sve_udot(%a: vector<[16]xi8>,
%b: vector<[16]xi8>,
%c: vector<[4]xi32>)
@@ -30,6 +34,8 @@ func.func @arm_sve_udot(%a: vector<[16]xi8>,
return %0 : vector<[4]xi32>
}
+// -----
+
func.func @arm_sve_ummla(%a: vector<[16]xi8>,
%b: vector<[16]xi8>,
%c: vector<[4]xi32>)
@@ -40,6 +46,8 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>,
return %0 : vector<[4]xi32>
}
+// -----
+
func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>,
%b: vector<[4]xi32>,
%c: vector<[4]xi32>,
@@ -65,6 +73,8 @@ func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>,
return %4 : vector<[4]xi32>
}
+// -----
+
func.func @arm_sve_arithf_masked(%a: vector<[4]xf32>,
%b: vector<[4]xf32>,
%c: vector<[4]xf32>,
@@ -87,6 +97,8 @@ func.func @arm_sve_arithf_masked(%a: vector<[4]xf32>,
return %3 : vector<[4]xf32>
}
+// -----
+
func.func @arm_sve_ab...
[truncated]
|
/// Unrolls a conversion to/from equivalent vector types, to allow using a | ||
/// conversion intrinsic that only supports 1-D vector types. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is doing two things and the op doesn't map 1-1 with intrinsics. Based on previous feedback I've received and general observations, I wonder if the unrolling should be done as a separate transform and this simply maps rank 1 arm_sve.convert_to_svbool
/ arm_sve.convert_from_svbool
to intrinsics?
It wouldn't have to be done as part of this patch, but something to consider.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's the whole reason these op exists though? If this mapped 1-1 to the intrinsics there'd be no reason for these ops to exist.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I understand that, but it doesn't mean both steps have to be done when lowering to intrinsics. At the Vector dialect level for example operations on higher rank vectors are typically broken down to rank 1 vectors first (VectorToSCF), before lowering to intrinsics.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That'll be quite a big rework as the ArmSVE dialect is pretty much just a skeleton, so currently has no -arm-sve-to-scf
(or similar) passes right now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think -arm-sve-to-scf
would make sense since this isn't emitting SCF ops, but I get your point, it would need some to live somewhere that doesn't exist today. Like I said, I'm happy with this regardless.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I second Cullen here. It's not obvious to me what the right "pass" would be and where would it live, but we should add a TODO in the comments and in the commit message. Something along the lines:
Extract the lowering of
convert_to_svbool
into a series ofvector.extract/insert
operations into a dedicated lowering layer/pass pre "LegalizeForLLVMExport". The latter should be as simple as 1-1 mapping.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the updates!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Few small questions, few small suggestions. Nothing major, just still struggling a bit with the new constraints. A few extra comments should be sufficient.
// Whether the n-th dim of the shape matches the given `size`. | ||
// Negative values index in reverse. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is size
? Or did you mean allowedSizes[n]
?
In general, quite unsure what this does. Could you add an example of what would be accepted and what would be rejected?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant:
// Whether the n-th dim of the shape is contained within `allowedSizes`.
// Negative values for `n` index in reverse.
Examples:
IsNthDimSizeIsOneOfPred<0, {2, 3, 4}>
Accepts any shape where the first dim is 2, 3, or 4.
=> This means shapes like: 2x8x9x5
, 4
, 3x1
, 4x?
, etc
IsNthDimSizeIsOneOfPred<-1, {16}>
Accepts any shape where the last dim is 16.
=> This means shapes like 2x16
, 16
, 1x2x3x4x16
, etc
IsNthDimSizeIsOneOfPred<-2, {10, 5}>
Accepts any shape where the second to last dim is 10 or 5.
=> This means shapes like: 1x10x2
, 2x1x4x5x6
, 8x10x?
, etc
"::llvm::cast<::mlir::ShapedType>($_self).getRank() + " # n, | ||
"" # n) | ||
# "))">]>; | ||
|
||
// Whether the shape of a vector matches the given `shape` list. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wouldn't "mix" IsVectorOfShape
(which is fairly basic) with the new definitions (which are quite complex)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've moved it to near the end of the list now
/// Unrolls a conversion to/from equivalent vector types, to allow using a | ||
/// conversion intrinsic that only supports 1-D vector types. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I second Cullen here. It's not obvious to me what the right "pass" would be and where would it live, but we should add a TODO in the comments and in the commit message. Something along the lines:
Extract the lowering of
convert_to_svbool
into a series ofvector.extract/insert
operations into a dedicated lowering layer/pass pre "LegalizeForLLVMExport". The latter should be as simple as 1-1 mapping.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for addressing my comments!
This is quite a non-trivial change, great job designing and implementing this!
Currently seeing some UBSAN issues in |
This patch adds a pass that ensures that loads, stores, and allocations of SVE vector types will be legal in the LLVM backend. It does this at the memref level, so this pass must be applied before lowering all the way to LLVM. This pass currently fixes two issues. ## Loading and storing predicate types It is only legal to load/store predicate types equal to (or greater than) a full predicate register, which in MLIR is `vector<[16]xi1>`. Smaller predicate types (`vector<[1|2|4|8]xi1>`) must be converted to/from a full predicate type (referred to as a `svbool`) before and after storing and loading respectively. This pass does this by widening allocations and inserting conversion intrinsics. For example: ```mlir %alloca = memref.alloca() : memref<vector<[4]xi1>> %mask = vector.constant_mask [4] : vector<[4]xi1> memref.store %mask, %alloca[] : memref<vector<[4]xi1>> %reload = memref.load %alloca[] : memref<vector<[4]xi1>> ``` Becomes: ```mlir %alloca = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>> %mask = vector.constant_mask [4] : vector<[4]xi1> %svbool = arm_sve.convert_to_svbool %mask : vector<[4]xi1> memref.store %svbool, %alloca[] : memref<vector<[16]xi1>> %reload_svbool = memref.load %alloca[] : memref<vector<[16]xi1>> %reload = arm_sve.convert_from_svbool %reload_svbool : vector<[4]xi1> ``` ## Relax alignments for SVE vector allocas The storage for SVE vector types only needs to have an alignment that matches the element type (for example 4 byte alignment for `f32`s). However, the LLVM backend currently defaults to aligning to `base size x element size` bytes. For non-legal vector types like `vector<[8]xf32>` this results in 8 x 4 = 32-byte alignment, but the backend only supports up to 16-byte alignment for SVE vectors on the stack. Explicitly setting a smaller alignment prevents this issue. Depends on: #68586 and #68695 (for testing)
This patch adds a pass that ensures that loads, stores, and allocations of SVE vector types will be legal in the LLVM backend. It does this at the memref level, so this pass must be applied before lowering all the way to LLVM. This pass currently fixes two issues. ## Loading and storing predicate types It is only legal to load/store predicate types equal to (or greater than) a full predicate register, which in MLIR is `vector<[16]xi1>`. Smaller predicate types (`vector<[1|2|4|8]xi1>`) must be converted to/from a full predicate type (referred to as a `svbool`) before and after storing and loading respectively. This pass does this by widening allocations and inserting conversion intrinsics. For example: ```mlir %alloca = memref.alloca() : memref<vector<[4]xi1>> %mask = vector.constant_mask [4] : vector<[4]xi1> memref.store %mask, %alloca[] : memref<vector<[4]xi1>> %reload = memref.load %alloca[] : memref<vector<[4]xi1>> ``` Becomes: ```mlir %alloca = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>> %mask = vector.constant_mask [4] : vector<[4]xi1> %svbool = arm_sve.convert_to_svbool %mask : vector<[4]xi1> memref.store %svbool, %alloca[] : memref<vector<[16]xi1>> %reload_svbool = memref.load %alloca[] : memref<vector<[16]xi1>> %reload = arm_sve.convert_from_svbool %reload_svbool : vector<[4]xi1> ``` ## Relax alignments for SVE vector allocas The storage for SVE vector types only needs to have an alignment that matches the element type (for example 4 byte alignment for `f32`s). However, the LLVM backend currently defaults to aligning to `base size x element size` bytes. For non-legal vector types like `vector<[8]xf32>` this results in 8 x 4 = 32-byte alignment, but the backend only supports up to 16-byte alignment for SVE vectors on the stack. Explicitly setting a smaller alignment prevents this issue. Depends on: llvm#68586 and llvm#68695 (for testing)
This adds slightly higher-level ops for converting masks between svbool
and SVE predicate types. The main reason to use these over the
intrinsics is these ops support vectors of masks (via unrolling).
E.g.
Or:
Depends on #68418