Use new inference manager in transcendental solver (#5022)
authorGereon Kremer <gkremer@stanford.edu>
Thu, 17 Sep 2020 13:54:02 +0000 (15:54 +0200)
committerGitHub <noreply@github.com>
Thu, 17 Sep 2020 13:54:02 +0000 (08:54 -0500)
This refactors the transcendental solver to add lemmas to the new inference manager instead of using the old lemma collection scheme.

src/theory/arith/inference_manager.cpp
src/theory/arith/inference_manager.h
src/theory/arith/nl/nonlinear_extension.cpp
src/theory/arith/nl/transcendental_solver.cpp
src/theory/arith/nl/transcendental_solver.h

index d4c5d17c5ac0cf35fd9a501586e2964b6fd27e3d..5c1602a1a73ea09252bb110e28801edace2231f7 100644 (file)
@@ -32,6 +32,8 @@ InferenceManager::InferenceManager(TheoryArith& ta,
 void InferenceManager::addPendingArithLemma(std::unique_ptr<ArithLemma> lemma,
                                             bool isWaiting)
 {
+  Trace("arith::infman") << "Add " << lemma->d_inference << " " << lemma->d_node
+                         << (isWaiting ? " as waiting" : "") << std::endl;
   lemma->d_node = Rewriter::rewrite(lemma->d_node);
   if (hasCachedLemma(lemma->d_node, lemma->d_property))
   {
@@ -77,6 +79,9 @@ void InferenceManager::flushWaitingLemmas()
 {
   for (auto& lem : d_waitingLem)
   {
+    Trace("arith::infman") << "Flush waiting lemma to pending: "
+                           << lem->d_inference << " " << lem->d_node
+                           << std::endl;
     d_pendingLem.emplace_back(std::move(lem));
   }
   d_waitingLem.clear();
@@ -84,6 +89,8 @@ void InferenceManager::flushWaitingLemmas()
 
 void InferenceManager::addConflict(const Node& conf, InferenceId inftype)
 {
+  Trace("arith::infman") << "Adding conflict: " << inftype << " " << conf
+                         << std::endl;
   conflict(Rewriter::rewrite(conf));
 }
 
@@ -92,6 +99,11 @@ bool InferenceManager::hasUsed() const
   return hasSent() || hasPending();
 }
 
+bool InferenceManager::hasWaitingLemma() const
+{
+  return !d_waitingLem.empty();
+}
+
 std::size_t InferenceManager::numWaitingLemmas() const
 {
   return d_waitingLem.size();
index e1e386beca7182fc49a301bb99583a5471df7d6c..f4806cc9a0ff4453292d63bc95ca7486a9be2b29 100644 (file)
@@ -87,6 +87,9 @@ class InferenceManager : public InferenceManagerBuffered
    */
   bool hasUsed() const;
 
+  /** Checks whether we have waiting lemmas. */
+  bool hasWaitingLemma() const;
+
   /** Returns the number of pending lemmas. */
   std::size_t numWaitingLemmas() const;
 
index 3bf547cebc552fa3bc0330ea068f6d5c9d94758c..df3a304be709c0b01cfcf24c60c5a1d29630bbea 100644 (file)
@@ -44,7 +44,7 @@ NonlinearExtension::NonlinearExtension(TheoryArith& containing,
                   containing.getUserContext(),
                   containing.getOutputChannel()),
       d_model(containing.getSatContext()),
-      d_trSlv(d_model),
+      d_trSlv(d_im, d_model),
       d_nlSlv(containing, d_model),
       d_cadSlv(d_im, d_model),
       d_iandSlv(containing, d_model),
@@ -386,9 +386,7 @@ int NonlinearExtension::checkLastCall(const std::vector<Node>& assertions,
     // initialize the non-linear solver
     d_nlSlv.initLastCall(assertions, false_asserts, xts);
     // initialize the trancendental function solver
-    d_trSlv.initLastCall(assertions, false_asserts, xts, lemmas);
-    // process lemmas that may have been generated by the transcendental solver
-    filterLemmas(lemmas, lems);
+    d_trSlv.initLastCall(assertions, false_asserts, xts);
   }
   if (options::nlCad())
   {
@@ -398,11 +396,12 @@ int NonlinearExtension::checkLastCall(const std::vector<Node>& assertions,
   // init last call with IAND
   d_iandSlv.initLastCall(assertions, false_asserts, xts);
 
-  if (!lems.empty())
+  if (d_im.hasUsed() || !lems.empty())
   {
-    Trace("nl-ext") << "  ...finished with " << lems.size()
+    unsigned count = lems.size() + d_im.numPendingLemmas() + d_im.numSentLemmas();
+    Trace("nl-ext") << "  ...finished with " << count
                     << " new lemmas during registration." << std::endl;
-    return lems.size();
+    return count;
   }
 
   //----------------------------------- possibly split on zero
@@ -423,13 +422,13 @@ int NonlinearExtension::checkLastCall(const std::vector<Node>& assertions,
   if (options::nlExt())
   {
     // functions
-    lemmas = d_trSlv.checkTranscendentalInitialRefine();
-    filterLemmas(lemmas, lems);
-    if (!lems.empty())
+    d_trSlv.checkTranscendentalInitialRefine();
+    if (d_im.hasUsed())
     {
-      Trace("nl-ext") << "  ...finished with " << lems.size() << " new lemmas."
+      unsigned count = lems.size() + d_im.numPendingLemmas() + d_im.numSentLemmas();
+      Trace("nl-ext") << "  ...finished with " << count << " new lemmas."
                       << std::endl;
-      return lems.size();
+      return count;
     }
   }
   //-----------------------------------initial lemmas for iand
@@ -456,13 +455,13 @@ int NonlinearExtension::checkLastCall(const std::vector<Node>& assertions,
     }
 
     //-----------------------------------monotonicity of transdental functions
-    lemmas = d_trSlv.checkTranscendentalMonotonic();
-    filterLemmas(lemmas, lems);
-    if (!lems.empty())
+    d_trSlv.checkTranscendentalMonotonic();
+    if (d_im.hasUsed())
     {
-      Trace("nl-ext") << "  ...finished with " << lems.size() << " new lemmas."
+      unsigned count = lems.size() + d_im.numPendingLemmas() + d_im.numSentLemmas();
+      Trace("nl-ext") << "  ...finished with " << count << " new lemmas."
                       << std::endl;
-      return lems.size();
+      return count;
     }
 
     //------------------------lemmas based on magnitude of non-zero monomials
@@ -551,8 +550,7 @@ int NonlinearExtension::checkLastCall(const std::vector<Node>& assertions,
     }
     if (options::nlExtTfTangentPlanes())
     {
-      lemmas = d_trSlv.checkTranscendentalTangentPlanes();
-      filterLemmas(lemmas, wlems);
+      d_trSlv.checkTranscendentalTangentPlanes();
     }
   }
   if (options::nlCad())
@@ -572,8 +570,9 @@ int NonlinearExtension::checkLastCall(const std::vector<Node>& assertions,
   lemmas = d_iandSlv.checkFullRefine();
   filterLemmas(lemmas, wlems);
 
-  Trace("nl-ext") << "  ...finished with " << wlems.size() << " waiting lemmas."
-                  << std::endl;
+  Trace("nl-ext") << "  ...finished with "
+                  << (wlems.size() + d_im.numWaitingLemmas())
+                  << " waiting lemmas." << std::endl;
 
   return 0;
 }
@@ -614,6 +613,7 @@ void NonlinearExtension::check(Theory::Effort e)
       d_im.doPendingFacts();
       d_im.doPendingLemmas();
       d_im.doPendingPhaseRequirements();
+      d_im.reset();
       return;
     }
     // Otherwise, we will answer SAT. The values that we approximated are
@@ -728,7 +728,7 @@ bool NonlinearExtension::modelBasedRefinement(std::vector<NlLemma>& mlems)
     {
       complete_status = num_shared_wrong_value > 0 ? -1 : 0;
       checkLastCall(assertions, false_asserts, xts, mlems, wlems);
-      if (!mlems.empty())
+      if (!mlems.empty() || d_im.hasSentLemma() || d_im.hasPendingLemma())
       {
         return true;
       }
@@ -768,10 +768,12 @@ bool NonlinearExtension::modelBasedRefinement(std::vector<NlLemma>& mlems)
     if (complete_status != 1)
     {
       // flush the waiting lemmas
-      if (!wlems.empty())
+      if (!wlems.empty() || d_im.hasWaitingLemma())
       {
+        std::size_t count = wlems.size() + d_im.numWaitingLemmas();
         mlems.insert(mlems.end(), wlems.begin(), wlems.end());
-        Trace("nl-ext") << "...added " << wlems.size() << " waiting lemmas."
+        d_im.flushWaitingLemmas();
+        Trace("nl-ext") << "...added " << count << " waiting lemmas."
                         << std::endl;
         return true;
       }
index d075d5037395bca648fbe238d67e009e485e2be3..b22cf990ee173bfa501b6cdbd379c2d9099b5ed2 100644 (file)
@@ -31,7 +31,7 @@ namespace theory {
 namespace arith {
 namespace nl {
 
-TranscendentalSolver::TranscendentalSolver(NlModel& m) : d_model(m)
+TranscendentalSolver::TranscendentalSolver(InferenceManager& im, NlModel& m) : d_im(im), d_model(m)
 {
   NodeManager* nm = NodeManager::currentNM();
   d_true = nm->mkConst(true);
@@ -49,8 +49,7 @@ TranscendentalSolver::~TranscendentalSolver() {}
 
 void TranscendentalSolver::initLastCall(const std::vector<Node>& assertions,
                                         const std::vector<Node>& false_asserts,
-                                        const std::vector<Node>& xts,
-                                        std::vector<NlLemma>& lems)
+                                        const std::vector<Node>& xts)
 {
   d_funcCongClass.clear();
   d_funcMap.clear();
@@ -136,7 +135,7 @@ void TranscendentalSolver::initLastCall(const std::vector<Node>& assertions,
             }
             Node expn = exp.size() == 1 ? exp[0] : nm->mkNode(AND, exp);
             Node cong_lemma = nm->mkNode(OR, expn.negate(), a.eqNode(aa));
-            lems.emplace_back(cong_lemma, InferenceId::NL_CONGRUENCE);
+            d_im.addPendingArithLemma(cong_lemma, InferenceId::NL_CONGRUENCE);
           }
         }
         else
@@ -160,10 +159,10 @@ void TranscendentalSolver::initLastCall(const std::vector<Node>& assertions,
   if (needPi && d_pi.isNull())
   {
     mkPi();
-    getCurrentPiBounds(lems);
+    getCurrentPiBounds();
   }
 
-  if (!lems.empty())
+  if (d_im.hasUsed())
   {
     return;
   }
@@ -212,9 +211,8 @@ void TranscendentalSolver::initLastCall(const std::vector<Node>& assertions,
     // note we must do preprocess on this lemma
     Trace("nl-ext-lemma") << "NonlinearExtension::Lemma : purify : " << lem
                           << std::endl;
-    NlLemma nlem(
-        lem, LemmaProperty::PREPROCESS, nullptr, InferenceId::NL_T_PURIFY_ARG);
-    lems.emplace_back(nlem);
+    NlLemma nlem(lem, LemmaProperty::PREPROCESS, nullptr, InferenceId::NL_T_PURIFY_ARG);
+    d_im.addPendingArithLemma(nlem);
   }
 
   if (Trace.isOn("nl-ext-mv"))
@@ -363,19 +361,18 @@ void TranscendentalSolver::mkPi()
   }
 }
 
-void TranscendentalSolver::getCurrentPiBounds(std::vector<NlLemma>& lemmas)
+void TranscendentalSolver::getCurrentPiBounds()
 {
   NodeManager* nm = NodeManager::currentNM();
   Node pi_lem = nm->mkNode(AND,
                            nm->mkNode(GEQ, d_pi, d_pi_bound[0]),
                            nm->mkNode(LEQ, d_pi, d_pi_bound[1]));
-  lemmas.emplace_back(pi_lem, InferenceId::NL_T_PI_BOUND);
+  d_im.addPendingArithLemma(pi_lem, InferenceId::NL_T_PI_BOUND);
 }
 
-std::vector<NlLemma> TranscendentalSolver::checkTranscendentalInitialRefine()
+void TranscendentalSolver::checkTranscendentalInitialRefine()
 {
   NodeManager* nm = NodeManager::currentNM();
-  std::vector<NlLemma> lemmas;
   Trace("nl-ext")
       << "Get initial refinement lemmas for transcendental functions..."
       << std::endl;
@@ -454,18 +451,15 @@ std::vector<NlLemma> TranscendentalSolver::checkTranscendentalInitialRefine()
         }
         if (!lem.isNull())
         {
-          lemmas.emplace_back(lem, InferenceId::NL_T_INIT_REFINE);
+          d_im.addPendingArithLemma(lem, InferenceId::NL_T_INIT_REFINE);
         }
       }
     }
   }
-
-  return lemmas;
 }
 
-std::vector<NlLemma> TranscendentalSolver::checkTranscendentalMonotonic()
+void TranscendentalSolver::checkTranscendentalMonotonic()
 {
-  std::vector<NlLemma> lemmas;
   Trace("nl-ext") << "Get monotonicity lemmas for transcendental functions..."
                   << std::endl;
 
@@ -630,7 +624,8 @@ std::vector<NlLemma> TranscendentalSolver::checkTranscendentalMonotonic()
               }
               Trace("nl-ext-tf-mono")
                   << "Monotonicity lemma : " << mono_lem << std::endl;
-              lemmas.emplace_back(mono_lem, InferenceId::NL_T_MONOTONICITY);
+
+              d_im.addPendingArithLemma(mono_lem, InferenceId::NL_T_MONOTONICITY);
             }
           }
           // store the previous values
@@ -642,12 +637,10 @@ std::vector<NlLemma> TranscendentalSolver::checkTranscendentalMonotonic()
       }
     }
   }
-  return lemmas;
 }
 
-std::vector<NlLemma> TranscendentalSolver::checkTranscendentalTangentPlanes()
+void TranscendentalSolver::checkTranscendentalTangentPlanes()
 {
-  std::vector<NlLemma> lemmas;
   Trace("nl-ext") << "Get tangent plane lemmas for transcendental functions..."
                   << std::endl;
   // this implements Figure 3 of "Satisfiaility Modulo Transcendental Functions
@@ -682,11 +675,13 @@ std::vector<NlLemma> TranscendentalSolver::checkTranscendentalTangentPlanes()
       for (unsigned d = 1; d <= d_taylor_degree; d++)
       {
         Trace("nl-ext-tftp") << "- run at degree " << d << "..." << std::endl;
-        unsigned prev = lemmas.size();
-        if (checkTfTangentPlanesFun(tf, d, lemmas))
+        unsigned prev = d_im.numPendingLemmas() + d_im.numWaitingLemmas();
+        if (checkTfTangentPlanesFun(tf, d))
         {
           Trace("nl-ext-tftp")
-              << "...fail, #lemmas = " << (lemmas.size() - prev) << std::endl;
+              << "...fail, #lemmas = "
+              << (d_im.numPendingLemmas() + d_im.numWaitingLemmas() - prev)
+              << std::endl;
           break;
         }
         else
@@ -696,13 +691,10 @@ std::vector<NlLemma> TranscendentalSolver::checkTranscendentalTangentPlanes()
       }
     }
   }
-
-  return lemmas;
 }
 
 bool TranscendentalSolver::checkTfTangentPlanesFun(Node tf,
-                                                   unsigned d,
-                                                   std::vector<NlLemma>& lemmas)
+                                                   unsigned d)
 {
   NodeManager* nm = NodeManager::currentNM();
   Kind k = tf.getKind();
@@ -883,7 +875,7 @@ bool TranscendentalSolver::checkTfTangentPlanesFun(Node tf,
         << "*** Tangent plane lemma : " << lem << std::endl;
     Assert(d_model.computeAbstractModelValue(lem) == d_false);
     // Figure 3 : line 9
-    lemmas.emplace_back(lem, InferenceId::NL_T_TANGENT);
+    d_im.addPendingArithLemma(lem, InferenceId::NL_T_TANGENT, true);
   }
   else if (is_secant)
   {
@@ -1017,11 +1009,11 @@ bool TranscendentalSolver::checkTfTangentPlanesFun(Node tf,
     Assert(!lemmaConj.empty());
     Node lem =
         lemmaConj.size() == 1 ? lemmaConj[0] : nm->mkNode(AND, lemmaConj);
-    NlLemma nlem(lem, InferenceId::NL_T_SECANT);
+    NlLemma nlem(lem, LemmaProperty::NONE, nullptr, InferenceId::NL_T_SECANT);
     // The side effect says that if lem is added, then we should add the
     // secant point c for (tf,d).
     nlem.d_secantPoint.push_back(std::make_tuple(tf, d, c));
-    lemmas.emplace_back(nlem);
+    d_im.addPendingArithLemma(nlem, true);
   }
   return true;
 }
index c80fa99e68ed519bd977c09d63d8eef140f39a24..2ac2ae2f3388d3ce346a273df0edb62d3e1abead 100644 (file)
@@ -21,7 +21,7 @@
 #include <vector>
 
 #include "expr/node.h"
-#include "theory/arith/nl/nl_lemma_utils.h"
+#include "theory/arith/inference_manager.h"
 #include "theory/arith/nl/nl_model.h"
 
 namespace CVC4 {
@@ -44,7 +44,7 @@ namespace nl {
 class TranscendentalSolver
 {
  public:
-  TranscendentalSolver(NlModel& m);
+  TranscendentalSolver(InferenceManager& im, NlModel& m);
   ~TranscendentalSolver();
 
   /** init last call
@@ -60,8 +60,7 @@ class TranscendentalSolver
    */
   void initLastCall(const std::vector<Node>& assertions,
                     const std::vector<Node>& false_asserts,
-                    const std::vector<Node>& xts,
-                    std::vector<NlLemma>& lems);
+                    const std::vector<Node>& xts);
   /** increment taylor degree */
   void incrementTaylorDegree();
   /** get taylor degree */
@@ -80,7 +79,7 @@ class TranscendentalSolver
   //-------------------------------------------- lemma schemas
   /** check transcendental initial refine
    *
-   * Returns a set of valid theory lemmas, based on
+   * Constructs a set of valid theory lemmas, based on
    * simple facts about transcendental functions.
    * This mostly follows the initial axioms described in
    * Section 4 of "Satisfiability
@@ -94,11 +93,11 @@ class TranscendentalSolver
    * exp( x )>0
    * x<0 => exp( x )<1
    */
-  std::vector<NlLemma> checkTranscendentalInitialRefine();
+  void checkTranscendentalInitialRefine();
 
   /** check transcendental monotonic
    *
-   * Returns a set of valid theory lemmas, based on a
+   * Constructs a set of valid theory lemmas, based on a
    * lemma scheme that ensures that applications
    * of transcendental functions respect monotonicity.
    *
@@ -108,11 +107,11 @@ class TranscendentalSolver
    * PI/2 > x > y > 0 => sin( x ) > sin( y )
    * PI > x > y > PI/2 => sin( x ) < sin( y )
    */
-  std::vector<NlLemma> checkTranscendentalMonotonic();
+  void checkTranscendentalMonotonic();
 
   /** check transcendental tangent planes
    *
-   * Returns a set of valid theory lemmas, based on
+   * Constructs a set of valid theory lemmas, based on
    * computing an "incremental linearization" of
    * transcendental functions based on the model values
    * of transcendental functions and their arguments.
@@ -168,7 +167,8 @@ class TranscendentalSolver
    *     where c1, c2 are rationals (for brevity, omitted here)
    *     such that c1 ~= .277 and c2 ~= 2.032.
    */
-  std::vector<NlLemma> checkTranscendentalTangentPlanes();
+  void checkTranscendentalTangentPlanes();
+ private:
   /** check transcendental function refinement for tf
    *
    * This method is called by the above method for each "master"
@@ -186,9 +186,8 @@ class TranscendentalSolver
    * It returns false if the bounds are not precise enough to add a
    * secant or tangent plane lemma.
    */
-  bool checkTfTangentPlanesFun(Node tf, unsigned d, std::vector<NlLemma>& lems);
+  bool checkTfTangentPlanesFun(Node tf, unsigned d);
   //-------------------------------------------- end lemma schemas
- private:
   /** polynomial approximation bounds
    *
    * This adds P_l+[x], P_l-[x], P_u+[x], P_u-[x] to pbounds, where x is
@@ -268,10 +267,12 @@ class TranscendentalSolver
   Node getDerivative(Node n, Node x);
 
   void mkPi();
-  void getCurrentPiBounds(std::vector<NlLemma>& lemmas);
+  void getCurrentPiBounds();
   /** Make the node -pi <= a <= pi */
   static Node mkValidPhase(Node a, Node pi);
 
+  /** The inference manager that we push conflicts and lemmas to. */
+  InferenceManager& d_im;
   /** Reference to the non-linear model object */
   NlModel& d_model;
   /** commonly used terms */