-
Notifications
You must be signed in to change notification settings - Fork 12.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][SVE] Add an e2e test for vector.contract #69845
[mlir][SVE] Add an e2e test for vector.contract #69845
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Andrzej Warzyński (banach-space) Changes
Patch is 61.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/69845.diff 20 Files Affected:
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
new file mode 100644
index 000000000000000..7e26858589f2756
--- /dev/null
+++ b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
@@ -0,0 +1,36 @@
+
+//===- PrintCallHelper.h - LLVM Interfaces ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LLVMIR_PRINTCALLHELPER_H_
+#define MLIR_DIALECT_LLVMIR_PRINTCALLHELPER_H_
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace mlir {
+
+class Location;
+class ModuleOp;
+class OpBuilder;
+class Operation;
+class Type;
+class ValueRange;
+class LLVMTypeConverter;
+
+namespace LLVM {
+
+/// Generate IR that prints the given string to stdout.
+void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp,
+ StringRef symbolName, StringRef string,
+ const LLVMTypeConverter &typeConverter);
+} // namespace LLVM
+
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
index f33061b2d87cffc..9f57627c321fb0c 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
@@ -1 +1,2 @@
add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000000..7226642daf86172
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/CMakeLists.txt
@@ -0,0 +1,5 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name ArmSVE)
+add_public_tablegen_target(MLIRArmSVEPassIncGen)
+
+add_mlir_doc(Passes ArmSVEPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.h
new file mode 100644
index 000000000000000..317fb9021b3c577
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.h
@@ -0,0 +1,33 @@
+//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_H
+#define MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_H
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir::arm_sve {
+
+#define GEN_PASS_DECL
+#include "mlir/Dialect/ArmSVE/Transforms/Passes.h.inc"
+
+/// Pass to legalize the types of mask stores.
+std::unique_ptr<Pass> createLegalizeVectorStoragePass();
+
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+/// Generate the code for registering passes.
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/ArmSVE/Transforms/Passes.h.inc"
+
+} // namespace mlir::arm_sve
+
+#endif // MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_H
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.td
new file mode 100644
index 000000000000000..35c49607181da0c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.td
@@ -0,0 +1,67 @@
+//===-- Passes.td - ArmSVE pass definition file ------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_TD
+#define MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_TD
+
+include "mlir/Pass/PassBase.td"
+
+def LegalizeVectorStorage
+ : Pass<"arm-sve-legalize-vector-storage", "mlir::func::FuncOp"> {
+ let summary = "Ensures stores of SVE vector types will be legal";
+ let description = [{
+ This pass ensures that loads, stores, and allocations of SVE vector types
+ will be legal in the LLVM backend. It does this at the memref level, so this
+ pass must be applied before lowering all the way to LLVM.
+
+ This pass currently fixes two issues.
+
+ ## Loading and storing predicate types
+
+ It is only legal to load/store predicate types equal to (or greater than) a
+ full predicate register, which in MLIR is `vector<[16]xi1>`. Smaller
+ predicate types (`vector<[1|2|4|8]xi1>`) must be converted to/from a full
+ predicate type (referred to as a `svbool`) before and after storing and
+ loading respectively. This pass does this by widening allocations and
+ inserting conversion intrinsics.
+
+ For example:
+
+ ```mlir
+ %alloca = memref.alloca() : memref<vector<[4]xi1>>
+ %mask = vector.constant_mask [4] : vector<[4]xi1>
+ memref.store %mask, %alloca[] : memref<vector<[4]xi1>>
+ %reload = memref.load %alloca[] : memref<vector<[4]xi1>>
+ ```
+ Becomes:
+ ```mlir
+ %alloca = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
+ %mask = vector.constant_mask [4] : vector<[4]xi1>
+ %svbool = arm_sve.convert_to_svbool %mask : vector<[4]xi1>
+ memref.store %svbool, %alloca[] : memref<vector<[16]xi1>>
+ %reload_svbool = memref.load %alloca[] : memref<vector<[16]xi1>>
+ %reload = arm_sve.convert_from_svbool %reload_svbool : vector<[4]xi1>
+ ```
+
+ ## Relax alignments for SVE vector allocas
+
+ The storage for SVE vector types only needs to have an alignment that
+ matches the element type (for example 4 byte alignment for `f32`s). However,
+ the LLVM backend currently defaults to aligning to `base size` x
+ `element size` bytes. For non-legal vector types like `vector<[8]xf32>` this
+ results in 8 x 4 = 32-byte alignment, but the backend only supports up to
+ 16-byte alignment for SVE vectors on the stack. Explicitly setting a smaller
+ alignment prevents this issue.
+ }];
+ let constructor = "mlir::arm_sve::createLegalizeVectorStoragePass()";
+ let dependentDialects = ["func::FuncDialect",
+ "memref::MemRefDialect", "vector::VectorDialect",
+ "arm_sve::ArmSVEDialect"];
+}
+
+#endif // MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 2df2fe4c5ce8e9c..2b60055ca9db94b 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -26,6 +26,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/VectorInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
+include "mlir/IR/BuiltinAttributes.td"
// TODO: Add an attribute to specify a different algebra with operators other
// than the current set: {*, +}.
@@ -2476,12 +2477,18 @@ def Vector_TransposeOp :
}
def Vector_PrintOp :
- Vector_Op<"print", []>,
+ Vector_Op<"print", [
+ PredOpTrait<
+ "`source` or `punctuation` are not set printing strings",
+ CPred<"!getStringLiteral() || (!getSource() && getPunctuation() == PrintPunctuation::NewLine)">
+ >,
+ ]>,
Arguments<(ins Optional<Type<Or<[
AnyVectorOfAnyRank.predicate,
AnyInteger.predicate, Index.predicate, AnyFloat.predicate
]>>>:$source, DefaultValuedAttr<Vector_PrintPunctuation,
- "::mlir::vector::PrintPunctuation::NewLine">:$punctuation)
+ "::mlir::vector::PrintPunctuation::NewLine">:$punctuation,
+ OptionalAttr<Builtin_StringAttr>:$stringLiteral)
> {
let summary = "print operation (for testing and debugging)";
let description = [{
@@ -2520,6 +2527,13 @@ def Vector_PrintOp :
```mlir
vector.print punctuation <newline>
```
+
+ Additionally, to aid with debugging and testing `vector.print` can also
+ print constant strings:
+
+ ```mlir
+ vector.print str "Hello, World!"
+ ```
}];
let extraClassDeclaration = [{
Type getPrintType() {
@@ -2528,11 +2542,26 @@ def Vector_PrintOp :
}];
let builders = [
OpBuilder<(ins "PrintPunctuation":$punctuation), [{
- build($_builder, $_state, {}, punctuation);
+ build($_builder, $_state, {}, punctuation, {});
+ }]>,
+ OpBuilder<(ins "::mlir::Value":$source), [{
+ build($_builder, $_state, source, PrintPunctuation::NewLine);
+ }]>,
+ OpBuilder<(ins "::mlir::Value":$source, "PrintPunctuation":$punctuation), [{
+ build($_builder, $_state, source, punctuation, {});
+ }]>,
+ OpBuilder<(ins "::llvm::StringRef":$string), [{
+ build($_builder, $_state, {}, PrintPunctuation::NewLine, $_builder.getStringAttr(string));
}]>,
];
- let assemblyFormat = "($source^ `:` type($source))? (`punctuation` $punctuation^)? attr-dict";
+ let assemblyFormat = [{
+ ($source^ `:` type($source))?
+ oilist(
+ `str` $stringLiteral
+ | `punctuation` $punctuation)
+ attr-dict
+ }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 5489a13a8040bdb..7301905954f56d8 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -19,6 +19,7 @@
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Passes.h"
#include "mlir/Dialect/Async/Passes.h"
#include "mlir/Dialect/Bufferization/Pipelines/Passes.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
@@ -82,6 +83,7 @@ inline void registerAllPasses() {
transform::registerTransformPasses();
vector::registerVectorPasses();
arm_sme::registerArmSMEPasses();
+ arm_sve::registerArmSVEPasses();
// Dialect pipelines
bufferization::registerBufferizationPipelines();
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index a4f146bbe475cc6..6b7647b038f1d94 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -16,6 +16,7 @@
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
@@ -36,51 +37,6 @@ using namespace mlir;
#define PASS_NAME "convert-cf-to-llvm"
-static std::string generateGlobalMsgSymbolName(ModuleOp moduleOp) {
- std::string prefix = "assert_msg_";
- int counter = 0;
- while (moduleOp.lookupSymbol(prefix + std::to_string(counter)))
- ++counter;
- return prefix + std::to_string(counter);
-}
-
-/// Generate IR that prints the given string to stderr.
-static void createPrintMsg(OpBuilder &builder, Location loc, ModuleOp moduleOp,
- StringRef msg,
- const LLVMTypeConverter &typeConverter) {
- auto ip = builder.saveInsertionPoint();
- builder.setInsertionPointToStart(moduleOp.getBody());
- MLIRContext *ctx = builder.getContext();
-
- // Create a zero-terminated byte representation and allocate global symbol.
- SmallVector<uint8_t> elementVals;
- elementVals.append(msg.begin(), msg.end());
- elementVals.push_back(0);
- auto dataAttrType = RankedTensorType::get(
- {static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
- auto dataAttr =
- DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
- auto arrayTy =
- LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
- std::string symbolName = generateGlobalMsgSymbolName(moduleOp);
- auto globalOp = builder.create<LLVM::GlobalOp>(
- loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, symbolName,
- dataAttr);
-
- // Emit call to `printStr` in runtime library.
- builder.restoreInsertionPoint(ip);
- auto msgAddr = builder.create<LLVM::AddressOfOp>(
- loc, typeConverter.getPointerType(arrayTy), globalOp.getName());
- SmallVector<LLVM::GEPArg> indices(1, 0);
- Value gep = builder.create<LLVM::GEPOp>(
- loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr,
- indices);
- Operation *printer = LLVM::lookupOrCreatePrintStrFn(
- moduleOp, typeConverter.useOpaquePointers());
- builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
- gep);
-}
-
namespace {
/// Lower `cf.assert`. The default lowering calls the `abort` function if the
/// assertion is violated and has no effect otherwise. The failure message is
@@ -105,7 +61,8 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
// Failed block: Generate IR to print the message and call `abort`.
Block *failureBlock = rewriter.createBlock(opBlock->getParent());
- createPrintMsg(rewriter, loc, module, op.getMsg(), *getTypeConverter());
+ LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(),
+ *getTypeConverter());
if (abortOnFailedAssert) {
// Insert the `abort` declaration if necessary.
auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
diff --git a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
index 091cd539f0ae014..568d9339aaabcb4 100644
--- a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
+++ b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_conversion_library(MLIRLLVMCommonConversion
LoweringOptions.cpp
MemRefBuilder.cpp
Pattern.cpp
+ PrintCallHelper.cpp
StructBuilder.cpp
TypeConverter.cpp
VectorPattern.cpp
diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
new file mode 100644
index 000000000000000..487abb435d10ad7
--- /dev/null
+++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
@@ -0,0 +1,66 @@
+
+//===- PrintCallHelper.cpp - LLVM Interfaces --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "llvm/ADT/ArrayRef.h"
+
+using namespace mlir;
+using namespace llvm;
+
+static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp,
+ StringRef symbolName) {
+ static int counter = 0;
+ std::string uniqueName = std::string(symbolName);
+ while (moduleOp.lookupSymbol(uniqueName)) {
+ uniqueName = std::string(symbolName) + "_" + std::to_string(counter++);
+ }
+ return uniqueName;
+}
+
+void mlir::LLVM::createPrintStrCall(OpBuilder &builder, Location loc,
+ ModuleOp moduleOp, StringRef symbolName,
+ StringRef string,
+ const LLVMTypeConverter &typeConverter) {
+ auto ip = builder.saveInsertionPoint();
+ builder.setInsertionPointToStart(moduleOp.getBody());
+ MLIRContext *ctx = builder.getContext();
+
+ // Create a zero-terminated byte representation and allocate global symbol.
+ SmallVector<uint8_t> elementVals;
+ elementVals.append(string.begin(), string.end());
+ elementVals.push_back(0);
+ auto dataAttrType = RankedTensorType::get(
+ {static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
+ auto dataAttr =
+ DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
+ auto arrayTy =
+ LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
+ auto globalOp = builder.create<LLVM::GlobalOp>(
+ loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private,
+ ensureSymbolNameIsUnique(moduleOp, symbolName), dataAttr);
+
+ // Emit call to `printStr` in runtime library.
+ builder.restoreInsertionPoint(ip);
+ auto msgAddr = builder.create<LLVM::AddressOfOp>(
+ loc, typeConverter.getPointerType(arrayTy), globalOp.getName());
+ SmallVector<LLVM::GEPArg> indices(1, 0);
+ Value gep = builder.create<LLVM::GEPOp>(
+ loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr,
+ indices);
+ Operation *printer = LLVM::lookupOrCreatePrintStrFn(
+ moduleOp, typeConverter.useOpaquePointers());
+ builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
+ gep);
+}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 8427d60f14c0bcc..4af58653c8227ae 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -9,6 +9,7 @@
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
+#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -1548,7 +1549,10 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
}
auto punct = printOp.getPunctuation();
- if (punct != PrintPunctuation::NoPunctuation) {
+ if (auto stringLiteral = printOp.getStringLiteral()) {
+ LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
+ *stringLiteral, *getTypeConverter());
+ } else if (punct != PrintPunctuation::NoPunctuation) {
emitCall(rewriter, printOp->getLoc(), [&] {
switch (punct) {
case PrintPunctuation::Close:
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
index 2f1c43fae240d51..a70c489a51fea9a 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
@@ -1,8 +1,10 @@
add_mlir_dialect_library(MLIRArmSVETransforms
LegalizeForLLVMExport.cpp
+ LegalizeVectorStorage.cpp
DEPENDS
MLIRArmSVEConversionsIncGen
+ MLIRArmSVEPassIncGen
LINK_LIBS PUBLIC
MLIRArmSVEDialect
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
new file mode 100644
index 000000000000000..bb5a36d70351831
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
@@ -0,0 +1,311 @@
+//===- LegalizeVectorStorage.cpp - Ensures SVE loads/stores are legal -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#i...
[truncated]
|
Is this for review? |
Yes, but only the last commit - first 2 are just dependencies. Apologies for not updating the summary earlier (the default generated by GitHub when uploading multiple commits is always confusing). |
54387a9
to
d5985c7
Compare
Adds an end-to-end test for `vector.contract` that targets SVE (i.e. scalable vectors). Note that this requires lifting the restriction on `vector.outerproduct` (to which `vector.contract` is lowered to) that would deem the following as invalid by the Op verifier (*): ``` vector.outerproduct %27, %28, %26 {kind = #vector.kind<add>} : vector<3xf32>, vector<[2]xf32> ``` This is indeed valid as the end-to-end test demonstrates (at least when compiling for SVE). Depends on llvm#68794
d5985c7
to
74d8670
Compare
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-contraction.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-contraction.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-contraction.mlir
Outdated
Show resolved
Hide resolved
Simplfiy following Ben's advice
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-contraction.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-contraction.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-contraction.mlir
Outdated
Show resolved
Hide resolved
Remove masks and unused variables, make sure that the file is compiled only once, add comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for the changes :)
Adds basic integration tests for `vector.contract` for the dot product and matvec operations. These tests exercise scalable vectors. Depends on #69845
Adds an end-to-end test for
vector.contract
that targets SVE (i.e.scalable vectors). Note that this requires lifting the restriction on
vector.outerproduct
(to whichvector.contract
is lowered to) thatwould deem the following as invalid by the Op verifier (*):
This is indeed valid as the end-to-end test demonstrates (at least when
compiling for SVE).