Improve arithmetic proofs (#6106)
authorGereon Kremer <gereon.kremer@cs.rwth-aachen.de>
Wed, 10 Mar 2021 20:48:13 +0000 (21:48 +0100)
committerGitHub <noreply@github.com>
Wed, 10 Mar 2021 20:48:13 +0000 (20:48 +0000)
The proof rules for ARITH_MULT_POS and ARITH_MULT_NEG were complex than necessary in that they incorporated a rewriting step. This PR removes rewriting from these rules, making them cleaner and easier to understand.
The proof now applies these simpler rule and uses MACRO_SR_PRED_TRANSFORM to prove the lemma that is actually added.

src/expr/proof_rule.h
src/theory/arith/nl/ext/monomial_bounds_check.cpp
src/theory/arith/nl/ext/proof_checker.cpp

index 909f7b7cdcb810f1e90966a8eff8f2ec1f55e244..2759a3c9e6516c5f23e4967c324a2bc691a6c213 100644 (file)
@@ -1108,19 +1108,17 @@ enum class PfRule : uint32_t
   ARITH_MULT_SIGN,
   //======== Multiplication with positive factor
   // Children: none
-  // Arguments: (m, orig, lhs, rel, rhs)
+  // Arguments: (m, (rel lhs rhs))
   // ---------------------
   // Conclusion: (=> (and (> m 0) (rel lhs rhs)) (rel (* m lhs) (* m rhs)))
-  // Where orig is the origin that implies (rel lhs rhs) and rel is a relation
-  // symbol.
+  // Where rel is a relation symbol.
   ARITH_MULT_POS,
   //======== Multiplication with negative factor
   // Children: none
-  // Arguments: (m, orig, (rel lhs rhs))
+  // Arguments: (m, (rel lhs rhs))
   // ---------------------
   // Conclusion: (=> (and (< m 0) (rel lhs rhs)) (rel_inv (* m lhs) (* m rhs)))
-  // Where orig is the origin that implies (rel lhs rhs) and rel is a relation
-  // symbol and rel_inv the inverted relation symbol.
+  // Where rel is a relation symbol and rel_inv the inverted relation symbol.
   ARITH_MULT_NEG,
   //======== Multiplication tangent plane
   // Children: none
index 47cb5daecf0783e41cd8ae84e3fddc682e5c0050..f1a2f45b9145675347aa2a20380a093ffbba81da 100644 (file)
@@ -299,11 +299,7 @@ void MonomialBoundsCheck::checkBounds(const std::vector<Node>& asserts,
           Node infer_rhs = nm->mkNode(Kind::MULT, mult, rhs);
           Node infer = nm->mkNode(infer_type, infer_lhs, infer_rhs);
           Trace("nl-ext-bound-debug") << "     " << infer << std::endl;
-          infer = Rewriter::rewrite(infer);
-          Trace("nl-ext-bound-debug2")
-              << "     ...rewritten : " << infer << std::endl;
-          // check whether it is false in model for abstraction
-          Node infer_mv = d_data->d_model.computeAbstractModelValue(infer);
+          Node infer_mv = d_data->d_model.computeAbstractModelValue(Rewriter::rewrite(infer));
           Trace("nl-ext-bound-debug")
               << "       ...infer model value is " << infer_mv << std::endl;
           if (infer_mv == d_data->d_false)
@@ -314,22 +310,35 @@ void MonomialBoundsCheck::checkBounds(const std::vector<Node>& asserts,
                     mmv_sign == 1 ? Kind::GT : Kind::LT, mult, d_data->d_zero),
                 d_ci_exp[x][coeff][rhs]);
             Node iblem = nm->mkNode(Kind::IMPLIES, exp, infer);
-            Node pr_iblem = iblem;
-            iblem = Rewriter::rewrite(iblem);
-            bool introNewTerms = hasNewMonomials(iblem, d_data->d_ms);
+            Node iblem_rw = Rewriter::rewrite(iblem);
+            bool introNewTerms = hasNewMonomials(iblem_rw, d_data->d_ms);
             Trace("nl-ext-bound-lemma")
-                << "*** Bound inference lemma : " << iblem
-                << " (pre-rewrite : " << pr_iblem << ")" << std::endl;
+                << "*** Bound inference lemma : " << iblem_rw
+                << " (pre-rewrite : " << iblem << ")" << std::endl;
             CDProof* proof = nullptr;
+            Node orig = d_ci_exp[x][coeff][rhs];
             if (d_data->isProofEnabled())
             {
               proof = d_data->getProof();
+              // this is iblem, but uses (type t rhs) instead of the original
+              // variant (which is identical under rewriting)
+              // we first infer the "clean" version of the lemma and then
+              // use MACRO_SR_PRED_TRANSFORM to rewrite
+              Node tmplem = nm->mkNode(
+                  Kind::IMPLIES,
+                  nm->mkNode(Kind::AND,
+                             nm->mkNode(mmv_sign == 1 ? Kind::GT : Kind::LT,
+                                        mult,
+                                        d_data->d_zero),
+                             nm->mkNode(type, t, rhs)),
+                  infer);
+              proof->addStep(tmplem,
+                             mmv_sign == 1 ? PfRule::ARITH_MULT_POS
+                                           : PfRule::ARITH_MULT_NEG,
+                             {},
+                             {mult, nm->mkNode(type, t, rhs)});
               proof->addStep(
-                  iblem,
-                  mmv_sign == 1 ? PfRule::ARITH_MULT_POS
-                                : PfRule::ARITH_MULT_NEG,
-                  {},
-                  {mult, d_ci_exp[x][coeff][rhs], nm->mkNode(type, t, rhs)});
+                  iblem, PfRule::MACRO_SR_PRED_TRANSFORM, {tmplem}, {iblem});
             }
             d_data->d_im.addPendingLemma(iblem,
                                          InferenceId::ARITH_NL_INFER_BOUNDS_NT,
index e88e08aaf892217eb2ad632866beed8fbcfdf01c..6d027fd16f21ff9e3adea95271d82daa5a5f30f6 100644 (file)
@@ -122,39 +122,37 @@ Node ExtProofRuleChecker::checkInternal(PfRule id,
   else if (id == PfRule::ARITH_MULT_POS)
   {
     Assert(children.empty());
-    Assert(args.size() == 3);
+    Assert(args.size() == 2);
     Node mult = args[0];
-    Node orig = args[1];
-    Kind rel = args[2].getKind();
+    Kind rel = args[1].getKind();
     Assert(rel == Kind::EQUAL || rel == Kind::DISTINCT || rel == Kind::LT
            || rel == Kind::LEQ || rel == Kind::GT || rel == Kind::GEQ);
-    Node lhs = args[2][0];
-    Node rhs = args[2][1];
-    return Rewriter::rewrite(nm->mkNode(
+    Node lhs = args[1][0];
+    Node rhs = args[1][1];
+    return nm->mkNode(
         Kind::IMPLIES,
-        nm->mkAnd(std::vector<Node>{nm->mkNode(Kind::GT, mult, zero), orig}),
+        nm->mkAnd(std::vector<Node>{nm->mkNode(Kind::GT, mult, zero), args[1]}),
         nm->mkNode(rel,
                    nm->mkNode(Kind::MULT, mult, lhs),
-                   nm->mkNode(Kind::MULT, mult, rhs))));
+                   nm->mkNode(Kind::MULT, mult, rhs)));
   }
   else if (id == PfRule::ARITH_MULT_NEG)
   {
     Assert(children.empty());
-    Assert(args.size() == 3);
+    Assert(args.size() == 2);
     Node mult = args[0];
-    Node orig = args[1];
-    Kind rel = args[2].getKind();
+    Kind rel = args[1].getKind();
     Assert(rel == Kind::EQUAL || rel == Kind::DISTINCT || rel == Kind::LT
            || rel == Kind::LEQ || rel == Kind::GT || rel == Kind::GEQ);
     Kind rel_inv = (rel == Kind::DISTINCT ? rel : reverseRelationKind(rel));
-    Node lhs = args[2][0];
-    Node rhs = args[2][1];
-    return Rewriter::rewrite(nm->mkNode(
+    Node lhs = args[1][0];
+    Node rhs = args[1][1];
+    return nm->mkNode(
         Kind::IMPLIES,
-        nm->mkAnd(std::vector<Node>{nm->mkNode(Kind::LT, mult, zero), orig}),
+        nm->mkAnd(std::vector<Node>{nm->mkNode(Kind::LT, mult, zero), args[1]}),
         nm->mkNode(rel_inv,
                    nm->mkNode(Kind::MULT, mult, lhs),
-                   nm->mkNode(Kind::MULT, mult, rhs))));
+                   nm->mkNode(Kind::MULT, mult, rhs)));
   }
   else if (id == PfRule::ARITH_MULT_TANGENT)
   {