Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][SVE] Add an e2e test for vector.contract #69845

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3067,9 +3067,12 @@ LogicalResult OuterProductOp::verify() {
return emitOpError("expected #1 operand dim to match result dim #1");
if (vRHS.getDimSize(0) != vRES.getDimSize(1))
return emitOpError("expected #2 operand dim to match result dim #2");
if (vRHS.isScalable() != vLHS.isScalable())
return emitOpError("expected either all or none of vector operands #1 "
"and #2 to be scalable");
if (vLHS.isScalable() && !vRHS.isScalable()) {
// This restriction reflects what's currently supported in terms of
// scalable vectors. However, we could relax this if there's a use case.
return emitOpError(
"expected either both or only #2 operand dim to be scalable");
}
} else {
// An AXPY operation.
if (vRES.getRank() != 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,21 +79,21 @@ func.func @masked_extract_contract2_scalable_parallel_dim(%arg0: vector<[2]x3xf3
}

// CHECK-LABEL: func.func @masked_extract_contract4(
// CHECK-SAME: %[[VAL_0:.*]]: vector<3x5xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<5x7xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: vector<3x7xf32>,
// CHECK-SAME: %[[VAL_3:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
// CHECK: %[[VAL_5:.*]] = vector.transpose %[[VAL_3]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
// CHECK: %[[VAL_8:.*]] = vector.extract %[[VAL_5]][0] : vector<3x7xi1> from vector<5x3x7xi1>
// CHECK: %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
// CHECK: %[[VAL_12:.*]] = vector.extract %[[VAL_5]][1] : vector<3x7xi1> from vector<5x3x7xi1>
// CHECK: %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
// CHECK: %[[VAL_16:.*]] = vector.extract %[[VAL_5]][2] : vector<3x7xi1> from vector<5x3x7xi1>
// CHECK: %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
// CHECK: %[[VAL_20:.*]] = vector.extract %[[VAL_5]][3] : vector<3x7xi1> from vector<5x3x7xi1>
// CHECK: %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
// CHECK: %[[VAL_24:.*]] = vector.extract %[[VAL_5]][4] : vector<3x7xi1> from vector<5x3x7xi1>
// CHECK: %[[VAL_25:.*]] = vector.mask %[[VAL_24]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
// CHECK-SAME: %{{.*}}: vector<3x5xf32>,
// CHECK-SAME: %{{.*}}: vector<5x7xf32>,
// CHECK-SAME: %{{.*}}: vector<3x7xf32>,
// CHECK-SAME: %[[IN_MASK:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
// CHECK: %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x7xi1> from vector<5x3x7xi1>
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
// CHECK: %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x7xi1> from vector<5x3x7xi1>
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
// CHECK: %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x7xi1> from vector<5x3x7xi1>
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
// CHECK: %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x7xi1> from vector<5x3x7xi1>
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
// CHECK: %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x7xi1> from vector<5x3x7xi1>
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>

func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
%arg1: vector<5x7xf32>,
Expand All @@ -104,6 +104,35 @@ func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
return %0 : vector<3x7xf32>
}

// CHECK-LABEL: func.func @masked_extract_contract4_scalable_J_dim(
// CHECK-SAME: %{{.*}}: vector<3x5xf32>,
// CHECK-SAME: %{{.*}}: vector<5x[7]xf32>,
// CHECK-SAME: %{{.*}}: vector<3x[7]xf32>,
// CHECK-SAME: %[[IN_MASK:.*]]: vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x[7]x5xi1> to vector<5x3x[7]xi1>
// CHECK: %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
// CHECK: %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
// CHECK: %[[VAL_13:.*]] = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
// CHECK: %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
// CHECK: %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
// CHECK: %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>

// Note that only the J dimension is scalable in this example. In theory, all
// dimensions could be be scalable, but there is no target yet for which this
// would make sense.
func.func @masked_extract_contract4_scalable_J_dim(%arg0: vector<3x5xf32>,
%arg1: vector<5x[7]xf32>,
%arg2: vector<3x[7]xf32>,
%m : vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
%0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
: vector<3x5xf32>, vector<5x[7]xf32> into vector<3x[7]xf32> } : vector<3x[7]x5xi1> -> vector<3x[7]xf32>
return %0 : vector<3x[7]xf32>
}

// CHECK-LABEL: func @matmul
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
Expand Down
5 changes: 4 additions & 1 deletion mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@ func.func @invalid_outerproduct(%src : memref<?xf32>) {
%0 = vector.load %src[%idx] : memref<?xf32>, vector<[4]xf32>
%1 = vector.load %src[%idx] : memref<?xf32>, vector<4xf32>

// expected-error @+1 {{expected either all or none of vector operands #1 and #2 to be scalable}}
// expected-error @+1 {{expected either both or only #2 operand dim to be scalable}}
%op = vector.outerproduct %0, %1 : vector<[4]xf32>, vector<4xf32>

return
}

// -----

func.func @invalid_outerproduct1(%src : memref<?xf32>) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// DEFINE: %{compile} = mlir-opt %s -test-transform-dialect-interpreter -test-transform-dialect-erase-schedule\
// DEFINE: -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage\
// DEFINE: -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm -o %t
// DEFINE: %{entry} =
// DEFINE: %{run} = %mcr_aarch64_cmd %t -e=%{entry} -entry-point-result=void --march=aarch64 --mattr="+sve" -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext

// This check whether the files compiles and generates a temporary that will be executed further down.
// RUN: %{compile}

// REDEFINE: %{entry} = matmul_i32
// RUN: %{run} | FileCheck %s --check-prefix=I32

// REDEFINE: %{entry} = matmul_f32
// RUN: %{run} | FileCheck %s --check-prefix=F32

// NOTE: These tests are meant to complement the integration tests from:
// * ../test-contraction.mlir
// (tests with fixed width vectors). Rather than duplicating those tests, this
// file focuses on excercissing scalable vectors in a few most common cases.

// TODO: Masks + matvec + dot product

#matmat_accesses = [
affine_map<(i, j, k) -> (i, k)>,
affine_map<(i, j, k) -> (k, j)>,
affine_map<(i, j, k) -> (i, j)>
]
#matmat_trait = {
indexing_maps = #matmat_accesses,
iterator_types = ["parallel", "parallel", "reduction"]
}

func.func @matmul_i32() {
// Setup vector A:
%vector_a = arith.constant dense<123> : vector<3x5xi32>

// Setup vector B:
%vector_b = arith.constant dense<123> : vector<5x[2]xi32>

// Setup vector C:
%vector_c = arith.constant dense<314> : vector<3x[2]xi32>

// Matmul
%0 = vector.contract #matmat_trait %vector_a, %vector_b, %vector_c
: vector<3x5xi32>, vector<5x[2]xi32> into vector<3x[2]xi32>

// Print the output
%slice1 = vector.extract %0[0] : vector<[2]xi32> from vector<3x[2]xi32>
// I32: ( 75959, 75959
vector.print %slice1 : vector<[2]xi32>
%slice2 = vector.extract %0[1] : vector<[2]xi32> from vector<3x[2]xi32>
// I32-NEXT: ( 75959, 75959
vector.print %slice2 : vector<[2]xi32>
%slice3 = vector.extract %0[2] : vector<[2]xi32> from vector<3x[2]xi32>
// I32-NEXT: ( 75959, 75959
vector.print %slice3 : vector<[2]xi32>

// CHECK: SVE: END OF TEST OUTPUT
vector.print str "SVE: END OF TEST OUTPUT"

return
}

func.func @matmul_f32() {
// Setup vector A:
%vector_a = arith.constant dense<1.23> : vector<3x5xf32>

// Setup vector B:
%vector_b = arith.constant dense<1.23> : vector<5x[2]xf32>

// Setup vector C:
%vector_c = arith.constant dense<3.14> : vector<3x[2]xf32>

// Matmul
%0 = vector.contract #matmat_trait %vector_a, %vector_b, %vector_c
: vector<3x5xf32>, vector<5x[2]xf32> into vector<3x[2]xf32>

// Print the output
%slice1 = vector.extract %0[0] : vector<[2]xf32> from vector<3x[2]xf32>
// F32: ( 10.7045, 10.7045
vector.print %slice1 : vector<[2]xf32>
%slice2 = vector.extract %0[1] : vector<[2]xf32> from vector<3x[2]xf32>
// F32-NEXT: ( 10.7045, 10.7045
vector.print %slice2 : vector<[2]xf32>
%slice3 = vector.extract %0[2] : vector<[2]xf32> from vector<3x[2]xf32>
// F32-NEXT: ( 10.7045, 10.7045
vector.print %slice3 : vector<[2]xf32>

// CHECK: SVE: END OF TEST OUTPUT
vector.print str "SVE: END OF TEST OUTPUT"

return
}

transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
%f = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op

transform.apply_patterns to %f {
transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
} : !transform.any_op
}