First round of refactoring on NlModel (#7255)
authorGereon Kremer <nafur42@gmail.com>
Tue, 5 Oct 2021 20:06:53 +0000 (13:06 -0700)
committerGitHub <noreply@github.com>
Tue, 5 Oct 2021 20:06:53 +0000 (20:06 +0000)
This PR performs a first refactoring on the NlModel class. It improves model value computation, comparison and stores the model substitutions in a map (instead of two vectors).

12 files changed:
src/expr/subs.cpp
src/expr/subs.h
src/theory/arith/arith_utilities.cpp
src/theory/arith/arith_utilities.h
src/theory/arith/nl/cad_solver.cpp
src/theory/arith/nl/ext/monomial_check.cpp
src/theory/arith/nl/nl_lemma_utils.cpp
src/theory/arith/nl/nl_model.cpp
src/theory/arith/nl/nl_model.h
src/theory/arith/nl/nonlinear_extension.cpp
src/theory/arith/nl/poly_conversion.h
src/theory/arith/nl/transcendental/transcendental_solver.cpp

index b140a41909e5160567c76fd6caace8c011bf744f..f08cf18c1b7d2dcec5c2eb9c5ef20f792d418e7d 100644 (file)
@@ -44,6 +44,16 @@ Node Subs::getSubs(Node v) const
   return d_subs[i];
 }
 
+std::optional<Node> Subs::find(TNode v) const
+{
+  auto it = std::find(d_vars.begin(), d_vars.end(), v);
+  if (it == d_vars.end())
+  {
+    return {};
+  }
+  return d_subs[std::distance(d_vars.begin(), it)];
+}
+
 void Subs::add(Node v)
 {
   SkolemManager* sm = NodeManager::currentNM()->getSkolemManager();
@@ -62,7 +72,7 @@ void Subs::add(const std::vector<Node>& vs)
 
 void Subs::add(Node v, Node s)
 {
-  Assert(v.getType().isComparableTo(s.getType()));
+  Assert(s.isNull() || v.getType().isComparableTo(s.getType()));
   d_vars.push_back(v);
   d_subs.push_back(s);
 }
index afde63b6e8e4bbab4d775a6ca1529cb903ee5fa7..245c6d77a3eb131f849edd1f3aa370ebb330b093 100644 (file)
@@ -17,7 +17,9 @@
 #define CVC5__EXPR__SUBS_H
 
 #include <map>
+#include <optional>
 #include <vector>
+
 #include "expr/node.h"
 
 namespace cvc5 {
@@ -41,6 +43,8 @@ class Subs
   bool contains(Node v) const;
   /** Get the substitution for v if it exists, or null otherwise */
   Node getSubs(Node v) const;
+  /** Find the substitution for v, or return std::nullopt */
+  std::optional<Node> find(TNode v) const;
   /** Add v -> k for fresh skolem of the same type as v */
   void add(Node v);
   /** Add v -> k for fresh skolem of the same type as v for each v in vs */
index 75edc49f59eaf05dba6ca754b71dd78844ae1302..5645542d0323cc86bf69f48f28054c366626b8bc 100644 (file)
@@ -203,31 +203,26 @@ void printRationalApprox(const char* c, Node cr, unsigned prec)
   }
 }
 
-Node arithSubstitute(Node n, std::vector<Node>& vars, std::vector<Node>& subs)
+Node arithSubstitute(Node n, const Subs& sub)
 {
-  Assert(vars.size() == subs.size());
   NodeManager* nm = NodeManager::currentNM();
   std::unordered_map<TNode, Node> visited;
-  std::unordered_map<TNode, Node>::iterator it;
-  std::vector<Node>::iterator itv;
   std::vector<TNode> visit;
-  TNode cur;
-  Kind ck;
   visit.push_back(n);
   do
   {
-    cur = visit.back();
+    TNode cur = visit.back();
     visit.pop_back();
-    it = visited.find(cur);
+    auto it = visited.find(cur);
 
     if (it == visited.end())
     {
       visited[cur] = Node::null();
-      ck = cur.getKind();
-      itv = std::find(vars.begin(), vars.end(), cur);
-      if (itv != vars.end())
+      Kind ck = cur.getKind();
+      auto s = sub.find(cur);
+      if (s)
       {
-        visited[cur] = subs[std::distance(vars.begin(), itv)];
+        visited[cur] = *s;
       }
       else if (cur.getNumChildren() == 0)
       {
index b842ae58e2b29b6a156fc072278ab64023c45003..0d7f214d7ea2216cbe2cafc583b7160b91a15ba1 100644 (file)
@@ -24,6 +24,7 @@
 
 #include "context/cdhashset.h"
 #include "expr/node.h"
+#include "expr/subs.h"
 #include "theory/arith/arithvar.h"
 #include "util/dense_map.h"
 #include "util/integer.h"
@@ -313,13 +314,13 @@ void printRationalApprox(const char* c, Node cr, unsigned prec = 5);
 
 /** Arithmetic substitute
  *
- * This computes the substitution n { vars -> subs }, but with the caveat
+ * This computes the substitution n { subs }, but with the caveat
  * that subterms of n that belong to a theory other than arithmetic are
  * not traversed. In other words, terms that belong to other theories are
  * treated as atomic variables. For example:
  *   (5*f(x) + 7*x ){ x -> 3 } returns 5*f(x) + 7*3.
  */
-Node arithSubstitute(Node n, std::vector<Node>& vars, std::vector<Node>& subs);
+Node arithSubstitute(Node n, const Subs& sub);
 
 /** Make the node u >= a ^ a >= l */
 Node mkBounded(Node l, Node a, Node u);
index ebaeb9d61aa7b98a2aa6710a07739b6f25eddc88..132cb9795a257e56636ee5e6a41199fcee379427 100644 (file)
@@ -180,11 +180,11 @@ bool CadSolver::constructModelIfAvailable(std::vector<Node>& assertions)
     Node value = value_to_node(d_CAC.getModel().get(v), d_ranVariable);
     if (value.isConst())
     {
-      d_model.addCheckModelSubstitution(variable, value);
+      d_model.addSubstitution(variable, value);
     }
     else
     {
-      d_model.addCheckModelWitness(variable, value);
+      d_model.addWitness(variable, value);
     }
     Trace("nl-cad") << "-> " << v << " = " << value << std::endl;
   }
index 330cd57a30c054c66679f8ef82d5d4f20e13cef3..b077dcfd0d2893497f64fd82fc4c1e35805a0f8f 100644 (file)
@@ -402,7 +402,7 @@ bool MonomialCheck::compareMonomial(
   if (a_index == vla.size() && b_index == vlb.size())
   {
     // finished, compare absolute value of abstract model values
-    int modelStatus = d_data->d_model.compare(oa, ob, false, true) * -2;
+    int modelStatus = d_data->d_model.compare(oa, ob, false, true) * 2;
     Trace("nl-ext-comp") << "...finished comparison with " << oa << " <"
                          << status << "> " << ob
                          << ", model status = " << modelStatus << std::endl;
@@ -677,7 +677,7 @@ void MonomialCheck::assignOrderIds(std::vector<Node>& vars,
         {
           Node vv = d_data->d_model.computeModelValue(
               d_order_points[order_index], isConcrete);
-          if (d_data->d_model.compareValue(v, vv, isAbsolute) <= 0)
+          if (d_data->d_model.compareValue(v, vv, isAbsolute) >= 0)
           {
             counter++;
             Trace("nl-ext-mvo") << "O[" << d_order_points[order_index]
index 3e2ebe87e57eb2276acb342e89f94a846c9a8afe..18e296da75b63793d44c72e308bbb2b541befcf2 100644 (file)
@@ -45,7 +45,7 @@ bool SortNlModel::operator()(Node i, Node j)
   {
     return i < j;
   }
-  return d_reverse_order ? cv < 0 : cv > 0;
+  return d_reverse_order ? cv > 0 : cv < 0;
 }
 
 bool SortNonlinearDegree::operator()(Node i, Node j)
index ca75a1a064dd0cbd1892ec60780e578ae3f76814..427d203ea68d6e12c38a018e23f8ada99fdb6b31 100644 (file)
@@ -43,18 +43,12 @@ NlModel::NlModel() : d_used_approx(false)
 
 NlModel::~NlModel() {}
 
-void NlModel::reset(TheoryModel* m, std::map<Node, Node>& arithModel)
+void NlModel::reset(TheoryModel* m, const std::map<Node, Node>& arithModel)
 {
   d_model = m;
-  d_mv[0].clear();
-  d_mv[1].clear();
-  d_arithVal.clear();
-  // process arithModel
-  std::map<Node, Node>::iterator it;
-  for (const std::pair<const Node, Node>& m2 : arithModel)
-  {
-    d_arithVal[m2.first] = m2.second;
-  }
+  d_concreteModelCache.clear();
+  d_abstractModelCache.clear();
+  d_arithVal = arithModel;
 }
 
 void NlModel::resetCheck()
@@ -63,46 +57,42 @@ void NlModel::resetCheck()
   d_check_model_solved.clear();
   d_check_model_bounds.clear();
   d_check_model_witnesses.clear();
-  d_check_model_vars.clear();
-  d_check_model_subs.clear();
+  d_substitutions.clear();
 }
 
-Node NlModel::computeConcreteModelValue(Node n)
+Node NlModel::computeConcreteModelValue(TNode n)
 {
   return computeModelValue(n, true);
 }
 
-Node NlModel::computeAbstractModelValue(Node n)
+Node NlModel::computeAbstractModelValue(TNode n)
 {
   return computeModelValue(n, false);
 }
 
-Node NlModel::computeModelValue(Node n, bool isConcrete)
+Node NlModel::computeModelValue(TNode n, bool isConcrete)
 {
-  unsigned index = isConcrete ? 0 : 1;
-  std::map<Node, Node>::iterator it = d_mv[index].find(n);
-  if (it != d_mv[index].end())
+  auto& cache = isConcrete ? d_concreteModelCache : d_abstractModelCache;
+  if (auto it = cache.find(n); it != cache.end())
   {
     return it->second;
   }
-  Trace("nl-ext-mv-debug") << "computeModelValue " << n << ", index=" << index
-                           << std::endl;
+  Trace("nl-ext-mv-debug") << "computeModelValue " << n
+                           << ", isConcrete=" << isConcrete << std::endl;
   Node ret;
-  Kind nk = n.getKind();
   if (n.isConst())
   {
     ret = n;
   }
-  else if (!isConcrete && hasTerm(n))
+  else if (!isConcrete && hasLinearModelValue(n, ret))
   {
     // use model value for abstraction
-    ret = getRepresentative(n);
   }
   else if (n.getNumChildren() == 0)
   {
     // we are interested in the exact value of PI, which cannot be computed.
     // hence, we return PI itself when asked for the concrete value.
-    if (nk == PI)
+    if (n.getKind() == PI)
     {
       ret = n;
     }
@@ -114,7 +104,7 @@ Node NlModel::computeModelValue(Node n, bool isConcrete)
   else
   {
     // otherwise, compute true value
-    TheoryId ctid = theory::kindToTheoryId(nk);
+    TheoryId ctid = theory::kindToTheoryId(n.getKind());
     if (ctid != THEORY_ARITH && ctid != THEORY_BOOL && ctid != THEORY_BUILTIN)
     {
       // we directly look up terms not belonging to arithmetic
@@ -125,65 +115,28 @@ Node NlModel::computeModelValue(Node n, bool isConcrete)
       std::vector<Node> children;
       if (n.getMetaKind() == metakind::PARAMETERIZED)
       {
-        children.push_back(n.getOperator());
+        children.emplace_back(n.getOperator());
       }
-      for (unsigned i = 0, nchild = n.getNumChildren(); i < nchild; i++)
+      for (size_t i = 0, nchild = n.getNumChildren(); i < nchild; i++)
       {
-        Node mc = computeModelValue(n[i], isConcrete);
-        children.push_back(mc);
+        children.emplace_back(computeModelValue(n[i], isConcrete));
       }
-      ret = NodeManager::currentNM()->mkNode(nk, children);
+      ret = NodeManager::currentNM()->mkNode(n.getKind(), children);
       ret = Rewriter::rewrite(ret);
     }
   }
-  Trace("nl-ext-mv-debug") << "computed " << (index == 0 ? "M" : "M_A") << "["
+  Trace("nl-ext-mv-debug") << "computed " << (isConcrete ? "M" : "M_A") << "["
                            << n << "] = " << ret << std::endl;
-  d_mv[index][n] = ret;
+  cache[n] = ret;
   return ret;
 }
 
-bool NlModel::hasTerm(Node n) const
+int NlModel::compare(TNode i, TNode j, bool isConcrete, bool isAbsolute)
 {
-  return d_arithVal.find(n) != d_arithVal.end();
-}
-
-Node NlModel::getRepresentative(Node n) const
-{
-  if (n.isConst())
-  {
-    return n;
-  }
-  std::map<Node, Node>::const_iterator it = d_arithVal.find(n);
-  if (it != d_arithVal.end())
-  {
-    AlwaysAssert(it->second.isConst());
-    return it->second;
-  }
-  return d_model->getRepresentative(n);
-}
-
-Node NlModel::getValueInternal(Node n)
-{
-  if (n.isConst())
+  if (i == j)
   {
-    return n;
+    return 0;
   }
-  std::map<Node, Node>::const_iterator it = d_arithVal.find(n);
-  if (it != d_arithVal.end())
-  {
-    AlwaysAssert(it->second.isConst());
-    return it->second;
-  }
-  // It is unconstrained in the model, return 0. We additionally add it
-  // to mapping from the linear solver. This ensures that if the nonlinear
-  // solver assumes that n = 0, then this assumption is recorded in the overall
-  // model.
-  d_arithVal[n] = d_zero;
-  return d_zero;
-}
-
-int NlModel::compare(Node i, Node j, bool isConcrete, bool isAbsolute)
-{
   Node ci = computeModelValue(i, isConcrete);
   Node cj = computeModelValue(j, isConcrete);
   if (ci.isConst())
@@ -197,27 +150,24 @@ int NlModel::compare(Node i, Node j, bool isConcrete, bool isAbsolute)
   return cj.isConst() ? -1 : 0;
 }
 
-int NlModel::compareValue(Node i, Node j, bool isAbsolute) const
+int NlModel::compareValue(TNode i, TNode j, bool isAbsolute) const
 {
   Assert(i.isConst() && j.isConst());
-  int ret;
   if (i == j)
   {
-    ret = 0;
+    return 0;
   }
-  else if (!isAbsolute)
+  if (!isAbsolute)
   {
-    ret = i.getConst<Rational>() < j.getConst<Rational>() ? 1 : -1;
+    return i.getConst<Rational>() < j.getConst<Rational>() ? -1 : 1;
   }
-  else
+  Rational iabs = i.getConst<Rational>().abs();
+  Rational jabs = j.getConst<Rational>().abs();
+  if (iabs == jabs)
   {
-    ret = (i.getConst<Rational>().abs() == j.getConst<Rational>().abs()
-               ? 0
-               : (i.getConst<Rational>().abs() < j.getConst<Rational>().abs()
-                      ? 1
-                      : -1));
+    return 0;
   }
-  return ret;
+  return iabs < jabs ? -1 : 1;
 }
 
 bool NlModel::checkModel(const std::vector<Node>& assertions,
@@ -262,7 +212,7 @@ bool NlModel::checkModel(const std::vector<Node>& assertions,
               && !isTranscendentalKind(k))
           {
             // if we have not set an approximate bound for it
-            if (!hasCheckModelAssignment(cur))
+            if (!hasAssignment(cur))
             {
               // set its exact model value in the substitution
               Node curv = computeConcreteModelValue(cur);
@@ -273,7 +223,7 @@ bool NlModel::checkModel(const std::vector<Node>& assertions,
                 printRationalApprox("nl-ext-cm", curv);
                 Trace("nl-ext-cm") << std::endl;
               }
-              bool ret = addCheckModelSubstitution(cur, curv);
+              bool ret = addSubstitution(cur, curv);
               AlwaysAssert(ret);
             }
           }
@@ -294,10 +244,9 @@ bool NlModel::checkModel(const std::vector<Node>& assertions,
     {
       Node av = a;
       // apply the substitution to a
-      if (!d_check_model_vars.empty())
+      if (!d_substitutions.empty())
       {
-        av = arithSubstitute(av, d_check_model_vars, d_check_model_subs);
-        av = Rewriter::rewrite(av);
+        av = Rewriter::rewrite(arithSubstitute(av, d_substitutions));
       }
       // simple check literal
       if (!simpleCheckModelLit(av))
@@ -321,14 +270,13 @@ bool NlModel::checkModel(const std::vector<Node>& assertions,
   return true;
 }
 
-bool NlModel::addCheckModelSubstitution(TNode v, TNode s)
+bool NlModel::addSubstitution(TNode v, TNode s)
 {
   // should not substitute the same variable twice
   Trace("nl-ext-model") << "* check model substitution : " << v << " -> " << s
                         << std::endl;
   // should not set exact bound more than once
-  if (std::find(d_check_model_vars.begin(), d_check_model_vars.end(), v)
-      != d_check_model_vars.end())
+  if (d_substitutions.contains(v))
   {
     Trace("nl-ext-model") << "...ERROR: already has value." << std::endl;
     // this should never happen since substitutions should be applied eagerly
@@ -352,37 +300,31 @@ bool NlModel::addCheckModelSubstitution(TNode v, TNode s)
   Assert(d_check_model_witnesses.find(v) == d_check_model_witnesses.end())
       << "We tried to add a substitution where we already had a witness term."
       << std::endl;
-  std::vector<Node> varsTmp;
-  varsTmp.push_back(v);
-  std::vector<Node> subsTmp;
-  subsTmp.push_back(s);
-  for (unsigned i = 0, size = d_check_model_subs.size(); i < size; i++)
+  Subs tmp;
+  tmp.add(v, s);
+  for (auto& sub : d_substitutions.d_subs)
   {
-    Node ms = d_check_model_subs[i];
-    Node mss = arithSubstitute(ms, varsTmp, subsTmp);
-    if (mss != ms)
+    Node ms = arithSubstitute(sub, tmp);
+    if (ms != sub)
     {
-      mss = Rewriter::rewrite(mss);
+      sub = Rewriter::rewrite(ms);
     }
-    d_check_model_subs[i] = mss;
   }
-  d_check_model_vars.push_back(v);
-  d_check_model_subs.push_back(s);
+  d_substitutions.add(v, s);
   return true;
 }
 
-bool NlModel::addCheckModelBound(TNode v, TNode l, TNode u)
+bool NlModel::addBound(TNode v, TNode l, TNode u)
 {
   Trace("nl-ext-model") << "* check model bound : " << v << " -> [" << l << " "
                         << u << "]" << std::endl;
   if (l == u)
   {
     // bound is exact, can add as substitution
-    return addCheckModelSubstitution(v, l);
+    return addSubstitution(v, l);
   }
   // should not set a bound for a value that is exact
-  if (std::find(d_check_model_vars.begin(), d_check_model_vars.end(), v)
-      != d_check_model_vars.end())
+  if (d_substitutions.contains(v))
   {
     Trace("nl-ext-model")
         << "...ERROR: setting bound for variable that already has exact value."
@@ -405,13 +347,12 @@ bool NlModel::addCheckModelBound(TNode v, TNode l, TNode u)
   return true;
 }
 
-bool NlModel::addCheckModelWitness(TNode v, TNode w)
+bool NlModel::addWitness(TNode v, TNode w)
 {
   Trace("nl-ext-model") << "* check model witness : " << v << " -> " << w
                         << std::endl;
   // should not set a witness for a value that is already set
-  if (std::find(d_check_model_vars.begin(), d_check_model_vars.end(), v)
-      != d_check_model_vars.end())
+  if (d_substitutions.contains(v))
   {
     Trace("nl-ext-model") << "...ERROR: setting witness for variable that "
                              "already has a constant value."
@@ -423,20 +364,6 @@ bool NlModel::addCheckModelWitness(TNode v, TNode w)
   return true;
 }
 
-bool NlModel::hasCheckModelAssignment(Node v) const
-{
-  if (d_check_model_bounds.find(v) != d_check_model_bounds.end())
-  {
-    return true;
-  }
-  if (d_check_model_witnesses.find(v) != d_check_model_witnesses.end())
-  {
-    return true;
-  }
-  return std::find(d_check_model_vars.begin(), d_check_model_vars.end(), v)
-         != d_check_model_vars.end();
-}
-
 void NlModel::setUsedApproximate() { d_used_approx = true; }
 
 bool NlModel::usedApproximate() const { return d_used_approx; }
@@ -446,9 +373,9 @@ bool NlModel::solveEqualitySimple(Node eq,
                                   std::vector<NlLemma>& lemmas)
 {
   Node seq = eq;
-  if (!d_check_model_vars.empty())
+  if (!d_substitutions.empty())
   {
-    seq = arithSubstitute(eq, d_check_model_vars, d_check_model_subs);
+    seq = arithSubstitute(eq, d_substitutions);
     seq = Rewriter::rewrite(seq);
     if (seq.isConst())
     {
@@ -545,7 +472,7 @@ bool NlModel::solveEqualitySimple(Node eq,
     {
       Trace("nl-ext-cm-debug") << "check subs var : " << uv << std::endl;
       // cannot already have a bound
-      if (uv.isVar() && !hasCheckModelAssignment(uv))
+      if (uv.isVar() && !hasAssignment(uv))
       {
         Node slv;
         Node veqc;
@@ -560,7 +487,7 @@ bool NlModel::solveEqualitySimple(Node eq,
           {
             Trace("nl-ext-cm")
                 << "check-model-subs : " << uv << " -> " << slv << std::endl;
-            bool ret = addCheckModelSubstitution(uv, slv);
+            bool ret = addSubstitution(uv, slv);
             if (ret)
             {
               Trace("nl-ext-cms") << "...success, model substitution " << uv
@@ -577,7 +504,7 @@ bool NlModel::solveEqualitySimple(Node eq,
     {
       Trace("nl-ext-cm-debug") << "check set var : " << uvf << std::endl;
       // cannot already have a bound
-      if (uvf.isVar() && !hasCheckModelAssignment(uvf))
+      if (uvf.isVar() && !hasAssignment(uvf))
       {
         Node uvfv = computeConcreteModelValue(uvf);
         if (Trace.isOn("nl-ext-cm"))
@@ -586,7 +513,7 @@ bool NlModel::solveEqualitySimple(Node eq,
           printRationalApprox("nl-ext-cm", uvfv);
           Trace("nl-ext-cm") << std::endl;
         }
-        bool ret = addCheckModelSubstitution(uvf, uvfv);
+        bool ret = addSubstitution(uvf, uvfv);
         // recurse
         return ret ? solveEqualitySimple(eq, d, lemmas) : false;
       }
@@ -618,7 +545,7 @@ bool NlModel::solveEqualitySimple(Node eq,
       printRationalApprox("nl-ext-cm", val);
       Trace("nl-ext-cm") << std::endl;
     }
-    bool ret = addCheckModelSubstitution(var, val);
+    bool ret = addSubstitution(var, val);
     if (ret)
     {
       Trace("nl-ext-cms") << "...success, solved linear." << std::endl;
@@ -647,7 +574,7 @@ bool NlModel::solveEqualitySimple(Node eq,
     Trace("nl-ext-cms") << "...fail due to negative discriminant." << std::endl;
     return false;
   }
-  if (hasCheckModelAssignment(var))
+  if (hasAssignment(var))
   {
     Trace("nl-ext-cms") << "...fail due to bounds on variable to solve for."
                         << std::endl;
@@ -730,8 +657,7 @@ bool NlModel::solveEqualitySimple(Node eq,
     printRationalApprox("nl-ext-cm", bounds[r_use_index][1]);
     Trace("nl-ext-cm") << std::endl;
   }
-  bool ret =
-      addCheckModelBound(var, bounds[r_use_index][0], bounds[r_use_index][1]);
+  bool ret = addBound(var, bounds[r_use_index][0], bounds[r_use_index][1]);
   if (ret)
   {
     d_check_model_solved[eq] = var;
@@ -829,8 +755,7 @@ bool NlModel::simpleCheckModelLit(Node lit)
                                                 ? vs_invalid[0]
                                                 : nm->mkNode(PLUS, vs_invalid));
   // substitution to try
-  std::vector<Node> qvars;
-  std::vector<Node> qsubs;
+  Subs qsub;
   for (const Node& v : vs)
   {
     // is it a valid variable?
@@ -882,7 +807,7 @@ bool NlModel::simpleCheckModelLit(Node lit)
         Assert(boundn[0].getConst<Rational>()
                <= boundn[1].getConst<Rational>());
         Node s;
-        qvars.push_back(v);
+        qsub.add(v, Node());
         if (cmp[0] != cmp[1])
         {
           Assert(!cmp[0] && cmp[1]);
@@ -899,10 +824,9 @@ bool NlModel::simpleCheckModelLit(Node lit)
             Node tcmpn[2];
             for (unsigned r = 0; r < 2; r++)
             {
-              qsubs.push_back(boundn[r]);
-              Node ts = arithSubstitute(t, qvars, qsubs);
+              qsub.d_subs.back() = boundn[r];
+              Node ts = arithSubstitute(t, qsub);
               tcmpn[r] = Rewriter::rewrite(ts);
-              qsubs.pop_back();
             }
             Node tcmp = nm->mkNode(LT, tcmpn[0], tcmpn[1]);
             Trace("nl-ext-cms-debug")
@@ -932,16 +856,15 @@ bool NlModel::simpleCheckModelLit(Node lit)
           s = boundn[bindex_use];
         }
         Assert(!s.isNull());
-        qsubs.push_back(s);
+        qsub.d_subs.back() = s;
         Trace("nl-ext-cms") << "* set bound based on quadratic : " << v
                             << " -> " << s << std::endl;
       }
     }
   }
-  if (!qvars.empty())
+  if (!qsub.empty())
   {
-    Assert(qvars.size() == qsubs.size());
-    Node slit = arithSubstitute(lit, qvars, qsubs);
+    Node slit = arithSubstitute(lit, qsub);
     slit = Rewriter::rewrite(slit);
     return simpleCheckModelLit(slit);
   }
@@ -1242,21 +1165,26 @@ void NlModel::printModelValue(const char* c, Node n, unsigned prec) const
   if (Trace.isOn(c))
   {
     Trace(c) << "  " << n << " -> ";
-    for (int i = 1; i >= 0; --i)
+    const Node& aval = d_abstractModelCache.at(n);
+    if (aval.isConst())
     {
-      std::map<Node, Node>::const_iterator it = d_mv[i].find(n);
-      Assert(it != d_mv[i].end());
-      if (it->second.isConst())
-      {
-        printRationalApprox(c, it->second, prec);
-      }
-      else
-      {
-        Trace(c) << "?";
-      }
-      Trace(c) << (i == 1 ? " [actual: " : " ]");
+      printRationalApprox(c, aval, prec);
+    }
+    else
+    {
+      Trace(c) << "?";
+    }
+    Trace(c) << " [actual: ";
+    const Node& cval = d_concreteModelCache.at(n);
+    if (cval.isConst())
+    {
+      printRationalApprox(c, cval, prec);
+    }
+    else
+    {
+      Trace(c) << "?";
     }
-    Trace(c) << std::endl;
+    Trace(c) << " ]" << std::endl;
   }
 }
 
@@ -1316,13 +1244,12 @@ void NlModel::getModelValueRepair(
   // special kind approximation of the form (witness x. x = exact_value).
   // Notice that the above term gets rewritten such that the choice function
   // is eliminated.
-  for (size_t i = 0, num = d_check_model_vars.size(); i < num; i++)
+  for (size_t i = 0; i < d_substitutions.size(); ++i)
   {
-    Node v = d_check_model_vars[i];
-    Node s = d_check_model_subs[i];
     // overwrite
-    arithModel[v] = s;
-    Trace("nl-model") << v << " solved is " << s << std::endl;
+    arithModel[d_substitutions.d_vars[i]] = d_substitutions.d_subs[i];
+    Trace("nl-model") << d_substitutions.d_vars[i] << " solved is "
+                      << d_substitutions.d_subs[i] << std::endl;
   }
 
   // multiplication terms should not be given values; their values are
@@ -1341,6 +1268,49 @@ void NlModel::getModelValueRepair(
   }
 }
 
+Node NlModel::getValueInternal(TNode n)
+{
+  if (n.isConst())
+  {
+    return n;
+  }
+  if (auto it = d_arithVal.find(n); it != d_arithVal.end())
+  {
+    AlwaysAssert(it->second.isConst());
+    return it->second;
+  }
+  // It is unconstrained in the model, return 0. We additionally add it
+  // to mapping from the linear solver. This ensures that if the nonlinear
+  // solver assumes that n = 0, then this assumption is recorded in the overall
+  // model.
+  d_arithVal[n] = d_zero;
+  return d_zero;
+}
+
+bool NlModel::hasAssignment(Node v) const
+{
+  if (d_check_model_bounds.find(v) != d_check_model_bounds.end())
+  {
+    return true;
+  }
+  if (d_check_model_witnesses.find(v) != d_check_model_witnesses.end())
+  {
+    return true;
+  }
+  return (d_substitutions.contains(v));
+}
+
+bool NlModel::hasLinearModelValue(TNode v, Node& val) const
+{
+  auto it = d_arithVal.find(v);
+  if (it != d_arithVal.end())
+  {
+    val = it->second;
+    return true;
+  }
+  return false;
+}
+
 }  // namespace nl
 }  // namespace arith
 }  // namespace theory
index 526a9393497a6ffa6d8f30367fb879a3ae5b2ffa..b3b841eab717e08e17a38296045387a7326d4ba3 100644 (file)
@@ -22,6 +22,7 @@
 
 #include "expr/kind.h"
 #include "expr/node.h"
+#include "expr/subs.h"
 
 namespace cvc5 {
 
@@ -59,7 +60,7 @@ class NlModel
    * where m is the model of the theory of arithmetic. This method resets the
    * cache of computed model values.
    */
-  void reset(TheoryModel* m, std::map<Node, Node>& arithModel);
+  void reset(TheoryModel* m, const std::map<Node, Node>& arithModel);
   /**
    * This method is called when the non-linear arithmetic solver restarts
    * its computation of lemmas and models during a last call effort check.
@@ -87,9 +88,9 @@ class NlModel
    * whereas:
    *   computeModelValue( a*b, false ) = 5
    */
-  Node computeConcreteModelValue(Node n);
-  Node computeAbstractModelValue(Node n);
-  Node computeModelValue(Node n, bool isConcrete);
+  Node computeConcreteModelValue(TNode n);
+  Node computeAbstractModelValue(TNode n);
+  Node computeModelValue(TNode n, bool isConcrete);
 
   /**
    * Compare arithmetic terms i and j based an ordering.
@@ -101,10 +102,10 @@ class NlModel
    * otherwise, we consider their abstract model values. For definitions of
    * concrete vs abstract model values, see NlModel::computeModelValue.
    *
-   * If isAbsolute is true, we compare the absolute value of thee above
+   * If isAbsolute is true, we compare the absolute value of the above
    * values.
    */
-  int compare(Node i, Node j, bool isConcrete, bool isAbsolute);
+  int compare(TNode i, TNode j, bool isConcrete, bool isAbsolute);
   /**
    * Compare arithmetic terms i and j based an ordering.
    *
@@ -113,38 +114,31 @@ class NlModel
    *
    * If isAbsolute is true, we compare the absolute value of i and j
    */
-  int compareValue(Node i, Node j, bool isAbsolute) const;
+  int compareValue(TNode i, TNode j, bool isAbsolute) const;
 
   //------------------------------ recording model substitutions and bounds
   /**
    * Adds the model substitution v -> s. This applies the substitution
-   * { v -> s } to each term in d_check_model_subs and adds v,s to
-   * d_check_model_vars and d_check_model_subs respectively.
+   * { v -> s } to each term in d_substitutions and then adds v,s to
+   * d_substitutions.
    * If this method returns false, then the substitution v -> s is inconsistent
    * with the current substitution and bounds.
    */
-  bool addCheckModelSubstitution(TNode v, TNode s);
+  bool addSubstitution(TNode v, TNode s);
   /**
    * Adds the bound x -> < l, u > to the map above, and records the
    * approximation ( x, l <= x <= u ) in the model. This method returns false
    * if the bound is inconsistent with the current model substitution or
    * bounds.
    */
-  bool addCheckModelBound(TNode v, TNode l, TNode u);
+  bool addBound(TNode v, TNode l, TNode u);
   /**
    * Adds a model witness v -> w to the underlying theory model.
    * The witness should only contain a single variable v and evaluate to true
    * for exactly one value of v. The variable v is then (implicitly,
    * declaratively) assigned to this single value that satisfies the witness w.
    */
-  bool addCheckModelWitness(TNode v, TNode w);
-  /**
-   * Have we assigned v in the current checkModel(...) call?
-   *
-   * This method returns true if variable v is in the domain of
-   * d_check_model_bounds or if it occurs in d_check_model_vars.
-   */
-  bool hasCheckModelAssignment(Node v) const;
+  bool addWitness(TNode v, TNode w);
   /**
    * Checks the current model based on solving for equalities, and using error
    * bounds on the Taylor approximation.
@@ -198,21 +192,53 @@ class NlModel
       bool witnessToValue);
 
  private:
+  /** Cache for concrete model values */
+  std::map<Node, Node> d_concreteModelCache;
+  /** Cache for abstract model values */
+  std::map<Node, Node> d_abstractModelCache;
+
   /** The current model */
   TheoryModel* d_model;
+
+  /**
+   * The values that the arithmetic theory solver assigned in the model. This
+   * corresponds to the set of equalities that linear solver (via TheoryArith)
+   * is currently sending to TheoryModel during collectModelValues, plus
+   * additional entries x -> 0 for variables that were unassigned by the linear
+   * solver.
+   */
+  std::map<Node, Node> d_arithVal;
+
+  /**
+   * A substitution from variables that appear in assertions to a solved form
+   * term.
+   */
+  Subs d_substitutions;
+
   /** Get the model value of n from the model object above */
-  Node getValueInternal(Node n);
-  /** Does the equality engine of the model have term n? */
-  bool hasTerm(Node n) const;
-  /** Get the representative of n in the model */
-  Node getRepresentative(Node n) const;
+  Node getValueInternal(TNode n);
+
+  /**
+   * Have we assigned v in the current checkModel(...) call?
+   *
+   * This method returns true if variable v is in the domain of
+   * d_check_model_bounds or if it occurs in d_substitutions.
+   */
+  bool hasAssignment(Node v) const;
+
+  /**
+   * Checks whether we have a linear model value for v, i.e. whether v is
+   * contained in d_arithVal. If so, we also store the value that v is mapped
+   * to in val.
+   */
+  bool hasLinearModelValue(TNode v, Node& val) const;
 
   //---------------------------check model
   /**
    * This method is used during checkModel(...). It takes as input an
    * equality eq. If it returns true, then eq is correct-by-construction based
    * on the information stored in our model representation (see
-   * d_check_model_vars, d_check_model_subs, d_check_model_bounds), and eq
+   * d_substitutions, d_check_model_bounds), and eq
    * is added to d_check_model_solved. The equality eq may involve any
    * number of variables, and monomials of arbitrary degree. If this method
    * returns false, then we did not show that the equality was true in the
@@ -267,29 +293,6 @@ class NlModel
   Node d_true;
   Node d_false;
   Node d_null;
-  /**
-   * The values that the arithmetic theory solver assigned in the model. This
-   * corresponds to the set of equalities that linear solver (via TheoryArith)
-   * is currently sending to TheoryModel during collectModelValues, plus
-   * additional entries x -> 0 for variables that were unassigned by the linear
-   * solver.
-   */
-  std::map<Node, Node> d_arithVal;
-  /**
-   * cache of model values
-   *
-   * Stores the the concrete/abstract model values. This is a cache of the
-   * computeModelValue method.
-   */
-  std::map<Node, Node> d_mv[2];
-  /**
-   * A substitution from variables that appear in assertions to a solved form
-   * term. These vectors are ordered in the form:
-   *   x_1 -> t_1 ... x_n -> t_n
-   * where x_i is not in the free variables of t_j for j>=i.
-   */
-  std::vector<Node> d_check_model_vars;
-  std::vector<Node> d_check_model_subs;
   /**
    * lower and upper bounds for check model
    *
index f80717b578f442dca1f237b63fc1856a19369d5e..207907fcc651d2404059bfc20fae946ddcf98c97 100644 (file)
@@ -96,17 +96,16 @@ void NonlinearExtension::processSideEffect(const NlLemma& se)
 void NonlinearExtension::computeRelevantAssertions(
     const std::vector<Node>& assertions, std::vector<Node>& keep)
 {
-  Trace("nl-ext-rlv") << "Compute relevant assertions..." << std::endl;
-  Valuation v = d_containing.getValuation();
+  const Valuation& v = d_containing.getValuation();
   for (const Node& a : assertions)
   {
     if (v.isRelevant(a))
     {
-      keep.push_back(a);
+      keep.emplace_back(a);
     }
   }
-  Trace("nl-ext-rlv") << "...keep " << keep.size() << "/" << assertions.size()
-                      << " assertions" << std::endl;
+  Trace("nl-ext-rlv") << "...relevant assertions: " << keep.size() << "/"
+                      << assertions.size() << std::endl;
 }
 
 void NonlinearExtension::getAssertions(std::vector<Node>& assertions)
index db64320d5200ec342691214a2e24e7fc92cac022..dddde3c0fca7ad803800b574885179779b6882f9 100644 (file)
@@ -98,8 +98,8 @@ std::pair<poly::Polynomial, poly::SignCondition> as_poly_constraint(
 /**
  * Transforms a real algebraic number to a node suitable for putting it into a
  * model. The resulting node can be either a constant (suitable for
- * addCheckModelSubstitution) or a witness term (suitable for
- * addCheckModelWitness).
+ * addSubstitution) or a witness term (suitable for
+ * addWitness).
  */
 Node ran_to_node(const RealAlgebraicNumber& ran, const Node& ran_variable);
 
index c7bb14b3f7fe23500fcd6aafd7bdbe0a87d5c597..978823a22e0cbccd322616fed39846e36d7be135 100644 (file)
@@ -83,12 +83,10 @@ void TranscendentalSolver::initLastCall(const std::vector<Node>& xts)
 bool TranscendentalSolver::preprocessAssertionsCheckModel(
     std::vector<Node>& assertions)
 {
-  std::vector<Node> pvars;
-  std::vector<Node> psubs;
-  for (const std::pair<const Node, Node>& tb : d_tstate.d_trMaster)
+  Subs subs;
+  for (const auto& sub : d_tstate.d_trMaster)
   {
-    pvars.push_back(tb.first);
-    psubs.push_back(tb.second);
+    subs.add(sub.first, sub.second);
   }
 
   // initialize representation of assertions
@@ -97,9 +95,9 @@ bool TranscendentalSolver::preprocessAssertionsCheckModel(
 
   {
     Node pa = a;
-    if (!pvars.empty())
+    if (!subs.empty())
     {
-      pa = arithSubstitute(pa, pvars, psubs);
+      pa = arithSubstitute(pa, subs);
       pa = Rewriter::rewrite(pa);
     }
     if (!pa.isConst() || !pa.getConst<bool>())
@@ -145,8 +143,8 @@ bool TranscendentalSolver::preprocessAssertionsCheckModel(
             Trace("nl-ext-cm")
                 << "...bound for " << stf << " : [" << bounds.first << ", "
                 << bounds.second << "]" << std::endl;
-            success = d_tstate.d_model.addCheckModelBound(
-                stf, bounds.first, bounds.second);
+            success =
+                d_tstate.d_model.addBound(stf, bounds.first, bounds.second);
           }
         }
       }