-
Notifications
You must be signed in to change notification settings - Fork 12.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][ArmSVE] Add convert_to/from_svbool ops #68586
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,6 +37,19 @@ def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) && | |
def IsScalableVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) && | ||
::llvm::cast<VectorType>($_self).isScalable()}]>; | ||
|
||
// Whether a type is a scalable VectorType, with a single trailing scalable dimension. | ||
// Examples: | ||
// Valid: | ||
// - vector<[4]xf32>, vector<2x3x[2]xi64>, vector<32x[8]xi32> | ||
// Invalid | ||
// - vector<[4]x8xi32>, vector<[2]x[2]xf64>, vector<2x[8]x4xi32> | ||
def IsVectorTypeWithOnlyTrailingDimScalablePred : And<[ | ||
CPred<"::llvm::isa<::mlir::VectorType>($_self)">, | ||
CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">, | ||
CPred<"::llvm::cast<::mlir::VectorType>($_self).getScalableDims().back()">, | ||
CPred<"!llvm::is_contained(::llvm::cast<::mlir::VectorType>($_self).getScalableDims().drop_back(), true)"> | ||
]>; | ||
|
||
// Whether a type is a VectorType and all dimensions are scalable. | ||
def allDimsScalableVectorTypePred : And<[ | ||
IsVectorTypePred, | ||
|
@@ -404,6 +417,15 @@ class ScalableVectorOf<list<Type> allowedTypes> : | |
ShapedContainerType<allowedTypes, IsScalableVectorTypePred, | ||
"scalable vector", "::mlir::VectorType">; | ||
|
||
// Any vector with a single trailing scalable dimension, with an element type in | ||
// the `allowedTypes` list. | ||
// | ||
// Note: This Similar to ScalableVectorOf, with the extra requirement that only | ||
// the trailing dim is scalable. | ||
class VectorWithTrailingDimScalableOf<list<Type> allowedTypes> : | ||
ShapedContainerType<allowedTypes, IsVectorTypeWithOnlyTrailingDimScalablePred, | ||
"trailing scalable vector", "::mlir::VectorType">; | ||
|
||
// Whether the number of elements of a vector is from the given | ||
// `allowedRanks` list | ||
class IsVectorOfRankPred<list<int> allowedRanks> : | ||
|
@@ -481,6 +503,40 @@ class IsScalableVectorOfLengthPred<list<int> allowedLengths> : | |
== }] | ||
# allowedlength>)>]>; | ||
|
||
// Normalizes an index so the indices in both directions have the same value. | ||
// For example, when indexing forwards index 2 is the third element. When | ||
// indexing in reverse the third element is -3. This helper would map both of | ||
// these to the "normalized" index of 3. This makes the bounds checking in | ||
// IsNthDimSizeIsOneOfPred simpler (see first CPred). | ||
class NormalizeIndex<int value> { | ||
int ret = !if(!lt(value, 0), | ||
!sub(0, value) /* -value if negative */, | ||
!add(value, 1) /* value + 1 if positive*/); | ||
} | ||
|
||
// Whether the n-th dim of the shape is contained within `allowedSizes`. | ||
// Negative values for `n` index in reverse. | ||
// | ||
// Examples: | ||
// IsNthDimSizeIsOneOfPred<0, {2, 3, 4}> | ||
// - Accepts any shape where the first dim is 2, 3, or 4. | ||
// * This means shapes like: 2x8x9x5, 4, 3x1, 4x?, etc | ||
// IsNthDimSizeIsOneOfPred<-1, {16}> | ||
// - Accepts any shape where the last dim is 16. | ||
// * This means shapes like 2x16, 16, 1x2x3x4x16, etc | ||
// IsNthDimSizeIsOneOfPred<-2, {10, 5}> | ||
// - Accepts any shape where the second to last dim is 10 or 5. | ||
// * This means shapes like: 1x10x2, 2x1x4x5x6, 8x10x?, etc | ||
class IsNthDimSizeIsOneOfPred<int n, list<int> allowedSizes> | ||
: And<[ | ||
CPred<"::llvm::cast<::mlir::ShapedType>($_self).getRank() >= " # NormalizeIndex<n>.ret>, | ||
CPred<"::llvm::is_contained(ArrayRef<int64_t>({" # !interleave(allowedSizes, ", ") # "}), " | ||
# "::llvm::cast<::mlir::ShapedType>($_self).getDimSize(" | ||
# !if(!lt(n, 0), | ||
"::llvm::cast<::mlir::ShapedType>($_self).getRank() + " # n, | ||
"" # n) | ||
# "))">]>; | ||
|
||
// Whether the shape of a vector matches the given `shape` list. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wouldn't "mix" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've moved it to near the end of the list now |
||
class IsVectorOfShape<list<int> shape> | ||
: CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(shape, ", ") # "})">; | ||
|
@@ -546,6 +602,24 @@ class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks, | |
ScalableVectorOfLength<allowedLengths>.summary, | ||
"::mlir::VectorType">; | ||
|
||
// Any ShapedType where the size of the n-th dim is contained in `allowedSizes`. | ||
// Negative values for `n` index in reverse. | ||
class ShapedTypeWithNthDimOfSize<int n, list<int> allowedSizes> : Type< | ||
IsNthDimSizeIsOneOfPred<n, allowedSizes>, | ||
" with dim " # n # " having a size of {" # !interleave(allowedSizes, ", ") # "}", | ||
"::mlir::ShapedType">; | ||
|
||
// Any scalable vector with a single trailing scalable dimensions, where the | ||
// size of the trailing dimension is in `allowedTrailingSizes` list, and the | ||
// type is in the `allowedTypes` list. | ||
class VectorWithTrailingDimScalableOfSizeAndType<list<int> allowedTrailingSizes, | ||
list<Type> allowedTypes> : AllOfType< | ||
[VectorWithTrailingDimScalableOf<allowedTypes>, | ||
ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>], | ||
VectorWithTrailingDimScalableOf<allowedTypes>.summary # | ||
ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>.summary, | ||
"::mlir::VectorType">; | ||
|
||
def AnyVector : VectorOf<[AnyType]>; | ||
// Temporary vector type clone that allows gradual transition to 0-D vectors. | ||
def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,8 @@ | |
#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" | ||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" | ||
#include "mlir/Dialect/Utils/IndexingUtils.h" | ||
#include "mlir/Dialect/Vector/IR/VectorOps.h" | ||
#include "mlir/IR/BuiltinOps.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
|
||
|
@@ -66,6 +68,77 @@ using ScalableMaskedDivFOpLowering = | |
OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp, | ||
ScalableMaskedDivFIntrOp>; | ||
|
||
namespace { | ||
|
||
/// Unrolls a conversion to/from equivalent vector types, to allow using a | ||
/// conversion intrinsic that only supports 1-D vector types. | ||
Comment on lines
+73
to
+74
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is doing two things and the op doesn't map 1-1 with intrinsics. Based on previous feedback I've received and general observations, I wonder if the unrolling should be done as a separate transform and this simply maps rank 1 It wouldn't have to be done as part of this patch, but something to consider. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's the whole reason these op exists though? If this mapped 1-1 to the intrinsics there'd be no reason for these ops to exist. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I understand that, but it doesn't mean both steps have to be done when lowering to intrinsics. At the Vector dialect level for example operations on higher rank vectors are typically broken down to rank 1 vectors first (VectorToSCF), before lowering to intrinsics. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That'll be quite a big rework as the ArmSVE dialect is pretty much just a skeleton, so currently has no There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I second Cullen here. It's not obvious to me what the right "pass" would be and where would it live, but we should add a TODO in the comments and in the commit message. Something along the lines:
|
||
/// | ||
/// Example: | ||
/// ``` | ||
/// %result = arm_sve.convert_to_svbool %source : vector<2x[4]xi1> | ||
/// ``` | ||
/// is rewritten into: | ||
/// ``` | ||
/// %cst = arith.constant dense<false> : vector<2x[16]xi1> | ||
/// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1> | ||
/// %2 = "arm_sve.intr.convert.to.svbool"(%1) | ||
/// : (vector<[4]xi1>) -> vector<[16]xi1> | ||
/// %3 = vector.insert %2, %cst[0] : vector<[16]xi1> into vector<2x[16]xi1> | ||
/// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1> | ||
/// %5 = "arm_sve.intr.convert.to.svbool"(%4) | ||
/// : (vector<[4]xi1>) -> vector<[16]xi1> | ||
/// %result = vector.insert %5, %3[1] : vector<[16]xi1> into vector<2x[16]xi1> | ||
/// ``` | ||
template <typename Op, typename IntrOp> | ||
struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> { | ||
using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern; | ||
|
||
LogicalResult | ||
matchAndRewrite(Op convertOp, typename Op::Adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
auto loc = convertOp.getLoc(); | ||
|
||
auto source = convertOp.getSource(); | ||
VectorType sourceType = source.getType(); | ||
VectorType resultType = convertOp.getResult().getType(); | ||
|
||
Value result = rewriter.create<arith::ConstantOp>( | ||
loc, resultType, rewriter.getZeroAttr(resultType)); | ||
|
||
// We want to iterate over the input vector in steps of the trailing | ||
// dimension. So this creates tile shape where all leading dimensions are 1, | ||
// and the trailing dimension step is the size of the dimension. | ||
SmallVector<int64_t> tileShape(sourceType.getRank(), 1); | ||
tileShape.back() = sourceType.getShape().back(); | ||
|
||
// Iterate over all scalable mask/predicate slices of the source vector. | ||
for (SmallVector<int64_t> index : | ||
StaticTileOffsetRange(sourceType.getShape(), tileShape)) { | ||
auto extractOrInsertPosition = ArrayRef(index).drop_back(); | ||
auto sourceVector = rewriter.create<vector::ExtractOp>( | ||
loc, source, extractOrInsertPosition); | ||
auto convertedType = | ||
VectorType::Builder(llvm::cast<VectorType>(sourceVector.getType())) | ||
.setDim(0, resultType.getShape().back()); | ||
auto convertedVector = | ||
rewriter.create<IntrOp>(loc, TypeRange{convertedType}, sourceVector); | ||
result = rewriter.create<vector::InsertOp>(loc, convertedVector, result, | ||
extractOrInsertPosition); | ||
} | ||
|
||
rewriter.replaceOp(convertOp, result); | ||
return success(); | ||
} | ||
}; | ||
|
||
using ConvertToSvboolOpLowering = | ||
SvboolConversionOpLowering<ConvertToSvboolOp, ConvertToSvboolIntrOp>; | ||
|
||
using ConvertFromSvboolOpLowering = | ||
SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>; | ||
|
||
} // namespace | ||
|
||
/// Populate the given list with patterns that convert from ArmSVE to LLVM. | ||
void mlir::populateArmSVELegalizeForLLVMExportPatterns( | ||
LLVMTypeConverter &converter, RewritePatternSet &patterns) { | ||
|
@@ -88,7 +161,9 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns( | |
ScalableMaskedMulFOpLowering, | ||
ScalableMaskedSDivIOpLowering, | ||
ScalableMaskedUDivIOpLowering, | ||
ScalableMaskedDivFOpLowering>(converter); | ||
ScalableMaskedDivFOpLowering, | ||
ConvertToSvboolOpLowering, | ||
ConvertFromSvboolOpLowering>(converter); | ||
// clang-format on | ||
} | ||
|
||
|
@@ -107,7 +182,9 @@ void mlir::configureArmSVELegalizeForExportTarget( | |
ScalableMaskedMulFIntrOp, | ||
ScalableMaskedSDivIIntrOp, | ||
ScalableMaskedUDivIIntrOp, | ||
ScalableMaskedDivFIntrOp>(); | ||
ScalableMaskedDivFIntrOp, | ||
ConvertToSvboolIntrOp, | ||
ConvertFromSvboolIntrOp>(); | ||
target.addIllegalOp<SdotOp, | ||
SmmlaOp, | ||
UdotOp, | ||
|
@@ -120,6 +197,8 @@ void mlir::configureArmSVELegalizeForExportTarget( | |
ScalableMaskedMulFOp, | ||
ScalableMaskedSDivIOp, | ||
ScalableMaskedUDivIOp, | ||
ScalableMaskedDivFOp>(); | ||
ScalableMaskedDivFOp, | ||
ConvertToSvboolOp, | ||
ConvertFromSvboolOp>(); | ||
// clang-format on | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be possible to test for this diagnostic, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is already implicitly tested. This is used to infer the the svbool type (from the result or argument), so there's no textual representation of the op that fails this constraint.