Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][ArmSVE] Add convert_to/from_svbool ops #68586

Merged
merged 6 commits into from
Oct 12, 2023

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Oct 9, 2023

This adds slightly higher-level ops for converting masks between svbool
and SVE predicate types. The main reason to use these over the
intrinsics is these ops support vectors of masks (via unrolling).

E.g.

// Convert a svbool mask to a mask of SVE predicates:
%svbool = vector.load %memref[%c0, %c0]
                       : memref<2x?xi1>, vector<2x[16]xi1>
%mask = arm_sve.convert_from_svbool %svbool : vector<2x[8]xi1>
// => Results in vector<2x[8]xi1>

Or:

// Convert a mask of SVE predicates to a svbool mask:
%mask = vector.create_mask %c2, %dim_size : vector<2x[2]xi1>
%svbool = arm_sve.convert_to_svbool %mask : vector<2x[2]xi1>
// => Results in vector<2x[16]xi1>

Depends on #68418

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly makes sense, though I'd like to take another look at the new tablegen defs when I am more awake :)

mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir Outdated Show resolved Hide resolved
mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt Show resolved Hide resolved
mlir/include/mlir/IR/CommonTypeConstraints.td Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td Outdated Show resolved Hide resolved
mlir/include/mlir/IR/CommonTypeConstraints.td Outdated Show resolved Hide resolved
// Negative values index in reverse.
class ShapedTypeWithNthDimOfSize<int n, list<int> allowedSizes> : Type<
IsNthDimSizeIsOneOfPred<n, allowedSizes>,
" with dim " # n # " having a size of {" # !interleave(allowedSizes, ", ") # "}",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this subtract 1 from n if n > 0?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, the dims are indexed from 1 (so both reverse and forward indexing is symmetrical).

This adds slightly higher-level ops for converting masks between svbool
and SVE predicate types. The main reason to use these over the
intrinsics is these ops support vectors of masks (via unrolling).

E.g.

```
// Convert a svbool mask to a mask of SVE predicates:
%svbool = vector.load %memref[%c0, %c0]
                       : memref<2x?xi1>, vector<2x[16]xi1>
%mask = arm_sve.convert_from_svbool %svbool : vector<2x[8]xi1>
// => Results in vector<2x[8]xi1>
```
Or:
```
// Convert a mask of SVE predicates to a svbool mask:
%mask = vector.create_mask %c2, %dim_size : vector<2x[2]xi1>
%svbool = arm_sve.convert_to_svbool %mask : vector<2x[2]xi1>
// => Results in vector<2x[16]xi1>
```
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, have gone through a few more definitions. Will do more later.

mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td Outdated Show resolved Hide resolved
@@ -236,6 +243,66 @@ def UmmlaOp : ArmSVE_Op<"ummla",
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}


class SvboolTypeContraint<string lhsArg, string rhsArg> : TypesMatchWith<
"expected corresponding svbool type widened to [16]xi1",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be possible to test for this diagnostic, right?

Copy link
Member Author

@MacDue MacDue Oct 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already implicitly tested. This is used to infer the the svbool type (from the result or argument), so there's no textual representation of the op that fails this constraint.

mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td Show resolved Hide resolved
Comment on lines 256 to 279
let description = [{
Converts svbool types (`vector<[16]xi1>` or vectors of that type, e.g.
`vector<2x3x[16]xi1>`) to SVE predicate types. Note: Only the trailing
dimension can be scalable.

Example 1: Convert a 1-D svbool mask to a SVE predicate.
```mlir
%svbool = vector.load %memref[%c0] : memref<?xi1>, vector<[16]xi1>
%mask = arm_sve.convert_from_svbool %svbool : vector<[4]xi1>
```

Example 2: Convert a 2-D svbool mask to a mask of SVE predicates.
```mlir
%svbool = vector.load %memref[%c0, %c0] : memref<2x?xi1>, vector<2x[16]xi1>
%mask = arm_sve.convert_from_svbool %svbool : vector<2x[8]xi1>
```
}];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would somebody reading this file know what svbool and "SVE predicate" types are? It would be good to define both. Note that also input/output names are a bit inconsistent:

  let arguments = (ins SVBoolMask:$source);
  let results = (outs SVEMask:$result);

Given how subtle this is, I'd suggest either expanding this definition, adding links or adding some definition at the top of this file. Or perhaps these types are defined somewhere? (as in, other than as "TableGen" definitions).

Also, "a SVE" -> "an SVE"?

Copy link
Member Author

@MacDue MacDue Oct 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, "a SVE" -> "an SVE"?

No idea! I think my brain hardcodes this for a few words, then does rand() for the rest 😛

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't tell if "a" or "an" is correct here... Grammarly says "a" though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a little explainer at the bottom of each op, the online docs should also include auto-generated descriptions of the vector types.

@@ -37,6 +37,12 @@ def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
def IsScalableVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
::llvm::cast<VectorType>($_self).isScalable()}]>;

// Whether a type is a scalable VectorType, with a single trailing scalable dimension.
def IsTrailingScalableVectorTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this inherit from IsScalableVectorTypePred? Just looking at the comments:

// Whether a type is a scalable VectorType.

vs

// Whether a type is a scalable VectorType, with a single trailing scalable dimension.

So it sounds like IsTrailingScalableVectorTypePred is a refined version of IsScalableVectorTypePred? In fact, given how complex this is - wouldn't it make sense to simply have a dedicated predicate to check that only the trailing dim is scalable? I am also thinking about the name of this predicate and IMHO it doesn't reveal what the predicate represents. I would try something like:

  • IsOnlyTrailingDImScalablePred

That would be shorter, but would capture the key aspect of this pred. Also, we know that only vectors can be "scalable", so can skip "scalable" from name. I would, perhaps, also add an example of what would be valid and what would not under this predicate.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't actually get anything from IsScalableVectorTypePred. That just checks if you have a scalable dim anywhere, so you still have to check that the trailing dim is scalable, and that all the other dims are not.

It's a refined version of IsVectorTypePred, but as that'll remove getRank() > 0 at some point in the future, that can't be depended on.

@MacDue MacDue force-pushed the arm_sve_add_svbool_ops branch from ee738b9 to 9c656ca Compare October 10, 2023 14:41
@MacDue MacDue marked this pull request as ready for review October 10, 2023 14:41
@llvmbot
Copy link
Member

llvmbot commented Oct 10, 2023

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-ods
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir-sve

Author: Benjamin Maxwell (MacDue)

Changes

This adds slightly higher-level ops for converting masks between svbool
and SVE predicate types. The main reason to use these over the
intrinsics is these ops support vectors of masks (via unrolling).

E.g.

// Convert a svbool mask to a mask of SVE predicates:
%svbool = vector.load %memref[%c0, %c0]
                       : memref&lt;2x?xi1&gt;, vector&lt;2x[16]xi1&gt;
%mask = arm_sve.convert_from_svbool %svbool : vector&lt;2x[8]xi1&gt;
// =&gt; Results in vector&lt;2x[8]xi1&gt;

Or:

// Convert a mask of SVE predicates to a svbool mask:
%mask = vector.create_mask %c2, %dim_size : vector&lt;2x[2]xi1&gt;
%svbool = arm_sve.convert_to_svbool %mask : vector&lt;2x[2]xi1&gt;
// =&gt; Results in vector&lt;2x[16]xi1&gt;

Depends on #68418


Patch is 26.88 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/68586.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td (+82)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+52)
  • (modified) mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp (+1)
  • (modified) mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp (+82-3)
  • (added) mlir/test/Dialect/ArmSVE/invalid.mlir (+51)
  • (modified) mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir (+72-1)
  • (modified) mlir/test/Dialect/ArmSVE/roundtrip.mlir (+48-1)
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index d4294b4dd9fd4e8..cae87b764fc67dd 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -28,6 +28,8 @@ def ArmSVE_Dialect : Dialect {
     This dialect contains the definitions necessary to target specific Arm SVE
     scalable vector operations.
   }];
+
+  let dependentDialects = ["vector::VectorDialect"];
 }
 
 //===----------------------------------------------------------------------===//
@@ -40,6 +42,11 @@ def SVBool : ScalableVectorOfRankAndLengthAndType<
 def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
   [1], [16, 8, 4, 2, 1], [I1]>;
 
+// Generalizations of SVBool and SVEPredicate to ranks >= 1.
+// These are masks with a single trailing scalable dimension.
+def SVBoolMask : TrailingScalableVectorOfSizeAndType<[16], [I1]>;
+def SVEPredicateMask : TrailingScalableVectorOfSizeAndType<[16, 8, 4, 2, 1], [I1]>;
+
 //===----------------------------------------------------------------------===//
 // ArmSVE op definitions
 //===----------------------------------------------------------------------===//
@@ -236,6 +243,81 @@ def UmmlaOp : ArmSVE_Op<"ummla",
     "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
 }
 
+class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
+      "expected corresponding svbool type widened to [16]xi1",
+      lhsArg, rhsArg,
+      "VectorType(VectorType::Builder(::llvm::cast<VectorType>($_self)).setDim(::llvm::cast<VectorType>($_self).getRank() - 1, 16))">;
+
+def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool",
+                            [Pure, SvboolTypeConstraint<"result", "source">]>
+{
+  let summary = "Convert a svbool type to a SVE predicate type";
+  let description = [{
+    Converts svbool types (`vector<[16]xi1>` or vectors of that type, e.g.
+    `vector<2x3x[16]xi1>`) to SVE predicate types. Note: Only the trailing
+    dimension can be scalable.
+
+    Example 1: Convert a 1-D svbool mask to a SVE predicate.
+    ```mlir
+    %source = vector.load %memref[%c0] : memref<?xi1>, vector<[16]xi1>
+    %result = arm_sve.convert_from_svbool %source : vector<[4]xi1>
+    ```
+
+    Example 2: Convert a 2-D svbool mask to a mask of SVE predicates.
+    ```mlir
+    %source = vector.load %memref[%c0, %c0] : memref<2x?xi1>, vector<2x[16]xi1>
+    %result = arm_sve.convert_from_svbool %source : vector<2x[8]xi1>
+    ```
+
+    ---
+
+    A `svbool` is the smallest SVE predicate type that has a in-memory
+    representation (and maps to a full predicate register). In MLIR `svbool` is
+    represented as `vector<[16]xi1>`. Smaller SVE predicate types
+    (`vector<[1|2|4|8]xi1>`) must be stored as `svbool` then converted back to
+    a predicate after loading.
+  }];
+  let arguments = (ins SVBoolMask:$source);
+  let results = (outs SVEPredicateMask:$result);
+  let assemblyFormat = "$source attr-dict `:` type($result)";
+}
+
+def ConvertToSvboolOp : ArmSVE_Op<"convert_to_svbool",
+                            [Pure, SvboolTypeConstraint<"source", "result">]>
+{
+  let summary = "Convert a SVE predicate type to a svbool type";
+  let description = [{
+    Converts SVE predicate types (or vectors of predicate types, e.g.
+    `vector<4x[4]xi1>`) to svbool types. Note: Only the trailing dimension can
+    be scalable.
+
+    Example 1: Convert a 1-D SVE predicate to a svbool mask.
+    ```mlir
+    %source = vector.create_mask %dim_size : vector<[4]xi1>
+    %result = arm_sve.convert_to_svbool %source : vector<[4]xi1>
+    // => Results in vector<[16]xi1>
+    ```
+
+    Example 2: Convert a 2-D mask of SVE predicates to a svbool mask.
+    ```mlir
+    %source = vector.create_mask %c2, %dim_size : vector<2x[2]xi1>
+    %result = arm_sve.convert_to_svbool %source : vector<2x[2]xi1>
+    // => Results in vector<2x[16]xi1>
+    ```
+
+    ---
+
+    A `svbool` is the smallest SVE predicate type that has a in-memory
+    representation (and maps to a full predicate register). In MLIR `svbool` is
+    represented as `vector<[16]xi1>`. Smaller SVE predicate types
+    (`vector<[1|2|4|8]xi1>`) must be converted to a `svbool` before they can be
+    stored.
+  }];
+  let arguments = (ins SVEPredicateMask:$source);
+  let results = (outs SVBoolMask:$result);
+  let assemblyFormat = "$source attr-dict `:` type($source)";
+}
+
 def ScalableMaskedAddIOp : ScalableMaskedIOp<"masked.addi", "addition",
                                              [Commutative]>;
 
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 4fc14e30b8a10d0..a7970e59de8c27e 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -37,6 +37,19 @@ def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
 def IsScalableVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
                                    ::llvm::cast<VectorType>($_self).isScalable()}]>;
 
+// Whether a type is a scalable VectorType, with a single trailing scalable dimension.
+// Examples:
+// Valid:
+//   - vector<[4]xf32>, vector<2x3x[2]xi64>, vector<32x[8]xi32>
+// Invalid
+//   - vector<[4]x8xi32>, vector<[2]x[2]xf64>, vector<2x[8]x4xi32>
+def IsOnlyTrailingDimScalablePred : And<[
+  CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
+  CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
+  CPred<"::llvm::cast<::mlir::VectorType>($_self).getScalableDims().back()">,
+  CPred<"!llvm::is_contained(::llvm::cast<::mlir::VectorType>($_self).getScalableDims().drop_back(), true)">
+]>;
+
 // Whether a type is a VectorType and all dimensions are scalable.
 def allDimsScalableVectorTypePred : And<[
   IsVectorTypePred,
@@ -404,6 +417,12 @@ class ScalableVectorOf<list<Type> allowedTypes> :
   ShapedContainerType<allowedTypes, IsScalableVectorTypePred,
           "scalable vector", "::mlir::VectorType">;
 
+// Any vector with a single trailing scalable dimension, with an element type in
+// the `allowedTypes` list.
+class TrailingScalableVectorOf<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsOnlyTrailingDimScalablePred,
+          "trailing scalable vector", "::mlir::VectorType">;
+
 // Whether the number of elements of a vector is from the given
 // `allowedRanks` list
 class IsVectorOfRankPred<list<int> allowedRanks> :
@@ -481,10 +500,32 @@ class IsScalableVectorOfLengthPred<list<int> allowedLengths> :
                            == }]
                          # allowedlength>)>]>;
 
+class abs<int value> {
+  int ret = !if(!lt(value, 0), !sub(0, value), value);
+}
+
+// Whether the n-th (starting from 1) dim of the shape matches the given `size`.
+// Negative values index in reverse.
+class IsNthDimSizeIsOneOfPred<int n, list<int> allowedSizes>
+  : And<[CPred<"::llvm::cast<::mlir::ShapedType>($_self).getRank() >= " # abs<n>.ret>,
+        CPred<"::llvm::is_contained(ArrayRef<int64_t>({" # !interleave(allowedSizes, ", ") # "}), "
+          # "::llvm::cast<::mlir::ShapedType>($_self).getDimSize("
+          #   !if(!lt(n, 0),
+                "::llvm::cast<::mlir::ShapedType>($_self).getRank() + " # n,
+                "" # !sub(n, 1))
+          # "))">]>;
+
 // Whether the shape of a vector matches the given `shape` list.
 class IsVectorOfShape<list<int> shape>
   : CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(shape, ", ") # "})">;
 
+// Any ShapedType where the size of the n-th dim is contained in `sizes`.
+// Negative values index in reverse.
+class ShapedTypeWithNthDimOfSize<int n, list<int> allowedSizes> : Type<
+  IsNthDimSizeIsOneOfPred<n, allowedSizes>,
+  " with dim " # n # " having a size of {" # !interleave(allowedSizes, ", ") # "}",
+  "::mlir::ShapedType">;
+
 // Any vector where the number of elements is from the given
 // `allowedLengths` list
 class VectorOfLength<list<int> allowedLengths> : Type<
@@ -546,6 +587,17 @@ class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
   ScalableVectorOfLength<allowedLengths>.summary,
   "::mlir::VectorType">;
 
+// Any scalable vector with a single trailing scalable dimensions, where the
+// size of the trailing dimension is in `allowedTrailingSizes` list, and the
+// type is in the `allowedTypes` list.
+class TrailingScalableVectorOfSizeAndType<list<int> allowedTrailingSizes,
+                                           list<Type> allowedTypes> : AllOfType<
+  [TrailingScalableVectorOf<allowedTypes>,
+   ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>],
+   TrailingScalableVectorOf<allowedTypes>.summary #
+   ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>.summary,
+  "::mlir::VectorType">;
+
 def AnyVector : VectorOf<[AnyType]>;
 // Temporary vector type clone that allows gradual transition to 0-D vectors.
 def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
index b7f1020deba1e40..594c9b4c270f218 100644
--- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
+++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/OpImplementation.h"
diff --git a/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt
index fffc77245d12c93..9ef7384fc54925a 100644
--- a/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt
@@ -10,5 +10,6 @@ add_mlir_dialect_library(MLIRArmSVEDialect
   LINK_LIBS PUBLIC
   MLIRIR
   MLIRLLVMDialect
+  MLIRVectorDialect
   MLIRSideEffectInterfaces
   )
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
index 7031ab4f799c4d2..2f1c43fae240d51 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRArmSVETransforms
   LINK_LIBS PUBLIC
   MLIRArmSVEDialect
   MLIRFuncDialect
+  MLIRVectorDialect
   MLIRIR
   MLIRLLVMCommonConversion
   MLIRLLVMDialect
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index abbb978304068e2..ca9e280f510858c 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -12,6 +12,8 @@
 #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/PatternMatch.h"
 
@@ -66,6 +68,77 @@ using ScalableMaskedDivFOpLowering =
     OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
                                  ScalableMaskedDivFIntrOp>;
 
+namespace {
+
+/// Unrolls a conversion to/from equivalent vector types, to allow using a
+/// conversion intrinsic that only supports 1-D vector types.
+///
+/// Example:
+/// ```
+/// %result = arm_sve.convert_to_svbool %source : vector<2x[4]xi1>
+/// ```
+/// is rewritten into:
+/// ```
+/// %cst = arith.constant dense<false> : vector<2x[16]xi1>
+/// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1>
+/// %2 = "arm_sve.intr.convert.to.svbool"(%1)
+///                : (vector<[4]xi1>) -> vector<[16]xi1>
+/// %3 = vector.insert %2, %cst [0] : vector<[16]xi1> into vector<2x[16]xi1>
+/// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1>
+/// %5 = "arm_sve.intr.convert.to.svbool"(%4)
+///                : (vector<[4]xi1>) -> vector<[16]xi1>
+/// %result = vector.insert %5, %3 [1] : vector<[16]xi1> into vector<2x[16]xi1>
+/// ```
+template <typename Op, typename IntrOp>
+struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> {
+  using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(Op convertOp, typename Op::Adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = convertOp.getLoc();
+
+    auto source = convertOp.getSource();
+    VectorType sourceType = source.getType();
+    VectorType resultType = convertOp.getResult().getType();
+
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, resultType, rewriter.getZeroAttr(resultType));
+
+    // We want to iterate over the input vector in steps of the trailing
+    // dimension. So this creates tile shape where all leading dimensions are 1,
+    // and the trailing dimension step is the size of the dimension.
+    SmallVector<int64_t> tileShape(sourceType.getRank(), 1);
+    tileShape.back() = sourceType.getShape().back();
+
+    // Iterate over all scalable mask/predicate slices of the source vector.
+    for (SmallVector<int64_t> index :
+         StaticTileOffsetRange(sourceType.getShape(), tileShape)) {
+      auto extractOrInsertPosition = ArrayRef(index).drop_back();
+      auto sourceVector = rewriter.create<vector::ExtractOp>(
+          loc, source, extractOrInsertPosition);
+      auto convertedType =
+          VectorType::Builder(llvm::cast<VectorType>(sourceVector.getType()))
+              .setDim(0, resultType.getShape().back());
+      auto convertedVector =
+          rewriter.create<IntrOp>(loc, TypeRange{convertedType}, sourceVector);
+      result = rewriter.create<vector::InsertOp>(loc, convertedVector, result,
+                                                 extractOrInsertPosition);
+    }
+
+    rewriter.replaceOp(convertOp, result);
+    return success();
+  }
+};
+
+using ConvertToSvboolOpLowering =
+    SvboolConversionOpLowering<ConvertToSvboolOp, ConvertToSvboolIntrOp>;
+
+using ConvertFromSvboolOpLowering =
+    SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>;
+
+} // namespace
+
 /// Populate the given list with patterns that convert from ArmSVE to LLVM.
 void mlir::populateArmSVELegalizeForLLVMExportPatterns(
     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
@@ -88,7 +161,9 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
                ScalableMaskedMulFOpLowering,
                ScalableMaskedSDivIOpLowering,
                ScalableMaskedUDivIOpLowering,
-               ScalableMaskedDivFOpLowering>(converter);
+               ScalableMaskedDivFOpLowering,
+               ConvertToSvboolOpLowering,
+               ConvertFromSvboolOpLowering>(converter);
   // clang-format on
 }
 
@@ -107,7 +182,9 @@ void mlir::configureArmSVELegalizeForExportTarget(
                     ScalableMaskedMulFIntrOp,
                     ScalableMaskedSDivIIntrOp,
                     ScalableMaskedUDivIIntrOp,
-                    ScalableMaskedDivFIntrOp>();
+                    ScalableMaskedDivFIntrOp,
+                    ConvertToSvboolIntrOp,
+                    ConvertFromSvboolIntrOp>();
   target.addIllegalOp<SdotOp,
                       SmmlaOp,
                       UdotOp,
@@ -120,6 +197,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
                       ScalableMaskedMulFOp,
                       ScalableMaskedSDivIOp,
                       ScalableMaskedUDivIOp,
-                      ScalableMaskedDivFOp>();
+                      ScalableMaskedDivFOp,
+                      ConvertToSvboolOp,
+                      ConvertFromSvboolOp>();
   // clang-format on
 }
diff --git a/mlir/test/Dialect/ArmSVE/invalid.mlir b/mlir/test/Dialect/ArmSVE/invalid.mlir
new file mode 100644
index 000000000000000..a1fa0d0292b7b76
--- /dev/null
+++ b/mlir/test/Dialect/ArmSVE/invalid.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+// -----
+
+func.func @arm_sve_convert_from_svbool__bad_mask_type(%bool: vector<2x[16]xi1>) -> vector<2x[8]xi2> {
+  // expected-error@+1 {{'result' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<2x[8]xi2>'}}
+  %mask = arm_sve.convert_from_svbool %bool : vector<2x[8]xi2>
+  return %mask : vector<2x[8]xi2>
+}
+
+// -----
+
+func.func @arm_sve_convert_from_svbool__bad_mask_shape(%bool : vector<[16]xi1>) -> vector<[7]xi1> {
+  // expected-error@+1 {{'result' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<[7]xi1>'}}
+  %mask = arm_sve.convert_from_svbool %bool : vector<[7]xi1>
+  return %mask : vector<[7]xi1>
+}
+
+// -----
+
+func.func @arm_sve_convert_from_svbool__bad_mask_scalability(%bool : vector<[4]x[16]xi1>) -> vector<[4]x[8]xi1> {
+  // expected-error@+1 {{'result' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<[4]x[8]xi1>'}}
+  %mask = arm_sve.convert_from_svbool %bool : vector<[4]x[8]xi1>
+  return %mask : vector<[4]x[8]xi1>
+}
+
+// -----
+
+func.func @arm_sve_convert_to_svbool__bad_mask_type(%mask: vector<2x[8]xi2>) -> vector<2x[16]xi1> {
+  // expected-error@+1 {{'source' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<2x[8]xi2>'}}
+  %bool = arm_sve.convert_to_svbool %mask : vector<2x[8]xi2>
+  return %bool : vector<2x[16]xi1>
+}
+
+// -----
+
+func.func @arm_sve_convert_to_svbool__bad_mask_shape(%mask : vector<[7]xi1>) -> vector<[16]xi1> {
+  // expected-error@+1 {{'source' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<[7]xi1>'}}
+  %bool = arm_sve.convert_to_svbool %mask : vector<[7]xi1>
+  return
+}
+
+// -----
+
+func.func @arm_sve_convert_to_svbool__bad_mask_scalability(%mask : vector<[4]x[8]xi1>) -> vector<[4]x[16]xi1> {
+  // expected-error@+1 {{'source' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<[4]x[8]xi1>'}}
+  %bool = arm_sve.convert_to_svbool %mask : vector<[4]x[8]xi1>
+  return
+}
+
+
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 2d980db981034dd..8e76fb7119b844e 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sve" -convert-func-to-llvm -reconcile-unrealized-casts | mlir-opt | FileCheck %s
+// RUN: mlir-opt -convert-vector-to-llvm="enable-arm-sve" -convert-func-to-llvm -reconcile-unrealized-casts -split-input-file %s | FileCheck %s
 
 func.func @arm_sve_sdot(%a: vector<[16]xi8>,
                    %b: vector<[16]xi8>,
@@ -10,6 +10,8 @@ func.func @arm_sve_sdot(%a: vector<[16]xi8>,
   return %0 : vector<[4]xi32>
 }
 
+// -----
+
 func.func @arm_sve_smmla(%a: vector<[16]xi8>,
                     %b: vector<[16]xi8>,
                     %c: vector<[4]xi32>)
@@ -20,6 +22,8 @@ func.func @arm_sve_smmla(%a: vector<[16]xi8>,
   return %0 : vector<[4]xi32>
 }
 
+// -----
+
 func.func @arm_sve_udot(%a: vector<[16]xi8>,
                    %b: vector<[16]xi8>,
                    %c: vector<[4]xi32>)
@@ -30,6 +34,8 @@ func.func @arm_sve_udot(%a: vector<[16]xi8>,
   return %0 : vector<[4]xi32>
 }
 
+// -----
+
 func.func @arm_sve_ummla(%a: vector<[16]xi8>,
                     %b: vector<[16]xi8>,
                     %c: vector<[4]xi32>)
@@ -40,6 +46,8 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>,
   return %0 : vector<[4]xi32>
 }
 
+// -----
+
 func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>,
                             %b: vector<[4]xi32>,
                             %c: vector<[4]xi32>,
@@ -65,6 +73,8 @@ func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>,
   return %4 : vector<[4]xi32>
 }
 
+// -----
+
 func.func @arm_sve_arithf_masked(%a: vector<[4]xf32>,
                             %b: vector<[4]xf32>,
                             %c: vector<[4]xf32>,
@@ -87,6 +97,8 @@ func.func @arm_sve_arithf_masked(%a: vector<[4]xf32>,
   return %3 : vector<[4]xf32>
 }
 
+// -----
+
 func.func @arm_sve_ab...
[truncated]

mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td Outdated Show resolved Hide resolved
Comment on lines +73 to +74
/// Unrolls a conversion to/from equivalent vector types, to allow using a
/// conversion intrinsic that only supports 1-D vector types.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is doing two things and the op doesn't map 1-1 with intrinsics. Based on previous feedback I've received and general observations, I wonder if the unrolling should be done as a separate transform and this simply maps rank 1 arm_sve.convert_to_svbool / arm_sve.convert_from_svbool to intrinsics?

It wouldn't have to be done as part of this patch, but something to consider.

Copy link
Member Author

@MacDue MacDue Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the whole reason these op exists though? If this mapped 1-1 to the intrinsics there'd be no reason for these ops to exist.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I understand that, but it doesn't mean both steps have to be done when lowering to intrinsics. At the Vector dialect level for example operations on higher rank vectors are typically broken down to rank 1 vectors first (VectorToSCF), before lowering to intrinsics.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That'll be quite a big rework as the ArmSVE dialect is pretty much just a skeleton, so currently has no -arm-sve-to-scf (or similar) passes right now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think -arm-sve-to-scf would make sense since this isn't emitting SCF ops, but I get your point, it would need some to live somewhere that doesn't exist today. Like I said, I'm happy with this regardless.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I second Cullen here. It's not obvious to me what the right "pass" would be and where would it live, but we should add a TODO in the comments and in the commit message. Something along the lines:

Extract the lowering of convert_to_svbool into a series of vector.extract/insert operations into a dedicated lowering layer/pass pre "LegalizeForLLVMExport". The latter should be as simple as 1-1 mapping.

mlir/include/mlir/IR/CommonTypeConstraints.td Outdated Show resolved Hide resolved
Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the updates!

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few small questions, few small suggestions. Nothing major, just still struggling a bit with the new constraints. A few extra comments should be sufficient.

mlir/include/mlir/IR/CommonTypeConstraints.td Outdated Show resolved Hide resolved
mlir/include/mlir/IR/CommonTypeConstraints.td Outdated Show resolved Hide resolved
mlir/include/mlir/IR/CommonTypeConstraints.td Outdated Show resolved Hide resolved
Comment on lines 513 to 514
// Whether the n-th dim of the shape matches the given `size`.
// Negative values index in reverse.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is size? Or did you mean allowedSizes[n]?

In general, quite unsure what this does. Could you add an example of what would be accepted and what would be rejected?

Copy link
Member Author

@MacDue MacDue Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant:

// Whether the n-th dim of the shape is contained within `allowedSizes`.
// Negative values for `n` index in reverse.

Examples:
IsNthDimSizeIsOneOfPred<0, {2, 3, 4}>
Accepts any shape where the first dim is 2, 3, or 4.
=> This means shapes like: 2x8x9x5, 4, 3x1, 4x?, etc

IsNthDimSizeIsOneOfPred<-1, {16}>
Accepts any shape where the last dim is 16.
=> This means shapes like 2x16, 16, 1x2x3x4x16, etc

IsNthDimSizeIsOneOfPred<-2, {10, 5}>
Accepts any shape where the second to last dim is 10 or 5.
=> This means shapes like: 1x10x2, 2x1x4x5x6, 8x10x?, etc

mlir/include/mlir/IR/CommonTypeConstraints.td Outdated Show resolved Hide resolved
"::llvm::cast<::mlir::ShapedType>($_self).getRank() + " # n,
"" # n)
# "))">]>;

// Whether the shape of a vector matches the given `shape` list.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't "mix" IsVectorOfShape (which is fairly basic) with the new definitions (which are quite complex)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've moved it to near the end of the list now

mlir/include/mlir/IR/CommonTypeConstraints.td Outdated Show resolved Hide resolved
Comment on lines +73 to +74
/// Unrolls a conversion to/from equivalent vector types, to allow using a
/// conversion intrinsic that only supports 1-D vector types.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I second Cullen here. It's not obvious to me what the right "pass" would be and where would it live, but we should add a TODO in the comments and in the commit message. Something along the lines:

Extract the lowering of convert_to_svbool into a series of vector.extract/insert operations into a dedicated lowering layer/pass pre "LegalizeForLLVMExport". The latter should be as simple as 1-1 mapping.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for addressing my comments!

This is quite a non-trivial change, great job designing and implementing this!

@MacDue
Copy link
Member Author

MacDue commented Oct 12, 2023

Currently seeing some UBSAN issues in VectorType::Builder. These should be fixed by b44b349, the underlying bug in VectorType::Builder should be addressed in another patch.

MacDue added a commit that referenced this pull request Oct 26, 2023
This patch adds a pass that ensures that loads, stores, and allocations
of SVE vector types will be legal in the LLVM backend. It does this at
the memref level, so this pass must be applied before lowering all the
way to LLVM.

This pass currently fixes two issues.

## Loading and storing predicate types

It is only legal to load/store predicate types equal to (or greater
than) a full predicate register, which in MLIR is `vector<[16]xi1>`.
Smaller predicate types (`vector<[1|2|4|8]xi1>`) must be converted
to/from a full predicate type (referred to as a `svbool`) before and
after storing and loading respectively. This pass does this by widening
allocations and inserting conversion intrinsics.

For example:


```mlir
%alloca = memref.alloca() : memref<vector<[4]xi1>>
%mask = vector.constant_mask [4] : vector<[4]xi1>
memref.store %mask, %alloca[] : memref<vector<[4]xi1>>
%reload = memref.load %alloca[] : memref<vector<[4]xi1>>
```
Becomes:
```mlir
%alloca = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
%mask = vector.constant_mask [4] : vector<[4]xi1>
%svbool = arm_sve.convert_to_svbool %mask : vector<[4]xi1>
memref.store %svbool, %alloca[] : memref<vector<[16]xi1>>
%reload_svbool = memref.load %alloca[] : memref<vector<[16]xi1>>
%reload = arm_sve.convert_from_svbool %reload_svbool : vector<[4]xi1>
```

## Relax alignments for SVE vector allocas

The storage for SVE vector types only needs to have an alignment that
matches the element type (for example 4 byte alignment for `f32`s).
However, the LLVM backend currently defaults to aligning to `base size x
element size` bytes. For non-legal vector types like `vector<[8]xf32>`
this results in 8 x 4 = 32-byte alignment, but the backend only supports
up to 16-byte alignment for SVE vectors on the stack. Explicitly setting
a smaller alignment prevents this issue.

Depends on: #68586 and #68695 (for testing)
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Oct 26, 2023
This patch adds a pass that ensures that loads, stores, and allocations
of SVE vector types will be legal in the LLVM backend. It does this at
the memref level, so this pass must be applied before lowering all the
way to LLVM.

This pass currently fixes two issues.

## Loading and storing predicate types

It is only legal to load/store predicate types equal to (or greater
than) a full predicate register, which in MLIR is `vector<[16]xi1>`.
Smaller predicate types (`vector<[1|2|4|8]xi1>`) must be converted
to/from a full predicate type (referred to as a `svbool`) before and
after storing and loading respectively. This pass does this by widening
allocations and inserting conversion intrinsics.

For example:


```mlir
%alloca = memref.alloca() : memref<vector<[4]xi1>>
%mask = vector.constant_mask [4] : vector<[4]xi1>
memref.store %mask, %alloca[] : memref<vector<[4]xi1>>
%reload = memref.load %alloca[] : memref<vector<[4]xi1>>
```
Becomes:
```mlir
%alloca = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
%mask = vector.constant_mask [4] : vector<[4]xi1>
%svbool = arm_sve.convert_to_svbool %mask : vector<[4]xi1>
memref.store %svbool, %alloca[] : memref<vector<[16]xi1>>
%reload_svbool = memref.load %alloca[] : memref<vector<[16]xi1>>
%reload = arm_sve.convert_from_svbool %reload_svbool : vector<[4]xi1>
```

## Relax alignments for SVE vector allocas

The storage for SVE vector types only needs to have an alignment that
matches the element type (for example 4 byte alignment for `f32`s).
However, the LLVM backend currently defaults to aligning to `base size x
element size` bytes. For non-legal vector types like `vector<[8]xf32>`
this results in 8 x 4 = 32-byte alignment, but the backend only supports
up to 16-byte alignment for SVE vectors on the stack. Explicitly setting
a smaller alignment prevents this issue.

Depends on: llvm#68586 and llvm#68695 (for testing)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants