From 5b32209ea2dd947c4cea384b2166e1de401cb929 Mon Sep 17 00:00:00 2001 From: Gereon Kremer Date: Fri, 25 Feb 2022 21:56:21 +0100 Subject: [PATCH] Refactor rewriting of arithmetic negation and subtraction (#8170) Slightly refactor negation and subtraction, get rid of utility functions. --- src/theory/arith/arith_rewriter.cpp | 60 +++++++++++++---------------- src/theory/arith/arith_rewriter.h | 7 ++-- 2 files changed, 29 insertions(+), 38 deletions(-) diff --git a/src/theory/arith/arith_rewriter.cpp b/src/theory/arith/arith_rewriter.cpp index f8d8594eb..aae34abb6 100644 --- a/src/theory/arith/arith_rewriter.cpp +++ b/src/theory/arith/arith_rewriter.cpp @@ -270,28 +270,13 @@ RewriteResponse ArithRewriter::rewriteRAN(TNode t) return RewriteResponse(REWRITE_DONE, t); } -RewriteResponse ArithRewriter::rewriteVariable(TNode t){ +RewriteResponse ArithRewriter::rewriteVariable(TNode t) +{ Assert(t.isVar()); return RewriteResponse(REWRITE_DONE, t); } -RewriteResponse ArithRewriter::rewriteSub(TNode t) -{ - Assert(t.getKind() == kind::SUB); - Assert(t.getNumChildren() == 2); - - auto* nm = NodeManager::currentNM(); - - if (t[0] == t[1]) - { - return RewriteResponse(REWRITE_DONE, - nm->mkConstRealOrInt(t.getType(), Rational(0))); - } - return RewriteResponse(REWRITE_AGAIN_FULL, - nm->mkNode(Kind::ADD, t[0], makeUnaryMinusNode(t[1]))); -} - RewriteResponse ArithRewriter::rewriteNeg(TNode t, bool pre) { Assert(t.getKind() == kind::NEG); @@ -299,25 +284,39 @@ RewriteResponse ArithRewriter::rewriteNeg(TNode t, bool pre) if (t[0].isConst()) { Rational neg = -(t[0].getConst()); - NodeManager* nm = NodeManager::currentNM(); - return RewriteResponse(REWRITE_DONE, - nm->mkConstRealOrInt(t[0].getType(), neg)); + return RewriteResponse(REWRITE_DONE, rewriter::mkConst(neg)); } - if (t[0].getKind() == Kind::REAL_ALGEBRAIC_NUMBER) + if (rewriter::isRAN(t[0])) { - const RealAlgebraicNumber& r = - t[0].getOperator().getConst(); - NodeManager* nm = NodeManager::currentNM(); - return RewriteResponse(REWRITE_DONE, nm->mkRealAlgebraicNumber(-r)); + return RewriteResponse(REWRITE_DONE, + rewriter::mkConst(-rewriter::getRAN(t[0]))); } - Node noUminus = makeUnaryMinusNode(t[0]); - if(pre) + auto* nm = NodeManager::currentNM(); + Node noUminus = nm->mkNode(kind::MULT, rewriter::mkConst(Integer(-1)), t[0]); + if (pre) return RewriteResponse(REWRITE_DONE, noUminus); else return RewriteResponse(REWRITE_AGAIN, noUminus); } +RewriteResponse ArithRewriter::rewriteSub(TNode t) +{ + Assert(t.getKind() == kind::SUB); + Assert(t.getNumChildren() == 2); + + if (t[0] == t[1]) + { + return RewriteResponse(REWRITE_DONE, rewriter::mkConst(Integer(0))); + } + auto* nm = NodeManager::currentNM(); + return RewriteResponse( + REWRITE_AGAIN_FULL, + nm->mkNode(Kind::ADD, + t[0], + nm->mkNode(kind::MULT, rewriter::mkConst(Integer(-1)), t[1]))); +} + RewriteResponse ArithRewriter::preRewriteTerm(TNode t){ if(t.isConst()){ return rewriteConstant(t); @@ -602,13 +601,6 @@ RewriteResponse ArithRewriter::postRewriteMult(TNode t){ rewriter::mkMultTerm(ran, std::move(leafs))); } -Node ArithRewriter::makeUnaryMinusNode(TNode n) -{ - NodeManager* nm = NodeManager::currentNM(); - Rational qNegOne(-1); - return nm->mkNode(kind::MULT, nm->mkConstRealOrInt(n.getType(), qNegOne), n); -} - RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre) { Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind() == kind::DIVISION); diff --git a/src/theory/arith/arith_rewriter.h b/src/theory/arith/arith_rewriter.h index 9e2a15c77..a1079e60f 100644 --- a/src/theory/arith/arith_rewriter.h +++ b/src/theory/arith/arith_rewriter.h @@ -46,9 +46,6 @@ class ArithRewriter : public TheoryRewriter static RewriteResponse preRewriteAtom(TNode t); static RewriteResponse postRewriteAtom(TNode t); - static Node makeSubtractionNode(TNode l, TNode r); - static Node makeUnaryMinusNode(TNode n); - static RewriteResponse preRewriteTerm(TNode t); static RewriteResponse postRewriteTerm(TNode t); @@ -56,8 +53,10 @@ class ArithRewriter : public TheoryRewriter static RewriteResponse rewriteRAN(TNode t); static RewriteResponse rewriteVariable(TNode t); - static RewriteResponse rewriteSub(TNode t); + /** rewrite unary minus */ static RewriteResponse rewriteNeg(TNode t, bool pre); + /** rewrite binary minus */ + static RewriteResponse rewriteSub(TNode t); static RewriteResponse rewriteDiv(TNode t, bool pre); static RewriteResponse rewriteAbs(TNode t); static RewriteResponse rewriteIntsDivMod(TNode t, bool pre); -- 2.30.2