Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][vector] Constrain patterns: vector.contract -> vector.outerproduct #68400

Closed

Conversation

banach-space
Copy link
Contributor

This patch constrains the patterns for converting vector.contract to
vector.outerproduct so that

  • the reduction dimension is not unrolled if the corresponding
    dimension is scalable.

This is necessary as the current lowering is incorrect for scalable
dims. Indeed, instead of the following unrolling that is currently being
generated for vector.contract (K is the size f the reduction
dimension):

  // K times
  %lhs = vector.extract %LHS[0]
  %rhs = vector.extract %RHS[0]
  vector.outerproduct %lhs, %rhs

  %lhs = vector.extract %LHS[1]
  %rhs = vector.extract %RHS[1]
  vector.outerproduct %lhs, %rhs

  ...

we should be generating a for loop like the following:

scf.for %k = 0 to K step 1
  %lhs = vector.extract LHS[%k]
  %rhs = vector.extract RHS[%k]
  vector.outerproduct %lhs, %rhs

However, the lowering of vector.extract of vector slices with dynamic
indices is incomplete and hence the implementation above wouldn't work
just yet. Instead, this patch effectively disables unrolling in case
where the generated code would be functionally incorrect (i.e. when the
reduction dimension is scalable).

In order to document unsupported cases, a dedicated test file is added:

  • "vector-contract-to-outerproduct-transforms-unsupported.mlir"

This is the first patch in a series of patches that strives to update
these patterns (and to test them) for scalable vectors.

…duct

This patch constrains the patterns for converting `vector.contract` to
`vector.outerproduct` so that

  * the reduction dimension is _not unrolled_ if the corresponding
    dimension is scalable.

This is necessary as the current lowering is incorrect for scalable
dims. Indeed, instead of the following unrolling that is currently being
generated for `vector.contract` (K is the size  f the reduction
dimension):

```
  // K times
  %lhs = vector.extract %LHS[0]
  %rhs = vector.extract %RHS[0]
  vector.outerproduct %lhs, %rhs

  %lhs = vector.extract %LHS[1]
  %rhs = vector.extract %RHS[1]
  vector.outerproduct %lhs, %rhs

  ...
```

we should be generating a `for` loop like the following:
```
scf.for %k = 0 to K step 1
  %lhs = vector.extract LHS[%k]
  %rhs = vector.extract RHS[%k]
  vector.outerproduct %lhs, %rhs
```

However, the lowering of `vector.extract` of vector slices with dynamic
indices is incomplete and hence the implementation above wouldn't work
just yet. Instead, this patch effectively disables unrolling in case
where the generated code would be functionally incorrect (i.e. when the
reduction dimension is scalable).

In order to document unsupported cases, a dedicated test file is added:

* "vector-contract-to-outerproduct-transforms-unsupported.mlir"

This is the first patch in a series of patches that strives to update
these patterns (and to test them) for scalable vectors.
@llvmbot
Copy link
Member

llvmbot commented Oct 6, 2023

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Changes

This patch constrains the patterns for converting vector.contract to
vector.outerproduct so that

  • the reduction dimension is not unrolled if the corresponding
    dimension is scalable.

This is necessary as the current lowering is incorrect for scalable
dims. Indeed, instead of the following unrolling that is currently being
generated for vector.contract (K is the size f the reduction
dimension):

  // K times
  %lhs = vector.extract %LHS[0]
  %rhs = vector.extract %RHS[0]
  vector.outerproduct %lhs, %rhs

  %lhs = vector.extract %LHS[1]
  %rhs = vector.extract %RHS[1]
  vector.outerproduct %lhs, %rhs

  ...

we should be generating a for loop like the following:

scf.for %k = 0 to K step 1
  %lhs = vector.extract LHS[%k]
  %rhs = vector.extract RHS[%k]
  vector.outerproduct %lhs, %rhs

However, the lowering of vector.extract of vector slices with dynamic
indices is incomplete and hence the implementation above wouldn't work
just yet. Instead, this patch effectively disables unrolling in case
where the generated code would be functionally incorrect (i.e. when the
reduction dimension is scalable).

In order to document unsupported cases, a dedicated test file is added:

  • "vector-contract-to-outerproduct-transforms-unsupported.mlir"

This is the first patch in a series of patches that strives to update
these patterns (and to test them) for scalable vectors.


Full diff: https://github.com/llvm/llvm-project/pull/68400.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+39-18)
  • (added) mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir (+33)
  • (modified) mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir (+30-23)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 04d9ddf2183f8c5..c1cc0d7c64de264 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -424,9 +424,14 @@ struct UnrolledOuterProductGenerator
     return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
   }
 
-  FailureOr<Value> outerProd(Value lhs, Value rhs, Value res, int reductionSize,
+  FailureOr<Value> outerProd(Value lhs, Value rhs, Value res, VectorType lhsType,
+                             int reductionDim,
                              std::optional<Value> maybeMask = std::nullopt) {
-    assert(reductionSize > 0);
+    // Unrolling a scalable dimension would be incorrect - bail out.
+    if (lhsType.getScalableDims()[reductionDim])
+      return failure();
+
+    int reductionSize = lhsType.getDimSize(reductionDim);
     // Incremental support for masking.
     if (mask && !maybeMask.has_value())
       return failure();
@@ -459,33 +464,39 @@ struct UnrolledOuterProductGenerator
     Value transposedMask = t(mask, {2, 0, 1});
     // Classical row-major matmul:  Just permute the lhs.
     if (layout({{m, k}, {k, n}, {m, n}}))
-      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), transposedMask);
+      return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1,
+                       transposedMask);
     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
     if (layout({{m, k}, {n, k}, {m, n}})) {
       Value tlhs = t(lhs);
-      return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1),
+      return outerProd(tlhs, t(rhs), res, lhsType, /*reductionDim=*/1,
                        transposedMask);
     }
     // No need to permute anything.
     if (layout({{k, m}, {k, n}, {m, n}}))
-      return outerProd(lhs, rhs, res, lhsType.getDimSize(0), transposedMask);
+      return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0,
+                       transposedMask);
     // Just permute the rhs.
     if (layout({{k, m}, {n, k}, {m, n}}))
-      return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0), transposedMask);
+      return outerProd(lhs, t(rhs), res, lhsType, /*reductionDim=*/0,
+                       transposedMask);
     // Transposed output: swap RHS and LHS.
     // Classical row-major matmul: permute the lhs.
     if (layout({{m, k}, {k, n}, {n, m}}))
-      return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1), transposedMask);
+      return outerProd(rhs, t(lhs), res, lhsType, /*reductionDim=*/1,
+                       transposedMask);
     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
     if (layout({{m, k}, {n, k}, {n, m}})) {
       Value trhs = t(rhs);
-      return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1),
+      return outerProd(trhs, t(lhs), res, lhsType, /*reductionDim=*/1,
                        transposedMask);
     }
     if (layout({{k, m}, {k, n}, {n, m}}))
-      return outerProd(rhs, lhs, res, lhsType.getDimSize(0), transposedMask);
+      return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0,
+                       transposedMask);
     if (layout({{k, m}, {n, k}, {n, m}}))
-      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), transposedMask);
+      return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0,
+                       transposedMask);
     return failure();
   }
 
@@ -503,16 +514,20 @@ struct UnrolledOuterProductGenerator
 
     // Case mat-vec: transpose.
     if (layout({{m, k}, {k}, {m}}))
-      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), transposedMask);
+      return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1,
+                       transposedMask);
     // Case mat-trans-vec: ready to go.
     if (layout({{k, m}, {k}, {m}}))
-      return outerProd(lhs, rhs, res, lhsType.getDimSize(0), transposedMask);
+      return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0,
+                       transposedMask);
     // Case vec-mat: swap and transpose.
     if (layout({{k}, {m, k}, {m}}))
-      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), transposedMask);
+      return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0,
+                       transposedMask);
     // Case vec-mat-trans: swap and ready to go.
     if (layout({{k}, {k, m}, {m}}))
-      return outerProd(rhs, lhs, res, lhsType.getDimSize(0), transposedMask);
+      return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0,
+                       transposedMask);
     return failure();
   }
 
@@ -528,16 +543,16 @@ struct UnrolledOuterProductGenerator
 
     // Case mat-vec: transpose.
     if (layout({{m, k}, {k}, {m}}))
-      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), mask);
+      return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1, mask);
     // Case mat-trans-vec: ready to go.
     if (layout({{k, m}, {k}, {m}}))
-      return outerProd(lhs, rhs, res, lhsType.getDimSize(0), mask);
+      return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0, mask);
     // Case vec-mat: swap and transpose.
     if (layout({{k}, {m, k}, {m}}))
-      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), mask);
+      return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0, mask);
     // Case vec-mat-trans: swap and ready to go.
     if (layout({{k}, {k, m}, {m}}))
-      return outerProd(rhs, lhs, res, lhsType.getDimSize(0), mask);
+      return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0, mask);
     return failure();
   }
 
@@ -980,9 +995,15 @@ FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
         diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
              << " to map to the same dimension";
       });
+    // Unrolling a scalable dimension would be incorrect - bail out.
+    if (lhsType.getScalableDims()[lhsIndex])
+      return failure();
     dimSize = lhsType.getDimSize(lhsIndex);
   } else if (rhsIndex >= 0) {
     iterIndex = iMap[1].getDimPosition(rhsIndex);
+    // Unrolling a scalable dimension would be incorrect - bail out.
+    if (rhsType.getScalableDims()[rhsIndex])
+      return failure();
     dimSize = rhsType.getDimSize(rhsIndex);
   }
   if (iterIndex < 0)
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir
new file mode 100644
index 000000000000000..a955250107d73d7
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics
+
+#matvec_accesses = [
+  affine_map<(i, j) -> (i, j)>,
+  affine_map<(i, j) -> (j)>,
+  affine_map<(i, j) -> (i)>
+]
+#matvec_trait = {
+  indexing_maps = #matvec_accesses,
+  iterator_types = ["parallel", "reduction"]
+}
+
+// Unrolling scalable reduction dim is not supported - bail out
+
+// expected-error@below {{greedy pattern application failed}}
+func.func @masked_extract_contract2_scalable_reduction_dim(%arg0: vector<[2]x[3]xf32>,
+                                    %arg1: vector<[3]xf32>,
+                                    %arg2: vector<[2]xf32>,
+                                    %m: vector<[2]x[3]xi1>) -> vector<[2]xf32> {
+  %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
+          : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32> } : vector<[2]x[3]xi1> -> vector<[2]xf32>
+  return %0 : vector<[2]xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+  %f = transform.structured.match ops{["func.func"]} in %module_op 
+    : (!transform.any_op) -> !transform.any_op
+
+  transform.apply_patterns to %f {
+    transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+  } : !transform.any_op
+}
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index deea7747f36799c..8ee0a35717ce87a 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -31,19 +31,19 @@
 }
 
 // CHECK-LABEL:   func.func @masked_extract_contract2(
-// CHECK-SAME:                                      %[[VAL_0:.*]]: vector<2x3xf32>,
-// CHECK-SAME:                                      %[[VAL_1:.*]]: vector<3xf32>,
-// CHECK-SAME:                                      %[[VAL_2:.*]]: vector<2xf32>,
-// CHECK-SAME:                                      %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
+// CHECK-SAME:      %{{.*}}: vector<2x3xf32>,
+// CHECK-SAME:      %{{.*}}: vector<3xf32>,
+// CHECK-SAME:      %{{.*}}: vector<2xf32>,
+// CHECK-SAME:      %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
 // CHECK:           %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1>
 // CHECK:           %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<2xi1> from vector<3x2xi1>
-// CHECK:           vector.mask %[[MASK0]] { vector.outerproduct
+// CHECK:           vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
 
 // CHECK:           %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<2xi1> from vector<3x2xi1>
-// CHECK:           vector.mask %[[MASK1]] { vector.outerproduct
+// CHECK:           vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
 
 // CHECK:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1>
-// CHECK:           vector.mask %[[MASK2]] { vector.outerproduct
+// CHECK:           vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
 
 func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
                                     %arg1: vector<3xf32>,
@@ -54,22 +54,29 @@ func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
   return %0 : vector<2xf32>
 }
 
-// CHECK-LABEL: func.func @masked_extract_contract4(
-// CHECK-SAME:                                      %[[VAL_0:.*]]: vector<3x5xf32>,
-// CHECK-SAME:                                      %[[VAL_1:.*]]: vector<5x7xf32>,
-// CHECK-SAME:                                      %[[VAL_2:.*]]: vector<3x7xf32>,
-// CHECK-SAME:                                      %[[VAL_3:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
-// CHECK:         %[[VAL_5:.*]] = vector.transpose %[[VAL_3]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
-// CHECK:         %[[VAL_8:.*]] = vector.extract %[[VAL_5]][0] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK:         %[[VAL_12:.*]] = vector.extract %[[VAL_5]][1] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK:         %[[VAL_16:.*]] = vector.extract %[[VAL_5]][2] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK:         %[[VAL_20:.*]] = vector.extract %[[VAL_5]][3] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK:         %[[VAL_24:.*]] = vector.extract %[[VAL_5]][4] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_25:.*]] = vector.mask %[[VAL_24]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+
+// CHECK-LABEL:   func.func @masked_extract_contract2_scalable_parallel_dim(
+// CHECK-SAME:      %{{.*}}: vector<[2]x3xf32>,
+// CHECK-SAME:      %{{.*}}: vector<3xf32>,
+// CHECK-SAME:      %{{.*}}: vector<[2]xf32>,
+// CHECK-SAME:      %[[IN_MASK:.*]]: vector<[2]x3xi1>) -> vector<[2]xf32>
+// CHECK:           %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<[2]x3xi1> to vector<3x[2]xi1>
+// CHECK:           %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK:           vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+
+// CHECK:           %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK:           vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+
+// CHECK:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK:           vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+func.func @masked_extract_contract2_scalable_parallel_dim(%arg0: vector<[2]x3xf32>,
+                                    %arg1: vector<3xf32>,
+                                    %arg2: vector<[2]xf32>,
+                                    %m: vector<[2]x3xi1>) -> vector<[2]xf32> {
+  %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
+          : vector<[2]x3xf32>, vector<3xf32> into vector<[2]xf32> } : vector<[2]x3xi1> -> vector<[2]xf32>
+  return %0 : vector<[2]xf32>
+}
 
 func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
                                     %arg1: vector<5x7xf32>,

@github-actions
Copy link

github-actions bot commented Oct 6, 2023

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff ff843c00ce1df5af29d5dae671086b92dcabf94b bb27634cbea6803cb3d8e1b15e3a85204f9d40c4 -- mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
View the diff from clang-format here.
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index f0bc7ec6f723..311e589547b8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -424,8 +424,8 @@ struct UnrolledOuterProductGenerator
     return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
   }
 
-  FailureOr<Value> outerProd(Value lhs, Value rhs, Value res, VectorType lhsType,
-                             int reductionDim,
+  FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
+                             VectorType lhsType, int reductionDim,
                              std::optional<Value> maybeMask = std::nullopt) {
     // Unrolling a scalable dimension would be incorrect - bail out.
     if (lhsType.getScalableDims()[reductionDim])

@@ -980,9 +995,15 @@ FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
<< " to map to the same dimension";
});
// Unrolling a scalable dimension would be incorrect - bail out.
if (lhsType.getScalableDims()[lhsIndex])
return failure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we notify with a message here and below ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, will do :)

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just a small question:

@@ -54,22 +54,29 @@ func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
return %0 : vector<2xf32>
}

// CHECK-LABEL: func.func @masked_extract_contract4(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this fixed size test removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy and paste failure, sorry :( Thanks for pointing this out!

Copy link
Contributor Author

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reviews, I will send an update shortly and then merge it.

@@ -980,9 +995,15 @@ FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
<< " to map to the same dimension";
});
// Unrolling a scalable dimension would be incorrect - bail out.
if (lhsType.getScalableDims()[lhsIndex])
return failure();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, will do :)

@@ -54,22 +54,29 @@ func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
return %0 : vector<2xf32>
}

// CHECK-LABEL: func.func @masked_extract_contract4(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy and paste failure, sorry :( Thanks for pointing this out!

…uterproduct

* Restore CHECK lines that were removed accidentally;
* Add diag messages
@banach-space banach-space deleted the andrzej/constrain_contract branch October 6, 2023 16:09
std::optional<Value> maybeMask = std::nullopt) {
assert(reductionSize > 0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why was this assert removed?

@@ -980,9 +995,19 @@ FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
<< " to map to the same dimension";
});
if (lhsType.getScalableDims()[lhsIndex])
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "Unrolloing scalable dimension (lhsIndex=" << lhsIndex
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/Unrolloing/Unrolling

dimSize = lhsType.getDimSize(lhsIndex);
} else if (rhsIndex >= 0) {
iterIndex = iMap[1].getDimPosition(rhsIndex);
if (rhsType.getScalableDims()[rhsIndex])
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "Unrolloing scalable dimension (lhsIndex=" << lhsIndex
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be rhsIndex?

banach-space added a commit to banach-space/llvm-project that referenced this pull request Oct 9, 2023
Follow-up for llvm#68400 - restoring an assert that was accidentally removed
and fixed a typo in a diagnostic.
@banach-space
Copy link
Contributor Author

@c-rhodes Thanks for the review 🙏🏻 . As this has already been merged (i.e. before you left your comments), I've uploaded a follow-up: #68581.

banach-space added a commit that referenced this pull request Oct 9, 2023
Follow-up for #68400 - restoring an assert that was accidentally removed
and fixed a typo in a diagnostic.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants