From b52dc978f2445c6765b806119d238ca81cb8fe90 Mon Sep 17 00:00:00 2001 From: Gereon Kremer Date: Sat, 26 Sep 2020 19:11:13 +0200 Subject: [PATCH] Use inference manager for nl_solver (#5125) This PR migrates the nl_solver part of the nonlinear extension to use the new inference manager as well. It adds a new method clearWaitingLemmas to the inference manager and uses ArithLemma (though NlLemma exists) as we don't need the additional functionality of NlLemma. --- src/theory/arith/inference_manager.cpp | 4 + src/theory/arith/inference_manager.h | 5 ++ src/theory/arith/nl/nl_solver.cpp | 94 ++++++++------------- src/theory/arith/nl/nl_solver.h | 32 +++---- src/theory/arith/nl/nonlinear_extension.cpp | 84 ++++++++---------- 5 files changed, 95 insertions(+), 124 deletions(-) diff --git a/src/theory/arith/inference_manager.cpp b/src/theory/arith/inference_manager.cpp index 3e33668b0..77f042d4d 100644 --- a/src/theory/arith/inference_manager.cpp +++ b/src/theory/arith/inference_manager.cpp @@ -86,6 +86,10 @@ void InferenceManager::flushWaitingLemmas() } d_waitingLem.clear(); } +void InferenceManager::clearWaitingLemmas() +{ + d_waitingLem.clear(); +} void InferenceManager::addConflict(const Node& conf, InferenceId inftype) { diff --git a/src/theory/arith/inference_manager.h b/src/theory/arith/inference_manager.h index 215d7382e..9228add19 100644 --- a/src/theory/arith/inference_manager.h +++ b/src/theory/arith/inference_manager.h @@ -78,6 +78,11 @@ class InferenceManager : public InferenceManagerBuffered */ void flushWaitingLemmas(); + /** + * Removes all waiting lemmas without sending them anywhere. + */ + void clearWaitingLemmas(); + /** Add a conflict to the this inference manager. */ void addConflict(const Node& conf, InferenceId inftype); diff --git a/src/theory/arith/nl/nl_solver.cpp b/src/theory/arith/nl/nl_solver.cpp index 6cf0306d8..5ffba7229 100644 --- a/src/theory/arith/nl/nl_solver.cpp +++ b/src/theory/arith/nl/nl_solver.cpp @@ -63,11 +63,12 @@ bool hasNewMonomials(Node n, const std::vector& existing) return false; } -NlSolver::NlSolver(TheoryArith& containing, NlModel& model) - : d_containing(containing), +NlSolver::NlSolver(InferenceManager& im, ArithState& astate, NlModel& model) + : d_im(im), + d_astate(astate), d_model(model), d_cdb(d_mdb), - d_zero_split(containing.getUserContext()) + d_zero_split(d_astate.getUserContext()) { NodeManager* nm = NodeManager::currentNM(); d_true = nm->mkConst(true); @@ -165,9 +166,8 @@ void NlSolver::setMonomialFactor(Node a, Node b, const NodeMultiset& common) } } -std::vector NlSolver::checkSplitZero() +void NlSolver::checkSplitZero() { - std::vector lemmas; for (unsigned i = 0; i < d_ms_vars.size(); i++) { Node v = d_ms_vars[i]; @@ -175,13 +175,11 @@ std::vector NlSolver::checkSplitZero() { Node eq = v.eqNode(d_zero); eq = Rewriter::rewrite(eq); - Node literal = d_containing.getValuation().ensureLiteral(eq); - d_containing.getOutputChannel().requirePhase(literal, true); - lemmas.emplace_back(literal.orNode(literal.negate()), - InferenceId::NL_SPLIT_ZERO); + d_im.addPendingPhaseRequirement(eq, true); + Node lem = eq.orNode(eq.negate()); + d_im.addPendingArithLemma(lem, InferenceId::NL_SPLIT_ZERO); } } - return lemmas; } void NlSolver::assignOrderIds(std::vector& vars, @@ -260,8 +258,7 @@ int NlSolver::compareSign(Node oa, Node a, unsigned a_index, int status, - std::vector& exp, - std::vector& lem) + std::vector& exp) { Trace("nl-ext-debug") << "Process " << a << " at index " << a_index << ", status is " << status << std::endl; @@ -274,7 +271,7 @@ int NlSolver::compareSign(Node oa, { Node lemma = safeConstructNary(AND, exp).impNode(mkLit(oa, d_zero, status * 2)); - lem.emplace_back(lemma, InferenceId::NL_SIGN); + d_im.addPendingArithLemma(lemma, InferenceId::NL_SIGN); } return status; } @@ -291,17 +288,17 @@ int NlSolver::compareSign(Node oa, if (mvaoa.getConst().sgn() != 0) { Node lemma = av.eqNode(d_zero).impNode(oa.eqNode(d_zero)); - lem.emplace_back(lemma, InferenceId::NL_SIGN); + d_im.addPendingArithLemma(lemma, InferenceId::NL_SIGN); } return 0; } if (aexp % 2 == 0) { exp.push_back(av.eqNode(d_zero).negate()); - return compareSign(oa, a, a_index + 1, status, exp, lem); + return compareSign(oa, a, a_index + 1, status, exp); } exp.push_back(nm->mkNode(sgn == 1 ? GT : LT, av, d_zero)); - return compareSign(oa, a, a_index + 1, status * sgn, exp, lem); + return compareSign(oa, a, a_index + 1, status * sgn, exp); } bool NlSolver::compareMonomial( @@ -312,7 +309,7 @@ bool NlSolver::compareMonomial( Node b, NodeMultiset& b_exp_proc, std::vector& exp, - std::vector& lem, + std::vector& lem, std::map > >& cmp_infers) { Trace("nl-ext-comp-debug") @@ -416,7 +413,7 @@ bool NlSolver::compareMonomial( NodeMultiset& b_exp_proc, int status, std::vector& exp, - std::vector& lem, + std::vector& lem, std::map > >& cmp_infers) { Trace("nl-ext-comp-debug") @@ -448,7 +445,7 @@ bool NlSolver::compareMonomial( Node clem = nm->mkNode( IMPLIES, safeConstructNary(AND, exp), mkLit(oa, ob, status, true)); Trace("nl-ext-comp-lemma") << "comparison lemma : " << clem << std::endl; - lem.emplace_back(clem, InferenceId::NL_COMPARISON); + lem.emplace_back(clem, LemmaProperty::NONE, nullptr, InferenceId::NL_COMPARISON); cmp_infers[status][oa][ob] = clem; } return true; @@ -629,9 +626,8 @@ bool NlSolver::compareMonomial( return false; } -std::vector NlSolver::checkMonomialSign() +void NlSolver::checkMonomialSign() { - std::vector lemmas; std::map signs; Trace("nl-ext") << "Get monomial sign lemmas..." << std::endl; for (unsigned j = 0; j < d_ms.size(); j++) @@ -648,7 +644,7 @@ std::vector NlSolver::checkMonomialSign() } if (d_m_nconst_factor.find(a) == d_m_nconst_factor.end()) { - signs[a] = compareSign(a, a, 0, 1, exp, lemmas); + signs[a] = compareSign(a, a, 0, 1, exp); if (signs[a] == 0) { d_ms_proc[a] = true; @@ -665,10 +661,9 @@ std::vector NlSolver::checkMonomialSign() } } } - return lemmas; } -std::vector NlSolver::checkMonomialMagnitude(unsigned c) +void NlSolver::checkMonomialMagnitude(unsigned c) { // ensure information is setup if (c == 0) @@ -682,7 +677,7 @@ std::vector NlSolver::checkMonomialMagnitude(unsigned c) } unsigned r = 1; - std::vector lemmas; + std::vector lemmas; // if (x,y,L) in cmp_infers, then x > y inferred as conclusion of L // in lemmas std::map > > cmp_infers; @@ -840,24 +835,18 @@ std::vector NlSolver::checkMonomialMagnitude(unsigned c) } } } - std::vector nr_lemmas; for (unsigned i = 0; i < lemmas.size(); i++) { if (r_lemmas.find(lemmas[i].d_node) == r_lemmas.end()) { - nr_lemmas.push_back(lemmas[i]); + d_im.addPendingArithLemma(lemmas[i]); } } // could only take maximal lower/minimial lower bounds? - - Trace("nl-ext-comp") << nr_lemmas.size() << " / " << lemmas.size() - << " were non-redundant." << std::endl; - return nr_lemmas; } -std::vector NlSolver::checkTangentPlanes() +void NlSolver::checkTangentPlanes(bool asWaitingLemmas) { - std::vector lemmas; Trace("nl-ext") << "Get monomial tangent plane lemmas..." << std::endl; NodeManager* nm = NodeManager::currentNM(); const std::map >& ccMap = @@ -1007,19 +996,16 @@ std::vector NlSolver::checkTangentPlanes() tplaneConj.push_back(lb_reverse2); Node tlem = nm->mkNode(AND, tplaneConj); - lemmas.emplace_back(tlem, InferenceId::NL_TANGENT_PLANE); + d_im.addPendingArithLemma( + tlem, InferenceId::NL_TANGENT_PLANE, asWaitingLemmas); } } } } } - Trace("nl-ext") << "...trying " << lemmas.size() << " tangent plane lemmas..." - << std::endl; - return lemmas; } -std::vector NlSolver::checkMonomialInferBounds( - std::vector& nt_lemmas, +void NlSolver::checkMonomialInferBounds( const std::vector& asserts, const std::vector& false_asserts) { @@ -1033,7 +1019,6 @@ std::vector NlSolver::checkMonomialInferBounds( const std::map >& cim = d_cdb.getConstraints(); - std::vector lemmas; NodeManager* nm = NodeManager::currentNM(); // register constraints Trace("nl-ext-debug") << "Register bound constraints..." << std::endl; @@ -1260,26 +1245,16 @@ std::vector NlSolver::checkMonomialInferBounds( << " (pre-rewrite : " << pr_iblem << ")" << std::endl; // Trace("nl-ext-bound-lemma") << " intro new // monomials = " << introNewTerms << std::endl; - if (!introNewTerms) - { - lemmas.emplace_back(iblem, InferenceId::NL_INFER_BOUNDS); - } - else - { - nt_lemmas.emplace_back(iblem, InferenceId::NL_INFER_BOUNDS_NT); - } + d_im.addPendingArithLemma(iblem, InferenceId::NL_INFER_BOUNDS_NT, introNewTerms); } } } } } - return lemmas; } -std::vector NlSolver::checkFactoring( - const std::vector& asserts, const std::vector& false_asserts) +void NlSolver::checkFactoring(const std::vector& asserts, const std::vector& false_asserts) { - std::vector lemmas; NodeManager* nm = NodeManager::currentNM(); Trace("nl-ext") << "Get factoring lemmas..." << std::endl; for (const Node& lit : asserts) @@ -1366,7 +1341,7 @@ std::vector NlSolver::checkFactoring( sum = Rewriter::rewrite(sum); Trace("nl-ext-factor") << "* Factored sum for " << x << " : " << sum << std::endl; - Node kf = getFactorSkolem(sum, lemmas); + Node kf = getFactorSkolem(sum); std::vector poly; poly.push_back(nm->mkNode(MULT, x, kf)); std::map >::iterator itfo = @@ -1398,15 +1373,14 @@ std::vector NlSolver::checkFactoring( lemma_disj.push_back(conc_lit); Node flem = nm->mkNode(OR, lemma_disj); Trace("nl-ext-factor") << "...lemma is " << flem << std::endl; - lemmas.emplace_back(flem, InferenceId::NL_FACTOR); + d_im.addPendingArithLemma(flem, InferenceId::NL_FACTOR); } } } } - return lemmas; } -Node NlSolver::getFactorSkolem(Node n, std::vector& lemmas) +Node NlSolver::getFactorSkolem(Node n) { std::map::iterator itf = d_factor_skolem.find(n); if (itf == d_factor_skolem.end()) @@ -1414,16 +1388,15 @@ Node NlSolver::getFactorSkolem(Node n, std::vector& lemmas) NodeManager* nm = NodeManager::currentNM(); Node k = nm->mkSkolem("kf", n.getType()); Node k_eq = Rewriter::rewrite(k.eqNode(n)); - lemmas.push_back(k_eq); + d_im.addPendingArithLemma(k_eq, InferenceId::NL_FACTOR); d_factor_skolem[n] = k; return k; } return itf->second; } -std::vector NlSolver::checkMonomialInferResBounds() +void NlSolver::checkMonomialInferResBounds() { - std::vector lemmas; NodeManager* nm = NodeManager::currentNM(); Trace("nl-ext") << "Get monomial resolution inferred bound lemmas..." << std::endl; @@ -1570,7 +1543,7 @@ std::vector NlSolver::checkMonomialInferResBounds() rblem = Rewriter::rewrite(rblem); Trace("nl-ext-rbound-lemma") << "Resolution bound lemma : " << rblem << std::endl; - lemmas.emplace_back(rblem, InferenceId::NL_RES_INFER_BOUNDS); + d_im.addPendingArithLemma(rblem, InferenceId::NL_RES_INFER_BOUNDS); } } exp.pop_back(); @@ -1583,7 +1556,6 @@ std::vector NlSolver::checkMonomialInferResBounds() } } } - return lemmas; } } // namespace nl diff --git a/src/theory/arith/nl/nl_solver.h b/src/theory/arith/nl/nl_solver.h index 050062234..9dd5b03c6 100644 --- a/src/theory/arith/nl/nl_solver.h +++ b/src/theory/arith/nl/nl_solver.h @@ -58,7 +58,7 @@ class NlSolver typedef context::CDHashSet NodeSet; public: - NlSolver(TheoryArith& containing, NlModel& model); + NlSolver(InferenceManager& im, ArithState& astate, NlModel& model); ~NlSolver(); /** init last call @@ -79,7 +79,7 @@ class NlSolver * t = 0 V t != 0 * where t is a term that exists in the current context. */ - std::vector checkSplitZero(); + void checkSplitZero(); /** check monomial sign * @@ -96,7 +96,7 @@ class NlSolver * x < 0 => x*y*y < 0 * x = 0 => x*y*z = 0 */ - std::vector checkMonomialSign(); + void checkMonomialSign(); /** check monomial magnitude * @@ -118,7 +118,7 @@ class NlSolver * against 1, 1 : compare non-linear monomials against variables, 2 : compare * non-linear monomials against other non-linear monomials. */ - std::vector checkMonomialMagnitude(unsigned c); + void checkMonomialMagnitude(unsigned c); /** check monomial inferred bounds * @@ -137,8 +137,7 @@ class NlSolver * ...where (y > z + w) and x*y are a constraint and term * that occur in the current context. */ - std::vector checkMonomialInferBounds( - std::vector& nt_lemmas, + void checkMonomialInferBounds( const std::vector& asserts, const std::vector& false_asserts); @@ -154,7 +153,7 @@ class NlSolver * ...where k is fresh and x*z + y*z > t is a * constraint that occurs in the current context. */ - std::vector checkFactoring(const std::vector& asserts, + void checkFactoring(const std::vector& asserts, const std::vector& false_asserts); /** check monomial infer resolution bounds @@ -171,7 +170,7 @@ class NlSolver * ...where s <= x*z and x*y <= t are constraints * that occur in the current context. */ - std::vector checkMonomialInferResBounds(); + void checkMonomialInferResBounds(); /** check tangent planes * @@ -197,12 +196,14 @@ class NlSolver * ( ( x>2 ^ y>5) ^ (x<2 ^ y<5) ) => x*y > 5*x + 2*y - 10 * ( ( x>2 ^ y<5) ^ (x<2 ^ y>5) ) => x*y < 5*x + 2*y - 10 */ - std::vector checkTangentPlanes(); + void checkTangentPlanes(bool asWaitingLemmas); //-------------------------------------------- end lemma schemas private: - // The theory of arithmetic containing this extension. - TheoryArith& d_containing; + /** The inference manager that we push conflicts and lemmas to. */ + InferenceManager& d_im; + /** Reference to the state. */ + ArithState& d_astate; /** Reference to the non-linear model object */ NlModel& d_model; /** commonly used terms */ @@ -294,8 +295,7 @@ class NlSolver Node a, unsigned a_index, int status, - std::vector& exp, - std::vector& lem); + std::vector& exp); /** compare monomials a and b * * Initially, a call to this function is such that : @@ -338,7 +338,7 @@ class NlSolver Node b, NodeMultiset& b_exp_proc, std::vector& exp, - std::vector& lem, + std::vector& lem, std::map > >& cmp_infers); /** helper function for above * @@ -356,10 +356,10 @@ class NlSolver NodeMultiset& b_exp_proc, int status, std::vector& exp, - std::vector& lem, + std::vector& lem, std::map > >& cmp_infers); /** Get factor skolem for n, add resulting lemmas to lemmas */ - Node getFactorSkolem(Node n, std::vector& lemmas); + Node getFactorSkolem(Node n); }; /* class NlSolver */ } // namespace nl diff --git a/src/theory/arith/nl/nonlinear_extension.cpp b/src/theory/arith/nl/nonlinear_extension.cpp index af1f536be..251294d37 100644 --- a/src/theory/arith/nl/nonlinear_extension.cpp +++ b/src/theory/arith/nl/nonlinear_extension.cpp @@ -47,7 +47,7 @@ NonlinearExtension::NonlinearExtension(TheoryArith& containing, containing.getOutputChannel()), d_model(containing.getSatContext()), d_trSlv(d_im, d_model), - d_nlSlv(containing, d_model), + d_nlSlv(d_im, state, d_model), d_cadSlv(d_im, d_model), d_icpSlv(d_im), d_iandSlv(d_im, state, d_model), @@ -426,13 +426,12 @@ int NonlinearExtension::checkLastCall(const std::vector& assertions, if (options::nlExt() && options::nlExtSplitZero()) { Trace("nl-ext") << "Get zero split lemmas..." << std::endl; - lemmas = d_nlSlv.checkSplitZero(); - filterLemmas(lemmas, lems); - if (!lems.empty()) + d_nlSlv.checkSplitZero(); + if (d_im.hasUsed()) { - Trace("nl-ext") << " ...finished with " << lems.size() << " new lemmas." + Trace("nl-ext") << " ...finished with " << d_im.numPendingLemmas() << " new lemmas." << std::endl; - return lems.size(); + return d_im.numPendingLemmas(); } } @@ -462,13 +461,12 @@ int NonlinearExtension::checkLastCall(const std::vector& assertions, if (options::nlExt()) { //---------------------------------lemmas based on sign (comparison to zero) - lemmas = d_nlSlv.checkMonomialSign(); - filterLemmas(lemmas, lems); - if (!lems.empty()) + d_nlSlv.checkMonomialSign(); + if (d_im.hasUsed()) { - Trace("nl-ext") << " ...finished with " << lems.size() << " new lemmas." + Trace("nl-ext") << " ...finished with " << d_im.numPendingLemmas() << " new lemmas." << std::endl; - return lems.size(); + return d_im.numPendingLemmas(); } //-----------------------------------monotonicity of transdental functions @@ -485,62 +483,53 @@ int NonlinearExtension::checkLastCall(const std::vector& assertions, for (unsigned c = 0; c < 3; c++) { // c is effort level - lemmas = d_nlSlv.checkMonomialMagnitude(c); - unsigned nlem = lemmas.size(); - filterLemmas(lemmas, lems); - if (!lems.empty()) + d_nlSlv.checkMonomialMagnitude(c); + if (d_im.hasUsed()) { - Trace("nl-ext") << " ...finished with " << lems.size() - << " new lemmas (out of possible " << nlem << ")." - << std::endl; - return lems.size(); + Trace("nl-ext") << " ...finished with " << d_im.numPendingLemmas() + << " new lemmas." << std::endl; + return d_im.numPendingLemmas(); } } //-----------------------------------inferred bounds lemmas // e.g. x >= t => y*x >= y*t - std::vector nt_lemmas; - lemmas = - d_nlSlv.checkMonomialInferBounds(nt_lemmas, assertions, false_asserts); - // Trace("nl-ext") << "Bound lemmas : " << lemmas.size() << ", " << - // nt_lemmas.size() << std::endl; prioritize lemmas that do not - // introduce new monomials - filterLemmas(lemmas, lems); + d_nlSlv.checkMonomialInferBounds(assertions, false_asserts); + Trace("nl-ext") << "Bound lemmas : " << d_im.numPendingLemmas() << ", " << d_im.numWaitingLemmas() << std::endl; + // prioritize lemmas that do not introduce new monomials if (options::nlExtTangentPlanes() && options::nlExtTangentPlanesInterleave()) { - lemmas = d_nlSlv.checkTangentPlanes(); - filterLemmas(lemmas, lems); + d_nlSlv.checkTangentPlanes(false); } - if (!lems.empty()) + if (d_im.hasUsed()) { - Trace("nl-ext") << " ...finished with " << lems.size() << " new lemmas." + Trace("nl-ext") << " ...finished with " << d_im.numPendingLemmas() << " new lemmas." << std::endl; - return lems.size(); + return d_im.numPendingLemmas(); } // from inferred bound inferences : now do ones that introduce new terms - filterLemmas(nt_lemmas, lems); - if (!lems.empty()) + d_im.flushWaitingLemmas(); + if (d_im.hasUsed()) { - Trace("nl-ext") << " ...finished with " << lems.size() + Trace("nl-ext") << " ...finished with " << d_im.numPendingLemmas() << " new (monomial-introducing) lemmas." << std::endl; - return lems.size(); + return d_im.numPendingLemmas(); } //------------------------------------factoring lemmas // x*y + x*z >= t => exists k. k = y + z ^ x*k >= t if (options::nlExtFactor()) { - lemmas = d_nlSlv.checkFactoring(assertions, false_asserts); - filterLemmas(lemmas, lems); - if (!lems.empty()) + d_nlSlv.checkFactoring(assertions, false_asserts); + if (d_im.hasUsed()) { - Trace("nl-ext") << " ...finished with " << lems.size() + Trace("nl-ext") << " ...finished with " << d_im.numPendingLemmas() << " new lemmas." << std::endl; - return lems.size(); + return d_im.numPendingLemmas(); } } @@ -548,13 +537,12 @@ int NonlinearExtension::checkLastCall(const std::vector& assertions, // e.g. ( y>=0 ^ s <= x*z ^ x*y <= t ) => y*s <= z*t if (options::nlExtResBound()) { - lemmas = d_nlSlv.checkMonomialInferResBounds(); - filterLemmas(lemmas, lems); - if (!lems.empty()) + d_nlSlv.checkMonomialInferResBounds(); + if (d_im.hasUsed()) { - Trace("nl-ext") << " ...finished with " << lems.size() + Trace("nl-ext") << " ...finished with " << d_im.numPendingLemmas() << " new lemmas." << std::endl; - return lems.size(); + return d_im.numPendingLemmas(); } } @@ -562,8 +550,7 @@ int NonlinearExtension::checkLastCall(const std::vector& assertions, if (options::nlExtTangentPlanes() && !options::nlExtTangentPlanesInterleave()) { - lemmas = d_nlSlv.checkTangentPlanes(); - filterLemmas(lemmas, wlems); + d_nlSlv.checkTangentPlanes(true); } if (options::nlExtTfTangentPlanes()) { @@ -746,6 +733,7 @@ bool NonlinearExtension::modelBasedRefinement(std::vector& mlems) checkLastCall(assertions, false_asserts, xts, mlems, wlems); if (!mlems.empty() || d_im.hasSentLemma() || d_im.hasPendingLemma()) { + d_im.clearWaitingLemmas(); return true; } } @@ -776,6 +764,7 @@ bool NonlinearExtension::modelBasedRefinement(std::vector& mlems) filterLemmas(lemmas, mlems); if (!mlems.empty() || d_im.hasPendingLemma()) { + d_im.clearWaitingLemmas(); return true; } } @@ -846,6 +835,7 @@ bool NonlinearExtension::modelBasedRefinement(std::vector& mlems) d_containing.getOutputChannel().setIncomplete(); } } + d_im.clearWaitingLemmas(); } while (needsRecheck); // did not add lemmas -- 2.30.2