-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add isBatchVecmat
utilities for linalg.batch_vecmat
#70284
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: None (bjacob) Changes
Full diff: https://github.com/llvm/llvm-project/pull/70284.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 44e82f452b3cef1..69ca888a8acdbe0 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -98,6 +98,17 @@ def LinalgContractionOpInterface : OpInterface<"ContractionOpInterface"> {
return mlir::isVecmat($_op.getIndexingMaps());
}]>,
InterfaceMethod<
+ /*desc=*/[{
+ Returns whether the given op has indexing maps that correspond to a
+ batched vector-matrix multiplication.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"isBatchVecmat",
+ /*args=*/(ins),
+ /*methodBody=*/[{
+ return mlir::isBatchVecmat($_op.getIndexingMaps());
+ }]>,
+ InterfaceMethod<
/*desc=*/[{
Returns whether the given op has indexing maps that correspond to a
matrix-vector multiplication.
diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index 225b9f287d340db..134c5569fbb2f3e 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -55,6 +55,12 @@ bool isRowMajorBatchMatmul(ArrayAttr indexingMaps);
/// performed within the reduction.
bool isVecmat(ArrayAttr indexingMaps);
+/// Tests whether the given maps describe a batch vector matrix multiplication.
+/// The test is permutation-invariant. Note that this only checks the affine
+/// maps from an operation, so does not perform any checks on the math being
+/// performed within the reduction.
+bool isBatchVecmat(ArrayAttr indexingMaps);
+
/// Tests whether the given maps describe a matrix vector multiplication. The
/// test is permutation-invariant. Note that this only checks the affine maps
/// from an operation, so does not perform any checks on the math being
diff --git a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
index 641ddf3f91cb2d9..383ef1cea53fd30 100644
--- a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
@@ -120,6 +120,31 @@ bool mlir::isVecmat(ArrayAttr indexingMaps) {
return indexingMaps == maps;
}
+bool mlir::isBatchVecmat(ArrayAttr indexingMaps) {
+ if (indexingMaps.size() != 3)
+ return false;
+ AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
+ AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
+ AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
+
+ if (map0.getNumResults() != 2 || map1.getNumResults() != 3 ||
+ map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
+ map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
+ return false;
+ }
+
+ // Extract dimensions for B*K * B*K*N -> B*N
+ AffineExpr b = map0.getResult(0);
+ AffineExpr k = map0.getResult(1);
+ AffineExpr n = map2.getResult(1);
+ auto *context = indexingMaps.getContext();
+ auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {b, k}, context));
+ auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {b, k, n}, context));
+ auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {b, n}, context));
+ auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
+ return indexingMaps == maps;
+}
+
bool mlir::isMatvec(ArrayAttr indexingMaps) {
if (indexingMaps.size() != 3)
return false;
diff --git a/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp
index 3f576bacebf6aad..d257fc5d6e041d1 100644
--- a/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp
+++ b/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp
@@ -370,4 +370,56 @@ TEST(isBatchMatvec, WrongDimOrderMatrix) {
EXPECT_THAT(maps, Not(Truly(isBatchMatvec)));
}
+TEST(isBatchVecmat, Simple) {
+ MLIRContext context;
+
+ AffineExpr batch, k, n;
+ bindDims(&context, batch, k, n);
+ auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
+ auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k, n}, &context));
+ auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
+ auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
+
+ EXPECT_THAT(maps, Truly(isBatchVecmat));
+}
+
+TEST(isBatchVecmat, BindingSwapped) {
+ MLIRContext context;
+
+ AffineExpr batch, k, n;
+ bindDims(&context, batch, n, k); // bind in different order
+ auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
+ auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k, n}, &context));
+ auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
+ auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
+
+ EXPECT_THAT(maps, Truly(isBatchVecmat));
+}
+
+TEST(isBatchVecmat, Matmul) {
+ MLIRContext context;
+
+ AffineExpr m, n, k;
+ bindDims(&context, m, n, k);
+ auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
+ auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
+ auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
+ auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
+
+ EXPECT_THAT(maps, Not(Truly(isBatchVecmat)));
+}
+
+TEST(isBatchVecmat, WrongDimOrderMatrix) {
+ MLIRContext context;
+
+ AffineExpr batch, k, n;
+ bindDims(&context, batch, k, n);
+ auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
+ auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n, k}, &context));
+ auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
+ auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
+
+ EXPECT_THAT(maps, Not(Truly(isBatchVecmat)));
+}
+
} // namespace
|
bool mlir::isBatchVecmat(ArrayAttr indexingMaps) { | ||
if (indexingMaps.size() != 3) | ||
return false; | ||
AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you could use the infer
utils which I find nicer to work with:
if (maps == infer({{m, k}, {k, n}, {m, n}})) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i see. nicer indeed, but i was copy and pasting other instances in this file. The move to infer
looks like it would be a nice improvement, but should be done for the whole file togeteher, so, in a separate PR.
`linalg.batch_vecmat` was just added in llvm#70218, but I forgot then to add the standard `isBatchVecmat` utilities
linalg.batch_vecmat
was just added in #70218, but I forgot then to add the standardisBatchVecmat
utilities