Skip to content

Commit

Permalink
Reapply "Propagate reshapes through generics with reduction… (iree-or…
Browse files Browse the repository at this point in the history
…g#18968)

This reverts commit 8d3faf8.

Signed-off-by: Ian Wood <[email protected]>
  • Loading branch information
IanWood1 committed Nov 13, 2024
1 parent 2bfc639 commit 6dd0bcc
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 15 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/pkgci_regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ jobs:
--goldentime-rocm-vae-ms 337.0 \
--goldendispatch-rocm-unet 1531 \
--goldendispatch-rocm-clip 1141 \
--goldendispatch-rocm-vae 246 \
--goldendispatch-rocm-vae 245 \
--goldensize-rocm-unet-bytes 2280000 \
--goldensize-rocm-clip-bytes 860000 \
--goldensize-rocm-vae-bytes 840000 \
Expand All @@ -243,7 +243,7 @@ jobs:
--goldentime-rocm-vae-ms 80.0 \
--goldendispatch-rocm-unet 1531 \
--goldendispatch-rocm-clip 1141 \
--goldendispatch-rocm-vae 246 \
--goldendispatch-rocm-vae 245 \
--goldensize-rocm-unet-bytes 2270000 \
--goldensize-rocm-clip-bytes 860000 \
--goldensize-rocm-vae-bytes 840000 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,13 @@ util.func public @grouped_quantized_matmul(%arg0: tensor<4096x32x128xi4>, %arg1:
// CHECK: flow.executable private @[[EXECUTABLE0:[a-zA-Z0-9_]+]]
// CHECK: func.func @[[FUNC0:[a-zA-Z0-9_x]+]]
// CHECK: %[[GEN0:.+]] = linalg.generic
// CHECK-SAME: ["parallel", "parallel", "parallel"]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel", "parallel"]
// CHECK: arith.extui
// CHECK: arith.uitofp
// CHECK: arith.subf
// CHECK: arith.mulf
// CHECK: %[[GEN1:.+]] = linalg.generic
// CHECK-SAME: ["parallel", "reduction", "reduction"]
// CHECK-SAME: ["parallel", "parallel", "parallel", "reduction", "reduction"]
// CHECK-SAME: ins(
// CHECK-SAME: %[[GEN0]]
// CHECK-SAME: outs(
Expand All @@ -95,5 +95,4 @@ util.func public @grouped_quantized_matmul(%arg0: tensor<4096x32x128xi4>, %arg1:
// CHECK: flow.dispatch.tensor.store %[[GEN1]]
// CHECK: util.func public @grouped_quantized_matmul(
// CHECK: %[[T0:.+]] = flow.dispatch @[[EXECUTABLE0]]::@[[FUNC0]]
// CHECK: %[[RS:.+]] = flow.tensor.reshape %[[T0]] : tensor<4096xf32> -> tensor<1x1x4096xf32>
// CHECK: util.return %[[RS]]
// CHECK: util.return %[[T0]]
Original file line number Diff line number Diff line change
Expand Up @@ -57,24 +57,18 @@ void BubbleUpExpandShapesPass::runOnOperation() {
return false;
}

// Do not fuse producer generic op if it has more than one user
// or any reduction iterators.
if (auto producerGenericOp = dyn_cast<linalg::GenericOp>(producer)) {
return producerGenericOp->hasOneUse() &&
llvm::all_of(producerGenericOp.getIteratorTypesArray(),
linalg::isParallelIterator);
return true;
}

// Do not fuse with any producer linalg named ops for now.
if (isa<linalg::LinalgOp>(producer)) {
return false;
}

// Do not fuse with consumer linalg named ops or reductions.
// Do not fuse with consumer linalg named ops.
if (auto consumerLinalgOp = dyn_cast<linalg::LinalgOp>(consumer)) {
return isa<linalg::GenericOp>(consumerLinalgOp) &&
llvm::all_of(consumerLinalgOp.getIteratorTypesArray(),
linalg::isParallelIterator);
return isa<linalg::GenericOp>(consumerLinalgOp);
}
// Fuse in all other cases.
return true;
Expand Down

0 comments on commit 6dd0bcc

Please sign in to comment.