-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Make ReorderElementwiseOpsOnBroadcast support vector.splat #66596
[mlir][vector] Make ReorderElementwiseOpsOnBroadcast support vector.splat #66596
Conversation
…plat Extend `ReorderElementwiseOpsOnBroadcast` so that the broadcastinvg op could be either `vector.broadcast` (alrady supported) as well as `vector.splat` (support added in this patch).
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector ChangesExtend Full diff: https://github.com/llvm/llvm-project/pull/66596.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 207df69929c1c9f..b2a5aef5ee62d0f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -880,7 +880,7 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
std::function<bool(BitCastOp)> controlFn;
};
-/// Reorders elementwise(broadcast) to broadcast(elementwise). Ex:
+/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
/// ```
/// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
/// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
@@ -891,6 +891,9 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
/// %r = arith.addi %arg0, %arg1 : index
/// %b = vector.broadcast %r : index to vector<1x4xindex>
/// ```
+///
+/// Both `vector.broadcast` and `vector.splat` are supported as broadcasting
+/// ops.
struct ReorderElementwiseOpsOnBroadcast final
: public OpTraitRewritePattern<OpTrait::Elementwise> {
using OpTraitRewritePattern::OpTraitRewritePattern;
@@ -903,35 +906,42 @@ struct ReorderElementwiseOpsOnBroadcast final
if (!OpTrait::hasElementwiseMappableTraits(op))
return failure();
- // Get the type of the first operand
- auto firstBcast = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
- if (!firstBcast)
+ // Get the type of the lhs operand
+ auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
+ if (!lhsBcastOrSplat ||
+ !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
return failure();
- auto firstOpType = firstBcast.getOperand().getType();
+ auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
- // Make sure that operands are "broadcast"ed from identical (scalar or
- // vector) types. That indicates that it's safe to skip the broadcasting of
- // operands.
- if (!llvm::all_of(op->getOperands(), [&firstOpType](Value val) {
+ // Make sure that all operands are broadcast from identical types:
+ // * scalar (`vector.broadcast` + `vector.splat`), or
+ // * vector (`vector.broadcast`).
+ // Otherwise the re-ordering wouldn't be safe.
+ if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
auto bcast = val.getDefiningOp<vector::BroadcastOp>();
- return (bcast && (bcast.getOperand().getType() == firstOpType));
+ if (bcast)
+ return (bcast.getOperand().getType() == lhsBcastOrSplatType);
+ auto splat = val.getDefiningOp<vector::SplatOp>();
+ if (splat)
+ return (splat.getOperand().getType() == lhsBcastOrSplatType);
+ return false;
})) {
return failure();
}
- // Collect the source values
+ // Collect the source values before broadcasting
SmallVector<Value> srcValues;
srcValues.reserve(op->getNumOperands());
-
for (Value operand : op->getOperands()) {
- srcValues.push_back(
- operand.getDefiningOp<vector::BroadcastOp>().getOperand());
+ srcValues.push_back(operand.getDefiningOp()->getOperand(0));
}
+ // Create the "elementwise" Op
Operation *elementwiseOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
- firstOpType, op->getAttrs());
+ lhsBcastOrSplatType, op->getAttrs());
+ // Replace the original Op with the elementwise Op
auto vectorType = op->getResultTypes()[0];
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
op, vectorType, elementwiseOp->getResults());
diff --git a/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir b/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir
index fcf9815f6f6f1d1..d9d2f44e6f16c1f 100644
--- a/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir
+++ b/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir
@@ -1,13 +1,12 @@
// RUN: mlir-opt %s -test-sink-vector-broadcast -split-input-file | FileCheck %s
-// CHECK-LABEL: func.func @broadcast_scalar(
+// CHECK-LABEL: func.func @broadcast_scalar_with_bcast(
// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x4xindex> {
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
// CHECK: return %[[BCAST]] : vector<1x4xindex>
-// CHECK: }
-func.func @broadcast_scalar( %arg1: index, %arg2: index) -> vector<1x4xindex> {
+func.func @broadcast_scalar_with_bcast( %arg1: index, %arg2: index) -> vector<1x4xindex> {
%0 = vector.broadcast %arg1 : index to vector<1x4xindex>
%1 = vector.broadcast %arg2 : index to vector<1x4xindex>
%2 = arith.addi %0, %1 : vector<1x4xindex>
@@ -16,13 +15,27 @@ func.func @broadcast_scalar( %arg1: index, %arg2: index) -> vector<1x4xindex> {
// -----
+// CHECK-LABEL: func.func @broadcast_scalar_with_bcast_and_splat(
+// CHECK-SAME: %[[ARG1:.*]]: index,
+// CHECK-SAME: %[[ARG2:.*]]: index) -> vector<1x4xindex> {
+// CHECK: %[[ADD:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
+// CHECK: return %[[BCAST]] : vector<1x4xindex>
+func.func @broadcast_scalar_with_bcast_and_splat( %arg1: index, %arg2: index) -> vector<1x4xindex> {
+ %0 = vector.splat %arg1 : vector<1x4xindex>
+ %1 = vector.broadcast %arg2 : index to vector<1x4xindex>
+ %2 = arith.addi %0, %1 : vector<1x4xindex>
+ return %2 : vector<1x4xindex>
+}
+
+// -----
+
// CHECK-LABEL: func.func @broadcast_vector(
// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>,
// CHECK-SAME: %[[ARG_1:.*]]: vector<4xf32>) -> vector<3x4xf32> {
// CHECK: %[[ADDF:.*]] = arith.addf %[[ARG_0]], %[[ARG_1]] : vector<4xf32>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADDF]] : vector<4xf32> to vector<3x4xf32>
// CHECK: return %[[BCAST]] : vector<3x4xf32>
-// CHECK: }
func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vector<3x4xf32> {
%arg1_bcast = vector.broadcast %arg1 : vector<4xf32> to vector<3x4xf32>
@@ -30,6 +43,23 @@ func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vect
%2 = arith.addf %arg1_bcast, %arg2_bcast : vector<3x4xf32>
return %2 : vector<3x4xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_scalar_and_vec(
+// CHECK-SAME: %[[ARG1:.*]]: index,
+// CHECK-SAME: %[[ARG2:.*]]: vector<4xindex>) -> vector<1x4xindex> {
+// CHECK: %[[SPLAT:.*]] = vector.splat %[[ARG1]] : vector<1x4xindex>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG2]] : vector<4xindex> to vector<1x4xindex>
+// CHECK: %[[ADD:.*]] = arith.addi %[[SPLAT]], %[[BCAST]] : vector<1x4xindex>
+// CHECK: return %[[ADD]] : vector<1x4xindex>
+func.func @broadcast_scalar_and_vec( %arg1: index, %arg2: vector<4xindex>) -> vector<1x4xindex> {
+ %0 = vector.splat %arg1 : vector<1x4xindex>
+ %1 = vector.broadcast %arg2 : vector<4xindex> to vector<1x4xindex>
+ %2 = arith.addi %0, %1 : vector<1x4xindex>
+ return %2 : vector<1x4xindex>
+}
+
// -----
// CHECK-LABEL: func.func @broadcast_vector_and_scalar(
@@ -38,7 +68,6 @@ func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vect
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : i32 to vector<4xi32>
// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[ARG_1]] : vector<4xi32>
// CHECK: return %[[ADD]] : vector<4xi32>
-// CHECK: }
func.func @broadcast_vector_and_scalar( %arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> {
%arg1_bcast = vector.broadcast %arg1 : i32 to vector<4xi32>
|
I guess the first question that comes to mind is... why do we have a |
Yup, was thinking the same the other day. I can try merging |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SG then! Let's land this if this is blocking your right now
Extend
ReorderElementwiseOpsOnBroadcast
so that the broadcasting opcould be either
vector.broadcast
(already supported) as well asvector.splat
(support added in this patch).