From d603925de77bc465fe6bfe5097e848224337acf9 Mon Sep 17 00:00:00 2001 From: Gereon Kremer Date: Wed, 12 Jan 2022 14:19:31 -0800 Subject: [PATCH] Refactor rewriteMinus (#7932) Refactors rewriteMinus to be generally simpler and explicitly rely on rewriting addition. Note that rewriteMinus would almost never be called in post rewrite anyway, as it would rewrite to addition in pre rewrite. --- src/theory/arith/arith_rewriter.cpp | 230 +++++++++++++--------------- src/theory/arith/arith_rewriter.h | 2 +- 2 files changed, 108 insertions(+), 124 deletions(-) diff --git a/src/theory/arith/arith_rewriter.cpp b/src/theory/arith/arith_rewriter.cpp index 4e79515bd..2c3fcdf48 100644 --- a/src/theory/arith/arith_rewriter.cpp +++ b/src/theory/arith/arith_rewriter.cpp @@ -60,25 +60,21 @@ RewriteResponse ArithRewriter::rewriteVariable(TNode t){ return RewriteResponse(REWRITE_DONE, t); } -RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){ +RewriteResponse ArithRewriter::rewriteMinus(TNode t) +{ Assert(t.getKind() == kind::MINUS); + Assert(t.getNumChildren() == 2); - if(pre){ - if(t[0] == t[1]){ - Rational zero(0); - Node zeroNode = - NodeManager::currentNM()->mkConstRealOrInt(t.getType(), zero); - return RewriteResponse(REWRITE_DONE, zeroNode); - }else{ - Node noMinus = makeSubtractionNode(t[0],t[1]); - return RewriteResponse(REWRITE_DONE, noMinus); - } - }else{ - Polynomial minuend = Polynomial::parsePolynomial(t[0]); - Polynomial subtrahend = Polynomial::parsePolynomial(t[1]); - Polynomial diff = minuend - subtrahend; - return RewriteResponse(REWRITE_DONE, diff.getNode()); + auto* nm = NodeManager::currentNM(); + + if (t[0] == t[1]) + { + return RewriteResponse(REWRITE_DONE, + nm->mkConstRealOrInt(t.getType(), Rational(0))); } + return RewriteResponse( + REWRITE_AGAIN_FULL, + nm->mkNode(Kind::PLUS, t[0], makeUnaryMinusNode(t[1]))); } RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){ @@ -106,60 +102,56 @@ RewriteResponse ArithRewriter::preRewriteTerm(TNode t){ return rewriteVariable(t); }else{ switch(Kind k = t.getKind()){ - case kind::MINUS: - return rewriteMinus(t, true); - case kind::UMINUS: - return rewriteUMinus(t, true); - case kind::DIVISION: - case kind::DIVISION_TOTAL: - return rewriteDiv(t,true); - case kind::PLUS: - return preRewritePlus(t); - case kind::MULT: - case kind::NONLINEAR_MULT: return preRewriteMult(t); - case kind::IAND: return RewriteResponse(REWRITE_DONE, t); - case kind::POW2: return RewriteResponse(REWRITE_DONE, t); - case kind::EXPONENTIAL: - case kind::SINE: - case kind::COSINE: - case kind::TANGENT: - case kind::COSECANT: - case kind::SECANT: - case kind::COTANGENT: - case kind::ARCSINE: - case kind::ARCCOSINE: - case kind::ARCTANGENT: - case kind::ARCCOSECANT: - case kind::ARCSECANT: - case kind::ARCCOTANGENT: - case kind::SQRT: return preRewriteTranscendental(t); - case kind::INTS_DIVISION: - case kind::INTS_MODULUS: return rewriteIntsDivMod(t, true); - case kind::INTS_DIVISION_TOTAL: - case kind::INTS_MODULUS_TOTAL: - return rewriteIntsDivModTotal(t,true); - case kind::ABS: - if(t[0].isConst()) { - const Rational& rat = t[0].getConst(); - if(rat >= 0) { - return RewriteResponse(REWRITE_DONE, t[0]); - } else { - return RewriteResponse( - REWRITE_DONE, - NodeManager::currentNM()->mkConstRealOrInt(t[0].getType(), -rat)); + case kind::MINUS: return rewriteMinus(t); + case kind::UMINUS: return rewriteUMinus(t, true); + case kind::DIVISION: + case kind::DIVISION_TOTAL: return rewriteDiv(t, true); + case kind::PLUS: return preRewritePlus(t); + case kind::MULT: + case kind::NONLINEAR_MULT: return preRewriteMult(t); + case kind::IAND: return RewriteResponse(REWRITE_DONE, t); + case kind::POW2: return RewriteResponse(REWRITE_DONE, t); + case kind::EXPONENTIAL: + case kind::SINE: + case kind::COSINE: + case kind::TANGENT: + case kind::COSECANT: + case kind::SECANT: + case kind::COTANGENT: + case kind::ARCSINE: + case kind::ARCCOSINE: + case kind::ARCTANGENT: + case kind::ARCCOSECANT: + case kind::ARCSECANT: + case kind::ARCCOTANGENT: + case kind::SQRT: return preRewriteTranscendental(t); + case kind::INTS_DIVISION: + case kind::INTS_MODULUS: return rewriteIntsDivMod(t, true); + case kind::INTS_DIVISION_TOTAL: + case kind::INTS_MODULUS_TOTAL: return rewriteIntsDivModTotal(t, true); + case kind::ABS: + if (t[0].isConst()) + { + const Rational& rat = t[0].getConst(); + if (rat >= 0) + { + return RewriteResponse(REWRITE_DONE, t[0]); + } + else + { + return RewriteResponse(REWRITE_DONE, + NodeManager::currentNM()->mkConstRealOrInt( + t[0].getType(), -rat)); + } } - } - return RewriteResponse(REWRITE_DONE, t); - case kind::IS_INTEGER: - case kind::TO_INTEGER: - return RewriteResponse(REWRITE_DONE, t); - case kind::TO_REAL: - case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]); - case kind::POW: - return RewriteResponse(REWRITE_DONE, t); - case kind::PI: - return RewriteResponse(REWRITE_DONE, t); - default: Unhandled() << k; + return RewriteResponse(REWRITE_DONE, t); + case kind::IS_INTEGER: + case kind::TO_INTEGER: return RewriteResponse(REWRITE_DONE, t); + case kind::TO_REAL: + case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]); + case kind::POW: return RewriteResponse(REWRITE_DONE, t); + case kind::PI: return RewriteResponse(REWRITE_DONE, t); + default: Unhandled() << k; } } } @@ -172,54 +164,53 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){ }else{ Trace("arith-rewriter") << "postRewriteTerm: " << t << std::endl; switch(t.getKind()){ - case kind::MINUS: - return rewriteMinus(t, false); - case kind::UMINUS: - return rewriteUMinus(t, false); - case kind::DIVISION: - case kind::DIVISION_TOTAL: - return rewriteDiv(t, false); - case kind::PLUS: - return postRewritePlus(t); - case kind::MULT: - case kind::NONLINEAR_MULT: return postRewriteMult(t); - case kind::IAND: return postRewriteIAnd(t); - case kind::POW2: return postRewritePow2(t); - case kind::EXPONENTIAL: - case kind::SINE: - case kind::COSINE: - case kind::TANGENT: - case kind::COSECANT: - case kind::SECANT: - case kind::COTANGENT: - case kind::ARCSINE: - case kind::ARCCOSINE: - case kind::ARCTANGENT: - case kind::ARCCOSECANT: - case kind::ARCSECANT: - case kind::ARCCOTANGENT: - case kind::SQRT: return postRewriteTranscendental(t); - case kind::INTS_DIVISION: - case kind::INTS_MODULUS: return rewriteIntsDivMod(t, false); - case kind::INTS_DIVISION_TOTAL: - case kind::INTS_MODULUS_TOTAL: - return rewriteIntsDivModTotal(t, false); - case kind::ABS: - if(t[0].isConst()) { - const Rational& rat = t[0].getConst(); - if(rat >= 0) { - return RewriteResponse(REWRITE_DONE, t[0]); - } else { - return RewriteResponse( - REWRITE_DONE, - NodeManager::currentNM()->mkConstRealOrInt(t[0].getType(), -rat)); + case kind::MINUS: return rewriteMinus(t); + case kind::UMINUS: return rewriteUMinus(t, false); + case kind::DIVISION: + case kind::DIVISION_TOTAL: return rewriteDiv(t, false); + case kind::PLUS: return postRewritePlus(t); + case kind::MULT: + case kind::NONLINEAR_MULT: return postRewriteMult(t); + case kind::IAND: return postRewriteIAnd(t); + case kind::POW2: return postRewritePow2(t); + case kind::EXPONENTIAL: + case kind::SINE: + case kind::COSINE: + case kind::TANGENT: + case kind::COSECANT: + case kind::SECANT: + case kind::COTANGENT: + case kind::ARCSINE: + case kind::ARCCOSINE: + case kind::ARCTANGENT: + case kind::ARCCOSECANT: + case kind::ARCSECANT: + case kind::ARCCOTANGENT: + case kind::SQRT: return postRewriteTranscendental(t); + case kind::INTS_DIVISION: + case kind::INTS_MODULUS: return rewriteIntsDivMod(t, false); + case kind::INTS_DIVISION_TOTAL: + case kind::INTS_MODULUS_TOTAL: return rewriteIntsDivModTotal(t, false); + case kind::ABS: + if (t[0].isConst()) + { + const Rational& rat = t[0].getConst(); + if (rat >= 0) + { + return RewriteResponse(REWRITE_DONE, t[0]); + } + else + { + return RewriteResponse(REWRITE_DONE, + NodeManager::currentNM()->mkConstRealOrInt( + t[0].getType(), -rat)); + } } - } - return RewriteResponse(REWRITE_DONE, t); - case kind::TO_REAL: - case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]); - case kind::TO_INTEGER: return rewriteExtIntegerOp(t); - case kind::POW: + return RewriteResponse(REWRITE_DONE, t); + case kind::TO_REAL: + case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]); + case kind::TO_INTEGER: return rewriteExtIntegerOp(t); + case kind::POW: { if (t[1].isConst()) { @@ -757,13 +748,6 @@ Node ArithRewriter::makeUnaryMinusNode(TNode n){ return nm->mkNode(kind::MULT, nm->mkConstRealOrInt(n.getType(), qNegOne), n); } -Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){ - Node negR = makeUnaryMinusNode(r); - Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR); - - return diff; -} - RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){ Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind() == kind::DIVISION); diff --git a/src/theory/arith/arith_rewriter.h b/src/theory/arith/arith_rewriter.h index 1f861b08a..5fcb628b9 100644 --- a/src/theory/arith/arith_rewriter.h +++ b/src/theory/arith/arith_rewriter.h @@ -51,7 +51,7 @@ class ArithRewriter : public TheoryRewriter static RewriteResponse rewriteVariable(TNode t); static RewriteResponse rewriteConstant(TNode t); - static RewriteResponse rewriteMinus(TNode t, bool pre); + static RewriteResponse rewriteMinus(TNode t); static RewriteResponse rewriteUMinus(TNode t, bool pre); static RewriteResponse rewriteDiv(TNode t, bool pre); static RewriteResponse rewriteIntsDivMod(TNode t, bool pre); -- 2.30.2