diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp index 04d9ddf2183f8c5..f0bc7ec6f723a97 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -424,9 +424,14 @@ struct UnrolledOuterProductGenerator return rewriter.create(loc, promotedType, v); } - FailureOr outerProd(Value lhs, Value rhs, Value res, int reductionSize, + FailureOr outerProd(Value lhs, Value rhs, Value res, VectorType lhsType, + int reductionDim, std::optional maybeMask = std::nullopt) { - assert(reductionSize > 0); + // Unrolling a scalable dimension would be incorrect - bail out. + if (lhsType.getScalableDims()[reductionDim]) + return failure(); + + int reductionSize = lhsType.getDimSize(reductionDim); // Incremental support for masking. if (mask && !maybeMask.has_value()) return failure(); @@ -459,33 +464,39 @@ struct UnrolledOuterProductGenerator Value transposedMask = t(mask, {2, 0, 1}); // Classical row-major matmul: Just permute the lhs. if (layout({{m, k}, {k, n}, {m, n}})) - return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), transposedMask); + return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1, + transposedMask); // TODO: may be better to fail and use some vector -> scalar reduction. if (layout({{m, k}, {n, k}, {m, n}})) { Value tlhs = t(lhs); - return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1), + return outerProd(tlhs, t(rhs), res, lhsType, /*reductionDim=*/1, transposedMask); } // No need to permute anything. if (layout({{k, m}, {k, n}, {m, n}})) - return outerProd(lhs, rhs, res, lhsType.getDimSize(0), transposedMask); + return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0, + transposedMask); // Just permute the rhs. if (layout({{k, m}, {n, k}, {m, n}})) - return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0), transposedMask); + return outerProd(lhs, t(rhs), res, lhsType, /*reductionDim=*/0, + transposedMask); // Transposed output: swap RHS and LHS. // Classical row-major matmul: permute the lhs. if (layout({{m, k}, {k, n}, {n, m}})) - return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1), transposedMask); + return outerProd(rhs, t(lhs), res, lhsType, /*reductionDim=*/1, + transposedMask); // TODO: may be better to fail and use some vector -> scalar reduction. if (layout({{m, k}, {n, k}, {n, m}})) { Value trhs = t(rhs); - return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1), + return outerProd(trhs, t(lhs), res, lhsType, /*reductionDim=*/1, transposedMask); } if (layout({{k, m}, {k, n}, {n, m}})) - return outerProd(rhs, lhs, res, lhsType.getDimSize(0), transposedMask); + return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0, + transposedMask); if (layout({{k, m}, {n, k}, {n, m}})) - return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), transposedMask); + return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0, + transposedMask); return failure(); } @@ -503,16 +514,20 @@ struct UnrolledOuterProductGenerator // Case mat-vec: transpose. if (layout({{m, k}, {k}, {m}})) - return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), transposedMask); + return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1, + transposedMask); // Case mat-trans-vec: ready to go. if (layout({{k, m}, {k}, {m}})) - return outerProd(lhs, rhs, res, lhsType.getDimSize(0), transposedMask); + return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0, + transposedMask); // Case vec-mat: swap and transpose. if (layout({{k}, {m, k}, {m}})) - return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), transposedMask); + return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0, + transposedMask); // Case vec-mat-trans: swap and ready to go. if (layout({{k}, {k, m}, {m}})) - return outerProd(rhs, lhs, res, lhsType.getDimSize(0), transposedMask); + return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0, + transposedMask); return failure(); } @@ -528,16 +543,16 @@ struct UnrolledOuterProductGenerator // Case mat-vec: transpose. if (layout({{m, k}, {k}, {m}})) - return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), mask); + return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1, mask); // Case mat-trans-vec: ready to go. if (layout({{k, m}, {k}, {m}})) - return outerProd(lhs, rhs, res, lhsType.getDimSize(0), mask); + return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0, mask); // Case vec-mat: swap and transpose. if (layout({{k}, {m, k}, {m}})) - return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), mask); + return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0, mask); // Case vec-mat-trans: swap and ready to go. if (layout({{k}, {k, m}, {m}})) - return outerProd(rhs, lhs, res, lhsType.getDimSize(0), mask); + return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0, mask); return failure(); } @@ -980,9 +995,19 @@ FailureOr ContractionOpLowering::lowerParallel(PatternRewriter &rewriter, diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex << " to map to the same dimension"; }); + if (lhsType.getScalableDims()[lhsIndex]) + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag << "Unrolloing scalable dimension (lhsIndex=" << lhsIndex + << ") is not supported yet"; + }); dimSize = lhsType.getDimSize(lhsIndex); } else if (rhsIndex >= 0) { iterIndex = iMap[1].getDimPosition(rhsIndex); + if (rhsType.getScalableDims()[rhsIndex]) + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag << "Unrolloing scalable dimension (lhsIndex=" << lhsIndex + << ") is not supported yet"; + }); dimSize = rhsType.getDimSize(rhsIndex); } if (iterIndex < 0) diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir new file mode 100644 index 000000000000000..a955250107d73d7 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir @@ -0,0 +1,33 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics + +#matvec_accesses = [ + affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (j)>, + affine_map<(i, j) -> (i)> +] +#matvec_trait = { + indexing_maps = #matvec_accesses, + iterator_types = ["parallel", "reduction"] +} + +// Unrolling scalable reduction dim is not supported - bail out + +// expected-error@below {{greedy pattern application failed}} +func.func @masked_extract_contract2_scalable_reduction_dim(%arg0: vector<[2]x[3]xf32>, + %arg1: vector<[3]xf32>, + %arg2: vector<[2]xf32>, + %m: vector<[2]x[3]xi1>) -> vector<[2]xf32> { + %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2 + : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32> } : vector<[2]x[3]xi1> -> vector<[2]xf32> + return %0 : vector<[2]xf32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !transform.any_op): + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + + transform.apply_patterns to %f { + transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" + } : !transform.any_op +} diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir index deea7747f36799c..1e92fcff64dea57 100644 --- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir @@ -31,19 +31,19 @@ } // CHECK-LABEL: func.func @masked_extract_contract2( -// CHECK-SAME: %[[VAL_0:.*]]: vector<2x3xf32>, -// CHECK-SAME: %[[VAL_1:.*]]: vector<3xf32>, -// CHECK-SAME: %[[VAL_2:.*]]: vector<2xf32>, -// CHECK-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32> +// CHECK-SAME: %{{.*}}: vector<2x3xf32>, +// CHECK-SAME: %{{.*}}: vector<3xf32>, +// CHECK-SAME: %{{.*}}: vector<2xf32>, +// CHECK-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32> // CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1> // CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<2xi1> from vector<3x2xi1> -// CHECK: vector.mask %[[MASK0]] { vector.outerproduct +// CHECK: vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32> // CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<2xi1> from vector<3x2xi1> -// CHECK: vector.mask %[[MASK1]] { vector.outerproduct +// CHECK: vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32> // CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1> -// CHECK: vector.mask %[[MASK2]] { vector.outerproduct +// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32> func.func @masked_extract_contract2(%arg0: vector<2x3xf32>, %arg1: vector<3xf32>, @@ -54,6 +54,30 @@ func.func @masked_extract_contract2(%arg0: vector<2x3xf32>, return %0 : vector<2xf32> } + +// CHECK-LABEL: func.func @masked_extract_contract2_scalable_parallel_dim( +// CHECK-SAME: %{{.*}}: vector<[2]x3xf32>, +// CHECK-SAME: %{{.*}}: vector<3xf32>, +// CHECK-SAME: %{{.*}}: vector<[2]xf32>, +// CHECK-SAME: %[[IN_MASK:.*]]: vector<[2]x3xi1>) -> vector<[2]xf32> +// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<[2]x3xi1> to vector<3x[2]xi1> +// CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<[2]xi1> from vector<3x[2]xi1> +// CHECK: vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32> + +// CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<[2]xi1> from vector<3x[2]xi1> +// CHECK: vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32> + +// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<3x[2]xi1> +// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32> +func.func @masked_extract_contract2_scalable_parallel_dim(%arg0: vector<[2]x3xf32>, + %arg1: vector<3xf32>, + %arg2: vector<[2]xf32>, + %m: vector<[2]x3xi1>) -> vector<[2]xf32> { + %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2 + : vector<[2]x3xf32>, vector<3xf32> into vector<[2]xf32> } : vector<[2]x3xi1> -> vector<[2]xf32> + return %0 : vector<[2]xf32> +} + // CHECK-LABEL: func.func @masked_extract_contract4( // CHECK-SAME: %[[VAL_0:.*]]: vector<3x5xf32>, // CHECK-SAME: %[[VAL_1:.*]]: vector<5x7xf32>,