Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mlir] decompose generic by unfolding projected permutation crash fix #122449

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

GrumpyPigSkin
Copy link
Contributor

@GrumpyPigSkin GrumpyPigSkin commented Jan 10, 2025

Fixes #122094.
This PR adds the check in DecomposeGenericByUnfoldingPermutation.cpp as adding the check anywhere else was too general and would cause other valid test cases to fail.

@llvmbot
Copy link
Member

llvmbot commented Jan 10, 2025

@llvm/pr-subscribers-mlir-linalg

Author: None (GrumpyPigSkin)

Changes

Fixes #122094.

@CoTinker could you please review.

I added the check in DecomposeGenericByUnfoldingPermutation.cpp as adding the check anywhere else was too general and would cause other valid test cases to fail.


Full diff: https://github.com/llvm/llvm-project/pull/122449.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp (+8)
  • (added) mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir (+28)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
index 83c4b5bdf10976..ce1c21504f1dc7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
@@ -159,6 +159,14 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
     auto map = op.getMatchingIndexingMap(&opOperand);
     if (!map.isProjectedPermutation(false))
       return failure();
+
+    // If we have any inputs that aren't memref or ranked tensor types, reject the pattern.
+    if (!dyn_cast<ShapedType>(opOperand.get().getType()))
+      return op->emitError("Expected operand #")
+             << opOperand.getOperandNumber()
+             << " to be memref of any type values or ranked tensor of any type "
+                "values, but got "
+             << opOperand.get().getType();
   }
 
   // Decomposing linalg.generic involves creating `tensor.empty`
diff --git a/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir
new file mode 100644
index 00000000000000..43fdd17e10078c
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt %s -linalg-specialize-generic-ops -verify-diagnostics
+
+// Fixes issue: 122094. Verify that the following code causes an error to be produced.
+
+func.func @test_broadcast_scalar_across_single_tensor() -> tensor<2x2xi32> {
+
+  %a = arith.constant dense<2> : tensor<2x2xi32>
+  %b = arith.constant 42 : i32
+  %c = tensor.empty() : tensor<2x2xi32>
+  // expected-error @+1 {{Expected operand #1 to be memref of any type values or ranked tensor of any type values, but got 'i32'}}
+  %res = linalg.generic
+    {
+      indexing_maps = [
+        affine_map<(i, j) -> (i, j)>, 
+        affine_map<(i, j) -> ()>,     
+        affine_map<(i, j) -> (i, j)>  
+      ],
+      iterator_types = ["parallel", "parallel"]
+    }
+    ins(%a, %b : tensor<2x2xi32>, i32)
+    outs(%c : tensor<2x2xi32>) {
+  ^bb0(%x: i32, %scalar: i32, %out: i32):
+    %sum = arith.addi %x, %scalar : i32
+    linalg.yield %sum : i32
+  } -> tensor<2x2xi32>
+
+  return %res : tensor<2x2xi32>
+}

@llvmbot
Copy link
Member

llvmbot commented Jan 10, 2025

@llvm/pr-subscribers-mlir

Author: None (GrumpyPigSkin)

Changes

Fixes #122094.

@CoTinker could you please review.

I added the check in DecomposeGenericByUnfoldingPermutation.cpp as adding the check anywhere else was too general and would cause other valid test cases to fail.


Full diff: https://github.com/llvm/llvm-project/pull/122449.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp (+8)
  • (added) mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir (+28)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
index 83c4b5bdf10976..ce1c21504f1dc7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
@@ -159,6 +159,14 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
     auto map = op.getMatchingIndexingMap(&opOperand);
     if (!map.isProjectedPermutation(false))
       return failure();
+
+    // If we have any inputs that aren't memref or ranked tensor types, reject the pattern.
+    if (!dyn_cast<ShapedType>(opOperand.get().getType()))
+      return op->emitError("Expected operand #")
+             << opOperand.getOperandNumber()
+             << " to be memref of any type values or ranked tensor of any type "
+                "values, but got "
+             << opOperand.get().getType();
   }
 
   // Decomposing linalg.generic involves creating `tensor.empty`
diff --git a/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir
new file mode 100644
index 00000000000000..43fdd17e10078c
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt %s -linalg-specialize-generic-ops -verify-diagnostics
+
+// Fixes issue: 122094. Verify that the following code causes an error to be produced.
+
+func.func @test_broadcast_scalar_across_single_tensor() -> tensor<2x2xi32> {
+
+  %a = arith.constant dense<2> : tensor<2x2xi32>
+  %b = arith.constant 42 : i32
+  %c = tensor.empty() : tensor<2x2xi32>
+  // expected-error @+1 {{Expected operand #1 to be memref of any type values or ranked tensor of any type values, but got 'i32'}}
+  %res = linalg.generic
+    {
+      indexing_maps = [
+        affine_map<(i, j) -> (i, j)>, 
+        affine_map<(i, j) -> ()>,     
+        affine_map<(i, j) -> (i, j)>  
+      ],
+      iterator_types = ["parallel", "parallel"]
+    }
+    ins(%a, %b : tensor<2x2xi32>, i32)
+    outs(%c : tensor<2x2xi32>) {
+  ^bb0(%x: i32, %scalar: i32, %out: i32):
+    %sum = arith.addi %x, %scalar : i32
+    linalg.yield %sum : i32
+  } -> tensor<2x2xi32>
+
+  return %res : tensor<2x2xi32>
+}

Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good. Just two comments about the test formatting.
Otherwise, we can just let some of the original issue reporters have a look too.


// If we have any inputs that aren't memref or ranked tensor types, reject
// the pattern.
if (!dyn_cast<ShapedType>(opOperand.get().getType()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On more thought, what about 0D shape?
Can is handle input like %b = arith.constant dense<42> : tensor<i32>?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it matches 0D shapes :)

//===-------------------------------------------===//
Processing operation : 'linalg.generic'(0x55556870a890) {

  * Pattern mlir::linalg::LinalgSpecializationPattern : 'linalg.generic -> ()' {
Trying to match "mlir::linalg::LinalgSpecializationPattern"
"mlir::linalg::LinalgSpecializationPattern" result 0
  } -> failure : pattern failed to match

  * Pattern (anonymous namespace)::DecomposeProjectedPermutation : 'linalg.generic -> ()' {
Trying to match "(anonymous namespace)::DecomposeProjectedPermutation"
"(anonymous namespace)::DecomposeProjectedPermutation" result 1
  } -> success : pattern applied successfully
// *** IR Dump After Pattern Application ***
func.func @test_broadcast_single_tensor() -> tensor<2x2xi32> {
  %cst = arith.constant dense<2> : tensor<2x2xi32>
  %cst_0 = arith.constant dense<42> : tensor<i32>
  %0 = tensor.empty() : tensor<2x2xi32>
  %1 = tensor.empty() : tensor<2x2xi32>
  %broadcasted = linalg.broadcast ins(%cst_0 : tensor<i32>) outs(%1 : tensor<2x2xi32>) dimensions = [0, 1] 
  %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst, %broadcasted : tensor<2x2xi32>, tensor<2x2xi32>) outs(%0 : tensor<2x2xi32>) {
  ^bb0(%in: i32, %in_1: i32, %out: i32):
    %3 = arith.addi %in, %in_1 : i32
    linalg.yield %3 : i32
  } -> tensor<2x2xi32>
  return %2 : tensor<2x2xi32>
}


} -> success : pattern matched
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'arith.addi'(0x555568783370) {
  %6 = "arith.addi"(%arg0, %arg1) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32

} -> failure : pattern failed to match

On a side note, after matching DecomposeProjectedPermutation, it then fails to match any other pattern. Is this dependant on the pass being applied or could it indicate that something has gone wrong?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like after the initial rewrite into a new linalg.generic, the pattern keeps hitting further matches. However, it does not perform any rewrites and the whole pattern fails to converge (essentially a failure).

It might be unrelated to this PR's change but could be worth investigating further.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After the rewrite by linalg.generic for all the operations arith.addi, linalg.broadcast etc, these functions all return 0 in PatternApplicator.cpp. https://github.com/GrumpyPigSkin/llvm-project/blob/981d5a2b81cc9fe0db90410b5071cb8362ebf4c9/mlir/lib/Rewrite/PatternApplicator.cpp#L151C1-L153C48. I guess that just means --linalg-specialize-generic-ops doesn't load any of these passes?

@@ -159,6 +160,16 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
auto map = op.getMatchingIndexingMap(&opOperand);
if (!map.isProjectedPermutation(false))
return failure();

// If we have any inputs that aren't memref or ranked tensor types, reject
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is only pure tensor semantics.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, is ShapedType appropriate for the check, or should I use RankedTensorType ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe RankedTensorType is more appropriate.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to RankedTensorType :)

Copy link
Contributor

@javedabsar1 javedabsar1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for putting out this test case to show current limitation,

@@ -69,3 +69,33 @@ func.func @broadcast_only(%x : tensor<2x16x32xf32>, %y: tensor<2x32xf32>, %z :
// CHECK: %[[X_bc:.+]] = linalg.broadcast ins(%[[Y]] : tensor<2x32xf32>) outs(%[[E0]] : tensor<2x16x32xf32>) dimensions = [1]
// CHECK: {{.*}} = linalg.div ins(%[[X]], %[[X_bc]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%arg2 : tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
// CHECK-NOT: linalg.generic

// -----

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comment ?
// unsupported currently.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comment :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Mlir] --linalg-specialize-generic-ops crashes in Casting.h:566
5 participants