From c0324db3ac7e5984c632f46690f58c333b9a42b2 Mon Sep 17 00:00:00 2001 From: Liana Hadarean Date: Thu, 22 Mar 2012 21:45:31 +0000 Subject: [PATCH] Merged updated version of the bitvector theory: * added simplification rewrites --- src/prop/bvminisat/core/Solver.cc | 5 +- src/prop/sat_module.cpp | 7 +- src/smt/smt_engine.cpp | 1 + src/theory/bv/Makefile.am | 5 +- src/theory/bv/bitblast_strategies.cpp | 17 +- src/theory/bv/bitblast_strategies.h | 3 +- src/theory/bv/bv_sat.cpp | 14 +- src/theory/bv/bv_sat.h | 2 +- src/theory/bv/theory_bv.cpp | 28 +- src/theory/bv/theory_bv.h | 3 +- src/theory/bv/theory_bv_rewrite_rules.h | 455 +++++++-- ...ory_bv_rewrite_rules_constant_evaluation.h | 414 ++++++++ src/theory/bv/theory_bv_rewrite_rules_core.h | 7 +- .../theory_bv_rewrite_rules_normalization.h | 147 +++ ...y_bv_rewrite_rules_operator_elimination.h} | 139 ++- .../theory_bv_rewrite_rules_simplification.h | 918 ++++++++++++++++++ src/theory/bv/theory_bv_rewriter.cpp | 611 +++++++++--- src/theory/bv/theory_bv_rewriter.h | 53 +- src/theory/bv/theory_bv_utils.h | 44 + src/util/bitvector.h | 283 +++++- src/util/integer_cln_imp.h | 51 + src/util/integer_gmp_imp.h | 84 +- 22 files changed, 2984 insertions(+), 307 deletions(-) create mode 100644 src/theory/bv/theory_bv_rewrite_rules_constant_evaluation.h create mode 100644 src/theory/bv/theory_bv_rewrite_rules_normalization.h rename src/theory/bv/{theory_bv_rewrite_rules_arith.h => theory_bv_rewrite_rules_operator_elimination.h} (69%) create mode 100644 src/theory/bv/theory_bv_rewrite_rules_simplification.h diff --git a/src/prop/bvminisat/core/Solver.cc b/src/prop/bvminisat/core/Solver.cc index af342dbbc..7ff7b50db 100644 --- a/src/prop/bvminisat/core/Solver.cc +++ b/src/prop/bvminisat/core/Solver.cc @@ -23,6 +23,7 @@ OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWA #include "mtl/Sort.h" #include "core/Solver.h" #include +#include "util/output.h" using namespace BVMinisat; @@ -766,6 +767,8 @@ static double luby(double y, int x){ // NOTE: assumptions passed in member-variable 'assumptions'. lbool Solver::solve_() { + Debug("bvminisat") <<"BVMinisat::Solving learned clauses " << learnts.size() <<"\n"; + Debug("bvminisat") <<"BVMinisat::Solving assumptions " << assumptions.size() <<"\n"; model.clear(); conflict.clear(); if (!ok) return l_False; @@ -930,7 +933,7 @@ void Solver::garbageCollect() // Initialize the next region to a size corresponding to the estimated utilization degree. This // is not precise but should avoid some unnecessary reallocations for the new region: ClauseAllocator to(ca.size() - ca.wasted()); - + Debug("bvminisat") << " BVMinisat::Garbage collection \n"; relocAll(to); if (verbosity >= 2) printf("| Garbage collection: %12d bytes => %12d bytes |\n", diff --git a/src/prop/sat_module.cpp b/src/prop/sat_module.cpp index db911f488..cda32a0e8 100644 --- a/src/prop/sat_module.cpp +++ b/src/prop/sat_module.cpp @@ -55,7 +55,10 @@ MinisatSatSolver::~MinisatSatSolver() { void MinisatSatSolver::addClause(SatClause& clause, bool removable) { Debug("sat::minisat") << "Add clause " << clause <<"\n"; BVMinisat::vec minisat_clause; - toMinisatClause(clause, minisat_clause); + toMinisatClause(clause, minisat_clause); + // for(unsigned i = 0; i < minisat_clause.size(); ++i) { + // d_minisat->setFrozen(BVMinisat::var(minisat_clause[i]), true); + // } d_minisat->addClause(minisat_clause); } @@ -102,7 +105,7 @@ SatLiteralValue MinisatSatSolver::solve(const context::CDList & assu } Debug("sat::minisat") <<"\n"; - SatLiteralValue result = toSatLiteralValue(d_minisat->solve(assump)); + SatLiteralValue result = toSatLiteralValue(d_minisat->solve(assump)); return result; } diff --git a/src/smt/smt_engine.cpp b/src/smt/smt_engine.cpp index 4f28225af..453817ec1 100644 --- a/src/smt/smt_engine.cpp +++ b/src/smt/smt_engine.cpp @@ -670,6 +670,7 @@ void SmtEnginePrivate::removeITEs() { for (unsigned i = 0; i < d_assertionsToCheck.size(); ++ i) { d_assertionsToCheck[i] = theory::Rewriter::rewrite(d_assertionsToCheck[i]); } + } void SmtEnginePrivate::staticLearning() { diff --git a/src/theory/bv/Makefile.am b/src/theory/bv/Makefile.am index 1b8e902e0..af760e520 100644 --- a/src/theory/bv/Makefile.am +++ b/src/theory/bv/Makefile.am @@ -17,7 +17,10 @@ libbv_la_SOURCES = \ theory_bv.cpp \ theory_bv_rewrite_rules.h \ theory_bv_rewrite_rules_core.h \ - theory_bv_rewrite_rules_arith.h \ + theory_bv_rewrite_rules_operator_elimination.h \ + theory_bv_rewrite_rules_constant_evaluation.h \ + theory_bv_rewrite_rules_normalization.h \ + theory_bv_rewrite_rules_simplification.h \ theory_bv_type_rules.h \ theory_bv_rewriter.h \ theory_bv_rewriter.cpp \ diff --git a/src/theory/bv/bitblast_strategies.cpp b/src/theory/bv/bitblast_strategies.cpp index dacd6a538..6cbec732c 100644 --- a/src/theory/bv/bitblast_strategies.cpp +++ b/src/theory/bv/bitblast_strategies.cpp @@ -175,7 +175,7 @@ Node inline sLessThanBB(const Bits&a, const Bits& b, bool orEqual) { Node UndefinedAtomBBStrategy(TNode node, Bitblaster* bb) { - Debug("bitvector") << "TheoryBV::Bitblaster Undefined bitblasting strategy for kind: " + Trace("bitvector") << "TheoryBV::Bitblaster Undefined bitblasting strategy for kind: " << node.getKind() << "\n"; Unreachable(); } @@ -261,6 +261,19 @@ Node DefaultUgeBB(TNode node, Bitblaster* bb){ Unimplemented(); } +// Node DefaultSltBB(TNode node, Bitblaster* bb){ +// Debug("bitvector-bb") << "Bitblasting node " << node << "\n"; +// // shoudl be rewritten in terms of ult +// Unimplemented(); +// } + +// Node DefaultSleBB(TNode node, Bitblaster* bb){ +// Debug("bitvector-bb") << "Bitblasting node " << node << "\n"; +// // shoudl be rewritten in terms of ule +// Unimplemented(); +// } + + Node DefaultSltBB(TNode node, Bitblaster* bb){ Debug("bitvector-bb") << "Bitblasting node " << node << "\n"; @@ -301,7 +314,7 @@ Node DefaultSgeBB(TNode node, Bitblaster* bb){ /// Term bitblasting strategies void UndefinedTermBBStrategy(TNode node, Bits& bits, Bitblaster* bb) { - Debug("bitvector") << "theory::bv:: Undefined bitblasting strategy for kind: " + Trace("bitvector") << "theory::bv:: Undefined bitblasting strategy for kind: " << node.getKind() << "\n"; Unreachable(); } diff --git a/src/theory/bv/bitblast_strategies.h b/src/theory/bv/bitblast_strategies.h index 504755e6c..c445af626 100644 --- a/src/theory/bv/bitblast_strategies.h +++ b/src/theory/bv/bitblast_strategies.h @@ -60,7 +60,8 @@ Node DefaultSgeBB(TNode node, Bitblaster* bb); /// other modes Node AdderUltBB(TNode node, Bitblaster* bb); - +Node SltBB(TNode node, Bitblaster* bb); +Node SleBB(TNode node, Bitblaster* bb); /** diff --git a/src/theory/bv/bv_sat.cpp b/src/theory/bv/bv_sat.cpp index 97e886db2..d386fd4db 100644 --- a/src/theory/bv/bv_sat.cpp +++ b/src/theory/bv/bv_sat.cpp @@ -81,6 +81,8 @@ void Bitblaster::bbAtom(TNode node) { return; } + BVDebug("bitvector-bitblast") << "Bitblasting node " << node <<"\n"; + // the bitblasted definition of the atom Node atom_bb = d_atomBBStrategies[node.getKind()](node, this); // asserting that the atom is true iff the definition holds @@ -97,7 +99,7 @@ void Bitblaster::bbTerm(TNode node, Bits& bits) { getBBTerm(node, bits); return; } - + BVDebug("bitvector-bitblast") << "Bitblasting node " << node <<"\n"; d_termBBStrategies[node.getKind()] (node, bits,this); Assert (bits.size() == utils::getSize(node)); @@ -115,7 +117,7 @@ void Bitblaster::bbTerm(TNode node, Bits& bits) { */ void Bitblaster::bitblast(TNode node) { TimerStat::CodeTimer codeTimer(d_statistics.d_bitblastTimer); - + /// strip the not if (node.getKind() == kind::NOT) { node = node[0]; @@ -123,9 +125,9 @@ void Bitblaster::bitblast(TNode node) { if (node.getKind() == kind::EQUAL || node.getKind() == kind::BITVECTOR_ULT || + node.getKind() == kind::BITVECTOR_ULE || node.getKind() == kind::BITVECTOR_SLT || - node.getKind() == kind::BITVECTOR_ULE || - node.getKind() == kind::BITVECTOR_SLE ) + node.getKind() == kind::BITVECTOR_SLE) { bbAtom(node); } @@ -159,7 +161,8 @@ void Bitblaster::assertToSat(TNode lit) { atom = lit; } - Assert (hasBBAtom(atom)); + Assert (hasBBAtom(atom)); + Node rewr_atom = Rewriter::rewrite(atom); SatLiteral markerLit = d_cnfStream->getLiteral(atom); if(lit.getKind() == kind::NOT) { @@ -180,6 +183,7 @@ void Bitblaster::assertToSat(TNode lit) { */ bool Bitblaster::solve() { + Trace("bitvector") << "Bitblaster::solve() asserted atoms " << d_assertedAtoms.size() <<"\n"; return SatValTrue == d_satSolver->solve(d_assertedAtoms); } diff --git a/src/theory/bv/bv_sat.h b/src/theory/bv/bv_sat.h index 773491fd0..3ffc79b7a 100644 --- a/src/theory/bv/bv_sat.h +++ b/src/theory/bv/bv_sat.h @@ -79,7 +79,7 @@ class Bitblaster { AtomSet d_bitblastedAtoms; context::CDList d_assertedAtoms; /**< context dependent list storing the atoms - currently asserted by the DPLL SAT solver. */ + currently asserted by the DPLL SAT solver. */ /// helper methods bool hasBBAtom(TNode node); diff --git a/src/theory/bv/theory_bv.cpp b/src/theory/bv/theory_bv.cpp index 79c065d7e..45d99f9c9 100644 --- a/src/theory/bv/theory_bv.cpp +++ b/src/theory/bv/theory_bv.cpp @@ -45,48 +45,48 @@ TheoryBV::~TheoryBV() { } TheoryBV::Statistics::Statistics(): d_avgConflictSize("theory::bv::AvgBVConflictSize"), - d_solveSubstitutions("theory::bv::NumberOfSolveSubstitutions", 0) + d_solveSubstitutions("theory::bv::NumberOfSolveSubstitutions", 0), + d_solveTimer("theory::bv::solveTimer") { StatisticsRegistry::registerStat(&d_avgConflictSize); StatisticsRegistry::registerStat(&d_solveSubstitutions); + StatisticsRegistry::registerStat(&d_solveTimer); } TheoryBV::Statistics::~Statistics() { StatisticsRegistry::unregisterStat(&d_avgConflictSize); StatisticsRegistry::unregisterStat(&d_solveSubstitutions); + StatisticsRegistry::unregisterStat(&d_solveTimer); } void TheoryBV::preRegisterTerm(TNode node) { - BVDebug("bitvector-preregister") << "TheoryBV::preRegister(" << node << ")" << std::endl; //marker literal: bitblast all terms before we start - d_bitblaster->bitblast(node); + //d_bitblaster->bitblast(node); } void TheoryBV::check(Effort e) { - BVDebug("bitvector") << "TheoryBV::check(" << e << ")" << std::endl; - if (fullEffort(e)) { - std::vector assertions; + if (fullEffort(e) && !done()) { + Trace("bitvector")<< "TheoryBV::check(" << e << ")" << std::endl; + std::vector assertions; + while (!done()) { TNode assertion = get(); - Debug("bitvector") << "assertion " << assertion << "\n"; - assertions.push_back(assertion); + Trace("bitvector-assertions") << "TheoryBV::check assertion " << assertion << "\n"; d_bitblaster->bitblast(assertion); + d_bitblaster->assertToSat(assertion); } - - std::vector::const_iterator it = assertions.begin(); - for (; it != assertions.end(); ++it) { - d_bitblaster->assertToSat(*it); - } + + TimerStat::CodeTimer codeTimer(d_statistics.d_solveTimer); bool res = d_bitblaster->solve(); if (res == false) { std::vector conflictAtoms; d_bitblaster->getConflict(conflictAtoms); - d_statistics.d_avgConflictSize.addEntry(conflictAtoms.size()); Node conflict = mkConjunction(conflictAtoms); d_out->conflict(conflict); + Trace("bitvector") << "TheoryBV::check returns conflict. \n "; return; } } diff --git a/src/theory/bv/theory_bv.h b/src/theory/bv/theory_bv.h index 5303b6595..36ba17b52 100644 --- a/src/theory/bv/theory_bv.h +++ b/src/theory/bv/theory_bv.h @@ -81,7 +81,8 @@ private: class Statistics { public: AverageStat d_avgConflictSize; - IntStat d_solveSubstitutions; + IntStat d_solveSubstitutions; + TimerStat d_solveTimer; Statistics(); ~Statistics(); }; diff --git a/src/theory/bv/theory_bv_rewrite_rules.h b/src/theory/bv/theory_bv_rewrite_rules.h index 2e2347d3c..b01a0646c 100644 --- a/src/theory/bv/theory_bv_rewrite_rules.h +++ b/src/theory/bv/theory_bv_rewrite_rules.h @@ -2,7 +2,7 @@ /*! \file theory_bv_rewrite_rules.h ** \verbatim ** Original author: dejan - ** Major contributors: none + ** Major contributors: lianah ** Minor contributors (to current version): mdeters ** This file is part of the CVC4 prototype. ** Copyright (c) 2009, 2010, 2011 The Analysis of Computer Systems Group (ACSys) @@ -31,7 +31,8 @@ namespace theory { namespace bv { enum RewriteRuleId { - /// core rewrite rules + + /// core normalization rules EmptyRule, ConcatFlatten, ConcatExtractMerge, @@ -43,23 +44,102 @@ enum RewriteRuleId { FailEq, SimplifyEq, ReflexivityEq, + /// operator elimination rules - UgtToUlt, - UgeToUle, - SgeToSle, - SgtToSlt, + UgtEliminate, + UgeEliminate, + SgeEliminate, + SgtEliminate, + SubEliminate, + SltEliminate, + SleEliminate, + CompEliminate, RepeatEliminate, RotateLeftEliminate, RotateRightEliminate, NandEliminate, NorEliminate, + XnorEliminate, SdivEliminate, UdivEliminate, SmodEliminate, SremEliminate, ZeroExtendEliminate, - // division by zero guards: rewrite a / b as b!=0 => a/b = ... - DivZeroGuard + SignExtendEliminate, + /// ground term evaluation + EvalEquals, + EvalConcat, + EvalAnd, + EvalOr, + EvalXor, + EvalNot, + EvalMult, + EvalPlus, + EvalUdiv, + EvalUrem, + EvalShl, + EvalLshr, + EvalAshr, + EvalUlt, + EvalUle, + EvalExtract, + EvalSignExtend, + EvalRotateLeft, + EvalRotateRight, + EvalNeg, + + /// simplification rules + /// all of these rules decrease formula size + ShlByConst, + LshrByConst, + AshrByConst, + BitwiseIdemp, + AndZero, + AndOne, + OrZero, + OrOne, + XorDuplicate, + XorOne, + XorZero, + BitwiseNotAnd, + BitwiseNotOr, + XorNot, + NotIdemp, + LtSelf, + LteSelf, + UltZero, + UltSelf, + UleZero, + UleSelf, + ZeroUle, + UleMax, + NotUlt, + NotUle, + MultOne, + MultZero, + MultPow2, + PlusZero, + PlusSelf, + PlusNegSelf, + NegIdemp, + UdivPow2, + UdivOne, + UdivSelf, + UremPow2, + UremOne, + UremSelf, + ShiftZero, + + /// normalization rules + ExtractBitwise, + ExtractNot, + ExtractArith, + ExtractArith2, + DoubleNeg, + NotConcat, + NotAnd, // not sure why this would help (not done) + NotOr, // not sure why this would help (not done) + NotXor // not sure why this would help (not done) }; inline std::ostream& operator << (std::ostream& out, RewriteRuleId ruleId) { @@ -75,10 +155,10 @@ inline std::ostream& operator << (std::ostream& out, RewriteRuleId ruleId) { case FailEq: out << "FailEq"; return out; case SimplifyEq: out << "SimplifyEq"; return out; case ReflexivityEq: out << "ReflexivityEq"; return out; - case UgtToUlt: out << "UgtToUlt"; return out; - case SgtToSlt: out << "SgtToSlt"; return out; - case UgeToUle: out << "UgeToUle"; return out; - case SgeToSle: out << "SgeToSle"; return out; + case UgtEliminate: out << "UgtEliminate"; return out; + case SgtEliminate: out << "SgtEliminate"; return out; + case UgeEliminate: out << "UgeEliminate"; return out; + case SgeEliminate: out << "SgeEliminate"; return out; case RepeatEliminate: out << "RepeatEliminate"; return out; case RotateLeftEliminate: out << "RotateLeftEliminate"; return out; case RotateRightEliminate:out << "RotateRightEliminate";return out; @@ -87,8 +167,80 @@ inline std::ostream& operator << (std::ostream& out, RewriteRuleId ruleId) { case SdivEliminate : out << "SdivEliminate"; return out; case SremEliminate : out << "SremEliminate"; return out; case SmodEliminate : out << "SmodEliminate"; return out; - case ZeroExtendEliminate :out << "ZeroExtendEliminate"; return out; - case DivZeroGuard : out << "DivZeroGuard"; return out; + case ZeroExtendEliminate :out << "ZeroExtendEliminate"; return out; + case EvalEquals : out << "EvalEquals"; return out; + case EvalConcat : out << "EvalConcat"; return out; + case EvalAnd : out << "EvalAnd"; return out; + case EvalOr : out << "EvalOr"; return out; + case EvalXor : out << "EvalXor"; return out; + case EvalNot : out << "EvalNot"; return out; + case EvalMult : out << "EvalMult"; return out; + case EvalPlus : out << "EvalPlus"; return out; + case EvalUdiv : out << "EvalUdiv"; return out; + case EvalUrem : out << "EvalUrem"; return out; + case EvalShl : out << "EvalShl"; return out; + case EvalLshr : out << "EvalLshr"; return out; + case EvalAshr : out << "EvalAshr"; return out; + case EvalUlt : out << "EvalUlt"; return out; + case EvalUle : out << "EvalUle"; return out; + case EvalExtract : out << "EvalExtract"; return out; + case EvalSignExtend : out << "EvalSignExtend"; return out; + case EvalRotateLeft : out << "EvalRotateLeft"; return out; + case EvalRotateRight : out << "EvalRotateRight"; return out; + case EvalNeg : out << "EvalNeg"; return out; + case ShlByConst : out << "ShlByConst"; return out; + case LshrByConst : out << "LshrByConst"; return out; + case AshrByConst : out << "AshrByConst"; return out; + case ExtractBitwise : out << "ExtractBitwise"; return out; + case ExtractNot : out << "ExtractNot"; return out; + case ExtractArith : out << "ExtractArith"; return out; + case ExtractArith2 : out << "ExtractArith2"; return out; + case DoubleNeg : out << "DoubleNeg"; return out; + case NotConcat : out << "NotConcat"; return out; + case NotAnd : out << "NotAnd"; return out; + case NotOr : out << "NotOr"; return out; + case NotXor : out << "NotXor"; return out; + case BitwiseIdemp : out << "BitwiseIdemp"; return out; + case XorDuplicate : out << "XorDuplicate"; return out; + case BitwiseNotAnd : out << "BitwiseNotAnd"; return out; + case BitwiseNotOr : out << "BitwiseNotOr"; return out; + case XorNot : out << "XorNot"; return out; + case LtSelf : out << "LtSelf"; return out; + case LteSelf : out << "LteSelf"; return out; + case UltZero : out << "UltZero"; return out; + case UleZero : out << "UleZero"; return out; + case ZeroUle : out << "ZeroUle"; return out; + case NotUlt : out << "NotUlt"; return out; + case NotUle : out << "NotUle"; return out; + case UleMax : out << "UleMax"; return out; + case SltEliminate : out << "SltEliminate"; return out; + case SleEliminate : out << "SleEliminate"; return out; + case AndZero : out << "AndZero"; return out; + case AndOne : out << "AndOne"; return out; + case OrZero : out << "OrZero"; return out; + case OrOne : out << "OrOne"; return out; + case XorOne : out << "XorOne"; return out; + case XorZero : out << "XorZero"; return out; + case MultOne : out << "MultOne"; return out; + case MultZero : out << "MultZero"; return out; + case MultPow2 : out << "MultPow2"; return out; + case PlusZero : out << "PlusZero"; return out; + case PlusSelf : out << "PlusSelf"; return out; + case PlusNegSelf : out << "PlusNegSelf"; return out; + case NegIdemp : out << "NegIdemp"; return out; + case UdivPow2 : out << "UdivPow2"; return out; + case UdivOne : out << "UdivOne"; return out; + case UdivSelf : out << "UdivSelf"; return out; + case UremPow2 : out << "UremPow2"; return out; + case UremOne : out << "UremOne"; return out; + case UremSelf : out << "UremSelf"; return out; + case ShiftZero : out << "ShiftZero"; return out; + case SubEliminate : out << "SubEliminate"; return out; + case CompEliminate : out << "CompEliminate"; return out; + case XnorEliminate : out << "XnorEliminate"; return out; + case SignExtendEliminate : out << "SignExtendEliminate"; return out; + case NotIdemp : out << "NotIdemp"; return out; + case UleSelf: out << "UleSelf"; return out; default: Unreachable(); } @@ -97,39 +249,39 @@ inline std::ostream& operator << (std::ostream& out, RewriteRuleId ruleId) { template class RewriteRule { - class RuleStatistics { + // class RuleStatistics { - /** The name of the rule prefixed with the prefix */ - static std::string getStatName(const char* prefix) { - std::stringstream statName; - statName << prefix << rule; - return statName.str(); - } + // /** The name of the rule prefixed with the prefix */ + // static std::string getStatName(const char* prefix) { + // std::stringstream statName; + // statName << prefix << rule; + // return statName.str(); + // } - public: + // public: - /** Number of applications of this rule */ - IntStat d_ruleApplications; + // /** Number of applications of this rule */ + // IntStat d_ruleApplications; - /** Constructor */ - RuleStatistics() - : d_ruleApplications(getStatName("theory::bv::RewriteRules::count"), 0) { - StatisticsRegistry::registerStat(&d_ruleApplications); - } + // /** Constructor */ + // RuleStatistics() + // : d_ruleApplications(getStatName("theory::bv::RewriteRules::count"), 0) { + // StatisticsRegistry::registerStat(&d_ruleApplications); + // } - /** Destructor */ - ~RuleStatistics() { - StatisticsRegistry::unregisterStat(&d_ruleApplications); - } - }; + // /** Destructor */ + // ~RuleStatistics() { + // StatisticsRegistry::unregisterStat(&d_ruleApplications); + // } + // }; - /* Statistics about the rule */ - // NOTE: Cannot have static fields like this, or else you can't have - // two SmtEngines in the process (the second-to-be-destroyed will - // have a dangling pointer and segfault). If this statistic is needed, - // fix the rewriter by making it an instance per-SmtEngine (instead of - // static). - //static RuleStatistics* s_statistics; + // /* Statistics about the rule */ + // // NOTE: Cannot have static fields like this, or else you can't have + // // two SmtEngines in the process (the second-to-be-destroyed will + // // have a dangling pointer and segfault). If this statistic is needed, + // // fix the rewriter by making it an instance per-SmtEngine (instead of + // // static). + // static RuleStatistics* s_statistics; /** Actually apply the rewrite rule */ static inline Node apply(Node node) { @@ -139,18 +291,18 @@ class RewriteRule { public: RewriteRule() { - /* - if (s_statistics == NULL) { - s_statistics = new RuleStatistics(); - } - */ + + // if (s_statistics == NULL) { + // s_statistics = new RuleStatistics(); + // } + } ~RewriteRule() { - /* - delete s_statistics; - s_statistics = NULL; - */ + + // delete s_statistics; + // s_statistics = NULL; + } static inline bool applies(Node node) { @@ -172,14 +324,13 @@ public: } }; -/* -template -typename RewriteRule::RuleStatistics* RewriteRule::s_statistics = NULL; -*/ + + // template + // typename RewriteRule::RuleStatistics* RewriteRule::s_statistics = NULL; + /** Have to list all the rewrite rules to get the statistics out */ struct AllRewriteRules { - RewriteRule rule00; RewriteRule rule01; RewriteRule rule02; @@ -191,10 +342,10 @@ struct AllRewriteRules { RewriteRule rule08; RewriteRule rule09; RewriteRule rule10; - RewriteRule rule11; - RewriteRule rule12; - RewriteRule rule13; - RewriteRule rule14; + RewriteRule rule11; + RewriteRule rule12; + RewriteRule rule13; + RewriteRule rule14; RewriteRule rule17; RewriteRule rule18; RewriteRule rule19; @@ -203,8 +354,80 @@ struct AllRewriteRules { RewriteRule rule22; RewriteRule rule23; RewriteRule rule24; - RewriteRule rule25; - + RewriteRule rule25; + RewriteRule rule26; + RewriteRule rule27; + RewriteRule rule28; + RewriteRule rule29; + RewriteRule rule31; + RewriteRule rule32; + RewriteRule rule34; + RewriteRule rule35; + RewriteRule rule36; + RewriteRule rule37; + RewriteRule rule38; + RewriteRule rule39; + RewriteRule rule40; + RewriteRule rule43; + RewriteRule rule44; + RewriteRule rule45; + RewriteRule rule46; + RewriteRule rule47; + RewriteRule rule48; + RewriteRule rule50; + RewriteRule rule51; + RewriteRule rule52; + RewriteRule rule53; + RewriteRule rule54; + RewriteRule rule55; + RewriteRule rule56; + RewriteRule rule57; + RewriteRule rule58; + RewriteRule rule59; + RewriteRule rule60; + RewriteRule rule61; + RewriteRule rule62; + RewriteRule rule63; + RewriteRule rule64; + RewriteRule rule65; + RewriteRule rule66; + RewriteRule rule67; + RewriteRule rule68; + RewriteRule rule69; + RewriteRule rule70; + RewriteRule rule71; + RewriteRule rule72; + RewriteRule rule73; + RewriteRule rule74; + RewriteRule rule75; + RewriteRule rule76; + RewriteRule rule77; + RewriteRule rule78; + RewriteRule rule79; + RewriteRule rule80; + RewriteRule rule81; + RewriteRule rule82; + RewriteRule rule83; + RewriteRule rule84; + RewriteRule rule85; + RewriteRule rule86; + RewriteRule rule87; + RewriteRule rule88; + RewriteRule rule89; + RewriteRule rule90; + RewriteRule rule91; + RewriteRule rule92; + RewriteRule rule93; + RewriteRule rule94; + RewriteRule rule95; + RewriteRule rule96; + RewriteRule rule97; + RewriteRule rule98; + RewriteRule rule99; + RewriteRule rule100; + RewriteRule rule101; + RewriteRule rule102; + RewriteRule rule103; }; template<> @@ -214,6 +437,7 @@ bool RewriteRule::applies(Node node) { template<> Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule for " << node.getKind() <<"\n"; Unreachable(); return node; } @@ -249,29 +473,108 @@ struct ApplyRuleToChildren { template < typename R1, - typename R2 = RewriteRule, - typename R3 = RewriteRule, - typename R4 = RewriteRule, - typename R5 = RewriteRule, - typename R6 = RewriteRule, - typename R7 = RewriteRule, - typename R8 = RewriteRule + typename R2 = RewriteRule, + typename R3 = RewriteRule, + typename R4 = RewriteRule, + typename R5 = RewriteRule, + typename R6 = RewriteRule, + typename R7 = RewriteRule, + typename R8 = RewriteRule, + typename R9 = RewriteRule, + typename R10 = RewriteRule, + typename R11 = RewriteRule, + typename R12 = RewriteRule, + typename R13 = RewriteRule, + typename R14 = RewriteRule, + typename R15 = RewriteRule, + typename R16 = RewriteRule, + typename R17 = RewriteRule, + typename R18 = RewriteRule, + typename R19 = RewriteRule, + typename R20 = RewriteRule > struct LinearRewriteStrategy { static Node apply(Node node) { Node current = node; - if (R1::applies(current)) current = R1::template run(current); - if (R2::applies(current)) current = R2::template run(current); - if (R3::applies(current)) current = R3::template run(current); - if (R4::applies(current)) current = R4::template run(current); - if (R5::applies(current)) current = R5::template run(current); - if (R6::applies(current)) current = R6::template run(current); - if (R7::applies(current)) current = R7::template run(current); - if (R8::applies(current)) current = R8::template run(current); + if (R1::applies(current)) current = R1::template run(current); + if (R2::applies(current)) current = R2::template run(current); + if (R3::applies(current)) current = R3::template run(current); + if (R4::applies(current)) current = R4::template run(current); + if (R5::applies(current)) current = R5::template run(current); + if (R6::applies(current)) current = R6::template run(current); + if (R7::applies(current)) current = R7::template run(current); + if (R8::applies(current)) current = R8::template run(current); + if (R9::applies(current)) current = R9::template run(current); + if (R10::applies(current)) current = R10::template run(current); + if (R11::applies(current)) current = R11::template run(current); + if (R12::applies(current)) current = R12::template run(current); + if (R13::applies(current)) current = R13::template run(current); + if (R14::applies(current)) current = R14::template run(current); + if (R15::applies(current)) current = R15::template run(current); + if (R16::applies(current)) current = R16::template run(current); + if (R17::applies(current)) current = R17::template run(current); + if (R18::applies(current)) current = R18::template run(current); + if (R19::applies(current)) current = R19::template run(current); + if (R20::applies(current)) current = R20::template run(current); return current; } }; +template < + typename R1, + typename R2 = RewriteRule, + typename R3 = RewriteRule, + typename R4 = RewriteRule, + typename R5 = RewriteRule, + typename R6 = RewriteRule, + typename R7 = RewriteRule, + typename R8 = RewriteRule, + typename R9 = RewriteRule, + typename R10 = RewriteRule, + typename R11 = RewriteRule, + typename R12 = RewriteRule, + typename R13 = RewriteRule, + typename R14 = RewriteRule, + typename R15 = RewriteRule, + typename R16 = RewriteRule, + typename R17 = RewriteRule, + typename R18 = RewriteRule, + typename R19 = RewriteRule, + typename R20 = RewriteRule + > +struct FixpointRewriteStrategy { + static Node apply(Node node) { + Node previous = node; + Node current = node; + do { + previous = current; + if (R1::applies(current)) current = R1::template run(current); + if (R2::applies(current)) current = R2::template run(current); + if (R3::applies(current)) current = R3::template run(current); + if (R4::applies(current)) current = R4::template run(current); + if (R5::applies(current)) current = R5::template run(current); + if (R6::applies(current)) current = R6::template run(current); + if (R7::applies(current)) current = R7::template run(current); + if (R8::applies(current)) current = R8::template run(current); + if (R9::applies(current)) current = R9::template run(current); + if (R10::applies(current)) current = R10::template run(current); + if (R11::applies(current)) current = R11::template run(current); + if (R12::applies(current)) current = R12::template run(current); + if (R13::applies(current)) current = R13::template run(current); + if (R14::applies(current)) current = R14::template run(current); + if (R15::applies(current)) current = R15::template run(current); + if (R16::applies(current)) current = R16::template run(current); + if (R17::applies(current)) current = R17::template run(current); + if (R18::applies(current)) current = R18::template run(current); + if (R19::applies(current)) current = R19::template run(current); + if (R20::applies(current)) current = R20::template run(current); + } while (previous != current); + + return current; + } +}; + + } // End namespace bv } // End namespace theory } // End namespace CVC4 diff --git a/src/theory/bv/theory_bv_rewrite_rules_constant_evaluation.h b/src/theory/bv/theory_bv_rewrite_rules_constant_evaluation.h new file mode 100644 index 000000000..1dc053b5d --- /dev/null +++ b/src/theory/bv/theory_bv_rewrite_rules_constant_evaluation.h @@ -0,0 +1,414 @@ +/********************* */ +/*! \file theory_bv_rewrite_rules_core.h + ** \verbatim + ** Original author: lianah + ** Major contributors: none + ** Minor contributors (to current version): + ** This file is part of the CVC4 prototype. + ** Copyright (c) 2009, 2010, 2011 The Analysis of Computer Systems Group (ACSys) + ** Courant Institute of Mathematical Sciences + ** New York University + ** See the file COPYING in the top-level source directory for licensing + ** information.\endverbatim + ** + ** \brief [[ Add one-line brief description here ]] + ** + ** [[ Add lengthier description here ]] + ** \todo document this file + **/ + +#include "cvc4_private.h" + +#pragma once + +#include "theory/bv/theory_bv_rewrite_rules.h" +#include "theory/bv/theory_bv_utils.h" + +namespace CVC4 { +namespace theory { +namespace bv { + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_AND && + utils::isBVGroundTerm(node)); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + BitVector a = node[0].getConst(); + BitVector b = node[1].getConst(); + BitVector res = a & b; + + return utils::mkConst(res); +} + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_OR && + utils::isBVGroundTerm(node)); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + BitVector a = node[0].getConst(); + BitVector b = node[1].getConst(); + BitVector res = a | b; + + return utils::mkConst(res); +} + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_XOR && + utils::isBVGroundTerm(node)); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + BitVector a = node[0].getConst(); + BitVector b = node[1].getConst(); + BitVector res = a ^ b; + + return utils::mkConst(res); +} + +// template<> +// bool RewriteRule::applies(Node node) { +// return (node.getKind() == kind::BITVECTOR_XNOR && +// utils::isBVGroundTerm(node)); +// } + +// template<> +// Node RewriteRule::apply(Node node) { +// BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; +// BitVector a = node[0].getConst(); +// BitVector b = node[1].getConst(); +// BitVector res = ~ (a ^ b); + +// return utils::mkConst(res); +// } +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_NOT && + utils::isBVGroundTerm(node)); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + BitVector a = node[0].getConst(); + BitVector res = ~ a; + return utils::mkConst(res); +} + +// template<> +// bool RewriteRule::applies(Node node) { +// return (node.getKind() == kind::BITVECTOR_COMP && +// utils::isBVGroundTerm(node)); +// } + +// template<> +// Node RewriteRule::apply(Node node) { +// BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; +// BitVector a = node[0].getConst(); +// BitVector b = node[1].getConst(); +// BitVector res; +// if (a == b) { +// res = BitVector(1, Integer(1)); +// } else { +// res = BitVector(1, Integer(0)); +// } + +// return utils::mkConst(res); +// } + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_MULT && + utils::isBVGroundTerm(node)); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + BitVector a = node[0].getConst(); + BitVector b = node[1].getConst(); + BitVector res = a * b; + + return utils::mkConst(res); +} + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_PLUS && + utils::isBVGroundTerm(node)); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + BitVector a = node[0].getConst(); + BitVector b = node[1].getConst(); + BitVector res = a + b; + + return utils::mkConst(res); +} + +// template<> +// bool RewriteRule::applies(Node node) { +// return (node.getKind() == kind::BITVECTOR_SUB && +// utils::isBVGroundTerm(node)); +// } + +// template<> +// Node RewriteRule::apply(Node node) { +// BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; +// BitVector a = node[0].getConst(); +// BitVector b = node[1].getConst(); +// BitVector res = a - b; + +// return utils::mkConst(res); +// } +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_NEG && + utils::isBVGroundTerm(node)); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + BitVector a = node[0].getConst(); + BitVector res = - a; + + return utils::mkConst(res); +} +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_UDIV && + utils::isBVGroundTerm(node)); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + BitVector a = node[0].getConst(); + BitVector b = node[1].getConst(); + BitVector res = a.unsignedDiv(b); + + return utils::mkConst(res); +} +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_UREM && + utils::isBVGroundTerm(node)); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + BitVector a = node[0].getConst(); + BitVector b = node[1].getConst(); + BitVector res = a.unsignedRem(b); + return utils::mkConst(res); +} + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_SHL && + utils::isBVGroundTerm(node)); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + BitVector a = node[0].getConst(); + BitVector b = node[1].getConst(); + + BitVector res = a.leftShift(b); + return utils::mkConst(res); +} + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_LSHR && + utils::isBVGroundTerm(node)); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + BitVector a = node[0].getConst(); + BitVector b = node[1].getConst(); + + BitVector res = a.logicalRightShift(b); + return utils::mkConst(res); +} + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_ASHR && + utils::isBVGroundTerm(node)); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + BitVector a = node[0].getConst(); + BitVector b = node[1].getConst(); + + BitVector res = a.arithRightShift(b); + return utils::mkConst(res); +} + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_ULT && + utils::isBVGroundTerm(node)); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + BitVector a = node[0].getConst(); + BitVector b = node[1].getConst(); + + if (a.unsignedLessThan(b)) { + return utils::mkTrue(); + } + return utils::mkFalse(); +} + +// template<> +// bool RewriteRule::applies(Node node) { +// return (node.getKind() == kind::BITVECTOR_SLT && +// utils::isBVGroundTerm(node)); +// } + +// template<> +// Node RewriteRule::apply(Node node) { +// BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; +// BitVector a = node[0].getConst(); +// BitVector b = node[1].getConst(); + +// if (a.signedLessThan(b)) { +// return utils::mkTrue(); +// } +// return utils::mkFalse(); + +// } + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_ULE && + utils::isBVGroundTerm(node)); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + BitVector a = node[0].getConst(); + BitVector b = node[1].getConst(); + + if (a.unsignedLessThanEq(b)) { + return utils::mkTrue(); + } + return utils::mkFalse(); +} + +// template<> +// bool RewriteRule::applies(Node node) { +// return (node.getKind() == kind::BITVECTOR_SLE && +// utils::isBVGroundTerm(node)); +// } + +// template<> +// Node RewriteRule::apply(Node node) { +// BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; +// BitVector a = node[0].getConst(); +// BitVector b = node[1].getConst(); + +// if (a.signedLessThanEq(b)) { +// return utils::mkTrue(); +// } +// return utils::mkFalse(); +// } + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_EXTRACT && + utils::isBVGroundTerm(node)); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + BitVector a = node[0].getConst(); + unsigned lo = utils::getExtractLow(node); + unsigned hi = utils::getExtractHigh(node); + + BitVector res = a.extract(hi, lo); + return utils::mkConst(res); +} + + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_CONCAT && + utils::isBVGroundTerm(node)); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + unsigned num = node.getNumChildren(); + BitVector res = node[0].getConst(); + for(unsigned i = 1; i < num; ++i ) { + BitVector a = node[i].getConst(); + res = res.concat(a); + } + return utils::mkConst(res); +} + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_SIGN_EXTEND && + utils::isBVGroundTerm(node)); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + BitVector a = node[0].getConst(); + unsigned amount = node.getOperator().getConst().signExtendAmount; + BitVector res = a.signExtend(amount); + + return utils::mkConst(res); +} + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::EQUAL && + utils::isBVGroundTerm(node)); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + BitVector a = node[0].getConst(); + BitVector b = node[1].getConst(); + if (a == b) { + return utils::mkTrue(); + } + return utils::mkFalse(); + +} + + +} +} +} diff --git a/src/theory/bv/theory_bv_rewrite_rules_core.h b/src/theory/bv/theory_bv_rewrite_rules_core.h index 851a6893c..3240ef5f3 100644 --- a/src/theory/bv/theory_bv_rewrite_rules_core.h +++ b/src/theory/bv/theory_bv_rewrite_rules_core.h @@ -50,6 +50,7 @@ Node RewriteRule::apply(Node node) { } } Node resultNode = result; + BVDebug("bv-rewrite") << "RewriteRule(" << resultNode << ")" << std::endl; return resultNode; } @@ -190,7 +191,7 @@ Node RewriteRule::apply(Node node) { template<> bool RewriteRule::applies(Node node) { - BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + //BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; if (node.getKind() != kind::BITVECTOR_EXTRACT) return false; if (node[0].getKind() != kind::BITVECTOR_CONCAT) return false; return true; @@ -245,7 +246,7 @@ Node RewriteRule::apply(Node node) { template<> bool RewriteRule::applies(Node node) { - BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + //BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; if (node.getKind() != kind::EQUAL) return false; if (node[0].getKind() != kind::CONST_BITVECTOR) return false; if (node[1].getKind() != kind::CONST_BITVECTOR) return false; @@ -254,7 +255,7 @@ bool RewriteRule::applies(Node node) { template<> Node RewriteRule::apply(Node node) { - return utils::mkFalse(); + return utils::mkFalse(); } template<> diff --git a/src/theory/bv/theory_bv_rewrite_rules_normalization.h b/src/theory/bv/theory_bv_rewrite_rules_normalization.h new file mode 100644 index 000000000..4e974881c --- /dev/null +++ b/src/theory/bv/theory_bv_rewrite_rules_normalization.h @@ -0,0 +1,147 @@ +/********************* */ +/*! \file theory_bv_rewrite_rules_normalization.h + ** \verbatim + ** Original author: lianah + ** Major contributors: none + ** Minor contributors (to current version): none + ** This file is part of the CVC4 prototype. + ** Copyright (c) 2009, 2010, 2011 The Analysis of Computer Systems Group (ACSys) + ** Courant Institute of Mathematical Sciences + ** New York University + ** See the file COPYING in the top-level source directory for licensing + ** information.\endverbatim + ** + ** \brief [[ Add one-line brief description here ]] + ** + ** [[ Add lengthier description here ]] + ** \todo document this file + **/ + +#include "cvc4_private.h" + +#pragma once + +#include "theory/bv/theory_bv_rewrite_rules.h" +#include "theory/bv/theory_bv_utils.h" + +namespace CVC4 { +namespace theory { +namespace bv { + +/** + * ExtractBitwise + * (x bvop y) [i:j] ==> x[i:j] bvop y[i:j] + * where bvop is bvand,bvor, bvxor + */ +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_EXTRACT && + (node[0].getKind() == kind::BITVECTOR_AND || + node[0].getKind() == kind::BITVECTOR_OR || + node[0].getKind() == kind::BITVECTOR_XOR )); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + unsigned high = utils::getExtractHigh(node); + unsigned low = utils::getExtractLow(node); + Node a = utils::mkExtract(node[0][0], high, low); + Node b = utils::mkExtract(node[0][1], high, low); + Kind kind = node[0].getKind(); + return utils::mkNode(kind, a, b); +} + +/** + * ExtractNot + * + * (~ a) [i:j] ==> ~ (a[i:j]) + */ +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_EXTRACT && + node[0].getKind() == kind::BITVECTOR_NOT); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + unsigned low = utils::getExtractLow(node); + unsigned high = utils::getExtractHigh(node); + Node a = utils::mkExtract(node[0][0], high, low); + return utils::mkNode(kind::BITVECTOR_NOT, a); +} + +/** + * ExtractArith + * + * (a bvop b) [k:0] ==> (a[k:0] bvop b[k:0]) + */ + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_EXTRACT && + utils::getExtractLow(node) == 0 && + (node[0].getKind() == kind::BITVECTOR_PLUS || + node[0].getKind() == kind::BITVECTOR_MULT)); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + unsigned low = utils::getExtractLow(node); + Assert (low == 0); + unsigned high = utils::getExtractHigh(node); + Node a = utils::mkExtract(node[0][0], high, low); + Node b = utils::mkExtract(node[0][1], high, low); + + Kind kind = node[0].getKind(); + return utils::mkNode(kind, a, b); + +} + +/** + * ExtractArith2 + * + * (a bvop b) [i:j] ==> (a[i:0] bvop b[i:0]) [i:j] + */ + +// careful not to apply in a loop +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_EXTRACT && + (node[0].getKind() == kind::BITVECTOR_PLUS || + node[0].getKind() == kind::BITVECTOR_MULT)); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + unsigned low = utils::getExtractLow(node); + unsigned high = utils::getExtractHigh(node); + Node a = utils::mkExtract(node[0][0], high, 0); + Node b = utils::mkExtract(node[0][1], high, 0); + + Kind kind = node[0].getKind(); + Node a_op_b = utils::mkNode(kind, a, b); + + return utils::mkExtract(a_op_b, high, low); +} + + +// template<> +// bool RewriteRule<>::applies(Node node) { +// return (node.getKind() == kind::BITVECTOR_CONCAT); +// } + +// template<> +// Node RewriteRule<>::apply(Node node) { +// BVDebug("bv-rewrite") << "RewriteRule<>(" << node << ")" << std::endl; +// return resultNode; +// } + + + +} +} +} diff --git a/src/theory/bv/theory_bv_rewrite_rules_arith.h b/src/theory/bv/theory_bv_rewrite_rules_operator_elimination.h similarity index 69% rename from src/theory/bv/theory_bv_rewrite_rules_arith.h rename to src/theory/bv/theory_bv_rewrite_rules_operator_elimination.h index 08a571db3..f0460643b 100644 --- a/src/theory/bv/theory_bv_rewrite_rules_arith.h +++ b/src/theory/bv/theory_bv_rewrite_rules_operator_elimination.h @@ -1,9 +1,9 @@ /********************* */ /*! \file theory_bv_rewrite_rules_core.h ** \verbatim - ** Original author: dejan + ** Original author: lianah ** Major contributors: none - ** Minor contributors (to current version): mdeters + ** Minor contributors (to current version): ** This file is part of the CVC4 prototype. ** Copyright (c) 2009, 2010, 2011 The Analysis of Computer Systems Group (ACSys) ** Courant Institute of Mathematical Sciences @@ -29,13 +29,13 @@ namespace theory { namespace bv { template<> -bool RewriteRule::applies(Node node) { +bool RewriteRule::applies(Node node) { return (node.getKind() == kind::BITVECTOR_UGT); } template<> -Node RewriteRule::apply(Node node) { - BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; TNode a = node[0]; TNode b = node[1]; Node result = utils::mkNode(kind::BITVECTOR_ULT, b, a); @@ -44,13 +44,13 @@ Node RewriteRule::apply(Node node) { template<> -bool RewriteRule::applies(Node node) { +bool RewriteRule::applies(Node node) { return (node.getKind() == kind::BITVECTOR_UGE); } template<> -Node RewriteRule::apply(Node node) { - BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; TNode a = node[0]; TNode b = node[1]; Node result = utils::mkNode(kind::BITVECTOR_ULE, b, a); @@ -59,13 +59,13 @@ Node RewriteRule::apply(Node node) { template<> -bool RewriteRule::applies(Node node) { +bool RewriteRule::applies(Node node) { return (node.getKind() == kind::BITVECTOR_SGT); } template<> -Node RewriteRule::apply(Node node) { - BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RgewriteRule(" << node << ")" << std::endl; TNode a = node[0]; TNode b = node[1]; Node result = utils::mkNode(kind::BITVECTOR_SLT, b, a); @@ -74,19 +74,85 @@ Node RewriteRule::apply(Node node) { template<> -bool RewriteRule::applies(Node node) { +bool RewriteRule::applies(Node node) { return (node.getKind() == kind::BITVECTOR_SGE); } template<> -Node RewriteRule::apply(Node node) { - BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; TNode a = node[0]; TNode b = node[1]; Node result = utils::mkNode(kind::BITVECTOR_SLE, b, a); return result; } +template <> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_SLT); +} + +template <> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + + unsigned size = utils::getSize(node[0]); + Node pow_two = utils::mkConst(BitVector(size, Integer(1).multiplyByPow2(size - 1))); + Node a = utils::mkNode(kind::BITVECTOR_PLUS, node[0], pow_two); + Node b = utils::mkNode(kind::BITVECTOR_PLUS, node[1], pow_two); + + return utils::mkNode(kind::BITVECTOR_ULT, a, b); + +} + +template <> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_SLE); +} + +template <> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + + unsigned size = utils::getSize(node[0]); + Node pow_two = utils::mkConst(BitVector(size, Integer(1).multiplyByPow2(size - 1))); + Node a = utils::mkNode(kind::BITVECTOR_PLUS, node[0], pow_two); + Node b = utils::mkNode(kind::BITVECTOR_PLUS, node[1], pow_two); + + return utils::mkNode(kind::BITVECTOR_ULE, a, b); + +} + +template <> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_COMP); +} + +template <> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + Node comp = utils::mkNode(kind::EQUAL, node[0], node[1]); + Node one = utils::mkConst(1, 1); + Node zero = utils::mkConst(1, 0); + + return utils::mkNode(kind::ITE, comp, one, zero); +} + +template <> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_SUB); +} + +template <> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + Node negb = utils::mkNode(kind::BITVECTOR_NEG, node[1]); + Node a = node[0]; + + return utils::mkNode(kind::BITVECTOR_PLUS, a, negb); +} + + template<> bool RewriteRule::applies(Node node) { return (node.getKind() == kind::BITVECTOR_REPEAT); @@ -183,6 +249,21 @@ Node RewriteRule::apply(Node node) { return result; } +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_XNOR); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + TNode a = node[0]; + TNode b = node[1]; + Node xorNode = utils::mkNode(kind::BITVECTOR_XOR, a, b); + return utils::mkNode(kind::BITVECTOR_NOT, xorNode); +} + + template<> bool RewriteRule::applies(Node node) { return (node.getKind() == kind::BITVECTOR_SDIV); @@ -288,32 +369,34 @@ Node RewriteRule::apply(Node node) { TNode bv = node[0]; unsigned amount = node.getOperator().getConst().zeroExtendAmount; - + if (amount == 0) { + return node[0]; + } Node zero = utils::mkConst(amount, 0); Node result = utils::mkConcat(zero, node[0]); return result; } - template<> -bool RewriteRule::applies(Node node) { - return (node.getKind() == kind::BITVECTOR_UDIV || - node.getKind() == kind::BITVECTOR_UREM || - node.getKind() == kind::BITVECTOR_SDIV || - node.getKind() == kind::BITVECTOR_SREM || - node.getKind() == kind::BITVECTOR_SMOD); +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_SIGN_EXTEND); } template<> -Node RewriteRule::apply(Node node) { - BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; - - Unimplemented(); - -} +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + unsigned amount = node.getOperator().getConst().signExtendAmount; + if(amount == 0) { + return node[0]; + } + unsigned size = utils::getSize(node[0]); + Node sign_bit = utils::mkExtract(node[0], size-1, size-1); + Node extension = utils::mkConcat(sign_bit, amount); + return utils::mkConcat(extension, node[0]); +} } diff --git a/src/theory/bv/theory_bv_rewrite_rules_simplification.h b/src/theory/bv/theory_bv_rewrite_rules_simplification.h new file mode 100644 index 000000000..3b813d1fa --- /dev/null +++ b/src/theory/bv/theory_bv_rewrite_rules_simplification.h @@ -0,0 +1,918 @@ +/********************* */ +/*! \file theory_bv_rewrite_rules_simplification.h + ** \verbatim + ** Original author: lianah + ** Major contributors: none + ** Minor contributors (to current version): + ** This file is part of the CVC4 prototype. + ** Copyright (c) 2009, 2010, 2011 The Analysis of Computer Systems Group (ACSys) + ** Courant Institute of Mathematical Sciences + ** New York University + ** See the file COPYING in the top-level source directory for licensing + ** information.\endverbatim + ** + ** \brief [[ Add one-line brief description here ]] + ** + ** [[ Add lengthier description here ]] + ** \todo document this file + **/ + +#include "cvc4_private.h" + +#pragma once + +#include "theory/bv/theory_bv_rewrite_rules.h" +#include "theory/bv/theory_bv_utils.h" + +namespace CVC4 { +namespace theory { +namespace bv { + +// FIXME: this rules subsume the constant evaluation ones + + +/** + * ShlByConst + * + * Left Shift by constant amount + */ +template<> +bool RewriteRule::applies(Node node) { + // if the shift amount is constant + return (node.getKind() == kind::BITVECTOR_SHL && + node[1].getKind() == kind::CONST_BITVECTOR); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + Integer amount = node[1].getConst().toInteger(); + + Node a = node[0]; + uint32_t size = utils::getSize(a); + + + if (amount >= Integer(size)) { + // if we are shifting more than the length of the bitvector return 0 + return utils::mkConst(BitVector(size, Integer(0))); + } + + // make sure we do not lose information casting + Assert(amount < Integer(1).multiplyByPow2(32)); + + uint32_t uint32_amount = amount.toUnsignedInt(); + Node left = utils::mkExtract(a, size - 1 - uint32_amount, 0); + Node right = utils::mkConst(BitVector(uint32_amount, Integer(0))); + return utils::mkConcat(left, right); +} + +/** + * LshrByConst + * + * Right Logical Shift by constant amount + */ + +template<> +bool RewriteRule::applies(Node node) { + // if the shift amount is constant + return (node.getKind() == kind::BITVECTOR_LSHR && + node[1].getKind() == kind::CONST_BITVECTOR); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + Integer amount = node[1].getConst().toInteger(); + + Node a = node[0]; + uint32_t size = utils::getSize(a); + + + if (amount >= Integer(size)) { + // if we are shifting more than the length of the bitvector return 0 + return utils::mkConst(BitVector(size, Integer(0))); + } + + // make sure we do not lose information casting + Assert(amount < Integer(1).multiplyByPow2(32)); + + uint32_t uint32_amount = amount.toUnsignedInt(); + Node right = utils::mkExtract(a, size - 1, uint32_amount); + Node left = utils::mkConst(BitVector(uint32_amount, Integer(0))); + return utils::mkConcat(left, right); +} + +/** + * AshrByConst + * + * Right Arithmetic Shift by constant amount + */ + +template<> +bool RewriteRule::applies(Node node) { + // if the shift amount is constant + return (node.getKind() == kind::BITVECTOR_ASHR && + node[1].getKind() == kind::CONST_BITVECTOR); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + Integer amount = node[1].getConst().toInteger(); + + Node a = node[0]; + uint32_t size = utils::getSize(a); + Node sign_bit = utils::mkExtract(a, size-1, size-1); + + if (amount >= Integer(size)) { + // if we are shifting more than the length of the bitvector return n repetitions + // of the first bit + return utils::mkConcat(sign_bit, size); + } + + // make sure we do not lose information casting + Assert(amount < Integer(1).multiplyByPow2(32)); + + uint32_t uint32_amount = amount.toUnsignedInt(); + if (uint32_amount == 0) { + return a; + } + + Node left = utils::mkConcat(sign_bit, uint32_amount); + Node right = utils::mkExtract(a, size - 1, uint32_amount); + return utils::mkConcat(left, right); +} + +/** + * BitwiseIdemp + * + * (a bvand a) ==> a + * (a bvor a) ==> a + */ + +template<> +bool RewriteRule::applies(Node node) { + return ((node.getKind() == kind::BITVECTOR_AND || + node.getKind() == kind::BITVECTOR_OR) && + node[0] == node[1]); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + return node[0]; +} + +/** + * AndZero + * + * (a bvand 0) ==> 0 + */ + +template<> +bool RewriteRule::applies(Node node) { + unsigned size = utils::getSize(node); + return (node.getKind() == kind::BITVECTOR_AND && + (node[0] == utils::mkConst(size, 0) || + node[1] == utils::mkConst(size, 0))); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + return utils::mkConst(utils::getSize(node), 0); +} + +/** + * AndOne + * + * (a bvand 1) ==> a + */ + +template<> +bool RewriteRule::applies(Node node) { + unsigned size = utils::getSize(node); + Node ones = utils::mkOnes(size); + return (node.getKind() == kind::BITVECTOR_AND && + (node[0] == ones || + node[1] == ones)); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + unsigned size = utils::getSize(node); + + if (node[0] == utils::mkOnes(size)) { + return node[1]; + } else { + Assert (node[1] == utils::mkOnes(size)); + return node[0]; + } +} + +/** + * OrZero + * + * (a bvor 0) ==> a + */ + +template<> +bool RewriteRule::applies(Node node) { + unsigned size = utils::getSize(node); + return (node.getKind() == kind::BITVECTOR_OR && + (node[0] == utils::mkConst(size, 0) || + node[1] == utils::mkConst(size, 0))); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + + unsigned size = utils::getSize(node); + if (node[0] == utils::mkConst(size, 0)) { + return node[1]; + } else { + Assert(node[1] == utils::mkConst(size, 0)); + return node[0]; + } +} + +/** + * OrOne + * + * (a bvor 1) ==> 1 + */ + +template<> +bool RewriteRule::applies(Node node) { + unsigned size = utils::getSize(node); + Node ones = utils::mkOnes(size); + return (node.getKind() == kind::BITVECTOR_OR && + (node[0] == ones || + node[1] == ones)); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + return utils::mkOnes(utils::getSize(node)); +} + + +/** + * XorDuplicate + * + * (a bvxor a) ==> 0 + */ + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_XOR && + node[0] == node[1]); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + return utils::mkConst(BitVector(utils::getSize(node), Integer(0))); +} + +/** + * XorOne + * + * (a bvxor 1) ==> ~a + */ + +template<> +bool RewriteRule::applies(Node node) { + Node ones = utils::mkOnes(utils::getSize(node)); + return (node.getKind() == kind::BITVECTOR_XOR && + (node[0] == ones || + node[1] == ones)); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + Node ones = utils::mkOnes(utils::getSize(node)); + Node a; + if (node[0] == ones) { + a = node[1]; + } else { + Assert(node[1] == ones); + a = node[0]; + } + + return utils::mkNode(kind::BITVECTOR_NOT, a); +} + + +/** + * XorZero + * + * (a bvxor 0) ==> a + */ + +template<> +bool RewriteRule::applies(Node node) { + Node zero = utils::mkConst(utils::getSize(node), 0); + return (node.getKind() == kind::BITVECTOR_XOR && + (node[0] == zero || + node[1] == zero)); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + Node zero = utils::mkConst(utils::getSize(node), 0); + if (node[0] == zero) { + return node[1]; + } + + Assert(node[1] == zero); + return node[0]; +} + + +/** + * BitwiseNotAnd + * + * (a bvand (~ a)) ==> 0 + */ + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_AND && + ((node[0].getKind() == kind::BITVECTOR_NOT && node[0][0] == node[1]) || + (node[1].getKind() == kind::BITVECTOR_NOT && node[1][0] == node[0]))); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + return utils::mkConst(BitVector(utils::getSize(node), Integer(0))); +} + +/** + * BitwiseNegOr + * + * (a bvor (~ a)) ==> 1 + */ + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_OR && + ((node[0].getKind() == kind::BITVECTOR_NOT && node[0][0] == node[1]) || + (node[1].getKind() == kind::BITVECTOR_NOT && node[1][0] == node[0]))); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + uint32_t size = utils::getSize(node); + Integer ones = Integer(1).multiplyByPow2(size) - 1; + return utils::mkConst(BitVector(size, ones)); +} + +/** + * XorNot + * + * ((~ a) bvxor (~ b)) ==> (a bvxor b) + */ + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_XOR && + node[0].getKind() == kind::BITVECTOR_NOT && + node[1].getKind() == kind::BITVECTOR_NOT); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + Node a = node[0][0]; + Node b = node[1][0]; + return utils::mkNode(kind::BITVECTOR_XOR, a, b); +} + +/** + * NotXor + * + * ~(a bvxor b) ==> (~ a bvxor b) + */ + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_NOT && + node[0].getKind() == kind::BITVECTOR_XOR); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + Node a = node[0][0]; + Node b = node[0][1]; + Node nota = utils::mkNode(kind::BITVECTOR_NOT, a); + return utils::mkNode(kind::BITVECTOR_XOR, nota, b); +} + +/** + * NotIdemp + * + * ~ (~ a) ==> a + */ + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_NOT && + node[0].getKind() == kind::BITVECTOR_NOT); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + return node[0][0]; +} + + + +/** + * LtSelf + * + * a < a ==> false + */ + +template<> +bool RewriteRule::applies(Node node) { + return ((node.getKind() == kind::BITVECTOR_ULT || + node.getKind() == kind::BITVECTOR_SLT) && + node[0] == node[1]); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + return utils::mkFalse(); +} + +/** + * LteSelf + * + * a <= a ==> true + */ + +template<> +bool RewriteRule::applies(Node node) { + return ((node.getKind() == kind::BITVECTOR_ULE || + node.getKind() == kind::BITVECTOR_SLE) && + node[0] == node[1]); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + return utils::mkTrue(); +} + +/** + * UltZero + * + * a < 0 ==> false + */ + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_ULT && + node[1] == utils::mkConst(BitVector(utils::getSize(node[0]), Integer(0)))); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + return utils::mkFalse(); +} + +/** + * UltSelf + * + * a < a ==> false + */ + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_ULT && + node[1] == node[0]); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + return utils::mkFalse(); +} + + +/** + * UleZero + * + * a <= 0 ==> a = 0 + */ + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_ULE && + node[1] == utils::mkConst(BitVector(utils::getSize(node[0]), Integer(0)))); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + return utils::mkNode(kind::EQUAL, node[0], node[1]); +} + +/** + * UleSelf + * + * a <= a ==> true + */ + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_ULE && + node[1] == node[0]); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + return utils::mkTrue(); +} + + +/** + * ZeroUle + * + * 0 <= a ==> true + */ + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_ULE && + node[1] == utils::mkConst(BitVector(utils::getSize(node[0]), Integer(0)))); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + return utils::mkTrue(); +} + +/** + * UleMax + * + * a <= 11..1 ==> true + */ + +template<> +bool RewriteRule::applies(Node node) { + if (node.getKind()!= kind::BITVECTOR_ULE) { + return false; + } + uint32_t size = utils::getSize(node[0]); + Integer ones = Integer(1).multiplyByPow2(size) -1; + return (node.getKind() == kind::BITVECTOR_ULE && + node[1] == utils::mkConst(BitVector(size, ones))); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + return utils::mkTrue(); +} + +/** + * NotUlt + * + * ~ ( a < b) ==> b <= a + */ + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::NOT && + node[0].getKind() == kind::BITVECTOR_ULT); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + Node ult = node[0]; + Node a = ult[0]; + Node b = ult[1]; + return utils::mkNode(kind::BITVECTOR_ULE, b, a); +} + +/** + * NotUle + * + * ~ ( a <= b) ==> b < a + */ + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::NOT && + node[0].getKind() == kind::BITVECTOR_ULE); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + Node ult = node[0]; + Node a = ult[0]; + Node b = ult[1]; + return utils::mkNode(kind::BITVECTOR_ULT, b, a); +} + +/** + * MultOne + * + * (a * 1) ==> a + */ + +template<> +bool RewriteRule::applies(Node node) { + unsigned size = utils::getSize(node); + return (node.getKind() == kind::BITVECTOR_MULT && + (node[0] == utils::mkConst(size, 1) || + node[1] == utils::mkConst(size, 1))); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + unsigned size = utils::getSize(node); + if (node[0] == utils::mkConst(size, 1)) { + return node[1]; + } + Assert(node[1] == utils::mkConst(size, 1)); + return node[0]; +} + +/** + * MultZero + * + * (a * 0) ==> 0 + */ + +template<> +bool RewriteRule::applies(Node node) { + unsigned size = utils::getSize(node); + return (node.getKind() == kind::BITVECTOR_MULT && + (node[0] == utils::mkConst(size, 0) || + node[1] == utils::mkConst(size, 0))); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + return utils::mkConst(utils::getSize(node), 0); +} + +/** + * MultPow2 + * + * (a * 2^k) ==> a[n-k-1:0] 0_k + */ + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_MULT && + (utils::isPow2Const(node[0]) || + utils::isPow2Const(node[1]))); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + Node a; + unsigned power; + power = utils::isPow2Const(node[0]); + + if (power != 0) { + a = node[1]; + // isPow2Const returns the power + 1 + --power; + } else { + power = utils::isPow2Const(node[1]); + Assert(power != 0); + a = node[0]; + power--; + } + + Node extract = utils::mkExtract(a, utils::getSize(node) - power - 1, 0); + Node zeros = utils::mkConst(power, 0); + return utils::mkConcat(extract, zeros); +} + +/** + * PlusZero + * + * (a + 0) ==> a + */ + +template<> +bool RewriteRule::applies(Node node) { + Node zero = utils::mkConst(utils::getSize(node), 0); + return (node.getKind() == kind::BITVECTOR_PLUS && + (node[0] == zero || + node[1] == zero)); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + Node zero = utils::mkConst(utils::getSize(node), 0); + if (node[0] == zero) { + return node[1]; + } + + return node[0]; +} + +/** + * NegIdemp + * + * -(-a) ==> a + */ + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_NEG && + node[0].getKind() == kind::BITVECTOR_NEG); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + return node[0][0]; +} + +/** + * UdivPow2 + * + * (a udiv 2^k) ==> 0_k a[n-1: k] + */ + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_UDIV && + utils::isPow2Const(node[1])); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + Node a = node[0]; + unsigned power = utils::isPow2Const(node[1]) -1; + + Node extract = utils::mkExtract(a, utils::getSize(node) - 1, power); + Node zeros = utils::mkConst(power, 0); + + return utils::mkNode(kind::BITVECTOR_CONCAT, zeros, extract); +} + +/** + * UdivOne + * + * (a udiv 1) ==> a + */ + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_UDIV && + node[1] == utils::mkConst(utils::getSize(node), 1)); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + return node[0]; +} + +/** + * UdivSelf + * + * (a udiv a) ==> 1 + */ + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_UDIV && + node[0] == node[1]); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + return utils::mkConst(utils::getSize(node), 1); +} + +/** + * UremPow2 + * + * (a urem 2^k) ==> 0_(n-k) a[k-1:0] + */ + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_UREM && + utils::isPow2Const(node[1])); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + TNode a = node[0]; + unsigned power = utils::isPow2Const(node[1]) - 1; + Node extract = utils::mkExtract(a, power - 1, 0); + Node zeros = utils::mkConst(utils::getSize(node) - power, 0); + return utils::mkNode(kind::BITVECTOR_CONCAT, zeros, extract); +} + +/** + * UremOne + * + * (a urem 1) ==> 0 + */ + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_UREM && + node[1] == utils::mkConst(utils::getSize(node), 1)); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + return utils::mkConst(utils::getSize(node), 0); +} + +/** + * UremSelf + * + * (a urem a) ==> 0 + */ + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_UREM && + node[0] == node[1]); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + return utils::mkConst(utils::getSize(node), 0); +} + +/** + * ShiftZero + * + * (0_k >> a) ==> 0_k + */ + +template<> +bool RewriteRule::applies(Node node) { + return ((node.getKind() == kind::BITVECTOR_SHL || + node.getKind() == kind::BITVECTOR_LSHR || + node.getKind() == kind::BITVECTOR_ASHR) && + node[0] == utils::mkConst(utils::getSize(node), 0)); +} + +template<> +Node RewriteRule::apply(Node node) { + BVDebug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + return node[0]; +} + +// /** +// * +// * +// * +// */ + +// template<> +// bool RewriteRule<>::applies(Node node) { +// return (node.getKind() == ); +// } + +// template<> +// Node RewriteRule<>::apply(Node node) { +// BVDebug("bv-rewrite") << "RewriteRule<>(" << node << ")" << std::endl; +// return ; +// } + + + +} +} +} diff --git a/src/theory/bv/theory_bv_rewriter.cpp b/src/theory/bv/theory_bv_rewriter.cpp index 2f3538837..586f37073 100644 --- a/src/theory/bv/theory_bv_rewriter.cpp +++ b/src/theory/bv/theory_bv_rewriter.cpp @@ -21,143 +21,506 @@ #include "theory/bv/theory_bv_rewriter.h" #include "theory/bv/theory_bv_rewrite_rules.h" #include "theory/bv/theory_bv_rewrite_rules_core.h" -#include "theory/bv/theory_bv_rewrite_rules_arith.h" +#include "theory/bv/theory_bv_rewrite_rules_operator_elimination.h" +#include "theory/bv/theory_bv_rewrite_rules_constant_evaluation.h" +#include "theory/bv/theory_bv_rewrite_rules_simplification.h" +#include "theory/bv/theory_bv_rewrite_rules_normalization.h" using namespace CVC4; using namespace CVC4::theory; using namespace CVC4::theory::bv; + +// CVC4_THREADLOCAL(AllRewriteRules*) TheoryBVRewriter::s_allRules = NULL; +// CVC4_THREADLOCAL(TimerStat*) TheoryBVRewriter::d_rewriteTimer = NULL; +CVC4_THREADLOCAL(RewriteFunction) TheoryBVRewriter::d_rewriteTable[kind::LAST_KIND]; +void TheoryBVRewriter::init() { + // s_allRules = new AllRewriteRules; + // d_rewriteTimer = new TimerStat("theory::bv::rewriteTimer"); + // StatisticsRegistry::registerStat(d_rewriteTimer); + initializeRewrites(); + +} + +void TheoryBVRewriter::shutdown() { + // delete s_allRules; + // StatisticsRegistry::unregisterStat(d_rewriteTimer); + //delete d_rewriteTimer; +} + +RewriteResponse TheoryBVRewriter::preRewrite(TNode node) { + Debug("bitvector-rewrite") << "TheoryBV::preRewrite(" << node << ")" << std::endl; + //return d_rewriteTable[node.getKind()](node); + return RewriteResponse(REWRITE_DONE, node); +} + RewriteResponse TheoryBVRewriter::postRewrite(TNode node) { + //TimerStat::CodeTimer codeTimer(*d_rewriteTimer); + Debug("bitvector-rewrite") << "TheoryBV::postRewrite(" << node << ")" << std::endl; + RewriteResponse res = d_rewriteTable[node.getKind()](node); + return res; +} + +RewriteResponse TheoryBVRewriter::RewriteUlt(TNode node) { + Node resultNode = LinearRewriteStrategy + < RewriteRule, + // if both arguments are constants evaluates + RewriteRule + // a < 0 rewrites to false + >::apply(node); - BVDebug("bitvector") << "TheoryBV::postRewrite(" << node << ")" << std::endl; - - Node result = node; - if (node.getKind() == kind::CONST_BITVECTOR || (node.getKind() != kind::EQUAL && Theory::isLeafOf(node, THEORY_BV))) { - result = node; - } else { - switch (node.getKind()) { - case kind::BITVECTOR_CONCAT: - result = LinearRewriteStrategy< - // Flatten the top level concatenations - RewriteRule, - // Merge the adjacent extracts on non-constants - RewriteRule, - // Merge the adjacent extracts on constants - RewriteRule, - // At this point only Extract-Whole could apply, if the result is only one extract - // or at some sub-expression if the result is a concatenation. - ApplyRuleToChildren - >::apply(node); - break; - case kind::BITVECTOR_EXTRACT: - result = LinearRewriteStrategy< - // Extract over a concatenation is distributed to the appropriate concatenations - RewriteRule, - // Extract over a constant gives a constant - RewriteRule, - // We could get another extract over extract - RewriteRule, - // At this point only Extract-Whole could apply - RewriteRule - >::apply(node); - break; - case kind::EQUAL: - result = LinearRewriteStrategy< - // Two distinct values rewrite to false - RewriteRule, - // If both sides are equal equality is true - RewriteRule, - // Eliminate the equalities - RewriteRule - >::apply(node); - break; - case kind::BITVECTOR_UGT: - result = LinearRewriteStrategy < - RewriteRule - >::apply(node); - break; - - case kind::BITVECTOR_UGE: - result = LinearRewriteStrategy < - RewriteRule - >::apply(node); - break; - case kind::BITVECTOR_SGT: - result = LinearRewriteStrategy < - RewriteRule - >::apply(node); - break; - case kind::BITVECTOR_SGE: - result = LinearRewriteStrategy < - RewriteRule - >::apply(node); - break; - case kind::BITVECTOR_REPEAT: - result = LinearRewriteStrategy < - RewriteRule - >::apply(node); - break; - case kind::BITVECTOR_ROTATE_RIGHT: - result = LinearRewriteStrategy < - RewriteRule - >::apply(node); - break; - case kind::BITVECTOR_ROTATE_LEFT: - result = LinearRewriteStrategy < - RewriteRule - >::apply(node); - break; - case kind::BITVECTOR_NAND: - result = LinearRewriteStrategy < - RewriteRule - >::apply(node); - break; - case kind::BITVECTOR_NOR: - result = LinearRewriteStrategy < - RewriteRule - >::apply(node); - break; - - case kind::BITVECTOR_SDIV: - result = LinearRewriteStrategy < - RewriteRule - >::apply(node); - break; - case kind::BITVECTOR_SREM: - result = LinearRewriteStrategy < - RewriteRule - >::apply(node); - break; - case kind::BITVECTOR_SMOD: - result = LinearRewriteStrategy < - RewriteRule - >::apply(node); - break; - case kind::BITVECTOR_ZERO_EXTEND: - result = LinearRewriteStrategy < - RewriteRule - >::apply(node); - break; - - default: - // TODO: figure out when this is an operator - result = node; - break; - // Unhandled(); - } + return RewriteResponse(REWRITE_DONE, resultNode); +} + +RewriteResponse TheoryBVRewriter::RewriteSlt(TNode node){ + return RewriteResponse(REWRITE_DONE, node); + // Node resultNode = LinearRewriteStrategy + // < RewriteRule < SltEliminate > + // // a <_s b ==> a + 2^{n-1} <_u b + 2^{n-1} + // >::apply(node); + + // return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); +} + +RewriteResponse TheoryBVRewriter::RewriteUle(TNode node){ + Node resultNode = LinearRewriteStrategy + < RewriteRule, + RewriteRule, + RewriteRule, + RewriteRule, + RewriteRule + >::apply(node); + return RewriteResponse(REWRITE_DONE, resultNode); +} + +RewriteResponse TheoryBVRewriter::RewriteSle(TNode node){ + return RewriteResponse(REWRITE_DONE, node); + // Node resultNode = LinearRewriteStrategy + // < RewriteRule + // >::apply(node); + + // return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); +} + +RewriteResponse TheoryBVRewriter::RewriteUgt(TNode node){ + Node resultNode = LinearRewriteStrategy + < RewriteRule + >::apply(node); + + return RewriteResponse(REWRITE_AGAIN, resultNode); +} + +RewriteResponse TheoryBVRewriter::RewriteSgt(TNode node){ + Node resultNode = LinearRewriteStrategy + < RewriteRule + //RewriteRule + >::apply(node); + + return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); +} + +RewriteResponse TheoryBVRewriter::RewriteUge(TNode node){ + Node resultNode = LinearRewriteStrategy + < RewriteRule + >::apply(node); + + return RewriteResponse(REWRITE_AGAIN, resultNode); +} + +RewriteResponse TheoryBVRewriter::RewriteSge(TNode node){ + Node resultNode = LinearRewriteStrategy + < RewriteRule + // RewriteRule + >::apply(node); + + return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); +} + +RewriteResponse TheoryBVRewriter::RewriteNot(TNode node){ + Node resultNode = node; + if(RewriteRule::applies(node)) { + resultNode = RewriteRule::run(node); + return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); } + resultNode = LinearRewriteStrategy + < RewriteRule, + RewriteRule + >::apply(node); + + return RewriteResponse(REWRITE_DONE, resultNode); +} - BVDebug("bitvector") << "TheoryBV::postRewrite(" << node << ") => " << result << std::endl; +RewriteResponse TheoryBVRewriter::RewriteExtract(TNode node) { + Node resultNode = node; + if (RewriteRule::applies(node)) { + resultNode = RewriteRule::run(node); + return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); + } - return RewriteResponse(REWRITE_DONE, result); + if (RewriteRule::applies(node)) { + resultNode = RewriteRule::run(node); + return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); + } + + if (RewriteRule::applies(node)) { + resultNode = RewriteRule::run(node); + return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); + } + + if (RewriteRule::applies(node)) { + resultNode = RewriteRule::run(node); + return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); + } + + resultNode = LinearRewriteStrategy + < RewriteRule, + RewriteRule, + // We could get another extract over extract + RewriteRule + // At this point only Extract-Whole could apply + >::apply(node); + + return RewriteResponse(REWRITE_DONE, resultNode); } -CVC4_THREADLOCAL(AllRewriteRules*) TheoryBVRewriter::s_allRules = NULL; -void TheoryBVRewriter::init() { - s_allRules = new AllRewriteRules; +RewriteResponse TheoryBVRewriter::RewriteConcat(TNode node) { + Node resultNode = LinearRewriteStrategy + < RewriteRule, + // Flatten the top level concatenations + RewriteRule, + // Merge the adjacent extracts on non-constants + RewriteRule, + // Merge the adjacent extracts on constants + ApplyRuleToChildren + >::apply(node); + return RewriteResponse(REWRITE_DONE, resultNode); } -void TheoryBVRewriter::shutdown() { - delete s_allRules; +RewriteResponse TheoryBVRewriter::RewriteAnd(TNode node){ + Node resultNode = LinearRewriteStrategy + < RewriteRule, + RewriteRule, + RewriteRule, + RewriteRule + >::apply(node); + + return RewriteResponse(REWRITE_DONE, resultNode); +} + +RewriteResponse TheoryBVRewriter::RewriteOr(TNode node){ + Node resultNode = LinearRewriteStrategy + < RewriteRule, + RewriteRule, + RewriteRule, + RewriteRule + >::apply(node); + + return RewriteResponse(REWRITE_DONE, resultNode); +} + +RewriteResponse TheoryBVRewriter::RewriteXor(TNode node) { + Node resultNode = node; + if (RewriteRule::applies(node)) { + resultNode = RewriteRule::run (node); + return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); + } + resultNode = LinearRewriteStrategy + < RewriteRule, + RewriteRule, + RewriteRule, + RewriteRule + >::apply(node); + + return RewriteResponse(REWRITE_DONE, resultNode); +} + +RewriteResponse TheoryBVRewriter::RewriteXnor(TNode node) { + Node resultNode = LinearRewriteStrategy + < RewriteRule + >::apply(node); + // need to rewrite two levels in + return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); } + +RewriteResponse TheoryBVRewriter::RewriteNand(TNode node) { + Node resultNode = LinearRewriteStrategy + < RewriteRule + >::apply(node); + + return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); +} + +RewriteResponse TheoryBVRewriter::RewriteNor(TNode node) { + Node resultNode = LinearRewriteStrategy + < RewriteRule + >::apply(node); + + return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); +} + +RewriteResponse TheoryBVRewriter::RewriteComp(TNode node) { + Node resultNode = LinearRewriteStrategy + < RewriteRule + >::apply(node); + + return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); +} + +RewriteResponse TheoryBVRewriter::RewriteMult(TNode node) { + Node resultNode = node; + if(RewriteRule::applies(node)) { + resultNode = RewriteRule::run (node); + return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); + } + + resultNode = LinearRewriteStrategy + < RewriteRule, + RewriteRule, + RewriteRule + >::apply(node); + + return RewriteResponse(REWRITE_DONE, resultNode); +} + +RewriteResponse TheoryBVRewriter::RewritePlus(TNode node) { + Node resultNode = LinearRewriteStrategy + < RewriteRule, + RewriteRule + // RewriteRule, + // RewriteRule + // RewriteRule + >::apply(node); + + return RewriteResponse(REWRITE_DONE, resultNode); +} + +RewriteResponse TheoryBVRewriter::RewriteSub(TNode node){ + return RewriteResponse(REWRITE_DONE, node); + // Node resultNode = LinearRewriteStrategy + // < RewriteRule + // >::apply(node); + + // return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); +} + +RewriteResponse TheoryBVRewriter::RewriteNeg(TNode node) { + Node resultNode = LinearRewriteStrategy + < RewriteRule, + RewriteRule + >::apply(node); + + return RewriteResponse(REWRITE_DONE, resultNode); +} + +RewriteResponse TheoryBVRewriter::RewriteUdiv(TNode node){ + Node resultNode = node; + if(RewriteRule::applies(node)) { + resultNode = RewriteRule::run (node); + return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); + } + + resultNode = LinearRewriteStrategy + < RewriteRule, + RewriteRule, + RewriteRule + >::apply(node); + + return RewriteResponse(REWRITE_DONE, resultNode); +} + +RewriteResponse TheoryBVRewriter::RewriteUrem(TNode node) { + Node resultNode = node; + if(RewriteRule::applies(node)) { + resultNode = RewriteRule::run (node); + return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); + } + + resultNode = LinearRewriteStrategy + < RewriteRule, + RewriteRule, + RewriteRule + >::apply(node); + return RewriteResponse(REWRITE_DONE, resultNode); +} + +RewriteResponse TheoryBVRewriter::RewriteSmod(TNode node) { + Node resultNode = LinearRewriteStrategy + < RewriteRule + >::apply(node); + + return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); +} + +RewriteResponse TheoryBVRewriter::RewriteSdiv(TNode node) { + Node resultNode = LinearRewriteStrategy + < RewriteRule + >::apply(node); + + return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); +} + +RewriteResponse TheoryBVRewriter::RewriteSrem(TNode node) { + Node resultNode = LinearRewriteStrategy + < RewriteRule + >::apply(node); + return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); +} + +RewriteResponse TheoryBVRewriter::RewriteShl(TNode node) { + Node resultNode = node; + if(RewriteRule::applies(node)) { + resultNode = RewriteRule::run (node); + return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); + } + + resultNode = LinearRewriteStrategy + < RewriteRule, + RewriteRule + >::apply(node); + + return RewriteResponse(REWRITE_DONE, resultNode); +} + +RewriteResponse TheoryBVRewriter::RewriteLshr(TNode node) { + Node resultNode = node; + if(RewriteRule::applies(node)) { + resultNode = RewriteRule::run (node); + return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); + } + + resultNode = LinearRewriteStrategy + < RewriteRule, + RewriteRule + >::apply(node); + + return RewriteResponse(REWRITE_DONE, resultNode); +} + +RewriteResponse TheoryBVRewriter::RewriteAshr(TNode node) { + Node resultNode = node; + if(RewriteRule::applies(node)) { + resultNode = RewriteRule::run (node); + return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); + } + + resultNode = LinearRewriteStrategy + < RewriteRule, + RewriteRule + >::apply(node); + + return RewriteResponse(REWRITE_DONE, resultNode); +} + + +RewriteResponse TheoryBVRewriter::RewriteRepeat(TNode node) { + Node resultNode = LinearRewriteStrategy + < RewriteRule + >::apply(node); + + return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); +} + +RewriteResponse TheoryBVRewriter::RewriteZeroExtend(TNode node){ + Node resultNode = LinearRewriteStrategy + < RewriteRule + >::apply(node); + + return RewriteResponse(REWRITE_AGAIN, resultNode); +} + +RewriteResponse TheoryBVRewriter::RewriteSignExtend(TNode node) { + // Node resultNode = LinearRewriteStrategy + // < RewriteRule + // >::apply(node); + + // return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); + return RewriteResponse(REWRITE_DONE, node); +} + + +RewriteResponse TheoryBVRewriter::RewriteRotateRight(TNode node) { + Node resultNode = LinearRewriteStrategy + < RewriteRule + >::apply(node); + + return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); +} + +RewriteResponse TheoryBVRewriter::RewriteRotateLeft(TNode node){ + Node resultNode = LinearRewriteStrategy + < RewriteRule + >::apply(node); + + return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); +} + +RewriteResponse TheoryBVRewriter::RewriteEqual(TNode node) { + Node resultNode = LinearRewriteStrategy + < RewriteRule, + RewriteRule, + RewriteRule + >::apply(node); + + return RewriteResponse(REWRITE_DONE, resultNode); +} + + +RewriteResponse TheoryBVRewriter::IdentityRewrite(TNode node) { + return RewriteResponse(REWRITE_DONE, node); +} + +RewriteResponse TheoryBVRewriter::UndefinedRewrite(TNode node) { + Debug("bv-rewrite") << "TheoryBV::UndefinedRewrite for" << node; + Unimplemented(); +} + + + +void TheoryBVRewriter::initializeRewrites() { + + for(unsigned i = 0; i < kind::LAST_KIND; ++i) { + d_rewriteTable[i] = IdentityRewrite; //UndefinedRewrite; + } + + d_rewriteTable [ kind::EQUAL ] = RewriteEqual; + d_rewriteTable [ kind::BITVECTOR_ULT ] = RewriteUlt; + d_rewriteTable [ kind::BITVECTOR_SLT ] = RewriteSlt; + d_rewriteTable [ kind::BITVECTOR_ULE ] = RewriteUle; + d_rewriteTable [ kind::BITVECTOR_SLE ] = RewriteSle; + d_rewriteTable [ kind::BITVECTOR_UGT ] = RewriteUgt; + d_rewriteTable [ kind::BITVECTOR_SGT ] = RewriteSgt; + d_rewriteTable [ kind::BITVECTOR_UGE ] = RewriteUge; + d_rewriteTable [ kind::BITVECTOR_SGE ] = RewriteSge; + d_rewriteTable [ kind::BITVECTOR_NOT ] = RewriteNot; + d_rewriteTable [ kind::BITVECTOR_CONCAT ] = RewriteConcat; + d_rewriteTable [ kind::BITVECTOR_AND ] = RewriteAnd; + d_rewriteTable [ kind::BITVECTOR_OR ] = RewriteOr; + d_rewriteTable [ kind::BITVECTOR_XOR] = RewriteXor; + d_rewriteTable [ kind::BITVECTOR_XNOR ] = RewriteXnor; + d_rewriteTable [ kind::BITVECTOR_NAND ] = RewriteNand; + d_rewriteTable [ kind::BITVECTOR_NOR ] = RewriteNor; + d_rewriteTable [ kind::BITVECTOR_COMP ] = RewriteComp; + d_rewriteTable [ kind::BITVECTOR_MULT ] = RewriteMult; + d_rewriteTable [ kind::BITVECTOR_PLUS ] = RewritePlus; + d_rewriteTable [ kind::BITVECTOR_SUB ] = RewriteSub; + d_rewriteTable [ kind::BITVECTOR_NEG ] = RewriteNeg; + d_rewriteTable [ kind::BITVECTOR_UDIV ] = RewriteUdiv; + d_rewriteTable [ kind::BITVECTOR_UREM ] = RewriteUrem; + d_rewriteTable [ kind::BITVECTOR_SMOD ] = RewriteSmod; + d_rewriteTable [ kind::BITVECTOR_SDIV ] = RewriteSdiv; + d_rewriteTable [ kind::BITVECTOR_SREM ] = RewriteSrem; + d_rewriteTable [ kind::BITVECTOR_SHL ] = RewriteShl; + d_rewriteTable [ kind::BITVECTOR_LSHR ] = RewriteLshr; + d_rewriteTable [ kind::BITVECTOR_ASHR ] = RewriteAshr; + d_rewriteTable [ kind::BITVECTOR_EXTRACT ] = RewriteExtract; + d_rewriteTable [ kind::BITVECTOR_REPEAT ] = RewriteRepeat; + d_rewriteTable [ kind::BITVECTOR_ZERO_EXTEND ] = RewriteZeroExtend; + d_rewriteTable [ kind::BITVECTOR_SIGN_EXTEND ] = RewriteSignExtend; + d_rewriteTable [ kind::BITVECTOR_ROTATE_RIGHT ] = RewriteRotateRight; + d_rewriteTable [ kind::BITVECTOR_ROTATE_LEFT ] = RewriteRotateLeft; +} + + + + diff --git a/src/theory/bv/theory_bv_rewriter.h b/src/theory/bv/theory_bv_rewriter.h index 1893977a8..e72c1a910 100644 --- a/src/theory/bv/theory_bv_rewriter.h +++ b/src/theory/bv/theory_bv_rewriter.h @@ -23,31 +23,76 @@ #define __CVC4__THEORY__BV__THEORY_BV_REWRITER_H #include "theory/rewriter.h" +#include "util/stats.h" namespace CVC4 { namespace theory { namespace bv { struct AllRewriteRules; +typedef RewriteResponse (*RewriteFunction) (TNode); class TheoryBVRewriter { + // static CVC4_THREADLOCAL(AllRewriteRules*) s_allRules; + // static CVC4_THREADLOCAL(TimerStat*) d_rewriteTimer; + static CVC4_THREADLOCAL(RewriteFunction) d_rewriteTable[kind::LAST_KIND]; - static CVC4_THREADLOCAL(AllRewriteRules*) s_allRules; + static RewriteResponse IdentityRewrite(TNode node); + static RewriteResponse UndefinedRewrite(TNode node); + + static void initializeRewrites(); + + static RewriteResponse RewriteEqual(TNode node); + static RewriteResponse RewriteUlt(TNode node); + static RewriteResponse RewriteSlt(TNode node); + static RewriteResponse RewriteUle(TNode node); + static RewriteResponse RewriteSle(TNode node); + static RewriteResponse RewriteUgt(TNode node); + static RewriteResponse RewriteSgt(TNode node); + static RewriteResponse RewriteUge(TNode node); + static RewriteResponse RewriteSge(TNode node); + static RewriteResponse RewriteNot(TNode node); + static RewriteResponse RewriteConcat(TNode node); + static RewriteResponse RewriteAnd(TNode node); + static RewriteResponse RewriteOr(TNode node); + static RewriteResponse RewriteXnor(TNode node); + static RewriteResponse RewriteXor(TNode node); + static RewriteResponse RewriteNand(TNode node); + static RewriteResponse RewriteNor(TNode node); + static RewriteResponse RewriteComp(TNode node); + static RewriteResponse RewriteMult(TNode node); + static RewriteResponse RewritePlus(TNode node); + static RewriteResponse RewriteSub(TNode node); + static RewriteResponse RewriteNeg(TNode node); + static RewriteResponse RewriteUdiv(TNode node); + static RewriteResponse RewriteUrem(TNode node); + static RewriteResponse RewriteSmod(TNode node); + static RewriteResponse RewriteSdiv(TNode node); + static RewriteResponse RewriteSrem(TNode node); + static RewriteResponse RewriteShl(TNode node); + static RewriteResponse RewriteLshr(TNode node); + static RewriteResponse RewriteAshr(TNode node); + static RewriteResponse RewriteExtract(TNode node); + static RewriteResponse RewriteRepeat(TNode node); + static RewriteResponse RewriteZeroExtend(TNode node); + static RewriteResponse RewriteSignExtend(TNode node); + static RewriteResponse RewriteRotateRight(TNode node); + static RewriteResponse RewriteRotateLeft(TNode node); public: static RewriteResponse postRewrite(TNode node); - static inline RewriteResponse preRewrite(TNode node) { - return postRewrite(node); - } + static RewriteResponse preRewrite(TNode node); static inline Node rewriteEquality(TNode node) { + Debug("bitvector") << "TheoryBV::rewriteEquality(" << node << ")" << std::endl; return postRewrite(node).node; } static void init(); static void shutdown(); + };/* class TheoryBVRewriter */ }/* CVC4::theory::bv namespace */ diff --git a/src/theory/bv/theory_bv_utils.h b/src/theory/bv/theory_bv_utils.h index 6bd7368c9..60e459958 100644 --- a/src/theory/bv/theory_bv_utils.h +++ b/src/theory/bv/theory_bv_utils.h @@ -32,6 +32,8 @@ #define BVDebug(x) if (false) Debug(x) #endif + + namespace CVC4 { namespace theory { namespace bv { @@ -129,6 +131,20 @@ inline Node mkBitOf(TNode node, unsigned index) { } + +inline Node mkConcat(Node node, unsigned repeat) { + Assert (repeat); + if(repeat == 1) { + return node; + } + NodeBuilder<> result(kind::BITVECTOR_CONCAT); + for (unsigned i = 0; i < repeat; ++i) { + result << node; + } + Node resultNode = result; + return resultNode; +} + inline Node mkConcat(std::vector& children) { if (children.size() > 1) return NodeManager::currentNM()->mkNode(kind::BITVECTOR_CONCAT, children); @@ -140,6 +156,10 @@ inline Node mkConcat(TNode t1, TNode t2) { return NodeManager::currentNM()->mkNode(kind::BITVECTOR_CONCAT, t1, t2); } +inline Node mkOnes(unsigned size) { + BitVector val = BitVector(1, Integer(1)).signExtend(size-1); + return NodeManager::currentNM()->mkConst(val); +} inline Node mkConst(unsigned size, unsigned int value) { BitVector val(size, value); @@ -197,6 +217,30 @@ inline Node mkConjunction(const std::set nodes) { return conjunction; } + +inline unsigned isPow2Const(TNode node) { + if (node.getKind() != kind::CONST_BITVECTOR) { + return false; + } + + BitVector bv = node.getConst(); + return bv.isPow2(); +} + +// neeed a better name, this is not technically a ground term +inline bool isBVGroundTerm(TNode node) { + if (node.getNumChildren() == 0) { + return node.isConst(); + } + + for (size_t i = 0; i < node.getNumChildren(); ++i) { + if(! node[i].isConst()) { + return false; + } + } + return true; +} + inline bool isBVPredicate(TNode node) { if (node.getKind() == kind::EQUAL || node.getKind() == kind::BITVECTOR_ULT || diff --git a/src/util/bitvector.h b/src/util/bitvector.h index 8adb466cf..7429ac8c9 100644 --- a/src/util/bitvector.h +++ b/src/util/bitvector.h @@ -32,30 +32,55 @@ class CVC4_PUBLIC BitVector { private: + /* + Class invariants: + * no overflows: 2^d_size < d_value + * no negative numbers: d_value >= 0 + */ unsigned d_size; Integer d_value; + Integer toSignedInt() const { + // returns Integer corresponding to two's complement interpretation of bv + unsigned size = d_size; + Integer sign_bit = d_value.extractBitRange(1,size-1); + Integer val = d_value.extractBitRange(size-1, 0); + Integer res = Integer(-1) * sign_bit.multiplyByPow2(size - 1) + val; + return res; + } + + public: - BitVector(unsigned size, const Integer& val) - : d_size(size), d_value(val) {} - + BitVector(unsigned size, const Integer& val): + d_size(size), + d_value(val.modByPow2(size)) + {} + BitVector(unsigned size = 0) - : d_size(size), d_value(0) {} + : d_size(size), d_value(0) {} BitVector(unsigned size, unsigned int z) - : d_size(size), d_value(z) {} - + : d_size(size), d_value(z) { + d_value = d_value.modByPow2(size); + } + BitVector(unsigned size, unsigned long int z) - : d_size(size), d_value(z) {} + : d_size(size), d_value(z) { + d_value = d_value.modByPow2(size); + } BitVector(unsigned size, const BitVector& q) - : d_size(size), d_value(q.d_value) {} - + : d_size(size), d_value(q.d_value) {} + BitVector(const std::string& num, unsigned base = 2); ~BitVector() {} + Integer toInteger() const { + return d_value; + } + BitVector& operator =(const BitVector& x) { if(this == &x) return *this; @@ -65,46 +90,227 @@ public: } bool operator ==(const BitVector& y) const { - if (d_size != y.d_size) return false; + if (d_size != y.d_size) return false; return d_value == y.d_value; } bool operator !=(const BitVector& y) const { - if (d_size == y.d_size) return false; + if (d_size != y.d_size) return true; return d_value != y.d_value; } + BitVector equals(const BitVector& y) const { + Assert(d_size == y.d_size); + return d_value == y.d_value; + } + + BitVector concat (const BitVector& other) const { + return BitVector(d_size + other.d_size, (d_value.multiplyByPow2(other.d_size)) + other.d_value); + } + + BitVector extract(unsigned high, unsigned low) const { + return BitVector(high - low + 1, d_value.extractBitRange(high - low + 1, low)); + } + + /* + Bitwise operations on BitVectors + */ + + // xor + BitVector operator ^(const BitVector& y) const { + Assert (d_size == y.d_size); + return BitVector(d_size, d_value.bitwiseXor(y.d_value)); + } + + // or + BitVector operator |(const BitVector& y) const { + Assert (d_size == y.d_size); + return BitVector(d_size, d_value.bitwiseOr(y.d_value)); + } + + // and + BitVector operator &(const BitVector& y) const { + Assert (d_size == y.d_size); + return BitVector(d_size, d_value.bitwiseAnd(y.d_value)); + } + + // not + BitVector operator ~() const { + return BitVector(d_size, d_value.bitwiseNot()); + } + + /* + Arithmetic operations on BitVectors + */ + + + bool operator <(const BitVector& y) const { + return d_value < y.d_value; + } + + bool operator >(const BitVector& y) const { + return d_value > y.d_value ; + } + + bool operator <=(const BitVector& y) const { + return d_value <= y.d_value; + } + + bool operator >=(const BitVector& y) const { + return d_value >= y.d_value ; + } + + BitVector operator +(const BitVector& y) const { - return BitVector(std::max(d_size, y.d_size), d_value + y.d_value); + Assert (d_size == y.d_size); + Integer sum = d_value + y.d_value; + return BitVector(d_size, sum); } BitVector operator -(const BitVector& y) const { - return *this + ~y + 1; + Assert (d_size == y.d_size); + // to maintain the invariant that we are only adding BitVectors of the + // same size + BitVector one(d_size, Integer(1)); + return *this + ~y + one; } BitVector operator -() const { - return ~(*this) + 1; + BitVector one(d_size, Integer(1)); + return ~(*this) + one; } BitVector operator *(const BitVector& y) const { - return BitVector(d_size, d_value * y.d_value); + Assert (d_size == y.d_size); + Integer prod = d_value * y.d_value; + return BitVector(d_size, prod); + } + + BitVector unsignedDiv (const BitVector& y) const { + Assert (d_size == y.d_size); + Assert (d_value >= 0 && y.d_value > 0); + return BitVector(d_size, d_value.floorDivideQuotient(y.d_value)); } - BitVector operator ~() const { - //is this right? it looks like a no-op? - return BitVector(d_size, d_value); + BitVector unsignedRem(const BitVector& y) const { + Assert (d_size == y.d_size); + Assert (d_value >= 0 && y.d_value > 0); + return BitVector(d_size, d_value.floorDivideRemainder(y.d_value)); + } + + + bool signedLessThan(const BitVector& y) const { + Assert(d_size == y.d_size); + Assert(d_value >= 0 && y.d_value >= 0); + Integer a = (*this).toSignedInt(); + Integer b = y.toSignedInt(); + + return a < b; } - BitVector concat (const BitVector& other) const { - return BitVector(d_size + other.d_size, (d_value.multiplyByPow2(other.d_size)) + other.d_value); - //return BitVector(d_size + other.d_size, (d_value * Integer(2).pow(other.d_size)) + other.d_value); + bool unsignedLessThan(const BitVector& y) const { + Assert(d_size == y.d_size); + Assert(d_value >= 0 && y.d_value >= 0); + return d_value < y.d_value; } - BitVector extract(unsigned high, unsigned low) const { - return BitVector(high - low + 1, d_value.extractBitRange(high - low + 1, low)); - //return BitVector(high - low + 1, (d_value % (Integer(2).pow(high + 1))) / Integer(2).pow(low)); + bool signedLessThanEq(const BitVector& y) const { + Assert(d_size == y.d_size); + Assert(d_value >= 0 && y.d_value >= 0); + Integer a = (*this).toSignedInt(); + Integer b = y.toSignedInt(); + + return a <= b; + } + + bool unsignedLessThanEq(const BitVector& y) const { + Assert(d_size == y.d_size); + Assert(d_value >= 0 && y.d_value >= 0); + return d_value <= y.d_value; + } + + + /* + Extend operations + */ + + BitVector zeroExtend(unsigned amount) const { + return BitVector(d_size + amount, d_value); + } + + BitVector signExtend(unsigned amount) const { + Integer sign_bit = d_value.extractBitRange(1, d_size -1); + if(sign_bit == Integer(0)) { + return BitVector(d_size + amount, d_value); + } else { + Integer val = d_value.oneExtend(d_size, amount); + return BitVector(d_size+ amount, val); + } + } + + /* + Shifts on BitVectors + */ + + BitVector leftShift(const BitVector& y) const { + if (y.d_value > Integer(d_size)) { + return BitVector(d_size, Integer(0)); + } + if (y.d_value == 0) { + return *this; + } + + // making sure we don't lose information casting + Assert(y.d_value < Integer(1).multiplyByPow2(32)); + uint32_t amount = y.d_value.toUnsignedInt(); + Integer res = d_value.multiplyByPow2(amount); + return BitVector(d_size, res); } + BitVector logicalRightShift(const BitVector& y) const { + if(y.d_value > Integer(d_size)) { + return BitVector(d_size, Integer(0)); + } + + // making sure we don't lose information casting + Assert(y.d_value < Integer(1).multiplyByPow2(32)); + uint32_t amount = y.d_value.toUnsignedInt(); + Integer res = d_value.divByPow2(amount); + return BitVector(d_size, res); + } + + BitVector arithRightShift(const BitVector& y) const { + Integer sign_bit = d_value.extractBitRange(1, d_size - 1); + if(y.d_value > Integer(d_size)) { + if(sign_bit == Integer(0)) { + return BitVector(d_size, Integer(0)); + } else { + return BitVector(d_size, Integer(d_size).multiplyByPow2(d_size) -1 ); + } + } + + if (y.d_value == 0) { + return *this; + } + + // making sure we don't lose information casting + Assert(y.d_value < Integer(1).multiplyByPow2(32)); + + uint32_t amount = y.d_value.toUnsignedInt(); + Integer rest = d_value.divByPow2(amount); + + if(sign_bit == Integer(0)) { + return BitVector(d_size, rest); + } + Integer res = rest.oneExtend(d_size - amount, amount); + return BitVector(d_size, res); + } + + + /* + Convenience functions + */ + size_t hash() const { return d_value.hash() + d_size; } @@ -129,32 +335,27 @@ public: const Integer& getValue() const { return d_value; } + + /** + Returns k is the integer is equal to 2^k and zero + otherwise + @return k if the integer is equal to 2^k and zero otherwise + */ + unsigned isPow2() { + return d_value.isPow2(); + } + };/* class BitVector */ + + inline BitVector::BitVector(const std::string& num, unsigned base) { AlwaysAssert( base == 2 || base == 16 ); if( base == 2 ) { d_size = num.size(); -// d_value = Integer(num,2); -/* - for( string::const_iterator it = num.begin(); it != num.end(); ++it ) { - if( *it != '0' || *it != '1' ) { - IllegalArgument(num, "BitVector argument is not a binary string."); - } - z = (Integer(2) * z) + (*it == '1'); - d_value = mpz_class(z.get_mpz()); - } -*/ } else if( base == 16 ) { d_size = num.size() * 4; -// // Use a stream to decode the hex string -// stringstream ss; -// ss.setf(ios::hex, ios::basefield); -// ss << num; -// ss >> z; -// d_value = mpz_class(z); -// break; } else { Unreachable("Unsupported base in BitVector(std::string&, unsigned int)."); } diff --git a/src/util/integer_cln_imp.h b/src/util/integer_cln_imp.h index 06459e3e1..9d67e8fba 100644 --- a/src/util/integer_cln_imp.h +++ b/src/util/integer_cln_imp.h @@ -185,6 +185,24 @@ public: } */ + + Integer bitwiseOr(const Integer& y) const { + return Integer(cln::logior(d_value, y.d_value)); + } + + Integer bitwiseAnd(const Integer& y) const { + return Integer(cln::logand(d_value, y.d_value)); + } + + Integer bitwiseXor(const Integer& y) const { + return Integer(cln::logxor(d_value, y.d_value)); + } + + Integer bitwiseNot() const { + return Integer(cln::lognot(d_value)); + } + + /** * Return this*(2^pow). */ @@ -193,6 +211,20 @@ public: return Integer( d_value << ipow); } + Integer oneExtend(uint32_t size, uint32_t amount) const { + Assert((*this) < Integer(1).multiplyByPow2(size)); + cln::cl_byte range(amount, size); + cln::cl_I allones = (cln::cl_I(1) << (size + amount))- 1; // 2^size - 1 + Integer temp(allones); + + return Integer(cln::deposit_field(allones, d_value, range)); + } + + uint32_t toUnsignedInt() const { + return cln::cl_I_to_uint(d_value); + } + + /** See CLN Documentation. */ Integer extractBitRange(uint32_t bitCount, uint32_t low) const { cln::cl_byte range(bitCount, low); @@ -243,6 +275,15 @@ public: return Integer( cln::exquo(d_value, y.d_value) ); } + Integer modByPow2(uint32_t exp) const { + cln::cl_byte range(exp, 0); + return Integer(cln::ldb(d_value, range)); + } + + Integer divByPow2(uint32_t exp) const { + return d_value >> exp; + } + /** * Raise this Integer to the power exp. * @@ -363,6 +404,16 @@ public: return cln::logbitp(n, d_value); } + /** + * Returns k if the integer is equal to 2^(k-1) + * @return k if the integer is equal to 2^(k-1) and 0 otherwise + */ + unsigned isPow2() const { + if (d_value <= 0) return 0; + // power2p returns n such that d_value = 2^(n-1) + return cln::power2p(d_value); + } + /** * If x != 0, returns the unique n s.t. 2^{n-1} <= abs(x) < 2^{n}. * If x == 0, returns 1. diff --git a/src/util/integer_gmp_imp.h b/src/util/integer_gmp_imp.h index a02f5d2c1..ceb585852 100644 --- a/src/util/integer_gmp_imp.h +++ b/src/util/integer_gmp_imp.h @@ -135,6 +135,30 @@ public: return *this; } + Integer bitwiseOr(const Integer& y) const { + mpz_class result; + mpz_ior(result.get_mpz_t(), d_value.get_mpz_t(), y.d_value.get_mpz_t()); + return Integer(result); + } + + Integer bitwiseAnd(const Integer& y) const { + mpz_class result; + mpz_and(result.get_mpz_t(), d_value.get_mpz_t(), y.d_value.get_mpz_t()); + return Integer(result); + } + + Integer bitwiseXor(const Integer& y) const { + mpz_class result; + mpz_xor(result.get_mpz_t(), d_value.get_mpz_t(), y.d_value.get_mpz_t()); + return Integer(result); + } + + Integer bitwiseNot() const { + mpz_class result; + mpz_com(result.get_mpz_t(), d_value.get_mpz_t()); + return Integer(result); + } + /** * Return this*(2^pow). */ @@ -144,6 +168,26 @@ public: return Integer( result ); } + /** + * Returns the integer with the binary representation of size bits + * extended with amount 1's + */ + Integer oneExtend(uint32_t size, uint32_t amount) const { + // check that the size is accurate + Assert ((*this) < Integer(1).multiplyByPow2(size)); + mpz_class res = d_value; + + for (unsigned i = size; i < size + amount; ++i) { + mpz_setbit(res.get_mpz_t(), i); + } + + return Integer(res); + } + + uint32_t toUnsignedInt() const { + return mpz_get_ui(d_value.get_mpz_t()); + } + /** See GMP Documentation. */ Integer extractBitRange(uint32_t bitCount, uint32_t low) const { // bitCount = high-low+1 @@ -175,7 +219,7 @@ public: } /** - * Computes a floor quoient and remainder for x divided by y. + * Computes a floor quotient and remainder for x divided by y. */ static void floorQR(Integer& q, Integer& r, const Integer& x, const Integer& y) { mpz_fdiv_qr(q.d_value.get_mpz_t(), r.d_value.get_mpz_t(), x.d_value.get_mpz_t(), y.d_value.get_mpz_t()); @@ -209,6 +253,25 @@ public: return Integer( q ); } + /** + * Returns y mod 2^exp + */ + Integer modByPow2(uint32_t exp) const { + mpz_class res; + mpz_fdiv_r_2exp(res.get_mpz_t(), d_value.get_mpz_t(), exp); + return Integer(res); + } + + /** + * Returns y / 2^exp + */ + Integer divByPow2(uint32_t exp) const { + mpz_class res; + mpz_fdiv_q_2exp(res.get_mpz_t(), d_value.get_mpz_t(), exp); + return Integer(res); + } + + int sgn() const { return mpz_sgn(d_value.get_mpz_t()); } @@ -268,14 +331,14 @@ public: long si = d_value.get_si(); // ensure there wasn't overflow AlwaysAssert(mpz_cmp_si(d_value.get_mpz_t(), si) == 0, - "Overflow when extracting long from multiprecision integer"); + "Overflow detected in Integer::getLong()"); return si; } unsigned long getUnsignedLong() const { unsigned long ui = d_value.get_ui(); // ensure there wasn't overflow AlwaysAssert(mpz_cmp_ui(d_value.get_mpz_t(), ui) == 0, - "Overflow when extracting unsigned long from multiprecision integer"); + "Overflow detected in Integer::getUnsignedLong()"); return ui; } @@ -297,6 +360,21 @@ public: return mpz_tstbit(d_value.get_mpz_t(), n); } + /** + * Returns k if the integer is equal to 2^(k-1) + * @return k if the integer is equal to 2^(k-1) and 0 otherwise + */ + unsigned isPow2() const { + if (d_value <= 0) return 0; + // check that the number of ones in the binary represenation is 1 + if (mpz_popcount(d_value.get_mpz_t()) == 1) { + // return the index of the first one plus 1 + return mpz_scan1(d_value.get_mpz_t(), 0) + 1; + } + return 0; + } + + /** * If x != 0, returns the smallest n s.t. 2^{n-1} <= abs(x) < 2^{n}. * If x == 0, returns 1. -- 2.30.2