1 /******************************************************************************
2 * Top contributors (to current version):
3 * Andrew Reynolds, Tim King, Morgan Deters
5 * This file is part of the cvc5 project.
7 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 * in the top-level source directory and their institutional affiliations.
9 * All rights reserved. See the file COPYING in the top-level source
10 * directory for licensing information.
11 * ****************************************************************************
13 * [[ Add one-line brief description here ]]
15 * [[ Add lengthier description here ]]
16 * \todo document this file
19 #include "theory/arith/arith_rewriter.h"
26 #include "smt/logic_exception.h"
27 #include "theory/arith/arith_msum.h"
28 #include "theory/arith/arith_utilities.h"
29 #include "theory/arith/normal_form.h"
30 #include "theory/arith/operator_elim.h"
31 #include "theory/theory.h"
32 #include "util/bitvector.h"
33 #include "util/divisible.h"
34 #include "util/iand.h"
40 ArithRewriter::ArithRewriter(OperatorElim
& oe
) : d_opElim(oe
) {}
42 bool ArithRewriter::isAtom(TNode n
) {
44 return arith::isRelationOperator(k
) || k
== kind::IS_INTEGER
45 || k
== kind::DIVISIBLE
;
48 RewriteResponse
ArithRewriter::rewriteConstant(TNode t
){
50 Assert(t
.getKind() == kind::CONST_RATIONAL
);
52 return RewriteResponse(REWRITE_DONE
, t
);
55 RewriteResponse
ArithRewriter::rewriteVariable(TNode t
){
58 return RewriteResponse(REWRITE_DONE
, t
);
61 RewriteResponse
ArithRewriter::rewriteMinus(TNode t
, bool pre
){
62 Assert(t
.getKind() == kind::MINUS
);
67 Node zeroNode
= mkRationalNode(zero
);
68 return RewriteResponse(REWRITE_DONE
, zeroNode
);
70 Node noMinus
= makeSubtractionNode(t
[0],t
[1]);
71 return RewriteResponse(REWRITE_DONE
, noMinus
);
74 Polynomial minuend
= Polynomial::parsePolynomial(t
[0]);
75 Polynomial subtrahend
= Polynomial::parsePolynomial(t
[1]);
76 Polynomial diff
= minuend
- subtrahend
;
77 return RewriteResponse(REWRITE_DONE
, diff
.getNode());
81 RewriteResponse
ArithRewriter::rewriteUMinus(TNode t
, bool pre
){
82 Assert(t
.getKind() == kind::UMINUS
);
84 if(t
[0].getKind() == kind::CONST_RATIONAL
){
85 Rational neg
= -(t
[0].getConst
<Rational
>());
86 return RewriteResponse(REWRITE_DONE
, mkRationalNode(neg
));
89 Node noUminus
= makeUnaryMinusNode(t
[0]);
91 return RewriteResponse(REWRITE_DONE
, noUminus
);
93 return RewriteResponse(REWRITE_AGAIN
, noUminus
);
96 RewriteResponse
ArithRewriter::preRewriteTerm(TNode t
){
98 return rewriteConstant(t
);
100 return rewriteVariable(t
);
102 switch(Kind k
= t
.getKind()){
104 return rewriteMinus(t
, true);
106 return rewriteUMinus(t
, true);
108 case kind::DIVISION_TOTAL
:
109 return rewriteDiv(t
,true);
111 return preRewritePlus(t
);
113 case kind::NONLINEAR_MULT
: return preRewriteMult(t
);
114 case kind::IAND
: return RewriteResponse(REWRITE_DONE
, t
);
115 case kind::POW2
: return RewriteResponse(REWRITE_DONE
, t
);
116 case kind::EXPONENTIAL
:
122 case kind::COTANGENT
:
124 case kind::ARCCOSINE
:
125 case kind::ARCTANGENT
:
126 case kind::ARCCOSECANT
:
127 case kind::ARCSECANT
:
128 case kind::ARCCOTANGENT
:
129 case kind::SQRT
: return preRewriteTranscendental(t
);
130 case kind::INTS_DIVISION
:
131 case kind::INTS_MODULUS
: return rewriteIntsDivMod(t
, true);
132 case kind::INTS_DIVISION_TOTAL
:
133 case kind::INTS_MODULUS_TOTAL
:
134 return rewriteIntsDivModTotal(t
,true);
137 const Rational
& rat
= t
[0].getConst
<Rational
>();
139 return RewriteResponse(REWRITE_DONE
, t
[0]);
141 return RewriteResponse(REWRITE_DONE
,
142 NodeManager::currentNM()->mkConst(-rat
));
145 return RewriteResponse(REWRITE_DONE
, t
);
146 case kind::IS_INTEGER
:
147 case kind::TO_INTEGER
:
148 return RewriteResponse(REWRITE_DONE
, t
);
150 case kind::CAST_TO_REAL
: return RewriteResponse(REWRITE_DONE
, t
[0]);
152 return RewriteResponse(REWRITE_DONE
, t
);
154 return RewriteResponse(REWRITE_DONE
, t
);
155 default: Unhandled() << k
;
160 RewriteResponse
ArithRewriter::postRewriteTerm(TNode t
){
162 return rewriteConstant(t
);
164 return rewriteVariable(t
);
168 return rewriteMinus(t
, false);
170 return rewriteUMinus(t
, false);
172 case kind::DIVISION_TOTAL
:
173 return rewriteDiv(t
, false);
175 return postRewritePlus(t
);
177 case kind::NONLINEAR_MULT
: return postRewriteMult(t
);
178 case kind::IAND
: return postRewriteIAnd(t
);
179 case kind::POW2
: return postRewritePow2(t
);
180 case kind::EXPONENTIAL
:
186 case kind::COTANGENT
:
188 case kind::ARCCOSINE
:
189 case kind::ARCTANGENT
:
190 case kind::ARCCOSECANT
:
191 case kind::ARCSECANT
:
192 case kind::ARCCOTANGENT
:
193 case kind::SQRT
: return postRewriteTranscendental(t
);
194 case kind::INTS_DIVISION
:
195 case kind::INTS_MODULUS
: return rewriteIntsDivMod(t
, false);
196 case kind::INTS_DIVISION_TOTAL
:
197 case kind::INTS_MODULUS_TOTAL
:
198 return rewriteIntsDivModTotal(t
, false);
201 const Rational
& rat
= t
[0].getConst
<Rational
>();
203 return RewriteResponse(REWRITE_DONE
, t
[0]);
205 return RewriteResponse(REWRITE_DONE
,
206 NodeManager::currentNM()->mkConst(-rat
));
209 return RewriteResponse(REWRITE_DONE
, t
);
211 case kind::CAST_TO_REAL
: return RewriteResponse(REWRITE_DONE
, t
[0]);
212 case kind::TO_INTEGER
:
214 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(Rational(t
[0].getConst
<Rational
>().floor())));
216 if(t
[0].getType().isInteger()) {
217 return RewriteResponse(REWRITE_DONE
, t
[0]);
219 //Unimplemented() << "TO_INTEGER, nonconstant";
220 //return rewriteToInteger(t);
221 return RewriteResponse(REWRITE_DONE
, t
);
222 case kind::IS_INTEGER
:
224 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(t
[0].getConst
<Rational
>().getDenominator() == 1));
226 if(t
[0].getType().isInteger()) {
227 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(true));
229 //Unimplemented() << "IS_INTEGER, nonconstant";
230 //return rewriteIsInteger(t);
231 return RewriteResponse(REWRITE_DONE
, t
);
234 if(t
[1].getKind() == kind::CONST_RATIONAL
){
235 const Rational
& exp
= t
[1].getConst
<Rational
>();
238 return RewriteResponse(REWRITE_DONE
, mkRationalNode(Rational(1)));
239 }else if(exp
.sgn() > 0 && exp
.isIntegral()){
240 cvc5::Rational
r(expr::NodeValue::MAX_CHILDREN
);
243 unsigned num
= exp
.getNumerator().toUnsignedInt();
245 return RewriteResponse(REWRITE_AGAIN
, base
);
247 NodeBuilder
nb(kind::MULT
);
248 for(unsigned i
=0; i
< num
; ++i
){
251 Assert(nb
.getNumChildren() > 0);
253 return RewriteResponse(REWRITE_AGAIN
, mult
);
259 // Todo improve the exception thrown
260 std::stringstream ss
;
261 ss
<< "The exponent of the POW(^) operator can only be a positive "
262 "integral constant below "
263 << (expr::NodeValue::MAX_CHILDREN
+ 1) << ". ";
264 ss
<< "Exception occurred in:" << std::endl
;
266 throw LogicException(ss
.str());
269 return RewriteResponse(REWRITE_DONE
, t
);
277 RewriteResponse
ArithRewriter::preRewriteMult(TNode t
){
278 Assert(t
.getKind() == kind::MULT
|| t
.getKind() == kind::NONLINEAR_MULT
);
280 if(t
.getNumChildren() == 2){
281 if(t
[0].getKind() == kind::CONST_RATIONAL
282 && t
[0].getConst
<Rational
>().isOne()){
283 return RewriteResponse(REWRITE_DONE
, t
[1]);
285 if(t
[1].getKind() == kind::CONST_RATIONAL
286 && t
[1].getConst
<Rational
>().isOne()){
287 return RewriteResponse(REWRITE_DONE
, t
[0]);
291 // Rewrite multiplications with a 0 argument and to 0
292 for(TNode::iterator i
= t
.begin(); i
!= t
.end(); ++i
) {
293 if((*i
).getKind() == kind::CONST_RATIONAL
) {
294 if((*i
).getConst
<Rational
>().isZero()) {
296 return RewriteResponse(REWRITE_DONE
, zero
);
300 return RewriteResponse(REWRITE_DONE
, t
);
303 static bool canFlatten(Kind k
, TNode t
){
304 for(TNode::iterator i
= t
.begin(); i
!= t
.end(); ++i
) {
306 if(child
.getKind() == k
){
313 static void flatten(std::vector
<TNode
>& pb
, Kind k
, TNode t
){
314 if(t
.getKind() == k
){
315 for(TNode::iterator i
= t
.begin(); i
!= t
.end(); ++i
) {
317 if(child
.getKind() == k
){
318 flatten(pb
, k
, child
);
328 static Node
flatten(Kind k
, TNode t
){
329 std::vector
<TNode
> pb
;
331 Assert(pb
.size() >= 2);
332 return NodeManager::currentNM()->mkNode(k
, pb
);
335 RewriteResponse
ArithRewriter::preRewritePlus(TNode t
){
336 Assert(t
.getKind() == kind::PLUS
);
338 if(canFlatten(kind::PLUS
, t
)){
339 return RewriteResponse(REWRITE_DONE
, flatten(kind::PLUS
, t
));
341 return RewriteResponse(REWRITE_DONE
, t
);
345 RewriteResponse
ArithRewriter::postRewritePlus(TNode t
){
346 Assert(t
.getKind() == kind::PLUS
);
348 std::vector
<Monomial
> monomials
;
349 std::vector
<Polynomial
> polynomials
;
351 for(TNode::iterator i
= t
.begin(), end
= t
.end(); i
!= end
; ++i
){
353 if(Monomial::isMember(curr
)){
354 monomials
.push_back(Monomial::parseMonomial(curr
));
356 polynomials
.push_back(Polynomial::parsePolynomial(curr
));
360 if(!monomials
.empty()){
361 Monomial::sort(monomials
);
362 Monomial::combineAdjacentMonomials(monomials
);
363 polynomials
.push_back(Polynomial::mkPolynomial(monomials
));
366 Polynomial res
= Polynomial::sumPolynomials(polynomials
);
368 return RewriteResponse(REWRITE_DONE
, res
.getNode());
371 RewriteResponse
ArithRewriter::postRewriteMult(TNode t
){
372 Assert(t
.getKind() == kind::MULT
|| t
.getKind() == kind::NONLINEAR_MULT
);
374 Polynomial res
= Polynomial::mkOne();
376 for(TNode::iterator i
= t
.begin(), end
= t
.end(); i
!= end
; ++i
){
378 Polynomial currPoly
= Polynomial::parsePolynomial(curr
);
380 res
= res
* currPoly
;
383 return RewriteResponse(REWRITE_DONE
, res
.getNode());
386 RewriteResponse
ArithRewriter::postRewritePow2(TNode t
)
388 Assert(t
.getKind() == kind::POW2
);
389 NodeManager
* nm
= NodeManager::currentNM();
390 // if constant, we eliminate
393 // pow2 is only supported for integers
394 Assert(t
[0].getType().isInteger());
395 Integer i
= t
[0].getConst
<Rational
>().getNumerator();
396 unsigned long k
= i
.getUnsignedLong();
397 Node ret
= nm
->mkConst
<Rational
>(Rational(Integer(2).pow(k
), Integer(1)));
398 return RewriteResponse(REWRITE_DONE
, ret
);
400 return RewriteResponse(REWRITE_DONE
, t
);
403 RewriteResponse
ArithRewriter::postRewriteIAnd(TNode t
)
405 Assert(t
.getKind() == kind::IAND
);
406 NodeManager
* nm
= NodeManager::currentNM();
407 // if constant, we eliminate
408 if (t
[0].isConst() && t
[1].isConst())
410 size_t bsize
= t
.getOperator().getConst
<IntAnd
>().d_size
;
411 Node iToBvop
= nm
->mkConst(IntToBitVector(bsize
));
412 Node arg1
= nm
->mkNode(kind::INT_TO_BITVECTOR
, iToBvop
, t
[0]);
413 Node arg2
= nm
->mkNode(kind::INT_TO_BITVECTOR
, iToBvop
, t
[1]);
414 Node bvand
= nm
->mkNode(kind::BITVECTOR_AND
, arg1
, arg2
);
415 Node ret
= nm
->mkNode(kind::BITVECTOR_TO_NAT
, bvand
);
416 return RewriteResponse(REWRITE_AGAIN_FULL
, ret
);
418 else if (t
[0] > t
[1])
420 // ((_ iand k) x y) ---> ((_ iand k) y x) if x > y by node ordering
421 Node ret
= nm
->mkNode(kind::IAND
, t
.getOperator(), t
[1], t
[0]);
422 return RewriteResponse(REWRITE_AGAIN
, ret
);
424 else if (t
[0] == t
[1])
426 // ((_ iand k) x x) ---> x
427 return RewriteResponse(REWRITE_DONE
, t
[0]);
429 // simplifications involving constants
430 for (unsigned i
= 0; i
< 2; i
++)
436 if (t
[i
].getConst
<Rational
>().sgn() == 0)
438 // ((_ iand k) 0 y) ---> 0
439 return RewriteResponse(REWRITE_DONE
, t
[i
]);
442 return RewriteResponse(REWRITE_DONE
, t
);
445 RewriteResponse
ArithRewriter::preRewriteTranscendental(TNode t
) {
446 return RewriteResponse(REWRITE_DONE
, t
);
449 RewriteResponse
ArithRewriter::postRewriteTranscendental(TNode t
) {
450 Trace("arith-tf-rewrite") << "Rewrite transcendental function : " << t
<< std::endl
;
451 NodeManager
* nm
= NodeManager::currentNM();
452 switch( t
.getKind() ){
453 case kind::EXPONENTIAL
: {
454 if(t
[0].getKind() == kind::CONST_RATIONAL
){
455 Node one
= nm
->mkConst(Rational(1));
456 if(t
[0].getConst
<Rational
>().sgn()>=0 && t
[0].getType().isInteger() && t
[0]!=one
){
457 return RewriteResponse(
459 nm
->mkNode(kind::POW
, nm
->mkNode(kind::EXPONENTIAL
, one
), t
[0]));
461 return RewriteResponse(REWRITE_DONE
, t
);
464 else if (t
[0].getKind() == kind::PLUS
)
466 std::vector
<Node
> product
;
467 for (const Node tc
: t
[0])
469 product
.push_back(nm
->mkNode(kind::EXPONENTIAL
, tc
));
471 // We need to do a full rewrite here, since we can get exponentials of
472 // constants, e.g. when we are rewriting exp(2 + x)
473 return RewriteResponse(REWRITE_AGAIN_FULL
,
474 nm
->mkNode(kind::MULT
, product
));
479 if(t
[0].getKind() == kind::CONST_RATIONAL
){
480 const Rational
& rat
= t
[0].getConst
<Rational
>();
482 return RewriteResponse(REWRITE_DONE
, nm
->mkConst(Rational(0)));
484 else if (rat
.sgn() == -1)
487 nm
->mkNode(kind::UMINUS
, nm
->mkNode(kind::SINE
, nm
->mkConst(-rat
)));
488 return RewriteResponse(REWRITE_AGAIN_FULL
, ret
);
491 // get the factor of PI in the argument
495 std::map
<Node
, Node
> msum
;
496 if (ArithMSum::getMonomialSum(t
[0], msum
))
499 std::map
<Node
, Node
>::iterator itm
= msum
.find(pi
);
500 if (itm
!= msum
.end())
502 if (itm
->second
.isNull())
504 pi_factor
= mkRationalNode(Rational(1));
508 pi_factor
= itm
->second
;
513 rem
= ArithMSum::mkNode(msum
);
522 // if there is a factor of PI
523 if( !pi_factor
.isNull() ){
524 Trace("arith-tf-rewrite-debug") << "Process pi factor = " << pi_factor
<< std::endl
;
525 Rational r
= pi_factor
.getConst
<Rational
>();
526 Rational r_abs
= r
.abs();
527 Rational rone
= Rational(1);
528 Node ntwo
= mkRationalNode(Rational(2));
531 //add/substract 2*pi beyond scope
532 Node ra_div_two
= nm
->mkNode(
533 kind::INTS_DIVISION
, mkRationalNode(r_abs
+ rone
), ntwo
);
537 nm
->mkNode(kind::MINUS
,
539 nm
->mkNode(kind::MULT
, ntwo
, ra_div_two
));
541 Assert(r
.sgn() == -1);
543 nm
->mkNode(kind::PLUS
,
545 nm
->mkNode(kind::MULT
, ntwo
, ra_div_two
));
547 Node new_arg
= nm
->mkNode(kind::MULT
, new_pi_factor
, pi
);
550 new_arg
= nm
->mkNode(kind::PLUS
, new_arg
, rem
);
552 // sin( 2*n*PI + x ) = sin( x )
553 return RewriteResponse(REWRITE_AGAIN_FULL
,
554 nm
->mkNode(kind::SINE
, new_arg
));
556 else if (r_abs
== rone
)
558 // sin( PI + x ) = -sin( x )
561 return RewriteResponse(REWRITE_DONE
, mkRationalNode(Rational(0)));
565 return RewriteResponse(
567 nm
->mkNode(kind::UMINUS
, nm
->mkNode(kind::SINE
, rem
)));
570 else if (rem
.isNull())
572 // other rational cases based on Niven's theorem
573 // (https://en.wikipedia.org/wiki/Niven%27s_theorem)
574 Integer one
= Integer(1);
575 Integer two
= Integer(2);
576 Integer six
= Integer(6);
577 if (r_abs
.getDenominator() == two
)
579 Assert(r_abs
.getNumerator() == one
);
580 return RewriteResponse(REWRITE_DONE
,
581 mkRationalNode(Rational(r
.sgn())));
583 else if (r_abs
.getDenominator() == six
)
585 Integer five
= Integer(5);
586 if (r_abs
.getNumerator() == one
|| r_abs
.getNumerator() == five
)
588 return RewriteResponse(
590 mkRationalNode(Rational(r
.sgn()) / Rational(2)));
598 return RewriteResponse(
600 nm
->mkNode(kind::SINE
,
601 nm
->mkNode(kind::MINUS
,
602 nm
->mkNode(kind::MULT
,
603 nm
->mkConst(Rational(1) / Rational(2)),
610 return RewriteResponse(REWRITE_AGAIN_FULL
,
611 nm
->mkNode(kind::DIVISION
,
612 nm
->mkNode(kind::SINE
, t
[0]),
613 nm
->mkNode(kind::COSINE
, t
[0])));
618 return RewriteResponse(REWRITE_AGAIN_FULL
,
619 nm
->mkNode(kind::DIVISION
,
620 mkRationalNode(Rational(1)),
621 nm
->mkNode(kind::SINE
, t
[0])));
626 return RewriteResponse(REWRITE_AGAIN_FULL
,
627 nm
->mkNode(kind::DIVISION
,
628 mkRationalNode(Rational(1)),
629 nm
->mkNode(kind::COSINE
, t
[0])));
632 case kind::COTANGENT
:
634 return RewriteResponse(REWRITE_AGAIN_FULL
,
635 nm
->mkNode(kind::DIVISION
,
636 nm
->mkNode(kind::COSINE
, t
[0]),
637 nm
->mkNode(kind::SINE
, t
[0])));
643 return RewriteResponse(REWRITE_DONE
, t
);
646 RewriteResponse
ArithRewriter::postRewriteAtom(TNode atom
){
647 if(atom
.getKind() == kind::IS_INTEGER
) {
648 if(atom
[0].isConst()) {
649 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(atom
[0].getConst
<Rational
>().isIntegral()));
651 if(atom
[0].getType().isInteger()) {
652 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(true));
654 // not supported, but this isn't the right place to complain
655 return RewriteResponse(REWRITE_DONE
, atom
);
656 } else if(atom
.getKind() == kind::DIVISIBLE
) {
657 if(atom
[0].isConst()) {
658 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(bool((atom
[0].getConst
<Rational
>() / atom
.getOperator().getConst
<Divisible
>().k
).isIntegral())));
660 if(atom
.getOperator().getConst
<Divisible
>().k
.isOne()) {
661 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(true));
663 return RewriteResponse(REWRITE_AGAIN
, NodeManager::currentNM()->mkNode(kind::EQUAL
, NodeManager::currentNM()->mkNode(kind::INTS_MODULUS_TOTAL
, atom
[0], NodeManager::currentNM()->mkConst(Rational(atom
.getOperator().getConst
<Divisible
>().k
))), NodeManager::currentNM()->mkConst(Rational(0))));
667 TNode left
= atom
[0];
668 TNode right
= atom
[1];
670 Polynomial pleft
= Polynomial::parsePolynomial(left
);
671 Polynomial pright
= Polynomial::parsePolynomial(right
);
673 Debug("arith::rewriter") << "pleft " << pleft
.getNode() << std::endl
;
674 Debug("arith::rewriter") << "pright " << pright
.getNode() << std::endl
;
676 Comparison cmp
= Comparison::mkComparison(atom
.getKind(), pleft
, pright
);
677 Assert(cmp
.isNormalForm());
678 return RewriteResponse(REWRITE_DONE
, cmp
.getNode());
681 RewriteResponse
ArithRewriter::preRewriteAtom(TNode atom
){
682 Assert(isAtom(atom
));
684 NodeManager
* currNM
= NodeManager::currentNM();
686 if(atom
.getKind() == kind::EQUAL
) {
687 if(atom
[0] == atom
[1]) {
688 return RewriteResponse(REWRITE_DONE
, currNM
->mkConst(true));
690 }else if(atom
.getKind() == kind::GT
){
691 Node leq
= currNM
->mkNode(kind::LEQ
, atom
[0], atom
[1]);
692 return RewriteResponse(REWRITE_DONE
, currNM
->mkNode(kind::NOT
, leq
));
693 }else if(atom
.getKind() == kind::LT
){
694 Node geq
= currNM
->mkNode(kind::GEQ
, atom
[0], atom
[1]);
695 return RewriteResponse(REWRITE_DONE
, currNM
->mkNode(kind::NOT
, geq
));
696 }else if(atom
.getKind() == kind::IS_INTEGER
){
697 if(atom
[0].getType().isInteger()){
698 return RewriteResponse(REWRITE_DONE
, currNM
->mkConst(true));
700 }else if(atom
.getKind() == kind::DIVISIBLE
){
701 if(atom
.getOperator().getConst
<Divisible
>().k
.isOne()){
702 return RewriteResponse(REWRITE_DONE
, currNM
->mkConst(true));
706 return RewriteResponse(REWRITE_DONE
, atom
);
709 RewriteResponse
ArithRewriter::postRewrite(TNode t
){
711 RewriteResponse response
= postRewriteTerm(t
);
712 if (Debug
.isOn("arith::rewriter") && response
.d_status
== REWRITE_DONE
)
714 Polynomial::parsePolynomial(response
.d_node
);
718 RewriteResponse response
= postRewriteAtom(t
);
719 if (Debug
.isOn("arith::rewriter") && response
.d_status
== REWRITE_DONE
)
721 Comparison::parseNormalForm(response
.d_node
);
729 RewriteResponse
ArithRewriter::preRewrite(TNode t
){
731 return preRewriteTerm(t
);
733 return preRewriteAtom(t
);
739 Node
ArithRewriter::makeUnaryMinusNode(TNode n
){
740 Rational
qNegOne(-1);
741 return NodeManager::currentNM()->mkNode(kind::MULT
, mkRationalNode(qNegOne
),n
);
744 Node
ArithRewriter::makeSubtractionNode(TNode l
, TNode r
){
745 Node negR
= makeUnaryMinusNode(r
);
746 Node diff
= NodeManager::currentNM()->mkNode(kind::PLUS
, l
, negR
);
751 RewriteResponse
ArithRewriter::rewriteDiv(TNode t
, bool pre
){
752 Assert(t
.getKind() == kind::DIVISION_TOTAL
|| t
.getKind() == kind::DIVISION
);
756 if(right
.getKind() == kind::CONST_RATIONAL
){
757 const Rational
& den
= right
.getConst
<Rational
>();
760 if(t
.getKind() == kind::DIVISION_TOTAL
){
761 return RewriteResponse(REWRITE_DONE
, mkRationalNode(0));
763 // This is unsupported, but this is not a good place to complain
764 return RewriteResponse(REWRITE_DONE
, t
);
767 Assert(den
!= Rational(0));
769 if(left
.getKind() == kind::CONST_RATIONAL
){
770 const Rational
& num
= left
.getConst
<Rational
>();
771 Rational div
= num
/ den
;
772 Node result
= mkRationalNode(div
);
773 return RewriteResponse(REWRITE_DONE
, result
);
776 Rational div
= den
.inverse();
778 Node result
= mkRationalNode(div
);
780 Node mult
= NodeManager::currentNM()->mkNode(kind::MULT
,left
,result
);
782 return RewriteResponse(REWRITE_DONE
, mult
);
784 return RewriteResponse(REWRITE_AGAIN
, mult
);
787 return RewriteResponse(REWRITE_DONE
, t
);
791 RewriteResponse
ArithRewriter::rewriteIntsDivMod(TNode t
, bool pre
)
793 NodeManager
* nm
= NodeManager::currentNM();
794 Kind k
= t
.getKind();
795 Node zero
= nm
->mkConst(Rational(0));
796 if (k
== kind::INTS_MODULUS
)
798 if (t
[1].isConst() && !t
[1].getConst
<Rational
>().isZero())
800 // can immediately replace by INTS_MODULUS_TOTAL
801 Node ret
= nm
->mkNode(kind::INTS_MODULUS_TOTAL
, t
[0], t
[1]);
802 return returnRewrite(t
, ret
, Rewrite::MOD_TOTAL_BY_CONST
);
805 if (k
== kind::INTS_DIVISION
)
807 if (t
[1].isConst() && !t
[1].getConst
<Rational
>().isZero())
809 // can immediately replace by INTS_DIVISION_TOTAL
810 Node ret
= nm
->mkNode(kind::INTS_DIVISION_TOTAL
, t
[0], t
[1]);
811 return returnRewrite(t
, ret
, Rewrite::DIV_TOTAL_BY_CONST
);
814 return RewriteResponse(REWRITE_DONE
, t
);
817 RewriteResponse
ArithRewriter::rewriteIntsDivModTotal(TNode t
, bool pre
)
821 // do not rewrite at prewrite.
822 return RewriteResponse(REWRITE_DONE
, t
);
824 NodeManager
* nm
= NodeManager::currentNM();
825 Kind k
= t
.getKind();
826 Assert(k
== kind::INTS_MODULUS_TOTAL
|| k
== kind::INTS_DIVISION_TOTAL
);
829 bool dIsConstant
= d
.getKind() == kind::CONST_RATIONAL
;
830 if(dIsConstant
&& d
.getConst
<Rational
>().isZero()){
831 // (div x 0) ---> 0 or (mod x 0) ---> 0
832 return returnRewrite(t
, mkRationalNode(0), Rewrite::DIV_MOD_BY_ZERO
);
833 }else if(dIsConstant
&& d
.getConst
<Rational
>().isOne()){
834 if (k
== kind::INTS_MODULUS_TOTAL
)
837 return returnRewrite(t
, mkRationalNode(0), Rewrite::MOD_BY_ONE
);
839 Assert(k
== kind::INTS_DIVISION_TOTAL
);
841 return returnRewrite(t
, n
, Rewrite::DIV_BY_ONE
);
843 else if (dIsConstant
&& d
.getConst
<Rational
>().sgn() < 0)
846 // (div x (- c)) ---> (- (div x c))
847 // (mod x (- c)) ---> (mod x c)
848 Node nn
= nm
->mkNode(k
, t
[0], nm
->mkConst(-t
[1].getConst
<Rational
>()));
849 Node ret
= (k
== kind::INTS_DIVISION
|| k
== kind::INTS_DIVISION_TOTAL
)
850 ? nm
->mkNode(kind::UMINUS
, nn
)
852 return returnRewrite(t
, ret
, Rewrite::DIV_MOD_PULL_NEG_DEN
);
854 else if (dIsConstant
&& n
.getKind() == kind::CONST_RATIONAL
)
856 Assert(d
.getConst
<Rational
>().isIntegral());
857 Assert(n
.getConst
<Rational
>().isIntegral());
858 Assert(!d
.getConst
<Rational
>().isZero());
859 Integer di
= d
.getConst
<Rational
>().getNumerator();
860 Integer ni
= n
.getConst
<Rational
>().getNumerator();
862 bool isDiv
= (k
== kind::INTS_DIVISION
|| k
== kind::INTS_DIVISION_TOTAL
);
864 Integer result
= isDiv
? ni
.euclidianDivideQuotient(di
) : ni
.euclidianDivideRemainder(di
);
866 // constant evaluation
867 // (mod c1 c2) ---> c3 or (div c1 c2) ---> c3
868 Node resultNode
= mkRationalNode(Rational(result
));
869 return returnRewrite(t
, resultNode
, Rewrite::CONST_EVAL
);
871 if (k
== kind::INTS_MODULUS_TOTAL
)
873 // Note these rewrites do not need to account for modulus by zero as being
874 // a UF, which is handled by the reduction of INTS_MODULUS.
875 Kind k0
= t
[0].getKind();
876 if (k0
== kind::INTS_MODULUS_TOTAL
&& t
[0][1] == t
[1])
878 // (mod (mod x c) c) --> (mod x c)
879 return returnRewrite(t
, t
[0], Rewrite::MOD_OVER_MOD
);
881 else if (k0
== kind::NONLINEAR_MULT
|| k0
== kind::MULT
|| k0
== kind::PLUS
)
884 std::vector
<Node
> newChildren
;
885 bool childChanged
= false;
886 for (const Node
& tc
: t
[0])
888 if (tc
.getKind() == kind::INTS_MODULUS_TOTAL
&& tc
[1] == t
[1])
890 newChildren
.push_back(tc
[0]);
894 newChildren
.push_back(tc
);
898 // (mod (op ... (mod x c) ...) c) ---> (mod (op ... x ...) c) where
899 // op is one of { NONLINEAR_MULT, MULT, PLUS }.
900 Node ret
= nm
->mkNode(k0
, newChildren
);
901 ret
= nm
->mkNode(kind::INTS_MODULUS_TOTAL
, ret
, t
[1]);
902 return returnRewrite(t
, ret
, Rewrite::MOD_CHILD_MOD
);
908 Assert(k
== kind::INTS_DIVISION_TOTAL
);
909 // Note these rewrites do not need to account for division by zero as being
910 // a UF, which is handled by the reduction of INTS_DIVISION.
911 if (t
[0].getKind() == kind::INTS_MODULUS_TOTAL
&& t
[0][1] == t
[1])
913 // (div (mod x c) c) --> 0
914 Node ret
= mkRationalNode(0);
915 return returnRewrite(t
, ret
, Rewrite::DIV_OVER_MOD
);
918 return RewriteResponse(REWRITE_DONE
, t
);
921 TrustNode
ArithRewriter::expandDefinition(Node node
)
923 // call eliminate operators, to eliminate partial operators only
924 std::vector
<SkolemLemma
> lems
;
925 TrustNode ret
= d_opElim
.eliminate(node
, lems
, true);
926 Assert(lems
.empty());
930 RewriteResponse
ArithRewriter::returnRewrite(TNode t
, Node ret
, Rewrite r
)
932 Trace("arith-rewrite") << "ArithRewriter : " << t
<< " == " << ret
<< " by "
934 return RewriteResponse(REWRITE_AGAIN_FULL
, ret
);
938 } // namespace theory