Update theory of arithmetic to standard check (#5178)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Thu, 1 Oct 2020 19:12:08 +0000 (14:12 -0500)
committerGitHub <noreply@github.com>
Thu, 1 Oct 2020 19:12:08 +0000 (14:12 -0500)
This updates the theory of arithmetic to use the standard check callbacks (preCheck, postCheck, preNotifyFact, notifyFact). It also refactors some use of the non-linear solver which will solely live in TheoryArith.

The longer term goal is to have TheoryArithPrivate be responsible for linear arithmetic solving only, and to be treated as a black box. Eventually, the non-linear solver will be removed from this class.

This PR is required for several things, including refactoring of preprocessing and expand definitions for arithmetic (div/mod will not be eliminated eagerly). It is also required for fixing issues related to div/mod appearing in instantiations.

This is the last class to have a custom check method. Followup PR will make Theory::check non-virtual.

src/theory/arith/nl/nl_solver.h
src/theory/arith/nl/nonlinear_extension.h
src/theory/arith/theory_arith.cpp
src/theory/arith/theory_arith.h
src/theory/arith/theory_arith_private.cpp
src/theory/arith/theory_arith_private.h

index 9dd5b03c66910eeddb6794dbfe9da0b25a4da14d..903220fb20e6307b6ec92d51164c472e49d7dd91 100644 (file)
@@ -31,7 +31,7 @@
 #include "theory/arith/nl/nl_lemma_utils.h"
 #include "theory/arith/nl/nl_model.h"
 #include "theory/arith/nl/nl_monomial.h"
-#include "theory/arith/theory_arith.h"
+#include "theory/arith/inference_manager.h"
 
 namespace CVC4 {
 namespace theory {
index 309caa519c5b1fdea0e2e6431a3fcc02a3a7a39a..cf45942c8cf914c3dde347335f1012dac74ffb93 100644 (file)
@@ -37,7 +37,6 @@
 #include "theory/arith/nl/stats.h"
 #include "theory/arith/nl/strategy.h"
 #include "theory/arith/nl/transcendental_solver.h"
-#include "theory/arith/theory_arith.h"
 #include "theory/ext_theory.h"
 #include "theory/uf/equality_engine.h"
 
index d84f666be9fcc20adcf265d41b6d49d73cf01260..93c8d87fc9b13f4d4cd091e51b92420b439733dd 100644 (file)
@@ -48,8 +48,9 @@ TheoryArith::TheoryArith(context::Context* c,
 {
   smtStatisticsRegistry()->registerStat(&d_ppRewriteTimer);
 
-  // indicate we are using the theory state object
+  // indicate we are using the theory state object and inference manager
   d_theoryState = &d_astate;
+  d_inferManager = &d_inferenceManager;
 }
 
 TheoryArith::~TheoryArith(){
@@ -89,14 +90,21 @@ void TheoryArith::finishInit()
   d_internal->finishInit();
 }
 
-void TheoryArith::preRegisterTerm(TNode n) { d_internal->preRegisterTerm(n); }
+void TheoryArith::preRegisterTerm(TNode n)
+{
+  if (d_nonlinearExtension != nullptr)
+  {
+    d_nonlinearExtension->preRegisterTerm(n);
+  }
+  d_internal->preRegisterTerm(n);
+}
 
 TrustNode TheoryArith::expandDefinition(Node node)
 {
   return d_internal->expandDefinition(node);
 }
 
-void TheoryArith::notifySharedTerm(TNode n) { d_internal->addSharedTerm(n); }
+void TheoryArith::notifySharedTerm(TNode n) { d_internal->notifySharedTerm(n); }
 
 TrustNode TheoryArith::ppRewrite(TNode atom)
 {
@@ -112,13 +120,38 @@ void TheoryArith::ppStaticLearn(TNode n, NodeBuilder<>& learned) {
   d_internal->ppStaticLearn(n, learned);
 }
 
-void TheoryArith::check(Effort effortLevel){
-  getOutputChannel().spendResource(ResourceManager::Resource::TheoryCheckStep);
-  d_internal->check(effortLevel);
+bool TheoryArith::preCheck(Effort level) { return d_internal->preCheck(level); }
+
+void TheoryArith::postCheck(Effort level)
+{
+  // check with the non-linear solver at last call
+  if (level == Theory::EFFORT_LAST_CALL)
+  {
+    if (d_nonlinearExtension != nullptr)
+    {
+      d_nonlinearExtension->check(level);
+    }
+    return;
+  }
+  // otherwise, check with the linear solver
+  d_internal->postCheck(level);
+}
+
+bool TheoryArith::preNotifyFact(
+    TNode atom, bool pol, TNode fact, bool isPrereg, bool isInternal)
+{
+  d_internal->preNotifyFact(atom, pol, fact);
+  // We do not assert to the equality engine of arithmetic in the standard way,
+  // hence we return "true" to indicate we are finished with this fact.
+  return true;
 }
 
 bool TheoryArith::needsCheckLastEffort() {
-  return d_internal->needsCheckLastEffort();
+  if (d_nonlinearExtension != nullptr)
+  {
+    return d_nonlinearExtension->needsCheckLastEffort();
+  }
+  return false;
 }
 
 TrustNode TheoryArith::explain(TNode n)
@@ -130,6 +163,7 @@ TrustNode TheoryArith::explain(TNode n)
 void TheoryArith::propagate(Effort e) {
   d_internal->propagate(e);
 }
+
 bool TheoryArith::collectModelInfo(TheoryModel* m)
 {
   std::set<Node> termSet;
@@ -186,6 +220,10 @@ void TheoryArith::notifyRestart(){
 
 void TheoryArith::presolve(){
   d_internal->presolve();
+  if (d_nonlinearExtension != nullptr)
+  {
+    d_nonlinearExtension->presolve();
+  }
 }
 
 EqualityStatus TheoryArith::getEqualityStatus(TNode a, TNode b) {
index e332646ffaf9d323c175365c1f10f98b368e33ec..f05dee38b3b2966177be7b210d5eb32ac84032dc 100644 (file)
@@ -20,6 +20,7 @@
 #include "expr/node.h"
 #include "theory/arith/arith_state.h"
 #include "theory/arith/inference_manager.h"
+#include "theory/arith/nl/nonlinear_extension.h"
 #include "theory/arith/theory_arith_private_forward.h"
 #include "theory/theory.h"
 
@@ -27,10 +28,6 @@ namespace CVC4 {
 namespace theory {
 namespace arith {
 
-namespace nl {
-class NonlinearExtension;
-}
-
 /**
  * Implementation of linear and non-linear integer and real arithmetic.
  * The linear arithmetic solver is based upon:
@@ -75,7 +72,18 @@ class TheoryArith : public Theory {
 
   TrustNode expandDefinition(Node node) override;
 
-  void check(Effort e) override;
+  //--------------------------------- standard check
+  /** Pre-check, called before the fact queue of the theory is processed. */
+  bool preCheck(Effort level) override;
+  /** Post-check, called after the fact queue of the theory is processed. */
+  void postCheck(Effort level) override;
+  /** Pre-notify fact, return true if processed. */
+  bool preNotifyFact(TNode atom,
+                     bool pol,
+                     TNode fact,
+                     bool isPrereg,
+                     bool isInternal) override;
+  //--------------------------------- end standard check
   bool needsCheckLastEffort() override;
   void propagate(Effort e) override;
   TrustNode explain(TNode n) override;
@@ -116,7 +124,6 @@ class TheoryArith : public Theory {
   eq::ProofEqEngine* getProofEqEngine();
   /** The state object wrapping TheoryArithPrivate  */
   ArithState d_astate;
-
   /** The arith::InferenceManager. */
   InferenceManager d_inferenceManager;
 
index dfc51a6daed4720e8a6ad9db598acc8e295459ce..d2293131f0678ca1d31e34545f1b3abbbbb7bae2 100644 (file)
@@ -169,6 +169,8 @@ TheoryArithPrivate::TheoryArithPrivate(TheoryArith& containing,
       d_dioSolveResources(0),
       d_solveIntMaybeHelp(0u),
       d_solveIntAttempts(0u),
+      d_newFacts(false),
+      d_previousStatus(Result::SAT_UNKNOWN),
       d_statistics(),
       d_opElim(pnm, logicInfo)
 {
@@ -544,6 +546,11 @@ void TheoryArithPrivate::raiseBlackBoxConflict(Node bb){
   }
 }
 
+bool TheoryArithPrivate::anyConflict() const
+{
+  return !conflictQueueEmpty() || !d_blackBoxConflict.get().isNull();
+}
+
 void TheoryArithPrivate::revertOutOfConflict(){
   d_partialModel.revertAssignmentChanges();
   clearUpdates();
@@ -1067,12 +1074,12 @@ bool TheoryArithPrivate::AssertDisequality(ConstraintP constraint){
   return false;
 }
 
-void TheoryArithPrivate::addSharedTerm(TNode n){
-  Debug("arith::addSharedTerm") << "addSharedTerm: " << n << endl;
+void TheoryArithPrivate::notifySharedTerm(TNode n)
+{
+  Debug("arith::notifySharedTerm") << "notifySharedTerm: " << n << endl;
   if(n.isConst()){
     d_partialModel.invalidateDelta();
   }
-
   d_congruenceManager.addSharedTerm(n);
   if(!n.isConst() && !isSetup(n)){
     Polynomial poly = Polynomial::parsePolynomial(n);
@@ -1472,11 +1479,6 @@ void TheoryArithPrivate::setupAtom(TNode atom) {
 void TheoryArithPrivate::preRegisterTerm(TNode n) {
   Debug("arith::preregister") <<"begin arith::preRegisterTerm("<< n <<")"<< endl;
 
-  if (d_nonlinearExtension != nullptr)
-  {
-    d_nonlinearExtension->preRegisterTerm(n);
-  }
-
   try {
     if(isRelationOperator(n.getKind())){
       if(!isSetup(n)){
@@ -1701,10 +1703,8 @@ Node TheoryArithPrivate::callDioSolver(){
   return d_diosolver.processEquationsForConflict();
 }
 
-ConstraintP TheoryArithPrivate::constraintFromFactQueue(){
-  Assert(!done());
-  TNode assertion = get();
-
+ConstraintP TheoryArithPrivate::constraintFromFactQueue(TNode assertion)
+{
   Kind simpleKind = Comparison::comparisonKind(assertion);
   ConstraintP constraint = d_constraintDatabase.lookup(assertion);
   if(constraint == NullConstraint){
@@ -1972,6 +1972,7 @@ void TheoryArithPrivate::outputConflict(TNode lit) {
 
 void TheoryArithPrivate::outputPropagate(TNode lit) {
   Debug("arith::channel") << "Arith propagation: " << lit << std::endl;
+  // call the propagate lit method of the
   (d_containing.d_out)->propagate(lit);
 }
 
@@ -1980,20 +1981,6 @@ void TheoryArithPrivate::outputRestart() {
   (d_containing.d_out)->demandRestart();
 }
 
-// void TheoryArithPrivate::branchVector(const std::vector<ArithVar>& lemmas){
-//   //output the lemmas
-//   for(vector<ArithVar>::const_iterator i = lemmas.begin(); i != lemmas.end();
-//   ++i){
-//     ArithVar v = *i;
-//     Assert(!d_cutInContext.contains(v));
-//     d_cutInContext.insert(v);
-//     d_cutCount = d_cutCount + 1;
-//     Node lem = branchIntegerVariable(v);
-//     outputLemma(lem);
-//     ++(d_statistics.d_externalBranchAndBounds);
-//   }
-// }
-
 bool TheoryArithPrivate::attemptSolveInteger(Theory::Effort effortLevel, bool emmmittedLemmaOrSplit){
   int level = getSatContext()->getLevel();
   Debug("approx")
@@ -3316,47 +3303,37 @@ bool TheoryArithPrivate::hasFreshArithLiteral(Node n) const{
   }
 }
 
-void TheoryArithPrivate::check(Theory::Effort effortLevel){
+bool TheoryArithPrivate::preCheck(Theory::Effort level)
+{
   Assert(d_currentPropagationList.empty());
-
-  if(done() && effortLevel < Theory::EFFORT_FULL && ( d_qflraStatus == Result::SAT) ){
-    return;
-  }
-
-  if(effortLevel == Theory::EFFORT_LAST_CALL){
-    if (d_nonlinearExtension != nullptr)
-    {
-      d_nonlinearExtension->check(effortLevel);
-    }
-    return;
-  }
-
-  TimerStat::CodeTimer checkTimer(d_containing.d_checkTime);
-  //cout << "TheoryArithPrivate::check " << effortLevel << std::endl;
-  Debug("effortlevel") << "TheoryArithPrivate::check " << effortLevel << std::endl;
-  Debug("arith") << "TheoryArithPrivate::check begun " << effortLevel << std::endl;
-
   if(Debug.isOn("arith::consistency")){
     Assert(unenqueuedVariablesAreConsistent());
   }
 
-  bool newFacts = !done();
-  //If previous == SAT, then reverts on conflicts are safe
-  //Otherwise, they are not and must be committed.
-  Result::Sat previous = d_qflraStatus;
-  if(newFacts){
+  d_newFacts = !done();
+  // If d_previousStatus == SAT, then reverts on conflicts are safe
+  // Otherwise, they are not and must be committed.
+  d_previousStatus = d_qflraStatus;
+  if (d_newFacts)
+  {
     d_qflraStatus = Result::SAT_UNKNOWN;
     d_hasDoneWorkSinceCut = true;
   }
+  return false;
+}
 
-  while(!done()){
-    ConstraintP curr = constraintFromFactQueue();
-    if(curr != NullConstraint){
-      bool res CVC4_UNUSED = assertionCases(curr);
-      Assert(!res || anyConflict());
-    }
-    if(anyConflict()){ break; }
+void TheoryArithPrivate::preNotifyFact(TNode atom, bool pol, TNode fact)
+{
+  ConstraintP curr = constraintFromFactQueue(fact);
+  if (curr != NullConstraint)
+  {
+    bool res CVC4_UNUSED = assertionCases(curr);
+    Assert(!res || anyConflict());
   }
+}
+
+void TheoryArithPrivate::postCheck(Theory::Effort effortLevel)
+{
   if(!anyConflict()){
     while(!d_learnedBounds.empty()){
       // we may attempt some constraints twice.  this is okay!
@@ -3373,14 +3350,19 @@ void TheoryArithPrivate::check(Theory::Effort effortLevel){
 
   if(anyConflict()){
     d_qflraStatus = Result::UNSAT;
-    if(options::revertArithModels() && previous == Result::SAT){
+    if (options::revertArithModels() && d_previousStatus == Result::SAT)
+    {
       ++d_statistics.d_revertsOnConflicts;
-      Debug("arith::bt") << "clearing here " << " " << newFacts << " " << previous << " " << d_qflraStatus  << endl;
+      Debug("arith::bt") << "clearing here "
+                         << " " << d_newFacts << " " << d_previousStatus << " "
+                         << d_qflraStatus << endl;
       revertOutOfConflict();
       d_errorSet.clear();
     }else{
       ++d_statistics.d_commitsOnConflicts;
-      Debug("arith::bt") << "committing here " << " " << newFacts << " " << previous << " " << d_qflraStatus  << endl;
+      Debug("arith::bt") << "committing here "
+                         << " " << d_newFacts << " " << d_previousStatus << " "
+                         << d_qflraStatus << endl;
       d_partialModel.commitAssignmentChanges();
       revertOutOfConflict();
     }
@@ -3415,7 +3397,9 @@ void TheoryArithPrivate::check(Theory::Effort effortLevel){
     solveInteger(effortLevel);
     if(anyConflict()){
       ++d_statistics.d_commitsOnConflicts;
-      Debug("arith::bt") << "committing here " << " " << newFacts << " " << previous << " " << d_qflraStatus  << endl;
+      Debug("arith::bt") << "committing here "
+                         << " " << d_newFacts << " " << d_previousStatus << " "
+                         << d_qflraStatus << endl;
       revertOutOfConflict();
       d_errorSet.clear();
       outputConflicts();
@@ -3428,11 +3412,14 @@ void TheoryArithPrivate::check(Theory::Effort effortLevel){
 
   switch(d_qflraStatus){
   case Result::SAT:
-    if(newFacts){
+    if (d_newFacts)
+    {
       ++d_statistics.d_nontrivialSatChecks;
     }
 
-    Debug("arith::bt") << "committing sap inConflit"  << " " << newFacts << " " << previous << " " << d_qflraStatus  << endl;
+    Debug("arith::bt") << "committing sap inConflit"
+                       << " " << d_newFacts << " " << d_previousStatus << " "
+                       << d_qflraStatus << endl;
     d_partialModel.commitAssignmentChanges();
     d_unknownsInARow = 0;
     if(Debug.isOn("arith::consistency")){
@@ -3450,7 +3437,9 @@ void TheoryArithPrivate::check(Theory::Effort effortLevel){
     ++d_unknownsInARow;
     ++(d_statistics.d_unknownChecks);
     Assert(!Theory::fullEffort(effortLevel));
-    Debug("arith::bt") << "committing unknown"  << " " << newFacts << " " << previous << " " << d_qflraStatus  << endl;
+    Debug("arith::bt") << "committing unknown"
+                       << " " << d_newFacts << " " << d_previousStatus << " "
+                       << d_qflraStatus << endl;
     d_partialModel.commitAssignmentChanges();
     d_statistics.d_maxUnknownsInARow.maxAssign(d_unknownsInARow);
 
@@ -3467,7 +3456,9 @@ void TheoryArithPrivate::check(Theory::Effort effortLevel){
 
     ++d_statistics.d_commitsOnConflicts;
 
-    Debug("arith::bt") << "committing on conflict" << " " << newFacts << " " << previous << " " << d_qflraStatus  << endl;
+    Debug("arith::bt") << "committing on conflict"
+                       << " " << d_newFacts << " " << d_previousStatus << " "
+                       << d_qflraStatus << endl;
     d_partialModel.commitAssignmentChanges();
     revertOutOfConflict();
 
@@ -3564,7 +3555,9 @@ void TheoryArithPrivate::check(Theory::Effort effortLevel){
       outputConflicts();
       emmittedConflictOrSplit = true;
       //cout << "unate conflict " << endl;
-      Debug("arith::bt") << "committing on unate conflict" << " " << newFacts << " " << previous << " " << d_qflraStatus  << endl;
+      Debug("arith::bt") << "committing on unate conflict"
+                         << " " << d_newFacts << " " << d_previousStatus << " "
+                         << d_qflraStatus << endl;
 
       Debug("arith::conflict") << "unate arith conflict" << endl;
     }
@@ -3852,15 +3845,6 @@ void TheoryArithPrivate::debugPrintModel(std::ostream& out) const{
   }
 }
 
-bool TheoryArithPrivate::needsCheckLastEffort() {
-  if (d_nonlinearExtension != nullptr)
-  {
-    return d_nonlinearExtension->needsCheckLastEffort();
-  }else{
-    return false;
-  }
-}
-
 Node TheoryArithPrivate::explain(TNode n)
 {
   Debug("arith::explain") << "explain @" << getSatContext()->getLevel() << ": " << n << endl;
@@ -4252,11 +4236,6 @@ void TheoryArithPrivate::presolve(){
     Debug("arith::oldprop") << " lemma lemma duck " <<lem << endl;
     outputLemma(lem);
   }
-
-  if (d_nonlinearExtension != nullptr)
-  {
-    d_nonlinearExtension->presolve();
-  }
 }
 
 EqualityStatus TheoryArithPrivate::getEqualityStatus(TNode a, TNode b) {
index 6f6ca7fdfc2a431bfa296eb48ed2df40409a99e3..71bada17da41005c138473af6c92f27622e8348f 100644 (file)
@@ -341,10 +341,7 @@ private:
    * Returns true iff a conflict has been raised. This method is public since
    * it is needed by the ArithState class to know whether we are in conflict.
    */
-  bool anyConflict() const
-  {
-    return !conflictQueueEmpty() || !d_blackBoxConflict.get().isNull();
-  }
+  bool anyConflict() const;
 
  private:
   inline bool conflictQueueEmpty() const {
@@ -461,8 +458,6 @@ private:
   void preRegisterTerm(TNode n);
   TrustNode expandDefinition(Node node);
 
-  void check(Theory::Effort e);
-  bool needsCheckLastEffort();
   void propagate(Theory::Effort e);
   Node explain(TNode n);
 
@@ -494,16 +489,23 @@ private:
 
   EqualityStatus getEqualityStatus(TNode a, TNode b);
 
-  void addSharedTerm(TNode n);
+  /** Called when n is notified as being a shared term with TheoryArith. */
+  void notifySharedTerm(TNode n);
 
   Node getModelValue(TNode var);
 
 
   std::pair<bool, Node> entailmentCheck(TNode lit, const ArithEntailmentCheckParameters& params, ArithEntailmentCheckSideEffects& out);
 
-
-private:
-
+  //--------------------------------- standard check
+  /** Pre-check, called before the fact queue of the theory is processed. */
+  bool preCheck(Theory::Effort level);
+  /** Pre-notify fact. */
+  void preNotifyFact(TNode atom, bool pol, TNode fact);
+  /** Post-check, called after the fact queue of the theory is processed. */
+  void postCheck(Theory::Effort level);
+  //--------------------------------- end standard check
+ private:
   /** The constant zero. */
   DeltaRational d_DELTA_ZERO;
 
@@ -648,8 +650,11 @@ private:
    * Handles the case splitting for check() for a new assertion.
    * Returns a conflict if one was found.
    * Returns Node::null if no conflict was found.
+   *
+   * @param assertion The assertion that was just popped from the fact queue
+   * of TheoryArith and given to this class via preNotifyFact.
    */
-  ConstraintP constraintFromFactQueue();
+  ConstraintP constraintFromFactQueue(TNode assertion);
   bool assertionCases(ConstraintP c);
 
   /**
@@ -763,6 +768,13 @@ private:
 
   RationalVector d_farkasBuffer;
 
+  //---------------- during check
+  /** Whether there were new facts during preCheck */
+  bool d_newFacts;
+  /** The previous status, computed during preCheck */
+  Result::Sat d_previousStatus;
+  //---------------- end during check
+
   /** These fields are designed to be accessible to TheoryArith methods. */
   class Statistics {
   public: