Skip to content

Commit

Permalink
[SandboxVec][Legality] Pack from different BBs (#124363)
Browse files Browse the repository at this point in the history
When the inputs of the pack come from different BBs we need to make sure
we emit the pack instructions at the correct place.
  • Loading branch information
vporpo authored Jan 24, 2025
1 parent 280c7d7 commit 6409799
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ enum class ResultReason {
DiffTypes,
DiffMathFlags,
DiffWrapFlags,
DiffBBs,
NotConsecutive,
CantSchedule,
Unimplemented,
Expand Down Expand Up @@ -127,6 +128,8 @@ struct ToStr {
return "DiffMathFlags";
case ResultReason::DiffWrapFlags:
return "DiffWrapFlags";
case ResultReason::DiffBBs:
return "DiffBBs";
case ResultReason::NotConsecutive:
return "NotConsecutive";
case ResultReason::CantSchedule:
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,11 @@ const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl,
dumpBndl(Bndl););
return createLegalityResult<Pack>(ResultReason::NotInstructions);
}
// Pack if not in the same BB.
auto *BB = cast<Instruction>(Bndl[0])->getParent();
if (any_of(drop_begin(Bndl),
[BB](auto *V) { return cast<Instruction>(V)->getParent() != BB; }))
return createLegalityResult<Pack>(ResultReason::DiffBBs);

auto CollectDescrs = getHowToCollectValues(Bndl);
if (CollectDescrs.hasVectorInputs()) {
Expand Down
27 changes: 27 additions & 0 deletions llvm/test/Transforms/SandboxVectorizer/pack.ll
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,30 @@ loop:
exit:
ret void
}

define void @packFromDiffBBs(ptr %ptr, i8 %v) {
; CHECK-LABEL: define void @packFromDiffBBs(
; CHECK-SAME: ptr [[PTR:%.*]], i8 [[V:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[ADD0:%.*]] = add i8 [[V]], 1
; CHECK-NEXT: br label %[[BB:.*]]
; CHECK: [[BB]]:
; CHECK-NEXT: [[ADD1:%.*]] = add i8 [[V]], 2
; CHECK-NEXT: [[PACK:%.*]] = insertelement <2 x i8> poison, i8 [[ADD0]], i32 0
; CHECK-NEXT: [[PACK1:%.*]] = insertelement <2 x i8> [[PACK]], i8 [[ADD1]], i32 1
; CHECK-NEXT: [[GEP0:%.*]] = getelementptr i8, ptr [[PTR]], i64 0
; CHECK-NEXT: store <2 x i8> [[PACK1]], ptr [[GEP0]], align 1
; CHECK-NEXT: ret void
;
entry:
%add0 = add i8 %v, 1
br label %bb

bb:
%add1 = add i8 %v, 2
%gep0 = getelementptr i8, ptr %ptr, i64 0
%gep1 = getelementptr i8, ptr %ptr, i64 1
store i8 %add0, ptr %gep0
store i8 %add1, ptr %gep1
ret void
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,24 @@ struct LegalityTest : public testing::Test {
}
};

static sandboxir::BasicBlock *getBasicBlockByName(sandboxir::Function *F,
StringRef Name) {
for (sandboxir::BasicBlock &BB : *F)
if (BB.getName() == Name)
return &BB;
llvm_unreachable("Expected to find basic block!");
}

TEST_F(LegalityTest, LegalitySkipSchedule) {
parseIR(C, R"IR(
define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1, i64 %v0, i64 %v1, i32 %v2) {
entry:
%gep0 = getelementptr float, ptr %ptr, i32 0
%gep1 = getelementptr float, ptr %ptr, i32 1
store float %farg0, ptr %gep1
br label %bb
bb:
%gep3 = getelementptr float, ptr %ptr, i32 3
%ld0 = load float, ptr %gep0
%ld0b = load float, ptr %gep0
Expand Down Expand Up @@ -89,10 +102,14 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float

sandboxir::Context Ctx(C);
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
auto It = BB->begin();
auto *EntryBB = getBasicBlockByName(F, "entry");
auto It = EntryBB->begin();
[[maybe_unused]] auto *Gep0 = cast<sandboxir::GetElementPtrInst>(&*It++);
[[maybe_unused]] auto *Gep1 = cast<sandboxir::GetElementPtrInst>(&*It++);
auto *St1Entry = cast<sandboxir::StoreInst>(&*It++);

auto *BB = getBasicBlockByName(F, "bb");
It = BB->begin();
[[maybe_unused]] auto *Gep3 = cast<sandboxir::GetElementPtrInst>(&*It++);
auto *Ld0 = cast<sandboxir::LoadInst>(&*It++);
auto *Ld0b = cast<sandboxir::LoadInst>(&*It++);
Expand Down Expand Up @@ -162,6 +179,14 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
sandboxir::ResultReason::DiffWrapFlags);
}
{
// Check DiffBBs
const auto &Result =
Legality.canVectorize({St0, St1Entry}, /*SkipScheduling=*/true);
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
sandboxir::ResultReason::DiffBBs);
}
{
// Check DiffTypes for unary operands that have a different type.
const auto &Result = Legality.canVectorize({Trunc64to8, Trunc32to8},
Expand Down

0 comments on commit 6409799

Please sign in to comment.