From e64a4bc87d2d98e04e8450d4ad9856bce3494c78 Mon Sep 17 00:00:00 2001 From: Gereon Kremer Date: Tue, 5 Oct 2021 13:06:53 -0700 Subject: [PATCH] First round of refactoring on NlModel (#7255) This PR performs a first refactoring on the NlModel class. It improves model value computation, comparison and stores the model substitutions in a map (instead of two vectors). --- src/expr/subs.cpp | 12 +- src/expr/subs.h | 4 + src/theory/arith/arith_utilities.cpp | 19 +- src/theory/arith/arith_utilities.h | 5 +- src/theory/arith/nl/cad_solver.cpp | 4 +- src/theory/arith/nl/ext/monomial_check.cpp | 4 +- src/theory/arith/nl/nl_lemma_utils.cpp | 2 +- src/theory/arith/nl/nl_model.cpp | 294 ++++++++---------- src/theory/arith/nl/nl_model.h | 99 +++--- src/theory/arith/nl/nonlinear_extension.cpp | 9 +- src/theory/arith/nl/poly_conversion.h | 4 +- .../transcendental/transcendental_solver.cpp | 16 +- 12 files changed, 226 insertions(+), 246 deletions(-) diff --git a/src/expr/subs.cpp b/src/expr/subs.cpp index b140a4190..f08cf18c1 100644 --- a/src/expr/subs.cpp +++ b/src/expr/subs.cpp @@ -44,6 +44,16 @@ Node Subs::getSubs(Node v) const return d_subs[i]; } +std::optional Subs::find(TNode v) const +{ + auto it = std::find(d_vars.begin(), d_vars.end(), v); + if (it == d_vars.end()) + { + return {}; + } + return d_subs[std::distance(d_vars.begin(), it)]; +} + void Subs::add(Node v) { SkolemManager* sm = NodeManager::currentNM()->getSkolemManager(); @@ -62,7 +72,7 @@ void Subs::add(const std::vector& vs) void Subs::add(Node v, Node s) { - Assert(v.getType().isComparableTo(s.getType())); + Assert(s.isNull() || v.getType().isComparableTo(s.getType())); d_vars.push_back(v); d_subs.push_back(s); } diff --git a/src/expr/subs.h b/src/expr/subs.h index afde63b6e..245c6d77a 100644 --- a/src/expr/subs.h +++ b/src/expr/subs.h @@ -17,7 +17,9 @@ #define CVC5__EXPR__SUBS_H #include +#include #include + #include "expr/node.h" namespace cvc5 { @@ -41,6 +43,8 @@ class Subs bool contains(Node v) const; /** Get the substitution for v if it exists, or null otherwise */ Node getSubs(Node v) const; + /** Find the substitution for v, or return std::nullopt */ + std::optional find(TNode v) const; /** Add v -> k for fresh skolem of the same type as v */ void add(Node v); /** Add v -> k for fresh skolem of the same type as v for each v in vs */ diff --git a/src/theory/arith/arith_utilities.cpp b/src/theory/arith/arith_utilities.cpp index 75edc49f5..5645542d0 100644 --- a/src/theory/arith/arith_utilities.cpp +++ b/src/theory/arith/arith_utilities.cpp @@ -203,31 +203,26 @@ void printRationalApprox(const char* c, Node cr, unsigned prec) } } -Node arithSubstitute(Node n, std::vector& vars, std::vector& subs) +Node arithSubstitute(Node n, const Subs& sub) { - Assert(vars.size() == subs.size()); NodeManager* nm = NodeManager::currentNM(); std::unordered_map visited; - std::unordered_map::iterator it; - std::vector::iterator itv; std::vector visit; - TNode cur; - Kind ck; visit.push_back(n); do { - cur = visit.back(); + TNode cur = visit.back(); visit.pop_back(); - it = visited.find(cur); + auto it = visited.find(cur); if (it == visited.end()) { visited[cur] = Node::null(); - ck = cur.getKind(); - itv = std::find(vars.begin(), vars.end(), cur); - if (itv != vars.end()) + Kind ck = cur.getKind(); + auto s = sub.find(cur); + if (s) { - visited[cur] = subs[std::distance(vars.begin(), itv)]; + visited[cur] = *s; } else if (cur.getNumChildren() == 0) { diff --git a/src/theory/arith/arith_utilities.h b/src/theory/arith/arith_utilities.h index b842ae58e..0d7f214d7 100644 --- a/src/theory/arith/arith_utilities.h +++ b/src/theory/arith/arith_utilities.h @@ -24,6 +24,7 @@ #include "context/cdhashset.h" #include "expr/node.h" +#include "expr/subs.h" #include "theory/arith/arithvar.h" #include "util/dense_map.h" #include "util/integer.h" @@ -313,13 +314,13 @@ void printRationalApprox(const char* c, Node cr, unsigned prec = 5); /** Arithmetic substitute * - * This computes the substitution n { vars -> subs }, but with the caveat + * This computes the substitution n { subs }, but with the caveat * that subterms of n that belong to a theory other than arithmetic are * not traversed. In other words, terms that belong to other theories are * treated as atomic variables. For example: * (5*f(x) + 7*x ){ x -> 3 } returns 5*f(x) + 7*3. */ -Node arithSubstitute(Node n, std::vector& vars, std::vector& subs); +Node arithSubstitute(Node n, const Subs& sub); /** Make the node u >= a ^ a >= l */ Node mkBounded(Node l, Node a, Node u); diff --git a/src/theory/arith/nl/cad_solver.cpp b/src/theory/arith/nl/cad_solver.cpp index ebaeb9d61..132cb9795 100644 --- a/src/theory/arith/nl/cad_solver.cpp +++ b/src/theory/arith/nl/cad_solver.cpp @@ -180,11 +180,11 @@ bool CadSolver::constructModelIfAvailable(std::vector& assertions) Node value = value_to_node(d_CAC.getModel().get(v), d_ranVariable); if (value.isConst()) { - d_model.addCheckModelSubstitution(variable, value); + d_model.addSubstitution(variable, value); } else { - d_model.addCheckModelWitness(variable, value); + d_model.addWitness(variable, value); } Trace("nl-cad") << "-> " << v << " = " << value << std::endl; } diff --git a/src/theory/arith/nl/ext/monomial_check.cpp b/src/theory/arith/nl/ext/monomial_check.cpp index 330cd57a3..b077dcfd0 100644 --- a/src/theory/arith/nl/ext/monomial_check.cpp +++ b/src/theory/arith/nl/ext/monomial_check.cpp @@ -402,7 +402,7 @@ bool MonomialCheck::compareMonomial( if (a_index == vla.size() && b_index == vlb.size()) { // finished, compare absolute value of abstract model values - int modelStatus = d_data->d_model.compare(oa, ob, false, true) * -2; + int modelStatus = d_data->d_model.compare(oa, ob, false, true) * 2; Trace("nl-ext-comp") << "...finished comparison with " << oa << " <" << status << "> " << ob << ", model status = " << modelStatus << std::endl; @@ -677,7 +677,7 @@ void MonomialCheck::assignOrderIds(std::vector& vars, { Node vv = d_data->d_model.computeModelValue( d_order_points[order_index], isConcrete); - if (d_data->d_model.compareValue(v, vv, isAbsolute) <= 0) + if (d_data->d_model.compareValue(v, vv, isAbsolute) >= 0) { counter++; Trace("nl-ext-mvo") << "O[" << d_order_points[order_index] diff --git a/src/theory/arith/nl/nl_lemma_utils.cpp b/src/theory/arith/nl/nl_lemma_utils.cpp index 3e2ebe87e..18e296da7 100644 --- a/src/theory/arith/nl/nl_lemma_utils.cpp +++ b/src/theory/arith/nl/nl_lemma_utils.cpp @@ -45,7 +45,7 @@ bool SortNlModel::operator()(Node i, Node j) { return i < j; } - return d_reverse_order ? cv < 0 : cv > 0; + return d_reverse_order ? cv > 0 : cv < 0; } bool SortNonlinearDegree::operator()(Node i, Node j) diff --git a/src/theory/arith/nl/nl_model.cpp b/src/theory/arith/nl/nl_model.cpp index ca75a1a06..427d203ea 100644 --- a/src/theory/arith/nl/nl_model.cpp +++ b/src/theory/arith/nl/nl_model.cpp @@ -43,18 +43,12 @@ NlModel::NlModel() : d_used_approx(false) NlModel::~NlModel() {} -void NlModel::reset(TheoryModel* m, std::map& arithModel) +void NlModel::reset(TheoryModel* m, const std::map& arithModel) { d_model = m; - d_mv[0].clear(); - d_mv[1].clear(); - d_arithVal.clear(); - // process arithModel - std::map::iterator it; - for (const std::pair& m2 : arithModel) - { - d_arithVal[m2.first] = m2.second; - } + d_concreteModelCache.clear(); + d_abstractModelCache.clear(); + d_arithVal = arithModel; } void NlModel::resetCheck() @@ -63,46 +57,42 @@ void NlModel::resetCheck() d_check_model_solved.clear(); d_check_model_bounds.clear(); d_check_model_witnesses.clear(); - d_check_model_vars.clear(); - d_check_model_subs.clear(); + d_substitutions.clear(); } -Node NlModel::computeConcreteModelValue(Node n) +Node NlModel::computeConcreteModelValue(TNode n) { return computeModelValue(n, true); } -Node NlModel::computeAbstractModelValue(Node n) +Node NlModel::computeAbstractModelValue(TNode n) { return computeModelValue(n, false); } -Node NlModel::computeModelValue(Node n, bool isConcrete) +Node NlModel::computeModelValue(TNode n, bool isConcrete) { - unsigned index = isConcrete ? 0 : 1; - std::map::iterator it = d_mv[index].find(n); - if (it != d_mv[index].end()) + auto& cache = isConcrete ? d_concreteModelCache : d_abstractModelCache; + if (auto it = cache.find(n); it != cache.end()) { return it->second; } - Trace("nl-ext-mv-debug") << "computeModelValue " << n << ", index=" << index - << std::endl; + Trace("nl-ext-mv-debug") << "computeModelValue " << n + << ", isConcrete=" << isConcrete << std::endl; Node ret; - Kind nk = n.getKind(); if (n.isConst()) { ret = n; } - else if (!isConcrete && hasTerm(n)) + else if (!isConcrete && hasLinearModelValue(n, ret)) { // use model value for abstraction - ret = getRepresentative(n); } else if (n.getNumChildren() == 0) { // we are interested in the exact value of PI, which cannot be computed. // hence, we return PI itself when asked for the concrete value. - if (nk == PI) + if (n.getKind() == PI) { ret = n; } @@ -114,7 +104,7 @@ Node NlModel::computeModelValue(Node n, bool isConcrete) else { // otherwise, compute true value - TheoryId ctid = theory::kindToTheoryId(nk); + TheoryId ctid = theory::kindToTheoryId(n.getKind()); if (ctid != THEORY_ARITH && ctid != THEORY_BOOL && ctid != THEORY_BUILTIN) { // we directly look up terms not belonging to arithmetic @@ -125,65 +115,28 @@ Node NlModel::computeModelValue(Node n, bool isConcrete) std::vector children; if (n.getMetaKind() == metakind::PARAMETERIZED) { - children.push_back(n.getOperator()); + children.emplace_back(n.getOperator()); } - for (unsigned i = 0, nchild = n.getNumChildren(); i < nchild; i++) + for (size_t i = 0, nchild = n.getNumChildren(); i < nchild; i++) { - Node mc = computeModelValue(n[i], isConcrete); - children.push_back(mc); + children.emplace_back(computeModelValue(n[i], isConcrete)); } - ret = NodeManager::currentNM()->mkNode(nk, children); + ret = NodeManager::currentNM()->mkNode(n.getKind(), children); ret = Rewriter::rewrite(ret); } } - Trace("nl-ext-mv-debug") << "computed " << (index == 0 ? "M" : "M_A") << "[" + Trace("nl-ext-mv-debug") << "computed " << (isConcrete ? "M" : "M_A") << "[" << n << "] = " << ret << std::endl; - d_mv[index][n] = ret; + cache[n] = ret; return ret; } -bool NlModel::hasTerm(Node n) const +int NlModel::compare(TNode i, TNode j, bool isConcrete, bool isAbsolute) { - return d_arithVal.find(n) != d_arithVal.end(); -} - -Node NlModel::getRepresentative(Node n) const -{ - if (n.isConst()) - { - return n; - } - std::map::const_iterator it = d_arithVal.find(n); - if (it != d_arithVal.end()) - { - AlwaysAssert(it->second.isConst()); - return it->second; - } - return d_model->getRepresentative(n); -} - -Node NlModel::getValueInternal(Node n) -{ - if (n.isConst()) + if (i == j) { - return n; + return 0; } - std::map::const_iterator it = d_arithVal.find(n); - if (it != d_arithVal.end()) - { - AlwaysAssert(it->second.isConst()); - return it->second; - } - // It is unconstrained in the model, return 0. We additionally add it - // to mapping from the linear solver. This ensures that if the nonlinear - // solver assumes that n = 0, then this assumption is recorded in the overall - // model. - d_arithVal[n] = d_zero; - return d_zero; -} - -int NlModel::compare(Node i, Node j, bool isConcrete, bool isAbsolute) -{ Node ci = computeModelValue(i, isConcrete); Node cj = computeModelValue(j, isConcrete); if (ci.isConst()) @@ -197,27 +150,24 @@ int NlModel::compare(Node i, Node j, bool isConcrete, bool isAbsolute) return cj.isConst() ? -1 : 0; } -int NlModel::compareValue(Node i, Node j, bool isAbsolute) const +int NlModel::compareValue(TNode i, TNode j, bool isAbsolute) const { Assert(i.isConst() && j.isConst()); - int ret; if (i == j) { - ret = 0; + return 0; } - else if (!isAbsolute) + if (!isAbsolute) { - ret = i.getConst() < j.getConst() ? 1 : -1; + return i.getConst() < j.getConst() ? -1 : 1; } - else + Rational iabs = i.getConst().abs(); + Rational jabs = j.getConst().abs(); + if (iabs == jabs) { - ret = (i.getConst().abs() == j.getConst().abs() - ? 0 - : (i.getConst().abs() < j.getConst().abs() - ? 1 - : -1)); + return 0; } - return ret; + return iabs < jabs ? -1 : 1; } bool NlModel::checkModel(const std::vector& assertions, @@ -262,7 +212,7 @@ bool NlModel::checkModel(const std::vector& assertions, && !isTranscendentalKind(k)) { // if we have not set an approximate bound for it - if (!hasCheckModelAssignment(cur)) + if (!hasAssignment(cur)) { // set its exact model value in the substitution Node curv = computeConcreteModelValue(cur); @@ -273,7 +223,7 @@ bool NlModel::checkModel(const std::vector& assertions, printRationalApprox("nl-ext-cm", curv); Trace("nl-ext-cm") << std::endl; } - bool ret = addCheckModelSubstitution(cur, curv); + bool ret = addSubstitution(cur, curv); AlwaysAssert(ret); } } @@ -294,10 +244,9 @@ bool NlModel::checkModel(const std::vector& assertions, { Node av = a; // apply the substitution to a - if (!d_check_model_vars.empty()) + if (!d_substitutions.empty()) { - av = arithSubstitute(av, d_check_model_vars, d_check_model_subs); - av = Rewriter::rewrite(av); + av = Rewriter::rewrite(arithSubstitute(av, d_substitutions)); } // simple check literal if (!simpleCheckModelLit(av)) @@ -321,14 +270,13 @@ bool NlModel::checkModel(const std::vector& assertions, return true; } -bool NlModel::addCheckModelSubstitution(TNode v, TNode s) +bool NlModel::addSubstitution(TNode v, TNode s) { // should not substitute the same variable twice Trace("nl-ext-model") << "* check model substitution : " << v << " -> " << s << std::endl; // should not set exact bound more than once - if (std::find(d_check_model_vars.begin(), d_check_model_vars.end(), v) - != d_check_model_vars.end()) + if (d_substitutions.contains(v)) { Trace("nl-ext-model") << "...ERROR: already has value." << std::endl; // this should never happen since substitutions should be applied eagerly @@ -352,37 +300,31 @@ bool NlModel::addCheckModelSubstitution(TNode v, TNode s) Assert(d_check_model_witnesses.find(v) == d_check_model_witnesses.end()) << "We tried to add a substitution where we already had a witness term." << std::endl; - std::vector varsTmp; - varsTmp.push_back(v); - std::vector subsTmp; - subsTmp.push_back(s); - for (unsigned i = 0, size = d_check_model_subs.size(); i < size; i++) + Subs tmp; + tmp.add(v, s); + for (auto& sub : d_substitutions.d_subs) { - Node ms = d_check_model_subs[i]; - Node mss = arithSubstitute(ms, varsTmp, subsTmp); - if (mss != ms) + Node ms = arithSubstitute(sub, tmp); + if (ms != sub) { - mss = Rewriter::rewrite(mss); + sub = Rewriter::rewrite(ms); } - d_check_model_subs[i] = mss; } - d_check_model_vars.push_back(v); - d_check_model_subs.push_back(s); + d_substitutions.add(v, s); return true; } -bool NlModel::addCheckModelBound(TNode v, TNode l, TNode u) +bool NlModel::addBound(TNode v, TNode l, TNode u) { Trace("nl-ext-model") << "* check model bound : " << v << " -> [" << l << " " << u << "]" << std::endl; if (l == u) { // bound is exact, can add as substitution - return addCheckModelSubstitution(v, l); + return addSubstitution(v, l); } // should not set a bound for a value that is exact - if (std::find(d_check_model_vars.begin(), d_check_model_vars.end(), v) - != d_check_model_vars.end()) + if (d_substitutions.contains(v)) { Trace("nl-ext-model") << "...ERROR: setting bound for variable that already has exact value." @@ -405,13 +347,12 @@ bool NlModel::addCheckModelBound(TNode v, TNode l, TNode u) return true; } -bool NlModel::addCheckModelWitness(TNode v, TNode w) +bool NlModel::addWitness(TNode v, TNode w) { Trace("nl-ext-model") << "* check model witness : " << v << " -> " << w << std::endl; // should not set a witness for a value that is already set - if (std::find(d_check_model_vars.begin(), d_check_model_vars.end(), v) - != d_check_model_vars.end()) + if (d_substitutions.contains(v)) { Trace("nl-ext-model") << "...ERROR: setting witness for variable that " "already has a constant value." @@ -423,20 +364,6 @@ bool NlModel::addCheckModelWitness(TNode v, TNode w) return true; } -bool NlModel::hasCheckModelAssignment(Node v) const -{ - if (d_check_model_bounds.find(v) != d_check_model_bounds.end()) - { - return true; - } - if (d_check_model_witnesses.find(v) != d_check_model_witnesses.end()) - { - return true; - } - return std::find(d_check_model_vars.begin(), d_check_model_vars.end(), v) - != d_check_model_vars.end(); -} - void NlModel::setUsedApproximate() { d_used_approx = true; } bool NlModel::usedApproximate() const { return d_used_approx; } @@ -446,9 +373,9 @@ bool NlModel::solveEqualitySimple(Node eq, std::vector& lemmas) { Node seq = eq; - if (!d_check_model_vars.empty()) + if (!d_substitutions.empty()) { - seq = arithSubstitute(eq, d_check_model_vars, d_check_model_subs); + seq = arithSubstitute(eq, d_substitutions); seq = Rewriter::rewrite(seq); if (seq.isConst()) { @@ -545,7 +472,7 @@ bool NlModel::solveEqualitySimple(Node eq, { Trace("nl-ext-cm-debug") << "check subs var : " << uv << std::endl; // cannot already have a bound - if (uv.isVar() && !hasCheckModelAssignment(uv)) + if (uv.isVar() && !hasAssignment(uv)) { Node slv; Node veqc; @@ -560,7 +487,7 @@ bool NlModel::solveEqualitySimple(Node eq, { Trace("nl-ext-cm") << "check-model-subs : " << uv << " -> " << slv << std::endl; - bool ret = addCheckModelSubstitution(uv, slv); + bool ret = addSubstitution(uv, slv); if (ret) { Trace("nl-ext-cms") << "...success, model substitution " << uv @@ -577,7 +504,7 @@ bool NlModel::solveEqualitySimple(Node eq, { Trace("nl-ext-cm-debug") << "check set var : " << uvf << std::endl; // cannot already have a bound - if (uvf.isVar() && !hasCheckModelAssignment(uvf)) + if (uvf.isVar() && !hasAssignment(uvf)) { Node uvfv = computeConcreteModelValue(uvf); if (Trace.isOn("nl-ext-cm")) @@ -586,7 +513,7 @@ bool NlModel::solveEqualitySimple(Node eq, printRationalApprox("nl-ext-cm", uvfv); Trace("nl-ext-cm") << std::endl; } - bool ret = addCheckModelSubstitution(uvf, uvfv); + bool ret = addSubstitution(uvf, uvfv); // recurse return ret ? solveEqualitySimple(eq, d, lemmas) : false; } @@ -618,7 +545,7 @@ bool NlModel::solveEqualitySimple(Node eq, printRationalApprox("nl-ext-cm", val); Trace("nl-ext-cm") << std::endl; } - bool ret = addCheckModelSubstitution(var, val); + bool ret = addSubstitution(var, val); if (ret) { Trace("nl-ext-cms") << "...success, solved linear." << std::endl; @@ -647,7 +574,7 @@ bool NlModel::solveEqualitySimple(Node eq, Trace("nl-ext-cms") << "...fail due to negative discriminant." << std::endl; return false; } - if (hasCheckModelAssignment(var)) + if (hasAssignment(var)) { Trace("nl-ext-cms") << "...fail due to bounds on variable to solve for." << std::endl; @@ -730,8 +657,7 @@ bool NlModel::solveEqualitySimple(Node eq, printRationalApprox("nl-ext-cm", bounds[r_use_index][1]); Trace("nl-ext-cm") << std::endl; } - bool ret = - addCheckModelBound(var, bounds[r_use_index][0], bounds[r_use_index][1]); + bool ret = addBound(var, bounds[r_use_index][0], bounds[r_use_index][1]); if (ret) { d_check_model_solved[eq] = var; @@ -829,8 +755,7 @@ bool NlModel::simpleCheckModelLit(Node lit) ? vs_invalid[0] : nm->mkNode(PLUS, vs_invalid)); // substitution to try - std::vector qvars; - std::vector qsubs; + Subs qsub; for (const Node& v : vs) { // is it a valid variable? @@ -882,7 +807,7 @@ bool NlModel::simpleCheckModelLit(Node lit) Assert(boundn[0].getConst() <= boundn[1].getConst()); Node s; - qvars.push_back(v); + qsub.add(v, Node()); if (cmp[0] != cmp[1]) { Assert(!cmp[0] && cmp[1]); @@ -899,10 +824,9 @@ bool NlModel::simpleCheckModelLit(Node lit) Node tcmpn[2]; for (unsigned r = 0; r < 2; r++) { - qsubs.push_back(boundn[r]); - Node ts = arithSubstitute(t, qvars, qsubs); + qsub.d_subs.back() = boundn[r]; + Node ts = arithSubstitute(t, qsub); tcmpn[r] = Rewriter::rewrite(ts); - qsubs.pop_back(); } Node tcmp = nm->mkNode(LT, tcmpn[0], tcmpn[1]); Trace("nl-ext-cms-debug") @@ -932,16 +856,15 @@ bool NlModel::simpleCheckModelLit(Node lit) s = boundn[bindex_use]; } Assert(!s.isNull()); - qsubs.push_back(s); + qsub.d_subs.back() = s; Trace("nl-ext-cms") << "* set bound based on quadratic : " << v << " -> " << s << std::endl; } } } - if (!qvars.empty()) + if (!qsub.empty()) { - Assert(qvars.size() == qsubs.size()); - Node slit = arithSubstitute(lit, qvars, qsubs); + Node slit = arithSubstitute(lit, qsub); slit = Rewriter::rewrite(slit); return simpleCheckModelLit(slit); } @@ -1242,21 +1165,26 @@ void NlModel::printModelValue(const char* c, Node n, unsigned prec) const if (Trace.isOn(c)) { Trace(c) << " " << n << " -> "; - for (int i = 1; i >= 0; --i) + const Node& aval = d_abstractModelCache.at(n); + if (aval.isConst()) { - std::map::const_iterator it = d_mv[i].find(n); - Assert(it != d_mv[i].end()); - if (it->second.isConst()) - { - printRationalApprox(c, it->second, prec); - } - else - { - Trace(c) << "?"; - } - Trace(c) << (i == 1 ? " [actual: " : " ]"); + printRationalApprox(c, aval, prec); + } + else + { + Trace(c) << "?"; + } + Trace(c) << " [actual: "; + const Node& cval = d_concreteModelCache.at(n); + if (cval.isConst()) + { + printRationalApprox(c, cval, prec); + } + else + { + Trace(c) << "?"; } - Trace(c) << std::endl; + Trace(c) << " ]" << std::endl; } } @@ -1316,13 +1244,12 @@ void NlModel::getModelValueRepair( // special kind approximation of the form (witness x. x = exact_value). // Notice that the above term gets rewritten such that the choice function // is eliminated. - for (size_t i = 0, num = d_check_model_vars.size(); i < num; i++) + for (size_t i = 0; i < d_substitutions.size(); ++i) { - Node v = d_check_model_vars[i]; - Node s = d_check_model_subs[i]; // overwrite - arithModel[v] = s; - Trace("nl-model") << v << " solved is " << s << std::endl; + arithModel[d_substitutions.d_vars[i]] = d_substitutions.d_subs[i]; + Trace("nl-model") << d_substitutions.d_vars[i] << " solved is " + << d_substitutions.d_subs[i] << std::endl; } // multiplication terms should not be given values; their values are @@ -1341,6 +1268,49 @@ void NlModel::getModelValueRepair( } } +Node NlModel::getValueInternal(TNode n) +{ + if (n.isConst()) + { + return n; + } + if (auto it = d_arithVal.find(n); it != d_arithVal.end()) + { + AlwaysAssert(it->second.isConst()); + return it->second; + } + // It is unconstrained in the model, return 0. We additionally add it + // to mapping from the linear solver. This ensures that if the nonlinear + // solver assumes that n = 0, then this assumption is recorded in the overall + // model. + d_arithVal[n] = d_zero; + return d_zero; +} + +bool NlModel::hasAssignment(Node v) const +{ + if (d_check_model_bounds.find(v) != d_check_model_bounds.end()) + { + return true; + } + if (d_check_model_witnesses.find(v) != d_check_model_witnesses.end()) + { + return true; + } + return (d_substitutions.contains(v)); +} + +bool NlModel::hasLinearModelValue(TNode v, Node& val) const +{ + auto it = d_arithVal.find(v); + if (it != d_arithVal.end()) + { + val = it->second; + return true; + } + return false; +} + } // namespace nl } // namespace arith } // namespace theory diff --git a/src/theory/arith/nl/nl_model.h b/src/theory/arith/nl/nl_model.h index 526a93934..b3b841eab 100644 --- a/src/theory/arith/nl/nl_model.h +++ b/src/theory/arith/nl/nl_model.h @@ -22,6 +22,7 @@ #include "expr/kind.h" #include "expr/node.h" +#include "expr/subs.h" namespace cvc5 { @@ -59,7 +60,7 @@ class NlModel * where m is the model of the theory of arithmetic. This method resets the * cache of computed model values. */ - void reset(TheoryModel* m, std::map& arithModel); + void reset(TheoryModel* m, const std::map& arithModel); /** * This method is called when the non-linear arithmetic solver restarts * its computation of lemmas and models during a last call effort check. @@ -87,9 +88,9 @@ class NlModel * whereas: * computeModelValue( a*b, false ) = 5 */ - Node computeConcreteModelValue(Node n); - Node computeAbstractModelValue(Node n); - Node computeModelValue(Node n, bool isConcrete); + Node computeConcreteModelValue(TNode n); + Node computeAbstractModelValue(TNode n); + Node computeModelValue(TNode n, bool isConcrete); /** * Compare arithmetic terms i and j based an ordering. @@ -101,10 +102,10 @@ class NlModel * otherwise, we consider their abstract model values. For definitions of * concrete vs abstract model values, see NlModel::computeModelValue. * - * If isAbsolute is true, we compare the absolute value of thee above + * If isAbsolute is true, we compare the absolute value of the above * values. */ - int compare(Node i, Node j, bool isConcrete, bool isAbsolute); + int compare(TNode i, TNode j, bool isConcrete, bool isAbsolute); /** * Compare arithmetic terms i and j based an ordering. * @@ -113,38 +114,31 @@ class NlModel * * If isAbsolute is true, we compare the absolute value of i and j */ - int compareValue(Node i, Node j, bool isAbsolute) const; + int compareValue(TNode i, TNode j, bool isAbsolute) const; //------------------------------ recording model substitutions and bounds /** * Adds the model substitution v -> s. This applies the substitution - * { v -> s } to each term in d_check_model_subs and adds v,s to - * d_check_model_vars and d_check_model_subs respectively. + * { v -> s } to each term in d_substitutions and then adds v,s to + * d_substitutions. * If this method returns false, then the substitution v -> s is inconsistent * with the current substitution and bounds. */ - bool addCheckModelSubstitution(TNode v, TNode s); + bool addSubstitution(TNode v, TNode s); /** * Adds the bound x -> < l, u > to the map above, and records the * approximation ( x, l <= x <= u ) in the model. This method returns false * if the bound is inconsistent with the current model substitution or * bounds. */ - bool addCheckModelBound(TNode v, TNode l, TNode u); + bool addBound(TNode v, TNode l, TNode u); /** * Adds a model witness v -> w to the underlying theory model. * The witness should only contain a single variable v and evaluate to true * for exactly one value of v. The variable v is then (implicitly, * declaratively) assigned to this single value that satisfies the witness w. */ - bool addCheckModelWitness(TNode v, TNode w); - /** - * Have we assigned v in the current checkModel(...) call? - * - * This method returns true if variable v is in the domain of - * d_check_model_bounds or if it occurs in d_check_model_vars. - */ - bool hasCheckModelAssignment(Node v) const; + bool addWitness(TNode v, TNode w); /** * Checks the current model based on solving for equalities, and using error * bounds on the Taylor approximation. @@ -198,21 +192,53 @@ class NlModel bool witnessToValue); private: + /** Cache for concrete model values */ + std::map d_concreteModelCache; + /** Cache for abstract model values */ + std::map d_abstractModelCache; + /** The current model */ TheoryModel* d_model; + + /** + * The values that the arithmetic theory solver assigned in the model. This + * corresponds to the set of equalities that linear solver (via TheoryArith) + * is currently sending to TheoryModel during collectModelValues, plus + * additional entries x -> 0 for variables that were unassigned by the linear + * solver. + */ + std::map d_arithVal; + + /** + * A substitution from variables that appear in assertions to a solved form + * term. + */ + Subs d_substitutions; + /** Get the model value of n from the model object above */ - Node getValueInternal(Node n); - /** Does the equality engine of the model have term n? */ - bool hasTerm(Node n) const; - /** Get the representative of n in the model */ - Node getRepresentative(Node n) const; + Node getValueInternal(TNode n); + + /** + * Have we assigned v in the current checkModel(...) call? + * + * This method returns true if variable v is in the domain of + * d_check_model_bounds or if it occurs in d_substitutions. + */ + bool hasAssignment(Node v) const; + + /** + * Checks whether we have a linear model value for v, i.e. whether v is + * contained in d_arithVal. If so, we also store the value that v is mapped + * to in val. + */ + bool hasLinearModelValue(TNode v, Node& val) const; //---------------------------check model /** * This method is used during checkModel(...). It takes as input an * equality eq. If it returns true, then eq is correct-by-construction based * on the information stored in our model representation (see - * d_check_model_vars, d_check_model_subs, d_check_model_bounds), and eq + * d_substitutions, d_check_model_bounds), and eq * is added to d_check_model_solved. The equality eq may involve any * number of variables, and monomials of arbitrary degree. If this method * returns false, then we did not show that the equality was true in the @@ -267,29 +293,6 @@ class NlModel Node d_true; Node d_false; Node d_null; - /** - * The values that the arithmetic theory solver assigned in the model. This - * corresponds to the set of equalities that linear solver (via TheoryArith) - * is currently sending to TheoryModel during collectModelValues, plus - * additional entries x -> 0 for variables that were unassigned by the linear - * solver. - */ - std::map d_arithVal; - /** - * cache of model values - * - * Stores the the concrete/abstract model values. This is a cache of the - * computeModelValue method. - */ - std::map d_mv[2]; - /** - * A substitution from variables that appear in assertions to a solved form - * term. These vectors are ordered in the form: - * x_1 -> t_1 ... x_n -> t_n - * where x_i is not in the free variables of t_j for j>=i. - */ - std::vector d_check_model_vars; - std::vector d_check_model_subs; /** * lower and upper bounds for check model * diff --git a/src/theory/arith/nl/nonlinear_extension.cpp b/src/theory/arith/nl/nonlinear_extension.cpp index f80717b57..207907fcc 100644 --- a/src/theory/arith/nl/nonlinear_extension.cpp +++ b/src/theory/arith/nl/nonlinear_extension.cpp @@ -96,17 +96,16 @@ void NonlinearExtension::processSideEffect(const NlLemma& se) void NonlinearExtension::computeRelevantAssertions( const std::vector& assertions, std::vector& keep) { - Trace("nl-ext-rlv") << "Compute relevant assertions..." << std::endl; - Valuation v = d_containing.getValuation(); + const Valuation& v = d_containing.getValuation(); for (const Node& a : assertions) { if (v.isRelevant(a)) { - keep.push_back(a); + keep.emplace_back(a); } } - Trace("nl-ext-rlv") << "...keep " << keep.size() << "/" << assertions.size() - << " assertions" << std::endl; + Trace("nl-ext-rlv") << "...relevant assertions: " << keep.size() << "/" + << assertions.size() << std::endl; } void NonlinearExtension::getAssertions(std::vector& assertions) diff --git a/src/theory/arith/nl/poly_conversion.h b/src/theory/arith/nl/poly_conversion.h index db64320d5..dddde3c0f 100644 --- a/src/theory/arith/nl/poly_conversion.h +++ b/src/theory/arith/nl/poly_conversion.h @@ -98,8 +98,8 @@ std::pair as_poly_constraint( /** * Transforms a real algebraic number to a node suitable for putting it into a * model. The resulting node can be either a constant (suitable for - * addCheckModelSubstitution) or a witness term (suitable for - * addCheckModelWitness). + * addSubstitution) or a witness term (suitable for + * addWitness). */ Node ran_to_node(const RealAlgebraicNumber& ran, const Node& ran_variable); diff --git a/src/theory/arith/nl/transcendental/transcendental_solver.cpp b/src/theory/arith/nl/transcendental/transcendental_solver.cpp index c7bb14b3f..978823a22 100644 --- a/src/theory/arith/nl/transcendental/transcendental_solver.cpp +++ b/src/theory/arith/nl/transcendental/transcendental_solver.cpp @@ -83,12 +83,10 @@ void TranscendentalSolver::initLastCall(const std::vector& xts) bool TranscendentalSolver::preprocessAssertionsCheckModel( std::vector& assertions) { - std::vector pvars; - std::vector psubs; - for (const std::pair& tb : d_tstate.d_trMaster) + Subs subs; + for (const auto& sub : d_tstate.d_trMaster) { - pvars.push_back(tb.first); - psubs.push_back(tb.second); + subs.add(sub.first, sub.second); } // initialize representation of assertions @@ -97,9 +95,9 @@ bool TranscendentalSolver::preprocessAssertionsCheckModel( { Node pa = a; - if (!pvars.empty()) + if (!subs.empty()) { - pa = arithSubstitute(pa, pvars, psubs); + pa = arithSubstitute(pa, subs); pa = Rewriter::rewrite(pa); } if (!pa.isConst() || !pa.getConst()) @@ -145,8 +143,8 @@ bool TranscendentalSolver::preprocessAssertionsCheckModel( Trace("nl-ext-cm") << "...bound for " << stf << " : [" << bounds.first << ", " << bounds.second << "]" << std::endl; - success = d_tstate.d_model.addCheckModelBound( - stf, bounds.first, bounds.second); + success = + d_tstate.d_model.addBound(stf, bounds.first, bounds.second); } } } -- 2.30.2