Refactor collectModelInfo in TheoryArith (#5027)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Wed, 16 Sep 2020 15:21:40 +0000 (10:21 -0500)
committerGitHub <noreply@github.com>
Wed, 16 Sep 2020 15:21:40 +0000 (10:21 -0500)
This is work towards updating the arithmetic solver to the new standard, and in particular isolating TheoryArithPrivate as the "linear solver", and TheoryArith as the overall approach for arithmetic.

This transfers ownership of the non-linear extension from TheoryArithPrivate to TheoryArith. The former still has a pointer to the non-linear extension, which will be removed with further refactoring.

This PR additionally moves the code that handles the interplay of the non-linear extension in TheoryArithPrivate::collectModelInfo to TheoryArith, and simplifies the model interface for TheoryArithPrivate.

src/theory/arith/theory_arith.cpp
src/theory/arith/theory_arith.h
src/theory/arith/theory_arith_private.cpp
src/theory/arith/theory_arith_private.h

index 1436198a892e03854c584b904633e5bfe61e7178..4884d84842448ba1e8864bd59dc907cb50afad9c 100644 (file)
@@ -21,6 +21,7 @@
 #include "smt/smt_statistics_registry.h"
 #include "theory/arith/arith_rewriter.h"
 #include "theory/arith/infer_bounds.h"
+#include "theory/arith/nl/nonlinear_extension.h"
 #include "theory/arith/theory_arith_private.h"
 #include "theory/ext_theory.h"
 
@@ -42,7 +43,8 @@ TheoryArith::TheoryArith(context::Context* c,
           new TheoryArithPrivate(*this, c, u, out, valuation, logicInfo, pnm)),
       d_ppRewriteTimer("theory::arith::ppRewriteTimer"),
       d_astate(*d_internal, c, u, valuation),
-      d_inferenceManager(*this, d_astate, pnm)
+      d_inferenceManager(*this, d_astate, pnm),
+      d_nonlinearExtension(nullptr)
 {
   smtStatisticsRegistry()->registerStat(&d_ppRewriteTimer);
 
@@ -76,6 +78,13 @@ void TheoryArith::finishInit()
     d_valuation.setUnevaluatedKind(kind::SINE);
     d_valuation.setUnevaluatedKind(kind::PI);
   }
+  // only need to create nonlinear extension if non-linear logic
+  const LogicInfo& logicInfo = getLogicInfo();
+  if (logicInfo.isTheoryEnabled(THEORY_ARITH) && !logicInfo.isLinear())
+  {
+    d_nonlinearExtension.reset(
+        new nl::NonlinearExtension(*this, d_equalityEngine));
+  }
   // finish initialize internally
   d_internal->finishInit();
 }
@@ -123,7 +132,53 @@ void TheoryArith::propagate(Effort e) {
 }
 bool TheoryArith::collectModelInfo(TheoryModel* m)
 {
-  return d_internal->collectModelInfo(m);
+  std::set<Node> termSet;
+  // Work out which variables are needed
+  const std::set<Kind>& irrKinds = m->getIrrelevantKinds();
+  computeAssertedTerms(termSet, irrKinds);
+  // this overrides behavior to not assert equality engine
+  return collectModelValues(m, termSet);
+}
+
+bool TheoryArith::collectModelValues(TheoryModel* m,
+                                     const std::set<Node>& termSet)
+{
+  // get the model from the linear solver
+  std::map<Node, Node> arithModel;
+  d_internal->collectModelValues(termSet, arithModel);
+  // if non-linear is enabled, intercept the model, which may repair its values
+  if (d_nonlinearExtension != nullptr)
+  {
+    // Non-linear may repair values to satisfy non-linear constraints (see
+    // documentation for NonlinearExtension::interceptModel).
+    d_nonlinearExtension->interceptModel(arithModel);
+  }
+  // We are now ready to assert the model.
+  for (const std::pair<const Node, Node>& p : arithModel)
+  {
+    // maps to constant of comparable type
+    Assert(p.first.getType().isComparableTo(p.second.getType()));
+    Assert(p.second.isConst());
+    if (m->assertEquality(p.first, p.second, true))
+    {
+      continue;
+    }
+    // If we failed to assert an equality, it is likely due to theory
+    // combination, namely the repaired model for non-linear changed
+    // an equality status that was agreed upon by both (linear) arithmetic
+    // and another theory. In this case, we must add a lemma, or otherwise
+    // we would terminate with an invalid model. Thus, we add a splitting
+    // lemma of the form ( x = v V x != v ) where v is the model value
+    // assigned by the non-linear solver to x.
+    if (d_nonlinearExtension != nullptr)
+    {
+      Node eq = p.first.eqNode(p.second);
+      Node lem = NodeManager::currentNM()->mkNode(kind::OR, eq, eq.negate());
+      d_out->lemma(lem);
+    }
+    return false;
+  }
+  return true;
 }
 
 void TheoryArith::notifyRestart(){
index 4851f1c5d131b7a701786fd8ff17d0ac7fb1c420..30ad724cca4dff3ffd007d7bd393f5a64abdeb00 100644 (file)
 
 namespace CVC4 {
 namespace theory {
-
 namespace arith {
 
+namespace nl {
+class NonlinearExtension;
+}
+
 /**
- * Implementation of QF_LRA.
- * Based upon:
+ * Implementation of linear and non-linear integer and real arithmetic.
+ * The linear arithmetic solver is based upon:
  * http://research.microsoft.com/en-us/um/people/leonardo/cav06.pdf
  */
 class TheoryArith : public Theory {
@@ -78,6 +81,11 @@ class TheoryArith : public Theory {
   TrustNode explain(TNode n) override;
 
   bool collectModelInfo(TheoryModel* m) override;
+  /**
+   * Collect model values in m based on the relevant terms given by termSet.
+   */
+  bool collectModelValues(TheoryModel* m,
+                          const std::set<Node>& termSet) override;
 
   void shutdown() override {}
 
@@ -110,6 +118,11 @@ class TheoryArith : public Theory {
   /** The arith::InferenceManager. */
   InferenceManager d_inferenceManager;
 
+  /**
+   * The non-linear extension, responsible for all approaches for non-linear
+   * arithmetic.
+   */
+  std::unique_ptr<nl::NonlinearExtension> d_nonlinearExtension;
 };/* class TheoryArith */
 
 }/* CVC4::theory::arith namespace */
index 1b49b73508efa2e4d7c4a661aa157cbcb20ee8a8..8595e26b53fa7dcc1e6f4f375cd1108376b5f455 100644 (file)
@@ -164,7 +164,6 @@ TheoryArithPrivate::TheoryArithPrivate(TheoryArith& containing,
 TheoryArithPrivate::~TheoryArithPrivate(){
   if(d_treeLog != NULL){ delete d_treeLog; }
   if(d_approxStats != NULL) { delete d_approxStats; }
-  if(d_nonlinearExtension != NULL) { delete d_nonlinearExtension; }
 }
 
 TheoryRewriter* TheoryArithPrivate::getTheoryRewriter() { return &d_rewriter; }
@@ -177,12 +176,7 @@ void TheoryArithPrivate::finishInit()
   eq::EqualityEngine* ee = d_containing.getEqualityEngine();
   Assert(ee != nullptr);
   d_congruenceManager.finishInit(ee);
-  const LogicInfo& logicInfo = getLogicInfo();
-  // only need to create nonlinear extension if non-linear logic
-  if (logicInfo.isTheoryEnabled(THEORY_ARITH) && !logicInfo.isLinear())
-  {
-    d_nonlinearExtension = new nl::NonlinearExtension(d_containing, ee);
-  }
+  d_nonlinearExtension = d_containing.d_nonlinearExtension.get();
 }
 
 static bool contains(const ConstraintCPVec& v, ConstraintP con){
@@ -4074,7 +4068,8 @@ Rational TheoryArithPrivate::deltaValueForTotalOrder() const{
   return belowMin;
 }
 
-bool TheoryArithPrivate::collectModelInfo(TheoryModel* m)
+void TheoryArithPrivate::collectModelValues(const std::set<Node>& termSet,
+                                            std::map<Node, Node>& arithModel)
 {
   AlwaysAssert(d_qflraStatus == Result::SAT);
   //AlwaysAssert(!d_nlIncomplete, "Arithmetic solver cannot currently produce models for input with nonlinear arithmetic constraints");
@@ -4085,10 +4080,6 @@ bool TheoryArithPrivate::collectModelInfo(TheoryModel* m)
 
   Debug("arith::collectModelInfo") << "collectModelInfo() begin " << endl;
 
-  std::set<Node> termSet;
-  const std::set<Kind>& irrKinds = m->getIrrelevantKinds();
-  d_containing.computeAssertedTerms(termSet, irrKinds, true);
-
   // Delta lasts at least the duration of the function call
   const Rational& delta = d_partialModel.getDelta();
   std::unordered_set<TNode, TNodeHashFunction> shared = d_containing.currentlySharedTerms();
@@ -4096,8 +4087,6 @@ bool TheoryArithPrivate::collectModelInfo(TheoryModel* m)
   // TODO:
   // This is not very good for user push/pop....
   // Revisit when implementing push/pop
-  // Map of terms to values, constructed when non-linear arithmetic is active.
-  std::map<Node, Node> arithModel;
   for(var_iterator vi = var_begin(), vend = var_end(); vi != vend; ++vi){
     ArithVar v = *vi;
 
@@ -4112,56 +4101,20 @@ bool TheoryArithPrivate::collectModelInfo(TheoryModel* m)
 
         Node qNode = mkRationalNode(qmodel);
         Debug("arith::collectModelInfo") << "m->assertEquality(" << term << ", " << qmodel << ", true)" << endl;
-        if (d_nonlinearExtension != nullptr)
-        {
-          // Let non-linear extension inspect the values before they are sent
-          // to the theory model.
-          arithModel[term] = qNode;
-        }
-        else
-        {
-          if (!m->assertEquality(term, qNode, true))
-          {
-            return false;
-          }
-        }
+        // Add to the map
+        arithModel[term] = qNode;
       }else{
         Debug("arith::collectModelInfo") << "Skipping m->assertEquality(" << term << ", true)" << endl;
 
       }
     }
   }
-  if (d_nonlinearExtension != nullptr)
-  {
-    // Non-linear may repair values to satisfy non-linear constraints (see
-    // documentation for NonlinearExtension::interceptModel).
-    d_nonlinearExtension->interceptModel(arithModel);
-    // We are now ready to assert the model.
-    for (std::pair<const Node, Node>& p : arithModel)
-    {
-      if (!m->assertEquality(p.first, p.second, true))
-      {
-        // If we failed to assert an equality, it is likely due to theory
-        // combination, namely the repaired model for non-linear changed
-        // an equality status that was agreed upon by both (linear) arithmetic
-        // and another theory. In this case, we must add a lemma, or otherwise
-        // we would terminate with an invalid model. Thus, we add a splitting
-        // lemma of the form ( x = v V x != v ) where v is the model value
-        // assigned by the non-linear solver to x.
-        Node eq = p.first.eqNode(p.second);
-        Node lem = NodeManager::currentNM()->mkNode(kind::OR, eq, eq.negate());
-        d_containing.d_out->lemma(lem);
-        return false;
-      }
-    }
-  }
 
   // Iterate over equivalence classes in LinearEqualityModule
   // const eq::EqualityEngine& ee = d_congruenceManager.getEqualityEngine();
   // m->assertEqualityEngine(&ee);
 
   Debug("arith::collectModelInfo") << "collectModelInfo() end " << endl;
-  return true;
 }
 
 bool TheoryArithPrivate::safeToReset() const {
index d0428f2ef6447587ad66b1f684517f89eba96d4b..6d030dece8c46c5e635c9ee358d69f62e7fed994 100644 (file)
@@ -371,7 +371,7 @@ public:
   FCSimplexDecisionProcedure d_fcSimplex;
   SumOfInfeasibilitiesSPD d_soiSimplex;
   AttemptSolutionSDP d_attemptSolSimplex;
-  
+
   /** non-linear algebraic approach */
   nl::NonlinearExtension* d_nonlinearExtension;
 
@@ -456,6 +456,18 @@ public:
   Rational deltaValueForTotalOrder() const;
 
   bool collectModelInfo(TheoryModel* m);
+  /**
+   * Collect model values. This is the main method for extracting information
+   * about how to construct the model. This method relies on the caller for
+   * processing the map, which is done so that other modules (e.g. the
+   * non-linear extension) can modify arithModel before it is sent to the model.
+   *
+   * @param termSet The set of relevant terms
+   * @param arithModel Mapping from terms (of real type) to their values. The
+   * caller should assert equalities to the model for each entry in this map.
+   */
+  void collectModelValues(const std::set<Node>& termSet,
+                          std::map<Node, Node>& arithModel);
 
   void shutdown(){ }