From: Gereon Kremer Date: Wed, 12 Jan 2022 22:19:31 +0000 (-0800) Subject: Refactor rewriteMinus (#7932) X-Git-Tag: cvc5-1.0.0~556 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=d603925de77bc465fe6bfe5097e848224337acf9;p=cvc5.git Refactor rewriteMinus (#7932) Refactors rewriteMinus to be generally simpler and explicitly rely on rewriting addition. Note that rewriteMinus would almost never be called in post rewrite anyway, as it would rewrite to addition in pre rewrite. --- diff --git a/src/theory/arith/arith_rewriter.cpp b/src/theory/arith/arith_rewriter.cpp index 4e79515bd..2c3fcdf48 100644 --- a/src/theory/arith/arith_rewriter.cpp +++ b/src/theory/arith/arith_rewriter.cpp @@ -60,25 +60,21 @@ RewriteResponse ArithRewriter::rewriteVariable(TNode t){ return RewriteResponse(REWRITE_DONE, t); } -RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){ +RewriteResponse ArithRewriter::rewriteMinus(TNode t) +{ Assert(t.getKind() == kind::MINUS); + Assert(t.getNumChildren() == 2); - if(pre){ - if(t[0] == t[1]){ - Rational zero(0); - Node zeroNode = - NodeManager::currentNM()->mkConstRealOrInt(t.getType(), zero); - return RewriteResponse(REWRITE_DONE, zeroNode); - }else{ - Node noMinus = makeSubtractionNode(t[0],t[1]); - return RewriteResponse(REWRITE_DONE, noMinus); - } - }else{ - Polynomial minuend = Polynomial::parsePolynomial(t[0]); - Polynomial subtrahend = Polynomial::parsePolynomial(t[1]); - Polynomial diff = minuend - subtrahend; - return RewriteResponse(REWRITE_DONE, diff.getNode()); + auto* nm = NodeManager::currentNM(); + + if (t[0] == t[1]) + { + return RewriteResponse(REWRITE_DONE, + nm->mkConstRealOrInt(t.getType(), Rational(0))); } + return RewriteResponse( + REWRITE_AGAIN_FULL, + nm->mkNode(Kind::PLUS, t[0], makeUnaryMinusNode(t[1]))); } RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){ @@ -106,60 +102,56 @@ RewriteResponse ArithRewriter::preRewriteTerm(TNode t){ return rewriteVariable(t); }else{ switch(Kind k = t.getKind()){ - case kind::MINUS: - return rewriteMinus(t, true); - case kind::UMINUS: - return rewriteUMinus(t, true); - case kind::DIVISION: - case kind::DIVISION_TOTAL: - return rewriteDiv(t,true); - case kind::PLUS: - return preRewritePlus(t); - case kind::MULT: - case kind::NONLINEAR_MULT: return preRewriteMult(t); - case kind::IAND: return RewriteResponse(REWRITE_DONE, t); - case kind::POW2: return RewriteResponse(REWRITE_DONE, t); - case kind::EXPONENTIAL: - case kind::SINE: - case kind::COSINE: - case kind::TANGENT: - case kind::COSECANT: - case kind::SECANT: - case kind::COTANGENT: - case kind::ARCSINE: - case kind::ARCCOSINE: - case kind::ARCTANGENT: - case kind::ARCCOSECANT: - case kind::ARCSECANT: - case kind::ARCCOTANGENT: - case kind::SQRT: return preRewriteTranscendental(t); - case kind::INTS_DIVISION: - case kind::INTS_MODULUS: return rewriteIntsDivMod(t, true); - case kind::INTS_DIVISION_TOTAL: - case kind::INTS_MODULUS_TOTAL: - return rewriteIntsDivModTotal(t,true); - case kind::ABS: - if(t[0].isConst()) { - const Rational& rat = t[0].getConst(); - if(rat >= 0) { - return RewriteResponse(REWRITE_DONE, t[0]); - } else { - return RewriteResponse( - REWRITE_DONE, - NodeManager::currentNM()->mkConstRealOrInt(t[0].getType(), -rat)); + case kind::MINUS: return rewriteMinus(t); + case kind::UMINUS: return rewriteUMinus(t, true); + case kind::DIVISION: + case kind::DIVISION_TOTAL: return rewriteDiv(t, true); + case kind::PLUS: return preRewritePlus(t); + case kind::MULT: + case kind::NONLINEAR_MULT: return preRewriteMult(t); + case kind::IAND: return RewriteResponse(REWRITE_DONE, t); + case kind::POW2: return RewriteResponse(REWRITE_DONE, t); + case kind::EXPONENTIAL: + case kind::SINE: + case kind::COSINE: + case kind::TANGENT: + case kind::COSECANT: + case kind::SECANT: + case kind::COTANGENT: + case kind::ARCSINE: + case kind::ARCCOSINE: + case kind::ARCTANGENT: + case kind::ARCCOSECANT: + case kind::ARCSECANT: + case kind::ARCCOTANGENT: + case kind::SQRT: return preRewriteTranscendental(t); + case kind::INTS_DIVISION: + case kind::INTS_MODULUS: return rewriteIntsDivMod(t, true); + case kind::INTS_DIVISION_TOTAL: + case kind::INTS_MODULUS_TOTAL: return rewriteIntsDivModTotal(t, true); + case kind::ABS: + if (t[0].isConst()) + { + const Rational& rat = t[0].getConst(); + if (rat >= 0) + { + return RewriteResponse(REWRITE_DONE, t[0]); + } + else + { + return RewriteResponse(REWRITE_DONE, + NodeManager::currentNM()->mkConstRealOrInt( + t[0].getType(), -rat)); + } } - } - return RewriteResponse(REWRITE_DONE, t); - case kind::IS_INTEGER: - case kind::TO_INTEGER: - return RewriteResponse(REWRITE_DONE, t); - case kind::TO_REAL: - case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]); - case kind::POW: - return RewriteResponse(REWRITE_DONE, t); - case kind::PI: - return RewriteResponse(REWRITE_DONE, t); - default: Unhandled() << k; + return RewriteResponse(REWRITE_DONE, t); + case kind::IS_INTEGER: + case kind::TO_INTEGER: return RewriteResponse(REWRITE_DONE, t); + case kind::TO_REAL: + case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]); + case kind::POW: return RewriteResponse(REWRITE_DONE, t); + case kind::PI: return RewriteResponse(REWRITE_DONE, t); + default: Unhandled() << k; } } } @@ -172,54 +164,53 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){ }else{ Trace("arith-rewriter") << "postRewriteTerm: " << t << std::endl; switch(t.getKind()){ - case kind::MINUS: - return rewriteMinus(t, false); - case kind::UMINUS: - return rewriteUMinus(t, false); - case kind::DIVISION: - case kind::DIVISION_TOTAL: - return rewriteDiv(t, false); - case kind::PLUS: - return postRewritePlus(t); - case kind::MULT: - case kind::NONLINEAR_MULT: return postRewriteMult(t); - case kind::IAND: return postRewriteIAnd(t); - case kind::POW2: return postRewritePow2(t); - case kind::EXPONENTIAL: - case kind::SINE: - case kind::COSINE: - case kind::TANGENT: - case kind::COSECANT: - case kind::SECANT: - case kind::COTANGENT: - case kind::ARCSINE: - case kind::ARCCOSINE: - case kind::ARCTANGENT: - case kind::ARCCOSECANT: - case kind::ARCSECANT: - case kind::ARCCOTANGENT: - case kind::SQRT: return postRewriteTranscendental(t); - case kind::INTS_DIVISION: - case kind::INTS_MODULUS: return rewriteIntsDivMod(t, false); - case kind::INTS_DIVISION_TOTAL: - case kind::INTS_MODULUS_TOTAL: - return rewriteIntsDivModTotal(t, false); - case kind::ABS: - if(t[0].isConst()) { - const Rational& rat = t[0].getConst(); - if(rat >= 0) { - return RewriteResponse(REWRITE_DONE, t[0]); - } else { - return RewriteResponse( - REWRITE_DONE, - NodeManager::currentNM()->mkConstRealOrInt(t[0].getType(), -rat)); + case kind::MINUS: return rewriteMinus(t); + case kind::UMINUS: return rewriteUMinus(t, false); + case kind::DIVISION: + case kind::DIVISION_TOTAL: return rewriteDiv(t, false); + case kind::PLUS: return postRewritePlus(t); + case kind::MULT: + case kind::NONLINEAR_MULT: return postRewriteMult(t); + case kind::IAND: return postRewriteIAnd(t); + case kind::POW2: return postRewritePow2(t); + case kind::EXPONENTIAL: + case kind::SINE: + case kind::COSINE: + case kind::TANGENT: + case kind::COSECANT: + case kind::SECANT: + case kind::COTANGENT: + case kind::ARCSINE: + case kind::ARCCOSINE: + case kind::ARCTANGENT: + case kind::ARCCOSECANT: + case kind::ARCSECANT: + case kind::ARCCOTANGENT: + case kind::SQRT: return postRewriteTranscendental(t); + case kind::INTS_DIVISION: + case kind::INTS_MODULUS: return rewriteIntsDivMod(t, false); + case kind::INTS_DIVISION_TOTAL: + case kind::INTS_MODULUS_TOTAL: return rewriteIntsDivModTotal(t, false); + case kind::ABS: + if (t[0].isConst()) + { + const Rational& rat = t[0].getConst(); + if (rat >= 0) + { + return RewriteResponse(REWRITE_DONE, t[0]); + } + else + { + return RewriteResponse(REWRITE_DONE, + NodeManager::currentNM()->mkConstRealOrInt( + t[0].getType(), -rat)); + } } - } - return RewriteResponse(REWRITE_DONE, t); - case kind::TO_REAL: - case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]); - case kind::TO_INTEGER: return rewriteExtIntegerOp(t); - case kind::POW: + return RewriteResponse(REWRITE_DONE, t); + case kind::TO_REAL: + case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]); + case kind::TO_INTEGER: return rewriteExtIntegerOp(t); + case kind::POW: { if (t[1].isConst()) { @@ -757,13 +748,6 @@ Node ArithRewriter::makeUnaryMinusNode(TNode n){ return nm->mkNode(kind::MULT, nm->mkConstRealOrInt(n.getType(), qNegOne), n); } -Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){ - Node negR = makeUnaryMinusNode(r); - Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR); - - return diff; -} - RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){ Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind() == kind::DIVISION); diff --git a/src/theory/arith/arith_rewriter.h b/src/theory/arith/arith_rewriter.h index 1f861b08a..5fcb628b9 100644 --- a/src/theory/arith/arith_rewriter.h +++ b/src/theory/arith/arith_rewriter.h @@ -51,7 +51,7 @@ class ArithRewriter : public TheoryRewriter static RewriteResponse rewriteVariable(TNode t); static RewriteResponse rewriteConstant(TNode t); - static RewriteResponse rewriteMinus(TNode t, bool pre); + static RewriteResponse rewriteMinus(TNode t); static RewriteResponse rewriteUMinus(TNode t, bool pre); static RewriteResponse rewriteDiv(TNode t, bool pre); static RewriteResponse rewriteIntsDivMod(TNode t, bool pre);