From bd52deb7434b1b08a122db4513972644c11fc4aa Mon Sep 17 00:00:00 2001 From: Gereon Kremer Date: Wed, 10 Mar 2021 21:48:13 +0100 Subject: [PATCH] Improve arithmetic proofs (#6106) The proof rules for ARITH_MULT_POS and ARITH_MULT_NEG were complex than necessary in that they incorporated a rewriting step. This PR removes rewriting from these rules, making them cleaner and easier to understand. The proof now applies these simpler rule and uses MACRO_SR_PRED_TRANSFORM to prove the lemma that is actually added. --- src/expr/proof_rule.h | 10 ++--- .../arith/nl/ext/monomial_bounds_check.cpp | 39 ++++++++++++------- src/theory/arith/nl/ext/proof_checker.cpp | 30 +++++++------- 3 files changed, 42 insertions(+), 37 deletions(-) diff --git a/src/expr/proof_rule.h b/src/expr/proof_rule.h index 909f7b7cd..2759a3c9e 100644 --- a/src/expr/proof_rule.h +++ b/src/expr/proof_rule.h @@ -1108,19 +1108,17 @@ enum class PfRule : uint32_t ARITH_MULT_SIGN, //======== Multiplication with positive factor // Children: none - // Arguments: (m, orig, lhs, rel, rhs) + // Arguments: (m, (rel lhs rhs)) // --------------------- // Conclusion: (=> (and (> m 0) (rel lhs rhs)) (rel (* m lhs) (* m rhs))) - // Where orig is the origin that implies (rel lhs rhs) and rel is a relation - // symbol. + // Where rel is a relation symbol. ARITH_MULT_POS, //======== Multiplication with negative factor // Children: none - // Arguments: (m, orig, (rel lhs rhs)) + // Arguments: (m, (rel lhs rhs)) // --------------------- // Conclusion: (=> (and (< m 0) (rel lhs rhs)) (rel_inv (* m lhs) (* m rhs))) - // Where orig is the origin that implies (rel lhs rhs) and rel is a relation - // symbol and rel_inv the inverted relation symbol. + // Where rel is a relation symbol and rel_inv the inverted relation symbol. ARITH_MULT_NEG, //======== Multiplication tangent plane // Children: none diff --git a/src/theory/arith/nl/ext/monomial_bounds_check.cpp b/src/theory/arith/nl/ext/monomial_bounds_check.cpp index 47cb5daec..f1a2f45b9 100644 --- a/src/theory/arith/nl/ext/monomial_bounds_check.cpp +++ b/src/theory/arith/nl/ext/monomial_bounds_check.cpp @@ -299,11 +299,7 @@ void MonomialBoundsCheck::checkBounds(const std::vector& asserts, Node infer_rhs = nm->mkNode(Kind::MULT, mult, rhs); Node infer = nm->mkNode(infer_type, infer_lhs, infer_rhs); Trace("nl-ext-bound-debug") << " " << infer << std::endl; - infer = Rewriter::rewrite(infer); - Trace("nl-ext-bound-debug2") - << " ...rewritten : " << infer << std::endl; - // check whether it is false in model for abstraction - Node infer_mv = d_data->d_model.computeAbstractModelValue(infer); + Node infer_mv = d_data->d_model.computeAbstractModelValue(Rewriter::rewrite(infer)); Trace("nl-ext-bound-debug") << " ...infer model value is " << infer_mv << std::endl; if (infer_mv == d_data->d_false) @@ -314,22 +310,35 @@ void MonomialBoundsCheck::checkBounds(const std::vector& asserts, mmv_sign == 1 ? Kind::GT : Kind::LT, mult, d_data->d_zero), d_ci_exp[x][coeff][rhs]); Node iblem = nm->mkNode(Kind::IMPLIES, exp, infer); - Node pr_iblem = iblem; - iblem = Rewriter::rewrite(iblem); - bool introNewTerms = hasNewMonomials(iblem, d_data->d_ms); + Node iblem_rw = Rewriter::rewrite(iblem); + bool introNewTerms = hasNewMonomials(iblem_rw, d_data->d_ms); Trace("nl-ext-bound-lemma") - << "*** Bound inference lemma : " << iblem - << " (pre-rewrite : " << pr_iblem << ")" << std::endl; + << "*** Bound inference lemma : " << iblem_rw + << " (pre-rewrite : " << iblem << ")" << std::endl; CDProof* proof = nullptr; + Node orig = d_ci_exp[x][coeff][rhs]; if (d_data->isProofEnabled()) { proof = d_data->getProof(); + // this is iblem, but uses (type t rhs) instead of the original + // variant (which is identical under rewriting) + // we first infer the "clean" version of the lemma and then + // use MACRO_SR_PRED_TRANSFORM to rewrite + Node tmplem = nm->mkNode( + Kind::IMPLIES, + nm->mkNode(Kind::AND, + nm->mkNode(mmv_sign == 1 ? Kind::GT : Kind::LT, + mult, + d_data->d_zero), + nm->mkNode(type, t, rhs)), + infer); + proof->addStep(tmplem, + mmv_sign == 1 ? PfRule::ARITH_MULT_POS + : PfRule::ARITH_MULT_NEG, + {}, + {mult, nm->mkNode(type, t, rhs)}); proof->addStep( - iblem, - mmv_sign == 1 ? PfRule::ARITH_MULT_POS - : PfRule::ARITH_MULT_NEG, - {}, - {mult, d_ci_exp[x][coeff][rhs], nm->mkNode(type, t, rhs)}); + iblem, PfRule::MACRO_SR_PRED_TRANSFORM, {tmplem}, {iblem}); } d_data->d_im.addPendingLemma(iblem, InferenceId::ARITH_NL_INFER_BOUNDS_NT, diff --git a/src/theory/arith/nl/ext/proof_checker.cpp b/src/theory/arith/nl/ext/proof_checker.cpp index e88e08aaf..6d027fd16 100644 --- a/src/theory/arith/nl/ext/proof_checker.cpp +++ b/src/theory/arith/nl/ext/proof_checker.cpp @@ -122,39 +122,37 @@ Node ExtProofRuleChecker::checkInternal(PfRule id, else if (id == PfRule::ARITH_MULT_POS) { Assert(children.empty()); - Assert(args.size() == 3); + Assert(args.size() == 2); Node mult = args[0]; - Node orig = args[1]; - Kind rel = args[2].getKind(); + Kind rel = args[1].getKind(); Assert(rel == Kind::EQUAL || rel == Kind::DISTINCT || rel == Kind::LT || rel == Kind::LEQ || rel == Kind::GT || rel == Kind::GEQ); - Node lhs = args[2][0]; - Node rhs = args[2][1]; - return Rewriter::rewrite(nm->mkNode( + Node lhs = args[1][0]; + Node rhs = args[1][1]; + return nm->mkNode( Kind::IMPLIES, - nm->mkAnd(std::vector{nm->mkNode(Kind::GT, mult, zero), orig}), + nm->mkAnd(std::vector{nm->mkNode(Kind::GT, mult, zero), args[1]}), nm->mkNode(rel, nm->mkNode(Kind::MULT, mult, lhs), - nm->mkNode(Kind::MULT, mult, rhs)))); + nm->mkNode(Kind::MULT, mult, rhs))); } else if (id == PfRule::ARITH_MULT_NEG) { Assert(children.empty()); - Assert(args.size() == 3); + Assert(args.size() == 2); Node mult = args[0]; - Node orig = args[1]; - Kind rel = args[2].getKind(); + Kind rel = args[1].getKind(); Assert(rel == Kind::EQUAL || rel == Kind::DISTINCT || rel == Kind::LT || rel == Kind::LEQ || rel == Kind::GT || rel == Kind::GEQ); Kind rel_inv = (rel == Kind::DISTINCT ? rel : reverseRelationKind(rel)); - Node lhs = args[2][0]; - Node rhs = args[2][1]; - return Rewriter::rewrite(nm->mkNode( + Node lhs = args[1][0]; + Node rhs = args[1][1]; + return nm->mkNode( Kind::IMPLIES, - nm->mkAnd(std::vector{nm->mkNode(Kind::LT, mult, zero), orig}), + nm->mkAnd(std::vector{nm->mkNode(Kind::LT, mult, zero), args[1]}), nm->mkNode(rel_inv, nm->mkNode(Kind::MULT, mult, lhs), - nm->mkNode(Kind::MULT, mult, rhs)))); + nm->mkNode(Kind::MULT, mult, rhs))); } else if (id == PfRule::ARITH_MULT_TANGENT) { -- 2.30.2