From 5d0a5e5571044000fdaf0d908bace8ed7c1c536a Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Tue, 12 Nov 2019 23:34:05 -0600 Subject: [PATCH] Refactor non-linear extension for model-based refinement (#3452) * Refactor non-linear extension for model-based refinement * Format * Minor * Address --- src/theory/arith/nl_model.cpp | 97 +++--- src/theory/arith/nl_model.h | 15 +- src/theory/arith/nonlinear_extension.cpp | 357 ++++++++++++----------- src/theory/arith/nonlinear_extension.h | 24 +- 4 files changed, 249 insertions(+), 244 deletions(-) diff --git a/src/theory/arith/nl_model.cpp b/src/theory/arith/nl_model.cpp index 571fbda6c..62bdf310b 100644 --- a/src/theory/arith/nl_model.cpp +++ b/src/theory/arith/nl_model.cpp @@ -137,38 +137,16 @@ Node NlModel::computeModelValue(Node n, bool isConcrete) Node NlModel::getValueInternal(Node n) const { return d_model->getValue(n); - /* - std::map< Node, Node >::const_iterator it = d_arithVal.find(n); - if (it!=d_arithVal.end()) - { - return it->second; - } - return Node::null(); - */ } bool NlModel::hasTerm(Node n) const { return d_model->hasTerm(n); - // return d_arithVal.find(n)!=d_arithVal.end(); } Node NlModel::getRepresentative(Node n) const { return d_model->getRepresentative(n); - /* - std::map< Node, Node >::const_iterator it = d_arithVal.find(n); - if (it!=d_arithVal.end()) - { - std::map< Node, Node >::const_iterator itr = d_valToRep.find(it->second); - if (itr != d_valToRep.end()) - { - return itr->second; - } - Assert(false); - } - return Node::null(); - */ } int NlModel::compare(Node i, Node j, bool isConcrete, bool isAbsolute) @@ -1225,12 +1203,37 @@ bool NlModel::getApproximateSqrt(Node c, Node& l, Node& u, unsigned iter) const return true; } -void NlModel::recordApproximations() +void NlModel::printModelValue(const char* c, Node n, unsigned prec) const +{ + if (Trace.isOn(c)) + { + Trace(c) << " " << n << " -> "; + for (int i = 1; i >= 0; --i) + { + std::map::const_iterator it = d_mv[i].find(n); + Assert(it != d_mv[i].end()); + if (it->second.isConst()) + { + printRationalApprox(c, it->second, prec); + } + else + { + Trace(c) << "?"; + } + Trace(c) << (i == 1 ? " [actual: " : " ]"); + } + Trace(c) << std::endl; + } +} + +void NlModel::getModelValueRepair(std::map& arithModel, + std::map& approximations) { // Record the approximations we used. This code calls the // recordApproximation method of the model, which overrides the model // values for variables that we solved for, using techniques specific to // this class. + Trace("nl-model") << "NlModel::getModelValueRepair:" << std::endl; NodeManager* nm = NodeManager::currentNM(); for (const std::pair >& cb : d_check_model_bounds) @@ -1241,17 +1244,15 @@ void NlModel::recordApproximations() Node v = cb.first; if (l != u) { - pred = nm->mkNode(AND, nm->mkNode(GEQ, v, l), nm->mkNode(GEQ, u, v)); - } - else if (!d_model->areEqual(v, l)) - { - // only record if value was not equal already - pred = v.eqNode(l); + Node pred = nm->mkNode(AND, nm->mkNode(GEQ, v, l), nm->mkNode(GEQ, u, v)); + approximations[v] = pred; + Trace("nl-model") << v << " approximated as " << pred << std::endl; } - if (!pred.isNull()) + else { - pred = Rewriter::rewrite(pred); - d_model->recordApproximation(v, pred); + // overwrite + arithModel[v] = l; + Trace("nl-model") << v << " exact approximation is " << l << std::endl; } } // Also record the exact values we used. An exact value can be seen as a @@ -1262,36 +1263,12 @@ void NlModel::recordApproximations() { Node v = d_check_model_vars[i]; Node s = d_check_model_subs[i]; - if (!d_model->areEqual(v, s)) - { - Node pred = v.eqNode(s); - pred = Rewriter::rewrite(pred); - d_model->recordApproximation(v, pred); - } - } -} -void NlModel::printModelValue(const char* c, Node n, unsigned prec) const -{ - if (Trace.isOn(c)) - { - Trace(c) << " " << n << " -> "; - for (unsigned i = 0; i < 2; i++) - { - std::map::const_iterator it = d_mv[1 - i].find(n); - Assert(it != d_mv[1 - i].end()); - if (it->second.isConst()) - { - printRationalApprox(c, it->second, prec); - } - else - { - Trace(c) << "?"; // it->second; - } - Trace(c) << (i == 0 ? " [actual: " : " ]"); - } - Trace(c) << std::endl; + // overwrite + arithModel[v] = s; + Trace("nl-model") << v << " solved is " << s << std::endl; } } + } // namespace arith } // namespace theory } // namespace CVC4 diff --git a/src/theory/arith/nl_model.h b/src/theory/arith/nl_model.h index 19341cc5f..ed13327cc 100644 --- a/src/theory/arith/nl_model.h +++ b/src/theory/arith/nl_model.h @@ -171,12 +171,19 @@ class NlModel * term n on Trace c with precision prec. */ void printModelValue(const char* c, Node n, unsigned prec = 5) const; - /** record approximations in the current model + /** get model value repair * - * This makes necessary calls that notify the model of any approximations - * that were used by this solver. + * This gets mappings that indicate how to repair the model generated by the + * linear arithmetic solver. This method should be called after a successful + * call to checkModel above. + * + * The mapping arithModel is updated by this method to map arithmetic terms v + * to their (exact) value that was computed during checkModel; the mapping + * approximations is updated to store approximate values in the form of a + * predicate over v. */ - void recordApproximations(); + void getModelValueRepair(std::map& arithModel, + std::map& approximations); private: /** The current model */ diff --git a/src/theory/arith/nonlinear_extension.cpp b/src/theory/arith/nonlinear_extension.cpp index e14f07a4f..d76089541 100644 --- a/src/theory/arith/nonlinear_extension.cpp +++ b/src/theory/arith/nonlinear_extension.cpp @@ -944,8 +944,7 @@ int NonlinearExtension::checkLastCall(const std::vector& assertions, std::vector repList; for (const Node& ac : a) { - Node r = - d_containing.getValuation().getModel()->getRepresentative(ac); + Node r = d_model.computeConcreteModelValue(ac); repList.push_back(r); } Node aa = argTrie[ak].add(a, repList); @@ -1232,11 +1231,7 @@ void NonlinearExtension::check(Theory::Effort e) { Trace("nl-ext") << std::endl; Trace("nl-ext") << "NonlinearExtension::check, effort = " << e << ", built model = " << d_builtModel.get() << std::endl; - if (d_builtModel.get()) - { - // already built model, nothing to do - } - else if (e == Theory::EFFORT_FULL) + if (e == Theory::EFFORT_FULL) { d_containing.getExtTheory()->clearCache(); d_needsLastCall = true; @@ -1255,198 +1250,210 @@ void NonlinearExtension::check(Theory::Effort e) { } else { - // get the assertions - std::vector assertions; - getAssertions(assertions); - - // reset cached information + std::map approximations; + std::map arithModel; TheoryModel* tm = d_containing.getValuation().getModel(); - d_model.reset(tm); - - Trace("nl-ext-mv-assert") - << "Getting model values... check for [model-false]" << std::endl; - // get the assertions that are false in the model - const std::vector false_asserts = checkModelEval(assertions); - - // get the extended terms belonging to this theory - std::vector xts; - d_containing.getExtTheory()->getTerms(xts); - - if (Trace.isOn("nl-ext-debug")) - { - Trace("nl-ext-debug") - << " processing NonlinearExtension::check : " << std::endl; - Trace("nl-ext-debug") << " " << false_asserts.size() - << " false assertions" << std::endl; - Trace("nl-ext-debug") - << " " << xts.size() << " extended terms: " << std::endl; - Trace("nl-ext-debug") << " "; - for (unsigned j = 0; j < xts.size(); j++) - { - Trace("nl-ext-debug") << xts[j] << " "; - } - Trace("nl-ext-debug") << std::endl; - } - - // compute whether shared terms have correct values - unsigned num_shared_wrong_value = 0; - std::vector shared_term_value_splits; - // must ensure that shared terms are equal to their concrete value - Trace("nl-ext-mv") << "Shared terms : " << std::endl; - for (context::CDList::const_iterator its = - d_containing.shared_terms_begin(); - its != d_containing.shared_terms_end(); - ++its) - { - TNode shared_term = *its; - // compute its value in the model, and its evaluation in the model - Node stv0 = d_model.computeConcreteModelValue(shared_term); - Node stv1 = d_model.computeAbstractModelValue(shared_term); - d_model.printModelValue("nl-ext-mv", shared_term); - if (stv0 != stv1) + if (!d_builtModel.get()) + { + // run model-based refinement + if (modelBasedRefinement()) { - num_shared_wrong_value++; - Trace("nl-ext-mv") << "Bad shared term value : " << shared_term - << std::endl; - if (shared_term != stv0) - { - // split on the value, this is non-terminating in general, TODO : - // improve this - Node eq = shared_term.eqNode(stv0); - shared_term_value_splits.push_back(eq); - } - else - { - // this can happen for transcendental functions - // the problem is that we cannot evaluate transcendental functions - // (they don't have a rewriter that returns constants) - // thus, the actual value in their model can be themselves, hence we - // have no reference point to rule out the current model. In this - // case, we may set incomplete below. - } + return; } } - Trace("nl-ext-debug") << " " << num_shared_wrong_value - << " shared terms with wrong model value." - << std::endl; - bool needsRecheck; - do + // get the values that should be replaced in the model + d_model.getModelValueRepair(arithModel, approximations); + // those that are exact are written as exact approximations to the model + for (std::pair& r : arithModel) + { + Node eq = r.first.eqNode(r.second); + eq = Rewriter::rewrite(eq); + tm->recordApproximation(r.first, eq); + } + // those that are approximate are recorded as approximations + for (std::pair& a : approximations) { - d_model.resetCheck(); - needsRecheck = false; - Assert(e == Theory::EFFORT_LAST_CALL); - // complete_status: - // 1 : we may answer SAT, -1 : we may not answer SAT, 0 : unknown - int complete_status = 1; - int num_added_lemmas = 0; - // we require a check either if an assertion is false or a shared term has - // a wrong value - if (!false_asserts.empty() || num_shared_wrong_value > 0) + tm->recordApproximation(a.first, a.second); + } + } +} + +bool NonlinearExtension::modelBasedRefinement() +{ + // reset the model object + d_model.reset(d_containing.getValuation().getModel()); + // get the assertions + std::vector assertions; + getAssertions(assertions); + + Trace("nl-ext-mv-assert") + << "Getting model values... check for [model-false]" << std::endl; + // get the assertions that are false in the model + const std::vector false_asserts = checkModelEval(assertions); + + // get the extended terms belonging to this theory + std::vector xts; + d_containing.getExtTheory()->getTerms(xts); + + if (Trace.isOn("nl-ext-debug")) + { + Trace("nl-ext-debug") << " processing NonlinearExtension::check : " + << std::endl; + Trace("nl-ext-debug") << " " << false_asserts.size() + << " false assertions" << std::endl; + Trace("nl-ext-debug") << " " << xts.size() + << " extended terms: " << std::endl; + Trace("nl-ext-debug") << " "; + for (unsigned j = 0; j < xts.size(); j++) + { + Trace("nl-ext-debug") << xts[j] << " "; + } + Trace("nl-ext-debug") << std::endl; + } + + // compute whether shared terms have correct values + unsigned num_shared_wrong_value = 0; + std::vector shared_term_value_splits; + // must ensure that shared terms are equal to their concrete value + Trace("nl-ext-mv") << "Shared terms : " << std::endl; + for (context::CDList::const_iterator its = + d_containing.shared_terms_begin(); + its != d_containing.shared_terms_end(); + ++its) + { + TNode shared_term = *its; + // compute its value in the model, and its evaluation in the model + Node stv0 = d_model.computeConcreteModelValue(shared_term); + Node stv1 = d_model.computeAbstractModelValue(shared_term); + d_model.printModelValue("nl-ext-mv", shared_term); + if (stv0 != stv1) + { + num_shared_wrong_value++; + Trace("nl-ext-mv") << "Bad shared term value : " << shared_term + << std::endl; + if (shared_term != stv0) { - complete_status = num_shared_wrong_value > 0 ? -1 : 0; - num_added_lemmas = checkLastCall(assertions, false_asserts, xts); - if (num_added_lemmas > 0) - { - return; - } + // split on the value, this is non-terminating in general, TODO : + // improve this + Node eq = shared_term.eqNode(stv0); + shared_term_value_splits.push_back(eq); } - Trace("nl-ext") << "Finished check with status : " << complete_status - << std::endl; + else + { + // this can happen for transcendental functions + // the problem is that we cannot evaluate transcendental functions + // (they don't have a rewriter that returns constants) + // thus, the actual value in their model can be themselves, hence we + // have no reference point to rule out the current model. In this + // case, we may set incomplete below. + } + } + } + Trace("nl-ext-debug") << " " << num_shared_wrong_value + << " shared terms with wrong model value." << std::endl; + bool needsRecheck; + do + { + d_model.resetCheck(); + needsRecheck = false; + // complete_status: + // 1 : we may answer SAT, -1 : we may not answer SAT, 0 : unknown + int complete_status = 1; + int num_added_lemmas = 0; + // we require a check either if an assertion is false or a shared term has + // a wrong value + if (!false_asserts.empty() || num_shared_wrong_value > 0) + { + complete_status = num_shared_wrong_value > 0 ? -1 : 0; + num_added_lemmas = checkLastCall(assertions, false_asserts, xts); + if (num_added_lemmas > 0) + { + return true; + } + } + Trace("nl-ext") << "Finished check with status : " << complete_status + << std::endl; - // if we did not add a lemma during check and there is a chance for SAT - if (complete_status == 0) + // if we did not add a lemma during check and there is a chance for SAT + if (complete_status == 0) + { + Trace("nl-ext") + << "Check model based on bounds for irrational-valued functions..." + << std::endl; + // check the model based on simple solving of equalities and using + // error bounds on the Taylor approximation of transcendental functions. + if (checkModel(assertions, false_asserts)) { - Trace("nl-ext") - << "Check model based on bounds for irrational-valued functions..." - << std::endl; - // check the model based on simple solving of equalities and using - // error bounds on the Taylor approximation of transcendental functions. - if (checkModel(assertions, false_asserts)) - { - complete_status = 1; - } + complete_status = 1; } + } - // if we have not concluded SAT - if (complete_status != 1) + // if we have not concluded SAT + if (complete_status != 1) + { + // flush the waiting lemmas + num_added_lemmas = flushLemmas(d_waiting_lemmas); + if (num_added_lemmas > 0) { - // flush the waiting lemmas - num_added_lemmas = flushLemmas(d_waiting_lemmas); - if (num_added_lemmas > 0) - { - Trace("nl-ext") << "...added " << num_added_lemmas - << " waiting lemmas." << std::endl; - return; - } - // resort to splitting on shared terms with their model value - // if we did not add any lemmas - if (num_shared_wrong_value > 0) + Trace("nl-ext") << "...added " << num_added_lemmas << " waiting lemmas." + << std::endl; + return true; + } + // resort to splitting on shared terms with their model value + // if we did not add any lemmas + if (num_shared_wrong_value > 0) + { + complete_status = -1; + if (!shared_term_value_splits.empty()) { - complete_status = -1; - if (!shared_term_value_splits.empty()) + std::vector shared_term_value_lemmas; + for (const Node& eq : shared_term_value_splits) { - std::vector shared_term_value_lemmas; - for (const Node& eq : shared_term_value_splits) - { - Node req = Rewriter::rewrite(eq); - Node literal = d_containing.getValuation().ensureLiteral(req); - d_containing.getOutputChannel().requirePhase(literal, true); - Trace("nl-ext-debug") << "Split on : " << literal << std::endl; - shared_term_value_lemmas.push_back( - literal.orNode(literal.negate())); - } - num_added_lemmas = flushLemmas(shared_term_value_lemmas); - if (num_added_lemmas > 0) - { - Trace("nl-ext") - << "...added " << num_added_lemmas - << " shared term value split lemmas." << std::endl; - return; - } + Node req = Rewriter::rewrite(eq); + Node literal = d_containing.getValuation().ensureLiteral(req); + d_containing.getOutputChannel().requirePhase(literal, true); + Trace("nl-ext-debug") << "Split on : " << literal << std::endl; + shared_term_value_lemmas.push_back( + literal.orNode(literal.negate())); } - else + num_added_lemmas = flushLemmas(shared_term_value_lemmas); + if (num_added_lemmas > 0) { - // this can happen if we are trying to do theory combination with - // trancendental functions - // since their model value cannot even be computed exactly + Trace("nl-ext") << "...added " << num_added_lemmas + << " shared term value split lemmas." << std::endl; + return true; } } - - // we are incomplete - if (options::nlExtIncPrecision() && d_model.usedApproximate()) - { - d_taylor_degree++; - needsRecheck = true; - // increase precision for PI? - // Difficult since Taylor series is very slow to converge - Trace("nl-ext") << "...increment Taylor degree to " << d_taylor_degree - << std::endl; - } else { - Trace("nl-ext") << "...failed to send lemma in " - "NonLinearExtension, set incomplete" - << std::endl; - d_containing.getOutputChannel().setIncomplete(); + // this can happen if we are trying to do theory combination with + // trancendental functions + // since their model value cannot even be computed exactly } } - } while (needsRecheck); - } - - // Did we internally determine a model exists? If so, we need to record some - // information in the theory engine's model class. - if (d_builtModel.get()) - { - if (e < Theory::EFFORT_LAST_CALL) - { - // don't need to build the model yet - return; + // we are incomplete + if (options::nlExtIncPrecision() && d_model.usedApproximate()) + { + d_taylor_degree++; + needsRecheck = true; + // increase precision for PI? + // Difficult since Taylor series is very slow to converge + Trace("nl-ext") << "...increment Taylor degree to " << d_taylor_degree + << std::endl; + } + else + { + Trace("nl-ext") << "...failed to send lemma in " + "NonLinearExtension, set incomplete" + << std::endl; + d_containing.getOutputChannel().setIncomplete(); + } } - // record approximations in the model - d_model.recordApproximations(); - return; - } + } while (needsRecheck); + + // did not add lemmas + return false; } void NonlinearExtension::presolve() diff --git a/src/theory/arith/nonlinear_extension.h b/src/theory/arith/nonlinear_extension.h index b76877414..1690c9334 100644 --- a/src/theory/arith/nonlinear_extension.h +++ b/src/theory/arith/nonlinear_extension.h @@ -63,8 +63,7 @@ typedef std::map NodeMultiset; * * It's main functionality is a check(...) method, * which is called by TheoryArithPrivate either: - * (1) at full effort with no conflicts or lemmas emitted, - * or + * (1) at full effort with no conflicts or lemmas emitted, or * (2) at last call effort. * In this method, this class calls d_out->lemma(...) * for valid arithmetic theory lemmas, based on the current set of assertions, @@ -115,9 +114,8 @@ class NonlinearExtension { const std::vector& exp) const; /** Check at effort level e. * - * This call may result in (possibly multiple) - * calls to d_out->lemma(...) where d_out - * is the output channel of TheoryArith. + * This call may result in (possibly multiple) calls to d_out->lemma(...) + * where d_out is the output channel of TheoryArith. */ void check(Theory::Effort e); /** Does this class need a call to check(...) at last call effort? */ @@ -131,6 +129,22 @@ class NonlinearExtension { */ void presolve(); private: + /** Model-based refinement + * + * This is the main entry point of this class for generating lemmas on the + * output channel of the theory of arithmetic. + * + * It is currently run at last call effort. It applies lemma schemas + * described in Reynolds et al. FroCoS 2017 that are based on ruling out + * the current candidate model. + * + * This function returns true if a lemma was sent out on the output + * channel of the theory of arithmetic. Otherwise, it returns false. In the + * latter case, the model object d_model may have information regarding + * how to construct a model, in the case that we determined the problem + * is satisfiable. + */ + bool modelBasedRefinement(); /** returns true if the multiset containing the * factors of monomial a is a subset of the multiset * containing the factors of monomial b. -- 2.30.2