-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Constrain patterns: vector.contract -> vector.outerproduct #68400
Conversation
…duct This patch constrains the patterns for converting `vector.contract` to `vector.outerproduct` so that * the reduction dimension is _not unrolled_ if the corresponding dimension is scalable. This is necessary as the current lowering is incorrect for scalable dims. Indeed, instead of the following unrolling that is currently being generated for `vector.contract` (K is the size f the reduction dimension): ``` // K times %lhs = vector.extract %LHS[0] %rhs = vector.extract %RHS[0] vector.outerproduct %lhs, %rhs %lhs = vector.extract %LHS[1] %rhs = vector.extract %RHS[1] vector.outerproduct %lhs, %rhs ... ``` we should be generating a `for` loop like the following: ``` scf.for %k = 0 to K step 1 %lhs = vector.extract LHS[%k] %rhs = vector.extract RHS[%k] vector.outerproduct %lhs, %rhs ``` However, the lowering of `vector.extract` of vector slices with dynamic indices is incomplete and hence the implementation above wouldn't work just yet. Instead, this patch effectively disables unrolling in case where the generated code would be functionally incorrect (i.e. when the reduction dimension is scalable). In order to document unsupported cases, a dedicated test file is added: * "vector-contract-to-outerproduct-transforms-unsupported.mlir" This is the first patch in a series of patches that strives to update these patterns (and to test them) for scalable vectors.
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir ChangesThis patch constrains the patterns for converting
This is necessary as the current lowering is incorrect for scalable
we should be generating a
However, the lowering of In order to document unsupported cases, a dedicated test file is added:
This is the first patch in a series of patches that strives to update Full diff: https://github.com/llvm/llvm-project/pull/68400.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 04d9ddf2183f8c5..c1cc0d7c64de264 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -424,9 +424,14 @@ struct UnrolledOuterProductGenerator
return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
}
- FailureOr<Value> outerProd(Value lhs, Value rhs, Value res, int reductionSize,
+ FailureOr<Value> outerProd(Value lhs, Value rhs, Value res, VectorType lhsType,
+ int reductionDim,
std::optional<Value> maybeMask = std::nullopt) {
- assert(reductionSize > 0);
+ // Unrolling a scalable dimension would be incorrect - bail out.
+ if (lhsType.getScalableDims()[reductionDim])
+ return failure();
+
+ int reductionSize = lhsType.getDimSize(reductionDim);
// Incremental support for masking.
if (mask && !maybeMask.has_value())
return failure();
@@ -459,33 +464,39 @@ struct UnrolledOuterProductGenerator
Value transposedMask = t(mask, {2, 0, 1});
// Classical row-major matmul: Just permute the lhs.
if (layout({{m, k}, {k, n}, {m, n}}))
- return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), transposedMask);
+ return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1,
+ transposedMask);
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
if (layout({{m, k}, {n, k}, {m, n}})) {
Value tlhs = t(lhs);
- return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1),
+ return outerProd(tlhs, t(rhs), res, lhsType, /*reductionDim=*/1,
transposedMask);
}
// No need to permute anything.
if (layout({{k, m}, {k, n}, {m, n}}))
- return outerProd(lhs, rhs, res, lhsType.getDimSize(0), transposedMask);
+ return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0,
+ transposedMask);
// Just permute the rhs.
if (layout({{k, m}, {n, k}, {m, n}}))
- return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0), transposedMask);
+ return outerProd(lhs, t(rhs), res, lhsType, /*reductionDim=*/0,
+ transposedMask);
// Transposed output: swap RHS and LHS.
// Classical row-major matmul: permute the lhs.
if (layout({{m, k}, {k, n}, {n, m}}))
- return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1), transposedMask);
+ return outerProd(rhs, t(lhs), res, lhsType, /*reductionDim=*/1,
+ transposedMask);
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
if (layout({{m, k}, {n, k}, {n, m}})) {
Value trhs = t(rhs);
- return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1),
+ return outerProd(trhs, t(lhs), res, lhsType, /*reductionDim=*/1,
transposedMask);
}
if (layout({{k, m}, {k, n}, {n, m}}))
- return outerProd(rhs, lhs, res, lhsType.getDimSize(0), transposedMask);
+ return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0,
+ transposedMask);
if (layout({{k, m}, {n, k}, {n, m}}))
- return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), transposedMask);
+ return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0,
+ transposedMask);
return failure();
}
@@ -503,16 +514,20 @@ struct UnrolledOuterProductGenerator
// Case mat-vec: transpose.
if (layout({{m, k}, {k}, {m}}))
- return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), transposedMask);
+ return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1,
+ transposedMask);
// Case mat-trans-vec: ready to go.
if (layout({{k, m}, {k}, {m}}))
- return outerProd(lhs, rhs, res, lhsType.getDimSize(0), transposedMask);
+ return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0,
+ transposedMask);
// Case vec-mat: swap and transpose.
if (layout({{k}, {m, k}, {m}}))
- return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), transposedMask);
+ return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0,
+ transposedMask);
// Case vec-mat-trans: swap and ready to go.
if (layout({{k}, {k, m}, {m}}))
- return outerProd(rhs, lhs, res, lhsType.getDimSize(0), transposedMask);
+ return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0,
+ transposedMask);
return failure();
}
@@ -528,16 +543,16 @@ struct UnrolledOuterProductGenerator
// Case mat-vec: transpose.
if (layout({{m, k}, {k}, {m}}))
- return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), mask);
+ return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1, mask);
// Case mat-trans-vec: ready to go.
if (layout({{k, m}, {k}, {m}}))
- return outerProd(lhs, rhs, res, lhsType.getDimSize(0), mask);
+ return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0, mask);
// Case vec-mat: swap and transpose.
if (layout({{k}, {m, k}, {m}}))
- return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), mask);
+ return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0, mask);
// Case vec-mat-trans: swap and ready to go.
if (layout({{k}, {k, m}, {m}}))
- return outerProd(rhs, lhs, res, lhsType.getDimSize(0), mask);
+ return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0, mask);
return failure();
}
@@ -980,9 +995,15 @@ FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
<< " to map to the same dimension";
});
+ // Unrolling a scalable dimension would be incorrect - bail out.
+ if (lhsType.getScalableDims()[lhsIndex])
+ return failure();
dimSize = lhsType.getDimSize(lhsIndex);
} else if (rhsIndex >= 0) {
iterIndex = iMap[1].getDimPosition(rhsIndex);
+ // Unrolling a scalable dimension would be incorrect - bail out.
+ if (rhsType.getScalableDims()[rhsIndex])
+ return failure();
dimSize = rhsType.getDimSize(rhsIndex);
}
if (iterIndex < 0)
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir
new file mode 100644
index 000000000000000..a955250107d73d7
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics
+
+#matvec_accesses = [
+ affine_map<(i, j) -> (i, j)>,
+ affine_map<(i, j) -> (j)>,
+ affine_map<(i, j) -> (i)>
+]
+#matvec_trait = {
+ indexing_maps = #matvec_accesses,
+ iterator_types = ["parallel", "reduction"]
+}
+
+// Unrolling scalable reduction dim is not supported - bail out
+
+// expected-error@below {{greedy pattern application failed}}
+func.func @masked_extract_contract2_scalable_reduction_dim(%arg0: vector<[2]x[3]xf32>,
+ %arg1: vector<[3]xf32>,
+ %arg2: vector<[2]xf32>,
+ %m: vector<[2]x[3]xi1>) -> vector<[2]xf32> {
+ %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
+ : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32> } : vector<[2]x[3]xi1> -> vector<[2]xf32>
+ return %0 : vector<[2]xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+ %f = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+
+ transform.apply_patterns to %f {
+ transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+ } : !transform.any_op
+}
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index deea7747f36799c..8ee0a35717ce87a 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -31,19 +31,19 @@
}
// CHECK-LABEL: func.func @masked_extract_contract2(
-// CHECK-SAME: %[[VAL_0:.*]]: vector<2x3xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: vector<3xf32>,
-// CHECK-SAME: %[[VAL_2:.*]]: vector<2xf32>,
-// CHECK-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
+// CHECK-SAME: %{{.*}}: vector<2x3xf32>,
+// CHECK-SAME: %{{.*}}: vector<3xf32>,
+// CHECK-SAME: %{{.*}}: vector<2xf32>,
+// CHECK-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1>
// CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<2xi1> from vector<3x2xi1>
-// CHECK: vector.mask %[[MASK0]] { vector.outerproduct
+// CHECK: vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
// CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<2xi1> from vector<3x2xi1>
-// CHECK: vector.mask %[[MASK1]] { vector.outerproduct
+// CHECK: vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1>
-// CHECK: vector.mask %[[MASK2]] { vector.outerproduct
+// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
%arg1: vector<3xf32>,
@@ -54,22 +54,29 @@ func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
return %0 : vector<2xf32>
}
-// CHECK-LABEL: func.func @masked_extract_contract4(
-// CHECK-SAME: %[[VAL_0:.*]]: vector<3x5xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: vector<5x7xf32>,
-// CHECK-SAME: %[[VAL_2:.*]]: vector<3x7xf32>,
-// CHECK-SAME: %[[VAL_3:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
-// CHECK: %[[VAL_5:.*]] = vector.transpose %[[VAL_3]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
-// CHECK: %[[VAL_8:.*]] = vector.extract %[[VAL_5]][0] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK: %[[VAL_12:.*]] = vector.extract %[[VAL_5]][1] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK: %[[VAL_16:.*]] = vector.extract %[[VAL_5]][2] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK: %[[VAL_20:.*]] = vector.extract %[[VAL_5]][3] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK: %[[VAL_24:.*]] = vector.extract %[[VAL_5]][4] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %[[VAL_25:.*]] = vector.mask %[[VAL_24]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+
+// CHECK-LABEL: func.func @masked_extract_contract2_scalable_parallel_dim(
+// CHECK-SAME: %{{.*}}: vector<[2]x3xf32>,
+// CHECK-SAME: %{{.*}}: vector<3xf32>,
+// CHECK-SAME: %{{.*}}: vector<[2]xf32>,
+// CHECK-SAME: %[[IN_MASK:.*]]: vector<[2]x3xi1>) -> vector<[2]xf32>
+// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<[2]x3xi1> to vector<3x[2]xi1>
+// CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK: vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+
+// CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK: vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+
+// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+func.func @masked_extract_contract2_scalable_parallel_dim(%arg0: vector<[2]x3xf32>,
+ %arg1: vector<3xf32>,
+ %arg2: vector<[2]xf32>,
+ %m: vector<[2]x3xi1>) -> vector<[2]xf32> {
+ %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
+ : vector<[2]x3xf32>, vector<3xf32> into vector<[2]xf32> } : vector<[2]x3xi1> -> vector<[2]xf32>
+ return %0 : vector<[2]xf32>
+}
func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
%arg1: vector<5x7xf32>,
|
You can test this locally with the following command:git-clang-format --diff ff843c00ce1df5af29d5dae671086b92dcabf94b bb27634cbea6803cb3d8e1b15e3a85204f9d40c4 -- mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp View the diff from clang-format here.diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index f0bc7ec6f723..311e589547b8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -424,8 +424,8 @@ struct UnrolledOuterProductGenerator
return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
}
- FailureOr<Value> outerProd(Value lhs, Value rhs, Value res, VectorType lhsType,
- int reductionDim,
+ FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
+ VectorType lhsType, int reductionDim,
std::optional<Value> maybeMask = std::nullopt) {
// Unrolling a scalable dimension would be incorrect - bail out.
if (lhsType.getScalableDims()[reductionDim])
|
@@ -980,9 +995,15 @@ FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter, | |||
diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex | |||
<< " to map to the same dimension"; | |||
}); | |||
// Unrolling a scalable dimension would be incorrect - bail out. | |||
if (lhsType.getScalableDims()[lhsIndex]) | |||
return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we notify with a message here and below ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, will do :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just a small question:
@@ -54,22 +54,29 @@ func.func @masked_extract_contract2(%arg0: vector<2x3xf32>, | |||
return %0 : vector<2xf32> | |||
} | |||
|
|||
// CHECK-LABEL: func.func @masked_extract_contract4( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why was this fixed size test removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copy and paste failure, sorry :( Thanks for pointing this out!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the reviews, I will send an update shortly and then merge it.
@@ -980,9 +995,15 @@ FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter, | |||
diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex | |||
<< " to map to the same dimension"; | |||
}); | |||
// Unrolling a scalable dimension would be incorrect - bail out. | |||
if (lhsType.getScalableDims()[lhsIndex]) | |||
return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, will do :)
@@ -54,22 +54,29 @@ func.func @masked_extract_contract2(%arg0: vector<2x3xf32>, | |||
return %0 : vector<2xf32> | |||
} | |||
|
|||
// CHECK-LABEL: func.func @masked_extract_contract4( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copy and paste failure, sorry :( Thanks for pointing this out!
…uterproduct * Restore CHECK lines that were removed accidentally; * Add diag messages
std::optional<Value> maybeMask = std::nullopt) { | ||
assert(reductionSize > 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why was this assert removed?
@@ -980,9 +995,19 @@ FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter, | |||
diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex | |||
<< " to map to the same dimension"; | |||
}); | |||
if (lhsType.getScalableDims()[lhsIndex]) | |||
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { | |||
diag << "Unrolloing scalable dimension (lhsIndex=" << lhsIndex |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s/Unrolloing/Unrolling
dimSize = lhsType.getDimSize(lhsIndex); | ||
} else if (rhsIndex >= 0) { | ||
iterIndex = iMap[1].getDimPosition(rhsIndex); | ||
if (rhsType.getScalableDims()[rhsIndex]) | ||
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { | ||
diag << "Unrolloing scalable dimension (lhsIndex=" << lhsIndex |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be rhsIndex
?
Follow-up for llvm#68400 - restoring an assert that was accidentally removed and fixed a typo in a diagnostic.
Follow-up for #68400 - restoring an assert that was accidentally removed and fixed a typo in a diagnostic.
This patch constrains the patterns for converting
vector.contract
tovector.outerproduct
so thatdimension is scalable.
This is necessary as the current lowering is incorrect for scalable
dims. Indeed, instead of the following unrolling that is currently being
generated for
vector.contract
(K is the size f the reductiondimension):
we should be generating a
for
loop like the following:However, the lowering of
vector.extract
of vector slices with dynamicindices is incomplete and hence the implementation above wouldn't work
just yet. Instead, this patch effectively disables unrolling in case
where the generated code would be functionally incorrect (i.e. when the
reduction dimension is scalable).
In order to document unsupported cases, a dedicated test file is added:
This is the first patch in a series of patches that strives to update
these patterns (and to test them) for scalable vectors.