Refactor check interface of nonlinear extension (#7235)
authorGereon Kremer <nafur42@gmail.com>
Thu, 23 Sep 2021 21:20:42 +0000 (14:20 -0700)
committerGitHub <noreply@github.com>
Thu, 23 Sep 2021 21:20:42 +0000 (16:20 -0500)
This PR does a first step for refactoring the main check interface of the nonlinear extension. It does not change anything yet, but merely moves code around.

src/theory/arith/nl/nonlinear_extension.cpp
src/theory/arith/nl/nonlinear_extension.h
src/theory/arith/theory_arith.cpp

index b8170df457e9cd1fe8ca853d7757662680299bc1..f80717b578f442dca1f237b63fc1856a19369d5e 100644 (file)
@@ -44,7 +44,7 @@ NonlinearExtension::NonlinearExtension(Env& env,
       d_containing(containing),
       d_astate(state),
       d_im(containing.getInferenceManager()),
-      d_needsLastCall(false),
+      d_hasNlTerms(false),
       d_checkCounter(0),
       d_extTheoryCb(state.getEqualityEngine()),
       d_extTheory(d_extTheoryCb, context(), userContext(), d_im),
@@ -187,25 +187,20 @@ void NonlinearExtension::getAssertions(std::vector<Node>& assertions)
                   << std::endl;
 }
 
-std::vector<Node> NonlinearExtension::checkModelEval(
+std::vector<Node> NonlinearExtension::getUnsatisfiedAssertions(
     const std::vector<Node>& assertions)
 {
   std::vector<Node> false_asserts;
-  for (size_t i = 0; i < assertions.size(); ++i)
+  for (const auto& lit : assertions)
   {
-    Node lit = assertions[i];
-    Node atom = lit.getKind() == NOT ? lit[0] : lit;
     Node litv = d_model.computeConcreteModelValue(lit);
     Trace("nl-ext-mv-assert") << "M[[ " << lit << " ]] -> " << litv;
     if (litv != d_true)
     {
-      Trace("nl-ext-mv-assert") << " [model-false]" << std::endl;
+      Trace("nl-ext-mv-assert") << " [model-false]";
       false_asserts.push_back(lit);
     }
-    else
-    {
-      Trace("nl-ext-mv-assert") << std::endl;
-    }
+    Trace("nl-ext-mv-assert") << std::endl;
   }
   return false_asserts;
 }
@@ -245,60 +240,72 @@ bool NonlinearExtension::checkModel(const std::vector<Node>& assertions)
   return ret;
 }
 
-void NonlinearExtension::check(Theory::Effort e)
+void NonlinearExtension::checkFullEffort(std::map<Node, Node>& arithModel,
+                                         const std::set<Node>& termSet)
 {
-  Trace("nl-ext") << std::endl;
-  Trace("nl-ext") << "NonlinearExtension::check, effort = " << e << std::endl;
-  if (e == Theory::EFFORT_FULL)
+  Trace("nl-ext") << "NonlinearExtension::checkFullEffort" << std::endl;
+
+  d_hasNlTerms = true;
+  if (options().arith.nlExtRewrites)
   {
-    d_needsLastCall = true;
-    if (options().arith.nlExtRewrites)
+    std::vector<Node> nred;
+    if (!d_extTheory.doInferences(0, nred))
     {
-      std::vector<Node> nred;
-      if (!d_extTheory.doInferences(0, nred))
-      {
-        Trace("nl-ext") << "...sent no lemmas, # extf to reduce = "
-                        << nred.size() << std::endl;
-        if (nred.empty())
-        {
-          d_needsLastCall = false;
-        }
-      }
-      else
+      Trace("nl-ext") << "...sent no lemmas, # extf to reduce = " << nred.size()
+                      << std::endl;
+      if (nred.empty())
       {
-        Trace("nl-ext") << "...sent lemmas." << std::endl;
+        d_hasNlTerms = false;
       }
     }
-  }
-  else
-  {
-    // If we computed lemmas during collectModelInfo, send them now.
-    if (d_im.hasPendingLemma())
+    else
     {
-      d_im.doPendingFacts();
-      d_im.doPendingLemmas();
-      d_im.doPendingPhaseRequirements();
-      return;
+      Trace("nl-ext") << "...sent lemmas." << std::endl;
     }
-    // Otherwise, we will answer SAT. The values that we approximated are
-    // recorded as approximations here.
-    TheoryModel* tm = d_containing.getValuation().getModel();
-    for (std::pair<const Node, std::pair<Node, Node>>& a : d_approximations)
+  }
+
+  if (!hasNlTerms())
+  {
+    // no non-linear constraints, we are done
+    return;
+  }
+  Trace("nl-ext") << "NonlinearExtension::interceptModel begin" << std::endl;
+  d_model.reset(d_containing.getValuation().getModel(), arithModel);
+  // run a last call effort check
+  Trace("nl-ext") << "interceptModel: do model-based refinement" << std::endl;
+  Result::Sat res = modelBasedRefinement(termSet);
+  if (res == Result::Sat::SAT)
+  {
+    Trace("nl-ext") << "interceptModel: do model repair" << std::endl;
+    d_approximations.clear();
+    d_witnesses.clear();
+    // modify the model values
+    d_model.getModelValueRepair(arithModel,
+                                d_approximations,
+                                d_witnesses,
+                                options().smt.modelWitnessValue);
+  }
+}
+
+void NonlinearExtension::finalizeModel(TheoryModel* tm)
+{
+  Trace("nl-ext") << "NonlinearExtension::finalizeModel" << std::endl;
+
+  for (std::pair<const Node, std::pair<Node, Node>>& a : d_approximations)
+  {
+    if (a.second.second.isNull())
     {
-      if (a.second.second.isNull())
-      {
-        tm->recordApproximation(a.first, a.second.first);
-      }
-      else
-      {
-        tm->recordApproximation(a.first, a.second.first, a.second.second);
-      }
+      tm->recordApproximation(a.first, a.second.first);
     }
-    for (const auto& vw : d_witnesses)
+    else
     {
-      tm->recordApproximation(vw.first, vw.second);
+      tm->recordApproximation(a.first, a.second.first, a.second.second);
     }
   }
+  for (const auto& vw : d_witnesses)
+  {
+    tm->recordApproximation(vw.first, vw.second);
+  }
 }
 
 Result::Sat NonlinearExtension::modelBasedRefinement(const std::set<Node>& termSet)
@@ -313,7 +320,7 @@ Result::Sat NonlinearExtension::modelBasedRefinement(const std::set<Node>& termS
   Trace("nl-ext-mv-assert")
       << "Getting model values... check for [model-false]" << std::endl;
   // get the assertions that are false in the model
-  const std::vector<Node> false_asserts = checkModelEval(assertions);
+  const std::vector<Node> false_asserts = getUnsatisfiedAssertions(assertions);
   Trace("nl-ext") << "# false asserts = " << false_asserts.size() << std::endl;
 
   // get the extended terms belonging to this theory
@@ -502,37 +509,6 @@ Result::Sat NonlinearExtension::modelBasedRefinement(const std::set<Node>& termS
   return Result::Sat::SAT;
 }
 
-void NonlinearExtension::interceptModel(std::map<Node, Node>& arithModel,
-                                        const std::set<Node>& termSet)
-{
-  if (!needsCheckLastEffort())
-  {
-    // no non-linear constraints, we are done
-    return;
-  }
-  Trace("nl-ext") << "NonlinearExtension::interceptModel begin" << std::endl;
-  d_model.reset(d_containing.getValuation().getModel(), arithModel);
-  // run a last call effort check
-  Trace("nl-ext") << "interceptModel: do model-based refinement" << std::endl;
-  Result::Sat res = modelBasedRefinement(termSet);
-  if (res == Result::Sat::SAT)
-  {
-    Trace("nl-ext") << "interceptModel: do model repair" << std::endl;
-    d_approximations.clear();
-    d_witnesses.clear();
-    // modify the model values
-    d_model.getModelValueRepair(arithModel,
-                                d_approximations,
-                                d_witnesses,
-                                options().smt.modelWitnessValue);
-  }
-}
-
-void NonlinearExtension::presolve()
-{
-  Trace("nl-ext") << "NonlinearExtension::presolve" << std::endl;
-}
-
 void NonlinearExtension::runStrategy(Theory::Effort effort,
                                      const std::vector<Node>& assertions,
                                      const std::vector<Node>& false_asserts,
index f3d6522811a992c8fd4e25cbcdbb253a7de274bc..53e0db90efe94e002ce409dec050184f52929458 100644 (file)
@@ -91,54 +91,33 @@ class NonlinearExtension : EnvObj
    * Does non-context dependent setup for a node connected to a theory.
    */
   void preRegisterTerm(TNode n);
-  /** Check at effort level e.
-   *
-   * This call may result in (possibly multiple) calls to d_im.lemma(...)
-   * where d_im is the inference manager of TheoryArith.
-   *
-   * If e is FULL, then we add lemmas based on context-depedent
-   * simplification (see Reynolds et al FroCoS 2017).
-   *
-   * If e is LAST_CALL, we add lemmas based on model-based refinement
-   * (see additionally Cimatti et al., TACAS 2017). The lemmas added at this
-   * effort may be computed during a call to interceptModel as described below.
-   */
-  void check(Theory::Effort e);
-  /** intercept model
-   *
-   * This method is called during TheoryArith::collectModelInfo, which is
-   * invoked after the linear arithmetic solver passes a full effort check
-   * with no lemmas.
+
+  /**
+   * Performs the main checks for nonlinear arithmetic, based on the current
+   * (linear) arithmetic model from `arithModel`. This method may already send
+   * lemmas, but most lemmas are stored and only sent when finalizeModel
+   * is called.
    *
    * The argument arithModel is a map of the form { v1 -> c1, ..., vn -> cn }
    * which represents the linear arithmetic theory solver's contribution to the
-   * current candidate model. That is, its collectModelInfo method is requesting
-   * that equalities v1 = c1, ..., vn = cn be added to the current model, where
-   * v1, ..., vn are arithmetic variables and c1, ..., cn are constants. Notice
-   * arithmetic variables may be real-valued terms belonging to other theories,
-   * or abstractions of applications of multiplication (kind NONLINEAR_MULT).
+   * current candidate model where v1, ..., vn are arithmetic variables and
+   * c1, ..., cn are constants. Note, that arithmetic variables may be
+   * real-valued terms belonging to other theories, or abstractions of
+   * applications of multiplication (kind NONLINEAR_MULT).
    *
-   * This method requests that the non-linear solver inspect this model and
-   * do any number of the following:
-   * (1) Construct lemmas based on a model-based refinement procedure inspired
-   * by Cimatti et al., TACAS 2017.,
-   * (2) In the case that the nonlinear solver finds that the current
-   * constraints are satisfiable, it may "repair" the values in the argument
-   * arithModel so that it satisfies certain nonlinear constraints. This may
-   * involve e.g. solving for variables in nonlinear equations.
+   * The argument termSet is the set of terms that is currently appearing in the
+   * assertions.
    */
-  void interceptModel(std::map<Node, Node>& arithModel,
-                      const std::set<Node>& termSet);
-  /** Does this class need a call to check(...) at last call effort? */
-  bool needsCheckLastEffort() const { return d_needsLastCall; }
-  /** presolve
-   *
-   * This function is called during TheoryArith's presolve command.
-   * In this function, we send lemmas we accumulated during preprocessing,
-   * for instance, definitional lemmas from expandDefinitions are sent out
-   * on the output channel of TheoryArith in this function.
+  void checkFullEffort(std::map<Node, Node>& arithModel,
+                       const std::set<Node>& termSet);
+
+  /**
+   * Finalize the given model by adding approximations and witnesses.
    */
-  void presolve();
+  void finalizeModel(TheoryModel* tm);
+
+  /** Does this class need a call to check(...) at last call effort? */
+  bool hasNlTerms() const { return d_hasNlTerms; }
 
   /** Process side effect se */
   void processSideEffect(const NlLemma& se);
@@ -179,7 +158,8 @@ class NonlinearExtension : EnvObj
    * whose model value cannot be computed is included in the return value of
    * this function.
    */
-  std::vector<Node> checkModelEval(const std::vector<Node>& assertions);
+  std::vector<Node> getUnsatisfiedAssertions(
+      const std::vector<Node>& assertions);
 
   //---------------------------check model
   /** Check model
@@ -227,7 +207,7 @@ class NonlinearExtension : EnvObj
   /** The statistics class */
   NlStats d_stats;
   // needs last call effort
-  bool d_needsLastCall;
+  bool d_hasNlTerms;
   /**
    * The number of times we have the called main check method
    * (modelBasedRefinement). This counter is used for interleaving strategies.
index 03fb06a965e2b28fcc158e7fc7780840c5010e86..5a2d1a397827d905cfc694b0edc0fbd1332c45b9 100644 (file)
@@ -167,7 +167,15 @@ void TheoryArith::postCheck(Effort level)
   {
     if (d_nonlinearExtension != nullptr)
     {
-      d_nonlinearExtension->check(level);
+      // If we computed lemmas in the last FULL_EFFORT check, send them now.
+      if (d_im.hasPendingLemma())
+      {
+        d_im.doPendingFacts();
+        d_im.doPendingLemmas();
+        d_im.doPendingPhaseRequirements();
+        return;
+      }
+      d_nonlinearExtension->finalizeModel(getValuation().getModel());
     }
     return;
   }
@@ -189,8 +197,7 @@ void TheoryArith::postCheck(Effort level)
     {
       std::set<Node> termSet;
       updateModelCache(termSet);
-      d_nonlinearExtension->check(level);
-      d_nonlinearExtension->interceptModel(d_arithModelCache, termSet);
+      d_nonlinearExtension->checkFullEffort(d_arithModelCache, termSet);
     }
     else if (d_internal->foundNonlinear())
     {
@@ -223,7 +230,7 @@ bool TheoryArith::preNotifyFact(
 bool TheoryArith::needsCheckLastEffort() {
   if (d_nonlinearExtension != nullptr)
   {
-    return d_nonlinearExtension->needsCheckLastEffort();
+    return d_nonlinearExtension->hasNlTerms();
   }
   return false;
 }
@@ -313,10 +320,6 @@ void TheoryArith::notifyRestart(){
 
 void TheoryArith::presolve(){
   d_internal->presolve();
-  if (d_nonlinearExtension != nullptr)
-  {
-    d_nonlinearExtension->presolve();
-  }
 }
 
 EqualityStatus TheoryArith::getEqualityStatus(TNode a, TNode b) {