Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize constants #806

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,10 @@ add_executable(souper-check
tools/souper-check.cpp
)

add_executable(generalize
tools/generalize.cpp
)

add_executable(souper-interpret
tools/souper-interpret.cpp
)
Expand Down Expand Up @@ -362,7 +366,7 @@ configure_file(
)

foreach(target souper internal-solver-test lexer-test parser-test souper-check count-insts
souper2llvm souper-interpret
souper2llvm souper-interpret generalize
souperExtractor souperInfer souperInst souperKVStore souperParser
souperSMTLIB2 souperTool souperPass souperPassProfileAll kleeExpr
souperCodegen)
Expand Down Expand Up @@ -400,6 +404,7 @@ target_link_libraries(internal-solver-test souperSMTLIB2)
target_link_libraries(lexer-test souperParser)
target_link_libraries(parser-test souperParser)
target_link_libraries(souper-check souperTool souperExtractor souperKVStore souperSMTLIB2 souperParser ${HIREDIS_LIBRARY} ${ALIVE_LIBRARY} ${Z3_LIBRARY})
target_link_libraries(generalize souperTool souperExtractor souperKVStore souperSMTLIB2 souperParser ${HIREDIS_LIBRARY} ${ALIVE_LIBRARY} ${Z3_LIBRARY})
target_link_libraries(souper-interpret souperTool souperExtractor souperKVStore souperSMTLIB2 souperParser ${HIREDIS_LIBRARY} ${ALIVE_LIBRARY} ${Z3_LIBRARY})
target_link_libraries(clang-souper souperClangTool souperExtractor souperKVStore souperParser souperSMTLIB2 souperTool kleeExpr ${CLANG_LIBS} ${LLVM_LIBS} ${LLVM_LDFLAGS} ${HIREDIS_LIBRARY} ${ALIVE_LIBRARY} ${Z3_LIBRARY})
target_link_libraries(count-insts souperParser)
Expand Down
10 changes: 8 additions & 2 deletions include/souper/Extractor/Solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ class Solver {
InstMapping Mapping, bool &IsValid,
std::vector<std::pair<Inst *, llvm::APInt>> *Model) = 0;

virtual std::error_code
isSatisfiable(llvm::StringRef Query, bool &Result,
unsigned NumModels,
std::vector<llvm::APInt> *Models,
unsigned Timeout = 0) = 0;

virtual std::string getName() = 0;

virtual
Expand Down Expand Up @@ -90,8 +96,8 @@ class Solver {
virtual
std::error_code abstractPrecondition(const BlockPCs &BPCs,
const std::vector<InstMapping> &PCs,
InstMapping &Mapping, InstContext &IC,
bool &FoundWeakest) = 0;
InstMapping &Mapping, InstContext &IC, bool &FoundWeakest,
std::vector<std::map<Inst *, llvm::KnownBits>> &Results) = 0;
};

std::unique_ptr<Solver> createBaseSolver(
Expand Down
6 changes: 6 additions & 0 deletions include/souper/Infer/EnumerativeSynthesis.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ namespace souper {

class EnumerativeSynthesis {
public:
EnumerativeSynthesis();

// Synthesize an instruction from the specification in LHS
std::error_code synthesize(SMTLIBSolver *SMTSolver,
const BlockPCs &BPCs,
Expand All @@ -38,6 +40,10 @@ class EnumerativeSynthesis {
bool CheckAllGuesses,
InstContext &IC, unsigned Timeout);

std::vector<Inst *>
generateExprs(InstContext &IC, size_t CountLimit,
std::vector<Inst *> Vars, size_t Width);

};
}

Expand Down
2 changes: 1 addition & 1 deletion include/souper/Infer/Preconditions.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class SMTLIBSolver;
class Solver;
std::vector<std::map<Inst *, llvm::KnownBits>>
inferAbstractKBPreconditions(SynthesisContext &SC, Inst *RHS,
SMTLIBSolver *SMTSolver, Solver *S, bool &FoundWeakest);
Solver *S, bool &FoundWeakest);
}

#endif // SOUPER_PRECONDITIONS_H
66 changes: 30 additions & 36 deletions lib/Extractor/Solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,39 +265,12 @@ class BaseSolver : public Solver {

std::error_code abstractPrecondition(const BlockPCs &BPCs,
const std::vector<InstMapping> &PCs,
InstMapping &Mapping, InstContext &IC,
bool &FoundWeakest) override {
InstMapping &Mapping, InstContext &IC, bool &FoundWeakest,
std::vector<std::map<Inst *, llvm::KnownBits>> &Results) override {
SynthesisContext SC{IC, SMTSolver.get(), Mapping.LHS, /*LHSUB*/nullptr, PCs,
BPCs, /*CheckAllGuesses=*/false, Timeout};

std::vector<std::map<Inst *, llvm::KnownBits>> Results =
inferAbstractKBPreconditions(SC, Mapping.RHS, SMTSolver.get(), this, FoundWeakest);

ReplacementContext RC;
auto LHSStr = RC.printInst(Mapping.LHS, llvm::outs(), true);
llvm::outs() << "infer " << LHSStr << "\n";
auto RHSStr = RC.printInst(Mapping.RHS, llvm::outs(), true);
llvm::outs() << "result " << RHSStr << "\n";
for (size_t i = 0; i < Results.size(); ++i) {
for (auto It = Results[i].begin(); It != Results[i].end(); ++It) {
auto &&P = *It;
std::string dummy;
llvm::raw_string_ostream str(dummy);
auto VarStr = RC.printInst(P.first, str, false);
llvm::outs() << VarStr << " -> " << Inst::getKnownBitsString(P.second.Zero, P.second.One);

auto Next = It;
Next++;
if (Next != Results[i].end()) {
llvm::outs() << " (and) ";
}
}
if (i == Results.size() - 1) {
llvm::outs() << "\n";
} else {
llvm::outs() << "\n(or)\n";
}
}
Results = inferAbstractKBPreconditions(SC, Mapping.RHS, this, FoundWeakest);
return {};
}

Expand Down Expand Up @@ -461,6 +434,13 @@ class BaseSolver : public Solver {
return EC;
}

std::error_code isSatisfiable(llvm::StringRef Query, bool &Result,
unsigned NumModels,
std::vector<llvm::APInt> *Models,
unsigned Timeout = 0) override {
return SMTSolver->isSatisfiable(Query, Result, NumModels, Models, Timeout);
}

std::error_code isValid(InstContext &IC, const BlockPCs &BPCs,
const std::vector<InstMapping> &PCs,
InstMapping Mapping, bool &IsValid,
Expand Down Expand Up @@ -717,6 +697,13 @@ class MemCachingSolver : public Solver {
}
}

std::error_code isSatisfiable(llvm::StringRef Query, bool &Result,
unsigned NumModels,
std::vector<llvm::APInt> *Models,
unsigned Timeout = 0) override {
return UnderlyingSolver->isSatisfiable(Query, Result, NumModels, Models, Timeout);
}

std::string getName() override {
return UnderlyingSolver->getName() + " + internal cache";
}
Expand Down Expand Up @@ -745,9 +732,9 @@ class MemCachingSolver : public Solver {

std::error_code abstractPrecondition(const BlockPCs &BPCs,
const std::vector<InstMapping> &PCs,
InstMapping &Mapping, InstContext &IC,
bool &FoundWeakest) override {
return UnderlyingSolver->abstractPrecondition(BPCs, PCs, Mapping, IC, FoundWeakest);
InstMapping &Mapping, InstContext &IC, bool &FoundWeakest,
std::vector<std::map<Inst *, llvm::KnownBits>> &Results) override {
return UnderlyingSolver->abstractPrecondition(BPCs, PCs, Mapping, IC, FoundWeakest, Results);
}

std::error_code knownBits(const BlockPCs &BPCs,
Expand Down Expand Up @@ -847,6 +834,13 @@ class ExternalCachingSolver : public Solver {
return UnderlyingSolver->constantRange(BPCs, PCs, LHS, IC);
}

std::error_code isSatisfiable(llvm::StringRef Query, bool &Result,
unsigned NumModels,
std::vector<llvm::APInt> *Models,
unsigned Timeout = 0) override {
return UnderlyingSolver->isSatisfiable(Query, Result, NumModels, Models, Timeout);
}

std::error_code isValid(InstContext &IC, const BlockPCs &BPCs,
const std::vector<InstMapping> &PCs,
InstMapping Mapping, bool &IsValid,
Expand Down Expand Up @@ -885,9 +879,9 @@ class ExternalCachingSolver : public Solver {

std::error_code abstractPrecondition(const BlockPCs &BPCs,
const std::vector<InstMapping> &PCs,
InstMapping &Mapping, InstContext &IC,
bool &FoundWeakest) override {
return UnderlyingSolver->abstractPrecondition(BPCs, PCs, Mapping, IC, FoundWeakest);
InstMapping &Mapping, InstContext &IC, bool &FoundWeakest,
std::vector<std::map<Inst *, llvm::KnownBits>> &Results) override {
return UnderlyingSolver->abstractPrecondition(BPCs, PCs, Mapping, IC, FoundWeakest, Results);
}

std::error_code knownBits(const BlockPCs &BPCs,
Expand Down
45 changes: 43 additions & 2 deletions lib/Infer/EnumerativeSynthesis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ extern unsigned DebugLevel;
using namespace souper;
using namespace llvm;

static const std::vector<Inst::Kind> UnaryOperators = {
Inst::CtPop, Inst::BSwap, Inst::BitReverse, Inst::Cttz, Inst::Ctlz, Inst::Freeze
static std::vector<Inst::Kind> UnaryOperators = {
Inst::CtPop, Inst::BSwap, Inst::BitReverse, Inst::Cttz, Inst::Ctlz
};

static const std::vector<Inst::Kind> BinaryOperators = {
Expand Down Expand Up @@ -91,6 +91,9 @@ namespace {
static cl::opt<bool> IgnoreCost("souper-enumerative-synthesis-ignore-cost",
cl::desc("Ignore cost of RHSes -- just generate them. (default=false)"),
cl::init(false));
static cl::opt<bool> SynFreeze("souper-synthesize-freeze",
cl::desc("Generate Freeze (default=true)"),
cl::init(true));
static cl::opt<unsigned> MaxLHSCands("souper-max-lhs-cands",
cl::desc("Gather at most this many values from a LHS to use as synthesis inputs (default=8)"),
cl::init(8));
Expand Down Expand Up @@ -881,3 +884,41 @@ EnumerativeSynthesis::synthesize(SMTLIBSolver *SMTSolver,

return EC;
}

EnumerativeSynthesis::EnumerativeSynthesis() {
if (SynFreeze) {
UnaryOperators.push_back(Inst::Freeze);
}
}

std::vector<Inst *>
EnumerativeSynthesis::generateExprs(InstContext &IC, size_t CountLimit,
std::vector<Inst *> Vars, size_t Width) {
MaxNumInstructions = CountLimit;

std::set<Inst*> Visited;
std::vector<PruneFunc> PruneFuncs = { [&Visited](Inst *I, std::vector<Inst*> &ReservedInsts) {
return CountPrune(I, ReservedInsts, Visited);
}};
auto PruneCallback = MkPruneFunc(PruneFuncs);

std::vector<Inst *> Guesses;

int TooExpensive = CountLimit + 1;

for (auto I : Vars) {
if (I->Width == Width)
addGuess(I, Width, IC, TooExpensive, Guesses, TooExpensive);
}

auto Generate = [&Guesses](Inst *Guess) {
Guesses.push_back(Guess);
return true;
};

getGuesses(Vars, Width, TooExpensive, IC, nullptr,
nullptr, TooExpensive, PruneCallback, Generate);

return Guesses;
}

9 changes: 6 additions & 3 deletions lib/Infer/Preconditions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using llvm::APInt;
namespace souper {
std::vector<std::map<Inst *, llvm::KnownBits>>
inferAbstractKBPreconditions(SynthesisContext &SC, Inst *RHS,
SMTLIBSolver *SMTSolver, Solver *S, bool &FoundWeakest) {
Solver *S, bool &FoundWeakest) {
InstMapping Mapping(SC.LHS, RHS);
bool Valid;
if (DebugLevel >= 3) {
Expand All @@ -20,7 +20,10 @@ std::vector<std::map<Inst *, llvm::KnownBits>>
}
std::vector<InstMapping> PCCopy = SC.PCs;
if (Valid) {
llvm::outs() << "Already valid.\n";
FoundWeakest = true;
if (DebugLevel > 1) {
llvm::errs() << "Already valid.\n";
}
return {};
}

Expand Down Expand Up @@ -97,7 +100,7 @@ std::vector<std::map<Inst *, llvm::KnownBits>>
&ModelInsts, Precondition, true);


SMTSolver->isSatisfiable(Query, FoundWeakest, ModelInsts.size(),
S->isSatisfiable(Query, FoundWeakest, ModelInsts.size(),
&ModelVals, SC.Timeout);

std::map<Inst *, llvm::KnownBits> Known;
Expand Down
42 changes: 42 additions & 0 deletions test/Generalize/fixit.opt
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
; REQUIRES: solver, synthesis
; RUN: %generalize -fixit %s | %souper-check > %t
; RUN: %FileCheck %s < %t

%x:i8 = var
%y:i8 = var
%z = add %x, %y
%t = add %z, 42
%u = sub %t, %y
infer %u
%v = add %x, 42
result %v
;CHECK: LGTM

%x:i8 = var
%y:i8 = var
%t = add %x, 42
%u = sub %t, %y
infer %u
%v = add %x, 42
result %v
;CHECK: LGTM

%x:i8 = var
%y:i8 = var
%t = and %x, 137
%u = xor %t, %y
infer %u
%v = or %x, %y
result %v
;CHECK: LGTM
;CHECK-NEXT: LGTM

%x:i8 = var
%y:i8 = var
%t = or %x, 42
%u = and %t, %y
infer %u
%v = and %x, %y
result %v
;CHECK: LGTM
;CHECK-NEXT: LGTM
22 changes: 22 additions & 0 deletions test/Generalize/leaf.opt
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
; REQUIRES: solver, synthesis
; RUN: %generalize -remove-leaf %s | %souper-check > %t
; RUN: %FileCheck %s < %t

%x:i8 = var
%y:i8 = var
%masked = and %x, 3
%and = and %masked, %y
%foo = lshr %and, 2
infer %and
result 0:i8
; CHECK: LGTM
; CHECK: LGTM

%x:i8 = var
%y:i8 = var
%a = and %x, 15
%b = and %y, 240
%foo = or %a, %b
infer %foo
result 0:i8
; CHECK: LGTM
11 changes: 11 additions & 0 deletions test/Generalize/symbolize.opt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
; REQUIRES: solver, synthesis
; RUN: %generalize -symbolize --souper-synthesize-freeze=false --generalization-num-results=2 %s | %souper-check > %t
; RUN: %FileCheck %s < %t

%x:i8 = var
%foo = add %x, 2
%bar = sub %foo, %x
infer %bar
result 2:i8
;CHECK: LGTM
;CHECK: LGTM
Loading