1 /******************************************************************************
2 * Top contributors (to current version):
3 * Andrew Reynolds, Tim King, Morgan Deters
5 * This file is part of the cvc5 project.
7 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 * in the top-level source directory and their institutional affiliations.
9 * All rights reserved. See the file COPYING in the top-level source
10 * directory for licensing information.
11 * ****************************************************************************
13 * [[ Add one-line brief description here ]]
15 * [[ Add lengthier description here ]]
16 * \todo document this file
24 #include "smt/logic_exception.h"
25 #include "theory/arith/arith_msum.h"
26 #include "theory/arith/arith_rewriter.h"
27 #include "theory/arith/arith_utilities.h"
28 #include "theory/arith/normal_form.h"
29 #include "theory/theory.h"
30 #include "util/iand.h"
36 bool ArithRewriter::isAtom(TNode n
) {
38 return arith::isRelationOperator(k
) || k
== kind::IS_INTEGER
39 || k
== kind::DIVISIBLE
;
42 RewriteResponse
ArithRewriter::rewriteConstant(TNode t
){
44 Assert(t
.getKind() == kind::CONST_RATIONAL
);
46 return RewriteResponse(REWRITE_DONE
, t
);
49 RewriteResponse
ArithRewriter::rewriteVariable(TNode t
){
52 return RewriteResponse(REWRITE_DONE
, t
);
55 RewriteResponse
ArithRewriter::rewriteMinus(TNode t
, bool pre
){
56 Assert(t
.getKind() == kind::MINUS
);
61 Node zeroNode
= mkRationalNode(zero
);
62 return RewriteResponse(REWRITE_DONE
, zeroNode
);
64 Node noMinus
= makeSubtractionNode(t
[0],t
[1]);
65 return RewriteResponse(REWRITE_DONE
, noMinus
);
68 Polynomial minuend
= Polynomial::parsePolynomial(t
[0]);
69 Polynomial subtrahend
= Polynomial::parsePolynomial(t
[1]);
70 Polynomial diff
= minuend
- subtrahend
;
71 return RewriteResponse(REWRITE_DONE
, diff
.getNode());
75 RewriteResponse
ArithRewriter::rewriteUMinus(TNode t
, bool pre
){
76 Assert(t
.getKind() == kind::UMINUS
);
78 if(t
[0].getKind() == kind::CONST_RATIONAL
){
79 Rational neg
= -(t
[0].getConst
<Rational
>());
80 return RewriteResponse(REWRITE_DONE
, mkRationalNode(neg
));
83 Node noUminus
= makeUnaryMinusNode(t
[0]);
85 return RewriteResponse(REWRITE_DONE
, noUminus
);
87 return RewriteResponse(REWRITE_AGAIN
, noUminus
);
90 RewriteResponse
ArithRewriter::preRewriteTerm(TNode t
){
92 return rewriteConstant(t
);
94 return rewriteVariable(t
);
96 switch(Kind k
= t
.getKind()){
98 return rewriteMinus(t
, true);
100 return rewriteUMinus(t
, true);
102 case kind::DIVISION_TOTAL
:
103 return rewriteDiv(t
,true);
105 return preRewritePlus(t
);
107 case kind::NONLINEAR_MULT
: return preRewriteMult(t
);
108 case kind::IAND
: return RewriteResponse(REWRITE_DONE
, t
);
109 case kind::EXPONENTIAL
:
115 case kind::COTANGENT
:
117 case kind::ARCCOSINE
:
118 case kind::ARCTANGENT
:
119 case kind::ARCCOSECANT
:
120 case kind::ARCSECANT
:
121 case kind::ARCCOTANGENT
:
122 case kind::SQRT
: return preRewriteTranscendental(t
);
123 case kind::INTS_DIVISION
:
124 case kind::INTS_MODULUS
: return rewriteIntsDivMod(t
, true);
125 case kind::INTS_DIVISION_TOTAL
:
126 case kind::INTS_MODULUS_TOTAL
:
127 return rewriteIntsDivModTotal(t
,true);
130 const Rational
& rat
= t
[0].getConst
<Rational
>();
132 return RewriteResponse(REWRITE_DONE
, t
[0]);
134 return RewriteResponse(REWRITE_DONE
,
135 NodeManager::currentNM()->mkConst(-rat
));
138 return RewriteResponse(REWRITE_DONE
, t
);
139 case kind::IS_INTEGER
:
140 case kind::TO_INTEGER
:
141 return RewriteResponse(REWRITE_DONE
, t
);
143 case kind::CAST_TO_REAL
: return RewriteResponse(REWRITE_DONE
, t
[0]);
145 return RewriteResponse(REWRITE_DONE
, t
);
147 return RewriteResponse(REWRITE_DONE
, t
);
148 default: Unhandled() << k
;
153 RewriteResponse
ArithRewriter::postRewriteTerm(TNode t
){
155 return rewriteConstant(t
);
157 return rewriteVariable(t
);
161 return rewriteMinus(t
, false);
163 return rewriteUMinus(t
, false);
165 case kind::DIVISION_TOTAL
:
166 return rewriteDiv(t
, false);
168 return postRewritePlus(t
);
170 case kind::NONLINEAR_MULT
: return postRewriteMult(t
);
171 case kind::IAND
: return postRewriteIAnd(t
);
172 case kind::EXPONENTIAL
:
178 case kind::COTANGENT
:
180 case kind::ARCCOSINE
:
181 case kind::ARCTANGENT
:
182 case kind::ARCCOSECANT
:
183 case kind::ARCSECANT
:
184 case kind::ARCCOTANGENT
:
185 case kind::SQRT
: return postRewriteTranscendental(t
);
186 case kind::INTS_DIVISION
:
187 case kind::INTS_MODULUS
: return rewriteIntsDivMod(t
, false);
188 case kind::INTS_DIVISION_TOTAL
:
189 case kind::INTS_MODULUS_TOTAL
:
190 return rewriteIntsDivModTotal(t
, false);
193 const Rational
& rat
= t
[0].getConst
<Rational
>();
195 return RewriteResponse(REWRITE_DONE
, t
[0]);
197 return RewriteResponse(REWRITE_DONE
,
198 NodeManager::currentNM()->mkConst(-rat
));
201 return RewriteResponse(REWRITE_DONE
, t
);
203 case kind::CAST_TO_REAL
: return RewriteResponse(REWRITE_DONE
, t
[0]);
204 case kind::TO_INTEGER
:
206 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(Rational(t
[0].getConst
<Rational
>().floor())));
208 if(t
[0].getType().isInteger()) {
209 return RewriteResponse(REWRITE_DONE
, t
[0]);
211 //Unimplemented() << "TO_INTEGER, nonconstant";
212 //return rewriteToInteger(t);
213 return RewriteResponse(REWRITE_DONE
, t
);
214 case kind::IS_INTEGER
:
216 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(t
[0].getConst
<Rational
>().getDenominator() == 1));
218 if(t
[0].getType().isInteger()) {
219 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(true));
221 //Unimplemented() << "IS_INTEGER, nonconstant";
222 //return rewriteIsInteger(t);
223 return RewriteResponse(REWRITE_DONE
, t
);
226 if(t
[1].getKind() == kind::CONST_RATIONAL
){
227 const Rational
& exp
= t
[1].getConst
<Rational
>();
230 return RewriteResponse(REWRITE_DONE
, mkRationalNode(Rational(1)));
231 }else if(exp
.sgn() > 0 && exp
.isIntegral()){
232 cvc5::Rational
r(expr::NodeValue::MAX_CHILDREN
);
235 unsigned num
= exp
.getNumerator().toUnsignedInt();
237 return RewriteResponse(REWRITE_AGAIN
, base
);
239 NodeBuilder
nb(kind::MULT
);
240 for(unsigned i
=0; i
< num
; ++i
){
243 Assert(nb
.getNumChildren() > 0);
245 return RewriteResponse(REWRITE_AGAIN
, mult
);
251 // Todo improve the exception thrown
252 std::stringstream ss
;
253 ss
<< "The exponent of the POW(^) operator can only be a positive "
254 "integral constant below "
255 << (expr::NodeValue::MAX_CHILDREN
+ 1) << ". ";
256 ss
<< "Exception occurred in:" << std::endl
;
258 throw LogicException(ss
.str());
261 return RewriteResponse(REWRITE_DONE
, t
);
269 RewriteResponse
ArithRewriter::preRewriteMult(TNode t
){
270 Assert(t
.getKind() == kind::MULT
|| t
.getKind() == kind::NONLINEAR_MULT
);
272 if(t
.getNumChildren() == 2){
273 if(t
[0].getKind() == kind::CONST_RATIONAL
274 && t
[0].getConst
<Rational
>().isOne()){
275 return RewriteResponse(REWRITE_DONE
, t
[1]);
277 if(t
[1].getKind() == kind::CONST_RATIONAL
278 && t
[1].getConst
<Rational
>().isOne()){
279 return RewriteResponse(REWRITE_DONE
, t
[0]);
283 // Rewrite multiplications with a 0 argument and to 0
284 for(TNode::iterator i
= t
.begin(); i
!= t
.end(); ++i
) {
285 if((*i
).getKind() == kind::CONST_RATIONAL
) {
286 if((*i
).getConst
<Rational
>().isZero()) {
288 return RewriteResponse(REWRITE_DONE
, zero
);
292 return RewriteResponse(REWRITE_DONE
, t
);
295 static bool canFlatten(Kind k
, TNode t
){
296 for(TNode::iterator i
= t
.begin(); i
!= t
.end(); ++i
) {
298 if(child
.getKind() == k
){
305 static void flatten(std::vector
<TNode
>& pb
, Kind k
, TNode t
){
306 if(t
.getKind() == k
){
307 for(TNode::iterator i
= t
.begin(); i
!= t
.end(); ++i
) {
309 if(child
.getKind() == k
){
310 flatten(pb
, k
, child
);
320 static Node
flatten(Kind k
, TNode t
){
321 std::vector
<TNode
> pb
;
323 Assert(pb
.size() >= 2);
324 return NodeManager::currentNM()->mkNode(k
, pb
);
327 RewriteResponse
ArithRewriter::preRewritePlus(TNode t
){
328 Assert(t
.getKind() == kind::PLUS
);
330 if(canFlatten(kind::PLUS
, t
)){
331 return RewriteResponse(REWRITE_DONE
, flatten(kind::PLUS
, t
));
333 return RewriteResponse(REWRITE_DONE
, t
);
337 RewriteResponse
ArithRewriter::postRewritePlus(TNode t
){
338 Assert(t
.getKind() == kind::PLUS
);
340 std::vector
<Monomial
> monomials
;
341 std::vector
<Polynomial
> polynomials
;
343 for(TNode::iterator i
= t
.begin(), end
= t
.end(); i
!= end
; ++i
){
345 if(Monomial::isMember(curr
)){
346 monomials
.push_back(Monomial::parseMonomial(curr
));
348 polynomials
.push_back(Polynomial::parsePolynomial(curr
));
352 if(!monomials
.empty()){
353 Monomial::sort(monomials
);
354 Monomial::combineAdjacentMonomials(monomials
);
355 polynomials
.push_back(Polynomial::mkPolynomial(monomials
));
358 Polynomial res
= Polynomial::sumPolynomials(polynomials
);
360 return RewriteResponse(REWRITE_DONE
, res
.getNode());
363 RewriteResponse
ArithRewriter::postRewriteMult(TNode t
){
364 Assert(t
.getKind() == kind::MULT
|| t
.getKind() == kind::NONLINEAR_MULT
);
366 Polynomial res
= Polynomial::mkOne();
368 for(TNode::iterator i
= t
.begin(), end
= t
.end(); i
!= end
; ++i
){
370 Polynomial currPoly
= Polynomial::parsePolynomial(curr
);
372 res
= res
* currPoly
;
375 return RewriteResponse(REWRITE_DONE
, res
.getNode());
378 RewriteResponse
ArithRewriter::postRewriteIAnd(TNode t
)
380 Assert(t
.getKind() == kind::IAND
);
381 NodeManager
* nm
= NodeManager::currentNM();
382 // if constant, we eliminate
383 if (t
[0].isConst() && t
[1].isConst())
385 size_t bsize
= t
.getOperator().getConst
<IntAnd
>().d_size
;
386 Node iToBvop
= nm
->mkConst(IntToBitVector(bsize
));
387 Node arg1
= nm
->mkNode(kind::INT_TO_BITVECTOR
, iToBvop
, t
[0]);
388 Node arg2
= nm
->mkNode(kind::INT_TO_BITVECTOR
, iToBvop
, t
[1]);
389 Node bvand
= nm
->mkNode(kind::BITVECTOR_AND
, arg1
, arg2
);
390 Node ret
= nm
->mkNode(kind::BITVECTOR_TO_NAT
, bvand
);
391 return RewriteResponse(REWRITE_AGAIN_FULL
, ret
);
393 else if (t
[0] > t
[1])
395 // ((_ iand k) x y) ---> ((_ iand k) y x) if x > y by node ordering
396 Node ret
= nm
->mkNode(kind::IAND
, t
.getOperator(), t
[1], t
[0]);
397 return RewriteResponse(REWRITE_AGAIN
, ret
);
399 else if (t
[0] == t
[1])
401 // ((_ iand k) x x) ---> x
402 return RewriteResponse(REWRITE_DONE
, t
[0]);
404 // simplifications involving constants
405 for (unsigned i
= 0; i
< 2; i
++)
411 if (t
[i
].getConst
<Rational
>().sgn() == 0)
413 // ((_ iand k) 0 y) ---> 0
414 return RewriteResponse(REWRITE_DONE
, t
[i
]);
417 return RewriteResponse(REWRITE_DONE
, t
);
420 RewriteResponse
ArithRewriter::preRewriteTranscendental(TNode t
) {
421 return RewriteResponse(REWRITE_DONE
, t
);
424 RewriteResponse
ArithRewriter::postRewriteTranscendental(TNode t
) {
425 Trace("arith-tf-rewrite") << "Rewrite transcendental function : " << t
<< std::endl
;
426 NodeManager
* nm
= NodeManager::currentNM();
427 switch( t
.getKind() ){
428 case kind::EXPONENTIAL
: {
429 if(t
[0].getKind() == kind::CONST_RATIONAL
){
430 Node one
= nm
->mkConst(Rational(1));
431 if(t
[0].getConst
<Rational
>().sgn()>=0 && t
[0].getType().isInteger() && t
[0]!=one
){
432 return RewriteResponse(
434 nm
->mkNode(kind::POW
, nm
->mkNode(kind::EXPONENTIAL
, one
), t
[0]));
436 return RewriteResponse(REWRITE_DONE
, t
);
439 else if (t
[0].getKind() == kind::PLUS
)
441 std::vector
<Node
> product
;
442 for (const Node tc
: t
[0])
444 product
.push_back(nm
->mkNode(kind::EXPONENTIAL
, tc
));
446 // We need to do a full rewrite here, since we can get exponentials of
447 // constants, e.g. when we are rewriting exp(2 + x)
448 return RewriteResponse(REWRITE_AGAIN_FULL
,
449 nm
->mkNode(kind::MULT
, product
));
454 if(t
[0].getKind() == kind::CONST_RATIONAL
){
455 const Rational
& rat
= t
[0].getConst
<Rational
>();
457 return RewriteResponse(REWRITE_DONE
, nm
->mkConst(Rational(0)));
459 else if (rat
.sgn() == -1)
462 nm
->mkNode(kind::UMINUS
, nm
->mkNode(kind::SINE
, nm
->mkConst(-rat
)));
463 return RewriteResponse(REWRITE_AGAIN_FULL
, ret
);
466 // get the factor of PI in the argument
470 std::map
<Node
, Node
> msum
;
471 if (ArithMSum::getMonomialSum(t
[0], msum
))
474 std::map
<Node
, Node
>::iterator itm
= msum
.find(pi
);
475 if (itm
!= msum
.end())
477 if (itm
->second
.isNull())
479 pi_factor
= mkRationalNode(Rational(1));
483 pi_factor
= itm
->second
;
488 rem
= ArithMSum::mkNode(msum
);
497 // if there is a factor of PI
498 if( !pi_factor
.isNull() ){
499 Trace("arith-tf-rewrite-debug") << "Process pi factor = " << pi_factor
<< std::endl
;
500 Rational r
= pi_factor
.getConst
<Rational
>();
501 Rational r_abs
= r
.abs();
502 Rational rone
= Rational(1);
503 Node ntwo
= mkRationalNode(Rational(2));
506 //add/substract 2*pi beyond scope
507 Node ra_div_two
= nm
->mkNode(
508 kind::INTS_DIVISION
, mkRationalNode(r_abs
+ rone
), ntwo
);
512 nm
->mkNode(kind::MINUS
,
514 nm
->mkNode(kind::MULT
, ntwo
, ra_div_two
));
516 Assert(r
.sgn() == -1);
518 nm
->mkNode(kind::PLUS
,
520 nm
->mkNode(kind::MULT
, ntwo
, ra_div_two
));
522 Node new_arg
= nm
->mkNode(kind::MULT
, new_pi_factor
, pi
);
525 new_arg
= nm
->mkNode(kind::PLUS
, new_arg
, rem
);
527 // sin( 2*n*PI + x ) = sin( x )
528 return RewriteResponse(REWRITE_AGAIN_FULL
,
529 nm
->mkNode(kind::SINE
, new_arg
));
531 else if (r_abs
== rone
)
533 // sin( PI + x ) = -sin( x )
536 return RewriteResponse(REWRITE_DONE
, mkRationalNode(Rational(0)));
540 return RewriteResponse(
542 nm
->mkNode(kind::UMINUS
, nm
->mkNode(kind::SINE
, rem
)));
545 else if (rem
.isNull())
547 // other rational cases based on Niven's theorem
548 // (https://en.wikipedia.org/wiki/Niven%27s_theorem)
549 Integer one
= Integer(1);
550 Integer two
= Integer(2);
551 Integer six
= Integer(6);
552 if (r_abs
.getDenominator() == two
)
554 Assert(r_abs
.getNumerator() == one
);
555 return RewriteResponse(REWRITE_DONE
,
556 mkRationalNode(Rational(r
.sgn())));
558 else if (r_abs
.getDenominator() == six
)
560 Integer five
= Integer(5);
561 if (r_abs
.getNumerator() == one
|| r_abs
.getNumerator() == five
)
563 return RewriteResponse(
565 mkRationalNode(Rational(r
.sgn()) / Rational(2)));
573 return RewriteResponse(
575 nm
->mkNode(kind::SINE
,
576 nm
->mkNode(kind::MINUS
,
577 nm
->mkNode(kind::MULT
,
578 nm
->mkConst(Rational(1) / Rational(2)),
585 return RewriteResponse(REWRITE_AGAIN_FULL
,
586 nm
->mkNode(kind::DIVISION
,
587 nm
->mkNode(kind::SINE
, t
[0]),
588 nm
->mkNode(kind::COSINE
, t
[0])));
593 return RewriteResponse(REWRITE_AGAIN_FULL
,
594 nm
->mkNode(kind::DIVISION
,
595 mkRationalNode(Rational(1)),
596 nm
->mkNode(kind::SINE
, t
[0])));
601 return RewriteResponse(REWRITE_AGAIN_FULL
,
602 nm
->mkNode(kind::DIVISION
,
603 mkRationalNode(Rational(1)),
604 nm
->mkNode(kind::COSINE
, t
[0])));
607 case kind::COTANGENT
:
609 return RewriteResponse(REWRITE_AGAIN_FULL
,
610 nm
->mkNode(kind::DIVISION
,
611 nm
->mkNode(kind::COSINE
, t
[0]),
612 nm
->mkNode(kind::SINE
, t
[0])));
618 return RewriteResponse(REWRITE_DONE
, t
);
621 RewriteResponse
ArithRewriter::postRewriteAtom(TNode atom
){
622 if(atom
.getKind() == kind::IS_INTEGER
) {
623 if(atom
[0].isConst()) {
624 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(atom
[0].getConst
<Rational
>().isIntegral()));
626 if(atom
[0].getType().isInteger()) {
627 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(true));
629 // not supported, but this isn't the right place to complain
630 return RewriteResponse(REWRITE_DONE
, atom
);
631 } else if(atom
.getKind() == kind::DIVISIBLE
) {
632 if(atom
[0].isConst()) {
633 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(bool((atom
[0].getConst
<Rational
>() / atom
.getOperator().getConst
<Divisible
>().k
).isIntegral())));
635 if(atom
.getOperator().getConst
<Divisible
>().k
.isOne()) {
636 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(true));
638 return RewriteResponse(REWRITE_AGAIN
, NodeManager::currentNM()->mkNode(kind::EQUAL
, NodeManager::currentNM()->mkNode(kind::INTS_MODULUS_TOTAL
, atom
[0], NodeManager::currentNM()->mkConst(Rational(atom
.getOperator().getConst
<Divisible
>().k
))), NodeManager::currentNM()->mkConst(Rational(0))));
642 TNode left
= atom
[0];
643 TNode right
= atom
[1];
645 Polynomial pleft
= Polynomial::parsePolynomial(left
);
646 Polynomial pright
= Polynomial::parsePolynomial(right
);
648 Debug("arith::rewriter") << "pleft " << pleft
.getNode() << std::endl
;
649 Debug("arith::rewriter") << "pright " << pright
.getNode() << std::endl
;
651 Comparison cmp
= Comparison::mkComparison(atom
.getKind(), pleft
, pright
);
652 Assert(cmp
.isNormalForm());
653 return RewriteResponse(REWRITE_DONE
, cmp
.getNode());
656 RewriteResponse
ArithRewriter::preRewriteAtom(TNode atom
){
657 Assert(isAtom(atom
));
659 NodeManager
* currNM
= NodeManager::currentNM();
661 if(atom
.getKind() == kind::EQUAL
) {
662 if(atom
[0] == atom
[1]) {
663 return RewriteResponse(REWRITE_DONE
, currNM
->mkConst(true));
665 }else if(atom
.getKind() == kind::GT
){
666 Node leq
= currNM
->mkNode(kind::LEQ
, atom
[0], atom
[1]);
667 return RewriteResponse(REWRITE_DONE
, currNM
->mkNode(kind::NOT
, leq
));
668 }else if(atom
.getKind() == kind::LT
){
669 Node geq
= currNM
->mkNode(kind::GEQ
, atom
[0], atom
[1]);
670 return RewriteResponse(REWRITE_DONE
, currNM
->mkNode(kind::NOT
, geq
));
671 }else if(atom
.getKind() == kind::IS_INTEGER
){
672 if(atom
[0].getType().isInteger()){
673 return RewriteResponse(REWRITE_DONE
, currNM
->mkConst(true));
675 }else if(atom
.getKind() == kind::DIVISIBLE
){
676 if(atom
.getOperator().getConst
<Divisible
>().k
.isOne()){
677 return RewriteResponse(REWRITE_DONE
, currNM
->mkConst(true));
681 return RewriteResponse(REWRITE_DONE
, atom
);
684 RewriteResponse
ArithRewriter::postRewrite(TNode t
){
686 RewriteResponse response
= postRewriteTerm(t
);
687 if (Debug
.isOn("arith::rewriter") && response
.d_status
== REWRITE_DONE
)
689 Polynomial::parsePolynomial(response
.d_node
);
693 RewriteResponse response
= postRewriteAtom(t
);
694 if (Debug
.isOn("arith::rewriter") && response
.d_status
== REWRITE_DONE
)
696 Comparison::parseNormalForm(response
.d_node
);
704 RewriteResponse
ArithRewriter::preRewrite(TNode t
){
706 return preRewriteTerm(t
);
708 return preRewriteAtom(t
);
714 Node
ArithRewriter::makeUnaryMinusNode(TNode n
){
715 Rational
qNegOne(-1);
716 return NodeManager::currentNM()->mkNode(kind::MULT
, mkRationalNode(qNegOne
),n
);
719 Node
ArithRewriter::makeSubtractionNode(TNode l
, TNode r
){
720 Node negR
= makeUnaryMinusNode(r
);
721 Node diff
= NodeManager::currentNM()->mkNode(kind::PLUS
, l
, negR
);
726 RewriteResponse
ArithRewriter::rewriteDiv(TNode t
, bool pre
){
727 Assert(t
.getKind() == kind::DIVISION_TOTAL
|| t
.getKind() == kind::DIVISION
);
731 if(right
.getKind() == kind::CONST_RATIONAL
){
732 const Rational
& den
= right
.getConst
<Rational
>();
735 if(t
.getKind() == kind::DIVISION_TOTAL
){
736 return RewriteResponse(REWRITE_DONE
, mkRationalNode(0));
738 // This is unsupported, but this is not a good place to complain
739 return RewriteResponse(REWRITE_DONE
, t
);
742 Assert(den
!= Rational(0));
744 if(left
.getKind() == kind::CONST_RATIONAL
){
745 const Rational
& num
= left
.getConst
<Rational
>();
746 Rational div
= num
/ den
;
747 Node result
= mkRationalNode(div
);
748 return RewriteResponse(REWRITE_DONE
, result
);
751 Rational div
= den
.inverse();
753 Node result
= mkRationalNode(div
);
755 Node mult
= NodeManager::currentNM()->mkNode(kind::MULT
,left
,result
);
757 return RewriteResponse(REWRITE_DONE
, mult
);
759 return RewriteResponse(REWRITE_AGAIN
, mult
);
762 return RewriteResponse(REWRITE_DONE
, t
);
766 RewriteResponse
ArithRewriter::rewriteIntsDivMod(TNode t
, bool pre
)
768 NodeManager
* nm
= NodeManager::currentNM();
769 Kind k
= t
.getKind();
770 Node zero
= nm
->mkConst(Rational(0));
771 if (k
== kind::INTS_MODULUS
)
773 if (t
[1].isConst() && !t
[1].getConst
<Rational
>().isZero())
775 // can immediately replace by INTS_MODULUS_TOTAL
776 Node ret
= nm
->mkNode(kind::INTS_MODULUS_TOTAL
, t
[0], t
[1]);
777 return returnRewrite(t
, ret
, Rewrite::MOD_TOTAL_BY_CONST
);
780 if (k
== kind::INTS_DIVISION
)
782 if (t
[1].isConst() && !t
[1].getConst
<Rational
>().isZero())
784 // can immediately replace by INTS_DIVISION_TOTAL
785 Node ret
= nm
->mkNode(kind::INTS_DIVISION_TOTAL
, t
[0], t
[1]);
786 return returnRewrite(t
, ret
, Rewrite::DIV_TOTAL_BY_CONST
);
789 return RewriteResponse(REWRITE_DONE
, t
);
792 RewriteResponse
ArithRewriter::rewriteIntsDivModTotal(TNode t
, bool pre
)
796 // do not rewrite at prewrite.
797 return RewriteResponse(REWRITE_DONE
, t
);
799 NodeManager
* nm
= NodeManager::currentNM();
800 Kind k
= t
.getKind();
801 Assert(k
== kind::INTS_MODULUS_TOTAL
|| k
== kind::INTS_DIVISION_TOTAL
);
804 bool dIsConstant
= d
.getKind() == kind::CONST_RATIONAL
;
805 if(dIsConstant
&& d
.getConst
<Rational
>().isZero()){
806 // (div x 0) ---> 0 or (mod x 0) ---> 0
807 return returnRewrite(t
, mkRationalNode(0), Rewrite::DIV_MOD_BY_ZERO
);
808 }else if(dIsConstant
&& d
.getConst
<Rational
>().isOne()){
809 if (k
== kind::INTS_MODULUS_TOTAL
)
812 return returnRewrite(t
, mkRationalNode(0), Rewrite::MOD_BY_ONE
);
814 Assert(k
== kind::INTS_DIVISION_TOTAL
);
816 return returnRewrite(t
, n
, Rewrite::DIV_BY_ONE
);
818 else if (dIsConstant
&& d
.getConst
<Rational
>().sgn() < 0)
821 // (div x (- c)) ---> (- (div x c))
822 // (mod x (- c)) ---> (mod x c)
823 Node nn
= nm
->mkNode(k
, t
[0], nm
->mkConst(-t
[1].getConst
<Rational
>()));
824 Node ret
= (k
== kind::INTS_DIVISION
|| k
== kind::INTS_DIVISION_TOTAL
)
825 ? nm
->mkNode(kind::UMINUS
, nn
)
827 return returnRewrite(t
, ret
, Rewrite::DIV_MOD_PULL_NEG_DEN
);
829 else if (dIsConstant
&& n
.getKind() == kind::CONST_RATIONAL
)
831 Assert(d
.getConst
<Rational
>().isIntegral());
832 Assert(n
.getConst
<Rational
>().isIntegral());
833 Assert(!d
.getConst
<Rational
>().isZero());
834 Integer di
= d
.getConst
<Rational
>().getNumerator();
835 Integer ni
= n
.getConst
<Rational
>().getNumerator();
837 bool isDiv
= (k
== kind::INTS_DIVISION
|| k
== kind::INTS_DIVISION_TOTAL
);
839 Integer result
= isDiv
? ni
.euclidianDivideQuotient(di
) : ni
.euclidianDivideRemainder(di
);
841 // constant evaluation
842 // (mod c1 c2) ---> c3 or (div c1 c2) ---> c3
843 Node resultNode
= mkRationalNode(Rational(result
));
844 return returnRewrite(t
, resultNode
, Rewrite::CONST_EVAL
);
846 if (k
== kind::INTS_MODULUS_TOTAL
)
848 // Note these rewrites do not need to account for modulus by zero as being
849 // a UF, which is handled by the reduction of INTS_MODULUS.
850 Kind k0
= t
[0].getKind();
851 if (k0
== kind::INTS_MODULUS_TOTAL
&& t
[0][1] == t
[1])
853 // (mod (mod x c) c) --> (mod x c)
854 return returnRewrite(t
, t
[0], Rewrite::MOD_OVER_MOD
);
856 else if (k0
== kind::NONLINEAR_MULT
|| k0
== kind::MULT
|| k0
== kind::PLUS
)
859 std::vector
<Node
> newChildren
;
860 bool childChanged
= false;
861 for (const Node
& tc
: t
[0])
863 if (tc
.getKind() == kind::INTS_MODULUS_TOTAL
&& tc
[1] == t
[1])
865 newChildren
.push_back(tc
[0]);
869 newChildren
.push_back(tc
);
873 // (mod (op ... (mod x c) ...) c) ---> (mod (op ... x ...) c) where
874 // op is one of { NONLINEAR_MULT, MULT, PLUS }.
875 Node ret
= nm
->mkNode(k0
, newChildren
);
876 ret
= nm
->mkNode(kind::INTS_MODULUS_TOTAL
, ret
, t
[1]);
877 return returnRewrite(t
, ret
, Rewrite::MOD_CHILD_MOD
);
883 Assert(k
== kind::INTS_DIVISION_TOTAL
);
884 // Note these rewrites do not need to account for division by zero as being
885 // a UF, which is handled by the reduction of INTS_DIVISION.
886 if (t
[0].getKind() == kind::INTS_MODULUS_TOTAL
&& t
[0][1] == t
[1])
888 // (div (mod x c) c) --> 0
889 Node ret
= mkRationalNode(0);
890 return returnRewrite(t
, ret
, Rewrite::DIV_OVER_MOD
);
893 return RewriteResponse(REWRITE_DONE
, t
);
896 RewriteResponse
ArithRewriter::returnRewrite(TNode t
, Node ret
, Rewrite r
)
898 Trace("arith-rewrite") << "ArithRewriter : " << t
<< " == " << ret
<< " by "
900 return RewriteResponse(REWRITE_AGAIN_FULL
, ret
);
904 } // namespace theory