From: Andrew Reynolds Date: Tue, 6 Jul 2021 22:44:09 +0000 (-0500) Subject: Integrate learned rewrite preprocessing pass (#6840) X-Git-Tag: cvc5-1.0.0~1522 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=4ac6c5179265ef9895bc9e939be0e47b3754137e;p=cvc5.git Integrate learned rewrite preprocessing pass (#6840) This adds the learned rewrite preprocessing pass, which rewrites the input formula based on (typically theory specific) reasoning about learned literals. The main motivation is for preprocessing ints division/modulus based on bounds. --- diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 74db7c941..3246df654 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -88,6 +88,8 @@ libcvc5_add_sources( preprocessing/passes/ite_removal.h preprocessing/passes/ite_simp.cpp preprocessing/passes/ite_simp.h + preprocessing/passes/learned_rewrite.cpp + preprocessing/passes/learned_rewrite.h preprocessing/passes/miplib_trick.cpp preprocessing/passes/miplib_trick.h preprocessing/passes/nl_ext_purify.cpp diff --git a/src/options/smt_options.toml b/src/options/smt_options.toml index 4d08aa672..9b5a93486 100644 --- a/src/options/smt_options.toml +++ b/src/options/smt_options.toml @@ -47,6 +47,14 @@ name = "SMT Layer" default = "true" help = "use static learning (on by default)" +[[option]] + name = "learnedRewrite" + category = "regular" + long = "learned-rewrite" + type = "bool" + default = "false" + help = "rewrite the input based on learned literals" + [[option]] name = "expandDefinitions" long = "expand-definitions" diff --git a/src/preprocessing/passes/learned_rewrite.cpp b/src/preprocessing/passes/learned_rewrite.cpp index 4e9aa7bb2..fd3cf832b 100644 --- a/src/preprocessing/passes/learned_rewrite.cpp +++ b/src/preprocessing/passes/learned_rewrite.cpp @@ -60,6 +60,7 @@ LearnedRewrite::LearnedRewrite(PreprocessingPassContext* preprocContext) PreprocessingPassResult LearnedRewrite::applyInternal( AssertionPipeline* assertionsToPreprocess) { + NodeManager* nm = NodeManager::currentNM(); arith::BoundInference binfer; std::vector learnedLits = d_preprocContext->getLearnedLiterals(); std::unordered_set llrw; @@ -72,14 +73,29 @@ PreprocessingPassResult LearnedRewrite::applyInternal( else { Trace("learned-rewrite-ll") << "Learned literals:" << std::endl; + std::map originLit; for (const Node& l : learnedLits) { // maybe use the literal for bound inference? - Kind k = l.getKind(); - Assert(k != LT && k != GT && k != LEQ); - if (k == EQUAL || k == GEQ) + bool pol = l.getKind()!=NOT; + TNode atom = pol ? l : l[0]; + Kind ak = atom.getKind(); + Assert(ak != LT && ak != GT && ak != LEQ); + if ((ak == EQUAL && pol) || ak == GEQ) { - binfer.add(l); + // provide as < if negated >= + Node atomu; + if (!pol) + { + atomu = nm->mkNode(LT, atom[0], atom[1]); + originLit[atomu] = l; + } + else + { + atomu = l; + originLit[l] = l; + } + binfer.add(atomu); } Trace("learned-rewrite-ll") << "- " << l << std::endl; } @@ -93,7 +109,8 @@ PreprocessingPassResult LearnedRewrite::applyInternal( Node origin = i == 0 ? b.second.lower_origin : b.second.upper_origin; if (!origin.isNull()) { - llrw.insert(origin); + Assert (originLit.find(origin)!=originLit.end()); + llrw.insert(originLit[origin]); } } } @@ -139,7 +156,6 @@ PreprocessingPassResult LearnedRewrite::applyInternal( // unchanged. if (!llrw.empty()) { - NodeManager* nm = NodeManager::currentNM(); std::vector llrvec(llrw.begin(), llrw.end()); Node llc = nm->mkAnd(llrvec); Trace("learned-rewrite-assert") @@ -165,7 +181,13 @@ Node LearnedRewrite::rewriteLearnedRec(Node n, cur = visit.back(); visit.pop_back(); it = visited.find(cur); - + if (lems.find(cur) != lems.end()) + { + // n is a learned literal: replace by true, not considered a rewrite + // for statistics + visited[cur] = nm->mkConst(true); + continue; + } if (it == visited.end()) { // mark pre-visited with null; will post-visit to construct final node @@ -210,12 +232,6 @@ Node LearnedRewrite::rewriteLearned(Node n, std::unordered_set& lems) { NodeManager* nm = NodeManager::currentNM(); - if (lems.find(n) != lems.end()) - { - // n is a learned literal: replace by true, not considered a rewrite - // for statistics - return nm->mkConst(true); - } Trace("learned-rewrite-rr-debug") << "Rewrite " << n << std::endl; Node nr = Rewriter::rewrite(n); Kind k = nr.getKind(); @@ -384,63 +400,6 @@ Node LearnedRewrite::rewriteLearned(Node n, return nr; } } - // inferences based on combining div terms - Node currDen; - Node currNum; - std::vector sum; - size_t divCount = 0; - bool divTotal = true; - for (const std::pair& m : msum) - { - if (m.first.isNull()) - { - sum.push_back(m.second); - continue; - } - Kind mk = m.first.getKind(); - if (mk == INTS_DIVISION || mk == INTS_DIVISION_TOTAL) - { - Node factor = ArithMSum::mkCoeffTerm(m.second, m.first[0]); - divTotal = divTotal && mk == INTS_DIVISION_TOTAL; - divCount++; - if (currDen.isNull()) - { - currNum = factor; - currDen = m.first[1]; - } - else - { - factor = nm->mkNode(MULT, factor, currDen); - currNum = nm->mkNode(MULT, currNum, m.first[1]); - currNum = nm->mkNode(PLUS, currNum, factor); - currDen = nm->mkNode(MULT, currDen, m.first[1]); - } - } - else - { - Node factor = ArithMSum::mkCoeffTerm(m.second, m.first); - sum.push_back(factor); - } - } - if (divCount >= 2) - { - SkolemManager* sm = nm->getSkolemManager(); - Node r = sm->mkDummySkolem("r", nm->integerType()); - Node d = nm->mkNode( - divTotal ? INTS_DIVISION_TOTAL : INTS_DIVISION, currNum, currDen); - sum.push_back(d); - sum.push_back(r); - Node bound = - nm->mkNode(AND, - nm->mkNode(LEQ, nm->mkConst(-Rational(divCount - 1)), r), - nm->mkNode(LEQ, r, nm->mkConst(Rational(divCount - 1)))); - Node sumn = nm->mkNode(PLUS, sum); - Node lit = nm->mkNode(k, sumn, nm->mkConst(Rational(0))); - Node lemma = nm->mkNode(IMPLIES, nr, nm->mkNode(AND, lit, bound)); - Trace("learned-rewrite-div") - << "Div collect lemma: " << lemma << std::endl; - lems.insert(lemma); - } } } return nr; diff --git a/src/preprocessing/preprocessing_pass_registry.cpp b/src/preprocessing/preprocessing_pass_registry.cpp index 6f846dc74..f0bd5af86 100644 --- a/src/preprocessing/preprocessing_pass_registry.cpp +++ b/src/preprocessing/preprocessing_pass_registry.cpp @@ -41,6 +41,7 @@ #include "preprocessing/passes/int_to_bv.h" #include "preprocessing/passes/ite_removal.h" #include "preprocessing/passes/ite_simp.h" +#include "preprocessing/passes/learned_rewrite.h" #include "preprocessing/passes/miplib_trick.h" #include "preprocessing/passes/nl_ext_purify.h" #include "preprocessing/passes/non_clausal_simp.h" @@ -126,6 +127,7 @@ PreprocessingPassRegistry::PreprocessingPassRegistry() registerPassInfo("global-negate", callCtor); registerPassInfo("int-to-bv", callCtor); registerPassInfo("bv-to-int", callCtor); + registerPassInfo("learned-rewrite", callCtor); registerPassInfo("foreign-theory-rewrite", callCtor); registerPassInfo("synth-rr", callCtor); registerPassInfo("real-to-int", callCtor); diff --git a/src/smt/process_assertions.cpp b/src/smt/process_assertions.cpp index cf747c360..a9426d5bd 100644 --- a/src/smt/process_assertions.cpp +++ b/src/smt/process_assertions.cpp @@ -288,6 +288,11 @@ bool ProcessAssertions::apply(Assertions& as) } Debug("smt") << " assertions : " << assertions.size() << endl; + if (options::learnedRewrite()) + { + d_passes["learned-rewrite"]->apply(&assertions); + } + if (options::earlyIteRemoval()) { d_smtStats.d_numAssertionsPre += assertions.size(); diff --git a/src/smt/set_defaults.cpp b/src/smt/set_defaults.cpp index ee3701d51..229fdeec5 100644 --- a/src/smt/set_defaults.cpp +++ b/src/smt/set_defaults.cpp @@ -489,15 +489,27 @@ void setDefaults(LogicInfo& logic, bool isInternalSubsolver) opts.smt.simplificationMode = options::SimplificationMode::NONE; } + if (options::learnedRewrite()) + { + if (opts.smt.learnedRewriteWasSetByUser) + { + throw OptionException( + "learned rewrites not supported with unsat cores"); + } + Notice() << "SmtEngine: turning off learned rewrites to support " + "unsat cores\n"; + opts.smt.learnedRewrite = false; + } + if (options::pbRewrites()) { if (opts.arith.pbRewritesWasSetByUser) { throw OptionException( - "pseudoboolean rewrites not supported with old unsat cores"); + "pseudoboolean rewrites not supported with unsat cores"); } Notice() << "SmtEngine: turning off pseudoboolean rewrites to support " - "old unsat cores\n"; + "unsat cores\n"; opts.arith.pbRewrites = false; } @@ -505,10 +517,9 @@ void setDefaults(LogicInfo& logic, bool isInternalSubsolver) { if (opts.smt.sortInferenceWasSetByUser) { - throw OptionException( - "sort inference not supported with old unsat cores"); + throw OptionException("sort inference not supported with unsat cores"); } - Notice() << "SmtEngine: turning off sort inference to support old unsat " + Notice() << "SmtEngine: turning off sort inference to support unsat " "cores\n"; opts.smt.sortInference = false; } @@ -518,9 +529,9 @@ void setDefaults(LogicInfo& logic, bool isInternalSubsolver) if (opts.quantifiers.preSkolemQuantWasSetByUser) { throw OptionException( - "pre-skolemization not supported with old unsat cores"); + "pre-skolemization not supported with unsat cores"); } - Notice() << "SmtEngine: turning off pre-skolemization to support old " + Notice() << "SmtEngine: turning off pre-skolemization to support " "unsat cores\n"; opts.quantifiers.preSkolemQuant = false; } @@ -529,9 +540,9 @@ void setDefaults(LogicInfo& logic, bool isInternalSubsolver) { if (opts.bv.bitvectorToBoolWasSetByUser) { - throw OptionException("bv-to-bool not supported with old unsat cores"); + throw OptionException("bv-to-bool not supported with unsat cores"); } - Notice() << "SmtEngine: turning off bitvector-to-bool to support old " + Notice() << "SmtEngine: turning off bitvector-to-bool to support " "unsat cores\n"; opts.bv.bitvectorToBool = false; } @@ -541,10 +552,9 @@ void setDefaults(LogicInfo& logic, bool isInternalSubsolver) if (opts.bv.boolToBitvectorWasSetByUser) { throw OptionException( - "bool-to-bv != off not supported with old unsat cores"); + "bool-to-bv != off not supported with unsat cores"); } - Notice() - << "SmtEngine: turning off bool-to-bv to support old unsat cores\n"; + Notice() << "SmtEngine: turning off bool-to-bv to support unsat cores\n"; opts.bv.boolToBitvector = options::BoolToBVMode::OFF; } @@ -552,11 +562,9 @@ void setDefaults(LogicInfo& logic, bool isInternalSubsolver) { if (opts.bv.bvIntroducePow2WasSetByUser) { - throw OptionException( - "bv-intro-pow2 not supported with old unsat cores"); + throw OptionException("bv-intro-pow2 not supported with unsat cores"); } - Notice() - << "SmtEngine: turning off bv-intro-pow2 to support old unsat cores"; + Notice() << "SmtEngine: turning off bv-intro-pow2 to support unsat cores"; opts.bv.bvIntroducePow2 = false; } @@ -564,10 +572,9 @@ void setDefaults(LogicInfo& logic, bool isInternalSubsolver) { if (opts.smt.repeatSimpWasSetByUser) { - throw OptionException("repeat-simp not supported with old unsat cores"); + throw OptionException("repeat-simp not supported with unsat cores"); } - Notice() - << "SmtEngine: turning off repeat-simp to support old unsat cores\n"; + Notice() << "SmtEngine: turning off repeat-simp to support unsat cores\n"; opts.smt.repeatSimp = false; } @@ -575,22 +582,21 @@ void setDefaults(LogicInfo& logic, bool isInternalSubsolver) { if (opts.quantifiers.globalNegateWasSetByUser) { - throw OptionException( - "global-negate not supported with old unsat cores"); + throw OptionException("global-negate not supported with unsat cores"); } - Notice() << "SmtEngine: turning off global-negate to support old unsat " + Notice() << "SmtEngine: turning off global-negate to support unsat " "cores\n"; opts.quantifiers.globalNegate = false; } if (options::bitvectorAig()) { - throw OptionException("bitblast-aig not supported with old unsat cores"); + throw OptionException("bitblast-aig not supported with unsat cores"); } if (options::doITESimp()) { - throw OptionException("ITE simp not supported with old unsat cores"); + throw OptionException("ITE simp not supported with unsat cores"); } } else