From 5998d7f5a9168b0dd1c26f3aa1b85e570fe72af8 Mon Sep 17 00:00:00 2001 From: Alex Ozdemir Date: Thu, 11 Mar 2021 14:13:57 -0800 Subject: [PATCH] arith proof rules shuffle & add ARITH_SUM_UB (#6118) Preparation for making ARITH_SCALE_SUM_UB a macro. Adds a proof rule for summing upper bounds: ARITH_SUM_UB. Moves ARITH_MULT_* rules from the non-linear extension to the main arithmetic checker, since they will be needed for all of arith now. Aligns the ARITH_SCALE_SUM_UB documentation with its checker. --- src/expr/proof_rule.h | 15 +++- src/theory/arith/nl/ext/proof_checker.cpp | 37 ---------- src/theory/arith/proof_checker.cpp | 85 ++++++++++++++++++++++- 3 files changed, 94 insertions(+), 43 deletions(-) diff --git a/src/expr/proof_rule.h b/src/expr/proof_rule.h index e2933e012..efa673409 100644 --- a/src/expr/proof_rule.h +++ b/src/expr/proof_rule.h @@ -1043,15 +1043,24 @@ enum class PfRule : uint32_t // // Arguments: (k1, ..., kn), non-zero reals // --------------------- - // Conclusion: (>< (* k t1) (* k t2)) + // Conclusion: (>< t1 t2) // where >< is the fusion of the combination of the >< is always one of <, <= // NB: this implies that lower bounds must have negative ki, // and upper bounds must have positive ki. - // t1 is the sum of the polynomials. - // t2 is the sum of the constants. + // t1 is the sum of the scaled polynomials (k_1 * poly_1 + ... + k_n * poly_n) + // t2 is the sum of the scaled constants (k_1 * const_1 + ... + k_n * const_n) ARITH_SCALE_SUM_UPPER_BOUNDS, + // ======== Sum Upper Bounds + // Children: (P1, ... , Pn) + // where each Pi has form (>< L R) + // where >< is < if any >registerChecker(PfRule::ARITH_MULT_SIGN, this); - pc->registerChecker(PfRule::ARITH_MULT_POS, this); - pc->registerChecker(PfRule::ARITH_MULT_NEG, this); pc->registerChecker(PfRule::ARITH_MULT_TANGENT, this); } @@ -119,41 +117,6 @@ Node ExtProofRuleChecker::checkInternal(PfRule id, default: Assert(false); return Node(); } } - else if (id == PfRule::ARITH_MULT_POS) - { - Assert(children.empty()); - Assert(args.size() == 2); - Node mult = args[0]; - Kind rel = args[1].getKind(); - Assert(rel == Kind::EQUAL || rel == Kind::DISTINCT || rel == Kind::LT - || rel == Kind::LEQ || rel == Kind::GT || rel == Kind::GEQ); - Node lhs = args[1][0]; - Node rhs = args[1][1]; - return nm->mkNode( - Kind::IMPLIES, - nm->mkAnd(std::vector{nm->mkNode(Kind::GT, mult, zero), args[1]}), - nm->mkNode(rel, - nm->mkNode(Kind::MULT, mult, lhs), - nm->mkNode(Kind::MULT, mult, rhs))); - } - else if (id == PfRule::ARITH_MULT_NEG) - { - Assert(children.empty()); - Assert(args.size() == 2); - Node mult = args[0]; - Kind rel = args[1].getKind(); - Assert(rel == Kind::EQUAL || rel == Kind::DISTINCT || rel == Kind::LT - || rel == Kind::LEQ || rel == Kind::GT || rel == Kind::GEQ); - Kind rel_inv = (rel == Kind::DISTINCT ? rel : reverseRelationKind(rel)); - Node lhs = args[1][0]; - Node rhs = args[1][1]; - return nm->mkNode( - Kind::IMPLIES, - nm->mkAnd(std::vector{nm->mkNode(Kind::LT, mult, zero), args[1]}), - nm->mkNode(rel_inv, - nm->mkNode(Kind::MULT, mult, lhs), - nm->mkNode(Kind::MULT, mult, rhs))); - } else if (id == PfRule::ARITH_MULT_TANGENT) { Assert(children.empty()); diff --git a/src/theory/arith/proof_checker.cpp b/src/theory/arith/proof_checker.cpp index 80a8a888f..c595791ab 100644 --- a/src/theory/arith/proof_checker.cpp +++ b/src/theory/arith/proof_checker.cpp @@ -30,10 +30,14 @@ namespace arith { void ArithProofRuleChecker::registerTo(ProofChecker* pc) { pc->registerChecker(PfRule::ARITH_SCALE_SUM_UPPER_BOUNDS, this); + pc->registerChecker(PfRule::ARITH_SUM_UB, this); pc->registerChecker(PfRule::ARITH_TRICHOTOMY, this); pc->registerChecker(PfRule::INT_TIGHT_UB, this); pc->registerChecker(PfRule::INT_TIGHT_LB, this); pc->registerChecker(PfRule::ARITH_OP_ELIM_AXIOM, this); + + pc->registerChecker(PfRule::ARITH_MULT_POS, this); + pc->registerChecker(PfRule::ARITH_MULT_NEG, this); // trusted rules pc->registerTrustedChecker(PfRule::INT_TRUST, this, 2); } @@ -42,6 +46,8 @@ Node ArithProofRuleChecker::checkInternal(PfRule id, const std::vector& children, const std::vector& args) { + NodeManager* nm = NodeManager::currentNM(); + auto zero = nm->mkConst(0); if (Debug.isOn("arith::pf::check")) { Debug("arith::pf::check") << "Arith PfRule:" << id << std::endl; @@ -58,6 +64,82 @@ Node ArithProofRuleChecker::checkInternal(PfRule id, } switch (id) { + case PfRule::ARITH_MULT_POS: + { + Assert(children.empty()); + Assert(args.size() == 2); + Node mult = args[0]; + Kind rel = args[1].getKind(); + Assert(rel == Kind::EQUAL || rel == Kind::DISTINCT || rel == Kind::LT + || rel == Kind::LEQ || rel == Kind::GT || rel == Kind::GEQ); + Node lhs = args[1][0]; + Node rhs = args[1][1]; + return nm->mkNode(Kind::IMPLIES, + nm->mkAnd(std::vector{ + nm->mkNode(Kind::GT, mult, zero), args[1]}), + nm->mkNode(rel, + nm->mkNode(Kind::MULT, mult, lhs), + nm->mkNode(Kind::MULT, mult, rhs))); + } + case PfRule::ARITH_MULT_NEG: + { + Assert(children.empty()); + Assert(args.size() == 2); + Node mult = args[0]; + Kind rel = args[1].getKind(); + Assert(rel == Kind::EQUAL || rel == Kind::DISTINCT || rel == Kind::LT + || rel == Kind::LEQ || rel == Kind::GT || rel == Kind::GEQ); + Kind rel_inv = (rel == Kind::DISTINCT ? rel : reverseRelationKind(rel)); + Node lhs = args[1][0]; + Node rhs = args[1][1]; + return nm->mkNode(Kind::IMPLIES, + nm->mkAnd(std::vector{ + nm->mkNode(Kind::LT, mult, zero), args[1]}), + nm->mkNode(rel_inv, + nm->mkNode(Kind::MULT, mult, lhs), + nm->mkNode(Kind::MULT, mult, rhs))); + } + case PfRule::ARITH_SUM_UB: + { + if (children.size() < 2) + { + return Node::null(); + } + + // Whether a strict inequality is in the sum. + bool strict = false; + NodeBuilder<> leftSum(Kind::PLUS); + NodeBuilder<> rightSum(Kind::PLUS); + for (size_t i = 0; i < children.size(); ++i) + { + // Adjust strictness + switch (children[i].getKind()) + { + case Kind::LT: + { + strict = true; + break; + } + case Kind::LEQ: + case Kind::EQUAL: + { + break; + } + default: + { + Debug("arith::pf::check") + << "Bad kind: " << children[i].getKind() << std::endl; + return Node::null(); + } + } + leftSum << children[i][0]; + rightSum << children[i][1]; + } + Node r = nm->mkNode(strict ? Kind::LT : Kind::LEQ, + leftSum.constructNode(), + rightSum.constructNode()); + return r; + } case PfRule::ARITH_SCALE_SUM_UPPER_BOUNDS: { // Children: (P1:l1, ..., Pn:ln) @@ -80,7 +162,6 @@ Node ArithProofRuleChecker::checkInternal(PfRule id, } // Whether a strict inequality is in the sum. - auto nm = NodeManager::currentNM(); bool strict = false; NodeBuilder<> leftSum(Kind::PLUS); NodeBuilder<> rightSum(Kind::PLUS); @@ -181,7 +262,6 @@ Node ArithProofRuleChecker::checkInternal(PfRule id, { Rational originalBound = children[0][1].getConst(); Rational newBound = leastIntGreaterThan(originalBound); - auto nm = NodeManager::currentNM(); Node rational = nm->mkConst(newBound); return nm->mkNode(kind::GEQ, children[0][0], rational); } @@ -207,7 +287,6 @@ Node ArithProofRuleChecker::checkInternal(PfRule id, { Rational originalBound = children[0][1].getConst(); Rational newBound = greatestIntLessThan(originalBound); - auto nm = NodeManager::currentNM(); Node rational = nm->mkConst(newBound); return nm->mkNode(kind::LEQ, children[0][0], rational); } -- 2.30.2