Do not substitute beneath arithmetic terms in the non-linear solver (#3324)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Wed, 11 Dec 2019 17:58:53 +0000 (11:58 -0600)
committerGitHub <noreply@github.com>
Wed, 11 Dec 2019 17:58:53 +0000 (11:58 -0600)
src/theory/arith/arith_utilities.cpp
src/theory/arith/arith_utilities.h
src/theory/arith/nl_model.cpp
src/theory/arith/nonlinear_extension.cpp
test/regress/CMakeLists.txt
test/regress/regress1/nl/issue3307.smt2 [new file with mode: 0644]

index 3d3078d995bf3e5c82ef216bef97b21be4d931f1..65aaceb809d54096e0abcf245187e1c3c1634571 100644 (file)
@@ -191,6 +191,84 @@ void printRationalApprox(const char* c, Node cr, unsigned prec)
   }
 }
 
+Node arithSubstitute(Node n, std::vector<Node>& vars, std::vector<Node>& subs)
+{
+  Assert(vars.size() == subs.size());
+  NodeManager* nm = NodeManager::currentNM();
+  std::unordered_map<TNode, Node, TNodeHashFunction> visited;
+  std::unordered_map<TNode, Node, TNodeHashFunction>::iterator it;
+  std::vector<Node>::iterator itv;
+  std::vector<TNode> visit;
+  TNode cur;
+  Kind ck;
+  visit.push_back(n);
+  do
+  {
+    cur = visit.back();
+    visit.pop_back();
+    it = visited.find(cur);
+
+    if (it == visited.end())
+    {
+      visited[cur] = Node::null();
+      ck = cur.getKind();
+      itv = std::find(vars.begin(), vars.end(), cur);
+      if (itv != vars.end())
+      {
+        visited[cur] = subs[std::distance(vars.begin(), itv)];
+      }
+      else if (cur.getNumChildren() == 0)
+      {
+        visited[cur] = cur;
+      }
+      else
+      {
+        TheoryId ctid = theory::kindToTheoryId(ck);
+        if (ctid != THEORY_ARITH && ctid != THEORY_BOOL
+            && ctid != THEORY_BUILTIN)
+        {
+          // do not traverse beneath applications that belong to another theory
+          visited[cur] = cur;
+        }
+        else
+        {
+          visit.push_back(cur);
+          for (const Node& cn : cur)
+          {
+            visit.push_back(cn);
+          }
+        }
+      }
+    }
+    else if (it->second.isNull())
+    {
+      Node ret = cur;
+      bool childChanged = false;
+      std::vector<Node> children;
+      if (cur.getMetaKind() == kind::metakind::PARAMETERIZED)
+      {
+        children.push_back(cur.getOperator());
+      }
+      for (const Node& cn : cur)
+      {
+        it = visited.find(cn);
+        Assert(it != visited.end());
+        Assert(!it->second.isNull());
+        childChanged = childChanged || cn != it->second;
+        children.push_back(it->second);
+      }
+      if (childChanged)
+      {
+        ret = nm->mkNode(cur.getKind(), children);
+      }
+      visited[cur] = ret;
+    }
+  } while (!visit.empty());
+  Assert(visited.find(n) != visited.end());
+  Assert(!visited.find(n)->second.isNull());
+  return visited[n];
+}
+
 }  // namespace arith
 }  // namespace theory
 }  // namespace CVC4
index d737fefeb007dbed2c2801b6a48771f8f4dea254..f87a908b4efdde6f597b5399c25ea60abe33298f 100644 (file)
@@ -325,6 +325,16 @@ Node getApproximateConstant(Node c, bool isLower, unsigned prec);
 /** print rational approximation of cr with precision prec on trace c */
 void printRationalApprox(const char* c, Node cr, unsigned prec = 5);
 
+/** Arithmetic substitute
+ *
+ * This computes the substitution n { vars -> subs }, but with the caveat
+ * that subterms of n that belong to a theory other than arithmetic are
+ * not traversed. In other words, terms that belong to other theories are
+ * treated as atomic variables. For example:
+ *   (5*f(x) + 7*x ){ x -> 3 } returns 5*f(x) + 7*3.
+ */
+Node arithSubstitute(Node n, std::vector<Node>& vars, std::vector<Node>& subs);
+
 }/* CVC4::theory::arith namespace */
 }/* CVC4::theory namespace */
 }/* CVC4 namespace */
index fe756e5f7a0c47c0bf4470c28b44d165b7412912..3274867bb7d6af72037140c0806f7e7f1aaaf478 100644 (file)
@@ -284,10 +284,7 @@ bool NlModel::checkModel(const std::vector<Node>& assertions,
       // apply the substitution to a
       if (!d_check_model_vars.empty())
       {
-        av = av.substitute(d_check_model_vars.begin(),
-                           d_check_model_vars.end(),
-                           d_check_model_subs.begin(),
-                           d_check_model_subs.end());
+        av = arithSubstitute(av, d_check_model_vars, d_check_model_subs);
         av = Rewriter::rewrite(av);
       }
       // simple check literal
@@ -360,10 +357,14 @@ bool NlModel::addCheckModelSubstitution(TNode v, TNode s)
       return false;
     }
   }
+  std::vector<Node> varsTmp;
+  varsTmp.push_back(v);
+  std::vector<Node> subsTmp;
+  subsTmp.push_back(s);
   for (unsigned i = 0, size = d_check_model_subs.size(); i < size; i++)
   {
     Node ms = d_check_model_subs[i];
-    Node mss = ms.substitute(v, s);
+    Node mss = arithSubstitute(ms, varsTmp, subsTmp);
     if (mss != ms)
     {
       mss = Rewriter::rewrite(mss);
@@ -430,10 +431,7 @@ bool NlModel::solveEqualitySimple(Node eq,
   Node seq = eq;
   if (!d_check_model_vars.empty())
   {
-    seq = eq.substitute(d_check_model_vars.begin(),
-                        d_check_model_vars.end(),
-                        d_check_model_subs.begin(),
-                        d_check_model_subs.end());
+    seq = arithSubstitute(eq, d_check_model_vars, d_check_model_subs);
     seq = Rewriter::rewrite(seq);
     if (seq.isConst())
     {
@@ -866,8 +864,7 @@ bool NlModel::simpleCheckModelLit(Node lit)
             for (unsigned r = 0; r < 2; r++)
             {
               qsubs.push_back(boundn[r]);
-              Node ts = t.substitute(
-                  qvars.begin(), qvars.end(), qsubs.begin(), qsubs.end());
+              Node ts = arithSubstitute(t, qvars, qsubs);
               tcmpn[r] = Rewriter::rewrite(ts);
               qsubs.pop_back();
             }
index 6e8e7623dce246eebb73120b28c4bfc155cc7dd5..ff2ec412bebc2066f305b83f4f4b4b1b2e766586 100644 (file)
@@ -772,8 +772,7 @@ bool NonlinearExtension::checkModel(const std::vector<Node>& assertions,
     Node pa = a;
     if (!pvars.empty())
     {
-      pa =
-          pa.substitute(pvars.begin(), pvars.end(), psubs.begin(), psubs.end());
+      pa = arithSubstitute(pa, pvars, psubs);
       pa = Rewriter::rewrite(pa);
     }
     if (!pa.isConst() || !pa.getConst<bool>())
index edc44c3d2537aaa1068a77919b6559656fada29d..814eaab497ead116b83c3cda156c8380bcb7c17a 100644 (file)
@@ -1264,6 +1264,7 @@ set(regress_1_tests
   regress1/nl/exp1-lb.smt2
   regress1/nl/exp_monotone.smt2
   regress1/nl/factor_agg_s.smt2
+  regress1/nl/issue3307.smt2
   regress1/nl/issue3441.smt2
   regress1/nl/metitarski-1025.smt2
   regress1/nl/metitarski-3-4.smt2
diff --git a/test/regress/regress1/nl/issue3307.smt2 b/test/regress/regress1/nl/issue3307.smt2
new file mode 100644 (file)
index 0000000..803bfeb
--- /dev/null
@@ -0,0 +1,14 @@
+; COMMAND-LINE: --no-check-models
+; EXPECT: sat
+(set-logic NRA)
+(set-info :status sat)
+(declare-fun a () Real)
+(declare-fun b () Real)
+(assert
+ (and
+  (> b 1)
+  (< a 0)
+  (>= (/ 0 (+ (* a b) (/ (- a) 0))) a)
+  )
+ )
+(check-sat)