arith proof rules shuffle & add ARITH_SUM_UB (#6118)
authorAlex Ozdemir <aozdemir@hmc.edu>
Thu, 11 Mar 2021 22:13:57 +0000 (14:13 -0800)
committerGitHub <noreply@github.com>
Thu, 11 Mar 2021 22:13:57 +0000 (22:13 +0000)
Preparation for making ARITH_SCALE_SUM_UB a macro.

Adds a proof rule for summing upper bounds: ARITH_SUM_UB.
Moves ARITH_MULT_* rules from the non-linear extension to the main
arithmetic checker, since they will be needed for all of arith now.
Aligns the ARITH_SCALE_SUM_UB documentation with its checker.

src/expr/proof_rule.h
src/theory/arith/nl/ext/proof_checker.cpp
src/theory/arith/proof_checker.cpp

index e2933e012edb500017026965db846a9781641cab..efa673409101a6315d742c128c774e58476535f1 100644 (file)
@@ -1043,15 +1043,24 @@ enum class PfRule : uint32_t
   //
   // Arguments: (k1, ..., kn), non-zero reals
   // ---------------------
-  // Conclusion: (>< (* k t1) (* k t2))
+  // Conclusion: (>< t1 t2)
   //    where >< is the fusion of the combination of the ><i, (flipping each it
   //    its ki is negative). >< is always one of <, <=
   //    NB: this implies that lower bounds must have negative ki,
   //                      and upper bounds must have positive ki.
-  //    t1 is the sum of the polynomials.
-  //    t2 is the sum of the constants.
+  //    t1 is the sum of the scaled polynomials (k_1 * poly_1 + ... + k_n * poly_n)
+  //    t2 is the sum of the scaled constants (k_1 * const_1 + ... + k_n * const_n)
   ARITH_SCALE_SUM_UPPER_BOUNDS,
 
+  // ======== Sum Upper Bounds
+  // Children: (P1, ... , Pn)
+  //           where each Pi has form (><i, Li, Ri)
+  //           for ><i in {<, <=, ==}
+  // Conclusion: (>< L R)
+  //           where >< is < if any ><i is <, and <= otherwise.
+  //                 L is (+ L1 ... Ln)
+  //                 R is (+ R1 ... Rn)
+  ARITH_SUM_UB,
   // ======== Tightening Strict Integer Upper Bounds
   // Children: (P:(< i c))
   //         where i has integer type.
index 6d027fd16f21ff9e3adea95271d82daa5a5f30f6..ca600ad559f76ed384c741957fb3cdfc63f0a523 100644 (file)
@@ -28,8 +28,6 @@ namespace nl {
 void ExtProofRuleChecker::registerTo(ProofChecker* pc)
 {
   pc->registerChecker(PfRule::ARITH_MULT_SIGN, this);
-  pc->registerChecker(PfRule::ARITH_MULT_POS, this);
-  pc->registerChecker(PfRule::ARITH_MULT_NEG, this);
   pc->registerChecker(PfRule::ARITH_MULT_TANGENT, this);
 }
 
@@ -119,41 +117,6 @@ Node ExtProofRuleChecker::checkInternal(PfRule id,
       default: Assert(false); return Node();
     }
   }
-  else if (id == PfRule::ARITH_MULT_POS)
-  {
-    Assert(children.empty());
-    Assert(args.size() == 2);
-    Node mult = args[0];
-    Kind rel = args[1].getKind();
-    Assert(rel == Kind::EQUAL || rel == Kind::DISTINCT || rel == Kind::LT
-           || rel == Kind::LEQ || rel == Kind::GT || rel == Kind::GEQ);
-    Node lhs = args[1][0];
-    Node rhs = args[1][1];
-    return nm->mkNode(
-        Kind::IMPLIES,
-        nm->mkAnd(std::vector<Node>{nm->mkNode(Kind::GT, mult, zero), args[1]}),
-        nm->mkNode(rel,
-                   nm->mkNode(Kind::MULT, mult, lhs),
-                   nm->mkNode(Kind::MULT, mult, rhs)));
-  }
-  else if (id == PfRule::ARITH_MULT_NEG)
-  {
-    Assert(children.empty());
-    Assert(args.size() == 2);
-    Node mult = args[0];
-    Kind rel = args[1].getKind();
-    Assert(rel == Kind::EQUAL || rel == Kind::DISTINCT || rel == Kind::LT
-           || rel == Kind::LEQ || rel == Kind::GT || rel == Kind::GEQ);
-    Kind rel_inv = (rel == Kind::DISTINCT ? rel : reverseRelationKind(rel));
-    Node lhs = args[1][0];
-    Node rhs = args[1][1];
-    return nm->mkNode(
-        Kind::IMPLIES,
-        nm->mkAnd(std::vector<Node>{nm->mkNode(Kind::LT, mult, zero), args[1]}),
-        nm->mkNode(rel_inv,
-                   nm->mkNode(Kind::MULT, mult, lhs),
-                   nm->mkNode(Kind::MULT, mult, rhs)));
-  }
   else if (id == PfRule::ARITH_MULT_TANGENT)
   {
     Assert(children.empty());
index 80a8a888fd7e34cfe0ab886ff858e321b99cdde2..c595791ab625bf588da002b9c3b1b683f6bd6d5b 100644 (file)
@@ -30,10 +30,14 @@ namespace arith {
 void ArithProofRuleChecker::registerTo(ProofChecker* pc)
 {
   pc->registerChecker(PfRule::ARITH_SCALE_SUM_UPPER_BOUNDS, this);
+  pc->registerChecker(PfRule::ARITH_SUM_UB, this);
   pc->registerChecker(PfRule::ARITH_TRICHOTOMY, this);
   pc->registerChecker(PfRule::INT_TIGHT_UB, this);
   pc->registerChecker(PfRule::INT_TIGHT_LB, this);
   pc->registerChecker(PfRule::ARITH_OP_ELIM_AXIOM, this);
+
+  pc->registerChecker(PfRule::ARITH_MULT_POS, this);
+  pc->registerChecker(PfRule::ARITH_MULT_NEG, this);
   // trusted rules
   pc->registerTrustedChecker(PfRule::INT_TRUST, this, 2);
 }
@@ -42,6 +46,8 @@ Node ArithProofRuleChecker::checkInternal(PfRule id,
                                           const std::vector<Node>& children,
                                           const std::vector<Node>& args)
 {
+  NodeManager* nm = NodeManager::currentNM();
+  auto zero = nm->mkConst<Rational>(0);
   if (Debug.isOn("arith::pf::check"))
   {
     Debug("arith::pf::check") << "Arith PfRule:" << id << std::endl;
@@ -58,6 +64,82 @@ Node ArithProofRuleChecker::checkInternal(PfRule id,
   }
   switch (id)
   {
+    case PfRule::ARITH_MULT_POS:
+    {
+      Assert(children.empty());
+      Assert(args.size() == 2);
+      Node mult = args[0];
+      Kind rel = args[1].getKind();
+      Assert(rel == Kind::EQUAL || rel == Kind::DISTINCT || rel == Kind::LT
+             || rel == Kind::LEQ || rel == Kind::GT || rel == Kind::GEQ);
+      Node lhs = args[1][0];
+      Node rhs = args[1][1];
+      return nm->mkNode(Kind::IMPLIES,
+                        nm->mkAnd(std::vector<Node>{
+                            nm->mkNode(Kind::GT, mult, zero), args[1]}),
+                        nm->mkNode(rel,
+                                   nm->mkNode(Kind::MULT, mult, lhs),
+                                   nm->mkNode(Kind::MULT, mult, rhs)));
+    }
+    case PfRule::ARITH_MULT_NEG:
+    {
+      Assert(children.empty());
+      Assert(args.size() == 2);
+      Node mult = args[0];
+      Kind rel = args[1].getKind();
+      Assert(rel == Kind::EQUAL || rel == Kind::DISTINCT || rel == Kind::LT
+             || rel == Kind::LEQ || rel == Kind::GT || rel == Kind::GEQ);
+      Kind rel_inv = (rel == Kind::DISTINCT ? rel : reverseRelationKind(rel));
+      Node lhs = args[1][0];
+      Node rhs = args[1][1];
+      return nm->mkNode(Kind::IMPLIES,
+                        nm->mkAnd(std::vector<Node>{
+                            nm->mkNode(Kind::LT, mult, zero), args[1]}),
+                        nm->mkNode(rel_inv,
+                                   nm->mkNode(Kind::MULT, mult, lhs),
+                                   nm->mkNode(Kind::MULT, mult, rhs)));
+    }
+    case PfRule::ARITH_SUM_UB:
+    {
+      if (children.size() < 2)
+      {
+        return Node::null();
+      }
+
+      // Whether a strict inequality is in the sum.
+      bool strict = false;
+      NodeBuilder<> leftSum(Kind::PLUS);
+      NodeBuilder<> rightSum(Kind::PLUS);
+      for (size_t i = 0; i < children.size(); ++i)
+      {
+        // Adjust strictness
+        switch (children[i].getKind())
+        {
+          case Kind::LT:
+          {
+            strict = true;
+            break;
+          }
+          case Kind::LEQ:
+          case Kind::EQUAL:
+          {
+            break;
+          }
+          default:
+          {
+            Debug("arith::pf::check")
+                << "Bad kind: " << children[i].getKind() << std::endl;
+            return Node::null();
+          }
+        }
+        leftSum << children[i][0];
+        rightSum << children[i][1];
+      }
+      Node r = nm->mkNode(strict ? Kind::LT : Kind::LEQ,
+                          leftSum.constructNode(),
+                          rightSum.constructNode());
+      return r;
+    }
     case PfRule::ARITH_SCALE_SUM_UPPER_BOUNDS:
     {
       // Children: (P1:l1, ..., Pn:ln)
@@ -80,7 +162,6 @@ Node ArithProofRuleChecker::checkInternal(PfRule id,
       }
 
       // Whether a strict inequality is in the sum.
-      auto nm = NodeManager::currentNM();
       bool strict = false;
       NodeBuilder<> leftSum(Kind::PLUS);
       NodeBuilder<> rightSum(Kind::PLUS);
@@ -181,7 +262,6 @@ Node ArithProofRuleChecker::checkInternal(PfRule id,
       {
         Rational originalBound = children[0][1].getConst<Rational>();
         Rational newBound = leastIntGreaterThan(originalBound);
-        auto nm = NodeManager::currentNM();
         Node rational = nm->mkConst<Rational>(newBound);
         return nm->mkNode(kind::GEQ, children[0][0], rational);
       }
@@ -207,7 +287,6 @@ Node ArithProofRuleChecker::checkInternal(PfRule id,
       {
         Rational originalBound = children[0][1].getConst<Rational>();
         Rational newBound = greatestIntLessThan(originalBound);
-        auto nm = NodeManager::currentNM();
         Node rational = nm->mkConst<Rational>(newBound);
         return nm->mkNode(kind::LEQ, children[0][0], rational);
       }