1 /******************************************************************************
2 * Top contributors (to current version):
3 * Andrew Reynolds, Tim King, Morgan Deters
5 * This file is part of the cvc5 project.
7 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 * in the top-level source directory and their institutional affiliations.
9 * All rights reserved. See the file COPYING in the top-level source
10 * directory for licensing information.
11 * ****************************************************************************
13 * [[ Add one-line brief description here ]]
15 * [[ Add lengthier description here ]]
16 * \todo document this file
24 #include "smt/logic_exception.h"
25 #include "theory/arith/arith_msum.h"
26 #include "theory/arith/arith_rewriter.h"
27 #include "theory/arith/arith_utilities.h"
28 #include "theory/arith/normal_form.h"
29 #include "theory/arith/operator_elim.h"
30 #include "theory/theory.h"
31 #include "util/iand.h"
37 ArithRewriter::ArithRewriter(OperatorElim
& oe
) : d_opElim(oe
) {}
39 bool ArithRewriter::isAtom(TNode n
) {
41 return arith::isRelationOperator(k
) || k
== kind::IS_INTEGER
42 || k
== kind::DIVISIBLE
;
45 RewriteResponse
ArithRewriter::rewriteConstant(TNode t
){
47 Assert(t
.getKind() == kind::CONST_RATIONAL
);
49 return RewriteResponse(REWRITE_DONE
, t
);
52 RewriteResponse
ArithRewriter::rewriteVariable(TNode t
){
55 return RewriteResponse(REWRITE_DONE
, t
);
58 RewriteResponse
ArithRewriter::rewriteMinus(TNode t
, bool pre
){
59 Assert(t
.getKind() == kind::MINUS
);
64 Node zeroNode
= mkRationalNode(zero
);
65 return RewriteResponse(REWRITE_DONE
, zeroNode
);
67 Node noMinus
= makeSubtractionNode(t
[0],t
[1]);
68 return RewriteResponse(REWRITE_DONE
, noMinus
);
71 Polynomial minuend
= Polynomial::parsePolynomial(t
[0]);
72 Polynomial subtrahend
= Polynomial::parsePolynomial(t
[1]);
73 Polynomial diff
= minuend
- subtrahend
;
74 return RewriteResponse(REWRITE_DONE
, diff
.getNode());
78 RewriteResponse
ArithRewriter::rewriteUMinus(TNode t
, bool pre
){
79 Assert(t
.getKind() == kind::UMINUS
);
81 if(t
[0].getKind() == kind::CONST_RATIONAL
){
82 Rational neg
= -(t
[0].getConst
<Rational
>());
83 return RewriteResponse(REWRITE_DONE
, mkRationalNode(neg
));
86 Node noUminus
= makeUnaryMinusNode(t
[0]);
88 return RewriteResponse(REWRITE_DONE
, noUminus
);
90 return RewriteResponse(REWRITE_AGAIN
, noUminus
);
93 RewriteResponse
ArithRewriter::preRewriteTerm(TNode t
){
95 return rewriteConstant(t
);
97 return rewriteVariable(t
);
99 switch(Kind k
= t
.getKind()){
101 return rewriteMinus(t
, true);
103 return rewriteUMinus(t
, true);
105 case kind::DIVISION_TOTAL
:
106 return rewriteDiv(t
,true);
108 return preRewritePlus(t
);
110 case kind::NONLINEAR_MULT
: return preRewriteMult(t
);
111 case kind::IAND
: return RewriteResponse(REWRITE_DONE
, t
);
112 case kind::EXPONENTIAL
:
118 case kind::COTANGENT
:
120 case kind::ARCCOSINE
:
121 case kind::ARCTANGENT
:
122 case kind::ARCCOSECANT
:
123 case kind::ARCSECANT
:
124 case kind::ARCCOTANGENT
:
125 case kind::SQRT
: return preRewriteTranscendental(t
);
126 case kind::INTS_DIVISION
:
127 case kind::INTS_MODULUS
: return rewriteIntsDivMod(t
, true);
128 case kind::INTS_DIVISION_TOTAL
:
129 case kind::INTS_MODULUS_TOTAL
:
130 return rewriteIntsDivModTotal(t
,true);
133 const Rational
& rat
= t
[0].getConst
<Rational
>();
135 return RewriteResponse(REWRITE_DONE
, t
[0]);
137 return RewriteResponse(REWRITE_DONE
,
138 NodeManager::currentNM()->mkConst(-rat
));
141 return RewriteResponse(REWRITE_DONE
, t
);
142 case kind::IS_INTEGER
:
143 case kind::TO_INTEGER
:
144 return RewriteResponse(REWRITE_DONE
, t
);
146 case kind::CAST_TO_REAL
: return RewriteResponse(REWRITE_DONE
, t
[0]);
148 return RewriteResponse(REWRITE_DONE
, t
);
150 return RewriteResponse(REWRITE_DONE
, t
);
151 default: Unhandled() << k
;
156 RewriteResponse
ArithRewriter::postRewriteTerm(TNode t
){
158 return rewriteConstant(t
);
160 return rewriteVariable(t
);
164 return rewriteMinus(t
, false);
166 return rewriteUMinus(t
, false);
168 case kind::DIVISION_TOTAL
:
169 return rewriteDiv(t
, false);
171 return postRewritePlus(t
);
173 case kind::NONLINEAR_MULT
: return postRewriteMult(t
);
174 case kind::IAND
: return postRewriteIAnd(t
);
175 case kind::EXPONENTIAL
:
181 case kind::COTANGENT
:
183 case kind::ARCCOSINE
:
184 case kind::ARCTANGENT
:
185 case kind::ARCCOSECANT
:
186 case kind::ARCSECANT
:
187 case kind::ARCCOTANGENT
:
188 case kind::SQRT
: return postRewriteTranscendental(t
);
189 case kind::INTS_DIVISION
:
190 case kind::INTS_MODULUS
: return rewriteIntsDivMod(t
, false);
191 case kind::INTS_DIVISION_TOTAL
:
192 case kind::INTS_MODULUS_TOTAL
:
193 return rewriteIntsDivModTotal(t
, false);
196 const Rational
& rat
= t
[0].getConst
<Rational
>();
198 return RewriteResponse(REWRITE_DONE
, t
[0]);
200 return RewriteResponse(REWRITE_DONE
,
201 NodeManager::currentNM()->mkConst(-rat
));
204 return RewriteResponse(REWRITE_DONE
, t
);
206 case kind::CAST_TO_REAL
: return RewriteResponse(REWRITE_DONE
, t
[0]);
207 case kind::TO_INTEGER
:
209 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(Rational(t
[0].getConst
<Rational
>().floor())));
211 if(t
[0].getType().isInteger()) {
212 return RewriteResponse(REWRITE_DONE
, t
[0]);
214 //Unimplemented() << "TO_INTEGER, nonconstant";
215 //return rewriteToInteger(t);
216 return RewriteResponse(REWRITE_DONE
, t
);
217 case kind::IS_INTEGER
:
219 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(t
[0].getConst
<Rational
>().getDenominator() == 1));
221 if(t
[0].getType().isInteger()) {
222 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(true));
224 //Unimplemented() << "IS_INTEGER, nonconstant";
225 //return rewriteIsInteger(t);
226 return RewriteResponse(REWRITE_DONE
, t
);
229 if(t
[1].getKind() == kind::CONST_RATIONAL
){
230 const Rational
& exp
= t
[1].getConst
<Rational
>();
233 return RewriteResponse(REWRITE_DONE
, mkRationalNode(Rational(1)));
234 }else if(exp
.sgn() > 0 && exp
.isIntegral()){
235 cvc5::Rational
r(expr::NodeValue::MAX_CHILDREN
);
238 unsigned num
= exp
.getNumerator().toUnsignedInt();
240 return RewriteResponse(REWRITE_AGAIN
, base
);
242 NodeBuilder
nb(kind::MULT
);
243 for(unsigned i
=0; i
< num
; ++i
){
246 Assert(nb
.getNumChildren() > 0);
248 return RewriteResponse(REWRITE_AGAIN
, mult
);
254 // Todo improve the exception thrown
255 std::stringstream ss
;
256 ss
<< "The exponent of the POW(^) operator can only be a positive "
257 "integral constant below "
258 << (expr::NodeValue::MAX_CHILDREN
+ 1) << ". ";
259 ss
<< "Exception occurred in:" << std::endl
;
261 throw LogicException(ss
.str());
264 return RewriteResponse(REWRITE_DONE
, t
);
272 RewriteResponse
ArithRewriter::preRewriteMult(TNode t
){
273 Assert(t
.getKind() == kind::MULT
|| t
.getKind() == kind::NONLINEAR_MULT
);
275 if(t
.getNumChildren() == 2){
276 if(t
[0].getKind() == kind::CONST_RATIONAL
277 && t
[0].getConst
<Rational
>().isOne()){
278 return RewriteResponse(REWRITE_DONE
, t
[1]);
280 if(t
[1].getKind() == kind::CONST_RATIONAL
281 && t
[1].getConst
<Rational
>().isOne()){
282 return RewriteResponse(REWRITE_DONE
, t
[0]);
286 // Rewrite multiplications with a 0 argument and to 0
287 for(TNode::iterator i
= t
.begin(); i
!= t
.end(); ++i
) {
288 if((*i
).getKind() == kind::CONST_RATIONAL
) {
289 if((*i
).getConst
<Rational
>().isZero()) {
291 return RewriteResponse(REWRITE_DONE
, zero
);
295 return RewriteResponse(REWRITE_DONE
, t
);
298 static bool canFlatten(Kind k
, TNode t
){
299 for(TNode::iterator i
= t
.begin(); i
!= t
.end(); ++i
) {
301 if(child
.getKind() == k
){
308 static void flatten(std::vector
<TNode
>& pb
, Kind k
, TNode t
){
309 if(t
.getKind() == k
){
310 for(TNode::iterator i
= t
.begin(); i
!= t
.end(); ++i
) {
312 if(child
.getKind() == k
){
313 flatten(pb
, k
, child
);
323 static Node
flatten(Kind k
, TNode t
){
324 std::vector
<TNode
> pb
;
326 Assert(pb
.size() >= 2);
327 return NodeManager::currentNM()->mkNode(k
, pb
);
330 RewriteResponse
ArithRewriter::preRewritePlus(TNode t
){
331 Assert(t
.getKind() == kind::PLUS
);
333 if(canFlatten(kind::PLUS
, t
)){
334 return RewriteResponse(REWRITE_DONE
, flatten(kind::PLUS
, t
));
336 return RewriteResponse(REWRITE_DONE
, t
);
340 RewriteResponse
ArithRewriter::postRewritePlus(TNode t
){
341 Assert(t
.getKind() == kind::PLUS
);
343 std::vector
<Monomial
> monomials
;
344 std::vector
<Polynomial
> polynomials
;
346 for(TNode::iterator i
= t
.begin(), end
= t
.end(); i
!= end
; ++i
){
348 if(Monomial::isMember(curr
)){
349 monomials
.push_back(Monomial::parseMonomial(curr
));
351 polynomials
.push_back(Polynomial::parsePolynomial(curr
));
355 if(!monomials
.empty()){
356 Monomial::sort(monomials
);
357 Monomial::combineAdjacentMonomials(monomials
);
358 polynomials
.push_back(Polynomial::mkPolynomial(monomials
));
361 Polynomial res
= Polynomial::sumPolynomials(polynomials
);
363 return RewriteResponse(REWRITE_DONE
, res
.getNode());
366 RewriteResponse
ArithRewriter::postRewriteMult(TNode t
){
367 Assert(t
.getKind() == kind::MULT
|| t
.getKind() == kind::NONLINEAR_MULT
);
369 Polynomial res
= Polynomial::mkOne();
371 for(TNode::iterator i
= t
.begin(), end
= t
.end(); i
!= end
; ++i
){
373 Polynomial currPoly
= Polynomial::parsePolynomial(curr
);
375 res
= res
* currPoly
;
378 return RewriteResponse(REWRITE_DONE
, res
.getNode());
381 RewriteResponse
ArithRewriter::postRewriteIAnd(TNode t
)
383 Assert(t
.getKind() == kind::IAND
);
384 NodeManager
* nm
= NodeManager::currentNM();
385 // if constant, we eliminate
386 if (t
[0].isConst() && t
[1].isConst())
388 size_t bsize
= t
.getOperator().getConst
<IntAnd
>().d_size
;
389 Node iToBvop
= nm
->mkConst(IntToBitVector(bsize
));
390 Node arg1
= nm
->mkNode(kind::INT_TO_BITVECTOR
, iToBvop
, t
[0]);
391 Node arg2
= nm
->mkNode(kind::INT_TO_BITVECTOR
, iToBvop
, t
[1]);
392 Node bvand
= nm
->mkNode(kind::BITVECTOR_AND
, arg1
, arg2
);
393 Node ret
= nm
->mkNode(kind::BITVECTOR_TO_NAT
, bvand
);
394 return RewriteResponse(REWRITE_AGAIN_FULL
, ret
);
396 else if (t
[0] > t
[1])
398 // ((_ iand k) x y) ---> ((_ iand k) y x) if x > y by node ordering
399 Node ret
= nm
->mkNode(kind::IAND
, t
.getOperator(), t
[1], t
[0]);
400 return RewriteResponse(REWRITE_AGAIN
, ret
);
402 else if (t
[0] == t
[1])
404 // ((_ iand k) x x) ---> x
405 return RewriteResponse(REWRITE_DONE
, t
[0]);
407 // simplifications involving constants
408 for (unsigned i
= 0; i
< 2; i
++)
414 if (t
[i
].getConst
<Rational
>().sgn() == 0)
416 // ((_ iand k) 0 y) ---> 0
417 return RewriteResponse(REWRITE_DONE
, t
[i
]);
420 return RewriteResponse(REWRITE_DONE
, t
);
423 RewriteResponse
ArithRewriter::preRewriteTranscendental(TNode t
) {
424 return RewriteResponse(REWRITE_DONE
, t
);
427 RewriteResponse
ArithRewriter::postRewriteTranscendental(TNode t
) {
428 Trace("arith-tf-rewrite") << "Rewrite transcendental function : " << t
<< std::endl
;
429 NodeManager
* nm
= NodeManager::currentNM();
430 switch( t
.getKind() ){
431 case kind::EXPONENTIAL
: {
432 if(t
[0].getKind() == kind::CONST_RATIONAL
){
433 Node one
= nm
->mkConst(Rational(1));
434 if(t
[0].getConst
<Rational
>().sgn()>=0 && t
[0].getType().isInteger() && t
[0]!=one
){
435 return RewriteResponse(
437 nm
->mkNode(kind::POW
, nm
->mkNode(kind::EXPONENTIAL
, one
), t
[0]));
439 return RewriteResponse(REWRITE_DONE
, t
);
442 else if (t
[0].getKind() == kind::PLUS
)
444 std::vector
<Node
> product
;
445 for (const Node tc
: t
[0])
447 product
.push_back(nm
->mkNode(kind::EXPONENTIAL
, tc
));
449 // We need to do a full rewrite here, since we can get exponentials of
450 // constants, e.g. when we are rewriting exp(2 + x)
451 return RewriteResponse(REWRITE_AGAIN_FULL
,
452 nm
->mkNode(kind::MULT
, product
));
457 if(t
[0].getKind() == kind::CONST_RATIONAL
){
458 const Rational
& rat
= t
[0].getConst
<Rational
>();
460 return RewriteResponse(REWRITE_DONE
, nm
->mkConst(Rational(0)));
462 else if (rat
.sgn() == -1)
465 nm
->mkNode(kind::UMINUS
, nm
->mkNode(kind::SINE
, nm
->mkConst(-rat
)));
466 return RewriteResponse(REWRITE_AGAIN_FULL
, ret
);
469 // get the factor of PI in the argument
473 std::map
<Node
, Node
> msum
;
474 if (ArithMSum::getMonomialSum(t
[0], msum
))
477 std::map
<Node
, Node
>::iterator itm
= msum
.find(pi
);
478 if (itm
!= msum
.end())
480 if (itm
->second
.isNull())
482 pi_factor
= mkRationalNode(Rational(1));
486 pi_factor
= itm
->second
;
491 rem
= ArithMSum::mkNode(msum
);
500 // if there is a factor of PI
501 if( !pi_factor
.isNull() ){
502 Trace("arith-tf-rewrite-debug") << "Process pi factor = " << pi_factor
<< std::endl
;
503 Rational r
= pi_factor
.getConst
<Rational
>();
504 Rational r_abs
= r
.abs();
505 Rational rone
= Rational(1);
506 Node ntwo
= mkRationalNode(Rational(2));
509 //add/substract 2*pi beyond scope
510 Node ra_div_two
= nm
->mkNode(
511 kind::INTS_DIVISION
, mkRationalNode(r_abs
+ rone
), ntwo
);
515 nm
->mkNode(kind::MINUS
,
517 nm
->mkNode(kind::MULT
, ntwo
, ra_div_two
));
519 Assert(r
.sgn() == -1);
521 nm
->mkNode(kind::PLUS
,
523 nm
->mkNode(kind::MULT
, ntwo
, ra_div_two
));
525 Node new_arg
= nm
->mkNode(kind::MULT
, new_pi_factor
, pi
);
528 new_arg
= nm
->mkNode(kind::PLUS
, new_arg
, rem
);
530 // sin( 2*n*PI + x ) = sin( x )
531 return RewriteResponse(REWRITE_AGAIN_FULL
,
532 nm
->mkNode(kind::SINE
, new_arg
));
534 else if (r_abs
== rone
)
536 // sin( PI + x ) = -sin( x )
539 return RewriteResponse(REWRITE_DONE
, mkRationalNode(Rational(0)));
543 return RewriteResponse(
545 nm
->mkNode(kind::UMINUS
, nm
->mkNode(kind::SINE
, rem
)));
548 else if (rem
.isNull())
550 // other rational cases based on Niven's theorem
551 // (https://en.wikipedia.org/wiki/Niven%27s_theorem)
552 Integer one
= Integer(1);
553 Integer two
= Integer(2);
554 Integer six
= Integer(6);
555 if (r_abs
.getDenominator() == two
)
557 Assert(r_abs
.getNumerator() == one
);
558 return RewriteResponse(REWRITE_DONE
,
559 mkRationalNode(Rational(r
.sgn())));
561 else if (r_abs
.getDenominator() == six
)
563 Integer five
= Integer(5);
564 if (r_abs
.getNumerator() == one
|| r_abs
.getNumerator() == five
)
566 return RewriteResponse(
568 mkRationalNode(Rational(r
.sgn()) / Rational(2)));
576 return RewriteResponse(
578 nm
->mkNode(kind::SINE
,
579 nm
->mkNode(kind::MINUS
,
580 nm
->mkNode(kind::MULT
,
581 nm
->mkConst(Rational(1) / Rational(2)),
588 return RewriteResponse(REWRITE_AGAIN_FULL
,
589 nm
->mkNode(kind::DIVISION
,
590 nm
->mkNode(kind::SINE
, t
[0]),
591 nm
->mkNode(kind::COSINE
, t
[0])));
596 return RewriteResponse(REWRITE_AGAIN_FULL
,
597 nm
->mkNode(kind::DIVISION
,
598 mkRationalNode(Rational(1)),
599 nm
->mkNode(kind::SINE
, t
[0])));
604 return RewriteResponse(REWRITE_AGAIN_FULL
,
605 nm
->mkNode(kind::DIVISION
,
606 mkRationalNode(Rational(1)),
607 nm
->mkNode(kind::COSINE
, t
[0])));
610 case kind::COTANGENT
:
612 return RewriteResponse(REWRITE_AGAIN_FULL
,
613 nm
->mkNode(kind::DIVISION
,
614 nm
->mkNode(kind::COSINE
, t
[0]),
615 nm
->mkNode(kind::SINE
, t
[0])));
621 return RewriteResponse(REWRITE_DONE
, t
);
624 RewriteResponse
ArithRewriter::postRewriteAtom(TNode atom
){
625 if(atom
.getKind() == kind::IS_INTEGER
) {
626 if(atom
[0].isConst()) {
627 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(atom
[0].getConst
<Rational
>().isIntegral()));
629 if(atom
[0].getType().isInteger()) {
630 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(true));
632 // not supported, but this isn't the right place to complain
633 return RewriteResponse(REWRITE_DONE
, atom
);
634 } else if(atom
.getKind() == kind::DIVISIBLE
) {
635 if(atom
[0].isConst()) {
636 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(bool((atom
[0].getConst
<Rational
>() / atom
.getOperator().getConst
<Divisible
>().k
).isIntegral())));
638 if(atom
.getOperator().getConst
<Divisible
>().k
.isOne()) {
639 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(true));
641 return RewriteResponse(REWRITE_AGAIN
, NodeManager::currentNM()->mkNode(kind::EQUAL
, NodeManager::currentNM()->mkNode(kind::INTS_MODULUS_TOTAL
, atom
[0], NodeManager::currentNM()->mkConst(Rational(atom
.getOperator().getConst
<Divisible
>().k
))), NodeManager::currentNM()->mkConst(Rational(0))));
645 TNode left
= atom
[0];
646 TNode right
= atom
[1];
648 Polynomial pleft
= Polynomial::parsePolynomial(left
);
649 Polynomial pright
= Polynomial::parsePolynomial(right
);
651 Debug("arith::rewriter") << "pleft " << pleft
.getNode() << std::endl
;
652 Debug("arith::rewriter") << "pright " << pright
.getNode() << std::endl
;
654 Comparison cmp
= Comparison::mkComparison(atom
.getKind(), pleft
, pright
);
655 Assert(cmp
.isNormalForm());
656 return RewriteResponse(REWRITE_DONE
, cmp
.getNode());
659 RewriteResponse
ArithRewriter::preRewriteAtom(TNode atom
){
660 Assert(isAtom(atom
));
662 NodeManager
* currNM
= NodeManager::currentNM();
664 if(atom
.getKind() == kind::EQUAL
) {
665 if(atom
[0] == atom
[1]) {
666 return RewriteResponse(REWRITE_DONE
, currNM
->mkConst(true));
668 }else if(atom
.getKind() == kind::GT
){
669 Node leq
= currNM
->mkNode(kind::LEQ
, atom
[0], atom
[1]);
670 return RewriteResponse(REWRITE_DONE
, currNM
->mkNode(kind::NOT
, leq
));
671 }else if(atom
.getKind() == kind::LT
){
672 Node geq
= currNM
->mkNode(kind::GEQ
, atom
[0], atom
[1]);
673 return RewriteResponse(REWRITE_DONE
, currNM
->mkNode(kind::NOT
, geq
));
674 }else if(atom
.getKind() == kind::IS_INTEGER
){
675 if(atom
[0].getType().isInteger()){
676 return RewriteResponse(REWRITE_DONE
, currNM
->mkConst(true));
678 }else if(atom
.getKind() == kind::DIVISIBLE
){
679 if(atom
.getOperator().getConst
<Divisible
>().k
.isOne()){
680 return RewriteResponse(REWRITE_DONE
, currNM
->mkConst(true));
684 return RewriteResponse(REWRITE_DONE
, atom
);
687 RewriteResponse
ArithRewriter::postRewrite(TNode t
){
689 RewriteResponse response
= postRewriteTerm(t
);
690 if (Debug
.isOn("arith::rewriter") && response
.d_status
== REWRITE_DONE
)
692 Polynomial::parsePolynomial(response
.d_node
);
696 RewriteResponse response
= postRewriteAtom(t
);
697 if (Debug
.isOn("arith::rewriter") && response
.d_status
== REWRITE_DONE
)
699 Comparison::parseNormalForm(response
.d_node
);
707 RewriteResponse
ArithRewriter::preRewrite(TNode t
){
709 return preRewriteTerm(t
);
711 return preRewriteAtom(t
);
717 Node
ArithRewriter::makeUnaryMinusNode(TNode n
){
718 Rational
qNegOne(-1);
719 return NodeManager::currentNM()->mkNode(kind::MULT
, mkRationalNode(qNegOne
),n
);
722 Node
ArithRewriter::makeSubtractionNode(TNode l
, TNode r
){
723 Node negR
= makeUnaryMinusNode(r
);
724 Node diff
= NodeManager::currentNM()->mkNode(kind::PLUS
, l
, negR
);
729 RewriteResponse
ArithRewriter::rewriteDiv(TNode t
, bool pre
){
730 Assert(t
.getKind() == kind::DIVISION_TOTAL
|| t
.getKind() == kind::DIVISION
);
734 if(right
.getKind() == kind::CONST_RATIONAL
){
735 const Rational
& den
= right
.getConst
<Rational
>();
738 if(t
.getKind() == kind::DIVISION_TOTAL
){
739 return RewriteResponse(REWRITE_DONE
, mkRationalNode(0));
741 // This is unsupported, but this is not a good place to complain
742 return RewriteResponse(REWRITE_DONE
, t
);
745 Assert(den
!= Rational(0));
747 if(left
.getKind() == kind::CONST_RATIONAL
){
748 const Rational
& num
= left
.getConst
<Rational
>();
749 Rational div
= num
/ den
;
750 Node result
= mkRationalNode(div
);
751 return RewriteResponse(REWRITE_DONE
, result
);
754 Rational div
= den
.inverse();
756 Node result
= mkRationalNode(div
);
758 Node mult
= NodeManager::currentNM()->mkNode(kind::MULT
,left
,result
);
760 return RewriteResponse(REWRITE_DONE
, mult
);
762 return RewriteResponse(REWRITE_AGAIN
, mult
);
765 return RewriteResponse(REWRITE_DONE
, t
);
769 RewriteResponse
ArithRewriter::rewriteIntsDivMod(TNode t
, bool pre
)
771 NodeManager
* nm
= NodeManager::currentNM();
772 Kind k
= t
.getKind();
773 Node zero
= nm
->mkConst(Rational(0));
774 if (k
== kind::INTS_MODULUS
)
776 if (t
[1].isConst() && !t
[1].getConst
<Rational
>().isZero())
778 // can immediately replace by INTS_MODULUS_TOTAL
779 Node ret
= nm
->mkNode(kind::INTS_MODULUS_TOTAL
, t
[0], t
[1]);
780 return returnRewrite(t
, ret
, Rewrite::MOD_TOTAL_BY_CONST
);
783 if (k
== kind::INTS_DIVISION
)
785 if (t
[1].isConst() && !t
[1].getConst
<Rational
>().isZero())
787 // can immediately replace by INTS_DIVISION_TOTAL
788 Node ret
= nm
->mkNode(kind::INTS_DIVISION_TOTAL
, t
[0], t
[1]);
789 return returnRewrite(t
, ret
, Rewrite::DIV_TOTAL_BY_CONST
);
792 return RewriteResponse(REWRITE_DONE
, t
);
795 RewriteResponse
ArithRewriter::rewriteIntsDivModTotal(TNode t
, bool pre
)
799 // do not rewrite at prewrite.
800 return RewriteResponse(REWRITE_DONE
, t
);
802 NodeManager
* nm
= NodeManager::currentNM();
803 Kind k
= t
.getKind();
804 Assert(k
== kind::INTS_MODULUS_TOTAL
|| k
== kind::INTS_DIVISION_TOTAL
);
807 bool dIsConstant
= d
.getKind() == kind::CONST_RATIONAL
;
808 if(dIsConstant
&& d
.getConst
<Rational
>().isZero()){
809 // (div x 0) ---> 0 or (mod x 0) ---> 0
810 return returnRewrite(t
, mkRationalNode(0), Rewrite::DIV_MOD_BY_ZERO
);
811 }else if(dIsConstant
&& d
.getConst
<Rational
>().isOne()){
812 if (k
== kind::INTS_MODULUS_TOTAL
)
815 return returnRewrite(t
, mkRationalNode(0), Rewrite::MOD_BY_ONE
);
817 Assert(k
== kind::INTS_DIVISION_TOTAL
);
819 return returnRewrite(t
, n
, Rewrite::DIV_BY_ONE
);
821 else if (dIsConstant
&& d
.getConst
<Rational
>().sgn() < 0)
824 // (div x (- c)) ---> (- (div x c))
825 // (mod x (- c)) ---> (mod x c)
826 Node nn
= nm
->mkNode(k
, t
[0], nm
->mkConst(-t
[1].getConst
<Rational
>()));
827 Node ret
= (k
== kind::INTS_DIVISION
|| k
== kind::INTS_DIVISION_TOTAL
)
828 ? nm
->mkNode(kind::UMINUS
, nn
)
830 return returnRewrite(t
, ret
, Rewrite::DIV_MOD_PULL_NEG_DEN
);
832 else if (dIsConstant
&& n
.getKind() == kind::CONST_RATIONAL
)
834 Assert(d
.getConst
<Rational
>().isIntegral());
835 Assert(n
.getConst
<Rational
>().isIntegral());
836 Assert(!d
.getConst
<Rational
>().isZero());
837 Integer di
= d
.getConst
<Rational
>().getNumerator();
838 Integer ni
= n
.getConst
<Rational
>().getNumerator();
840 bool isDiv
= (k
== kind::INTS_DIVISION
|| k
== kind::INTS_DIVISION_TOTAL
);
842 Integer result
= isDiv
? ni
.euclidianDivideQuotient(di
) : ni
.euclidianDivideRemainder(di
);
844 // constant evaluation
845 // (mod c1 c2) ---> c3 or (div c1 c2) ---> c3
846 Node resultNode
= mkRationalNode(Rational(result
));
847 return returnRewrite(t
, resultNode
, Rewrite::CONST_EVAL
);
849 if (k
== kind::INTS_MODULUS_TOTAL
)
851 // Note these rewrites do not need to account for modulus by zero as being
852 // a UF, which is handled by the reduction of INTS_MODULUS.
853 Kind k0
= t
[0].getKind();
854 if (k0
== kind::INTS_MODULUS_TOTAL
&& t
[0][1] == t
[1])
856 // (mod (mod x c) c) --> (mod x c)
857 return returnRewrite(t
, t
[0], Rewrite::MOD_OVER_MOD
);
859 else if (k0
== kind::NONLINEAR_MULT
|| k0
== kind::MULT
|| k0
== kind::PLUS
)
862 std::vector
<Node
> newChildren
;
863 bool childChanged
= false;
864 for (const Node
& tc
: t
[0])
866 if (tc
.getKind() == kind::INTS_MODULUS_TOTAL
&& tc
[1] == t
[1])
868 newChildren
.push_back(tc
[0]);
872 newChildren
.push_back(tc
);
876 // (mod (op ... (mod x c) ...) c) ---> (mod (op ... x ...) c) where
877 // op is one of { NONLINEAR_MULT, MULT, PLUS }.
878 Node ret
= nm
->mkNode(k0
, newChildren
);
879 ret
= nm
->mkNode(kind::INTS_MODULUS_TOTAL
, ret
, t
[1]);
880 return returnRewrite(t
, ret
, Rewrite::MOD_CHILD_MOD
);
886 Assert(k
== kind::INTS_DIVISION_TOTAL
);
887 // Note these rewrites do not need to account for division by zero as being
888 // a UF, which is handled by the reduction of INTS_DIVISION.
889 if (t
[0].getKind() == kind::INTS_MODULUS_TOTAL
&& t
[0][1] == t
[1])
891 // (div (mod x c) c) --> 0
892 Node ret
= mkRationalNode(0);
893 return returnRewrite(t
, ret
, Rewrite::DIV_OVER_MOD
);
896 return RewriteResponse(REWRITE_DONE
, t
);
899 TrustNode
ArithRewriter::expandDefinition(Node node
)
901 // call eliminate operators, to eliminate partial operators only
902 std::vector
<SkolemLemma
> lems
;
903 TrustNode ret
= d_opElim
.eliminate(node
, lems
, true);
904 Assert(lems
.empty());
908 RewriteResponse
ArithRewriter::returnRewrite(TNode t
, Node ret
, Rewrite r
)
910 Trace("arith-rewrite") << "ArithRewriter : " << t
<< " == " << ret
<< " by "
912 return RewriteResponse(REWRITE_AGAIN_FULL
, ret
);
916 } // namespace theory