-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Mlir] decompose generic by unfolding projected permutation crash fix #122449
base: main
Are you sure you want to change the base?
[Mlir] decompose generic by unfolding projected permutation crash fix #122449
Conversation
@llvm/pr-subscribers-mlir-linalg Author: None (GrumpyPigSkin) ChangesFixes #122094. @CoTinker could you please review. I added the check in DecomposeGenericByUnfoldingPermutation.cpp as adding the check anywhere else was too general and would cause other valid test cases to fail. Full diff: https://github.com/llvm/llvm-project/pull/122449.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
index 83c4b5bdf10976..ce1c21504f1dc7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
@@ -159,6 +159,14 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
auto map = op.getMatchingIndexingMap(&opOperand);
if (!map.isProjectedPermutation(false))
return failure();
+
+ // If we have any inputs that aren't memref or ranked tensor types, reject the pattern.
+ if (!dyn_cast<ShapedType>(opOperand.get().getType()))
+ return op->emitError("Expected operand #")
+ << opOperand.getOperandNumber()
+ << " to be memref of any type values or ranked tensor of any type "
+ "values, but got "
+ << opOperand.get().getType();
}
// Decomposing linalg.generic involves creating `tensor.empty`
diff --git a/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir
new file mode 100644
index 00000000000000..43fdd17e10078c
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt %s -linalg-specialize-generic-ops -verify-diagnostics
+
+// Fixes issue: 122094. Verify that the following code causes an error to be produced.
+
+func.func @test_broadcast_scalar_across_single_tensor() -> tensor<2x2xi32> {
+
+ %a = arith.constant dense<2> : tensor<2x2xi32>
+ %b = arith.constant 42 : i32
+ %c = tensor.empty() : tensor<2x2xi32>
+ // expected-error @+1 {{Expected operand #1 to be memref of any type values or ranked tensor of any type values, but got 'i32'}}
+ %res = linalg.generic
+ {
+ indexing_maps = [
+ affine_map<(i, j) -> (i, j)>,
+ affine_map<(i, j) -> ()>,
+ affine_map<(i, j) -> (i, j)>
+ ],
+ iterator_types = ["parallel", "parallel"]
+ }
+ ins(%a, %b : tensor<2x2xi32>, i32)
+ outs(%c : tensor<2x2xi32>) {
+ ^bb0(%x: i32, %scalar: i32, %out: i32):
+ %sum = arith.addi %x, %scalar : i32
+ linalg.yield %sum : i32
+ } -> tensor<2x2xi32>
+
+ return %res : tensor<2x2xi32>
+}
|
@llvm/pr-subscribers-mlir Author: None (GrumpyPigSkin) ChangesFixes #122094. @CoTinker could you please review. I added the check in DecomposeGenericByUnfoldingPermutation.cpp as adding the check anywhere else was too general and would cause other valid test cases to fail. Full diff: https://github.com/llvm/llvm-project/pull/122449.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
index 83c4b5bdf10976..ce1c21504f1dc7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
@@ -159,6 +159,14 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
auto map = op.getMatchingIndexingMap(&opOperand);
if (!map.isProjectedPermutation(false))
return failure();
+
+ // If we have any inputs that aren't memref or ranked tensor types, reject the pattern.
+ if (!dyn_cast<ShapedType>(opOperand.get().getType()))
+ return op->emitError("Expected operand #")
+ << opOperand.getOperandNumber()
+ << " to be memref of any type values or ranked tensor of any type "
+ "values, but got "
+ << opOperand.get().getType();
}
// Decomposing linalg.generic involves creating `tensor.empty`
diff --git a/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir
new file mode 100644
index 00000000000000..43fdd17e10078c
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt %s -linalg-specialize-generic-ops -verify-diagnostics
+
+// Fixes issue: 122094. Verify that the following code causes an error to be produced.
+
+func.func @test_broadcast_scalar_across_single_tensor() -> tensor<2x2xi32> {
+
+ %a = arith.constant dense<2> : tensor<2x2xi32>
+ %b = arith.constant 42 : i32
+ %c = tensor.empty() : tensor<2x2xi32>
+ // expected-error @+1 {{Expected operand #1 to be memref of any type values or ranked tensor of any type values, but got 'i32'}}
+ %res = linalg.generic
+ {
+ indexing_maps = [
+ affine_map<(i, j) -> (i, j)>,
+ affine_map<(i, j) -> ()>,
+ affine_map<(i, j) -> (i, j)>
+ ],
+ iterator_types = ["parallel", "parallel"]
+ }
+ ins(%a, %b : tensor<2x2xi32>, i32)
+ outs(%c : tensor<2x2xi32>) {
+ ^bb0(%x: i32, %scalar: i32, %out: i32):
+ %sum = arith.addi %x, %scalar : i32
+ linalg.yield %sum : i32
+ } -> tensor<2x2xi32>
+
+ return %res : tensor<2x2xi32>
+}
|
mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty good. Just two comments about the test formatting.
Otherwise, we can just let some of the original issue reporters have a look too.
mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir
Outdated
Show resolved
Hide resolved
|
||
// If we have any inputs that aren't memref or ranked tensor types, reject | ||
// the pattern. | ||
if (!dyn_cast<ShapedType>(opOperand.get().getType())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On more thought, what about 0D shape?
Can is handle input like %b = arith.constant dense<42> : tensor<i32>
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it matches 0D shapes :)
//===-------------------------------------------===//
Processing operation : 'linalg.generic'(0x55556870a890) {
* Pattern mlir::linalg::LinalgSpecializationPattern : 'linalg.generic -> ()' {
Trying to match "mlir::linalg::LinalgSpecializationPattern"
"mlir::linalg::LinalgSpecializationPattern" result 0
} -> failure : pattern failed to match
* Pattern (anonymous namespace)::DecomposeProjectedPermutation : 'linalg.generic -> ()' {
Trying to match "(anonymous namespace)::DecomposeProjectedPermutation"
"(anonymous namespace)::DecomposeProjectedPermutation" result 1
} -> success : pattern applied successfully
// *** IR Dump After Pattern Application ***
func.func @test_broadcast_single_tensor() -> tensor<2x2xi32> {
%cst = arith.constant dense<2> : tensor<2x2xi32>
%cst_0 = arith.constant dense<42> : tensor<i32>
%0 = tensor.empty() : tensor<2x2xi32>
%1 = tensor.empty() : tensor<2x2xi32>
%broadcasted = linalg.broadcast ins(%cst_0 : tensor<i32>) outs(%1 : tensor<2x2xi32>) dimensions = [0, 1]
%2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst, %broadcasted : tensor<2x2xi32>, tensor<2x2xi32>) outs(%0 : tensor<2x2xi32>) {
^bb0(%in: i32, %in_1: i32, %out: i32):
%3 = arith.addi %in, %in_1 : i32
linalg.yield %3 : i32
} -> tensor<2x2xi32>
return %2 : tensor<2x2xi32>
}
} -> success : pattern matched
//===-------------------------------------------===//
//===-------------------------------------------===//
Processing operation : 'arith.addi'(0x555568783370) {
%6 = "arith.addi"(%arg0, %arg1) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
} -> failure : pattern failed to match
On a side note, after matching DecomposeProjectedPermutation
, it then fails to match any other pattern. Is this dependant on the pass being applied or could it indicate that something has gone wrong?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like after the initial rewrite into a new linalg.generic
, the pattern keeps hitting further matches. However, it does not perform any rewrites and the whole pattern fails to converge (essentially a failure).
It might be unrelated to this PR's change but could be worth investigating further.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After the rewrite by linalg.generic for all the operations arith.addi, linalg.broadcast etc, these functions all return 0 in PatternApplicator.cpp
. https://github.com/GrumpyPigSkin/llvm-project/blob/981d5a2b81cc9fe0db90410b5071cb8362ebf4c9/mlir/lib/Rewrite/PatternApplicator.cpp#L151C1-L153C48. I guess that just means --linalg-specialize-generic-ops
doesn't load any of these passes?
@@ -159,6 +160,16 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite( | |||
auto map = op.getMatchingIndexingMap(&opOperand); | |||
if (!map.isProjectedPermutation(false)) | |||
return failure(); | |||
|
|||
// If we have any inputs that aren't memref or ranked tensor types, reject |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is only pure tensor semantics.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, is ShapedType
appropriate for the check, or should I use RankedTensorType
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe RankedTensorType
is more appropriate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to RankedTensorType
:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for putting out this test case to show current limitation,
@@ -69,3 +69,33 @@ func.func @broadcast_only(%x : tensor<2x16x32xf32>, %y: tensor<2x32xf32>, %z : | |||
// CHECK: %[[X_bc:.+]] = linalg.broadcast ins(%[[Y]] : tensor<2x32xf32>) outs(%[[E0]] : tensor<2x16x32xf32>) dimensions = [1] | |||
// CHECK: {{.*}} = linalg.div ins(%[[X]], %[[X_bc]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%arg2 : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> | |||
// CHECK-NOT: linalg.generic | |||
|
|||
// ----- | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add comment ?
// unsupported currently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added comment :)
Fixes #122094.
This PR adds the check in DecomposeGenericByUnfoldingPermutation.cpp as adding the check anywhere else was too general and would cause other valid test cases to fail.