From: Gereon Kremer Date: Wed, 2 Mar 2022 16:37:03 +0000 (+0100) Subject: Refactor rewriting of arithmetic division (#8195) X-Git-Tag: cvc5-1.0.0~346 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=bc37256a1a23ade3ebecc74247d0e69f02abc844;p=cvc5.git Refactor rewriting of arithmetic division (#8195) This PR does minor refactoring of how we rewrite division. Also, it moves some functions around to reconcile the order in header and source file. --- diff --git a/src/theory/arith/arith_rewriter.cpp b/src/theory/arith/arith_rewriter.cpp index 863eaee1d..91652b7ce 100644 --- a/src/theory/arith/arith_rewriter.cpp +++ b/src/theory/arith/arith_rewriter.cpp @@ -46,27 +46,6 @@ namespace cvc5 { namespace theory { namespace arith { -namespace { - -/** - * Check whether the parent has a child that is a constant zero. - * If so, return this child. Otherwise, return std::nullopt. - */ -template -std::optional getZeroChild(const Iterable& parent) -{ - for (const auto& node : parent) - { - if (node.isConst() && node.template getConst().isZero()) - { - return node; - } - } - return std::nullopt; -} - -} // namespace - ArithRewriter::ArithRewriter(OperatorElim& oe) : d_opElim(oe) {} RewriteResponse ArithRewriter::preRewrite(TNode t) @@ -225,46 +204,6 @@ RewriteResponse ArithRewriter::postRewriteAtom(TNode atom) } } -RewriteResponse ArithRewriter::rewriteNeg(TNode t, bool pre) -{ - Assert(t.getKind() == kind::NEG); - - if (t[0].isConst()) - { - Rational neg = -(t[0].getConst()); - return RewriteResponse(REWRITE_DONE, rewriter::mkConst(neg)); - } - if (rewriter::isRAN(t[0])) - { - return RewriteResponse(REWRITE_DONE, - rewriter::mkConst(-rewriter::getRAN(t[0]))); - } - - auto* nm = NodeManager::currentNM(); - Node noUminus = nm->mkNode(kind::MULT, rewriter::mkConst(Integer(-1)), t[0]); - if (pre) - return RewriteResponse(REWRITE_DONE, noUminus); - else - return RewriteResponse(REWRITE_AGAIN, noUminus); -} - -RewriteResponse ArithRewriter::rewriteSub(TNode t) -{ - Assert(t.getKind() == kind::SUB); - Assert(t.getNumChildren() == 2); - - if (t[0] == t[1]) - { - return RewriteResponse(REWRITE_DONE, rewriter::mkConst(Integer(0))); - } - auto* nm = NodeManager::currentNM(); - return RewriteResponse( - REWRITE_AGAIN_FULL, - nm->mkNode(Kind::ADD, - t[0], - nm->mkNode(kind::MULT, rewriter::mkConst(Integer(-1)), t[1]))); -} - RewriteResponse ArithRewriter::preRewriteTerm(TNode t){ if(t.isConst()){ return RewriteResponse(REWRITE_DONE, t); @@ -317,7 +256,9 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){ return RewriteResponse(REWRITE_DONE, t); }else if(t.isVar()){ return rewriteVariable(t); - }else{ + } + else + { Trace("arith-rewriter") << "postRewriteTerm: " << t << std::endl; switch(t.getKind()){ case kind::REAL_ALGEBRAIC_NUMBER: return rewriteRAN(t); @@ -424,6 +365,50 @@ RewriteResponse ArithRewriter::rewriteVariable(TNode t) return RewriteResponse(REWRITE_DONE, t); } +RewriteResponse ArithRewriter::rewriteNeg(TNode t, bool pre) +{ + Assert(t.getKind() == kind::NEG); + + if (t[0].isConst()) + { + Rational neg = -(t[0].getConst()); + return RewriteResponse(REWRITE_DONE, rewriter::mkConst(neg)); + } + if (rewriter::isRAN(t[0])) + { + return RewriteResponse(REWRITE_DONE, + rewriter::mkConst(-rewriter::getRAN(t[0]))); + } + + auto* nm = NodeManager::currentNM(); + Node noUminus = nm->mkNode(kind::MULT, rewriter::mkConst(Integer(-1)), t[0]); + if (pre) + { + return RewriteResponse(REWRITE_DONE, noUminus); + } + else + { + return RewriteResponse(REWRITE_AGAIN, noUminus); + } +} + +RewriteResponse ArithRewriter::rewriteSub(TNode t) +{ + Assert(t.getKind() == kind::SUB); + Assert(t.getNumChildren() == 2); + + if (t[0] == t[1]) + { + return RewriteResponse(REWRITE_DONE, rewriter::mkConst(Integer(0))); + } + auto* nm = NodeManager::currentNM(); + return RewriteResponse( + REWRITE_AGAIN_FULL, + nm->mkNode(Kind::ADD, + t[0], + nm->mkNode(kind::MULT, rewriter::mkConst(Integer(-1)), t[1]))); +} + RewriteResponse ArithRewriter::preRewritePlus(TNode t) { Assert(t.getKind() == kind::ADD); @@ -537,13 +522,11 @@ RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre) const Rational& num = left.getConst(); return RewriteResponse(REWRITE_DONE, nm->mkConstReal(num / den)); } - if (left.getKind() == Kind::REAL_ALGEBRAIC_NUMBER) + if (rewriter::isRAN(left)) { - const RealAlgebraicNumber& num = - left.getOperator().getConst(); return RewriteResponse( REWRITE_DONE, - nm->mkRealAlgebraicNumber(num / RealAlgebraicNumber(den))); + nm->mkRealAlgebraicNumber(rewriter::getRAN(left) / den)); } Node result = nm->mkConstReal(den.inverse()); @@ -557,27 +540,22 @@ RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre) return RewriteResponse(REWRITE_AGAIN, mult); } } - if (right.getKind() == Kind::REAL_ALGEBRAIC_NUMBER) + if (rewriter::isRAN(right)) { - NodeManager* nm = NodeManager::currentNM(); - const RealAlgebraicNumber& den = - right.getOperator().getConst(); + const RealAlgebraicNumber& den = rewriter::getRAN(right); + if (left.isConst()) { - const Rational& num = left.getConst(); return RewriteResponse( - REWRITE_DONE, - nm->mkRealAlgebraicNumber(RealAlgebraicNumber(num) / den)); + REWRITE_DONE, rewriter::mkConst(left.getConst() / den)); } - if (left.getKind() == Kind::REAL_ALGEBRAIC_NUMBER) + if (rewriter::isRAN(left)) { - const RealAlgebraicNumber& num = - left.getOperator().getConst(); return RewriteResponse(REWRITE_DONE, - nm->mkRealAlgebraicNumber(num / den)); + rewriter::mkConst(rewriter::getRAN(left) / den)); } - Node result = nm->mkRealAlgebraicNumber(inverse(den)); + Node result = rewriter::mkConst(inverse(den)); Node mult = NodeManager::currentNM()->mkNode(kind::MULT, left, result); if (pre) { @@ -607,10 +585,9 @@ RewriteResponse ArithRewriter::rewriteAbs(TNode t) REWRITE_DONE, NodeManager::currentNM()->mkConstRealOrInt(t[0].getType(), -rat)); } - if (t[0].getKind() == Kind::REAL_ALGEBRAIC_NUMBER) + if (rewriter::isRAN(t[0])) { - const RealAlgebraicNumber& ran = - t[0].getOperator().getConst(); + const RealAlgebraicNumber& ran = rewriter::getRAN(t[0]); if (ran >= RealAlgebraicNumber()) { return RewriteResponse(REWRITE_DONE, t[0]); @@ -646,37 +623,6 @@ RewriteResponse ArithRewriter::rewriteIntsDivMod(TNode t, bool pre) return RewriteResponse(REWRITE_DONE, t); } -RewriteResponse ArithRewriter::rewriteExtIntegerOp(TNode t) -{ - Assert(t.getKind() == kind::TO_INTEGER || t.getKind() == kind::IS_INTEGER); - bool isPred = t.getKind() == kind::IS_INTEGER; - NodeManager* nm = NodeManager::currentNM(); - if (t[0].isConst()) - { - Node ret; - if (isPred) - { - ret = nm->mkConst(t[0].getConst().isIntegral()); - } - else - { - ret = nm->mkConstInt(Rational(t[0].getConst().floor())); - } - return returnRewrite(t, ret, Rewrite::INT_EXT_CONST); - } - if (t[0].getType().isInteger()) - { - Node ret = isPred ? nm->mkConst(true) : Node(t[0]); - return returnRewrite(t, ret, Rewrite::INT_EXT_INT); - } - if (t[0].getKind() == kind::PI) - { - Node ret = isPred ? nm->mkConst(false) : nm->mkConstReal(Rational(3)); - return returnRewrite(t, ret, Rewrite::INT_EXT_PI); - } - return RewriteResponse(REWRITE_DONE, t); -} - RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre) { if (pre) @@ -785,6 +731,37 @@ RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre) return RewriteResponse(REWRITE_DONE, t); } +RewriteResponse ArithRewriter::rewriteExtIntegerOp(TNode t) +{ + Assert(t.getKind() == kind::TO_INTEGER || t.getKind() == kind::IS_INTEGER); + bool isPred = t.getKind() == kind::IS_INTEGER; + NodeManager* nm = NodeManager::currentNM(); + if (t[0].isConst()) + { + Node ret; + if (isPred) + { + ret = nm->mkConst(t[0].getConst().isIntegral()); + } + else + { + ret = nm->mkConstInt(Rational(t[0].getConst().floor())); + } + return returnRewrite(t, ret, Rewrite::INT_EXT_CONST); + } + if (t[0].getType().isInteger()) + { + Node ret = isPred ? nm->mkConst(true) : Node(t[0]); + return returnRewrite(t, ret, Rewrite::INT_EXT_INT); + } + if (t[0].getKind() == kind::PI) + { + Node ret = isPred ? nm->mkConst(false) : nm->mkConstReal(Rational(3)); + return returnRewrite(t, ret, Rewrite::INT_EXT_PI); + } + return RewriteResponse(REWRITE_DONE, t); +} + RewriteResponse ArithRewriter::postRewriteIAnd(TNode t) { Assert(t.getKind() == kind::IAND); diff --git a/src/theory/arith/arith_rewriter.h b/src/theory/arith/arith_rewriter.h index 3ffa90015..6f2e06ef1 100644 --- a/src/theory/arith/arith_rewriter.h +++ b/src/theory/arith/arith_rewriter.h @@ -48,7 +48,9 @@ class ArithRewriter : public TheoryRewriter /** postRewrite for atoms */ static RewriteResponse postRewriteAtom(TNode t); + /** preRewrite for terms */ static RewriteResponse preRewriteTerm(TNode t); + /** postRewrite for terms */ static RewriteResponse postRewriteTerm(TNode t); /** rewrite real algebraic numbers */ @@ -64,16 +66,22 @@ class ArithRewriter : public TheoryRewriter static RewriteResponse preRewritePlus(TNode t); /** postRewrite addition */ static RewriteResponse postRewritePlus(TNode t); + /** preRewrite multiplication */ + static RewriteResponse preRewriteMult(TNode t); + /** postRewrite multiplication */ + static RewriteResponse postRewriteMult(TNode t); + + /** rewrite division */ static RewriteResponse rewriteDiv(TNode t, bool pre); + /** rewrite absolute */ static RewriteResponse rewriteAbs(TNode t); + /** rewrite integer division and modulus */ static RewriteResponse rewriteIntsDivMod(TNode t, bool pre); + /** rewrite integer total division and total modulus */ static RewriteResponse rewriteIntsDivModTotal(TNode t, bool pre); - /** Entry for applications of to_int and is_int */ + /** rewrite to_int and is_int */ static RewriteResponse rewriteExtIntegerOp(TNode t); - static RewriteResponse preRewriteMult(TNode t); - static RewriteResponse postRewriteMult(TNode t); - /** postRewrite IAND */ static RewriteResponse postRewriteIAnd(TNode t); /** postRewrite POW2 */