Internally remove redundant assertions and infer equalities in NonLinearExtension...
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 20 Mar 2018 19:03:04 +0000 (14:03 -0500)
committerGitHub <noreply@github.com>
Tue, 20 Mar 2018 19:03:04 +0000 (14:03 -0500)
src/theory/arith/nonlinear_extension.cpp
src/theory/arith/nonlinear_extension.h
test/regress/regress1/nl/Makefile.am
test/regress/regress1/nl/nl-eq-infer.smt2 [new file with mode: 0644]

index 5694ce45157adec8bf5981e8dd49aa9a5907c154..07cf43a354bdb64b16c0ee6d63bbce91445a8756 100644 (file)
@@ -1112,6 +1112,124 @@ int NonlinearExtension::flushLemmas(std::vector<Node>& lemmas) {
   return sum;
 }
 
+void NonlinearExtension::getAssertions(std::vector<Node>& assertions)
+{
+  Trace("nl-ext") << "Getting assertions..." << std::endl;
+  NodeManager* nm = NodeManager::currentNM();
+  // get the assertions
+  std::map<Node, Rational> init_bounds[2];
+  std::map<Node, Node> init_bounds_lit[2];
+  unsigned nassertions = 0;
+  std::unordered_set<Node, NodeHashFunction> init_assertions;
+  for (Theory::assertions_iterator it = d_containing.facts_begin();
+       it != d_containing.facts_end();
+       ++it)
+  {
+    nassertions++;
+    const Assertion& assertion = *it;
+    Node lit = assertion.assertion;
+    init_assertions.insert(lit);
+    // check for concrete bounds
+    bool pol = lit.getKind() != NOT;
+    Node atom_orig = lit.getKind() == NOT ? lit[0] : lit;
+
+    std::vector<Node> atoms;
+    if (atom_orig.getKind() == EQUAL)
+    {
+      if (pol)
+      {
+        // t = s  is ( t >= s ^ t <= s )
+        for (unsigned i = 0; i < 2; i++)
+        {
+          Node atom_new = nm->mkNode(GEQ, atom_orig[i], atom_orig[1 - i]);
+          atom_new = Rewriter::rewrite(atom_new);
+          atoms.push_back(atom_new);
+        }
+      }
+    }
+    else
+    {
+      atoms.push_back(atom_orig);
+    }
+
+    for (const Node& atom : atoms)
+    {
+      // non-strict bounds only
+      if (atom.getKind() == GEQ || (!pol && atom.getKind() == GT))
+      {
+        Node p = atom[0];
+        Assert(atom[1].isConst());
+        Rational bound = atom[1].getConst<Rational>();
+        if (!pol)
+        {
+          if (atom[0].getType().isInteger())
+          {
+            // ~( p >= c ) ---> ( p <= c-1 )
+            bound = bound - Rational(1);
+          }
+        }
+        unsigned bindex = pol ? 0 : 1;
+        bool setBound = true;
+        std::map<Node, Rational>::iterator itb = init_bounds[bindex].find(p);
+        if (itb != init_bounds[bindex].end())
+        {
+          if (itb->second == bound)
+          {
+            setBound = atom_orig.getKind() == EQUAL;
+          }
+          else
+          {
+            setBound = pol ? itb->second < bound : itb->second > bound;
+          }
+          if (setBound)
+          {
+            // the bound is subsumed
+            init_assertions.erase(init_bounds_lit[bindex][p]);
+          }
+        }
+        if (setBound)
+        {
+          Trace("nl-ext-init") << (pol ? "Lower" : "Upper") << " bound for "
+                               << p << " : " << bound << std::endl;
+          init_bounds[bindex][p] = bound;
+          init_bounds_lit[bindex][p] = lit;
+        }
+      }
+    }
+  }
+  // for each bound that is the same, ensure we've inferred the equality
+  for (std::pair<const Node, Rational>& ib : init_bounds[0])
+  {
+    Node p = ib.first;
+    Node lit1 = init_bounds_lit[0][p];
+    if (lit1.getKind() != EQUAL)
+    {
+      std::map<Node, Rational>::iterator itb = init_bounds[1].find(p);
+      if (itb != init_bounds[1].end())
+      {
+        if (ib.second == itb->second)
+        {
+          Node eq = p.eqNode(nm->mkConst(ib.second));
+          eq = Rewriter::rewrite(eq);
+          Node lit2 = init_bounds_lit[1][p];
+          Assert(lit2.getKind() != EQUAL);
+          // use the equality instead, thus these are redundant
+          init_assertions.erase(lit1);
+          init_assertions.erase(lit2);
+          init_assertions.insert(eq);
+        }
+      }
+    }
+  }
+
+  for (const Node& a : init_assertions)
+  {
+    assertions.push_back(a);
+  }
+  Trace("nl-ext") << "...keep " << assertions.size() << " / " << nassertions
+                  << " assertions." << std::endl;
+}
+
 std::vector<Node> NonlinearExtension::checkModel(
     const std::vector<Node>& assertions)
 {
@@ -1551,7 +1669,7 @@ int NonlinearExtension::checkLastCall(const std::vector<Node>& assertions,
   //-----------------------------------inferred bounds lemmas
   //  e.g. x >= t => y*x >= y*t
   std::vector< Node > nt_lemmas;
-  lemmas = checkMonomialInferBounds( nt_lemmas, false_asserts );
+  lemmas = checkMonomialInferBounds(nt_lemmas, assertions, false_asserts);
   // Trace("nl-ext") << "Bound lemmas : " << lemmas.size() << ", " <<
   // nt_lemmas.size() << std::endl;  prioritize lemmas that do not
   // introduce new monomials
@@ -1573,7 +1691,7 @@ int NonlinearExtension::checkLastCall(const std::vector<Node>& assertions,
   //------------------------------------factoring lemmas
   //   x*y + x*z >= t => exists k. k = y + z ^ x*k >= t
   if( options::nlExtFactor() ){
-    lemmas = checkFactoring( false_asserts );
+    lemmas = checkFactoring(assertions, false_asserts);
     lemmas_proc = flushLemmas(lemmas);
     if (lemmas_proc > 0) {
       Trace("nl-ext") << "  ...finished with " << lemmas_proc << " new lemmas." << std::endl;
@@ -1634,26 +1752,22 @@ void NonlinearExtension::check(Theory::Effort e) {
       }
     }
   } else {
+    // get the assertions
+    std::vector<Node> assertions;
+    getAssertions(assertions);
+
     bool needsRecheck;
     do
     {
       needsRecheck = false;
       Assert(e == Theory::EFFORT_LAST_CALL);
-      Trace("nl-ext-mv") << "Getting model values... check for [model-false]"
-                         << std::endl;
+
       // reset cached information
       d_mv[0].clear();
       d_mv[1].clear();
 
-      // get the assertions
-      std::vector<Node> assertions;
-      for (Theory::assertions_iterator it = d_containing.facts_begin();
-           it != d_containing.facts_end();
-           ++it)
-      {
-        const Assertion& assertion = *it;
-        assertions.push_back(assertion.assertion);
-      }
+      Trace("nl-ext-mv") << "Getting model values... check for [model-false]"
+                         << std::endl;
       // get the assertions that are false in the model
       const std::vector<Node> false_asserts = checkModel(assertions);
 
@@ -2499,15 +2613,15 @@ std::vector<Node> NonlinearExtension::checkTangentPlanes() {
 }
 
 std::vector<Node> NonlinearExtension::checkMonomialInferBounds(
-    std::vector<Node>& nt_lemmas, const std::vector<Node>& false_asserts)
+    std::vector<Node>& nt_lemmas,
+    const std::vector<Node>& asserts,
+    const std::vector<Node>& false_asserts)
 {
   std::vector< Node > lemmas; 
   // register constraints
   Trace("nl-ext-debug") << "Register bound constraints..." << std::endl;
-  for (context::CDList<Assertion>::const_iterator it =
-           d_containing.facts_begin();
-       it != d_containing.facts_end(); ++it) {
-    Node lit = (*it).assertion;
+  for (const Node& lit : asserts)
+  {
     bool polarity = lit.getKind() != NOT;
     Node atom = lit.getKind() == NOT ? lit[0] : lit;
     registerConstraint(atom);
@@ -2764,14 +2878,12 @@ std::vector<Node> NonlinearExtension::checkMonomialInferBounds(
 }
 
 std::vector<Node> NonlinearExtension::checkFactoring(
-    const std::vector<Node>& false_asserts)
+    const std::vector<Node>& asserts, const std::vector<Node>& false_asserts)
 {
   std::vector< Node > lemmas; 
   Trace("nl-ext") << "Get factoring lemmas..." << std::endl;
-  for (context::CDList<Assertion>::const_iterator it =
-           d_containing.facts_begin();
-       it != d_containing.facts_end(); ++it) {
-    Node lit = (*it).assertion;
+  for (const Node& lit : asserts)
+  {
     bool polarity = lit.getKind() != NOT;
     Node atom = lit.getKind() == NOT ? lit[0] : lit;
     if (std::find(false_asserts.begin(), false_asserts.end(), lit)
index a37ef97f81fe5de7fac58dfb78e015ecaef23cec..96d37cbc276b1afe1b3c0159cf2260d62bffdebf 100644 (file)
@@ -245,6 +245,17 @@ class NonlinearExtension {
   void assignOrderIds(std::vector<Node>& vars, NodeMultiset& d_order,
                       unsigned orderType);
 
+  /** get assertions
+   *
+   * Let M be the set of assertions known by THEORY_ARITH. This function adds a
+   * set of literals M' to assertions such that M' and M are equivalent.
+   *
+   * Examples of how M' differs with M:
+   * (1) M' may not include t < c (in M) if t < c' is in M' for c' < c, where
+   * c and c' are constants,
+   * (2) M' may contain t = c if both t >= c and t <= c are in M.
+   */
+  void getAssertions(std::vector<Node>& assertions);
   /** check model
    *
    * Returns the subset of assertions whose concrete values we cannot show are
@@ -700,7 +711,9 @@ private:
   *      that occur in the current context.
   */
   std::vector<Node> checkMonomialInferBounds(
-      std::vector<Node>& nt_lemmas, const std::vector<Node>& false_asserts);
+      std::vector<Node>& nt_lemmas,
+      const std::vector<Node>& asserts,
+      const std::vector<Node>& false_asserts);
 
   /** check factoring
   *
@@ -714,7 +727,8 @@ private:
   *   ...where k is fresh and x*z + y*z > t is a
   *      constraint that occurs in the current context.
   */
-  std::vector<Node> checkFactoring(const std::vector<Node>& false_asserts);
+  std::vector<Node> checkFactoring(const std::vector<Node>& asserts,
+                                   const std::vector<Node>& false_asserts);
 
   /** check monomial infer resolution bounds
   *
index a9571525379ee9803838824f9b1241dee3f09a3c..a008e4df129f61e9968c3545bfb3cf2c673d6382 100644 (file)
@@ -64,7 +64,8 @@ TESTS =       \
        sin2-ub.smt2 \
        sugar-ident.smt2 \
        zero-subset.smt2 \
-       sin1-deq-sat.smt2
+       sin1-deq-sat.smt2 \
+       nl-eq-infer.smt2
 
 EXTRA_DIST = $(TESTS)
 
diff --git a/test/regress/regress1/nl/nl-eq-infer.smt2 b/test/regress/regress1/nl/nl-eq-infer.smt2
new file mode 100644 (file)
index 0000000..f0968a0
--- /dev/null
@@ -0,0 +1,14 @@
+(set-logic QF_NIA)
+(set-info :status unsat)
+(declare-fun i () Int)
+(declare-fun n () Int)
+(declare-fun s () Int)
+
+(assert (and 
+(= i (+ (* (- 2) s) (* i i))) 
+(>= (+ i (* (- 1) n)) 1) 
+(not (>= (+ i (* (- 1) n)) 2))
+))
+(assert (not (= n (+ (* 2 s) (* (- 1) (* n n))))))
+
+(check-sat)