Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add isBatchVecmat utilities for linalg.batch_vecmat #70284

Merged
merged 1 commit into from
Oct 26, 2023

Conversation

bjacob
Copy link
Contributor

@bjacob bjacob commented Oct 26, 2023

linalg.batch_vecmat was just added in #70218, but I forgot then to add the standard isBatchVecmat utilities

@bjacob bjacob marked this pull request as ready for review October 26, 2023 02:53
@llvmbot
Copy link
Member

llvmbot commented Oct 26, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: None (bjacob)

Changes

linalg.batch_vecmat was just added in #70218, but I forgot then to add the standard isBatchVecmat utilities


Full diff: https://github.com/llvm/llvm-project/pull/70284.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td (+11)
  • (modified) mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h (+6)
  • (modified) mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp (+25)
  • (modified) mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp (+52)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 44e82f452b3cef1..69ca888a8acdbe0 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -98,6 +98,17 @@ def LinalgContractionOpInterface : OpInterface<"ContractionOpInterface"> {
         return mlir::isVecmat($_op.getIndexingMaps());
     }]>,
     InterfaceMethod<
+    /*desc=*/[{
+      Returns whether the given op has indexing maps that correspond to a
+      batched vector-matrix multiplication.
+    }],
+    /*retTy=*/"bool",
+    /*methodName=*/"isBatchVecmat",
+    /*args=*/(ins),
+    /*methodBody=*/[{
+        return mlir::isBatchVecmat($_op.getIndexingMaps());
+    }]>,
+    InterfaceMethod<
     /*desc=*/[{
       Returns whether the given op has indexing maps that correspond to a
       matrix-vector multiplication.
diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index 225b9f287d340db..134c5569fbb2f3e 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -55,6 +55,12 @@ bool isRowMajorBatchMatmul(ArrayAttr indexingMaps);
 /// performed within the reduction.
 bool isVecmat(ArrayAttr indexingMaps);
 
+/// Tests whether the given maps describe a batch vector matrix multiplication.
+/// The test is permutation-invariant. Note that this only checks the affine
+/// maps from an operation, so does not perform any checks on the math being
+/// performed within the reduction.
+bool isBatchVecmat(ArrayAttr indexingMaps);
+
 /// Tests whether the given maps describe a matrix vector multiplication. The
 /// test is permutation-invariant. Note that this only checks the affine maps
 /// from an operation, so does not perform any checks on the math being
diff --git a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
index 641ddf3f91cb2d9..383ef1cea53fd30 100644
--- a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
@@ -120,6 +120,31 @@ bool mlir::isVecmat(ArrayAttr indexingMaps) {
   return indexingMaps == maps;
 }
 
+bool mlir::isBatchVecmat(ArrayAttr indexingMaps) {
+  if (indexingMaps.size() != 3)
+    return false;
+  AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
+  AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
+  AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
+
+  if (map0.getNumResults() != 2 || map1.getNumResults() != 3 ||
+      map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
+      map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
+    return false;
+  }
+
+  // Extract dimensions for B*K * B*K*N -> B*N
+  AffineExpr b = map0.getResult(0);
+  AffineExpr k = map0.getResult(1);
+  AffineExpr n = map2.getResult(1);
+  auto *context = indexingMaps.getContext();
+  auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {b, k}, context));
+  auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {b, k, n}, context));
+  auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {b, n}, context));
+  auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
+  return indexingMaps == maps;
+}
+
 bool mlir::isMatvec(ArrayAttr indexingMaps) {
   if (indexingMaps.size() != 3)
     return false;
diff --git a/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp
index 3f576bacebf6aad..d257fc5d6e041d1 100644
--- a/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp
+++ b/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp
@@ -370,4 +370,56 @@ TEST(isBatchMatvec, WrongDimOrderMatrix) {
   EXPECT_THAT(maps, Not(Truly(isBatchMatvec)));
 }
 
+TEST(isBatchVecmat, Simple) {
+  MLIRContext context;
+
+  AffineExpr batch, k, n;
+  bindDims(&context, batch, k, n);
+  auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
+  auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k, n}, &context));
+  auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
+  auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
+
+  EXPECT_THAT(maps, Truly(isBatchVecmat));
+}
+
+TEST(isBatchVecmat, BindingSwapped) {
+  MLIRContext context;
+
+  AffineExpr batch, k, n;
+  bindDims(&context, batch, n, k); // bind in different order
+  auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
+  auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k, n}, &context));
+  auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
+  auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
+
+  EXPECT_THAT(maps, Truly(isBatchVecmat));
+}
+
+TEST(isBatchVecmat, Matmul) {
+  MLIRContext context;
+
+  AffineExpr m, n, k;
+  bindDims(&context, m, n, k);
+  auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
+  auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
+  auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
+  auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
+
+  EXPECT_THAT(maps, Not(Truly(isBatchVecmat)));
+}
+
+TEST(isBatchVecmat, WrongDimOrderMatrix) {
+  MLIRContext context;
+
+  AffineExpr batch, k, n;
+  bindDims(&context, batch, k, n);
+  auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
+  auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n, k}, &context));
+  auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
+  auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
+
+  EXPECT_THAT(maps, Not(Truly(isBatchVecmat)));
+}
+
 } // namespace

bool mlir::isBatchVecmat(ArrayAttr indexingMaps) {
if (indexingMaps.size() != 3)
return false;
AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could use the infer utils which I find nicer to work with:

if (maps == infer({{m, k}, {k, n}, {m, n}})) {
your call :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see. nicer indeed, but i was copy and pasting other instances in this file. The move to infer looks like it would be a nice improvement, but should be done for the whole file togeteher, so, in a separate PR.

@bjacob bjacob merged commit 8a80e33 into llvm:main Oct 26, 2023
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Oct 26, 2023
`linalg.batch_vecmat` was just added in
llvm#70218, but I forgot then to
add the standard `isBatchVecmat` utilities
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants