Use arith::InferenceManager for CAD lemmas (#5015)
authorGereon Kremer <gereon.kremer@cs.rwth-aachen.de>
Fri, 4 Sep 2020 14:59:49 +0000 (16:59 +0200)
committerGitHub <noreply@github.com>
Fri, 4 Sep 2020 14:59:49 +0000 (09:59 -0500)
This makes the CAD solver use the new arith::InferenceManager instead of the previously used lemma collection scheme.

src/theory/arith/inference_manager.cpp
src/theory/arith/inference_manager.h
src/theory/arith/nl/cad_solver.cpp
src/theory/arith/nl/cad_solver.h
src/theory/arith/nl/nonlinear_extension.cpp
src/theory/sets/cardinality_extension.cpp
src/theory/sets/theory_sets_private.cpp
src/theory/theory_inference_manager.cpp
src/theory/theory_inference_manager.h

index d03d2ba3778584012f1b0c8bf171322b0ad3280a..d4c5d17c5ac0cf35fd9a501586e2964b6fd27e3d 100644 (file)
@@ -87,6 +87,11 @@ void InferenceManager::addConflict(const Node& conf, InferenceId inftype)
   conflict(Rewriter::rewrite(conf));
 }
 
+bool InferenceManager::hasUsed() const
+{
+  return hasSent() || hasPending();
+}
+
 std::size_t InferenceManager::numWaitingLemmas() const
 {
   return d_waitingLem.size();
index 33e4f424bc3cd0aad1df75364aa98ee32f0aca43..e1e386beca7182fc49a301bb99583a5471df7d6c 100644 (file)
@@ -81,6 +81,12 @@ class InferenceManager : public InferenceManagerBuffered
   /** Add a conflict to the this inference manager. */
   void addConflict(const Node& conf, InferenceId inftype);
 
+  /**
+   * Checks whether we have made any progress, that is whether a conflict, lemma
+   * or fact was added or whether a lemma or fact is pending.
+   */
+  bool hasUsed() const;
+
   /** Returns the number of pending lemmas. */
   std::size_t numWaitingLemmas() const;
 
index 473e067b70ff492713b644d456d670fe0a841ec2..416de1c5ab9b1fbafb889249a0757dab066302aa 100644 (file)
@@ -28,8 +28,8 @@ namespace theory {
 namespace arith {
 namespace nl {
 
-CadSolver::CadSolver(TheoryArith& containing, NlModel& model)
-    : d_foundSatisfiability(false), d_containing(containing), d_model(model)
+CadSolver::CadSolver(InferenceManager& im, NlModel& model)
+    : d_foundSatisfiability(false), d_im(im), d_model(model)
 {
   d_ranVariable =
       NodeManager::currentNM()->mkSkolem("__z",
@@ -66,10 +66,9 @@ void CadSolver::initLastCall(const std::vector<Node>& assertions)
 #endif
 }
 
-std::vector<NlLemma> CadSolver::checkFull()
+void CadSolver::checkFull()
 {
 #ifdef CVC4_POLY_IMP
-  std::vector<NlLemma> lems;
   auto covering = d_CAC.getUnsatCover();
   if (covering.empty())
   {
@@ -81,23 +80,11 @@ std::vector<NlLemma> CadSolver::checkFull()
     d_foundSatisfiability = false;
     auto mis = collectConstraints(covering);
     Trace("nl-cad") << "Collected MIS: " << mis << std::endl;
-    auto* nm = NodeManager::currentNM();
-    for (auto& n : mis)
-    {
-      n = n.negate();
-    }
     Assert(!mis.empty()) << "Infeasible subset can not be empty";
-    if (mis.size() == 1)
-    {
-      lems.emplace_back(mis.front(), InferenceId::NL_CAD_CONFLICT);
-    }
-    else
-    {
-      lems.emplace_back(nm->mkNode(Kind::OR, mis), InferenceId::NL_CAD_CONFLICT);
-    }
-    Trace("nl-cad") << "UNSAT with MIS: " << lems.back().d_node << std::endl;
+    Trace("nl-cad") << "UNSAT with MIS: " << mis << std::endl;
+    d_im.addConflict(NodeManager::currentNM()->mkAnd(mis),
+                     InferenceId::NL_CAD_CONFLICT);
   }
-  return lems;
 #else
   Warning() << "Tried to use CadSolver but libpoly is not available. Compile "
                "with --poly."
@@ -106,10 +93,9 @@ std::vector<NlLemma> CadSolver::checkFull()
 #endif
 }
 
-std::vector<NlLemma> CadSolver::checkPartial()
+void CadSolver::checkPartial()
 {
 #ifdef CVC4_POLY_IMP
-  std::vector<NlLemma> lems;
   auto covering = d_CAC.getUnsatCover(0, true);
   if (covering.empty())
   {
@@ -135,14 +121,16 @@ std::vector<NlLemma> CadSolver::checkPartial()
       }
       Node conclusion =
           excluding_interval_to_lemma(first_var, interval.d_interval, false);
-      if (!conclusion.isNull()) {
+      if (!conclusion.isNull())
+      {
         Node lemma = nm->mkNode(Kind::IMPLIES, premise, conclusion);
-        Trace("nl-cad") << "Excluding " << first_var << " -> " << interval.d_interval << " using " << lemma << std::endl;
-        lems.emplace_back(lemma, InferenceId::NL_CAD_EXCLUDED_INTERVAL);
-       }
+        Trace("nl-cad") << "Excluding " << first_var << " -> "
+                        << interval.d_interval << " using " << lemma
+                        << std::endl;
+        d_im.addPendingArithLemma(lemma, InferenceId::NL_CAD_EXCLUDED_INTERVAL);
+      }
     }
   }
-  return lems;
 #else
   Warning() << "Tried to use CadSolver but libpoly is not available. Compile "
                "with --poly."
index 6f6c0d43cedd306320d1e16b1dfead4f91b0080e..615cdb03a90a3146c4b1fb558521b45d9cc8643f 100644 (file)
@@ -18,9 +18,9 @@
 #include <vector>
 
 #include "expr/node.h"
+#include "theory/arith/inference_manager.h"
 #include "theory/arith/nl/cad/cdcac.h"
 #include "theory/arith/nl/nl_model.h"
-#include "theory/arith/theory_arith.h"
 
 namespace CVC4 {
 namespace theory {
@@ -34,7 +34,7 @@ namespace nl {
 class CadSolver
 {
  public:
-  CadSolver(TheoryArith& containing, NlModel& model);
+  CadSolver(InferenceManager& im, NlModel& model);
   ~CadSolver();
 
   /**
@@ -52,7 +52,7 @@ class CadSolver
    * for construct_model_if_available. Otherwise, the single lemma can be used
    * as an infeasible subset.
    */
-  std::vector<NlLemma> checkFull();
+  void checkFull();
 
   /**
    * Perform a partial check, returning either {} or a list of lemmas.
@@ -60,7 +60,7 @@ class CadSolver
    * for construct_model_if_available. Otherwise, the lemmas exclude some part
    * of the search space.
    */
-  std::vector<NlLemma> checkPartial();
+  void checkPartial();
 
   /**
    * If a model is available (indicated by the last call to check_full() or
@@ -88,8 +88,8 @@ class CadSolver
    */
   bool d_foundSatisfiability;
 
-  /** The theory of arithmetic containing this extension.*/
-  TheoryArith& d_containing;
+  /** The inference manager we are pushing conflicts and lemmas to. */
+  InferenceManager& d_im;
   /** Reference to the non-linear model object */
   NlModel& d_model;
 }; /* class CadSolver */
index 537dd604c2dcc23b937b4bb930e6aaf9dc8d6b57..3bf547cebc552fa3bc0330ea068f6d5c9d94758c 100644 (file)
@@ -46,7 +46,7 @@ NonlinearExtension::NonlinearExtension(TheoryArith& containing,
       d_model(containing.getSatContext()),
       d_trSlv(d_model),
       d_nlSlv(containing, d_model),
-      d_cadSlv(containing, d_model),
+      d_cadSlv(d_im, d_model),
       d_iandSlv(containing, d_model),
       d_builtModel(containing.getSatContext(), false)
 {
@@ -557,12 +557,16 @@ int NonlinearExtension::checkLastCall(const std::vector<Node>& assertions,
   }
   if (options::nlCad())
   {
-    lemmas = d_cadSlv.checkFull();
-    if (lemmas.empty())
+    d_cadSlv.checkFull();
+    if (!d_im.hasUsed())
     {
       Trace("nl-cad") << "nl-cad found SAT!" << std::endl;
     }
-    filterLemmas(lemmas, wlems);
+    else
+    {
+      // checkFull() only adds a single conflict
+      return 1;
+    }
   }
   // run the full refinement in the IAND solver
   lemmas = d_iandSlv.checkFullRefine();
index a51cee2c33deb92f28332e1d23c4661aaa3bfafe..321559f5a29bff02490272b58722b4f556f46803 100644 (file)
@@ -183,17 +183,17 @@ void CardinalityExtension::check()
 {
   checkCardinalityExtended();
   checkRegister();
-  if (d_im.hasProcessed())
+  if (d_im.hasSent())
   {
     return;
   }
   checkMinCard();
-  if (d_im.hasProcessed())
+  if (d_im.hasSent())
   {
     return;
   }
   checkCardCycles();
-  if (d_im.hasProcessed())
+  if (d_im.hasSent())
   {
     return;
   }
@@ -300,7 +300,7 @@ void CardinalityExtension::checkCardCycles()
     std::vector<Node> curr;
     std::vector<Node> exp;
     checkCardCyclesRec(s, curr, exp);
-    if (d_im.hasProcessed())
+    if (d_im.hasSent())
     {
       return;
     }
@@ -414,7 +414,7 @@ void CardinalityExtension::checkCardCyclesRec(Node eqc,
       }
       d_im.assertInference(conc, n.eqNode(emp_set), "cg_emp");
       d_im.doPendingLemmas();
-      if (d_im.hasProcessed())
+      if (d_im.hasSent())
       {
         return;
       }
@@ -446,7 +446,7 @@ void CardinalityExtension::checkCardCyclesRec(Node eqc,
         Assert(!d_state.areEqual(n, emp_set));
         d_im.assertInference(n.eqNode(emp_set), p.eqNode(emp_set), "cg_emppar");
         d_im.doPendingLemmas();
-        if (d_im.hasProcessed())
+        if (d_im.hasSent())
         {
           return;
         }
@@ -493,7 +493,7 @@ void CardinalityExtension::checkCardCyclesRec(Node eqc,
             << "...derived " << conc.size() << " conclusions" << std::endl;
         d_im.assertInference(conc, n.eqNode(p), "cg_eqpar");
         d_im.doPendingLemmas();
-        if (d_im.hasProcessed())
+        if (d_im.hasSent())
         {
           return;
         }
@@ -552,7 +552,7 @@ void CardinalityExtension::checkCardCyclesRec(Node eqc,
           Trace("sets-nf") << "Split empty : " << n << std::endl;
           d_im.split(n.eqNode(emp_set), 1);
         }
-        Assert(d_im.hasProcessed());
+        Assert(d_im.hasSent());
         return;
       }
       else
@@ -600,7 +600,7 @@ void CardinalityExtension::checkCardCyclesRec(Node eqc,
           }
           d_im.assertInference(conc, cpk.eqNode(cpnl), "cg_pareq");
           d_im.doPendingLemmas();
-          if (d_im.hasProcessed())
+          if (d_im.hasSent())
           {
             return;
           }
@@ -619,7 +619,7 @@ void CardinalityExtension::checkCardCyclesRec(Node eqc,
       Trace("sets-cycle-debug")
           << "Traverse card parent " << eqc << " -> " << cpnc << std::endl;
       checkCardCyclesRec(cpnc, curr, exp);
-      if (d_im.hasProcessed())
+      if (d_im.hasSent())
       {
         return;
       }
@@ -642,7 +642,7 @@ void CardinalityExtension::checkNormalForms(std::vector<Node>& intro_sets)
   for (int i = (int)(d_oSetEqc.size() - 1); i >= 0; i--)
   {
     checkNormalForm(d_oSetEqc[i], intro_sets);
-    if (d_im.hasProcessed() || !intro_sets.empty())
+    if (d_im.hasSent() || !intro_sets.empty())
     {
       return;
     }
@@ -783,7 +783,7 @@ void CardinalityExtension::checkNormalForm(Node eqc,
               d_state.debugPrintSet(r, "sets-nf");
               Trace("sets-nf") << std::endl;
               d_im.split(r.eqNode(emp_set), 1);
-              Assert(d_im.hasProcessed());
+              Assert(d_im.hasSent());
               return;
             }
           }
@@ -867,7 +867,7 @@ void CardinalityExtension::checkNormalForm(Node eqc,
   }
   if (!success)
   {
-    Assert(d_im.hasProcessed());
+    Assert(d_im.hasSent());
     return;
   }
   // Send to parents (a parent is a set that contains a term in this equivalence
index 7d498a7981f4eb0371427dbbfc72bacfcac02f55..5e78e7ed5766b7ffb0763e94530c0d56b712e174 100644 (file)
@@ -286,7 +286,7 @@ void TheorySetsPrivate::fullEffortCheck()
   Trace("sets") << "----- Full effort check ------" << std::endl;
   do
   {
-    Assert(!d_im.hasPendingLemma() || d_im.hasProcessed());
+    Assert(!d_im.hasPendingLemma() || d_im.hasSent());
 
     Trace("sets") << "...iterate full effort check..." << std::endl;
     fullEffortReset();
@@ -391,7 +391,7 @@ void TheorySetsPrivate::fullEffortCheck()
 
     // We may have sent lemmas while registering the terms in the loop above,
     // e.g. the cardinality solver.
-    if (d_im.hasProcessed())
+    if (d_im.hasSent())
     {
       continue;
     }
@@ -421,35 +421,35 @@ void TheorySetsPrivate::fullEffortCheck()
     // check subtypes
     checkSubtypes();
     d_im.doPendingLemmas();
-    if (d_im.hasProcessed())
+    if (d_im.hasSent())
     {
       continue;
     }
     // check downwards closure
     checkDownwardsClosure();
     d_im.doPendingLemmas();
-    if (d_im.hasProcessed())
+    if (d_im.hasSent())
     {
       continue;
     }
     // check upwards closure
     checkUpwardsClosure();
     d_im.doPendingLemmas();
-    if (d_im.hasProcessed())
+    if (d_im.hasSent())
     {
       continue;
     }
     // check disequalities
     checkDisequalities();
     d_im.doPendingLemmas();
-    if (d_im.hasProcessed())
+    if (d_im.hasSent())
     {
       continue;
     }
     // check reduce comprehensions
     checkReduceComprehensions();
     d_im.doPendingLemmas();
-    if (d_im.hasProcessed())
+    if (d_im.hasSent())
     {
       continue;
     }
@@ -457,7 +457,7 @@ void TheorySetsPrivate::fullEffortCheck()
     {
       // call the check method of the cardinality solver
       d_cardSolver->check();
-      if (d_im.hasProcessed())
+      if (d_im.hasSent())
       {
         continue;
       }
@@ -469,7 +469,7 @@ void TheorySetsPrivate::fullEffortCheck()
     }
   } while (!d_im.hasSentLemma() && !d_state.isInConflict()
            && d_im.hasSentFact());
-  Assert(!d_im.hasPendingLemma() || d_im.hasProcessed());
+  Assert(!d_im.hasPendingLemma() || d_im.hasSent());
   Trace("sets") << "----- End full effort check, conflict="
                 << d_state.isInConflict() << ", lemma=" << d_im.hasSentLemma()
                 << std::endl;
@@ -720,7 +720,7 @@ void TheorySetsPrivate::checkUpwardsClosure()
       }
     }
   }
-  if (!d_im.hasProcessed())
+  if (!d_im.hasSent())
   {
     if (options::setsExt())
     {
@@ -827,7 +827,7 @@ void TheorySetsPrivate::checkDisequalities()
     lem = Rewriter::rewrite(lem);
     d_im.assertInference(lem, d_true, "diseq", 1);
     d_im.doPendingLemmas();
-    if (d_im.hasProcessed())
+    if (d_im.hasSent())
     {
       return;
     }
index 570b878b46c1e0ef08fec11290525aea35e40455..801d6a26659d24a4a60b56d8ce20fdff4d6c223a 100644 (file)
@@ -58,7 +58,7 @@ void TheoryInferenceManager::reset()
   d_numCurrentFacts = 0;
 }
 
-bool TheoryInferenceManager::hasProcessed() const
+bool TheoryInferenceManager::hasSent() const
 {
   return d_theoryState.isInConflict() || d_numCurrentLemmas > 0
          || d_numCurrentFacts > 0;
index 7fecacf82d094bc325939478bf4bd02579eb38cf..97baeaa408954502ece653412cd4b79538b04b49 100644 (file)
@@ -96,7 +96,7 @@ class TheoryInferenceManager
    * Returns true if we are in conflict, or if we have sent a lemma or fact
    * since the last call to reset.
    */
-  bool hasProcessed() const;
+  bool hasSent() const;
   //--------------------------------------- propagations
   /**
    * T-propagate literal lit, possibly encountered by equality engine,