Refactor non-linear extension for model-based refinement (#3452)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Wed, 13 Nov 2019 05:34:05 +0000 (23:34 -0600)
committerAhmed Irfan <43099566+ahmed-irfan@users.noreply.github.com>
Wed, 13 Nov 2019 05:34:05 +0000 (21:34 -0800)
* Refactor non-linear extension for model-based refinement

* Format

* Minor

* Address

src/theory/arith/nl_model.cpp
src/theory/arith/nl_model.h
src/theory/arith/nonlinear_extension.cpp
src/theory/arith/nonlinear_extension.h

index 571fbda6c3ca8a2f1008c922ed6618a655776aef..62bdf310bec42acffebd0a37934c8fbe91c2654d 100644 (file)
@@ -137,38 +137,16 @@ Node NlModel::computeModelValue(Node n, bool isConcrete)
 Node NlModel::getValueInternal(Node n) const
 {
   return d_model->getValue(n);
-  /*
-  std::map< Node, Node >::const_iterator it = d_arithVal.find(n);
-  if (it!=d_arithVal.end())
-  {
-    return it->second;
-  }
-  return Node::null();
-  */
 }
 
 bool NlModel::hasTerm(Node n) const
 {
   return d_model->hasTerm(n);
-  // return d_arithVal.find(n)!=d_arithVal.end();
 }
 
 Node NlModel::getRepresentative(Node n) const
 {
   return d_model->getRepresentative(n);
-  /*
-  std::map< Node, Node >::const_iterator it = d_arithVal.find(n);
-  if (it!=d_arithVal.end())
-  {
-    std::map< Node, Node >::const_iterator itr = d_valToRep.find(it->second);
-    if (itr != d_valToRep.end())
-    {
-      return itr->second;
-    }
-    Assert(false);
-  }
-  return Node::null();
-  */
 }
 
 int NlModel::compare(Node i, Node j, bool isConcrete, bool isAbsolute)
@@ -1225,12 +1203,37 @@ bool NlModel::getApproximateSqrt(Node c, Node& l, Node& u, unsigned iter) const
   return true;
 }
 
-void NlModel::recordApproximations()
+void NlModel::printModelValue(const char* c, Node n, unsigned prec) const
+{
+  if (Trace.isOn(c))
+  {
+    Trace(c) << "  " << n << " -> ";
+    for (int i = 1; i >= 0; --i)
+    {
+      std::map<Node, Node>::const_iterator it = d_mv[i].find(n);
+      Assert(it != d_mv[i].end());
+      if (it->second.isConst())
+      {
+        printRationalApprox(c, it->second, prec);
+      }
+      else
+      {
+        Trace(c) << "?";
+      }
+      Trace(c) << (i == 1 ? " [actual: " : " ]");
+    }
+    Trace(c) << std::endl;
+  }
+}
+
+void NlModel::getModelValueRepair(std::map<Node, Node>& arithModel,
+                                  std::map<Node, Node>& approximations)
 {
   // Record the approximations we used. This code calls the
   // recordApproximation method of the model, which overrides the model
   // values for variables that we solved for, using techniques specific to
   // this class.
+  Trace("nl-model") << "NlModel::getModelValueRepair:" << std::endl;
   NodeManager* nm = NodeManager::currentNM();
   for (const std::pair<const Node, std::pair<Node, Node> >& cb :
        d_check_model_bounds)
@@ -1241,17 +1244,15 @@ void NlModel::recordApproximations()
     Node v = cb.first;
     if (l != u)
     {
-      pred = nm->mkNode(AND, nm->mkNode(GEQ, v, l), nm->mkNode(GEQ, u, v));
-    }
-    else if (!d_model->areEqual(v, l))
-    {
-      // only record if value was not equal already
-      pred = v.eqNode(l);
+      Node pred = nm->mkNode(AND, nm->mkNode(GEQ, v, l), nm->mkNode(GEQ, u, v));
+      approximations[v] = pred;
+      Trace("nl-model") << v << " approximated as " << pred << std::endl;
     }
-    if (!pred.isNull())
+    else
     {
-      pred = Rewriter::rewrite(pred);
-      d_model->recordApproximation(v, pred);
+      // overwrite
+      arithModel[v] = l;
+      Trace("nl-model") << v << " exact approximation is " << l << std::endl;
     }
   }
   // Also record the exact values we used. An exact value can be seen as a
@@ -1262,36 +1263,12 @@ void NlModel::recordApproximations()
   {
     Node v = d_check_model_vars[i];
     Node s = d_check_model_subs[i];
-    if (!d_model->areEqual(v, s))
-    {
-      Node pred = v.eqNode(s);
-      pred = Rewriter::rewrite(pred);
-      d_model->recordApproximation(v, pred);
-    }
-  }
-}
-void NlModel::printModelValue(const char* c, Node n, unsigned prec) const
-{
-  if (Trace.isOn(c))
-  {
-    Trace(c) << "  " << n << " -> ";
-    for (unsigned i = 0; i < 2; i++)
-    {
-      std::map<Node, Node>::const_iterator it = d_mv[1 - i].find(n);
-      Assert(it != d_mv[1 - i].end());
-      if (it->second.isConst())
-      {
-        printRationalApprox(c, it->second, prec);
-      }
-      else
-      {
-        Trace(c) << "?";  // it->second;
-      }
-      Trace(c) << (i == 0 ? " [actual: " : " ]");
-    }
-    Trace(c) << std::endl;
+    // overwrite
+    arithModel[v] = s;
+    Trace("nl-model") << v << " solved is " << s << std::endl;
   }
 }
+
 }  // namespace arith
 }  // namespace theory
 }  // namespace CVC4
index 19341cc5f6f390eaaaab3f9d0bca40fcb545e33f..ed13327cc27f289bfe950f0aa800248722414999 100644 (file)
@@ -171,12 +171,19 @@ class NlModel
    * term n on Trace c with precision prec.
    */
   void printModelValue(const char* c, Node n, unsigned prec = 5) const;
-  /** record approximations in the current model
+  /** get model value repair
    *
-   * This makes necessary calls that notify the model of any approximations
-   * that were used by this solver.
+   * This gets mappings that indicate how to repair the model generated by the
+   * linear arithmetic solver. This method should be called after a successful
+   * call to checkModel above.
+   *
+   * The mapping arithModel is updated by this method to map arithmetic terms v
+   * to their (exact) value that was computed during checkModel; the mapping
+   * approximations is updated to store approximate values in the form of a
+   * predicate over v.
    */
-  void recordApproximations();
+  void getModelValueRepair(std::map<Node, Node>& arithModel,
+                           std::map<Node, Node>& approximations);
 
  private:
   /** The current model */
index e14f07a4fc3bf914d9f35e32e82224b44939db04..d76089541255ad72c54291810a7ea215b6c99a5f 100644 (file)
@@ -944,8 +944,7 @@ int NonlinearExtension::checkLastCall(const std::vector<Node>& assertions,
         std::vector<Node> repList;
         for (const Node& ac : a)
         {
-          Node r =
-              d_containing.getValuation().getModel()->getRepresentative(ac);
+          Node r = d_model.computeConcreteModelValue(ac);
           repList.push_back(r);
         }
         Node aa = argTrie[ak].add(a, repList);
@@ -1232,11 +1231,7 @@ void NonlinearExtension::check(Theory::Effort e) {
   Trace("nl-ext") << std::endl;
   Trace("nl-ext") << "NonlinearExtension::check, effort = " << e
                   << ", built model = " << d_builtModel.get() << std::endl;
-  if (d_builtModel.get())
-  {
-    // already built model, nothing to do
-  }
-  else if (e == Theory::EFFORT_FULL)
+  if (e == Theory::EFFORT_FULL)
   {
     d_containing.getExtTheory()->clearCache();
     d_needsLastCall = true;
@@ -1255,198 +1250,210 @@ void NonlinearExtension::check(Theory::Effort e) {
   }
   else
   {
-    // get the assertions
-    std::vector<Node> assertions;
-    getAssertions(assertions);
-
-    // reset cached information
+    std::map<Node, Node> approximations;
+    std::map<Node, Node> arithModel;
     TheoryModel* tm = d_containing.getValuation().getModel();
-    d_model.reset(tm);
-
-    Trace("nl-ext-mv-assert")
-        << "Getting model values... check for [model-false]" << std::endl;
-    // get the assertions that are false in the model
-    const std::vector<Node> false_asserts = checkModelEval(assertions);
-
-    // get the extended terms belonging to this theory
-    std::vector<Node> xts;
-    d_containing.getExtTheory()->getTerms(xts);
-
-    if (Trace.isOn("nl-ext-debug"))
-    {
-      Trace("nl-ext-debug")
-          << "  processing NonlinearExtension::check : " << std::endl;
-      Trace("nl-ext-debug") << "     " << false_asserts.size()
-                            << " false assertions" << std::endl;
-      Trace("nl-ext-debug")
-          << "     " << xts.size() << " extended terms: " << std::endl;
-      Trace("nl-ext-debug") << "       ";
-      for (unsigned j = 0; j < xts.size(); j++)
-      {
-        Trace("nl-ext-debug") << xts[j] << " ";
-      }
-      Trace("nl-ext-debug") << std::endl;
-    }
-
-    // compute whether shared terms have correct values
-    unsigned num_shared_wrong_value = 0;
-    std::vector<Node> shared_term_value_splits;
-    // must ensure that shared terms are equal to their concrete value
-    Trace("nl-ext-mv") << "Shared terms : " << std::endl;
-    for (context::CDList<TNode>::const_iterator its =
-             d_containing.shared_terms_begin();
-         its != d_containing.shared_terms_end();
-         ++its)
-    {
-      TNode shared_term = *its;
-      // compute its value in the model, and its evaluation in the model
-      Node stv0 = d_model.computeConcreteModelValue(shared_term);
-      Node stv1 = d_model.computeAbstractModelValue(shared_term);
-      d_model.printModelValue("nl-ext-mv", shared_term);
-      if (stv0 != stv1)
+    if (!d_builtModel.get())
+    {
+      // run model-based refinement
+      if (modelBasedRefinement())
       {
-        num_shared_wrong_value++;
-        Trace("nl-ext-mv") << "Bad shared term value : " << shared_term
-                           << std::endl;
-        if (shared_term != stv0)
-        {
-          // split on the value, this is non-terminating in general, TODO :
-          // improve this
-          Node eq = shared_term.eqNode(stv0);
-          shared_term_value_splits.push_back(eq);
-        }
-        else
-        {
-          // this can happen for transcendental functions
-          // the problem is that we cannot evaluate transcendental functions
-          // (they don't have a rewriter that returns constants)
-          // thus, the actual value in their model can be themselves, hence we
-          // have no reference point to rule out the current model.  In this
-          // case, we may set incomplete below.
-        }
+        return;
       }
     }
-    Trace("nl-ext-debug") << "     " << num_shared_wrong_value
-                          << " shared terms with wrong model value."
-                          << std::endl;
-    bool needsRecheck;
-    do
+    // get the values that should be replaced in the model
+    d_model.getModelValueRepair(arithModel, approximations);
+    // those that are exact are written as exact approximations to the model
+    for (std::pair<const Node, Node>& r : arithModel)
+    {
+      Node eq = r.first.eqNode(r.second);
+      eq = Rewriter::rewrite(eq);
+      tm->recordApproximation(r.first, eq);
+    }
+    // those that are approximate are recorded as approximations
+    for (std::pair<const Node, Node>& a : approximations)
     {
-      d_model.resetCheck();
-      needsRecheck = false;
-      Assert(e == Theory::EFFORT_LAST_CALL);
-      // complete_status:
-      //   1 : we may answer SAT, -1 : we may not answer SAT, 0 : unknown
-      int complete_status = 1;
-      int num_added_lemmas = 0;
-      // we require a check either if an assertion is false or a shared term has
-      // a wrong value
-      if (!false_asserts.empty() || num_shared_wrong_value > 0)
+      tm->recordApproximation(a.first, a.second);
+    }
+  }
+}
+
+bool NonlinearExtension::modelBasedRefinement()
+{
+  // reset the model object
+  d_model.reset(d_containing.getValuation().getModel());
+  // get the assertions
+  std::vector<Node> assertions;
+  getAssertions(assertions);
+
+  Trace("nl-ext-mv-assert")
+      << "Getting model values... check for [model-false]" << std::endl;
+  // get the assertions that are false in the model
+  const std::vector<Node> false_asserts = checkModelEval(assertions);
+
+  // get the extended terms belonging to this theory
+  std::vector<Node> xts;
+  d_containing.getExtTheory()->getTerms(xts);
+
+  if (Trace.isOn("nl-ext-debug"))
+  {
+    Trace("nl-ext-debug") << "  processing NonlinearExtension::check : "
+                          << std::endl;
+    Trace("nl-ext-debug") << "     " << false_asserts.size()
+                          << " false assertions" << std::endl;
+    Trace("nl-ext-debug") << "     " << xts.size()
+                          << " extended terms: " << std::endl;
+    Trace("nl-ext-debug") << "       ";
+    for (unsigned j = 0; j < xts.size(); j++)
+    {
+      Trace("nl-ext-debug") << xts[j] << " ";
+    }
+    Trace("nl-ext-debug") << std::endl;
+  }
+
+  // compute whether shared terms have correct values
+  unsigned num_shared_wrong_value = 0;
+  std::vector<Node> shared_term_value_splits;
+  // must ensure that shared terms are equal to their concrete value
+  Trace("nl-ext-mv") << "Shared terms : " << std::endl;
+  for (context::CDList<TNode>::const_iterator its =
+           d_containing.shared_terms_begin();
+       its != d_containing.shared_terms_end();
+       ++its)
+  {
+    TNode shared_term = *its;
+    // compute its value in the model, and its evaluation in the model
+    Node stv0 = d_model.computeConcreteModelValue(shared_term);
+    Node stv1 = d_model.computeAbstractModelValue(shared_term);
+    d_model.printModelValue("nl-ext-mv", shared_term);
+    if (stv0 != stv1)
+    {
+      num_shared_wrong_value++;
+      Trace("nl-ext-mv") << "Bad shared term value : " << shared_term
+                         << std::endl;
+      if (shared_term != stv0)
       {
-        complete_status = num_shared_wrong_value > 0 ? -1 : 0;
-        num_added_lemmas = checkLastCall(assertions, false_asserts, xts);
-        if (num_added_lemmas > 0)
-        {
-          return;
-        }
+        // split on the value, this is non-terminating in general, TODO :
+        // improve this
+        Node eq = shared_term.eqNode(stv0);
+        shared_term_value_splits.push_back(eq);
       }
-      Trace("nl-ext") << "Finished check with status : " << complete_status
-                      << std::endl;
+      else
+      {
+        // this can happen for transcendental functions
+        // the problem is that we cannot evaluate transcendental functions
+        // (they don't have a rewriter that returns constants)
+        // thus, the actual value in their model can be themselves, hence we
+        // have no reference point to rule out the current model.  In this
+        // case, we may set incomplete below.
+      }
+    }
+  }
+  Trace("nl-ext-debug") << "     " << num_shared_wrong_value
+                        << " shared terms with wrong model value." << std::endl;
+  bool needsRecheck;
+  do
+  {
+    d_model.resetCheck();
+    needsRecheck = false;
+    // complete_status:
+    //   1 : we may answer SAT, -1 : we may not answer SAT, 0 : unknown
+    int complete_status = 1;
+    int num_added_lemmas = 0;
+    // we require a check either if an assertion is false or a shared term has
+    // a wrong value
+    if (!false_asserts.empty() || num_shared_wrong_value > 0)
+    {
+      complete_status = num_shared_wrong_value > 0 ? -1 : 0;
+      num_added_lemmas = checkLastCall(assertions, false_asserts, xts);
+      if (num_added_lemmas > 0)
+      {
+        return true;
+      }
+    }
+    Trace("nl-ext") << "Finished check with status : " << complete_status
+                    << std::endl;
 
-      // if we did not add a lemma during check and there is a chance for SAT
-      if (complete_status == 0)
+    // if we did not add a lemma during check and there is a chance for SAT
+    if (complete_status == 0)
+    {
+      Trace("nl-ext")
+          << "Check model based on bounds for irrational-valued functions..."
+          << std::endl;
+      // check the model based on simple solving of equalities and using
+      // error bounds on the Taylor approximation of transcendental functions.
+      if (checkModel(assertions, false_asserts))
       {
-        Trace("nl-ext")
-            << "Check model based on bounds for irrational-valued functions..."
-            << std::endl;
-        // check the model based on simple solving of equalities and using
-        // error bounds on the Taylor approximation of transcendental functions.
-        if (checkModel(assertions, false_asserts))
-        {
-          complete_status = 1;
-        }
+        complete_status = 1;
       }
+    }
 
-      // if we have not concluded SAT
-      if (complete_status != 1)
+    // if we have not concluded SAT
+    if (complete_status != 1)
+    {
+      // flush the waiting lemmas
+      num_added_lemmas = flushLemmas(d_waiting_lemmas);
+      if (num_added_lemmas > 0)
       {
-        // flush the waiting lemmas
-        num_added_lemmas = flushLemmas(d_waiting_lemmas);
-        if (num_added_lemmas > 0)
-        {
-          Trace("nl-ext") << "...added " << num_added_lemmas
-                          << " waiting lemmas." << std::endl;
-          return;
-        }
-        // resort to splitting on shared terms with their model value
-        // if we did not add any lemmas
-        if (num_shared_wrong_value > 0)
+        Trace("nl-ext") << "...added " << num_added_lemmas << " waiting lemmas."
+                        << std::endl;
+        return true;
+      }
+      // resort to splitting on shared terms with their model value
+      // if we did not add any lemmas
+      if (num_shared_wrong_value > 0)
+      {
+        complete_status = -1;
+        if (!shared_term_value_splits.empty())
         {
-          complete_status = -1;
-          if (!shared_term_value_splits.empty())
+          std::vector<Node> shared_term_value_lemmas;
+          for (const Node& eq : shared_term_value_splits)
           {
-            std::vector<Node> shared_term_value_lemmas;
-            for (const Node& eq : shared_term_value_splits)
-            {
-              Node req = Rewriter::rewrite(eq);
-              Node literal = d_containing.getValuation().ensureLiteral(req);
-              d_containing.getOutputChannel().requirePhase(literal, true);
-              Trace("nl-ext-debug") << "Split on : " << literal << std::endl;
-              shared_term_value_lemmas.push_back(
-                  literal.orNode(literal.negate()));
-            }
-            num_added_lemmas = flushLemmas(shared_term_value_lemmas);
-            if (num_added_lemmas > 0)
-            {
-              Trace("nl-ext")
-                  << "...added " << num_added_lemmas
-                  << " shared term value split lemmas." << std::endl;
-              return;
-            }
+            Node req = Rewriter::rewrite(eq);
+            Node literal = d_containing.getValuation().ensureLiteral(req);
+            d_containing.getOutputChannel().requirePhase(literal, true);
+            Trace("nl-ext-debug") << "Split on : " << literal << std::endl;
+            shared_term_value_lemmas.push_back(
+                literal.orNode(literal.negate()));
           }
-          else
+          num_added_lemmas = flushLemmas(shared_term_value_lemmas);
+          if (num_added_lemmas > 0)
           {
-            // this can happen if we are trying to do theory combination with
-            // trancendental functions
-            // since their model value cannot even be computed exactly
+            Trace("nl-ext") << "...added " << num_added_lemmas
+                            << " shared term value split lemmas." << std::endl;
+            return true;
           }
         }
-
-        // we are incomplete
-        if (options::nlExtIncPrecision() && d_model.usedApproximate())
-        {
-          d_taylor_degree++;
-          needsRecheck = true;
-          // increase precision for PI?
-          // Difficult since Taylor series is very slow to converge
-          Trace("nl-ext") << "...increment Taylor degree to " << d_taylor_degree
-                          << std::endl;
-        }
         else
         {
-          Trace("nl-ext") << "...failed to send lemma in "
-                             "NonLinearExtension, set incomplete"
-                          << std::endl;
-          d_containing.getOutputChannel().setIncomplete();
+          // this can happen if we are trying to do theory combination with
+          // trancendental functions
+          // since their model value cannot even be computed exactly
         }
       }
 
-    } while (needsRecheck);
-  }
-
-  // Did we internally determine a model exists? If so, we need to record some
-  // information in the theory engine's model class.
-  if (d_builtModel.get())
-  {
-    if (e < Theory::EFFORT_LAST_CALL)
-    {
-      // don't need to build the model yet
-      return;
+      // we are incomplete
+      if (options::nlExtIncPrecision() && d_model.usedApproximate())
+      {
+        d_taylor_degree++;
+        needsRecheck = true;
+        // increase precision for PI?
+        // Difficult since Taylor series is very slow to converge
+        Trace("nl-ext") << "...increment Taylor degree to " << d_taylor_degree
+                        << std::endl;
+      }
+      else
+      {
+        Trace("nl-ext") << "...failed to send lemma in "
+                           "NonLinearExtension, set incomplete"
+                        << std::endl;
+        d_containing.getOutputChannel().setIncomplete();
+      }
     }
-    // record approximations in the model
-    d_model.recordApproximations();
-    return;
-  }
+  } while (needsRecheck);
+
+  // did not add lemmas
+  return false;
 }
 
 void NonlinearExtension::presolve()
index b76877414788a851846ccb1bc15423edf1ca0bdc..1690c933461fc8fe5c02ad3d26e8832688ebd7bf 100644 (file)
@@ -63,8 +63,7 @@ typedef std::map<Node, unsigned> NodeMultiset;
  *
  * It's main functionality is a check(...) method,
  * which is called by TheoryArithPrivate either:
- * (1) at full effort with no conflicts or lemmas emitted,
- * or
+ * (1) at full effort with no conflicts or lemmas emitted, or
  * (2) at last call effort.
  * In this method, this class calls d_out->lemma(...)
  * for valid arithmetic theory lemmas, based on the current set of assertions,
@@ -115,9 +114,8 @@ class NonlinearExtension {
                                       const std::vector<Node>& exp) const;
   /** Check at effort level e.
    *
-   * This call may result in (possibly multiple)
-   * calls to d_out->lemma(...) where d_out
-   * is the output channel of TheoryArith.
+   * This call may result in (possibly multiple) calls to d_out->lemma(...)
+   * where d_out is the output channel of TheoryArith.
    */
   void check(Theory::Effort e);
   /** Does this class need a call to check(...) at last call effort? */
@@ -131,6 +129,22 @@ class NonlinearExtension {
    */
   void presolve();
  private:
+  /** Model-based refinement
+   *
+   * This is the main entry point of this class for generating lemmas on the
+   * output channel of the theory of arithmetic.
+   *
+   * It is currently run at last call effort. It applies lemma schemas
+   * described in Reynolds et al. FroCoS 2017 that are based on ruling out
+   * the current candidate model.
+   *
+   * This function returns true if a lemma was sent out on the output
+   * channel of the theory of arithmetic. Otherwise, it returns false. In the
+   * latter case, the model object d_model may have information regarding
+   * how to construct a model, in the case that we determined the problem
+   * is satisfiable.
+   */
+  bool modelBasedRefinement();
   /** returns true if the multiset containing the
    * factors of monomial a is a subset of the multiset
    * containing the factors of monomial b.