diff --git a/CMakeLists.txt b/CMakeLists.txt index e431055c6..c6f04a606 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -218,6 +218,12 @@ set(SOUPER_INFER_FILES include/souper/Infer/Interpreter.h lib/Infer/Preconditions.cpp include/souper/Infer/Preconditions.h + lib/Infer/Z3Expr.cpp + include/souper/Infer/Z3Expr.h + lib/Infer/Z3Driver.cpp + include/souper/Infer/Z3Driver.h + lib/Infer/Verification.cpp + include/souper/Infer/Verification.h ) add_library(souperInfer STATIC diff --git a/include/souper/Infer/Verification.h b/include/souper/Infer/Verification.h new file mode 100644 index 000000000..1f3883a5b --- /dev/null +++ b/include/souper/Infer/Verification.h @@ -0,0 +1,52 @@ +#ifndef SOUPER_INFER_VERIFICATION_H +#define SOUPER_INFER_VERIFICATION_H +#include "souper/Inst/Inst.h" + +namespace souper { + +struct RefinementProblem { + souper::Inst *LHS; + souper::Inst *RHS; + souper::Inst *Pre; + BlockPCs BPCs; + + RefinementProblem ReplacePhi(souper::InstContext &IC, std::map &Change); + + bool operator == (const RefinementProblem &P) const { + if (LHS == P.LHS && RHS == P.RHS && + Pre == P.Pre && BPCs.size() == P.BPCs.size()) { + for (size_t i = 0; i < BPCs.size(); ++i) { + if (BPCs[i].B != P.BPCs[i].B || + BPCs[i].PC.LHS != P.BPCs[i].PC.LHS || + BPCs[i].PC.RHS != P.BPCs[i].PC.RHS) { + return false; + } + } + return true; + } else { + return false; + } + } + struct Hash + { + std::size_t operator()(const RefinementProblem &P) const + { + return std::hash()(P.LHS) + ^ std::hash()(P.RHS) << 1 + ^ std::hash()(P.Pre) << 2 + ^ std::hash()(P.BPCs.size()); + } + }; + +}; + +void collectPhis(souper::Inst *I, + std::map> &Phis); + +std::unordered_set + explodePhis(InstContext &IC, RefinementProblem P); + +Inst *getDataflowConditions(const Inst *I, InstContext &IC); + +} +#endif diff --git a/include/souper/Infer/Z3Driver.h b/include/souper/Infer/Z3Driver.h new file mode 100644 index 000000000..4768f0991 --- /dev/null +++ b/include/souper/Infer/Z3Driver.h @@ -0,0 +1,240 @@ +#ifndef SOUPER_Z3_DRIVER_H +#define SOUPER_Z3_DRIVER_H +#include "souper/Infer/Verification.h" +#include "souper/Infer/Z3Expr.h" +#include "souper/Inst/Inst.h" + +namespace souper { + +class Z3Driver { +public: + Z3Driver(Inst *LHS_, Inst *PreCondition_, InstContext &IC_, BlockPCs BPCs_, + std::vector ExtraInputs = {}, unsigned Timeout = 10000) + : LHS(LHS_), Precondition(PreCondition_), IC(IC_), + TranslatedExprs(ctx), Solver(ctx){ + // TODO: Preprocessing, solver keepalive, variables in solver for reuse + +// z3::params p(ctx); +// p.set(":timeout", Timeout); +// Solver.set(p); + + InstNumbers = 201; + //201 is chosen arbitrarily. + + for (auto BPC: BPCs_) { + BPCs[BPC.B][BPC.PredIdx].push_back(BPC.PC); + } + + Translate(LHS); + if (Precondition) { + AddConstraint(Precondition); + } + } + bool verify(Inst *RHS, Inst *RHSAssumptions = nullptr) { + Translate(RHS); + if (RHSAssumptions) { + AddConstraint(RHSAssumptions); + } + + if (LHS->DemandedBits != 0) { + auto Mask = ctx.bv_val(LHS->DemandedBits.toString(10, false).c_str(), LHS->Width); + Put(LHS, Get(LHS) & Mask); + Put(RHS, Get(RHS) & Mask); + } + Solver.add(Get(LHS) != Get(RHS)); + + if (Solver.check() == z3::unsat) { + return true; + } else { + // TODO: Model + return false; + } + + } +private: + Inst *LHS; + Inst *Precondition; + InstContext &IC; + std::map>> BPCs; + std::map NamesCache; + std::map ExprCache; + z3::context ctx; + z3::expr_vector TranslatedExprs; + z3::solver Solver; + + bool isCached(Inst *I) { + return ExprCache.find(I) != ExprCache.end(); + } + + z3::expr Get(Inst *I) { + return TranslatedExprs[ExprCache[I]]; + } + void Put(Inst *I, z3::expr E) { + TranslatedExprs.push_back(E); + ExprCache[I] = TranslatedExprs.size() - 1; + } + + void AddConstraint(Inst *I) { + Translate(I); + Solver.add(Get(I) == ctx.bv_val(1, 1)); + } + + z3::expr getPhiArgConstraint(Block *B, unsigned idx) { + souper::Inst *Cond = nullptr; + for (auto Mapping : BPCs[B][idx]) { + auto Cur = IC.getInst(Inst::Kind::Eq, 1, {Mapping.LHS, Mapping.RHS}); + if (!Cond) { + Cond = Cur; + } else { + Cond = IC.getInst(Inst::Kind::And, 1, {Cond, Cur}); + } + } + Translate(Cond); + auto Expr = Get(Cond); + if (Expr.get_sort().is_bv()) { + // i1 to bool. TODO: investigate if there is a more efficient way of doing this. + auto &ctx = Expr.ctx(); + Expr = z3::ite(Expr == ctx.bv_val(1, 1), ctx.bool_val(true), ctx.bool_val(false)); + Put(Cond, Expr); + } + return Expr; + } + + int InstNumbers; + + void addExtraPreds(souper::Inst *I); + + bool Translate(souper::Inst *I) { + if (!I) { + return false; + } + // unused translation; this is souper's internal instruction to represent overflow instructions + if (souper::Inst::isOverflowIntrinsicSub(I->K)) { + return true; + } + + if (isCached(I)) { + return true; + } + + auto Ops = I->Ops; + if (souper::Inst::isOverflowIntrinsicMain(I->K)) { + Ops = Ops[0]->Ops; + } + + for (auto &&Op : Ops) { + if (!Translate(Op)) { + return false; + } + if (I->K == Inst::ExtractValue) { + break; // Break after translating main arg, idx is handled separately. + } + + } + + std::string Name; + if (NamesCache.find(I) != NamesCache.end()) { + Name = NamesCache[I]; + } else if (I->Name != "") { + if (I->SynthesisConstID != 0) { + Name = "%" + souper::ReservedConstPrefix + std::to_string(I->SynthesisConstID); + } else { + Name = "%var_" + I->Name; + } + } else { + Name = "%" + std::to_string(InstNumbers++); + } + NamesCache[I] = Name; + + auto W = I->Width; + addExtraPreds(I); + switch (I->K) { + case souper::Inst::Var: { + Put(I, ctx.bv_const(Name.c_str(), W)); + auto DFC = getDataflowConditions(I, IC); + if (DFC) { + AddConstraint(DFC); + } + return true; + } + case souper::Inst::Hole: { + llvm::report_fatal_error("Holes unimplemented in Z3Driver."); + } + case souper::Inst::Const: { + Put(I, ctx.bv_val(I->Val.toString(10, false).c_str(), W)); + // inefficient? + return true; + } + + case souper::Inst::Phi: { + auto Var = ctx.bv_const(Name.c_str(), I->Width); + auto Constraint = ((Var == Get(I->Ops[0])) && getPhiArgConstraint(I->B, 0)); + for (size_t i = 0; i < I->Ops.size(); ++i) { + Constraint = Constraint || + ((Var == Get(I->Ops[i])) && getPhiArgConstraint(I->B, i)); + } + Solver.add(Constraint); + Put(I, Var); + return true; + } + + case souper::Inst::ExtractValue: { + unsigned idx = I->Ops[1]->Val.getLimitedValue(); + assert(idx <= 1 && "Only extractvalue with overflow instructions are supported."); + Put(I, z3expr::ExtractValue(Get(I->Ops[0]), idx)); + return true; + } + + #define UNOP(SOUPER, Z3) case souper::Inst::SOUPER: { \ + Put(I, z3expr::Z3(Get(Ops[0]))); \ + return true; \ + } + #define UNOPC(SOUPER, Z3) case souper::Inst::SOUPER: { \ + Put(I, z3expr::Z3(Get(Ops[0]), I->Width)); \ + return true; \ + } + #define BINOP(SOUPER, Z3) case souper::Inst::SOUPER: { \ + Put(I, z3expr::Z3(Get(Ops[0]), Get(Ops[1]))); \ + return true; \ + } + #define TERNOP(SOUPER, Z3) case souper::Inst::SOUPER: { \ + Put(I, z3expr::Z3(Get(Ops[0]), Get(Ops[1]), Get(Ops[2])));\ + return true; \ + } + + UNOP(Freeze, Freeze); UNOP(CtPop, CtPop); UNOP(BSwap, BSwap); + UNOP(BitReverse, BitReverse); UNOP(Cttz, Cttz); UNOP(Ctlz, Ctlz); + + UNOPC(ZExt, ZExt); UNOPC(SExt, SExt); UNOPC(Trunc, Trunc); + + BINOP(Add, Add); BINOP(AddNSW, Add); BINOP(AddNUW, Add); BINOP(AddNW, Add); + BINOP(Sub, Sub); BINOP(SubNSW, Sub); BINOP(SubNUW, Sub); BINOP(SubNW, Sub); + BINOP(Mul, Mul); BINOP(MulNSW, Mul); BINOP(MulNUW, Mul); BINOP(MulNW, Mul); + BINOP(Shl, Shl); BINOP(ShlNSW, Shl); BINOP(ShlNUW, Shl); BINOP(ShlNW, Shl); + BINOP(And, And); BINOP(Or, Or); BINOP(Xor, Xor); + BINOP(LShr, LShr); BINOP(LShrExact, LShr); BINOP(AShr, AShr); BINOP(AShrExact, AShr); + BINOP(URem, URem); BINOP(SRem, SRem); BINOP(UDiv, UDiv); BINOP(UDivExact, UDiv); + BINOP(SDiv, SDiv); BINOP(SDivExact, SDiv); BINOP(SAddSat, SAddSat); + BINOP(UAddSat, SAddSat); BINOP(SSubSat, SSubSat); BINOP(USubSat, USubSat); + BINOP(SAddWithOverflow, SAddWithOverflow); BINOP(UAddWithOverflow, UAddWithOverflow); + BINOP(SSubWithOverflow, SSubWithOverflow); BINOP(USubWithOverflow, USubWithOverflow); + BINOP(SMulWithOverflow, SMulWithOverflow); BINOP(UMulWithOverflow, UMulWithOverflow); + BINOP(Eq, Eq); BINOP(Ne, Ne); BINOP(Ule, Ule); + BINOP(Ult, Ult); BINOP(Sle, Sle); BINOP(Slt, Slt); + + TERNOP(Select, Select); TERNOP(FShl, FShl); TERNOP(FShr, FShr); + + default: llvm::report_fatal_error("Unimplemented instruction."); + } + } +}; + + +bool isTransformationValidZ3(souper::Inst *LHS, souper::Inst *RHS, + const std::vector &PCs, + const souper::BlockPCs &BPCs, + InstContext &IC, unsigned Timeout = 10000); + +} + +#endif diff --git a/include/souper/Infer/Z3Expr.h b/include/souper/Infer/Z3Expr.h new file mode 100644 index 000000000..52459650f --- /dev/null +++ b/include/souper/Infer/Z3Expr.h @@ -0,0 +1,118 @@ +#ifndef Z3_EXPR_H +#define Z3_EXPR_H +#include "z3++.h" +namespace z3expr { + +z3::expr Add(z3::expr x, z3::expr y); + +z3::expr Sub(z3::expr x, z3::expr y); + +z3::expr Mul(z3::expr x, z3::expr y); + +z3::expr UDiv(z3::expr x, z3::expr y); + +z3::expr SDiv(z3::expr x, z3::expr y); + +z3::expr URem(z3::expr x, z3::expr y); + +z3::expr SRem(z3::expr x, z3::expr y); + +z3::expr And(z3::expr x, z3::expr y); + +z3::expr Or(z3::expr x, z3::expr y); + +z3::expr Xor(z3::expr x, z3::expr y); + +z3::expr Shl(z3::expr x, z3::expr y); + +z3::expr LShr(z3::expr x, z3::expr y); + +z3::expr AShr(z3::expr x, z3::expr y); + +z3::expr Select(z3::expr C, z3::expr T, z3::expr F); + +z3::expr ZExt(z3::expr x, size_t W); + +z3::expr SExt(z3::expr x, size_t W); + +z3::expr Trunc(z3::expr x, size_t W); + +z3::expr Eq(z3::expr x, z3::expr y); + +z3::expr Ne(z3::expr x, z3::expr y); + +z3::expr Ult(z3::expr x, z3::expr y); + +z3::expr Slt(z3::expr x, z3::expr y); + +z3::expr Ule(z3::expr x, z3::expr y); + +z3::expr Sle(z3::expr x, z3::expr y); + +z3::expr CtPop(z3::expr x); + +z3::expr Freeze(z3::expr x); + +z3::expr ExtractValue(z3::expr x, size_t idx); + +z3::expr SAddWithOverflow(z3::expr x, z3::expr y); + +z3::expr UAddWithOverflow(z3::expr x, z3::expr y); + +z3::expr SSubWithOverflow(z3::expr x, z3::expr y); + +z3::expr USubWithOverflow(z3::expr x, z3::expr y); + +z3::expr SMulWithOverflow(z3::expr x, z3::expr y); + +z3::expr UMulWithOverflow(z3::expr x, z3::expr y); + +z3::expr BSwap(z3::expr x); + +z3::expr Cttz(z3::expr x); + +z3::expr Ctlz(z3::expr x); + +z3::expr BitReverse(z3::expr x); + +z3::expr FShl(z3::expr a, z3::expr b, z3::expr c); + +z3::expr FShr(z3::expr a, z3::expr b, z3::expr c); + +z3::expr SAddSat(z3::expr x, z3::expr y); + +z3::expr UAddSat(z3::expr x, z3::expr y); + +z3::expr SSubSat(z3::expr x, z3::expr y); + +z3::expr USubSat(z3::expr x, z3::expr y); + +z3::expr add_no_soverflow(z3::expr x, z3::expr y); + +z3::expr add_no_uoverflow(z3::expr x, z3::expr y); + +z3::expr sub_no_soverflow(z3::expr x, z3::expr y); + +z3::expr sub_no_uoverflow(z3::expr x, z3::expr y); + +z3::expr mul_no_soverflow(z3::expr x, z3::expr y); + +z3::expr mul_no_uoverflow(z3::expr x, z3::expr y); + +z3::expr shl_no_soverflow(z3::expr x, z3::expr y); + +z3::expr shl_no_uoverflow(z3::expr x, z3::expr y); + +z3::expr sdiv_exact(z3::expr x, z3::expr y); + +z3::expr udiv_exact(z3::expr x, z3::expr y); + +z3::expr ashr_exact(z3::expr x, z3::expr y); + +z3::expr lshr_exact(z3::expr x, z3::expr y); + +z3::expr ToBV(z3::expr x); + +z3::expr ToIBV(z3::expr x); //inverted +} +#endif diff --git a/lib/Extractor/Solver.cpp b/lib/Extractor/Solver.cpp index 912d25dfd..3ff3a3cb8 100644 --- a/lib/Extractor/Solver.cpp +++ b/lib/Extractor/Solver.cpp @@ -24,6 +24,7 @@ #include "souper/Codegen/Codegen.h" #include "souper/Extractor/Solver.h" #include "souper/Infer/AliveDriver.h" +#include "souper/Infer/Z3Driver.h" #include "souper/Infer/ConstantSynthesis.h" #include "souper/Infer/EnumerativeSynthesis.h" #include "souper/Infer/InstSynthesis.h" @@ -49,6 +50,10 @@ namespace { static cl::opt NoInfer("souper-no-infer", cl::desc("Populate the external cache, but don't infer replacements (default=false)"), cl::init(false)); +static cl::opt UseZ3Driver("souper-in-process-z3", + cl::desc("Use Z3 C++ api instead of smtlib2 (default=false)"), + cl::init(false)); + static cl::opt UseCegis("souper-use-cegis", cl::desc("Infer instructions (default=false)"), cl::init(false)); @@ -487,6 +492,9 @@ class BaseSolver : public Solver { if (UseAlive) { IsValid = isTransformationValid(Mapping.LHS, Mapping.RHS, PCs, BPCs, IC); return std::error_code(); + } else if (UseZ3Driver) { + IsValid = isTransformationValidZ3(Mapping.LHS, Mapping.RHS, PCs, BPCs, IC, Timeout); + return std::error_code(); } std::string Query; if (Model) { diff --git a/lib/Infer/AliveDriver.cpp b/lib/Infer/AliveDriver.cpp index 73eb08119..526d37255 100644 --- a/lib/Infer/AliveDriver.cpp +++ b/lib/Infer/AliveDriver.cpp @@ -1,4 +1,4 @@ -#include "souper/Extractor/ExprBuilder.h" +#include "souper/Infer/Verification.h" #include "souper/Infer/AliveDriver.h" #include "souper/Inst/Inst.h" @@ -472,30 +472,7 @@ bool souper::AliveDriver::translateRoot(const souper::Inst *I, const Inst *PC, return true; } -// Dummy because it doesn't actually build expressions. -// It exists for the purpose of reusing parts of the abstract ExprBuilder here. -// FIXME: Allow creating objects of ExprBuilder -class DummyExprBuilder : public souper::ExprBuilder { -public: - DummyExprBuilder(souper::InstContext &IC) : souper::ExprBuilder(IC) {} - std::string BuildQuery(const souper::BlockPCs & BPCs, - const std::vector & PCs, - souper::InstMapping Mapping, - std::vector * ModelVars, - souper::Inst *Precondition, - bool Negate, bool DropUB) override { - llvm::report_fatal_error("Do not call"); - return ""; - } - std::string GetExprStr(const souper::BlockPCs & BPCs, - const std::vector & PCs, - souper::InstMapping Mapping, - std::vector * ModelVars, - bool Negate, bool DropUB) override { - llvm::report_fatal_error("Do not call"); - return ""; - } -}; + std::string getUniqueName() { static int N = 0; return "dummy_" + std::to_string(N++); @@ -722,9 +699,7 @@ bool souper::AliveDriver::translateDataflowFacts(const souper::Inst* I, IR::Function& F, souper::AliveDriver::Cache& ExprCache) { - DummyExprBuilder EB(IC); - auto DataFlowConstraints = EB.getDataflowConditions(const_cast(I)); - //FIXME: Get rid of the const_cast by making getDataflowConditions take const Inst * + auto DataFlowConstraints = getDataflowConditions(I, IC); if (DataFlowConstraints) { if (!translateAndCache(DataFlowConstraints, F, ExprCache)) { return false; @@ -775,146 +750,6 @@ IR::Type &souper::AliveDriver::getOverflowType(int Width) { return *TypeCache[n]; } -namespace souper { -void collectPhis(souper::Inst *I, std::map> &Phis) { - std::vector Stack{I}; - std::unordered_set Visited; - while (!Stack.empty()) { - auto Current = Stack.back(); - Stack.pop_back(); - if (Current->K == Inst::Phi) { - Phis[Current->B].insert(Current); - } - Visited.insert(Current); - for (auto Child : Current->Ops) { - if (Visited.find(Child) == Visited.end()) { - Stack.push_back(Child); - } - } - } -} - -struct RefinementProblem { - souper::Inst *LHS; - souper::Inst *RHS; - souper::Inst *Pre; - BlockPCs BPCs; - - RefinementProblem ReplacePhi(souper::InstContext &IC, std::map &Change) { - std::map> Phis; - collectPhis(LHS, Phis); - collectPhis(RHS, Phis); - collectPhis(Pre, Phis); - for (auto &BPC : BPCs) { - collectPhis(BPC.PC.LHS, Phis); - } - - if (Phis.empty()) { - return *this; // Base case, no more Phis - } - - std::map InstCache; - for (auto Pair : Phis) { - for (auto Phi : Pair.second) { - InstCache[Phi] = Phi->Ops[Change[Pair.first]]; - } - } - std::map BlockCache; - std::map ConstMap; - RefinementProblem Result; - Result.LHS = getInstCopy(LHS, IC, InstCache, BlockCache, &ConstMap, false); - Result.RHS = getInstCopy(RHS, IC, InstCache, BlockCache, &ConstMap, false); - Result.Pre = getInstCopy(Pre, IC, InstCache, BlockCache, &ConstMap, false); - Result.BPCs = BPCs; - for (auto &BPC : Result.BPCs) { - BPC.PC.LHS = getInstCopy(BPC.PC.LHS, IC, InstCache, BlockCache, - &ConstMap, false); - } - - // Recursively call ReplacePhi, because Result might have Phi`s - return Result.ReplacePhi(IC, Change); - } - bool operator == (const RefinementProblem &P) const { - if (LHS == P.LHS && RHS == P.RHS && - Pre == P.Pre && BPCs.size() == P.BPCs.size()) { - for (size_t i = 0; i < BPCs.size(); ++i) { - if (BPCs[i].B != P.BPCs[i].B || - BPCs[i].PC.LHS != P.BPCs[i].PC.LHS || - BPCs[i].PC.RHS != P.BPCs[i].PC.RHS) { - return false; - } - } - return true; - } else { - return false; - } - } - struct Hash - { - std::size_t operator()(const RefinementProblem &P) const - { - return std::hash()(P.LHS) - ^ std::hash()(P.RHS) << 1 - ^ std::hash()(P.Pre) << 2 - ^ std::hash()(P.BPCs.size()); - } - }; - -}; - -std::unordered_set - explodePhis(InstContext &IC, RefinementProblem P) { - std::map> Phis; - collectPhis(P.LHS, Phis); - collectPhis(P.Pre, Phis); - - if (Phis.empty()) { - return {P}; - } - - std::vector Blocks; - for (auto &&Pair : Phis) { - Blocks.push_back(Pair.first); - } - - std::vector> ChangeList; - - for (size_t i = 0; i < Blocks.size(); ++i) { // Each block - if (i == 0) { - for (size_t j = 0; j < Blocks[i]->Preds; ++j) { - ChangeList.push_back({{Blocks[i], j}}); - } - } else { - std::vector> NewChangeList; - for (size_t j = 0; j < Blocks[i]->Preds; ++j) { - for (auto Change : ChangeList) { - Change.insert({Blocks[i], j}); - NewChangeList.push_back(Change); - } - } - std::swap(ChangeList, NewChangeList); - } - } - - std::unordered_set Result; - - for (auto Change : ChangeList) { - auto Goal = P.ReplacePhi(IC, Change); - // Consider switching to better data structures for dealing with BPCs - for (auto &[Block, Pred] : Change) { - for (auto &BPC : Goal.BPCs) { - if (BPC.B == Block && BPC.PredIdx == Pred) { - auto Ante = IC.getInst(Inst::Eq, 1, {BPC.PC.LHS, BPC.PC.RHS}); - Goal.Pre = IC.getInst(Inst::And, 1, {Goal.Pre, Ante}); - } - } - } - Result.insert(Goal); - } - return Result; -} - -} bool souper::isTransformationValid(souper::Inst *LHS, souper::Inst *RHS, const std::vector &PCs, const souper::BlockPCs &BPCs, diff --git a/lib/Infer/EnumerativeSynthesis.cpp b/lib/Infer/EnumerativeSynthesis.cpp index 9ae534427..18a5ce1db 100644 --- a/lib/Infer/EnumerativeSynthesis.cpp +++ b/lib/Infer/EnumerativeSynthesis.cpp @@ -16,6 +16,7 @@ #include "llvm/ADT/SetVector.h" #include "llvm/Support/CommandLine.h" #include "souper/Infer/AliveDriver.h" +#include "souper/Infer/Z3Driver.h" #include "souper/Infer/ConstantSynthesis.h" #include "souper/Infer/EnumerativeSynthesis.h" #include "souper/Infer/Pruning.h" diff --git a/lib/Infer/Verification.cpp b/lib/Infer/Verification.cpp new file mode 100644 index 000000000..72ad66e28 --- /dev/null +++ b/lib/Infer/Verification.cpp @@ -0,0 +1,140 @@ +#include "souper/Infer/Verification.h" +#include "souper/Extractor/ExprBuilder.h" +namespace souper { +void collectPhis(souper::Inst *I, std::map> &Phis) { + std::vector Stack{I}; + std::unordered_set Visited; + while (!Stack.empty()) { + auto Current = Stack.back(); + Stack.pop_back(); + if (Current->K == Inst::Phi) { + Phis[Current->B].insert(Current); + } + Visited.insert(Current); + for (auto Child : Current->Ops) { + if (Visited.find(Child) == Visited.end()) { + Stack.push_back(Child); + } + } + } +} + +std::unordered_set + explodePhis(InstContext &IC, RefinementProblem P) { + std::map> Phis; + collectPhis(P.LHS, Phis); + collectPhis(P.Pre, Phis); + + if (Phis.empty()) { + return {P}; + } + + std::vector Blocks; + for (auto &&Pair : Phis) { + Blocks.push_back(Pair.first); + } + + std::vector> ChangeList; + + for (size_t i = 0; i < Blocks.size(); ++i) { // Each block + if (i == 0) { + for (size_t j = 0; j < Blocks[i]->Preds; ++j) { + ChangeList.push_back({{Blocks[i], j}}); + } + } else { + std::vector> NewChangeList; + for (size_t j = 0; j < Blocks[i]->Preds; ++j) { + for (auto Change : ChangeList) { + Change.insert({Blocks[i], j}); + NewChangeList.push_back(Change); + } + } + std::swap(ChangeList, NewChangeList); + } + } + + std::unordered_set Result; + + for (auto Change : ChangeList) { + auto Goal = P.ReplacePhi(IC, Change); + // Consider switching to better data structures for dealing with BPCs + for (auto &[Block, Pred] : Change) { + for (auto &BPC : Goal.BPCs) { + if (BPC.B == Block && BPC.PredIdx == Pred) { + auto Ante = IC.getInst(Inst::Eq, 1, {BPC.PC.LHS, BPC.PC.RHS}); + Goal.Pre = IC.getInst(Inst::And, 1, {Goal.Pre, Ante}); + } + } + } + Goal.LHS->DemandedBits = P.LHS->DemandedBits; + Result.insert(Goal); + } + return Result; +} + +Inst *getDataflowConditions(const Inst *I, InstContext &IC) { + // Dummy because it doesn't actually build expressions. + // It exists for the purpose of reusing parts of the abstract ExprBuilder here. + // FIXME: Allow creating objects of ExprBuilder + class DummyExprBuilder : public souper::ExprBuilder { + public: + DummyExprBuilder(souper::InstContext &IC) : souper::ExprBuilder(IC) {} + std::string BuildQuery(const souper::BlockPCs & BPCs, + const std::vector & PCs, + souper::InstMapping Mapping, + std::vector * ModelVars, + souper::Inst *Precondition, + bool Negate, bool DropUB) override { + llvm::report_fatal_error("Do not call"); + return ""; + } + std::string GetExprStr(const souper::BlockPCs & BPCs, + const std::vector & PCs, + souper::InstMapping Mapping, + std::vector * ModelVars, + bool Negate, bool DropUB) override { + llvm::report_fatal_error("Do not call"); + return ""; + } + }; + DummyExprBuilder EB(IC); + return EB.getDataflowConditions(const_cast(I)); + +} + +RefinementProblem RefinementProblem::ReplacePhi(souper::InstContext &IC, std::map &Change) { + std::map> Phis; + collectPhis(LHS, Phis); + collectPhis(RHS, Phis); + collectPhis(Pre, Phis); + for (auto &BPC : BPCs) { + collectPhis(BPC.PC.LHS, Phis); + } + + if (Phis.empty()) { + return *this; // Base case, no more Phis + } + + std::map InstCache; + for (auto Pair : Phis) { + for (auto Phi : Pair.second) { + InstCache[Phi] = Phi->Ops[Change[Pair.first]]; + } + } + std::map BlockCache; + std::map ConstMap; + RefinementProblem Result; + Result.LHS = getInstCopy(LHS, IC, InstCache, BlockCache, &ConstMap, false); + Result.RHS = getInstCopy(RHS, IC, InstCache, BlockCache, &ConstMap, false); + Result.Pre = getInstCopy(Pre, IC, InstCache, BlockCache, &ConstMap, false); + Result.BPCs = BPCs; + for (auto &BPC : Result.BPCs) { + BPC.PC.LHS = getInstCopy(BPC.PC.LHS, IC, InstCache, BlockCache, + &ConstMap, false); + } + + // Recursively call ReplacePhi, because Result might have Phi`s + return Result.ReplacePhi(IC, Change); +} + +} diff --git a/lib/Infer/Z3Driver.cpp b/lib/Infer/Z3Driver.cpp new file mode 100644 index 000000000..fcfd9eeec --- /dev/null +++ b/lib/Infer/Z3Driver.cpp @@ -0,0 +1,103 @@ +#include "souper/Infer/Z3Driver.h" +#include "souper/Infer/Z3Expr.h" + +extern unsigned DebugLevel; + +namespace souper { +bool isTransformationValidZ3(souper::Inst *LHS, souper::Inst *RHS, + const std::vector &PCs, + const souper::BlockPCs &BPCs, + InstContext &IC, unsigned Timeout) { + Inst *Ante = IC.getConst(llvm::APInt(1, true)); + for (auto PC : PCs ) { + Inst *Eq = IC.getInst(Inst::Eq, 1, {PC.LHS, PC.RHS}); + Ante = IC.getInst(Inst::And, 1, {Ante, Eq}); + } + +// auto Goals = explodePhis(IC, {LHS, RHS, Ante, BPCs}); +// // ^ Explanation in AliveDriver.cpp + +// if (DebugLevel > 3) +// llvm::errs() << "Number of sub-goals : " << Goals.size() << "\n"; +// for (const auto &Goal : Goals) { +// if (DebugLevel > 3) { +// llvm::errs() << "Goal:\n"; +// ReplacementContext RC; +// RC.printInst(Goal.LHS, llvm::errs(), true); +// llvm::errs() << "\n------\n"; +// } +// std::vector Vars; +// findVars(Goal.RHS, Vars); +// Z3Driver Verifier(Goal.LHS, Goal.Pre, IC, BPCs, Vars, Timeout); +// if (!Verifier.verify(Goal.RHS, Goal.Pre)) +// return false; +// } + + std::vector Vars; + findVars(RHS, Vars); + Z3Driver Verifier(LHS, Ante, IC, BPCs, Vars, Timeout); + if (!Verifier.verify(RHS, Ante)) { + return false; + } else { + return true; + } +} + +void Z3Driver::addExtraPreds(souper::Inst *I) { + if (I->K == Inst::Kind::UDiv || I->K == Inst::Kind::SDiv + || I->K == Inst::Kind::SDivExact || I->K == Inst::Kind::UDivExact + || I->K == Inst::Kind::URem || I->K == Inst::Kind::SRem) { + Solver.add(Get(I->Ops[1]) != ctx.bv_val(0, I->Width)); + } + + if (I->K == Inst::Kind::Shl || I->K == Inst::Kind::LShr + || I->K == Inst::Kind::AShr || I->K == Inst::Kind::AShrExact + || I->K == Inst::Kind::LShrExact) { + Solver.add(z3::ult(Get(I->Ops[1]), ctx.bv_val(I->Width, I->Width))); + } + + if (I->K == Inst::Kind::AddNSW || I->K == Inst::Kind::AddNW) { + Solver.add(z3expr::add_no_soverflow(Get(I->Ops[0]), Get(I->Ops[1]))); + } + + if (I->K == Inst::Kind::AddNUW || I->K == Inst::Kind::AddNW) { + Solver.add(z3expr::add_no_uoverflow(Get(I->Ops[0]), Get(I->Ops[1]))); + } + + if (I->K == Inst::Kind::SubNSW || I->K == Inst::Kind::SubNW) { + Solver.add(z3expr::sub_no_soverflow(Get(I->Ops[0]), Get(I->Ops[1]))); + } + if (I->K == Inst::Kind::SubNUW || I->K == Inst::Kind::SubNW) { + Solver.add(z3expr::sub_no_uoverflow(Get(I->Ops[0]), Get(I->Ops[1]))); + } + + if (I->K == Inst::Kind::MulNSW || I->K == Inst::Kind::MulNW) { + Solver.add(z3expr::mul_no_soverflow(Get(I->Ops[0]), Get(I->Ops[1]))); + } + if (I->K == Inst::Kind::MulNUW || I->K == Inst::Kind::MulNW) { + Solver.add(z3expr::mul_no_uoverflow(Get(I->Ops[0]), Get(I->Ops[1]))); + } + + if (I->K == Inst::Kind::ShlNSW || I->K == Inst::Kind::ShlNW) { + Solver.add(z3expr::shl_no_soverflow(Get(I->Ops[0]), Get(I->Ops[1]))); + } + if (I->K == Inst::Kind::ShlNUW || I->K == Inst::Kind::ShlNW) { + Solver.add(z3expr::shl_no_uoverflow(Get(I->Ops[0]), Get(I->Ops[1]))); + } + + if (I->K == Inst::Kind::UDivExact) { + Solver.add(z3expr::udiv_exact(Get(I->Ops[0]), Get(I->Ops[1]))); + } + if (I->K == Inst::Kind::SDivExact) { + Solver.add(z3expr::sdiv_exact(Get(I->Ops[0]), Get(I->Ops[1]))); + } + + if (I->K == Inst::Kind::AShrExact) { + Solver.add(z3expr::ashr_exact(Get(I->Ops[0]), Get(I->Ops[1]))); + } + if (I->K == Inst::Kind::LShrExact) { + Solver.add(z3expr::lshr_exact(Get(I->Ops[0]), Get(I->Ops[1]))); + } +} + +} diff --git a/lib/Infer/Z3Expr.cpp b/lib/Infer/Z3Expr.cpp new file mode 100644 index 000000000..e179e0e01 --- /dev/null +++ b/lib/Infer/Z3Expr.cpp @@ -0,0 +1,290 @@ +#include "souper/Infer/Z3Expr.h" + +namespace z3expr { +using E = z3::expr; +using namespace z3; +E Add(E x, E y) { + return x + y; +} +E Sub(E x, E y) { + return x - y; +} +E Mul(E x, E y) { + return x * y; +} +E UDiv(E x, E y) { + return z3::udiv(x, y); +} +E SDiv(E x, E y) { + return x / y; +} +E URem(E x, E y) { + return z3::urem(x, y); +} +E SRem(E x, E y) { + return z3::srem(x, y); +} +E And(E x, E y) { + return x & y; +} +E Or(E x, E y) { + return x | y; +} +E Xor(E x, E y) { + return x ^ y; +} +E Shl(E x, E y) { + return z3::shl(x, y); +} +E LShr(E x, E y) { + return z3::lshr(x, y); +} +E AShr(E x, E y) { + return z3::ashr(x, y); +} +E Select(E C, E T, E F) { + return z3::ite(C == C.ctx().bv_val(1, 1), T, F); +} +E ZExt(E x, size_t W) { + auto xW = x.get_sort().bv_size(); + return z3::zext(x, W - xW); +} +E SExt(E x, size_t W) { + auto xW = x.get_sort().bv_size(); + return z3::sext(x, W - xW); +} +E Trunc(E x, size_t W) { + return x.extract(W-1, 0); +} + +E ToBV(E x) { + auto &ctx = x.ctx(); + return z3::ite(x, ctx.bv_val(1, 1), ctx.bv_val(0, 1)); +} + +E Eq(E x, E y) { + return ToBV(x == y); +} +E Ne(E x, E y) { + return ToBV(x != y); +} +E Ult(E x, E y){ + return ToBV(z3::ult(x, y)); +} +E Slt(E x, E y){ + return ToBV(z3::slt(x, y)); +} +E Ule(E x, E y){ + return ToBV(z3::ule(x, y)); +} +E Sle(E x, E y){ + return ToBV(z3::sle(x, y)); +} +E CtPop(E x) { + auto W = x.get_sort().bv_size(); + auto &ctx = x.ctx(); + auto sum = ctx.bv_val(0, W); + for (size_t i = 0; i < W; ++i) { + sum = sum + z3::zext(x.extract(i, i), W - 1); + } + return sum; +} +E Freeze(E x) { + return x; +} + +E ExtractValue(E x, size_t idx) { + if (idx == 0) { // return value + return x.extract(x.get_sort().bv_size() - 1, 1); + } else { // return overflow flag + return x.extract(0 , 0); + } +} + +E add_no_soverflow(E x, E y) { + return sext(x, 1) + sext(y, 1) == sext(x + y, 1); +} + +E add_no_uoverflow(E x, E y) { + auto bw = x.get_sort().bv_size(); + return (zext(x, 1) + zext(y, 1)).extract(bw, bw) == 0; +} + +E sub_no_soverflow(E x, E y) { + return sext(x, 1) - sext(y, 1) == sext((x - y), 1); +} + +E sub_no_uoverflow(E x, E y) { + auto bw = x.get_sort().bv_size(); + return (zext(x, 1) - zext(y, 1)).extract(bw, bw) == 0; +} + +E mul_no_soverflow(E x, E y) { + auto bw = x.get_sort().bv_size(); + return sext(x, bw) * sext(y, bw) == sext((x * y), bw); +} + +E mul_no_uoverflow(E x, E y) { + auto bw = x.get_sort().bv_size(); + return (zext(x, bw) * zext(y, bw)).extract(2*bw - 1, bw) == 0; +} + +E shl_no_soverflow(E x, E y) { + return ashr(shl(x, y), y) == x; +} + +E shl_no_uoverflow(E x, E y) { + return lshr(shl(x, y), y) == x; +} + +E ToIBV(E x) { + auto &ctx = x.ctx(); + return z3::ite(x, ctx.bv_val(0, 1), ctx.bv_val(1, 1)); +} + +E SAddWithOverflow(E x, E y) { + return z3::concat(x + y, ToIBV(add_no_soverflow(x, y))); +} + +E UAddWithOverflow(E x, E y) { + return z3::concat(x + y, ToIBV(add_no_uoverflow(x, y))); +} + +E SSubWithOverflow(E x, E y) { + return z3::concat(x - y, ToIBV(sub_no_soverflow(x, y))); +} + +E USubWithOverflow(E x, E y) { + return z3::concat(x - y, ToIBV(sub_no_uoverflow(x, y))); +} + +E SMulWithOverflow(E x, E y) { + return z3::concat(x * y, ToIBV(mul_no_soverflow(x, y))); +} + +E UMulWithOverflow(E x, E y) { + return z3::concat(x * y, ToIBV(mul_no_uoverflow(x, y))); +} + +// following implementations were borrowed from alive2 codebase +E BSwap(E x) { + auto nbits = x.get_sort().bv_size(); + constexpr unsigned bytelen = 8; + assert(nbits % (bytelen * 2) == 0); + E res = x.extract(bytelen - 1, 0); + for (unsigned i = 1; i < nbits / bytelen; i++) { + res = z3::concat(res, x.extract((i + 1) * bytelen - 1, i * bytelen)); + } + return res; +} +E Cttz(E x) { + auto srt = x.get_sort(); + auto result = x.ctx().bv_val(0, srt.bv_size()); + for (int i = srt.bv_size() - 1; i >= 0; --i) { + result = z3::ite(x.extract(i, i) == x.ctx().bv_val(1, 1), + x.ctx().bv_val(i, srt.bv_size()), result); + } + return result; +} +E Ctlz(E x) { + auto nbits = x.get_sort().bv_size(); + auto result = x.ctx().bv_val(nbits, nbits); + for (unsigned i = 0; i < nbits; ++i) { + result = z3::ite(x.extract(i, i) == x.ctx().bv_val(1, 1), + x.ctx().bv_val(nbits - 1 - i, nbits), result); + } + return result; +} +E BitReverse(E x) { + auto nbits = x.get_sort().bv_size(); + + E res = x.extract(0, 0); + for (unsigned i = 1; i < nbits; ++i) { + res = concat(res, x.extract(i, i)); + } + + return res; +} +E FShl(E a, E b, E c) { + auto width = a.ctx().bv_val(a.get_sort().bv_size(), a.get_sort().bv_size()); + E c_mod_width = z3::urem(c, width); + return shl(a, c_mod_width) | z3::lshr(b, width - c_mod_width); +} + +E FShr(E a, E b, E c) { + auto width = a.ctx().bv_val(a.get_sort().bv_size(), a.get_sort().bv_size()); + E c_mod_width = z3::urem(c, width); + return shl(a , (width - c_mod_width)) | z3::lshr(b, c_mod_width); +} + +static E IntSMin(unsigned bits, z3::context &ctx) { + E v = ctx.bv_val(1, 1); + if (bits > 1) + v = z3::concat(v, ctx.bv_val(0, bits - 1)); + return v; +} + +static E IntSMax(unsigned bits, z3::context &ctx) { + E v = ctx.bv_val(0, 1); + if (bits > 1) + v = z3::concat(v, ctx.bv_val(-1, bits - 1)); + return v; +} + +static E IntUMax(unsigned bits, z3::context &ctx) { + return ctx.bv_val(-1, bits); +} + +E SAddSat(E x, E y) { + E add_ext = z3::sext(x, 1) + z3::sext(y, 1); + auto bw = x.get_sort().bv_size(); + auto min = IntSMin(bw, x.ctx()); + auto max = IntSMax(bw, x.ctx()); + return z3::ite(z3::sle(add_ext, z3::sext(min, 1)), + min, + z3::ite(z3::sge(add_ext, z3::sext(max, 1)), + max, + x + y)); +} + +E UAddSat(E x, E y) { + return z3::ite(z3::bvadd_no_overflow(x, y, false), + x + y, + IntUMax(x.get_sort().bv_size(), x.ctx())); +} + +E SSubSat(E x, E y){ + E sub_ext = z3::sext(x, 1) - sext(y, 1); + auto bw = x.get_sort().bv_size(); + auto min = IntSMin(bw, x.ctx()); + auto max = IntSMax(bw, x.ctx()); + return z3::ite(z3::sle(sub_ext, z3::sext(min, 1)), + min, + z3::ite(z3::sge(sub_ext, z3::sext(max, 1)), + max, + x - y)); +} + +E USubSat(E x, E y) { + return z3::ite(uge(y, x), + x.ctx().bv_val(0, x.get_sort().bv_size()), + x - y); +} + +E sdiv_exact(E x, E y) { + return x / y * y == x; +} + +E udiv_exact(E x, E y) { + return udiv(x, y) * y == x; +} + +E ashr_exact(E x, E y) { + return (shl(ashr(x, y), y)) == x; +} + +E lshr_exact(E x, E y) { + return (shl(lshr(x, y), y) ) == x; +} + +}