(proof-new) Prove lemmas in Constraint (#5254)
authorAlex Ozdemir <aozdemir@hmc.edu>
Wed, 14 Oct 2020 04:15:25 +0000 (21:15 -0700)
committerGitHub <noreply@github.com>
Wed, 14 Oct 2020 04:15:25 +0000 (23:15 -0500)
Includes:

T/F splitting lemmas for any arith constraint
Unate lemmas produced early on
The hard part is proving the unate lemmas. In general, they are all implied by 2-constraint farkas proofs, so we ultimately map them all down to proveOr, which constructs that proof.

make check was happy with this change. Hopefully the CI agrees :).

src/theory/arith/constraint.cpp
src/theory/arith/constraint.h
src/theory/arith/theory_arith_private.cpp
src/theory/arith/theory_arith_private.h

index a081222952a9ea0f204bcc752eb11c33d7d04dd3..b0be108f72ee13780217f3ed3f072554f37feb0e 100644 (file)
@@ -1062,7 +1062,8 @@ bool Constraint::contextDependentDataIsSet() const{
   return hasProof() || isSplit() || canBePropagated() || assertedToTheTheory();
 }
 
-Node Constraint::split(){
+TrustNode Constraint::split()
+{
   Assert(isEquality() || isDisequality());
 
   bool isEq = isEquality();
@@ -1076,15 +1077,48 @@ Node Constraint::split(){
   TNode rhs = eqNode[1];
 
   Node leqNode = NodeBuilder<2>(kind::LEQ) << lhs << rhs;
+  Node ltNode = NodeBuilder<2>(kind::LT) << lhs << rhs;
+  Node gtNode = NodeBuilder<2>(kind::GT) << lhs << rhs;
   Node geqNode = NodeBuilder<2>(kind::GEQ) << lhs << rhs;
 
   Node lemma = NodeBuilder<3>(OR) << leqNode << geqNode;
 
+  TrustNode trustedLemma;
+  if (options::proofNew())
+  {
+    // Farkas proof that this works.
+    auto nm = NodeManager::currentNM();
+    auto nLeqPf = d_database->d_pnm->mkAssume(leqNode.negate());
+    auto gtPf = d_database->d_pnm->mkNode(
+        PfRule::MACRO_SR_PRED_TRANSFORM, {nLeqPf}, {gtNode});
+    auto nGeqPf = d_database->d_pnm->mkAssume(geqNode.negate());
+    auto ltPf = d_database->d_pnm->mkNode(
+        PfRule::MACRO_SR_PRED_TRANSFORM, {nGeqPf}, {ltNode});
+    auto sumPf = d_database->d_pnm->mkNode(
+        PfRule::ARITH_SCALE_SUM_UPPER_BOUNDS,
+        {gtPf, ltPf},
+        {nm->mkConst<Rational>(-1), nm->mkConst<Rational>(1)});
+    auto botPf = d_database->d_pnm->mkNode(
+        PfRule::MACRO_SR_PRED_TRANSFORM, {sumPf}, {nm->mkConst(false)});
+    std::vector<Node> a = {leqNode.negate(), geqNode.negate()};
+    auto notAndNotPf = d_database->d_pnm->mkScope(botPf, a);
+    // No need to ensure that the expected node aggrees with `a` because we are
+    // not providing an expected node.
+    auto orNotNotPf =
+        d_database->d_pnm->mkNode(PfRule::NOT_AND, {notAndNotPf}, {});
+    auto orPf = d_database->d_pnm->mkNode(
+        PfRule::MACRO_SR_PRED_TRANSFORM, {orNotNotPf}, {lemma});
+    trustedLemma = d_database->d_pfGen->mkTrustNode(lemma, orPf);
+  }
+  else
+  {
+    trustedLemma = TrustNode::mkTrustLemma(lemma);
+  }
 
   eq->d_database->pushSplitWatch(eq);
   diseq->d_database->pushSplitWatch(diseq);
 
-  return lemma;
+  return trustedLemma;
 }
 
 bool ConstraintDatabase::hasLiteral(TNode literal) const {
@@ -2026,30 +2060,83 @@ Node Constraint::getProofLiteral() const
   return neg ? posLit.negate() : posLit;
 }
 
-void implies(std::vector<Node>& out, ConstraintP a, ConstraintP b){
+void ConstraintDatabase::proveOr(std::vector<TrustNode>& out,
+                                 ConstraintP a,
+                                 ConstraintP b,
+                                 bool negateSecond) const
+{
+  Node la = a->getLiteral();
+  Node lb = b->getLiteral();
+  Node orN = (la < lb) ? la.orNode(lb) : lb.orNode(la);
+  if (options::proofNew())
+  {
+    Assert(b->getNegation()->getType() != ConstraintType::Disequality);
+    auto nm = NodeManager::currentNM();
+    auto pf_neg_la = d_pnm->mkNode(PfRule::MACRO_SR_PRED_TRANSFORM,
+                                   {d_pnm->mkAssume(la.negate())},
+                                   {a->getNegation()->getProofLiteral()});
+    auto pf_neg_lb = d_pnm->mkNode(PfRule::MACRO_SR_PRED_TRANSFORM,
+                                   {d_pnm->mkAssume(lb.negate())},
+                                   {b->getNegation()->getProofLiteral()});
+    int sndSign = negateSecond ? -1 : 1;
+    auto bot_pf =
+        d_pnm->mkNode(PfRule::MACRO_SR_PRED_TRANSFORM,
+                      {d_pnm->mkNode(PfRule::ARITH_SCALE_SUM_UPPER_BOUNDS,
+                                     {pf_neg_la, pf_neg_lb},
+                                     {nm->mkConst<Rational>(-1 * sndSign),
+                                      nm->mkConst<Rational>(sndSign)})},
+                      {nm->mkConst(false)});
+    std::vector<Node> as;
+    std::transform(orN.begin(), orN.end(), std::back_inserter(as), [](Node n) {
+      return n.negate();
+    });
+    // No need to ensure that the expected node aggrees with `as` because we
+    // are not providing an expected node.
+    auto pf = d_pnm->mkNode(
+        PfRule::MACRO_SR_PRED_TRANSFORM,
+        {d_pnm->mkNode(PfRule::NOT_AND, {d_pnm->mkScope(bot_pf, as)}, {})},
+        {orN});
+    out.push_back(d_pfGen->mkTrustNode(orN, pf));
+  }
+  else
+  {
+    out.push_back(TrustNode::mkTrustLemma(orN));
+  }
+}
+
+void ConstraintDatabase::implies(std::vector<TrustNode>& out,
+                                 ConstraintP a,
+                                 ConstraintP b) const
+{
   Node la = a->getLiteral();
   Node lb = b->getLiteral();
 
   Node neg_la = (la.getKind() == kind::NOT)? la[0] : la.notNode();
 
   Assert(lb != neg_la);
-  Node orderOr = (lb < neg_la) ? lb.orNode(neg_la) : neg_la.orNode(lb);
-  out.push_back(orderOr);
+  Assert(b->getNegation()->getType() == ConstraintType::LowerBound
+         || b->getNegation()->getType() == ConstraintType::UpperBound);
+  proveOr(out,
+          a->getNegation(),
+          b,
+          b->getNegation()->getType() == ConstraintType::LowerBound);
 }
 
-void mutuallyExclusive(std::vector<Node>& out, ConstraintP a, ConstraintP b){
+void ConstraintDatabase::mutuallyExclusive(std::vector<TrustNode>& out,
+                                           ConstraintP a,
+                                           ConstraintP b) const
+{
   Node la = a->getLiteral();
   Node lb = b->getLiteral();
 
-  Node neg_la = (la.getKind() == kind::NOT)? la[0] : la.notNode();
-  Node neg_lb = (lb.getKind() == kind::NOT)? lb[0] : lb.notNode();
-
-  Assert(neg_la != neg_lb);
-  Node orderOr = (neg_la < neg_lb) ? neg_la.orNode(neg_lb) : neg_lb.orNode(neg_la);
-  out.push_back(orderOr);
+  Node neg_la = la.negate();
+  Node neg_lb = lb.negate();
+  proveOr(out, a->getNegation(), b->getNegation(), true);
 }
 
-void ConstraintDatabase::outputUnateInequalityLemmas(std::vector<Node>& out, ArithVar v) const{
+void ConstraintDatabase::outputUnateInequalityLemmas(
+    std::vector<TrustNode>& out, ArithVar v) const
+{
   SortedConstraintMap& scm = getVariableSCM(v);
   SortedConstraintMapConstIterator scm_iter = scm.begin();
   SortedConstraintMapConstIterator scm_end = scm.end();
@@ -2070,8 +2157,9 @@ void ConstraintDatabase::outputUnateInequalityLemmas(std::vector<Node>& out, Ari
   }
 }
 
-void ConstraintDatabase::outputUnateEqualityLemmas(std::vector<Node>& out, ArithVar v) const{
-
+void ConstraintDatabase::outputUnateEqualityLemmas(std::vector<TrustNode>& out,
+                                                   ArithVar v) const
+{
   vector<ConstraintP> equalities;
 
   SortedConstraintMap& scm = getVariableSCM(v);
@@ -2123,13 +2211,17 @@ void ConstraintDatabase::outputUnateEqualityLemmas(std::vector<Node>& out, Arith
   }
 }
 
-void ConstraintDatabase::outputUnateEqualityLemmas(std::vector<Node>& lemmas) const{
+void ConstraintDatabase::outputUnateEqualityLemmas(
+    std::vector<TrustNode>& lemmas) const
+{
   for(ArithVar v = 0, N = d_varDatabases.size(); v < N; ++v){
     outputUnateEqualityLemmas(lemmas, v);
   }
 }
 
-void ConstraintDatabase::outputUnateInequalityLemmas(std::vector<Node>& lemmas) const{
+void ConstraintDatabase::outputUnateInequalityLemmas(
+    std::vector<TrustNode>& lemmas) const
+{
   for(ArithVar v = 0, N = d_varDatabases.size(); v < N; ++v){
     outputUnateInequalityLemmas(lemmas, v);
   }
index 02bc3c98869cf023cf352d627908ac1c05d2490d..952879182f372e96f1623817158ffa9be806b85a 100644 (file)
@@ -411,7 +411,7 @@ class Constraint {
    * Returns a lemma that is assumed to be true for the rest of the user context.
    * Constraint must be an equality or disequality.
    */
-  Node split();
+  TrustNode split();
 
   bool canBePropagated() const {
     return d_canBePropagated;
@@ -1191,14 +1191,39 @@ private:
 
   void deleteConstraintAndNegation(ConstraintP c);
 
+  /** Given constraints `a` and `b` such that `a OR b` by unate reasoning,
+   *  adds a TrustNode to `out` which proves `a OR b` as a lemma.
+   *
+   *  Example: `x <= 5` OR `5 <= x`.
+   */
+  void proveOr(std::vector<TrustNode>& out,
+               ConstraintP a,
+               ConstraintP b,
+               bool negateSecond) const;
+  /** Given constraints `a` and `b` such that `a` implies `b` by unate
+   * reasoning, adds a TrustNode to `out` which proves `-a OR b` as a lemma.
+   *
+   *  Example: `x >= 5` -> `x >= 4`.
+   */
+  void implies(std::vector<TrustNode>& out, ConstraintP a, ConstraintP b) const;
+  /** Given constraints `a` and `b` such that `not(a AND b)` by unate reasoning,
+   *  adds a TrustNode to `out` which proves `-a OR -b` as a lemma.
+   *
+   *  Example: `x >= 4` -> `x <= 3`.
+   */
+  void mutuallyExclusive(std::vector<TrustNode>& out,
+                         ConstraintP a,
+                         ConstraintP b) const;
+
   /**
    * Outputs a minimal set of unate implications onto the vector for the variable.
    * This outputs lemmas of the general forms
    *     (= p c) implies (<= p d) for c < d, or
    *     (= p c) implies (not (= p d)) for c != d.
    */
-  void outputUnateEqualityLemmas(std::vector<Node>& lemmas) const;
-  void outputUnateEqualityLemmas(std::vector<Node>& lemmas, ArithVar v) const;
+  void outputUnateEqualityLemmas(std::vector<TrustNode>& lemmas) const;
+  void outputUnateEqualityLemmas(std::vector<TrustNode>& lemmas,
+                                 ArithVar v) const;
 
   /**
    * Outputs a minimal set of unate implications onto the vector for the variable.
@@ -1206,9 +1231,9 @@ private:
    * If ineqs is true, this outputs lemmas of the general form
    *     (<= p c) implies (<= p d) for c < d.
    */
-  void outputUnateInequalityLemmas(std::vector<Node>& lemmas) const;
-  void outputUnateInequalityLemmas(std::vector<Node>& lemmas, ArithVar v) const;
-
+  void outputUnateInequalityLemmas(std::vector<TrustNode>& lemmas) const;
+  void outputUnateInequalityLemmas(std::vector<TrustNode>& lemmas,
+                                   ArithVar v) const;
 
   void unatePropLowerBound(ConstraintP curr, ConstraintP prev);
   void unatePropUpperBound(ConstraintP curr, ConstraintP prev);
index 3abd9495e2fb6fc4046f3ef3eae21bcf57520b7c..119be6307a06b755c102624e1a57f931c631314a 100644 (file)
@@ -1063,7 +1063,7 @@ bool TheoryArithPrivate::AssertDisequality(ConstraintP constraint){
 
   if(!split && c_i == d_partialModel.getAssignment(x_i)){
     Debug("arith::eq") << "lemma now! " << constraint << endl;
-    outputLemma(constraint->split());
+    outputTrustedLemma(constraint->split());
     return false;
   }else if(d_partialModel.strictlyLessThanLowerBound(x_i, c_i)){
     Debug("arith::eq") << "can drop as less than lb" << constraint << endl;
@@ -1918,6 +1918,12 @@ void TheoryArithPrivate::outputConflicts(){
   }
 }
 
+void TheoryArithPrivate::outputTrustedLemma(TrustNode lemma)
+{
+  Debug("arith::channel") << "Arith trusted lemma: " << lemma << std::endl;
+  (d_containing.d_out)->lemma(lemma.getNode());
+}
+
 void TheoryArithPrivate::outputLemma(TNode lem) {
   Debug("arith::channel") << "Arith lemma: " << lem << std::endl;
   (d_containing.d_out)->lemma(lem);
@@ -3728,11 +3734,12 @@ bool TheoryArithPrivate::splitDisequalities(){
         Debug("arith::lemma") << "Splitting on " << front << endl;
         Debug("arith::lemma") << "LHS value = " << lhsValue << endl;
         Debug("arith::lemma") << "RHS value = " << rhsValue << endl;
-        Node lemma = front->split();
+        TrustNode lemma = front->split();
         ++(d_statistics.d_statDisequalitySplits);
 
-        Debug("arith::lemma") << "Now " << Rewriter::rewrite(lemma) << endl;
-        outputLemma(lemma);
+        Debug("arith::lemma")
+            << "Now " << Rewriter::rewrite(lemma.getNode()) << endl;
+        outputTrustedLemma(lemma);
         //cout << "Now " << Rewriter::rewrite(lemma) << endl;
         splitSomething = true;
       }else if(d_partialModel.strictlyLessThanLowerBound(lhsVar, rhsValue)){
@@ -4158,7 +4165,7 @@ void TheoryArithPrivate::presolve(){
     callCount = callCount + 1;
   }
 
-  vector<Node> lemmas;
+  vector<TrustNode> lemmas;
   if(!options::incrementalSolving()) {
     switch(options::arithUnateLemmaMode()){
       case options::ArithUnateLemmaMode::NO: break;
@@ -4176,11 +4183,11 @@ void TheoryArithPrivate::presolve(){
     }
   }
 
-  vector<Node>::const_iterator i = lemmas.begin(), i_end = lemmas.end();
+  vector<TrustNode>::const_iterator i = lemmas.begin(), i_end = lemmas.end();
   for(; i != i_end; ++i){
-    Node lem = *i;
+    TrustNode lem = *i;
     Debug("arith::oldprop") << " lemma lemma duck " <<lem << endl;
-    outputLemma(lem);
+    outputTrustedLemma(lem);
   }
 }
 
index 012e45b2f3d43942bc321a91248fd42b92031d05..be192e8805cac896da7f65fc125faa3bfcfea56d 100644 (file)
@@ -689,6 +689,7 @@ private:
   inline TheoryId theoryOf(TNode x) const { return d_containing.theoryOf(x); }
   inline void debugPrintFacts() const { d_containing.debugPrintFacts(); }
   inline context::Context* getSatContext() const { return d_containing.getSatContext(); }
+  void outputTrustedLemma(TrustNode lem);
   void outputLemma(TNode lem);
   void outputConflict(TNode lit);
   void outputPropagate(TNode lit);