Improve integration of nonlinear arithmetic into the arithmetic theory. (#6956)
authorGereon Kremer <nafur42@gmail.com>
Thu, 26 Aug 2021 14:19:56 +0000 (07:19 -0700)
committerGitHub <noreply@github.com>
Thu, 26 Aug 2021 14:19:56 +0000 (14:19 +0000)
The nonlinear extension used to be called only during model construction, a rather atypical place which lead to a series of subtle issues. This PR moves it into the postCheck method. To do that, a few other changes are necessary, most notably collectAssertedTerms and collectTerms are moved back from the model manager into the theory class.

src/theory/arith/nl/nonlinear_extension.cpp
src/theory/arith/theory_arith.cpp
src/theory/arith/theory_arith.h
src/theory/arith/theory_arith_private.cpp
test/regress/regress2/bv_to_int_mask_array_1.smt2

index c0ea3195ada0616faa0774f13338ba7e3e814625..c28292ad346fd29ad5905e18a6ac0681f3b06c02 100644 (file)
@@ -252,7 +252,6 @@ bool NonlinearExtension::checkModel(const std::vector<Node>& assertions)
 
 void NonlinearExtension::check(Theory::Effort e)
 {
-  d_im.reset();
   Trace("nl-ext") << std::endl;
   Trace("nl-ext") << "NonlinearExtension::check, effort = " << e << std::endl;
   if (e == Theory::EFFORT_FULL)
index cb3f21da03d4d1e82a47d863af4683f3ffcbd935..ccdf5a90a4fd4b608639b3ce4a7ded8366fd9065 100644 (file)
@@ -160,6 +160,7 @@ bool TheoryArith::preCheck(Effort level)
 
 void TheoryArith::postCheck(Effort level)
 {
+  d_im.reset();
   Trace("arith-check") << "TheoryArith::postCheck " << level << std::endl;
   // check with the non-linear solver at last call
   if (level == Theory::EFFORT_LAST_CALL)
@@ -176,12 +177,20 @@ void TheoryArith::postCheck(Effort level)
     // linear solver emitted a conflict or lemma, return
     return;
   }
+  if (d_im.hasSent())
+  {
+    return;
+  }
 
   if (Theory::fullEffort(level))
   {
+    d_arithModelCache.clear();
     if (d_nonlinearExtension != nullptr)
     {
+      std::set<Node> termSet;
+      updateModelCache(termSet);
       d_nonlinearExtension->check(level);
+      d_nonlinearExtension->interceptModel(d_arithModelCache, termSet);
     }
     else if (d_internal->foundNonlinear())
     {
@@ -247,59 +256,38 @@ bool TheoryArith::collectModelInfo(TheoryModel* m,
 bool TheoryArith::collectModelValues(TheoryModel* m,
                                      const std::set<Node>& termSet)
 {
-  // get the model from the linear solver
-  std::map<Node, Node> arithModel;
-  d_internal->collectModelValues(termSet, arithModel);
-  // Double check that the model from the linear solver respects integer types,
-  // if it does not, add a branch and bound lemma. This typically should never
-  // be necessary, but is needed in rare cases.
-  bool addedLemma = false;
-  bool badAssignment = false;
-  for (const std::pair<const Node, Node>& p : arithModel)
+  if (Trace.isOn("arith::model"))
   {
-    if (p.first.getType().isInteger() && !p.second.getType().isInteger())
+    Trace("arith::model") << "arithmetic model after pruning" << std::endl;
+    for (const auto& p : d_arithModelCache)
     {
-      Assert(false) << "TheoryArithPrivate generated a bad model value for "
-                       "integer variable "
-                    << p.first << " : " << p.second;
-      // must branch and bound
-      TrustNode lem =
-          d_bab.branchIntegerVariable(p.first, p.second.getConst<Rational>());
-      if (d_im.trustedLemma(lem, InferenceId::ARITH_BB_LEMMA))
-      {
-        addedLemma = true;
-      }
-      badAssignment = true;
+      Trace("arith::model") << "\t" << p.first << " -> " << p.second << std::endl;
     }
   }
-  if (addedLemma)
+
+  updateModelCache(termSet);
+
+  if (sanityCheckIntegerModel())
   {
-    // we had to add a branch and bound lemma since the linear solver assigned
-    // a non-integer value to an integer variable.
+    // We added a lemma
     return false;
   }
-  // this would imply that linear arithmetic's model failed to satisfy a branch
-  // and bound lemma
-  AlwaysAssert(!badAssignment)
-      << "Bad assignment from TheoryArithPrivate::collectModelValues, and no "
-         "branching lemma was sent";
 
-  // if non-linear is enabled, intercept the model, which may repair its values
-  if (d_nonlinearExtension != nullptr)
-  {
-    // Non-linear may repair values to satisfy non-linear constraints (see
-    // documentation for NonlinearExtension::interceptModel).
-    d_nonlinearExtension->interceptModel(arithModel, termSet);
-  }
   // We are now ready to assert the model.
-  for (const std::pair<const Node, Node>& p : arithModel)
+  for (const std::pair<const Node, Node>& p : d_arithModelCache)
   {
+    if (termSet.find(p.first) == termSet.end())
+    {
+      continue;
+    }
     // maps to constant of comparable type
     Assert(p.first.getType().isComparableTo(p.second.getType()));
     if (m->assertEquality(p.first, p.second, true))
     {
       continue;
     }
+    Assert(false) << "A model equality could not be asserted: " << p.first
+                        << " == " << p.second << std::endl;
     // If we failed to assert an equality, it is likely due to theory
     // combination, namely the repaired model for non-linear changed
     // an equality status that was agreed upon by both (linear) arithmetic
@@ -332,7 +320,18 @@ void TheoryArith::presolve(){
 }
 
 EqualityStatus TheoryArith::getEqualityStatus(TNode a, TNode b) {
-  return d_internal->getEqualityStatus(a,b);
+  Debug("arith") << "TheoryArith::getEqualityStatus(" << a << ", " << b << ")" << std::endl;
+  if (d_arithModelCache.empty())
+  {
+    return d_internal->getEqualityStatus(a,b);
+  }
+  Node aval = Rewriter::rewrite(a.substitute(d_arithModelCache.begin(), d_arithModelCache.end()));
+  Node bval = Rewriter::rewrite(b.substitute(d_arithModelCache.begin(), d_arithModelCache.end()));
+  if (aval == bval)
+  {
+    return EQUALITY_TRUE_IN_MODEL;
+  }
+  return EQUALITY_FALSE_IN_MODEL;
 }
 
 Node TheoryArith::getModelValue(TNode var) {
@@ -352,6 +351,62 @@ eq::ProofEqEngine* TheoryArith::getProofEqEngine()
   return d_im.getProofEqEngine();
 }
 
+void TheoryArith::updateModelCache(std::set<Node>& termSet)
+{
+  if (d_arithModelCache.empty())
+  {
+    collectAssertedTerms(termSet);
+    d_internal->collectModelValues(termSet, d_arithModelCache);
+  }
+}
+void TheoryArith::updateModelCache(const std::set<Node>& termSet)
+{
+  if (d_arithModelCache.empty())
+  {
+    d_internal->collectModelValues(termSet, d_arithModelCache);
+  }
+}
+bool TheoryArith::sanityCheckIntegerModel()
+{
+
+    // Double check that the model from the linear solver respects integer types,
+    // if it does not, add a branch and bound lemma. This typically should never
+    // be necessary, but is needed in rare cases.
+    bool addedLemma = false;
+    bool badAssignment = false;
+    Trace("arith-check") << "model:" << std::endl;
+    for (const auto& p : d_arithModelCache)
+    {
+      Trace("arith-check") << p.first << " -> " << p.second << std::endl;
+      if (p.first.getType().isInteger() && !p.second.getType().isInteger())
+      {
+        Assert(false) << "TheoryArithPrivate generated a bad model value for "
+                        "integer variable "
+                      << p.first << " : " << p.second;
+        // must branch and bound
+        TrustNode lem =
+            d_bab.branchIntegerVariable(p.first, p.second.getConst<Rational>());
+        if (d_im.trustedLemma(lem, InferenceId::ARITH_BB_LEMMA))
+        {
+          addedLemma = true;
+        }
+        badAssignment = true;
+      }
+    }
+    if (addedLemma)
+    {
+      // we had to add a branch and bound lemma since the linear solver assigned
+      // a non-integer value to an integer variable.
+      return true;
+    }
+    // this would imply that linear arithmetic's model failed to satisfy a branch
+    // and bound lemma
+    AlwaysAssert(!badAssignment)
+        << "Bad assignment from TheoryArithPrivate::collectModelValues, and no "
+          "branching lemma was sent";
+    return false;
+}
+
 }  // namespace arith
 }  // namespace theory
 }  // namespace cvc5
index 4b0c88fd241e2e97d7c9a80e89ddde214e14c8ef..80e351466981984bdab32a89a1d1d486d4544713 100644 (file)
@@ -125,6 +125,25 @@ class TheoryArith : public Theory {
   }
 
  private:
+  /**
+   * Update d_arithModelCache (if it is empty right now) and compute the termSet
+   * by calling collectAssertedTerms.
+   */
+  void updateModelCache(std::set<Node>& termSet);
+  /**
+   * Update d_arithModelCache (if it is empty right now) using the given
+   * termSet.
+   */
+  void updateModelCache(const std::set<Node>& termSet);
+  /**
+   * Perform a sanity check on the model that all integer variables are assigned
+   * to integer values. If an integer variables is assigned to a non-integer
+   * value, but the respective lemma can not be added (i.e. it has already been
+   * added) an assertion triggers. Otherwise teturns true if a lemma was added,
+   * false otherwise.
+   */
+  bool sanityCheckIntegerModel();
+
   /** Get the proof equality engine */
   eq::ProofEqEngine* getProofEqEngine();
   /** Timer for ppRewrite */
@@ -153,6 +172,17 @@ class TheoryArith : public Theory {
   ArithPreprocess d_arithPreproc;
   /** The theory rewriter for this theory. */
   ArithRewriter d_rewriter;
+
+  /**
+   * Caches the current arithmetic model with the following life cycle:
+   * postCheck retrieves the model from arith_private and puts it into the
+   * cache. If nonlinear reasoning is enabled, the cache is used for (and
+   * possibly updated by) model-based refinement in postCheck.
+   * In collectModelValues, the cache is filtered for the termSet and then
+   * used to augment the TheoryModel.
+   */
+  std::map<Node, Node> d_arithModelCache;
+
 };/* class TheoryArith */
 
 }  // namespace arith
index ea2887c441762f66b7f1fc3ebc56f5048a76ce42..4eff69ef8ce7a821296361662d55de4f462e356b 100644 (file)
@@ -114,7 +114,7 @@ TheoryArithPrivate::TheoryArithPrivate(TheoryArith& containing,
       d_diseqQueue(d_env.getContext(), false),
       d_currentPropagationList(),
       d_learnedBounds(d_env.getContext()),
-      d_preregisteredNodes(d_env.getUserContext()),
+      d_preregisteredNodes(d_env.getContext()),
       d_partialModel(d_env.getContext(), DeltaComputeCallback(*this)),
       d_errorSet(
           d_partialModel, TableauSizes(&d_tableau), BoundCountingLookup(*this)),
index 165e39e7a0d166d13d2c23fe0a226c54b084ab73..3b55c035d47a17258f48d12d1304f13257b544f5 100644 (file)
@@ -1,6 +1,6 @@
 ; COMMAND-LINE: --solve-bv-as-int=sum --bvand-integer-granularity=1
 ; COMMAND-LINE: --solve-bv-as-int=iand --iand-mode=value
-; COMMAND-LINE: --solve-bv-as-int=iand --iand-mode=sum
+; COMMAND-LINE: --solve-bv-as-int=iand --iand-mode=sum --no-check-unsat-cores
 ; COMMAND-LINE: --solve-bv-as-int=bv --no-check-unsat-cores
 ; EXPECT: unsat
 (set-logic ALL)