Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][vector] Make ReorderElementwiseOpsOnBroadcast support vector.splat #66596

Merged

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Sep 17, 2023

Extend ReorderElementwiseOpsOnBroadcast so that the broadcasting op
could be either vector.broadcast (already supported) as well as
vector.splat (support added in this patch).

…plat

Extend `ReorderElementwiseOpsOnBroadcast` so that the broadcastinvg op
could be either `vector.broadcast` (alrady supported) as well as
`vector.splat` (support added in this patch).
@llvmbot
Copy link
Member

llvmbot commented Sep 17, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Changes

Extend ReorderElementwiseOpsOnBroadcast so that the broadcasting op
could be either vector.broadcast (already supported) as well as
vector.splat (support added in this patch).


Full diff: https://github.com/llvm/llvm-project/pull/66596.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+25-15)
  • (modified) mlir/test/Dialect/Vector/sink-vector-broadcast.mlir (+34-5)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 207df69929c1c9f..b2a5aef5ee62d0f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -880,7 +880,7 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
   std::function<bool(BitCastOp)> controlFn;
 };
 
-/// Reorders elementwise(broadcast) to broadcast(elementwise). Ex:
+/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
 /// ```
 /// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
 /// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
@@ -891,6 +891,9 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
 /// %r = arith.addi %arg0, %arg1 : index
 /// %b = vector.broadcast %r : index to vector<1x4xindex>
 /// ```
+///
+/// Both `vector.broadcast` and `vector.splat` are supported as broadcasting
+/// ops.
 struct ReorderElementwiseOpsOnBroadcast final
     : public OpTraitRewritePattern<OpTrait::Elementwise> {
   using OpTraitRewritePattern::OpTraitRewritePattern;
@@ -903,35 +906,42 @@ struct ReorderElementwiseOpsOnBroadcast final
     if (!OpTrait::hasElementwiseMappableTraits(op))
       return failure();
 
-    // Get the type of the first operand
-    auto firstBcast = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
-    if (!firstBcast)
+    // Get the type of the lhs operand
+    auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
+    if (!lhsBcastOrSplat ||
+        !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
       return failure();
-    auto firstOpType = firstBcast.getOperand().getType();
+    auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
 
-    // Make sure that operands are "broadcast"ed from identical (scalar or
-    // vector) types. That indicates that it's safe to skip the broadcasting of
-    // operands.
-    if (!llvm::all_of(op->getOperands(), [&firstOpType](Value val) {
+    // Make sure that all operands are broadcast from identical types:
+    //  * scalar (`vector.broadcast` + `vector.splat`), or
+    //  * vector (`vector.broadcast`).
+    // Otherwise the re-ordering wouldn't be safe.
+    if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
           auto bcast = val.getDefiningOp<vector::BroadcastOp>();
-          return (bcast && (bcast.getOperand().getType() == firstOpType));
+          if (bcast)
+            return (bcast.getOperand().getType() == lhsBcastOrSplatType);
+          auto splat = val.getDefiningOp<vector::SplatOp>();
+          if (splat)
+            return (splat.getOperand().getType() == lhsBcastOrSplatType);
+          return false;
         })) {
       return failure();
     }
 
-    // Collect the source values
+    // Collect the source values before broadcasting
     SmallVector<Value> srcValues;
     srcValues.reserve(op->getNumOperands());
-
     for (Value operand : op->getOperands()) {
-      srcValues.push_back(
-          operand.getDefiningOp<vector::BroadcastOp>().getOperand());
+      srcValues.push_back(operand.getDefiningOp()->getOperand(0));
     }
 
+    // Create the "elementwise" Op
     Operation *elementwiseOp =
         rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
-                        firstOpType, op->getAttrs());
+                        lhsBcastOrSplatType, op->getAttrs());
 
+    // Replace the original Op with the elementwise Op
     auto vectorType = op->getResultTypes()[0];
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
         op, vectorType, elementwiseOp->getResults());
diff --git a/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir b/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir
index fcf9815f6f6f1d1..d9d2f44e6f16c1f 100644
--- a/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir
+++ b/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir
@@ -1,13 +1,12 @@
 // RUN: mlir-opt %s -test-sink-vector-broadcast -split-input-file | FileCheck %s
 
-// CHECK-LABEL:   func.func @broadcast_scalar(
+// CHECK-LABEL:   func.func @broadcast_scalar_with_bcast(
 // CHECK-SAME:     %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x4xindex> {
 // CHECK:           %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index
 // CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
 // CHECK:           return %[[BCAST]] : vector<1x4xindex>
-// CHECK:         }
 
-func.func @broadcast_scalar( %arg1: index, %arg2: index) -> vector<1x4xindex> {
+func.func @broadcast_scalar_with_bcast( %arg1: index, %arg2: index) -> vector<1x4xindex> {
   %0 = vector.broadcast %arg1 : index to vector<1x4xindex>
   %1 = vector.broadcast %arg2 : index to vector<1x4xindex>
   %2 = arith.addi %0, %1 : vector<1x4xindex>
@@ -16,13 +15,27 @@ func.func @broadcast_scalar( %arg1: index, %arg2: index) -> vector<1x4xindex> {
 
 // -----
 
+// CHECK-LABEL:   func.func @broadcast_scalar_with_bcast_and_splat(
+// CHECK-SAME:      %[[ARG1:.*]]: index,
+// CHECK-SAME:      %[[ARG2:.*]]: index) -> vector<1x4xindex> {
+// CHECK:           %[[ADD:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
+// CHECK:           return %[[BCAST]] : vector<1x4xindex>
+func.func @broadcast_scalar_with_bcast_and_splat( %arg1: index, %arg2: index) -> vector<1x4xindex> {
+  %0 = vector.splat %arg1 : vector<1x4xindex>
+  %1 = vector.broadcast %arg2 : index to vector<1x4xindex>
+  %2 = arith.addi %0, %1 : vector<1x4xindex>
+  return %2 : vector<1x4xindex>
+}
+
+// -----
+
 // CHECK-LABEL:   func.func @broadcast_vector(
 // CHECK-SAME:      %[[ARG_0:.*]]: vector<4xf32>,
 // CHECK-SAME:      %[[ARG_1:.*]]: vector<4xf32>) -> vector<3x4xf32> {
 // CHECK:           %[[ADDF:.*]] = arith.addf %[[ARG_0]], %[[ARG_1]] : vector<4xf32>
 // CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ADDF]] : vector<4xf32> to vector<3x4xf32>
 // CHECK:           return %[[BCAST]] : vector<3x4xf32>
-// CHECK:         }
 
 func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vector<3x4xf32> {
   %arg1_bcast = vector.broadcast %arg1 : vector<4xf32> to vector<3x4xf32>
@@ -30,6 +43,23 @@ func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vect
   %2 = arith.addf %arg1_bcast, %arg2_bcast : vector<3x4xf32>
   return %2 : vector<3x4xf32>
 }
+
+// -----
+
+// CHECK-LABEL:   func.func @broadcast_scalar_and_vec(
+// CHECK-SAME:       %[[ARG1:.*]]: index,
+// CHECK-SAME:       %[[ARG2:.*]]: vector<4xindex>) -> vector<1x4xindex> {
+// CHECK:            %[[SPLAT:.*]] = vector.splat %[[ARG1]] : vector<1x4xindex>
+// CHECK:            %[[BCAST:.*]] = vector.broadcast %[[ARG2]] : vector<4xindex> to vector<1x4xindex>
+// CHECK:            %[[ADD:.*]] = arith.addi %[[SPLAT]], %[[BCAST]] : vector<1x4xindex>
+// CHECK:            return %[[ADD]] : vector<1x4xindex>
+func.func @broadcast_scalar_and_vec( %arg1: index, %arg2: vector<4xindex>) -> vector<1x4xindex> {
+  %0 = vector.splat %arg1 : vector<1x4xindex>
+  %1 = vector.broadcast %arg2 : vector<4xindex> to vector<1x4xindex>
+  %2 = arith.addi %0, %1 : vector<1x4xindex>
+  return %2 : vector<1x4xindex>
+}
+
 // -----
 
 // CHECK-LABEL:   func.func @broadcast_vector_and_scalar(
@@ -38,7 +68,6 @@ func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vect
 // CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : i32 to vector<4xi32>
 // CHECK:           %[[ADD:.*]] = arith.addi %[[BCAST]], %[[ARG_1]] : vector<4xi32>
 // CHECK:           return %[[ADD]] : vector<4xi32>
-// CHECK:         }
 
 func.func @broadcast_vector_and_scalar( %arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> {
   %arg1_bcast = vector.broadcast %arg1 : i32 to vector<4xi32>

@dcaballe
Copy link
Contributor

I guess the first question that comes to mind is... why do we have a vector.broadcast and a vector.splat? :)
Isn't the former a superset of the latter? Is the latter just legacy?

@banach-space
Copy link
Contributor Author

I guess the first question that comes to mind is... why do we have a vector.broadcast and a vector.splat? :) Isn't the former a superset of the latter? Is the latter just legacy?

Yup, was thinking the same the other day. I can try merging vector.broadcast and vector.splat in a separate patch. Shall I land this anyway?

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG then! Let's land this if this is blocking your right now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants