Implement proofs for arith BRAB lemmas (#5784)
authorAlex Ozdemir <aozdemir@hmc.edu>
Tue, 19 Jan 2021 22:25:02 +0000 (14:25 -0800)
committerGitHub <noreply@github.com>
Tue, 19 Jan 2021 22:25:02 +0000 (16:25 -0600)
All changes:

Add a Pf type alias for std::shared_ptr to
expr/proof_rule.h
Add an eager proof generator to TheoryArith for preprocessing
rewrites. Right now those are proven with INT_TRUST. Will eventually
fix.
Generate proved lemmas in TheoryArithPrivate::branchIntegerVariable.
Same for TheoryArithPrivate::roundRobinBranch
Add EagerProofGenerator::mkTrustedRewrite.
Add some proofsEnabled methods.

src/expr/proof_node.h
src/theory/arith/theory_arith.cpp
src/theory/arith/theory_arith.h
src/theory/arith/theory_arith_private.cpp
src/theory/arith/theory_arith_private.h
src/theory/eager_proof_generator.cpp
src/theory/eager_proof_generator.h
src/theory/theory.cpp
src/theory/theory.h

index 0c36deada35e4e3952972057fb68923014fc3c0a..a505b1bc038c2ad79f8c2142614eec8b4cd8ce03 100644 (file)
@@ -27,6 +27,9 @@ namespace CVC4 {
 class ProofNodeManager;
 class ProofNode;
 
+// Alias for shared pointer to a proof node
+using Pf = std::shared_ptr<ProofNode>;
+
 struct ProofNodeHashFunction
 {
   inline size_t operator()(std::shared_ptr<ProofNode> pfn) const;
index 7437ce223f053459e2751b655c84cd8cbbecbe8b..6de5c72213c7e4ea268ef1b30b3d7b43fb450764 100644 (file)
@@ -17,6 +17,7 @@
 
 #include "theory/arith/theory_arith.h"
 
+#include "expr/proof_rule.h"
 #include "options/smt_options.h"
 #include "smt/smt_statistics_registry.h"
 #include "theory/arith/arith_rewriter.h"
@@ -42,6 +43,7 @@ TheoryArith::TheoryArith(context::Context* c,
       d_internal(
           new TheoryArithPrivate(*this, c, u, out, valuation, logicInfo, pnm)),
       d_ppRewriteTimer("theory::arith::ppRewriteTimer"),
+      d_ppPfGen(pnm, c, "Arith::ppRewrite"),
       d_astate(*d_internal, c, u, valuation),
       d_inferenceManager(*this, d_astate, pnm),
       d_nonlinearExtension(nullptr),
@@ -131,7 +133,17 @@ TrustNode TheoryArith::ppRewrite(TNode atom)
       Debug("arith::preprocess")
           << "arith::preprocess() : returning " << rewritten << endl;
       // don't need to rewrite terms since rewritten is not a non-standard op
-      return TrustNode::mkTrustRewrite(atom, rewritten, nullptr);
+      if (proofsEnabled())
+      {
+        return d_ppPfGen.mkTrustedRewrite(
+            atom,
+            rewritten,
+            d_pnm->mkNode(PfRule::INT_TRUST, {}, {atom.eqNode(rewritten)}));
+      }
+      else
+      {
+        return TrustNode::mkTrustRewrite(atom, rewritten, nullptr);
+      }
     }
   }
   return ppRewriteTerms(atom);
index e26ff51efb2cb93c93215e9888d44a0b68f906d1..eba84e339410cd922f12a48f32a9de92329201ca 100644 (file)
@@ -45,6 +45,9 @@ class TheoryArith : public Theory {
 
   TimerStat d_ppRewriteTimer;
 
+  /** Used to prove pp-rewrites */
+  EagerProofGenerator d_ppPfGen;
+
  public:
   TheoryArith(context::Context* c,
               context::UserContext* u,
@@ -152,6 +155,7 @@ class TheoryArith : public Theory {
   ArithPreprocess d_arithPreproc;
   /** The theory rewriter for this theory. */
   ArithRewriter d_rewriter;
+
 };/* class TheoryArith */
 
 }/* CVC4::theory::arith namespace */
index 7b0096f3079187a52b454e6e3f93736d63ed16a2..58e8741589f87eaeccb73d51f91d55d156474149 100644 (file)
@@ -34,6 +34,7 @@
 #include "expr/node_builder.h"
 #include "expr/proof_generator.h"
 #include "expr/proof_node_manager.h"
+#include "expr/proof_rule.h"
 #include "expr/skolem_manager.h"
 #include "options/arith_options.h"
 #include "options/smt_options.h"  // for incrementalSolving()
@@ -3033,11 +3034,11 @@ bool TheoryArithPrivate::solveRelaxationOrPanic(Theory::Effort effortLevel){
     ArithVar canBranch = nextIntegerViolatation(false);
     if(canBranch != ARITHVAR_SENTINEL){
       ++d_statistics.d_panicBranches;
-      Node branch = branchIntegerVariable(canBranch);
-      Assert(branch.getKind() == kind::OR);
-      Node rwbranch = Rewriter::rewrite(branch[0]);
+      TrustNode branch = branchIntegerVariable(canBranch);
+      Assert(branch.getNode().getKind() == kind::OR);
+      Node rwbranch = Rewriter::rewrite(branch.getNode()[0]);
       if(!isSatLiteral(rwbranch)){
-        d_approxCuts.push_back(branch);
+        d_approxCuts.push_back(branch.getNode());
         return true;
       }
     }
@@ -3619,15 +3620,15 @@ bool TheoryArithPrivate::postCheck(Theory::Effort effortLevel)
     }
 
     if(!emmittedConflictOrSplit) {
-      Node possibleLemma = roundRobinBranch();
-      if(!possibleLemma.isNull()){
+      TrustNode possibleLemma = roundRobinBranch();
+      if (!possibleLemma.getNode().isNull())
+      {
         ++(d_statistics.d_externalBranchAndBounds);
         d_cutCount = d_cutCount + 1;
         emmittedConflictOrSplit = true;
         Debug("arith::lemma") << "rrbranch lemma"
                               << possibleLemma << endl;
-        outputLemma(possibleLemma);
-
+        outputTrustedLemma(possibleLemma);
       }
     }
 
@@ -3662,7 +3663,8 @@ bool TheoryArithPrivate::postCheck(Theory::Effort effortLevel)
 
 bool TheoryArithPrivate::foundNonlinear() const { return d_foundNl; }
 
-Node TheoryArithPrivate::branchIntegerVariable(ArithVar x) const {
+TrustNode TheoryArithPrivate::branchIntegerVariable(ArithVar x) const
+{
   const DeltaRational& d = d_partialModel.getAssignment(x);
   Assert(!d.isIntegral());
   const Rational& r = d.getNoninfinitesimalPart();
@@ -3674,7 +3676,7 @@ Node TheoryArithPrivate::branchIntegerVariable(ArithVar x) const {
   TNode var = d_partialModel.asNode(x);
   Integer floor_d = d.floor();
 
-  Node lem;
+  TrustNode lem = TrustNode::null();
   NodeManager* nm = NodeManager::currentNM();
   if (options::brabTest())
   {
@@ -3691,38 +3693,102 @@ Node TheoryArithPrivate::branchIntegerVariable(ArithVar x) const {
         nm->mkNode(kind::LEQ, var, mkRationalNode(nearest - 1)));
     Node lb = Rewriter::rewrite(
         nm->mkNode(kind::GEQ, var, mkRationalNode(nearest + 1)));
-    lem = nm->mkNode(kind::OR, ub, lb);
-    Node eq = Rewriter::rewrite(
-        nm->mkNode(kind::EQUAL, var, mkRationalNode(nearest)));
+    Node right = nm->mkNode(kind::OR, ub, lb);
+    Node rawEq = nm->mkNode(kind::EQUAL, var, mkRationalNode(nearest));
+    Node eq = Rewriter::rewrite(rawEq);
     // Also preprocess it before we send it out. This is important since
     // arithmetic may prefer eliminating equalities.
+    TrustNode teq;
     if (Theory::theoryOf(eq) == THEORY_ARITH)
     {
-      TrustNode teq = d_containing.ppRewrite(eq);
+      teq = d_containing.ppRewrite(eq);
       eq = teq.isNull() ? eq : teq.getNode();
     }
     Node literal = d_containing.getValuation().ensureLiteral(eq);
+    Trace("integers") << "eq: " << eq << "\nto: " << literal << endl;
     d_containing.getOutputChannel().requirePhase(literal, true);
-    lem = nm->mkNode(kind::OR, literal, lem);
+    Node l = nm->mkNode(kind::OR, literal, right);
+    Trace("integers") << "l: " << l << endl;
+    if (proofsEnabled())
+    {
+      Node less = nm->mkNode(kind::LT, var, mkRationalNode(nearest));
+      Node greater = nm->mkNode(kind::GT, var, mkRationalNode(nearest));
+      // TODO (project #37): justify. Thread proofs through *ensureLiteral*.
+      Debug("integers::pf") << "less: " << less << endl;
+      Debug("integers::pf") << "greater: " << greater << endl;
+      Debug("integers::pf") << "literal: " << literal << endl;
+      Debug("integers::pf") << "eq: " << eq << endl;
+      Debug("integers::pf") << "rawEq: " << rawEq << endl;
+      Pf pfNotLit = d_pnm->mkAssume(literal.negate());
+      // rewrite notLiteral to notRawEq, using teq.
+      Pf pfNotRawEq =
+          literal == rawEq
+              ? pfNotLit
+              : d_pnm->mkNode(
+                  PfRule::MACRO_SR_PRED_TRANSFORM,
+                  {pfNotLit, teq.getGenerator()->getProofFor(teq.getProven())},
+                  {rawEq.negate()});
+      Pf pfBot =
+          d_pnm->mkNode(PfRule::CONTRA,
+                        {d_pnm->mkNode(PfRule::ARITH_TRICHOTOMY,
+                                       {d_pnm->mkAssume(less.negate()), pfNotRawEq},
+                                       {greater}),
+                         d_pnm->mkAssume(greater.negate())},
+                        {});
+      std::vector<Node> assumptions = {
+          literal.negate(), less.negate(), greater.negate()};
+      // Proof of (not (and (not (= v i)) (not (< v i)) (not (> v i))))
+      Pf pfNotAnd = d_pnm->mkScope(pfBot, assumptions);
+      Pf pfL = d_pnm->mkNode(PfRule::MACRO_SR_PRED_TRANSFORM,
+                             {d_pnm->mkNode(PfRule::NOT_AND, {pfNotAnd}, {})},
+                             {l});
+      lem = d_pfGen->mkTrustNode(l, pfL);
+    }
+    else
+    {
+      lem = TrustNode::mkTrustLemma(l, nullptr);
+    }
   }
   else
   {
     Node ub =
         Rewriter::rewrite(nm->mkNode(kind::LEQ, var, mkRationalNode(floor_d)));
     Node lb = ub.notNode();
-    lem = nm->mkNode(kind::OR, ub, lb);
+    if (proofsEnabled())
+    {
+      lem = d_pfGen->mkTrustNode(
+          nm->mkNode(kind::OR, ub, lb), PfRule::SPLIT, {}, {ub});
+    }
+    else
+    {
+      lem = TrustNode::mkTrustLemma(nm->mkNode(kind::OR, ub, lb), nullptr);
+    }
   }
 
   Trace("integers") << "integers: branch & bound: " << lem << endl;
-  if(isSatLiteral(lem[0])) {
-    Debug("integers") << "    " << lem[0] << " == " << getSatValue(lem[0]) << endl;
-  } else {
-    Debug("integers") << "    " << lem[0] << " is not assigned a SAT literal" << endl;
-  }
-  if(isSatLiteral(lem[1])) {
-    Debug("integers") << "    " << lem[1] << " == " << getSatValue(lem[1]) << endl;
-    } else {
-    Debug("integers") << "    " << lem[1] << " is not assigned a SAT literal" << endl;
+  if (Debug.isOn("integers"))
+  {
+    Node l = lem.getNode();
+    if (isSatLiteral(l[0]))
+    {
+      Debug("integers") << "    " << l[0] << " == " << getSatValue(l[0])
+                        << endl;
+    }
+    else
+    {
+      Debug("integers") << "    " << l[0] << " is not assigned a SAT literal"
+                        << endl;
+    }
+    if (isSatLiteral(l[1]))
+    {
+      Debug("integers") << "    " << l[1] << " == " << getSatValue(l[1])
+                        << endl;
+    }
+    else
+    {
+      Debug("integers") << "    " << l[1] << " is not assigned a SAT literal"
+                        << endl;
+    }
   }
   return lem;
 }
@@ -3748,9 +3814,10 @@ std::vector<ArithVar> TheoryArithPrivate::cutAllBounded() const{
 }
 
 /** Returns true if the roundRobinBranching() issues a lemma. */
-Node TheoryArithPrivate::roundRobinBranch(){
+TrustNode TheoryArithPrivate::roundRobinBranch()
+{
   if(hasIntegerModel()){
-    return Node::null();
+    return TrustNode::null();
   }else{
     ArithVar v = d_nextIntegerCheckVar;
 
index 31435221fb528a7c4dfeed01a23d0b6c01c8998a..e7f5d82b20584a318a4b424048531e0527db3804 100644 (file)
@@ -552,9 +552,11 @@ private:
    * Returns a cut for a lemma.
    * If there is an integer model, this returns Node::null().
    */
-  Node roundRobinBranch();
+  TrustNode roundRobinBranch();
 
-public:
+  bool proofsEnabled() const { return d_pnm; }
+
+ public:
   /**
    * This requests a new unique ArithVar value for x.
    * This also does initial (not context dependent) set up for a variable,
@@ -709,7 +711,7 @@ private:
   /** Counts the number of fullCheck calls to arithmetic. */
   uint32_t d_fullCheckCounter;
   std::vector<ArithVar> cutAllBounded() const;
-  Node branchIntegerVariable(ArithVar x) const;
+  TrustNode branchIntegerVariable(ArithVar x) const;
   void branchVector(const std::vector<ArithVar>& lemmas);
 
   context::CDO<unsigned> d_cutCount;
index c49c33790e65607d414133051697651361e944ce..a1c78fc7d68fa0e48654ecb3bcd9186829f617e5 100644 (file)
@@ -120,6 +120,18 @@ TrustNode EagerProofGenerator::mkTrustNode(Node conc,
   return mkTrustNode(pfs->getResult(), pfs, isConflict);
 }
 
+TrustNode EagerProofGenerator::mkTrustedRewrite(
+    Node a, Node b, std::shared_ptr<ProofNode> pf)
+{
+  if (pf == nullptr)
+  {
+    return TrustNode::null();
+  }
+  Node eq = a.eqNode(b);
+  setProofFor(eq, pf);
+  return TrustNode::mkTrustRewrite(a, b, this);
+}
+
 TrustNode EagerProofGenerator::mkTrustedPropagation(
     Node n, Node exp, std::shared_ptr<ProofNode> pf)
 {
index 29f916e00eba67fd651396193aa4f1f7ee6df481..4b94228c5c6b98d8e77ad7360b48b704669cc621 100644 (file)
@@ -146,6 +146,17 @@ class EagerProofGenerator : public ProofGenerator
   TrustNode mkTrustedPropagation(Node n,
                                  Node exp,
                                  std::shared_ptr<ProofNode> pf);
+  /**
+   * Make trust node: `a = b` as a Rewrite trust node
+   *
+   * @param a the original
+   * @param b what is rewrites to
+   * @param pf The proof of a = b,
+   * @return The trust node corresponding to the fact that this generator has
+   * a proof of a = b
+   */
+  TrustNode mkTrustedRewrite(
+      Node a, Node b, std::shared_ptr<ProofNode> pf);
   //--------------------------------------- common proofs
   /**
    * This returns the trust node corresponding to the splitting lemma
index 9ab20e6cb35c9eabbcda2951af68eadaf575e8fd..5bea454e8e3207c902867dd6ace54706b32055cd 100644 (file)
@@ -424,6 +424,11 @@ void Theory::getCareGraph(CareGraph* careGraph) {
   d_careGraph = NULL;
 }
 
+bool Theory::proofsEnabled() const
+{
+  return d_pnm != nullptr;
+}
+
 EqualityStatus Theory::getEqualityStatus(TNode a, TNode b)
 {
   // if not using an equality engine, then by default we don't know the status
index a9783c19c1f9a5c0e09b96b0e4f339f47f8272b5..0eb3d9a33f013cbfef77e736dbde3b25943c51cf 100644 (file)
@@ -237,6 +237,13 @@ class Theory {
   /** Pointer to proof node manager */
   ProofNodeManager* d_pnm;
 
+  /**
+   * Are proofs enabled?
+   *
+   * They are considered enabled if the ProofNodeManager is non-null.
+   */
+  bool proofsEnabled() const;
+
   /**
    * Returns the next assertion in the assertFact() queue.
    *