1 /******************************************************************************
2 * Top contributors (to current version):
3 * Andrew Reynolds, Tim King, Morgan Deters
5 * This file is part of the cvc5 project.
7 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 * in the top-level source directory and their institutional affiliations.
9 * All rights reserved. See the file COPYING in the top-level source
10 * directory for licensing information.
11 * ****************************************************************************
13 * [[ Add one-line brief description here ]]
15 * [[ Add lengthier description here ]]
16 * \todo document this file
19 #include "theory/arith/arith_rewriter.h"
26 #include "smt/logic_exception.h"
27 #include "theory/arith/arith_msum.h"
28 #include "theory/arith/arith_utilities.h"
29 #include "theory/arith/normal_form.h"
30 #include "theory/arith/operator_elim.h"
31 #include "theory/theory.h"
32 #include "util/bitvector.h"
33 #include "util/divisible.h"
34 #include "util/iand.h"
40 ArithRewriter::ArithRewriter(OperatorElim
& oe
) : d_opElim(oe
) {}
42 bool ArithRewriter::isAtom(TNode n
) {
44 return arith::isRelationOperator(k
) || k
== kind::IS_INTEGER
45 || k
== kind::DIVISIBLE
;
48 RewriteResponse
ArithRewriter::rewriteConstant(TNode t
){
50 Assert(t
.getKind() == kind::CONST_RATIONAL
);
52 return RewriteResponse(REWRITE_DONE
, t
);
55 RewriteResponse
ArithRewriter::rewriteVariable(TNode t
){
58 return RewriteResponse(REWRITE_DONE
, t
);
61 RewriteResponse
ArithRewriter::rewriteMinus(TNode t
, bool pre
){
62 Assert(t
.getKind() == kind::MINUS
);
67 Node zeroNode
= mkRationalNode(zero
);
68 return RewriteResponse(REWRITE_DONE
, zeroNode
);
70 Node noMinus
= makeSubtractionNode(t
[0],t
[1]);
71 return RewriteResponse(REWRITE_DONE
, noMinus
);
74 Polynomial minuend
= Polynomial::parsePolynomial(t
[0]);
75 Polynomial subtrahend
= Polynomial::parsePolynomial(t
[1]);
76 Polynomial diff
= minuend
- subtrahend
;
77 return RewriteResponse(REWRITE_DONE
, diff
.getNode());
81 RewriteResponse
ArithRewriter::rewriteUMinus(TNode t
, bool pre
){
82 Assert(t
.getKind() == kind::UMINUS
);
84 if(t
[0].getKind() == kind::CONST_RATIONAL
){
85 Rational neg
= -(t
[0].getConst
<Rational
>());
86 return RewriteResponse(REWRITE_DONE
, mkRationalNode(neg
));
89 Node noUminus
= makeUnaryMinusNode(t
[0]);
91 return RewriteResponse(REWRITE_DONE
, noUminus
);
93 return RewriteResponse(REWRITE_AGAIN
, noUminus
);
96 RewriteResponse
ArithRewriter::preRewriteTerm(TNode t
){
98 return rewriteConstant(t
);
100 return rewriteVariable(t
);
102 switch(Kind k
= t
.getKind()){
104 return rewriteMinus(t
, true);
106 return rewriteUMinus(t
, true);
108 case kind::DIVISION_TOTAL
:
109 return rewriteDiv(t
,true);
111 return preRewritePlus(t
);
113 case kind::NONLINEAR_MULT
: return preRewriteMult(t
);
114 case kind::IAND
: return RewriteResponse(REWRITE_DONE
, t
);
115 case kind::EXPONENTIAL
:
121 case kind::COTANGENT
:
123 case kind::ARCCOSINE
:
124 case kind::ARCTANGENT
:
125 case kind::ARCCOSECANT
:
126 case kind::ARCSECANT
:
127 case kind::ARCCOTANGENT
:
128 case kind::SQRT
: return preRewriteTranscendental(t
);
129 case kind::INTS_DIVISION
:
130 case kind::INTS_MODULUS
: return rewriteIntsDivMod(t
, true);
131 case kind::INTS_DIVISION_TOTAL
:
132 case kind::INTS_MODULUS_TOTAL
:
133 return rewriteIntsDivModTotal(t
,true);
136 const Rational
& rat
= t
[0].getConst
<Rational
>();
138 return RewriteResponse(REWRITE_DONE
, t
[0]);
140 return RewriteResponse(REWRITE_DONE
,
141 NodeManager::currentNM()->mkConst(-rat
));
144 return RewriteResponse(REWRITE_DONE
, t
);
145 case kind::IS_INTEGER
:
146 case kind::TO_INTEGER
:
147 return RewriteResponse(REWRITE_DONE
, t
);
149 case kind::CAST_TO_REAL
: return RewriteResponse(REWRITE_DONE
, t
[0]);
151 return RewriteResponse(REWRITE_DONE
, t
);
153 return RewriteResponse(REWRITE_DONE
, t
);
154 default: Unhandled() << k
;
159 RewriteResponse
ArithRewriter::postRewriteTerm(TNode t
){
161 return rewriteConstant(t
);
163 return rewriteVariable(t
);
167 return rewriteMinus(t
, false);
169 return rewriteUMinus(t
, false);
171 case kind::DIVISION_TOTAL
:
172 return rewriteDiv(t
, false);
174 return postRewritePlus(t
);
176 case kind::NONLINEAR_MULT
: return postRewriteMult(t
);
177 case kind::IAND
: return postRewriteIAnd(t
);
178 case kind::EXPONENTIAL
:
184 case kind::COTANGENT
:
186 case kind::ARCCOSINE
:
187 case kind::ARCTANGENT
:
188 case kind::ARCCOSECANT
:
189 case kind::ARCSECANT
:
190 case kind::ARCCOTANGENT
:
191 case kind::SQRT
: return postRewriteTranscendental(t
);
192 case kind::INTS_DIVISION
:
193 case kind::INTS_MODULUS
: return rewriteIntsDivMod(t
, false);
194 case kind::INTS_DIVISION_TOTAL
:
195 case kind::INTS_MODULUS_TOTAL
:
196 return rewriteIntsDivModTotal(t
, false);
199 const Rational
& rat
= t
[0].getConst
<Rational
>();
201 return RewriteResponse(REWRITE_DONE
, t
[0]);
203 return RewriteResponse(REWRITE_DONE
,
204 NodeManager::currentNM()->mkConst(-rat
));
207 return RewriteResponse(REWRITE_DONE
, t
);
209 case kind::CAST_TO_REAL
: return RewriteResponse(REWRITE_DONE
, t
[0]);
210 case kind::TO_INTEGER
:
212 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(Rational(t
[0].getConst
<Rational
>().floor())));
214 if(t
[0].getType().isInteger()) {
215 return RewriteResponse(REWRITE_DONE
, t
[0]);
217 //Unimplemented() << "TO_INTEGER, nonconstant";
218 //return rewriteToInteger(t);
219 return RewriteResponse(REWRITE_DONE
, t
);
220 case kind::IS_INTEGER
:
222 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(t
[0].getConst
<Rational
>().getDenominator() == 1));
224 if(t
[0].getType().isInteger()) {
225 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(true));
227 //Unimplemented() << "IS_INTEGER, nonconstant";
228 //return rewriteIsInteger(t);
229 return RewriteResponse(REWRITE_DONE
, t
);
232 if(t
[1].getKind() == kind::CONST_RATIONAL
){
233 const Rational
& exp
= t
[1].getConst
<Rational
>();
236 return RewriteResponse(REWRITE_DONE
, mkRationalNode(Rational(1)));
237 }else if(exp
.sgn() > 0 && exp
.isIntegral()){
238 cvc5::Rational
r(expr::NodeValue::MAX_CHILDREN
);
241 unsigned num
= exp
.getNumerator().toUnsignedInt();
243 return RewriteResponse(REWRITE_AGAIN
, base
);
245 NodeBuilder
nb(kind::MULT
);
246 for(unsigned i
=0; i
< num
; ++i
){
249 Assert(nb
.getNumChildren() > 0);
251 return RewriteResponse(REWRITE_AGAIN
, mult
);
257 // Todo improve the exception thrown
258 std::stringstream ss
;
259 ss
<< "The exponent of the POW(^) operator can only be a positive "
260 "integral constant below "
261 << (expr::NodeValue::MAX_CHILDREN
+ 1) << ". ";
262 ss
<< "Exception occurred in:" << std::endl
;
264 throw LogicException(ss
.str());
267 return RewriteResponse(REWRITE_DONE
, t
);
275 RewriteResponse
ArithRewriter::preRewriteMult(TNode t
){
276 Assert(t
.getKind() == kind::MULT
|| t
.getKind() == kind::NONLINEAR_MULT
);
278 if(t
.getNumChildren() == 2){
279 if(t
[0].getKind() == kind::CONST_RATIONAL
280 && t
[0].getConst
<Rational
>().isOne()){
281 return RewriteResponse(REWRITE_DONE
, t
[1]);
283 if(t
[1].getKind() == kind::CONST_RATIONAL
284 && t
[1].getConst
<Rational
>().isOne()){
285 return RewriteResponse(REWRITE_DONE
, t
[0]);
289 // Rewrite multiplications with a 0 argument and to 0
290 for(TNode::iterator i
= t
.begin(); i
!= t
.end(); ++i
) {
291 if((*i
).getKind() == kind::CONST_RATIONAL
) {
292 if((*i
).getConst
<Rational
>().isZero()) {
294 return RewriteResponse(REWRITE_DONE
, zero
);
298 return RewriteResponse(REWRITE_DONE
, t
);
301 static bool canFlatten(Kind k
, TNode t
){
302 for(TNode::iterator i
= t
.begin(); i
!= t
.end(); ++i
) {
304 if(child
.getKind() == k
){
311 static void flatten(std::vector
<TNode
>& pb
, Kind k
, TNode t
){
312 if(t
.getKind() == k
){
313 for(TNode::iterator i
= t
.begin(); i
!= t
.end(); ++i
) {
315 if(child
.getKind() == k
){
316 flatten(pb
, k
, child
);
326 static Node
flatten(Kind k
, TNode t
){
327 std::vector
<TNode
> pb
;
329 Assert(pb
.size() >= 2);
330 return NodeManager::currentNM()->mkNode(k
, pb
);
333 RewriteResponse
ArithRewriter::preRewritePlus(TNode t
){
334 Assert(t
.getKind() == kind::PLUS
);
336 if(canFlatten(kind::PLUS
, t
)){
337 return RewriteResponse(REWRITE_DONE
, flatten(kind::PLUS
, t
));
339 return RewriteResponse(REWRITE_DONE
, t
);
343 RewriteResponse
ArithRewriter::postRewritePlus(TNode t
){
344 Assert(t
.getKind() == kind::PLUS
);
346 std::vector
<Monomial
> monomials
;
347 std::vector
<Polynomial
> polynomials
;
349 for(TNode::iterator i
= t
.begin(), end
= t
.end(); i
!= end
; ++i
){
351 if(Monomial::isMember(curr
)){
352 monomials
.push_back(Monomial::parseMonomial(curr
));
354 polynomials
.push_back(Polynomial::parsePolynomial(curr
));
358 if(!monomials
.empty()){
359 Monomial::sort(monomials
);
360 Monomial::combineAdjacentMonomials(monomials
);
361 polynomials
.push_back(Polynomial::mkPolynomial(monomials
));
364 Polynomial res
= Polynomial::sumPolynomials(polynomials
);
366 return RewriteResponse(REWRITE_DONE
, res
.getNode());
369 RewriteResponse
ArithRewriter::postRewriteMult(TNode t
){
370 Assert(t
.getKind() == kind::MULT
|| t
.getKind() == kind::NONLINEAR_MULT
);
372 Polynomial res
= Polynomial::mkOne();
374 for(TNode::iterator i
= t
.begin(), end
= t
.end(); i
!= end
; ++i
){
376 Polynomial currPoly
= Polynomial::parsePolynomial(curr
);
378 res
= res
* currPoly
;
381 return RewriteResponse(REWRITE_DONE
, res
.getNode());
384 RewriteResponse
ArithRewriter::postRewriteIAnd(TNode t
)
386 Assert(t
.getKind() == kind::IAND
);
387 NodeManager
* nm
= NodeManager::currentNM();
388 // if constant, we eliminate
389 if (t
[0].isConst() && t
[1].isConst())
391 size_t bsize
= t
.getOperator().getConst
<IntAnd
>().d_size
;
392 Node iToBvop
= nm
->mkConst(IntToBitVector(bsize
));
393 Node arg1
= nm
->mkNode(kind::INT_TO_BITVECTOR
, iToBvop
, t
[0]);
394 Node arg2
= nm
->mkNode(kind::INT_TO_BITVECTOR
, iToBvop
, t
[1]);
395 Node bvand
= nm
->mkNode(kind::BITVECTOR_AND
, arg1
, arg2
);
396 Node ret
= nm
->mkNode(kind::BITVECTOR_TO_NAT
, bvand
);
397 return RewriteResponse(REWRITE_AGAIN_FULL
, ret
);
399 else if (t
[0] > t
[1])
401 // ((_ iand k) x y) ---> ((_ iand k) y x) if x > y by node ordering
402 Node ret
= nm
->mkNode(kind::IAND
, t
.getOperator(), t
[1], t
[0]);
403 return RewriteResponse(REWRITE_AGAIN
, ret
);
405 else if (t
[0] == t
[1])
407 // ((_ iand k) x x) ---> x
408 return RewriteResponse(REWRITE_DONE
, t
[0]);
410 // simplifications involving constants
411 for (unsigned i
= 0; i
< 2; i
++)
417 if (t
[i
].getConst
<Rational
>().sgn() == 0)
419 // ((_ iand k) 0 y) ---> 0
420 return RewriteResponse(REWRITE_DONE
, t
[i
]);
423 return RewriteResponse(REWRITE_DONE
, t
);
426 RewriteResponse
ArithRewriter::preRewriteTranscendental(TNode t
) {
427 return RewriteResponse(REWRITE_DONE
, t
);
430 RewriteResponse
ArithRewriter::postRewriteTranscendental(TNode t
) {
431 Trace("arith-tf-rewrite") << "Rewrite transcendental function : " << t
<< std::endl
;
432 NodeManager
* nm
= NodeManager::currentNM();
433 switch( t
.getKind() ){
434 case kind::EXPONENTIAL
: {
435 if(t
[0].getKind() == kind::CONST_RATIONAL
){
436 Node one
= nm
->mkConst(Rational(1));
437 if(t
[0].getConst
<Rational
>().sgn()>=0 && t
[0].getType().isInteger() && t
[0]!=one
){
438 return RewriteResponse(
440 nm
->mkNode(kind::POW
, nm
->mkNode(kind::EXPONENTIAL
, one
), t
[0]));
442 return RewriteResponse(REWRITE_DONE
, t
);
445 else if (t
[0].getKind() == kind::PLUS
)
447 std::vector
<Node
> product
;
448 for (const Node tc
: t
[0])
450 product
.push_back(nm
->mkNode(kind::EXPONENTIAL
, tc
));
452 // We need to do a full rewrite here, since we can get exponentials of
453 // constants, e.g. when we are rewriting exp(2 + x)
454 return RewriteResponse(REWRITE_AGAIN_FULL
,
455 nm
->mkNode(kind::MULT
, product
));
460 if(t
[0].getKind() == kind::CONST_RATIONAL
){
461 const Rational
& rat
= t
[0].getConst
<Rational
>();
463 return RewriteResponse(REWRITE_DONE
, nm
->mkConst(Rational(0)));
465 else if (rat
.sgn() == -1)
468 nm
->mkNode(kind::UMINUS
, nm
->mkNode(kind::SINE
, nm
->mkConst(-rat
)));
469 return RewriteResponse(REWRITE_AGAIN_FULL
, ret
);
472 // get the factor of PI in the argument
476 std::map
<Node
, Node
> msum
;
477 if (ArithMSum::getMonomialSum(t
[0], msum
))
480 std::map
<Node
, Node
>::iterator itm
= msum
.find(pi
);
481 if (itm
!= msum
.end())
483 if (itm
->second
.isNull())
485 pi_factor
= mkRationalNode(Rational(1));
489 pi_factor
= itm
->second
;
494 rem
= ArithMSum::mkNode(msum
);
503 // if there is a factor of PI
504 if( !pi_factor
.isNull() ){
505 Trace("arith-tf-rewrite-debug") << "Process pi factor = " << pi_factor
<< std::endl
;
506 Rational r
= pi_factor
.getConst
<Rational
>();
507 Rational r_abs
= r
.abs();
508 Rational rone
= Rational(1);
509 Node ntwo
= mkRationalNode(Rational(2));
512 //add/substract 2*pi beyond scope
513 Node ra_div_two
= nm
->mkNode(
514 kind::INTS_DIVISION
, mkRationalNode(r_abs
+ rone
), ntwo
);
518 nm
->mkNode(kind::MINUS
,
520 nm
->mkNode(kind::MULT
, ntwo
, ra_div_two
));
522 Assert(r
.sgn() == -1);
524 nm
->mkNode(kind::PLUS
,
526 nm
->mkNode(kind::MULT
, ntwo
, ra_div_two
));
528 Node new_arg
= nm
->mkNode(kind::MULT
, new_pi_factor
, pi
);
531 new_arg
= nm
->mkNode(kind::PLUS
, new_arg
, rem
);
533 // sin( 2*n*PI + x ) = sin( x )
534 return RewriteResponse(REWRITE_AGAIN_FULL
,
535 nm
->mkNode(kind::SINE
, new_arg
));
537 else if (r_abs
== rone
)
539 // sin( PI + x ) = -sin( x )
542 return RewriteResponse(REWRITE_DONE
, mkRationalNode(Rational(0)));
546 return RewriteResponse(
548 nm
->mkNode(kind::UMINUS
, nm
->mkNode(kind::SINE
, rem
)));
551 else if (rem
.isNull())
553 // other rational cases based on Niven's theorem
554 // (https://en.wikipedia.org/wiki/Niven%27s_theorem)
555 Integer one
= Integer(1);
556 Integer two
= Integer(2);
557 Integer six
= Integer(6);
558 if (r_abs
.getDenominator() == two
)
560 Assert(r_abs
.getNumerator() == one
);
561 return RewriteResponse(REWRITE_DONE
,
562 mkRationalNode(Rational(r
.sgn())));
564 else if (r_abs
.getDenominator() == six
)
566 Integer five
= Integer(5);
567 if (r_abs
.getNumerator() == one
|| r_abs
.getNumerator() == five
)
569 return RewriteResponse(
571 mkRationalNode(Rational(r
.sgn()) / Rational(2)));
579 return RewriteResponse(
581 nm
->mkNode(kind::SINE
,
582 nm
->mkNode(kind::MINUS
,
583 nm
->mkNode(kind::MULT
,
584 nm
->mkConst(Rational(1) / Rational(2)),
591 return RewriteResponse(REWRITE_AGAIN_FULL
,
592 nm
->mkNode(kind::DIVISION
,
593 nm
->mkNode(kind::SINE
, t
[0]),
594 nm
->mkNode(kind::COSINE
, t
[0])));
599 return RewriteResponse(REWRITE_AGAIN_FULL
,
600 nm
->mkNode(kind::DIVISION
,
601 mkRationalNode(Rational(1)),
602 nm
->mkNode(kind::SINE
, t
[0])));
607 return RewriteResponse(REWRITE_AGAIN_FULL
,
608 nm
->mkNode(kind::DIVISION
,
609 mkRationalNode(Rational(1)),
610 nm
->mkNode(kind::COSINE
, t
[0])));
613 case kind::COTANGENT
:
615 return RewriteResponse(REWRITE_AGAIN_FULL
,
616 nm
->mkNode(kind::DIVISION
,
617 nm
->mkNode(kind::COSINE
, t
[0]),
618 nm
->mkNode(kind::SINE
, t
[0])));
624 return RewriteResponse(REWRITE_DONE
, t
);
627 RewriteResponse
ArithRewriter::postRewriteAtom(TNode atom
){
628 if(atom
.getKind() == kind::IS_INTEGER
) {
629 if(atom
[0].isConst()) {
630 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(atom
[0].getConst
<Rational
>().isIntegral()));
632 if(atom
[0].getType().isInteger()) {
633 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(true));
635 // not supported, but this isn't the right place to complain
636 return RewriteResponse(REWRITE_DONE
, atom
);
637 } else if(atom
.getKind() == kind::DIVISIBLE
) {
638 if(atom
[0].isConst()) {
639 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(bool((atom
[0].getConst
<Rational
>() / atom
.getOperator().getConst
<Divisible
>().k
).isIntegral())));
641 if(atom
.getOperator().getConst
<Divisible
>().k
.isOne()) {
642 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(true));
644 return RewriteResponse(REWRITE_AGAIN
, NodeManager::currentNM()->mkNode(kind::EQUAL
, NodeManager::currentNM()->mkNode(kind::INTS_MODULUS_TOTAL
, atom
[0], NodeManager::currentNM()->mkConst(Rational(atom
.getOperator().getConst
<Divisible
>().k
))), NodeManager::currentNM()->mkConst(Rational(0))));
648 TNode left
= atom
[0];
649 TNode right
= atom
[1];
651 Polynomial pleft
= Polynomial::parsePolynomial(left
);
652 Polynomial pright
= Polynomial::parsePolynomial(right
);
654 Debug("arith::rewriter") << "pleft " << pleft
.getNode() << std::endl
;
655 Debug("arith::rewriter") << "pright " << pright
.getNode() << std::endl
;
657 Comparison cmp
= Comparison::mkComparison(atom
.getKind(), pleft
, pright
);
658 Assert(cmp
.isNormalForm());
659 return RewriteResponse(REWRITE_DONE
, cmp
.getNode());
662 RewriteResponse
ArithRewriter::preRewriteAtom(TNode atom
){
663 Assert(isAtom(atom
));
665 NodeManager
* currNM
= NodeManager::currentNM();
667 if(atom
.getKind() == kind::EQUAL
) {
668 if(atom
[0] == atom
[1]) {
669 return RewriteResponse(REWRITE_DONE
, currNM
->mkConst(true));
671 }else if(atom
.getKind() == kind::GT
){
672 Node leq
= currNM
->mkNode(kind::LEQ
, atom
[0], atom
[1]);
673 return RewriteResponse(REWRITE_DONE
, currNM
->mkNode(kind::NOT
, leq
));
674 }else if(atom
.getKind() == kind::LT
){
675 Node geq
= currNM
->mkNode(kind::GEQ
, atom
[0], atom
[1]);
676 return RewriteResponse(REWRITE_DONE
, currNM
->mkNode(kind::NOT
, geq
));
677 }else if(atom
.getKind() == kind::IS_INTEGER
){
678 if(atom
[0].getType().isInteger()){
679 return RewriteResponse(REWRITE_DONE
, currNM
->mkConst(true));
681 }else if(atom
.getKind() == kind::DIVISIBLE
){
682 if(atom
.getOperator().getConst
<Divisible
>().k
.isOne()){
683 return RewriteResponse(REWRITE_DONE
, currNM
->mkConst(true));
687 return RewriteResponse(REWRITE_DONE
, atom
);
690 RewriteResponse
ArithRewriter::postRewrite(TNode t
){
692 RewriteResponse response
= postRewriteTerm(t
);
693 if (Debug
.isOn("arith::rewriter") && response
.d_status
== REWRITE_DONE
)
695 Polynomial::parsePolynomial(response
.d_node
);
699 RewriteResponse response
= postRewriteAtom(t
);
700 if (Debug
.isOn("arith::rewriter") && response
.d_status
== REWRITE_DONE
)
702 Comparison::parseNormalForm(response
.d_node
);
710 RewriteResponse
ArithRewriter::preRewrite(TNode t
){
712 return preRewriteTerm(t
);
714 return preRewriteAtom(t
);
720 Node
ArithRewriter::makeUnaryMinusNode(TNode n
){
721 Rational
qNegOne(-1);
722 return NodeManager::currentNM()->mkNode(kind::MULT
, mkRationalNode(qNegOne
),n
);
725 Node
ArithRewriter::makeSubtractionNode(TNode l
, TNode r
){
726 Node negR
= makeUnaryMinusNode(r
);
727 Node diff
= NodeManager::currentNM()->mkNode(kind::PLUS
, l
, negR
);
732 RewriteResponse
ArithRewriter::rewriteDiv(TNode t
, bool pre
){
733 Assert(t
.getKind() == kind::DIVISION_TOTAL
|| t
.getKind() == kind::DIVISION
);
737 if(right
.getKind() == kind::CONST_RATIONAL
){
738 const Rational
& den
= right
.getConst
<Rational
>();
741 if(t
.getKind() == kind::DIVISION_TOTAL
){
742 return RewriteResponse(REWRITE_DONE
, mkRationalNode(0));
744 // This is unsupported, but this is not a good place to complain
745 return RewriteResponse(REWRITE_DONE
, t
);
748 Assert(den
!= Rational(0));
750 if(left
.getKind() == kind::CONST_RATIONAL
){
751 const Rational
& num
= left
.getConst
<Rational
>();
752 Rational div
= num
/ den
;
753 Node result
= mkRationalNode(div
);
754 return RewriteResponse(REWRITE_DONE
, result
);
757 Rational div
= den
.inverse();
759 Node result
= mkRationalNode(div
);
761 Node mult
= NodeManager::currentNM()->mkNode(kind::MULT
,left
,result
);
763 return RewriteResponse(REWRITE_DONE
, mult
);
765 return RewriteResponse(REWRITE_AGAIN
, mult
);
768 return RewriteResponse(REWRITE_DONE
, t
);
772 RewriteResponse
ArithRewriter::rewriteIntsDivMod(TNode t
, bool pre
)
774 NodeManager
* nm
= NodeManager::currentNM();
775 Kind k
= t
.getKind();
776 Node zero
= nm
->mkConst(Rational(0));
777 if (k
== kind::INTS_MODULUS
)
779 if (t
[1].isConst() && !t
[1].getConst
<Rational
>().isZero())
781 // can immediately replace by INTS_MODULUS_TOTAL
782 Node ret
= nm
->mkNode(kind::INTS_MODULUS_TOTAL
, t
[0], t
[1]);
783 return returnRewrite(t
, ret
, Rewrite::MOD_TOTAL_BY_CONST
);
786 if (k
== kind::INTS_DIVISION
)
788 if (t
[1].isConst() && !t
[1].getConst
<Rational
>().isZero())
790 // can immediately replace by INTS_DIVISION_TOTAL
791 Node ret
= nm
->mkNode(kind::INTS_DIVISION_TOTAL
, t
[0], t
[1]);
792 return returnRewrite(t
, ret
, Rewrite::DIV_TOTAL_BY_CONST
);
795 return RewriteResponse(REWRITE_DONE
, t
);
798 RewriteResponse
ArithRewriter::rewriteIntsDivModTotal(TNode t
, bool pre
)
802 // do not rewrite at prewrite.
803 return RewriteResponse(REWRITE_DONE
, t
);
805 NodeManager
* nm
= NodeManager::currentNM();
806 Kind k
= t
.getKind();
807 Assert(k
== kind::INTS_MODULUS_TOTAL
|| k
== kind::INTS_DIVISION_TOTAL
);
810 bool dIsConstant
= d
.getKind() == kind::CONST_RATIONAL
;
811 if(dIsConstant
&& d
.getConst
<Rational
>().isZero()){
812 // (div x 0) ---> 0 or (mod x 0) ---> 0
813 return returnRewrite(t
, mkRationalNode(0), Rewrite::DIV_MOD_BY_ZERO
);
814 }else if(dIsConstant
&& d
.getConst
<Rational
>().isOne()){
815 if (k
== kind::INTS_MODULUS_TOTAL
)
818 return returnRewrite(t
, mkRationalNode(0), Rewrite::MOD_BY_ONE
);
820 Assert(k
== kind::INTS_DIVISION_TOTAL
);
822 return returnRewrite(t
, n
, Rewrite::DIV_BY_ONE
);
824 else if (dIsConstant
&& d
.getConst
<Rational
>().sgn() < 0)
827 // (div x (- c)) ---> (- (div x c))
828 // (mod x (- c)) ---> (mod x c)
829 Node nn
= nm
->mkNode(k
, t
[0], nm
->mkConst(-t
[1].getConst
<Rational
>()));
830 Node ret
= (k
== kind::INTS_DIVISION
|| k
== kind::INTS_DIVISION_TOTAL
)
831 ? nm
->mkNode(kind::UMINUS
, nn
)
833 return returnRewrite(t
, ret
, Rewrite::DIV_MOD_PULL_NEG_DEN
);
835 else if (dIsConstant
&& n
.getKind() == kind::CONST_RATIONAL
)
837 Assert(d
.getConst
<Rational
>().isIntegral());
838 Assert(n
.getConst
<Rational
>().isIntegral());
839 Assert(!d
.getConst
<Rational
>().isZero());
840 Integer di
= d
.getConst
<Rational
>().getNumerator();
841 Integer ni
= n
.getConst
<Rational
>().getNumerator();
843 bool isDiv
= (k
== kind::INTS_DIVISION
|| k
== kind::INTS_DIVISION_TOTAL
);
845 Integer result
= isDiv
? ni
.euclidianDivideQuotient(di
) : ni
.euclidianDivideRemainder(di
);
847 // constant evaluation
848 // (mod c1 c2) ---> c3 or (div c1 c2) ---> c3
849 Node resultNode
= mkRationalNode(Rational(result
));
850 return returnRewrite(t
, resultNode
, Rewrite::CONST_EVAL
);
852 if (k
== kind::INTS_MODULUS_TOTAL
)
854 // Note these rewrites do not need to account for modulus by zero as being
855 // a UF, which is handled by the reduction of INTS_MODULUS.
856 Kind k0
= t
[0].getKind();
857 if (k0
== kind::INTS_MODULUS_TOTAL
&& t
[0][1] == t
[1])
859 // (mod (mod x c) c) --> (mod x c)
860 return returnRewrite(t
, t
[0], Rewrite::MOD_OVER_MOD
);
862 else if (k0
== kind::NONLINEAR_MULT
|| k0
== kind::MULT
|| k0
== kind::PLUS
)
865 std::vector
<Node
> newChildren
;
866 bool childChanged
= false;
867 for (const Node
& tc
: t
[0])
869 if (tc
.getKind() == kind::INTS_MODULUS_TOTAL
&& tc
[1] == t
[1])
871 newChildren
.push_back(tc
[0]);
875 newChildren
.push_back(tc
);
879 // (mod (op ... (mod x c) ...) c) ---> (mod (op ... x ...) c) where
880 // op is one of { NONLINEAR_MULT, MULT, PLUS }.
881 Node ret
= nm
->mkNode(k0
, newChildren
);
882 ret
= nm
->mkNode(kind::INTS_MODULUS_TOTAL
, ret
, t
[1]);
883 return returnRewrite(t
, ret
, Rewrite::MOD_CHILD_MOD
);
889 Assert(k
== kind::INTS_DIVISION_TOTAL
);
890 // Note these rewrites do not need to account for division by zero as being
891 // a UF, which is handled by the reduction of INTS_DIVISION.
892 if (t
[0].getKind() == kind::INTS_MODULUS_TOTAL
&& t
[0][1] == t
[1])
894 // (div (mod x c) c) --> 0
895 Node ret
= mkRationalNode(0);
896 return returnRewrite(t
, ret
, Rewrite::DIV_OVER_MOD
);
899 return RewriteResponse(REWRITE_DONE
, t
);
902 TrustNode
ArithRewriter::expandDefinition(Node node
)
904 // call eliminate operators, to eliminate partial operators only
905 std::vector
<SkolemLemma
> lems
;
906 TrustNode ret
= d_opElim
.eliminate(node
, lems
, true);
907 Assert(lems
.empty());
911 RewriteResponse
ArithRewriter::returnRewrite(TNode t
, Node ret
, Rewrite r
)
913 Trace("arith-rewrite") << "ArithRewriter : " << t
<< " == " << ret
<< " by "
915 return RewriteResponse(REWRITE_AGAIN_FULL
, ret
);
919 } // namespace theory