Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][ArmSVE] Add convert_to/from_svbool ops #68586

Merged
merged 6 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def ArmSVE_Dialect : Dialect {
This dialect contains the definitions necessary to target specific Arm SVE
scalable vector operations.
}];

let dependentDialects = ["vector::VectorDialect"];
}

//===----------------------------------------------------------------------===//
Expand All @@ -40,6 +42,13 @@ def SVBool : ScalableVectorOfRankAndLengthAndType<
def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
[1], [16, 8, 4, 2, 1], [I1]>;

// Generalizations of SVBool and SVEPredicate to ranks >= 1.
// These are masks with a single trailing scalable dimension.
def SVBoolMask : VectorWithTrailingDimScalableOfSizeAndType<
[16], [I1]>;
def SVEPredicateMask : VectorWithTrailingDimScalableOfSizeAndType<
[16, 8, 4, 2, 1], [I1]>;

//===----------------------------------------------------------------------===//
// ArmSVE op definitions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -236,6 +245,81 @@ def UmmlaOp : ArmSVE_Op<"ummla",
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}

class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
"expected corresponding svbool type widened to [16]xi1",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be possible to test for this diagnostic, right?

Copy link
Member Author

@MacDue MacDue Oct 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already implicitly tested. This is used to infer the the svbool type (from the result or argument), so there's no textual representation of the op that fails this constraint.

lhsArg, rhsArg,
"VectorType(VectorType::Builder(::llvm::cast<VectorType>($_self)).setDim(::llvm::cast<VectorType>($_self).getRank() - 1, 16))">;
MacDue marked this conversation as resolved.
Show resolved Hide resolved

def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool",
[Pure, SvboolTypeConstraint<"result", "source">]>
{
let summary = "Convert a svbool type to a SVE predicate type";
let description = [{
Converts svbool types (`vector<[16]xi1>` or vectors of that type, e.g.
`vector<2x3x[16]xi1>`) to SVE predicate types. Note: Only the trailing
dimension can be scalable.

Example 1: Convert a 1-D svbool mask to a SVE predicate.
```mlir
%source = vector.load %memref[%c0] : memref<?xi1>, vector<[16]xi1>
%result = arm_sve.convert_from_svbool %source : vector<[4]xi1>
```

Example 2: Convert a 2-D svbool mask to a mask of SVE predicates.
```mlir
%source = vector.load %memref[%c0, %c0] : memref<2x?xi1>, vector<2x[16]xi1>
%result = arm_sve.convert_from_svbool %source : vector<2x[8]xi1>
```

---

A `svbool` is the smallest SVE predicate type that has a in-memory
representation (and maps to a full predicate register). In MLIR `svbool` is
represented as `vector<[16]xi1>`. Smaller SVE predicate types
(`vector<[1|2|4|8]xi1>`) must be stored as a `svbool` then converted back to
the original predicate type after loading.
}];
let arguments = (ins SVBoolMask:$source);
let results = (outs SVEPredicateMask:$result);
let assemblyFormat = "$source attr-dict `:` type($result)";
}

def ConvertToSvboolOp : ArmSVE_Op<"convert_to_svbool",
[Pure, SvboolTypeConstraint<"source", "result">]>
{
let summary = "Convert a SVE predicate type to a svbool type";
let description = [{
Converts SVE predicate types (or vectors of predicate types, e.g.
`vector<4x[4]xi1>`) to svbool types. Note: Only the trailing dimension can
be scalable.

Example 1: Convert a 1-D SVE predicate to a svbool mask.
```mlir
%source = vector.create_mask %dim_size : vector<[4]xi1>
%result = arm_sve.convert_to_svbool %source : vector<[4]xi1>
// => Results in vector<[16]xi1>
```

Example 2: Convert a 2-D mask of SVE predicates to a svbool mask.
```mlir
%source = vector.create_mask %c2, %dim_size : vector<2x[2]xi1>
%result = arm_sve.convert_to_svbool %source : vector<2x[2]xi1>
// => Results in vector<2x[16]xi1>
```

---

A `svbool` is the smallest SVE predicate type that has a in-memory
representation (and maps to a full predicate register). In MLIR `svbool` is
represented as `vector<[16]xi1>`. Smaller SVE predicate types
(`vector<[1|2|4|8]xi1>`) must be converted to a `svbool` before they can be
stored.
}];
let arguments = (ins SVEPredicateMask:$source);
let results = (outs SVBoolMask:$result);
let assemblyFormat = "$source attr-dict `:` type($source)";
}

def ScalableMaskedAddIOp : ScalableMaskedIOp<"masked.addi", "addition",
[Commutative]>;

Expand Down
74 changes: 74 additions & 0 deletions mlir/include/mlir/IR/CommonTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
def IsScalableVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
::llvm::cast<VectorType>($_self).isScalable()}]>;

// Whether a type is a scalable VectorType, with a single trailing scalable dimension.
// Examples:
// Valid:
// - vector<[4]xf32>, vector<2x3x[2]xi64>, vector<32x[8]xi32>
// Invalid
// - vector<[4]x8xi32>, vector<[2]x[2]xf64>, vector<2x[8]x4xi32>
def IsVectorTypeWithOnlyTrailingDimScalablePred : And<[
CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
CPred<"::llvm::cast<::mlir::VectorType>($_self).getScalableDims().back()">,
CPred<"!llvm::is_contained(::llvm::cast<::mlir::VectorType>($_self).getScalableDims().drop_back(), true)">
]>;

// Whether a type is a VectorType and all dimensions are scalable.
def allDimsScalableVectorTypePred : And<[
IsVectorTypePred,
Expand Down Expand Up @@ -404,6 +417,15 @@ class ScalableVectorOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsScalableVectorTypePred,
"scalable vector", "::mlir::VectorType">;

// Any vector with a single trailing scalable dimension, with an element type in
// the `allowedTypes` list.
//
// Note: This Similar to ScalableVectorOf, with the extra requirement that only
// the trailing dim is scalable.
class VectorWithTrailingDimScalableOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsVectorTypeWithOnlyTrailingDimScalablePred,
"trailing scalable vector", "::mlir::VectorType">;

// Whether the number of elements of a vector is from the given
// `allowedRanks` list
class IsVectorOfRankPred<list<int> allowedRanks> :
Expand Down Expand Up @@ -481,6 +503,40 @@ class IsScalableVectorOfLengthPred<list<int> allowedLengths> :
== }]
# allowedlength>)>]>;

// Normalizes an index so the indices in both directions have the same value.
// For example, when indexing forwards index 2 is the third element. When
// indexing in reverse the third element is -3. This helper would map both of
// these to the "normalized" index of 3. This makes the bounds checking in
// IsNthDimSizeIsOneOfPred simpler (see first CPred).
class NormalizeIndex<int value> {
int ret = !if(!lt(value, 0),
!sub(0, value) /* -value if negative */,
!add(value, 1) /* value + 1 if positive*/);
}

// Whether the n-th dim of the shape is contained within `allowedSizes`.
// Negative values for `n` index in reverse.
//
// Examples:
// IsNthDimSizeIsOneOfPred<0, {2, 3, 4}>
// - Accepts any shape where the first dim is 2, 3, or 4.
// * This means shapes like: 2x8x9x5, 4, 3x1, 4x?, etc
// IsNthDimSizeIsOneOfPred<-1, {16}>
// - Accepts any shape where the last dim is 16.
// * This means shapes like 2x16, 16, 1x2x3x4x16, etc
// IsNthDimSizeIsOneOfPred<-2, {10, 5}>
// - Accepts any shape where the second to last dim is 10 or 5.
// * This means shapes like: 1x10x2, 2x1x4x5x6, 8x10x?, etc
class IsNthDimSizeIsOneOfPred<int n, list<int> allowedSizes>
: And<[
CPred<"::llvm::cast<::mlir::ShapedType>($_self).getRank() >= " # NormalizeIndex<n>.ret>,
CPred<"::llvm::is_contained(ArrayRef<int64_t>({" # !interleave(allowedSizes, ", ") # "}), "
# "::llvm::cast<::mlir::ShapedType>($_self).getDimSize("
# !if(!lt(n, 0),
"::llvm::cast<::mlir::ShapedType>($_self).getRank() + " # n,
"" # n)
# "))">]>;

// Whether the shape of a vector matches the given `shape` list.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't "mix" IsVectorOfShape (which is fairly basic) with the new definitions (which are quite complex)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've moved it to near the end of the list now

class IsVectorOfShape<list<int> shape>
: CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(shape, ", ") # "})">;
Expand Down Expand Up @@ -546,6 +602,24 @@ class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
ScalableVectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;

// Any ShapedType where the size of the n-th dim is contained in `allowedSizes`.
// Negative values for `n` index in reverse.
class ShapedTypeWithNthDimOfSize<int n, list<int> allowedSizes> : Type<
IsNthDimSizeIsOneOfPred<n, allowedSizes>,
" with dim " # n # " having a size of {" # !interleave(allowedSizes, ", ") # "}",
"::mlir::ShapedType">;

// Any scalable vector with a single trailing scalable dimensions, where the
// size of the trailing dimension is in `allowedTrailingSizes` list, and the
// type is in the `allowedTypes` list.
class VectorWithTrailingDimScalableOfSizeAndType<list<int> allowedTrailingSizes,
list<Type> allowedTypes> : AllOfType<
[VectorWithTrailingDimScalableOf<allowedTypes>,
ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>],
VectorWithTrailingDimScalableOf<allowedTypes>.summary #
ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>.summary,
"::mlir::VectorType">;

def AnyVector : VectorOf<[AnyType]>;
// Temporary vector type clone that allows gradual transition to 0-D vectors.
def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ add_mlir_dialect_library(MLIRArmSVEDialect
LINK_LIBS PUBLIC
MLIRIR
MLIRLLVMDialect
MLIRVectorDialect
MacDue marked this conversation as resolved.
Show resolved Hide resolved
MLIRSideEffectInterfaces
)
1 change: 1 addition & 0 deletions mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRArmSVETransforms
LINK_LIBS PUBLIC
MLIRArmSVEDialect
MLIRFuncDialect
MLIRVectorDialect
MLIRIR
MLIRLLVMCommonConversion
MLIRLLVMDialect
Expand Down
85 changes: 82 additions & 3 deletions mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"

Expand Down Expand Up @@ -66,6 +68,77 @@ using ScalableMaskedDivFOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
ScalableMaskedDivFIntrOp>;

namespace {

/// Unrolls a conversion to/from equivalent vector types, to allow using a
/// conversion intrinsic that only supports 1-D vector types.
Comment on lines +73 to +74
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is doing two things and the op doesn't map 1-1 with intrinsics. Based on previous feedback I've received and general observations, I wonder if the unrolling should be done as a separate transform and this simply maps rank 1 arm_sve.convert_to_svbool / arm_sve.convert_from_svbool to intrinsics?

It wouldn't have to be done as part of this patch, but something to consider.

Copy link
Member Author

@MacDue MacDue Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the whole reason these op exists though? If this mapped 1-1 to the intrinsics there'd be no reason for these ops to exist.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I understand that, but it doesn't mean both steps have to be done when lowering to intrinsics. At the Vector dialect level for example operations on higher rank vectors are typically broken down to rank 1 vectors first (VectorToSCF), before lowering to intrinsics.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That'll be quite a big rework as the ArmSVE dialect is pretty much just a skeleton, so currently has no -arm-sve-to-scf (or similar) passes right now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think -arm-sve-to-scf would make sense since this isn't emitting SCF ops, but I get your point, it would need some to live somewhere that doesn't exist today. Like I said, I'm happy with this regardless.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I second Cullen here. It's not obvious to me what the right "pass" would be and where would it live, but we should add a TODO in the comments and in the commit message. Something along the lines:

Extract the lowering of convert_to_svbool into a series of vector.extract/insert operations into a dedicated lowering layer/pass pre "LegalizeForLLVMExport". The latter should be as simple as 1-1 mapping.

///
/// Example:
/// ```
/// %result = arm_sve.convert_to_svbool %source : vector<2x[4]xi1>
/// ```
/// is rewritten into:
/// ```
/// %cst = arith.constant dense<false> : vector<2x[16]xi1>
/// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1>
/// %2 = "arm_sve.intr.convert.to.svbool"(%1)
/// : (vector<[4]xi1>) -> vector<[16]xi1>
/// %3 = vector.insert %2, %cst[0] : vector<[16]xi1> into vector<2x[16]xi1>
/// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1>
/// %5 = "arm_sve.intr.convert.to.svbool"(%4)
/// : (vector<[4]xi1>) -> vector<[16]xi1>
/// %result = vector.insert %5, %3[1] : vector<[16]xi1> into vector<2x[16]xi1>
/// ```
template <typename Op, typename IntrOp>
struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> {
using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(Op convertOp, typename Op::Adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = convertOp.getLoc();

auto source = convertOp.getSource();
VectorType sourceType = source.getType();
VectorType resultType = convertOp.getResult().getType();

Value result = rewriter.create<arith::ConstantOp>(
loc, resultType, rewriter.getZeroAttr(resultType));

// We want to iterate over the input vector in steps of the trailing
// dimension. So this creates tile shape where all leading dimensions are 1,
// and the trailing dimension step is the size of the dimension.
SmallVector<int64_t> tileShape(sourceType.getRank(), 1);
tileShape.back() = sourceType.getShape().back();

// Iterate over all scalable mask/predicate slices of the source vector.
for (SmallVector<int64_t> index :
StaticTileOffsetRange(sourceType.getShape(), tileShape)) {
auto extractOrInsertPosition = ArrayRef(index).drop_back();
auto sourceVector = rewriter.create<vector::ExtractOp>(
loc, source, extractOrInsertPosition);
auto convertedType =
VectorType::Builder(llvm::cast<VectorType>(sourceVector.getType()))
.setDim(0, resultType.getShape().back());
auto convertedVector =
rewriter.create<IntrOp>(loc, TypeRange{convertedType}, sourceVector);
result = rewriter.create<vector::InsertOp>(loc, convertedVector, result,
extractOrInsertPosition);
}

rewriter.replaceOp(convertOp, result);
return success();
}
};

using ConvertToSvboolOpLowering =
SvboolConversionOpLowering<ConvertToSvboolOp, ConvertToSvboolIntrOp>;

using ConvertFromSvboolOpLowering =
SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>;

} // namespace

/// Populate the given list with patterns that convert from ArmSVE to LLVM.
void mlir::populateArmSVELegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
Expand All @@ -88,7 +161,9 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
ScalableMaskedMulFOpLowering,
ScalableMaskedSDivIOpLowering,
ScalableMaskedUDivIOpLowering,
ScalableMaskedDivFOpLowering>(converter);
ScalableMaskedDivFOpLowering,
ConvertToSvboolOpLowering,
ConvertFromSvboolOpLowering>(converter);
// clang-format on
}

Expand All @@ -107,7 +182,9 @@ void mlir::configureArmSVELegalizeForExportTarget(
ScalableMaskedMulFIntrOp,
ScalableMaskedSDivIIntrOp,
ScalableMaskedUDivIIntrOp,
ScalableMaskedDivFIntrOp>();
ScalableMaskedDivFIntrOp,
ConvertToSvboolIntrOp,
ConvertFromSvboolIntrOp>();
target.addIllegalOp<SdotOp,
SmmlaOp,
UdotOp,
Expand All @@ -120,6 +197,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
ScalableMaskedMulFOp,
ScalableMaskedSDivIOp,
ScalableMaskedUDivIOp,
ScalableMaskedDivFOp>();
ScalableMaskedDivFOp,
ConvertToSvboolOp,
ConvertFromSvboolOp>();
// clang-format on
}
Loading