-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add missing linalg.batch_vecmat
named op
#70218
Conversation
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: None (bjacob) ChangesLinalg currently has these named ops:
But it does not have:
This PRs adds that for consistency, and I have a short-term need for it ( iree-org/iree#15158 ), so not having this would cause some contortion on my end. Full diff: https://github.com/llvm/llvm-project/pull/70218.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index cd64b813c11e532..12d520cd382413a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1796,6 +1796,74 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: B
--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: batch_vecmat
+ cpp_class_name: BatchVecmatOp
+ doc: |-
+ Performs a batched matrix-vector multiplication.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ implements:
+ - LinalgContractionOpInterface
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: A
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
+ - !LinalgOperandDefConfig
+ name: B
+ kind: input_tensor
+ type_var: T2
+ shape_map: affine_map<()[s0, s1, s2] -> (s0, s1, s2)>
+ - !LinalgOperandDefConfig
+ name: C
+ kind: output_tensor
+ type_var: U
+ shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
+ - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2, d1)>
+ - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
+ iterator_types:
+ - parallel
+ - parallel
+ - reduction
+ assignments:
+ - !ScalarAssign
+ arg: C
+ value: !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ scalar_arg: C
+ - !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: mul
+ operands:
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast_signed
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: A
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast_signed
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: B
+--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: dot
cpp_class_name: DotOp
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 19734a80a107bfe..5144c42480cbc75 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -516,6 +516,23 @@ def batch_matvec(
U, B[D.b, D.k]
)
+@linalg_structured_op
+def batch_vecmat(
+ A=TensorDef(T1, Batch, S.K),
+ B=TensorDef(T2, Batch, S.K, S.N),
+ C=TensorDef(U, Batch, S.N, output=True),
+):
+ """Performs a batched matrix-vector multiplication.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ domain(D.b, D.n, D.k)
+ implements(ContractionOpInterface)
+ C[D.b, D.n] += TypeFn.cast_signed(U, A[D.b, D.k]) * TypeFn.cast_signed(
+ U, B[D.b, D.k, D.n]
+ )
+
@linalg_structured_op
def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, output=True)):
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 54cc0defc1f8cd8..2259d47eb2b2b0d 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -251,6 +251,31 @@ func.func @generalize_batch_matm_vec(%lhs : memref<?x?x?xi8>, %rhs: memref<?x?xi
// -----
+func.func @generalize_batch_vecmat(%lhs : memref<?x?xi8>, %rhs: memref<?x?x?xi8>, %out: memref<?x?xf32>) {
+ linalg.batch_vecmat ins(%lhs, %rhs: memref<?x?xi8>, memref<?x?x?xi8>)
+ outs(%out: memref<?x?xf32>)
+ return
+}
+// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK: @generalize_batch_vecmat
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?xi8>, memref<?x?x?xi8>)
+// CHECK-SAME: outs(%{{.+}} : memref<?x?xf32>)
+// CHECK: ^{{.+}}(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: f32)
+// CHECK: %[[BBARG0_F32:.+]] = arith.sitofp %[[BBARG0]] : i8 to f32
+// CHECK: %[[BBARG1_F32:.+]] = arith.sitofp %[[BBARG1]] : i8 to f32
+// CHECK: %[[MUL:.+]] = arith.mulf %[[BBARG0_F32]], %[[BBARG1_F32]]
+// CHECK: %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]]
+// CHECK: linalg.yield %[[ADD]] : f32
+
+// -----
+
func.func @batch_reduce_gemm(%lhs: memref<7x8x9xf32>, %rhs: memref<7x9x8xf32>, %out: memref<8x8xf32>) {
linalg.batch_reduce_matmul ins(%lhs, %rhs: memref<7x8x9xf32>, memref<7x9x8xf32>)
outs(%out: memref<8x8xf32>)
|
✅ With the latest revision this PR passed the Python code formatter. |
Linalg currently has these named ops: * `matmul` * `matvec` * `vecmat` * `batch_matmul` * `batch_matvec` But it does not have: * `batch_vecmat` This PRs adds that for consistency, and I have a short-term need for it ( iree-org/iree#15158 ), so not having this would cause some contortion on my end.
`linalg.batch_vecmat` was just added in #70218, but I forgot then to add the standard `isBatchVecmat` utilities
Linalg currently has these named ops: * `matmul` * `matvec` * `vecmat` * `batch_matmul` * `batch_matvec` But it does not have: * `batch_vecmat` This PRs adds that for consistency, and I have a short-term need for it ( iree-org/iree#15158 ), so not having this would cause some contortion on my end.
`linalg.batch_vecmat` was just added in llvm#70218, but I forgot then to add the standard `isBatchVecmat` utilities
llvm/llvm-project#70218 just missed the last integrate, and cherry-picks are frowned upon. The good thing with just missing an integrate is that just bumping the submodule shouldn't be too hard still. I just had to fix up one small thing in CollapseDimensions. ci-extra:build_test_all_windows,build_test_all_macos_arm64,build_test_all_macos_x86_64
llvm/llvm-project#70218 just missed the last integrate, and cherry-picks are frowned upon. The good thing with just missing an integrate is that just bumping the submodule shouldn't be too hard still. I just had to fix up one small thing in CollapseDimensions. ci-extra:build_test_all_windows,build_test_all_macos_arm64,build_test_all_macos_x86_64
Linalg currently has these named ops:
matmul
matvec
vecmat
batch_matmul
batch_matvec
But it does not have:
batch_vecmat
This PRs adds that for consistency, and I have a short-term need for it ( iree-org/iree#15158 ), so not having this would cause some contortion on my end.