From 77966f11fad2979a843722ba1bbc22dab2104ff1 Mon Sep 17 00:00:00 2001 From: Gereon Kremer Date: Mon, 28 Feb 2022 16:22:48 +0100 Subject: [PATCH] Refactor rewriting of arithmetic addition (#8180) This PR uses the new addition utilities to refactor rewriting of arithmetic addition. This properly handles real algebraic numbers now, eliminating a few more edge cases where the previous solution might allow for non-idempotent rewrites. --- src/theory/arith/arith_rewriter.cpp | 80 +++--------------------- src/theory/arith/arith_rewriter.h | 7 ++- src/theory/arith/rewriter/addition.cpp | 1 + src/theory/arith/rewriter/node_utils.cpp | 9 ++- 4 files changed, 23 insertions(+), 74 deletions(-) diff --git a/src/theory/arith/arith_rewriter.cpp b/src/theory/arith/arith_rewriter.cpp index aae34abb6..a251d9e08 100644 --- a/src/theory/arith/arith_rewriter.cpp +++ b/src/theory/arith/arith_rewriter.cpp @@ -459,86 +459,26 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){ } -RewriteResponse ArithRewriter::preRewritePlus(TNode t){ +RewriteResponse ArithRewriter::preRewritePlus(TNode t) +{ Assert(t.getKind() == kind::ADD); return RewriteResponse(REWRITE_DONE, expr::algorithm::flatten(t)); } -RewriteResponse ArithRewriter::postRewritePlus(TNode t){ +RewriteResponse ArithRewriter::postRewritePlus(TNode t) +{ Assert(t.getKind() == kind::ADD); Assert(t.getNumChildren() > 1); - { - Node flat = expr::algorithm::flatten(t); - if (flat != t) - { - return RewriteResponse(REWRITE_AGAIN, flat); - } - } - - Rational rational; - RealAlgebraicNumber ran; - std::vector monomials; - std::vector polynomials; - - for (const auto& child : t) - { - if (child.isConst()) - { - if (child.getConst().isZero()) - { - continue; - } - rational += child.getConst(); - } - else if (child.getKind() == Kind::REAL_ALGEBRAIC_NUMBER) - { - ran += child.getOperator().getConst(); - } - else if (Monomial::isMember(child)) - { - monomials.emplace_back(Monomial::parseMonomial(child)); - } - else - { - polynomials.emplace_back(Polynomial::parsePolynomial(child)); - } - } - - if(!monomials.empty()){ - Monomial::sort(monomials); - Monomial::combineAdjacentMonomials(monomials); - polynomials.emplace_back(Polynomial::mkPolynomial(monomials)); - } - if (!rational.isZero()) - { - polynomials.emplace_back( - Polynomial::mkPolynomial(Constant::mkConstant(rational))); - } - - Polynomial poly = Polynomial::sumPolynomials(polynomials); - - if (isZero(ran)) - { - return RewriteResponse(REWRITE_DONE, poly.getNode()); - } - if (poly.containsConstant()) - { - ran += RealAlgebraicNumber(poly.getHead().getConstant().getValue()); - if (!poly.isConstant()) - { - poly = poly.getTail(); - } - } + std::vector children; + expr::algorithm::flatten(t, children); - auto* nm = NodeManager::currentNM(); - if (poly.isConstant()) + rewriter::Sum sum; + for (const auto& child : children) { - return RewriteResponse(REWRITE_DONE, nm->mkRealAlgebraicNumber(ran)); + rewriter::addToSum(sum, child); } - return RewriteResponse( - REWRITE_DONE, - nm->mkNode(Kind::ADD, nm->mkRealAlgebraicNumber(ran), poly.getNode())); + return RewriteResponse(REWRITE_DONE, rewriter::collectSum(sum)); } RewriteResponse ArithRewriter::preRewriteMult(TNode node) diff --git a/src/theory/arith/arith_rewriter.h b/src/theory/arith/arith_rewriter.h index a1079e60f..7cf9a4f5b 100644 --- a/src/theory/arith/arith_rewriter.h +++ b/src/theory/arith/arith_rewriter.h @@ -57,6 +57,10 @@ class ArithRewriter : public TheoryRewriter static RewriteResponse rewriteNeg(TNode t, bool pre); /** rewrite binary minus */ static RewriteResponse rewriteSub(TNode t); + /** preRewrite addition */ + static RewriteResponse preRewritePlus(TNode t); + /** postRewrite addition */ + static RewriteResponse postRewritePlus(TNode t); static RewriteResponse rewriteDiv(TNode t, bool pre); static RewriteResponse rewriteAbs(TNode t); static RewriteResponse rewriteIntsDivMod(TNode t, bool pre); @@ -64,9 +68,6 @@ class ArithRewriter : public TheoryRewriter /** Entry for applications of to_int and is_int */ static RewriteResponse rewriteExtIntegerOp(TNode t); - static RewriteResponse preRewritePlus(TNode t); - static RewriteResponse postRewritePlus(TNode t); - static RewriteResponse preRewriteMult(TNode t); static RewriteResponse postRewriteMult(TNode t); diff --git a/src/theory/arith/rewriter/addition.cpp b/src/theory/arith/rewriter/addition.cpp index 16e288920..e8743a43a 100644 --- a/src/theory/arith/rewriter/addition.cpp +++ b/src/theory/arith/rewriter/addition.cpp @@ -162,6 +162,7 @@ void addToSum(Sum& sum, TNode n, bool negate) Node collectSum(const Sum& sum) { if (sum.empty()) return mkConst(Rational(0)); + Trace("arith-rewriter") << "Collecting sum " << sum << std::endl; // construct the sum as nodes. NodeBuilder nb(Kind::ADD); for (const auto& s : sum) diff --git a/src/theory/arith/rewriter/node_utils.cpp b/src/theory/arith/rewriter/node_utils.cpp index 34c3849a0..3d1856e0b 100644 --- a/src/theory/arith/rewriter/node_utils.cpp +++ b/src/theory/arith/rewriter/node_utils.cpp @@ -82,7 +82,14 @@ Node mkMultTerm(const RealAlgebraicNumber& multiplicity, TNode monomial) } std::vector prod; prod.emplace_back(mkConst(multiplicity)); - prod.insert(prod.end(), monomial.begin(), monomial.end()); + if (monomial.getKind() == Kind::MULT || monomial.getKind() == Kind::NONLINEAR_MULT) + { + prod.insert(prod.end(), monomial.begin(), monomial.end()); + } + else + { + prod.emplace_back(monomial); + } Assert(prod.size() >= 2); return NodeManager::currentNM()->mkNode(Kind::NONLINEAR_MULT, prod); } -- 2.30.2