1 /******************************************************************************
2 * Top contributors (to current version):
3 * Andrew Reynolds, Tim King, Morgan Deters
5 * This file is part of the cvc5 project.
7 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 * in the top-level source directory and their institutional affiliations.
9 * All rights reserved. See the file COPYING in the top-level source
10 * directory for licensing information.
11 * ****************************************************************************
13 * [[ Add one-line brief description here ]]
15 * [[ Add lengthier description here ]]
16 * \todo document this file
19 #include "theory/arith/arith_rewriter.h"
26 #include "smt/logic_exception.h"
27 #include "theory/arith/arith_msum.h"
28 #include "theory/arith/arith_utilities.h"
29 #include "theory/arith/normal_form.h"
30 #include "theory/arith/operator_elim.h"
31 #include "theory/theory.h"
32 #include "util/bitvector.h"
33 #include "util/divisible.h"
34 #include "util/iand.h"
36 using namespace cvc5::kind
;
42 ArithRewriter::ArithRewriter(OperatorElim
& oe
) : d_opElim(oe
) {}
44 bool ArithRewriter::isAtom(TNode n
) {
46 return arith::isRelationOperator(k
) || k
== kind::IS_INTEGER
47 || k
== kind::DIVISIBLE
;
50 RewriteResponse
ArithRewriter::rewriteConstant(TNode t
){
52 Assert(t
.getKind() == CONST_RATIONAL
|| t
.getKind() == CONST_INTEGER
);
54 return RewriteResponse(REWRITE_DONE
, t
);
57 RewriteResponse
ArithRewriter::rewriteVariable(TNode t
){
60 return RewriteResponse(REWRITE_DONE
, t
);
63 RewriteResponse
ArithRewriter::rewriteMinus(TNode t
)
65 Assert(t
.getKind() == kind::MINUS
);
66 Assert(t
.getNumChildren() == 2);
68 auto* nm
= NodeManager::currentNM();
72 return RewriteResponse(REWRITE_DONE
,
73 nm
->mkConstRealOrInt(t
.getType(), Rational(0)));
75 return RewriteResponse(
77 nm
->mkNode(Kind::PLUS
, t
[0], makeUnaryMinusNode(t
[1])));
80 RewriteResponse
ArithRewriter::rewriteUMinus(TNode t
, bool pre
){
81 Assert(t
.getKind() == kind::UMINUS
);
85 Rational neg
= -(t
[0].getConst
<Rational
>());
86 NodeManager
* nm
= NodeManager::currentNM();
87 return RewriteResponse(REWRITE_DONE
,
88 nm
->mkConstRealOrInt(t
[0].getType(), neg
));
91 Node noUminus
= makeUnaryMinusNode(t
[0]);
93 return RewriteResponse(REWRITE_DONE
, noUminus
);
95 return RewriteResponse(REWRITE_AGAIN
, noUminus
);
98 RewriteResponse
ArithRewriter::preRewriteTerm(TNode t
){
100 return rewriteConstant(t
);
102 return rewriteVariable(t
);
104 switch(Kind k
= t
.getKind()){
105 case kind::MINUS
: return rewriteMinus(t
);
106 case kind::UMINUS
: return rewriteUMinus(t
, true);
108 case kind::DIVISION_TOTAL
: return rewriteDiv(t
, true);
109 case kind::PLUS
: return preRewritePlus(t
);
111 case kind::NONLINEAR_MULT
: return preRewriteMult(t
);
112 case kind::IAND
: return RewriteResponse(REWRITE_DONE
, t
);
113 case kind::POW2
: return RewriteResponse(REWRITE_DONE
, t
);
114 case kind::EXPONENTIAL
:
120 case kind::COTANGENT
:
122 case kind::ARCCOSINE
:
123 case kind::ARCTANGENT
:
124 case kind::ARCCOSECANT
:
125 case kind::ARCSECANT
:
126 case kind::ARCCOTANGENT
:
127 case kind::SQRT
: return preRewriteTranscendental(t
);
128 case kind::INTS_DIVISION
:
129 case kind::INTS_MODULUS
: return rewriteIntsDivMod(t
, true);
130 case kind::INTS_DIVISION_TOTAL
:
131 case kind::INTS_MODULUS_TOTAL
: return rewriteIntsDivModTotal(t
, true);
135 const Rational
& rat
= t
[0].getConst
<Rational
>();
138 return RewriteResponse(REWRITE_DONE
, t
[0]);
142 return RewriteResponse(REWRITE_DONE
,
143 NodeManager::currentNM()->mkConstRealOrInt(
144 t
[0].getType(), -rat
));
147 return RewriteResponse(REWRITE_DONE
, t
);
148 case kind::IS_INTEGER
:
149 case kind::TO_INTEGER
: return RewriteResponse(REWRITE_DONE
, t
);
151 case kind::CAST_TO_REAL
: return RewriteResponse(REWRITE_DONE
, t
[0]);
152 case kind::POW
: return RewriteResponse(REWRITE_DONE
, t
);
153 case kind::PI
: return RewriteResponse(REWRITE_DONE
, t
);
154 default: Unhandled() << k
;
159 RewriteResponse
ArithRewriter::postRewriteTerm(TNode t
){
161 return rewriteConstant(t
);
163 return rewriteVariable(t
);
165 Trace("arith-rewriter") << "postRewriteTerm: " << t
<< std::endl
;
167 case kind::MINUS
: return rewriteMinus(t
);
168 case kind::UMINUS
: return rewriteUMinus(t
, false);
170 case kind::DIVISION_TOTAL
: return rewriteDiv(t
, false);
171 case kind::PLUS
: return postRewritePlus(t
);
173 case kind::NONLINEAR_MULT
: return postRewriteMult(t
);
174 case kind::IAND
: return postRewriteIAnd(t
);
175 case kind::POW2
: return postRewritePow2(t
);
176 case kind::EXPONENTIAL
:
182 case kind::COTANGENT
:
184 case kind::ARCCOSINE
:
185 case kind::ARCTANGENT
:
186 case kind::ARCCOSECANT
:
187 case kind::ARCSECANT
:
188 case kind::ARCCOTANGENT
:
189 case kind::SQRT
: return postRewriteTranscendental(t
);
190 case kind::INTS_DIVISION
:
191 case kind::INTS_MODULUS
: return rewriteIntsDivMod(t
, false);
192 case kind::INTS_DIVISION_TOTAL
:
193 case kind::INTS_MODULUS_TOTAL
: return rewriteIntsDivModTotal(t
, false);
197 const Rational
& rat
= t
[0].getConst
<Rational
>();
200 return RewriteResponse(REWRITE_DONE
, t
[0]);
204 return RewriteResponse(REWRITE_DONE
,
205 NodeManager::currentNM()->mkConstRealOrInt(
206 t
[0].getType(), -rat
));
209 return RewriteResponse(REWRITE_DONE
, t
);
211 case kind::CAST_TO_REAL
: return RewriteResponse(REWRITE_DONE
, t
[0]);
212 case kind::TO_INTEGER
: return rewriteExtIntegerOp(t
);
217 const Rational
& exp
= t
[1].getConst
<Rational
>();
220 return RewriteResponse(REWRITE_DONE
,
221 NodeManager::currentNM()->mkConstRealOrInt(
222 t
.getType(), Rational(1)));
223 }else if(exp
.sgn() > 0 && exp
.isIntegral()){
224 cvc5::Rational
r(expr::NodeValue::MAX_CHILDREN
);
227 unsigned num
= exp
.getNumerator().toUnsignedInt();
229 return RewriteResponse(REWRITE_AGAIN
, base
);
231 NodeBuilder
nb(kind::MULT
);
232 for(unsigned i
=0; i
< num
; ++i
){
235 Assert(nb
.getNumChildren() > 0);
237 return RewriteResponse(REWRITE_AGAIN
, mult
);
242 else if (t
[0].isConst()
243 && t
[0].getConst
<Rational
>().getNumerator().toUnsignedInt()
246 return RewriteResponse(
247 REWRITE_DONE
, NodeManager::currentNM()->mkNode(kind::POW2
, t
[1]));
250 // Todo improve the exception thrown
251 std::stringstream ss
;
252 ss
<< "The exponent of the POW(^) operator can only be a positive "
253 "integral constant below "
254 << (expr::NodeValue::MAX_CHILDREN
+ 1) << ". ";
255 ss
<< "Exception occurred in:" << std::endl
;
257 throw LogicException(ss
.str());
260 return RewriteResponse(REWRITE_DONE
, t
);
268 RewriteResponse
ArithRewriter::preRewriteMult(TNode t
){
269 Assert(t
.getKind() == kind::MULT
|| t
.getKind() == kind::NONLINEAR_MULT
);
271 if(t
.getNumChildren() == 2){
272 if (t
[0].isConst() && t
[0].getConst
<Rational
>().isOne())
274 return RewriteResponse(REWRITE_DONE
, t
[1]);
276 if (t
[1].isConst() && t
[1].getConst
<Rational
>().isOne())
278 return RewriteResponse(REWRITE_DONE
, t
[0]);
282 // Rewrite multiplications with a 0 argument and to 0
283 for(TNode::iterator i
= t
.begin(); i
!= t
.end(); ++i
) {
286 if((*i
).getConst
<Rational
>().isZero()) {
288 return RewriteResponse(REWRITE_DONE
, zero
);
292 return RewriteResponse(REWRITE_DONE
, t
);
295 static bool canFlatten(Kind k
, TNode t
){
296 for(TNode::iterator i
= t
.begin(); i
!= t
.end(); ++i
) {
298 if(child
.getKind() == k
){
305 static void flatten(std::vector
<TNode
>& pb
, Kind k
, TNode t
){
306 if(t
.getKind() == k
){
307 for(TNode::iterator i
= t
.begin(); i
!= t
.end(); ++i
) {
309 if(child
.getKind() == k
){
310 flatten(pb
, k
, child
);
320 static Node
flatten(Kind k
, TNode t
){
321 std::vector
<TNode
> pb
;
323 Assert(pb
.size() >= 2);
324 return NodeManager::currentNM()->mkNode(k
, pb
);
327 RewriteResponse
ArithRewriter::preRewritePlus(TNode t
){
328 Assert(t
.getKind() == kind::PLUS
);
330 if(canFlatten(kind::PLUS
, t
)){
331 return RewriteResponse(REWRITE_DONE
, flatten(kind::PLUS
, t
));
333 return RewriteResponse(REWRITE_DONE
, t
);
337 RewriteResponse
ArithRewriter::postRewritePlus(TNode t
){
338 Assert(t
.getKind() == kind::PLUS
);
340 std::vector
<Monomial
> monomials
;
341 std::vector
<Polynomial
> polynomials
;
343 for(TNode::iterator i
= t
.begin(), end
= t
.end(); i
!= end
; ++i
){
345 if(Monomial::isMember(curr
)){
346 monomials
.push_back(Monomial::parseMonomial(curr
));
348 polynomials
.push_back(Polynomial::parsePolynomial(curr
));
352 if(!monomials
.empty()){
353 Monomial::sort(monomials
);
354 Monomial::combineAdjacentMonomials(monomials
);
355 polynomials
.push_back(Polynomial::mkPolynomial(monomials
));
358 Polynomial res
= Polynomial::sumPolynomials(polynomials
);
360 return RewriteResponse(REWRITE_DONE
, res
.getNode());
363 RewriteResponse
ArithRewriter::postRewriteMult(TNode t
){
364 Assert(t
.getKind() == kind::MULT
|| t
.getKind() == kind::NONLINEAR_MULT
);
366 Polynomial res
= Polynomial::mkOne();
368 for(TNode::iterator i
= t
.begin(), end
= t
.end(); i
!= end
; ++i
){
370 Polynomial currPoly
= Polynomial::parsePolynomial(curr
);
372 res
= res
* currPoly
;
375 return RewriteResponse(REWRITE_DONE
, res
.getNode());
378 RewriteResponse
ArithRewriter::postRewritePow2(TNode t
)
380 Assert(t
.getKind() == kind::POW2
);
381 NodeManager
* nm
= NodeManager::currentNM();
382 // if constant, we eliminate
385 // pow2 is only supported for integers
386 Assert(t
[0].getType().isInteger());
387 Integer i
= t
[0].getConst
<Rational
>().getNumerator();
390 return RewriteResponse(REWRITE_DONE
, nm
->mkConstInt(Rational(0)));
392 // (pow2 t) ---> (pow 2 t) and continue rewriting to eliminate pow
393 Node two
= nm
->mkConstInt(Rational(Integer(2)));
394 Node ret
= nm
->mkNode(kind::POW
, two
, t
[0]);
395 return RewriteResponse(REWRITE_AGAIN
, ret
);
397 return RewriteResponse(REWRITE_DONE
, t
);
400 RewriteResponse
ArithRewriter::postRewriteIAnd(TNode t
)
402 Assert(t
.getKind() == kind::IAND
);
403 size_t bsize
= t
.getOperator().getConst
<IntAnd
>().d_size
;
404 NodeManager
* nm
= NodeManager::currentNM();
405 // if constant, we eliminate
406 if (t
[0].isConst() && t
[1].isConst())
408 Node iToBvop
= nm
->mkConst(IntToBitVector(bsize
));
409 Node arg1
= nm
->mkNode(kind::INT_TO_BITVECTOR
, iToBvop
, t
[0]);
410 Node arg2
= nm
->mkNode(kind::INT_TO_BITVECTOR
, iToBvop
, t
[1]);
411 Node bvand
= nm
->mkNode(kind::BITVECTOR_AND
, arg1
, arg2
);
412 Node ret
= nm
->mkNode(kind::BITVECTOR_TO_NAT
, bvand
);
413 return RewriteResponse(REWRITE_AGAIN_FULL
, ret
);
415 else if (t
[0] > t
[1])
417 // ((_ iand k) x y) ---> ((_ iand k) y x) if x > y by node ordering
418 Node ret
= nm
->mkNode(kind::IAND
, t
.getOperator(), t
[1], t
[0]);
419 return RewriteResponse(REWRITE_AGAIN
, ret
);
421 else if (t
[0] == t
[1])
423 // ((_ iand k) x x) ---> x
424 return RewriteResponse(REWRITE_DONE
, t
[0]);
426 // simplifications involving constants
427 for (unsigned i
= 0; i
< 2; i
++)
433 if (t
[i
].getConst
<Rational
>().sgn() == 0)
435 // ((_ iand k) 0 y) ---> 0
436 return RewriteResponse(REWRITE_DONE
, t
[i
]);
438 if (t
[i
].getConst
<Rational
>().getNumerator() == Integer(2).pow(bsize
) - 1)
440 // ((_ iand k) 111...1 y) ---> y
441 return RewriteResponse(REWRITE_DONE
, t
[i
== 0 ? 1 : 0]);
444 return RewriteResponse(REWRITE_DONE
, t
);
447 RewriteResponse
ArithRewriter::preRewriteTranscendental(TNode t
) {
448 return RewriteResponse(REWRITE_DONE
, t
);
451 RewriteResponse
ArithRewriter::postRewriteTranscendental(TNode t
) {
452 Trace("arith-tf-rewrite") << "Rewrite transcendental function : " << t
<< std::endl
;
453 NodeManager
* nm
= NodeManager::currentNM();
454 switch( t
.getKind() ){
455 case kind::EXPONENTIAL
: {
458 Node one
= nm
->mkConstReal(Rational(1));
459 if(t
[0].getConst
<Rational
>().sgn()>=0 && t
[0].getType().isInteger() && t
[0]!=one
){
460 return RewriteResponse(
462 nm
->mkNode(kind::POW
, nm
->mkNode(kind::EXPONENTIAL
, one
), t
[0]));
464 return RewriteResponse(REWRITE_DONE
, t
);
467 else if (t
[0].getKind() == kind::PLUS
)
469 std::vector
<Node
> product
;
470 for (const Node tc
: t
[0])
472 product
.push_back(nm
->mkNode(kind::EXPONENTIAL
, tc
));
474 // We need to do a full rewrite here, since we can get exponentials of
475 // constants, e.g. when we are rewriting exp(2 + x)
476 return RewriteResponse(REWRITE_AGAIN_FULL
,
477 nm
->mkNode(kind::MULT
, product
));
484 const Rational
& rat
= t
[0].getConst
<Rational
>();
486 return RewriteResponse(REWRITE_DONE
, nm
->mkConstReal(Rational(0)));
488 else if (rat
.sgn() == -1)
490 Node ret
= nm
->mkNode(kind::UMINUS
,
491 nm
->mkNode(kind::SINE
, nm
->mkConstReal(-rat
)));
492 return RewriteResponse(REWRITE_AGAIN_FULL
, ret
);
495 // get the factor of PI in the argument
499 std::map
<Node
, Node
> msum
;
500 if (ArithMSum::getMonomialSum(t
[0], msum
))
503 std::map
<Node
, Node
>::iterator itm
= msum
.find(pi
);
504 if (itm
!= msum
.end())
506 if (itm
->second
.isNull())
508 pi_factor
= nm
->mkConstReal(Rational(1));
512 pi_factor
= itm
->second
;
517 rem
= ArithMSum::mkNode(t
[0].getType(), msum
);
526 // if there is a factor of PI
527 if( !pi_factor
.isNull() ){
528 Trace("arith-tf-rewrite-debug") << "Process pi factor = " << pi_factor
<< std::endl
;
529 Rational r
= pi_factor
.getConst
<Rational
>();
530 Rational r_abs
= r
.abs();
531 Rational rone
= Rational(1);
532 Node ntwo
= nm
->mkConstInt(Rational(2));
535 //add/substract 2*pi beyond scope
536 Node ra_div_two
= nm
->mkNode(
537 kind::INTS_DIVISION
, mkRationalNode(r_abs
+ rone
), ntwo
);
541 nm
->mkNode(kind::MINUS
,
543 nm
->mkNode(kind::MULT
, ntwo
, ra_div_two
));
545 Assert(r
.sgn() == -1);
547 nm
->mkNode(kind::PLUS
,
549 nm
->mkNode(kind::MULT
, ntwo
, ra_div_two
));
551 Node new_arg
= nm
->mkNode(kind::MULT
, new_pi_factor
, pi
);
554 new_arg
= nm
->mkNode(kind::PLUS
, new_arg
, rem
);
556 // sin( 2*n*PI + x ) = sin( x )
557 return RewriteResponse(REWRITE_AGAIN_FULL
,
558 nm
->mkNode(kind::SINE
, new_arg
));
560 else if (r_abs
== rone
)
562 // sin( PI + x ) = -sin( x )
565 return RewriteResponse(REWRITE_DONE
, nm
->mkConstReal(Rational(0)));
569 return RewriteResponse(
571 nm
->mkNode(kind::UMINUS
, nm
->mkNode(kind::SINE
, rem
)));
574 else if (rem
.isNull())
576 // other rational cases based on Niven's theorem
577 // (https://en.wikipedia.org/wiki/Niven%27s_theorem)
578 Integer one
= Integer(1);
579 Integer two
= Integer(2);
580 Integer six
= Integer(6);
581 if (r_abs
.getDenominator() == two
)
583 Assert(r_abs
.getNumerator() == one
);
584 return RewriteResponse(REWRITE_DONE
,
585 nm
->mkConstReal(Rational(r
.sgn())));
587 else if (r_abs
.getDenominator() == six
)
589 Integer five
= Integer(5);
590 if (r_abs
.getNumerator() == one
|| r_abs
.getNumerator() == five
)
592 return RewriteResponse(
594 nm
->mkConstReal(Rational(r
.sgn()) / Rational(2)));
602 return RewriteResponse(
606 nm
->mkNode(kind::MINUS
,
607 nm
->mkNode(kind::MULT
,
608 nm
->mkConstReal(Rational(1) / Rational(2)),
615 return RewriteResponse(REWRITE_AGAIN_FULL
,
616 nm
->mkNode(kind::DIVISION
,
617 nm
->mkNode(kind::SINE
, t
[0]),
618 nm
->mkNode(kind::COSINE
, t
[0])));
623 return RewriteResponse(REWRITE_AGAIN_FULL
,
624 nm
->mkNode(kind::DIVISION
,
625 nm
->mkConstReal(Rational(1)),
626 nm
->mkNode(kind::SINE
, t
[0])));
631 return RewriteResponse(REWRITE_AGAIN_FULL
,
632 nm
->mkNode(kind::DIVISION
,
633 nm
->mkConstReal(Rational(1)),
634 nm
->mkNode(kind::COSINE
, t
[0])));
637 case kind::COTANGENT
:
639 return RewriteResponse(REWRITE_AGAIN_FULL
,
640 nm
->mkNode(kind::DIVISION
,
641 nm
->mkNode(kind::COSINE
, t
[0]),
642 nm
->mkNode(kind::SINE
, t
[0])));
648 return RewriteResponse(REWRITE_DONE
, t
);
651 RewriteResponse
ArithRewriter::postRewriteAtom(TNode atom
){
652 if(atom
.getKind() == kind::IS_INTEGER
) {
653 return rewriteExtIntegerOp(atom
);
654 } else if(atom
.getKind() == kind::DIVISIBLE
) {
655 if(atom
[0].isConst()) {
656 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(bool((atom
[0].getConst
<Rational
>() / atom
.getOperator().getConst
<Divisible
>().k
).isIntegral())));
658 if(atom
.getOperator().getConst
<Divisible
>().k
.isOne()) {
659 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(true));
661 NodeManager
* nm
= NodeManager::currentNM();
662 return RewriteResponse(
664 nm
->mkNode(kind::EQUAL
,
665 nm
->mkNode(kind::INTS_MODULUS_TOTAL
,
667 nm
->mkConstInt(Rational(
668 atom
.getOperator().getConst
<Divisible
>().k
))),
669 nm
->mkConstInt(Rational(0))));
673 TNode left
= atom
[0];
674 TNode right
= atom
[1];
676 Polynomial pleft
= Polynomial::parsePolynomial(left
);
677 Polynomial pright
= Polynomial::parsePolynomial(right
);
679 Debug("arith::rewriter") << "pleft " << pleft
.getNode() << std::endl
;
680 Debug("arith::rewriter") << "pright " << pright
.getNode() << std::endl
;
682 Comparison cmp
= Comparison::mkComparison(atom
.getKind(), pleft
, pright
);
683 Assert(cmp
.isNormalForm());
684 return RewriteResponse(REWRITE_DONE
, cmp
.getNode());
687 RewriteResponse
ArithRewriter::preRewriteAtom(TNode atom
){
688 Assert(isAtom(atom
));
690 NodeManager
* currNM
= NodeManager::currentNM();
692 if(atom
.getKind() == kind::EQUAL
) {
693 if(atom
[0] == atom
[1]) {
694 return RewriteResponse(REWRITE_DONE
, currNM
->mkConst(true));
696 }else if(atom
.getKind() == kind::GT
){
697 Node leq
= currNM
->mkNode(kind::LEQ
, atom
[0], atom
[1]);
698 return RewriteResponse(REWRITE_DONE
, currNM
->mkNode(kind::NOT
, leq
));
699 }else if(atom
.getKind() == kind::LT
){
700 Node geq
= currNM
->mkNode(kind::GEQ
, atom
[0], atom
[1]);
701 return RewriteResponse(REWRITE_DONE
, currNM
->mkNode(kind::NOT
, geq
));
702 }else if(atom
.getKind() == kind::IS_INTEGER
){
703 if(atom
[0].getType().isInteger()){
704 return RewriteResponse(REWRITE_DONE
, currNM
->mkConst(true));
706 }else if(atom
.getKind() == kind::DIVISIBLE
){
707 if(atom
.getOperator().getConst
<Divisible
>().k
.isOne()){
708 return RewriteResponse(REWRITE_DONE
, currNM
->mkConst(true));
712 return RewriteResponse(REWRITE_DONE
, atom
);
715 RewriteResponse
ArithRewriter::postRewrite(TNode t
){
717 RewriteResponse response
= postRewriteTerm(t
);
718 if (Debug
.isOn("arith::rewriter") && response
.d_status
== REWRITE_DONE
)
720 Polynomial::parsePolynomial(response
.d_node
);
724 RewriteResponse response
= postRewriteAtom(t
);
725 if (Debug
.isOn("arith::rewriter") && response
.d_status
== REWRITE_DONE
)
727 Comparison::parseNormalForm(response
.d_node
);
735 RewriteResponse
ArithRewriter::preRewrite(TNode t
){
737 return preRewriteTerm(t
);
739 return preRewriteAtom(t
);
745 Node
ArithRewriter::makeUnaryMinusNode(TNode n
){
746 NodeManager
* nm
= NodeManager::currentNM();
747 Rational
qNegOne(-1);
748 return nm
->mkNode(kind::MULT
, nm
->mkConstRealOrInt(n
.getType(), qNegOne
), n
);
751 RewriteResponse
ArithRewriter::rewriteDiv(TNode t
, bool pre
){
752 Assert(t
.getKind() == kind::DIVISION_TOTAL
|| t
.getKind() == kind::DIVISION
);
758 NodeManager
* nm
= NodeManager::currentNM();
759 const Rational
& den
= right
.getConst
<Rational
>();
762 if(t
.getKind() == kind::DIVISION_TOTAL
){
763 return RewriteResponse(REWRITE_DONE
, nm
->mkConstReal(0));
765 // This is unsupported, but this is not a good place to complain
766 return RewriteResponse(REWRITE_DONE
, t
);
769 Assert(den
!= Rational(0));
773 const Rational
& num
= left
.getConst
<Rational
>();
774 Rational div
= num
/ den
;
775 Node result
= nm
->mkConstReal(div
);
776 return RewriteResponse(REWRITE_DONE
, result
);
779 Rational div
= den
.inverse();
781 Node result
= nm
->mkConstReal(div
);
783 Node mult
= NodeManager::currentNM()->mkNode(kind::MULT
,left
,result
);
785 return RewriteResponse(REWRITE_DONE
, mult
);
787 return RewriteResponse(REWRITE_AGAIN
, mult
);
790 return RewriteResponse(REWRITE_DONE
, t
);
793 RewriteResponse
ArithRewriter::rewriteIntsDivMod(TNode t
, bool pre
)
795 NodeManager
* nm
= NodeManager::currentNM();
796 Kind k
= t
.getKind();
797 if (k
== kind::INTS_MODULUS
)
799 if (t
[1].isConst() && !t
[1].getConst
<Rational
>().isZero())
801 // can immediately replace by INTS_MODULUS_TOTAL
802 Node ret
= nm
->mkNode(kind::INTS_MODULUS_TOTAL
, t
[0], t
[1]);
803 return returnRewrite(t
, ret
, Rewrite::MOD_TOTAL_BY_CONST
);
806 if (k
== kind::INTS_DIVISION
)
808 if (t
[1].isConst() && !t
[1].getConst
<Rational
>().isZero())
810 // can immediately replace by INTS_DIVISION_TOTAL
811 Node ret
= nm
->mkNode(kind::INTS_DIVISION_TOTAL
, t
[0], t
[1]);
812 return returnRewrite(t
, ret
, Rewrite::DIV_TOTAL_BY_CONST
);
815 return RewriteResponse(REWRITE_DONE
, t
);
818 RewriteResponse
ArithRewriter::rewriteExtIntegerOp(TNode t
)
820 Assert(t
.getKind() == kind::TO_INTEGER
|| t
.getKind() == kind::IS_INTEGER
);
821 bool isPred
= t
.getKind() == kind::IS_INTEGER
;
822 NodeManager
* nm
= NodeManager::currentNM();
828 ret
= nm
->mkConst(t
[0].getConst
<Rational
>().isIntegral());
832 ret
= nm
->mkConstInt(Rational(t
[0].getConst
<Rational
>().floor()));
834 return returnRewrite(t
, ret
, Rewrite::INT_EXT_CONST
);
836 if (t
[0].getType().isInteger())
838 Node ret
= isPred
? nm
->mkConst(true) : Node(t
[0]);
839 return returnRewrite(t
, ret
, Rewrite::INT_EXT_INT
);
841 if (t
[0].getKind() == kind::PI
)
843 Node ret
= isPred
? nm
->mkConst(false) : nm
->mkConstReal(Rational(3));
844 return returnRewrite(t
, ret
, Rewrite::INT_EXT_PI
);
846 return RewriteResponse(REWRITE_DONE
, t
);
849 RewriteResponse
ArithRewriter::rewriteIntsDivModTotal(TNode t
, bool pre
)
853 // do not rewrite at prewrite.
854 return RewriteResponse(REWRITE_DONE
, t
);
856 NodeManager
* nm
= NodeManager::currentNM();
857 Kind k
= t
.getKind();
858 Assert(k
== kind::INTS_MODULUS_TOTAL
|| k
== kind::INTS_DIVISION_TOTAL
);
861 bool dIsConstant
= d
.isConst();
862 if(dIsConstant
&& d
.getConst
<Rational
>().isZero()){
863 // (div x 0) ---> 0 or (mod x 0) ---> 0
864 return returnRewrite(t
, nm
->mkConstInt(0), Rewrite::DIV_MOD_BY_ZERO
);
865 }else if(dIsConstant
&& d
.getConst
<Rational
>().isOne()){
866 if (k
== kind::INTS_MODULUS_TOTAL
)
869 return returnRewrite(t
, nm
->mkConstInt(0), Rewrite::MOD_BY_ONE
);
871 Assert(k
== kind::INTS_DIVISION_TOTAL
);
873 return returnRewrite(t
, n
, Rewrite::DIV_BY_ONE
);
875 else if (dIsConstant
&& d
.getConst
<Rational
>().sgn() < 0)
878 // (div x (- c)) ---> (- (div x c))
879 // (mod x (- c)) ---> (mod x c)
880 Node nn
= nm
->mkNode(k
, t
[0], nm
->mkConstInt(-t
[1].getConst
<Rational
>()));
881 Node ret
= (k
== kind::INTS_DIVISION
|| k
== kind::INTS_DIVISION_TOTAL
)
882 ? nm
->mkNode(kind::UMINUS
, nn
)
884 return returnRewrite(t
, ret
, Rewrite::DIV_MOD_PULL_NEG_DEN
);
886 else if (dIsConstant
&& n
.isConst())
888 Assert(d
.getConst
<Rational
>().isIntegral());
889 Assert(n
.getConst
<Rational
>().isIntegral());
890 Assert(!d
.getConst
<Rational
>().isZero());
891 Integer di
= d
.getConst
<Rational
>().getNumerator();
892 Integer ni
= n
.getConst
<Rational
>().getNumerator();
894 bool isDiv
= (k
== kind::INTS_DIVISION
|| k
== kind::INTS_DIVISION_TOTAL
);
896 Integer result
= isDiv
? ni
.euclidianDivideQuotient(di
) : ni
.euclidianDivideRemainder(di
);
898 // constant evaluation
899 // (mod c1 c2) ---> c3 or (div c1 c2) ---> c3
900 Node resultNode
= nm
->mkConstInt(Rational(result
));
901 return returnRewrite(t
, resultNode
, Rewrite::CONST_EVAL
);
903 if (k
== kind::INTS_MODULUS_TOTAL
)
905 // Note these rewrites do not need to account for modulus by zero as being
906 // a UF, which is handled by the reduction of INTS_MODULUS.
907 Kind k0
= t
[0].getKind();
908 if (k0
== kind::INTS_MODULUS_TOTAL
&& t
[0][1] == t
[1])
910 // (mod (mod x c) c) --> (mod x c)
911 return returnRewrite(t
, t
[0], Rewrite::MOD_OVER_MOD
);
913 else if (k0
== kind::NONLINEAR_MULT
|| k0
== kind::MULT
|| k0
== kind::PLUS
)
916 std::vector
<Node
> newChildren
;
917 bool childChanged
= false;
918 for (const Node
& tc
: t
[0])
920 if (tc
.getKind() == kind::INTS_MODULUS_TOTAL
&& tc
[1] == t
[1])
922 newChildren
.push_back(tc
[0]);
926 newChildren
.push_back(tc
);
930 // (mod (op ... (mod x c) ...) c) ---> (mod (op ... x ...) c) where
931 // op is one of { NONLINEAR_MULT, MULT, PLUS }.
932 Node ret
= nm
->mkNode(k0
, newChildren
);
933 ret
= nm
->mkNode(kind::INTS_MODULUS_TOTAL
, ret
, t
[1]);
934 return returnRewrite(t
, ret
, Rewrite::MOD_CHILD_MOD
);
940 Assert(k
== kind::INTS_DIVISION_TOTAL
);
941 // Note these rewrites do not need to account for division by zero as being
942 // a UF, which is handled by the reduction of INTS_DIVISION.
943 if (t
[0].getKind() == kind::INTS_MODULUS_TOTAL
&& t
[0][1] == t
[1])
945 // (div (mod x c) c) --> 0
946 Node ret
= nm
->mkConstInt(0);
947 return returnRewrite(t
, ret
, Rewrite::DIV_OVER_MOD
);
950 return RewriteResponse(REWRITE_DONE
, t
);
953 TrustNode
ArithRewriter::expandDefinition(Node node
)
955 // call eliminate operators, to eliminate partial operators only
956 std::vector
<SkolemLemma
> lems
;
957 TrustNode ret
= d_opElim
.eliminate(node
, lems
, true);
958 Assert(lems
.empty());
962 RewriteResponse
ArithRewriter::returnRewrite(TNode t
, Node ret
, Rewrite r
)
964 Trace("arith-rewrite") << "ArithRewriter : " << t
<< " == " << ret
<< " by "
966 return RewriteResponse(REWRITE_AGAIN_FULL
, ret
);
970 } // namespace theory