1 /******************************************************************************
2 * Top contributors (to current version):
3 * Andrew Reynolds, Tim King, Morgan Deters
5 * This file is part of the cvc5 project.
7 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 * in the top-level source directory and their institutional affiliations.
9 * All rights reserved. See the file COPYING in the top-level source
10 * directory for licensing information.
11 * ****************************************************************************
13 * [[ Add one-line brief description here ]]
15 * [[ Add lengthier description here ]]
16 * \todo document this file
19 #include "theory/arith/arith_rewriter.h"
27 #include "expr/algorithm/flatten.h"
28 #include "smt/logic_exception.h"
29 #include "theory/arith/arith_msum.h"
30 #include "theory/arith/arith_utilities.h"
31 #include "theory/arith/normal_form.h"
32 #include "theory/arith/operator_elim.h"
33 #include "theory/arith/rewriter/addition.h"
34 #include "theory/arith/rewriter/node_utils.h"
35 #include "theory/arith/rewriter/ordering.h"
36 #include "theory/theory.h"
37 #include "util/bitvector.h"
38 #include "util/divisible.h"
39 #include "util/iand.h"
40 #include "util/real_algebraic_number.h"
42 using namespace cvc5::kind
;
50 template <typename L
, typename R
>
51 bool evaluateRelation(Kind rel
, const L
& l
, const R
& r
)
55 case Kind::LT
: return l
< r
;
56 case Kind::LEQ
: return l
<= r
;
57 case Kind::EQUAL
: return l
== r
;
58 case Kind::GEQ
: return l
>= r
;
59 case Kind::GT
: return l
> r
;
60 default: Unreachable(); return false;
65 * Check whether the parent has a child that is a constant zero.
66 * If so, return this child. Otherwise, return std::nullopt.
68 template <typename Iterable
>
69 std::optional
<TNode
> getZeroChild(const Iterable
& parent
)
71 for (const auto& node
: parent
)
73 if (node
.isConst() && node
.template getConst
<Rational
>().isZero())
83 ArithRewriter::ArithRewriter(OperatorElim
& oe
) : d_opElim(oe
) {}
85 RewriteResponse
ArithRewriter::preRewrite(TNode t
)
87 Trace("arith-rewriter") << "preRewrite(" << t
<< ")" << std::endl
;
90 auto res
= preRewriteAtom(t
);
91 Trace("arith-rewriter")
92 << res
.d_status
<< " -> " << res
.d_node
<< std::endl
;
95 auto res
= preRewriteTerm(t
);
96 Trace("arith-rewriter") << res
.d_status
<< " -> " << res
.d_node
<< std::endl
;
100 RewriteResponse
ArithRewriter::postRewrite(TNode t
)
102 Trace("arith-rewriter") << "postRewrite(" << t
<< ")" << std::endl
;
105 auto res
= postRewriteAtom(t
);
106 Trace("arith-rewriter")
107 << res
.d_status
<< " -> " << res
.d_node
<< std::endl
;
110 auto res
= postRewriteTerm(t
);
111 Trace("arith-rewriter") << res
.d_status
<< " -> " << res
.d_node
<< std::endl
;
115 RewriteResponse
ArithRewriter::preRewriteAtom(TNode atom
)
117 Assert(isAtom(atom
));
119 NodeManager
* nm
= NodeManager::currentNM();
121 if (isRelationOperator(atom
.getKind()) && atom
[0] == atom
[1])
123 switch (atom
.getKind())
125 case Kind::LT
: return RewriteResponse(REWRITE_DONE
, nm
->mkConst(false));
126 case Kind::LEQ
: return RewriteResponse(REWRITE_DONE
, nm
->mkConst(true));
127 case Kind::EQUAL
: return RewriteResponse(REWRITE_DONE
, nm
->mkConst(true));
128 case Kind::GEQ
: return RewriteResponse(REWRITE_DONE
, nm
->mkConst(true));
129 case Kind::GT
: return RewriteResponse(REWRITE_DONE
, nm
->mkConst(false));
134 switch (atom
.getKind())
137 return RewriteResponse(REWRITE_DONE
,
138 nm
->mkNode(kind::LEQ
, atom
[0], atom
[1]).notNode());
140 return RewriteResponse(REWRITE_DONE
,
141 nm
->mkNode(kind::GEQ
, atom
[0], atom
[1]).notNode());
142 case Kind::IS_INTEGER
:
143 if (atom
[0].getType().isInteger())
145 return RewriteResponse(REWRITE_DONE
, nm
->mkConst(true));
148 case Kind::DIVISIBLE
:
149 if (atom
.getOperator().getConst
<Divisible
>().k
.isOne())
151 return RewriteResponse(REWRITE_DONE
, nm
->mkConst(true));
157 return RewriteResponse(REWRITE_DONE
, atom
);
160 RewriteResponse
ArithRewriter::postRewriteAtom(TNode atom
)
162 Assert(isAtom(atom
));
163 if (atom
.getKind() == kind::IS_INTEGER
)
165 return rewriteExtIntegerOp(atom
);
167 else if (atom
.getKind() == kind::DIVISIBLE
)
169 if (atom
[0].isConst())
171 return RewriteResponse(REWRITE_DONE
,
172 NodeManager::currentNM()->mkConst(bool(
173 (atom
[0].getConst
<Rational
>()
174 / atom
.getOperator().getConst
<Divisible
>().k
)
177 if (atom
.getOperator().getConst
<Divisible
>().k
.isOne())
179 return RewriteResponse(REWRITE_DONE
,
180 NodeManager::currentNM()->mkConst(true));
182 NodeManager
* nm
= NodeManager::currentNM();
183 return RewriteResponse(
185 nm
->mkNode(kind::EQUAL
,
186 nm
->mkNode(kind::INTS_MODULUS_TOTAL
,
188 nm
->mkConstInt(Rational(
189 atom
.getOperator().getConst
<Divisible
>().k
))),
190 nm
->mkConstInt(Rational(0))));
194 TNode left
= atom
[0];
195 TNode right
= atom
[1];
197 auto* nm
= NodeManager::currentNM();
200 const Rational
& l
= left
.getConst
<Rational
>();
203 const Rational
& r
= right
.getConst
<Rational
>();
204 return RewriteResponse(
205 REWRITE_DONE
, nm
->mkConst(evaluateRelation(atom
.getKind(), l
, r
)));
207 else if (right
.getKind() == Kind::REAL_ALGEBRAIC_NUMBER
)
209 const RealAlgebraicNumber
& r
=
210 right
.getOperator().getConst
<RealAlgebraicNumber
>();
211 return RewriteResponse(
212 REWRITE_DONE
, nm
->mkConst(evaluateRelation(atom
.getKind(), l
, r
)));
215 else if (left
.getKind() == Kind::REAL_ALGEBRAIC_NUMBER
)
217 const RealAlgebraicNumber
& l
=
218 left
.getOperator().getConst
<RealAlgebraicNumber
>();
221 const Rational
& r
= right
.getConst
<Rational
>();
222 return RewriteResponse(
223 REWRITE_DONE
, nm
->mkConst(evaluateRelation(atom
.getKind(), l
, r
)));
225 else if (right
.getKind() == Kind::REAL_ALGEBRAIC_NUMBER
)
227 const RealAlgebraicNumber
& r
=
228 right
.getOperator().getConst
<RealAlgebraicNumber
>();
229 return RewriteResponse(
230 REWRITE_DONE
, nm
->mkConst(evaluateRelation(atom
.getKind(), l
, r
)));
234 Polynomial pleft
= Polynomial::parsePolynomial(left
);
235 Polynomial pright
= Polynomial::parsePolynomial(right
);
237 Debug("arith::rewriter") << "pleft " << pleft
.getNode() << std::endl
;
238 Debug("arith::rewriter") << "pright " << pright
.getNode() << std::endl
;
240 Comparison cmp
= Comparison::mkComparison(atom
.getKind(), pleft
, pright
);
241 Assert(cmp
.isNormalForm());
242 return RewriteResponse(REWRITE_DONE
, cmp
.getNode());
245 bool ArithRewriter::isAtom(TNode n
) {
246 Kind k
= n
.getKind();
247 return arith::isRelationOperator(k
) || k
== kind::IS_INTEGER
248 || k
== kind::DIVISIBLE
;
251 RewriteResponse
ArithRewriter::rewriteConstant(TNode t
){
253 Assert(t
.getKind() == CONST_RATIONAL
|| t
.getKind() == CONST_INTEGER
);
255 return RewriteResponse(REWRITE_DONE
, t
);
258 RewriteResponse
ArithRewriter::rewriteRAN(TNode t
)
260 Assert(t
.getKind() == REAL_ALGEBRAIC_NUMBER
);
262 const RealAlgebraicNumber
& r
=
263 t
.getOperator().getConst
<RealAlgebraicNumber
>();
266 return RewriteResponse(
267 REWRITE_DONE
, NodeManager::currentNM()->mkConstReal(r
.toRational()));
270 return RewriteResponse(REWRITE_DONE
, t
);
273 RewriteResponse
ArithRewriter::rewriteVariable(TNode t
){
276 return RewriteResponse(REWRITE_DONE
, t
);
279 RewriteResponse
ArithRewriter::rewriteSub(TNode t
)
281 Assert(t
.getKind() == kind::SUB
);
282 Assert(t
.getNumChildren() == 2);
284 auto* nm
= NodeManager::currentNM();
288 return RewriteResponse(REWRITE_DONE
,
289 nm
->mkConstRealOrInt(t
.getType(), Rational(0)));
291 return RewriteResponse(REWRITE_AGAIN_FULL
,
292 nm
->mkNode(Kind::ADD
, t
[0], makeUnaryMinusNode(t
[1])));
295 RewriteResponse
ArithRewriter::rewriteNeg(TNode t
, bool pre
)
297 Assert(t
.getKind() == kind::NEG
);
301 Rational neg
= -(t
[0].getConst
<Rational
>());
302 NodeManager
* nm
= NodeManager::currentNM();
303 return RewriteResponse(REWRITE_DONE
,
304 nm
->mkConstRealOrInt(t
[0].getType(), neg
));
306 if (t
[0].getKind() == Kind::REAL_ALGEBRAIC_NUMBER
)
308 const RealAlgebraicNumber
& r
=
309 t
[0].getOperator().getConst
<RealAlgebraicNumber
>();
310 NodeManager
* nm
= NodeManager::currentNM();
311 return RewriteResponse(REWRITE_DONE
, nm
->mkRealAlgebraicNumber(-r
));
314 Node noUminus
= makeUnaryMinusNode(t
[0]);
316 return RewriteResponse(REWRITE_DONE
, noUminus
);
318 return RewriteResponse(REWRITE_AGAIN
, noUminus
);
321 RewriteResponse
ArithRewriter::preRewriteTerm(TNode t
){
323 return rewriteConstant(t
);
325 return rewriteVariable(t
);
327 switch(Kind k
= t
.getKind()){
328 case kind::REAL_ALGEBRAIC_NUMBER
: return rewriteRAN(t
);
329 case kind::SUB
: return rewriteSub(t
);
330 case kind::NEG
: return rewriteNeg(t
, true);
332 case kind::DIVISION_TOTAL
: return rewriteDiv(t
, true);
333 case kind::ADD
: return preRewritePlus(t
);
335 case kind::NONLINEAR_MULT
: return preRewriteMult(t
);
336 case kind::IAND
: return RewriteResponse(REWRITE_DONE
, t
);
337 case kind::POW2
: return RewriteResponse(REWRITE_DONE
, t
);
338 case kind::EXPONENTIAL
:
344 case kind::COTANGENT
:
346 case kind::ARCCOSINE
:
347 case kind::ARCTANGENT
:
348 case kind::ARCCOSECANT
:
349 case kind::ARCSECANT
:
350 case kind::ARCCOTANGENT
:
351 case kind::SQRT
: return preRewriteTranscendental(t
);
352 case kind::INTS_DIVISION
:
353 case kind::INTS_MODULUS
: return rewriteIntsDivMod(t
, true);
354 case kind::INTS_DIVISION_TOTAL
:
355 case kind::INTS_MODULUS_TOTAL
: return rewriteIntsDivModTotal(t
, true);
356 case kind::ABS
: return rewriteAbs(t
);
357 case kind::IS_INTEGER
:
358 case kind::TO_INTEGER
: return RewriteResponse(REWRITE_DONE
, t
);
360 case kind::CAST_TO_REAL
: return RewriteResponse(REWRITE_DONE
, t
[0]);
361 case kind::POW
: return RewriteResponse(REWRITE_DONE
, t
);
362 case kind::PI
: return RewriteResponse(REWRITE_DONE
, t
);
363 default: Unhandled() << k
;
368 RewriteResponse
ArithRewriter::postRewriteTerm(TNode t
){
370 return rewriteConstant(t
);
372 return rewriteVariable(t
);
374 Trace("arith-rewriter") << "postRewriteTerm: " << t
<< std::endl
;
376 case kind::REAL_ALGEBRAIC_NUMBER
: return rewriteRAN(t
);
377 case kind::SUB
: return rewriteSub(t
);
378 case kind::NEG
: return rewriteNeg(t
, false);
380 case kind::DIVISION_TOTAL
: return rewriteDiv(t
, false);
381 case kind::ADD
: return postRewritePlus(t
);
383 case kind::NONLINEAR_MULT
: return postRewriteMult(t
);
384 case kind::IAND
: return postRewriteIAnd(t
);
385 case kind::POW2
: return postRewritePow2(t
);
386 case kind::EXPONENTIAL
:
392 case kind::COTANGENT
:
394 case kind::ARCCOSINE
:
395 case kind::ARCTANGENT
:
396 case kind::ARCCOSECANT
:
397 case kind::ARCSECANT
:
398 case kind::ARCCOTANGENT
:
399 case kind::SQRT
: return postRewriteTranscendental(t
);
400 case kind::INTS_DIVISION
:
401 case kind::INTS_MODULUS
: return rewriteIntsDivMod(t
, false);
402 case kind::INTS_DIVISION_TOTAL
:
403 case kind::INTS_MODULUS_TOTAL
: return rewriteIntsDivModTotal(t
, false);
404 case kind::ABS
: return rewriteAbs(t
);
406 case kind::CAST_TO_REAL
: return RewriteResponse(REWRITE_DONE
, t
[0]);
407 case kind::TO_INTEGER
: return rewriteExtIntegerOp(t
);
412 const Rational
& exp
= t
[1].getConst
<Rational
>();
415 return RewriteResponse(REWRITE_DONE
,
416 NodeManager::currentNM()->mkConstRealOrInt(
417 t
.getType(), Rational(1)));
418 }else if(exp
.sgn() > 0 && exp
.isIntegral()){
419 cvc5::Rational
r(expr::NodeValue::MAX_CHILDREN
);
422 unsigned num
= exp
.getNumerator().toUnsignedInt();
424 return RewriteResponse(REWRITE_AGAIN
, base
);
426 NodeBuilder
nb(kind::MULT
);
427 for(unsigned i
=0; i
< num
; ++i
){
430 Assert(nb
.getNumChildren() > 0);
432 return RewriteResponse(REWRITE_AGAIN
, mult
);
437 else if (t
[0].isConst()
438 && t
[0].getConst
<Rational
>().getNumerator().toUnsignedInt()
441 return RewriteResponse(
442 REWRITE_DONE
, NodeManager::currentNM()->mkNode(kind::POW2
, t
[1]));
445 // Todo improve the exception thrown
446 std::stringstream ss
;
447 ss
<< "The exponent of the POW(^) operator can only be a positive "
448 "integral constant below "
449 << (expr::NodeValue::MAX_CHILDREN
+ 1) << ". ";
450 ss
<< "Exception occurred in:" << std::endl
;
452 throw LogicException(ss
.str());
455 return RewriteResponse(REWRITE_DONE
, t
);
463 RewriteResponse
ArithRewriter::preRewritePlus(TNode t
){
464 Assert(t
.getKind() == kind::ADD
);
465 return RewriteResponse(REWRITE_DONE
, expr::algorithm::flatten(t
));
468 RewriteResponse
ArithRewriter::postRewritePlus(TNode t
){
469 Assert(t
.getKind() == kind::ADD
);
470 Assert(t
.getNumChildren() > 1);
473 Node flat
= expr::algorithm::flatten(t
);
476 return RewriteResponse(REWRITE_AGAIN
, flat
);
481 RealAlgebraicNumber ran
;
482 std::vector
<Monomial
> monomials
;
483 std::vector
<Polynomial
> polynomials
;
485 for (const auto& child
: t
)
489 if (child
.getConst
<Rational
>().isZero())
493 rational
+= child
.getConst
<Rational
>();
495 else if (child
.getKind() == Kind::REAL_ALGEBRAIC_NUMBER
)
497 ran
+= child
.getOperator().getConst
<RealAlgebraicNumber
>();
499 else if (Monomial::isMember(child
))
501 monomials
.emplace_back(Monomial::parseMonomial(child
));
505 polynomials
.emplace_back(Polynomial::parsePolynomial(child
));
509 if(!monomials
.empty()){
510 Monomial::sort(monomials
);
511 Monomial::combineAdjacentMonomials(monomials
);
512 polynomials
.emplace_back(Polynomial::mkPolynomial(monomials
));
514 if (!rational
.isZero())
516 polynomials
.emplace_back(
517 Polynomial::mkPolynomial(Constant::mkConstant(rational
)));
520 Polynomial poly
= Polynomial::sumPolynomials(polynomials
);
524 return RewriteResponse(REWRITE_DONE
, poly
.getNode());
526 if (poly
.containsConstant())
528 ran
+= RealAlgebraicNumber(poly
.getHead().getConstant().getValue());
529 if (!poly
.isConstant())
531 poly
= poly
.getTail();
535 auto* nm
= NodeManager::currentNM();
536 if (poly
.isConstant())
538 return RewriteResponse(REWRITE_DONE
, nm
->mkRealAlgebraicNumber(ran
));
540 return RewriteResponse(
542 nm
->mkNode(Kind::ADD
, nm
->mkRealAlgebraicNumber(ran
), poly
.getNode()));
545 RewriteResponse
ArithRewriter::preRewriteMult(TNode node
)
547 Assert(node
.getKind() == kind::MULT
548 || node
.getKind() == kind::NONLINEAR_MULT
);
550 if (auto res
= rewriter::getZeroChild(node
); res
)
552 return RewriteResponse(REWRITE_DONE
, *res
);
554 return RewriteResponse(REWRITE_DONE
, node
);
557 RewriteResponse
ArithRewriter::postRewriteMult(TNode t
){
558 Assert(t
.getKind() == kind::MULT
|| t
.getKind() == kind::NONLINEAR_MULT
);
559 Assert(t
.getNumChildren() >= 2);
561 std::vector
<TNode
> children
;
562 expr::algorithm::flatten(t
, children
, Kind::MULT
, Kind::NONLINEAR_MULT
);
564 if (auto res
= rewriter::getZeroChild(children
); res
)
566 return RewriteResponse(REWRITE_DONE
, *res
);
569 // Distribute over addition
570 if (std::any_of(children
.begin(), children
.end(), [](TNode child
) {
571 return child
.getKind() == Kind::ADD
;
574 return RewriteResponse(REWRITE_DONE
,
575 rewriter::distributeMultiplication(children
));
578 RealAlgebraicNumber ran
= RealAlgebraicNumber(Integer(1));
579 std::vector
<Node
> leafs
;
581 for (const auto& child
: children
)
585 if (child
.getConst
<Rational
>().isZero())
587 return RewriteResponse(REWRITE_DONE
, child
);
589 ran
*= child
.getConst
<Rational
>();
591 else if (rewriter::isRAN(child
))
593 ran
*= rewriter::getRAN(child
);
597 leafs
.emplace_back(child
);
601 return RewriteResponse(REWRITE_DONE
,
602 rewriter::mkMultTerm(ran
, std::move(leafs
)));
605 Node
ArithRewriter::makeUnaryMinusNode(TNode n
)
607 NodeManager
* nm
= NodeManager::currentNM();
608 Rational
qNegOne(-1);
609 return nm
->mkNode(kind::MULT
, nm
->mkConstRealOrInt(n
.getType(), qNegOne
), n
);
612 RewriteResponse
ArithRewriter::rewriteDiv(TNode t
, bool pre
)
614 Assert(t
.getKind() == kind::DIVISION_TOTAL
|| t
.getKind() == kind::DIVISION
);
615 Assert(t
.getNumChildren() == 2);
621 NodeManager
* nm
= NodeManager::currentNM();
622 const Rational
& den
= right
.getConst
<Rational
>();
626 if (t
.getKind() == kind::DIVISION_TOTAL
)
628 return RewriteResponse(REWRITE_DONE
, nm
->mkConstReal(0));
632 // This is unsupported, but this is not a good place to complain
633 return RewriteResponse(REWRITE_DONE
, t
);
636 Assert(den
!= Rational(0));
640 const Rational
& num
= left
.getConst
<Rational
>();
641 return RewriteResponse(REWRITE_DONE
, nm
->mkConstReal(num
/ den
));
643 if (left
.getKind() == Kind::REAL_ALGEBRAIC_NUMBER
)
645 const RealAlgebraicNumber
& num
=
646 left
.getOperator().getConst
<RealAlgebraicNumber
>();
647 return RewriteResponse(
649 nm
->mkRealAlgebraicNumber(num
/ RealAlgebraicNumber(den
)));
652 Node result
= nm
->mkConstReal(den
.inverse());
653 Node mult
= NodeManager::currentNM()->mkNode(kind::MULT
, left
, result
);
656 return RewriteResponse(REWRITE_DONE
, mult
);
660 return RewriteResponse(REWRITE_AGAIN
, mult
);
663 if (right
.getKind() == Kind::REAL_ALGEBRAIC_NUMBER
)
665 NodeManager
* nm
= NodeManager::currentNM();
666 const RealAlgebraicNumber
& den
=
667 right
.getOperator().getConst
<RealAlgebraicNumber
>();
670 const Rational
& num
= left
.getConst
<Rational
>();
671 return RewriteResponse(
673 nm
->mkRealAlgebraicNumber(RealAlgebraicNumber(num
) / den
));
675 if (left
.getKind() == Kind::REAL_ALGEBRAIC_NUMBER
)
677 const RealAlgebraicNumber
& num
=
678 left
.getOperator().getConst
<RealAlgebraicNumber
>();
679 return RewriteResponse(REWRITE_DONE
,
680 nm
->mkRealAlgebraicNumber(num
/ den
));
683 Node result
= nm
->mkRealAlgebraicNumber(inverse(den
));
684 Node mult
= NodeManager::currentNM()->mkNode(kind::MULT
, left
, result
);
687 return RewriteResponse(REWRITE_DONE
, mult
);
691 return RewriteResponse(REWRITE_AGAIN
, mult
);
694 return RewriteResponse(REWRITE_DONE
, t
);
697 RewriteResponse
ArithRewriter::rewriteAbs(TNode t
)
699 Assert(t
.getKind() == Kind::ABS
);
700 Assert(t
.getNumChildren() == 1);
704 const Rational
& rat
= t
[0].getConst
<Rational
>();
707 return RewriteResponse(REWRITE_DONE
, t
[0]);
709 return RewriteResponse(
711 NodeManager::currentNM()->mkConstRealOrInt(t
[0].getType(), -rat
));
713 if (t
[0].getKind() == Kind::REAL_ALGEBRAIC_NUMBER
)
715 const RealAlgebraicNumber
& ran
=
716 t
[0].getOperator().getConst
<RealAlgebraicNumber
>();
717 if (ran
>= RealAlgebraicNumber())
719 return RewriteResponse(REWRITE_DONE
, t
[0]);
721 return RewriteResponse(
722 REWRITE_DONE
, NodeManager::currentNM()->mkRealAlgebraicNumber(-ran
));
724 return RewriteResponse(REWRITE_DONE
, t
);
727 RewriteResponse
ArithRewriter::rewriteIntsDivMod(TNode t
, bool pre
)
729 NodeManager
* nm
= NodeManager::currentNM();
730 Kind k
= t
.getKind();
731 if (k
== kind::INTS_MODULUS
)
733 if (t
[1].isConst() && !t
[1].getConst
<Rational
>().isZero())
735 // can immediately replace by INTS_MODULUS_TOTAL
736 Node ret
= nm
->mkNode(kind::INTS_MODULUS_TOTAL
, t
[0], t
[1]);
737 return returnRewrite(t
, ret
, Rewrite::MOD_TOTAL_BY_CONST
);
740 if (k
== kind::INTS_DIVISION
)
742 if (t
[1].isConst() && !t
[1].getConst
<Rational
>().isZero())
744 // can immediately replace by INTS_DIVISION_TOTAL
745 Node ret
= nm
->mkNode(kind::INTS_DIVISION_TOTAL
, t
[0], t
[1]);
746 return returnRewrite(t
, ret
, Rewrite::DIV_TOTAL_BY_CONST
);
749 return RewriteResponse(REWRITE_DONE
, t
);
752 RewriteResponse
ArithRewriter::rewriteExtIntegerOp(TNode t
)
754 Assert(t
.getKind() == kind::TO_INTEGER
|| t
.getKind() == kind::IS_INTEGER
);
755 bool isPred
= t
.getKind() == kind::IS_INTEGER
;
756 NodeManager
* nm
= NodeManager::currentNM();
762 ret
= nm
->mkConst(t
[0].getConst
<Rational
>().isIntegral());
766 ret
= nm
->mkConstInt(Rational(t
[0].getConst
<Rational
>().floor()));
768 return returnRewrite(t
, ret
, Rewrite::INT_EXT_CONST
);
770 if (t
[0].getType().isInteger())
772 Node ret
= isPred
? nm
->mkConst(true) : Node(t
[0]);
773 return returnRewrite(t
, ret
, Rewrite::INT_EXT_INT
);
775 if (t
[0].getKind() == kind::PI
)
777 Node ret
= isPred
? nm
->mkConst(false) : nm
->mkConstReal(Rational(3));
778 return returnRewrite(t
, ret
, Rewrite::INT_EXT_PI
);
780 return RewriteResponse(REWRITE_DONE
, t
);
783 RewriteResponse
ArithRewriter::rewriteIntsDivModTotal(TNode t
, bool pre
)
787 // do not rewrite at prewrite.
788 return RewriteResponse(REWRITE_DONE
, t
);
790 NodeManager
* nm
= NodeManager::currentNM();
791 Kind k
= t
.getKind();
792 Assert(k
== kind::INTS_MODULUS_TOTAL
|| k
== kind::INTS_DIVISION_TOTAL
);
795 bool dIsConstant
= d
.isConst();
796 if (dIsConstant
&& d
.getConst
<Rational
>().isZero())
798 // (div x 0) ---> 0 or (mod x 0) ---> 0
799 return returnRewrite(t
, nm
->mkConstInt(0), Rewrite::DIV_MOD_BY_ZERO
);
801 else if (dIsConstant
&& d
.getConst
<Rational
>().isOne())
803 if (k
== kind::INTS_MODULUS_TOTAL
)
806 return returnRewrite(t
, nm
->mkConstInt(0), Rewrite::MOD_BY_ONE
);
808 Assert(k
== kind::INTS_DIVISION_TOTAL
);
810 return returnRewrite(t
, n
, Rewrite::DIV_BY_ONE
);
812 else if (dIsConstant
&& d
.getConst
<Rational
>().sgn() < 0)
815 // (div x (- c)) ---> (- (div x c))
816 // (mod x (- c)) ---> (mod x c)
817 Node nn
= nm
->mkNode(k
, t
[0], nm
->mkConstInt(-t
[1].getConst
<Rational
>()));
818 Node ret
= (k
== kind::INTS_DIVISION
|| k
== kind::INTS_DIVISION_TOTAL
)
819 ? nm
->mkNode(kind::NEG
, nn
)
821 return returnRewrite(t
, ret
, Rewrite::DIV_MOD_PULL_NEG_DEN
);
823 else if (dIsConstant
&& n
.isConst())
825 Assert(d
.getConst
<Rational
>().isIntegral());
826 Assert(n
.getConst
<Rational
>().isIntegral());
827 Assert(!d
.getConst
<Rational
>().isZero());
828 Integer di
= d
.getConst
<Rational
>().getNumerator();
829 Integer ni
= n
.getConst
<Rational
>().getNumerator();
831 bool isDiv
= (k
== kind::INTS_DIVISION
|| k
== kind::INTS_DIVISION_TOTAL
);
833 Integer result
= isDiv
? ni
.euclidianDivideQuotient(di
)
834 : ni
.euclidianDivideRemainder(di
);
836 // constant evaluation
837 // (mod c1 c2) ---> c3 or (div c1 c2) ---> c3
838 Node resultNode
= nm
->mkConstInt(Rational(result
));
839 return returnRewrite(t
, resultNode
, Rewrite::CONST_EVAL
);
841 if (k
== kind::INTS_MODULUS_TOTAL
)
843 // Note these rewrites do not need to account for modulus by zero as being
844 // a UF, which is handled by the reduction of INTS_MODULUS.
845 Kind k0
= t
[0].getKind();
846 if (k0
== kind::INTS_MODULUS_TOTAL
&& t
[0][1] == t
[1])
848 // (mod (mod x c) c) --> (mod x c)
849 return returnRewrite(t
, t
[0], Rewrite::MOD_OVER_MOD
);
851 else if (k0
== kind::NONLINEAR_MULT
|| k0
== kind::MULT
|| k0
== kind::ADD
)
854 std::vector
<Node
> newChildren
;
855 bool childChanged
= false;
856 for (const Node
& tc
: t
[0])
858 if (tc
.getKind() == kind::INTS_MODULUS_TOTAL
&& tc
[1] == t
[1])
860 newChildren
.push_back(tc
[0]);
864 newChildren
.push_back(tc
);
868 // (mod (op ... (mod x c) ...) c) ---> (mod (op ... x ...) c) where
869 // op is one of { NONLINEAR_MULT, MULT, ADD }.
870 Node ret
= nm
->mkNode(k0
, newChildren
);
871 ret
= nm
->mkNode(kind::INTS_MODULUS_TOTAL
, ret
, t
[1]);
872 return returnRewrite(t
, ret
, Rewrite::MOD_CHILD_MOD
);
878 Assert(k
== kind::INTS_DIVISION_TOTAL
);
879 // Note these rewrites do not need to account for division by zero as being
880 // a UF, which is handled by the reduction of INTS_DIVISION.
881 if (t
[0].getKind() == kind::INTS_MODULUS_TOTAL
&& t
[0][1] == t
[1])
883 // (div (mod x c) c) --> 0
884 Node ret
= nm
->mkConstInt(0);
885 return returnRewrite(t
, ret
, Rewrite::DIV_OVER_MOD
);
888 return RewriteResponse(REWRITE_DONE
, t
);
891 RewriteResponse
ArithRewriter::postRewriteIAnd(TNode t
)
893 Assert(t
.getKind() == kind::IAND
);
894 size_t bsize
= t
.getOperator().getConst
<IntAnd
>().d_size
;
895 NodeManager
* nm
= NodeManager::currentNM();
896 // if constant, we eliminate
897 if (t
[0].isConst() && t
[1].isConst())
899 Node iToBvop
= nm
->mkConst(IntToBitVector(bsize
));
900 Node arg1
= nm
->mkNode(kind::INT_TO_BITVECTOR
, iToBvop
, t
[0]);
901 Node arg2
= nm
->mkNode(kind::INT_TO_BITVECTOR
, iToBvop
, t
[1]);
902 Node bvand
= nm
->mkNode(kind::BITVECTOR_AND
, arg1
, arg2
);
903 Node ret
= nm
->mkNode(kind::BITVECTOR_TO_NAT
, bvand
);
904 return RewriteResponse(REWRITE_AGAIN_FULL
, ret
);
906 else if (t
[0] > t
[1])
908 // ((_ iand k) x y) ---> ((_ iand k) y x) if x > y by node ordering
909 Node ret
= nm
->mkNode(kind::IAND
, t
.getOperator(), t
[1], t
[0]);
910 return RewriteResponse(REWRITE_AGAIN
, ret
);
912 else if (t
[0] == t
[1])
914 // ((_ iand k) x x) ---> (mod x 2^k)
915 Node twok
= nm
->mkConstInt(Rational(Integer(2).pow(bsize
)));
916 Node ret
= nm
->mkNode(kind::INTS_MODULUS
, t
[0], twok
);
917 return RewriteResponse(REWRITE_AGAIN
, ret
);
919 // simplifications involving constants
920 for (unsigned i
= 0; i
< 2; i
++)
926 if (t
[i
].getConst
<Rational
>().sgn() == 0)
928 // ((_ iand k) 0 y) ---> 0
929 return RewriteResponse(REWRITE_DONE
, t
[i
]);
931 if (t
[i
].getConst
<Rational
>().getNumerator() == Integer(2).pow(bsize
) - 1)
933 // ((_ iand k) 111...1 y) ---> (mod y 2^k)
934 Node twok
= nm
->mkConstInt(Rational(Integer(2).pow(bsize
)));
935 Node ret
= nm
->mkNode(kind::INTS_MODULUS
, t
[1 - i
], twok
);
936 return RewriteResponse(REWRITE_AGAIN
, ret
);
939 return RewriteResponse(REWRITE_DONE
, t
);
942 RewriteResponse
ArithRewriter::postRewritePow2(TNode t
)
944 Assert(t
.getKind() == kind::POW2
);
945 NodeManager
* nm
= NodeManager::currentNM();
946 // if constant, we eliminate
949 // pow2 is only supported for integers
950 Assert(t
[0].getType().isInteger());
951 Integer i
= t
[0].getConst
<Rational
>().getNumerator();
954 return RewriteResponse(REWRITE_DONE
, rewriter::mkConst(Integer(0)));
956 // (pow2 t) ---> (pow 2 t) and continue rewriting to eliminate pow
957 Node two
= rewriter::mkConst(Integer(2));
958 Node ret
= nm
->mkNode(kind::POW
, two
, t
[0]);
959 return RewriteResponse(REWRITE_AGAIN
, ret
);
961 return RewriteResponse(REWRITE_DONE
, t
);
964 RewriteResponse
ArithRewriter::preRewriteTranscendental(TNode t
)
966 return RewriteResponse(REWRITE_DONE
, t
);
969 RewriteResponse
ArithRewriter::postRewriteTranscendental(TNode t
)
971 Trace("arith-tf-rewrite")
972 << "Rewrite transcendental function : " << t
<< std::endl
;
973 NodeManager
* nm
= NodeManager::currentNM();
976 case kind::EXPONENTIAL
:
980 Node one
= rewriter::mkConst(Integer(1));
981 if (t
[0].getConst
<Rational
>().sgn() >= 0 && t
[0].getType().isInteger()
984 return RewriteResponse(
986 nm
->mkNode(kind::POW
, nm
->mkNode(kind::EXPONENTIAL
, one
), t
[0]));
990 return RewriteResponse(REWRITE_DONE
, t
);
993 else if (t
[0].getKind() == kind::ADD
)
995 std::vector
<Node
> product
;
996 for (const Node tc
: t
[0])
998 product
.push_back(nm
->mkNode(kind::EXPONENTIAL
, tc
));
1000 // We need to do a full rewrite here, since we can get exponentials of
1001 // constants, e.g. when we are rewriting exp(2 + x)
1002 return RewriteResponse(REWRITE_AGAIN_FULL
,
1003 nm
->mkNode(kind::MULT
, product
));
1010 const Rational
& rat
= t
[0].getConst
<Rational
>();
1013 return RewriteResponse(REWRITE_DONE
, rewriter::mkConst(Integer(0)));
1015 else if (rat
.sgn() == -1)
1017 Node ret
= nm
->mkNode(
1018 kind::NEG
, nm
->mkNode(kind::SINE
, rewriter::mkConst(-rat
)));
1019 return RewriteResponse(REWRITE_AGAIN_FULL
, ret
);
1022 else if ((t
[0].getKind() == MULT
|| t
[0].getKind() == NONLINEAR_MULT
)
1023 && t
[0][0].isConst() && t
[0][0].getConst
<Rational
>().sgn() == -1)
1025 // sin(-n*x) ---> -sin(n*x)
1026 std::vector
<Node
> mchildren(t
[0].begin(), t
[0].end());
1027 mchildren
[0] = nm
->mkConstReal(-t
[0][0].getConst
<Rational
>());
1028 Node ret
= nm
->mkNode(
1030 nm
->mkNode(kind::SINE
, nm
->mkNode(t
[0].getKind(), mchildren
)));
1031 return RewriteResponse(REWRITE_AGAIN_FULL
, ret
);
1035 // get the factor of PI in the argument
1039 std::map
<Node
, Node
> msum
;
1040 if (ArithMSum::getMonomialSum(t
[0], msum
))
1043 std::map
<Node
, Node
>::iterator itm
= msum
.find(pi
);
1044 if (itm
!= msum
.end())
1046 if (itm
->second
.isNull())
1048 pi_factor
= rewriter::mkConst(Integer(1));
1052 pi_factor
= itm
->second
;
1057 rem
= ArithMSum::mkNode(t
[0].getType(), msum
);
1066 // if there is a factor of PI
1067 if (!pi_factor
.isNull())
1069 Trace("arith-tf-rewrite-debug")
1070 << "Process pi factor = " << pi_factor
<< std::endl
;
1071 Rational r
= pi_factor
.getConst
<Rational
>();
1072 Rational r_abs
= r
.abs();
1073 Rational rone
= Rational(1);
1074 Rational rtwo
= Rational(2);
1077 // add/substract 2*pi beyond scope
1078 Rational ra_div_two
= (r_abs
+ rone
) / rtwo
;
1082 new_pi_factor
= nm
->mkConstReal(r
- rtwo
* ra_div_two
.floor());
1086 Assert(r
.sgn() == -1);
1087 new_pi_factor
= nm
->mkConstReal(r
+ rtwo
* ra_div_two
.floor());
1089 Node new_arg
= nm
->mkNode(kind::MULT
, new_pi_factor
, pi
);
1092 new_arg
= nm
->mkNode(kind::ADD
, new_arg
, rem
);
1094 // sin( 2*n*PI + x ) = sin( x )
1095 return RewriteResponse(REWRITE_AGAIN_FULL
,
1096 nm
->mkNode(kind::SINE
, new_arg
));
1098 else if (r_abs
== rone
)
1100 // sin( PI + x ) = -sin( x )
1103 return RewriteResponse(REWRITE_DONE
,
1104 nm
->mkConstReal(Rational(0)));
1108 return RewriteResponse(
1110 nm
->mkNode(kind::NEG
, nm
->mkNode(kind::SINE
, rem
)));
1113 else if (rem
.isNull())
1115 // other rational cases based on Niven's theorem
1116 // (https://en.wikipedia.org/wiki/Niven%27s_theorem)
1117 Integer one
= Integer(1);
1118 Integer two
= Integer(2);
1119 Integer six
= Integer(6);
1120 if (r_abs
.getDenominator() == two
)
1122 Assert(r_abs
.getNumerator() == one
);
1123 return RewriteResponse(REWRITE_DONE
,
1124 nm
->mkConstReal(Rational(r
.sgn())));
1126 else if (r_abs
.getDenominator() == six
)
1128 Integer five
= Integer(5);
1129 if (r_abs
.getNumerator() == one
|| r_abs
.getNumerator() == five
)
1131 return RewriteResponse(
1133 nm
->mkConstReal(Rational(r
.sgn()) / Rational(2)));
1142 return RewriteResponse(
1146 nm
->mkNode(kind::SUB
,
1147 nm
->mkNode(kind::MULT
,
1148 nm
->mkConstReal(Rational(1) / Rational(2)),
1155 return RewriteResponse(REWRITE_AGAIN_FULL
,
1156 nm
->mkNode(kind::DIVISION
,
1157 nm
->mkNode(kind::SINE
, t
[0]),
1158 nm
->mkNode(kind::COSINE
, t
[0])));
1161 case kind::COSECANT
:
1163 return RewriteResponse(REWRITE_AGAIN_FULL
,
1164 nm
->mkNode(kind::DIVISION
,
1165 nm
->mkConstReal(Rational(1)),
1166 nm
->mkNode(kind::SINE
, t
[0])));
1171 return RewriteResponse(REWRITE_AGAIN_FULL
,
1172 nm
->mkNode(kind::DIVISION
,
1173 nm
->mkConstReal(Rational(1)),
1174 nm
->mkNode(kind::COSINE
, t
[0])));
1177 case kind::COTANGENT
:
1179 return RewriteResponse(REWRITE_AGAIN_FULL
,
1180 nm
->mkNode(kind::DIVISION
,
1181 nm
->mkNode(kind::COSINE
, t
[0]),
1182 nm
->mkNode(kind::SINE
, t
[0])));
1187 return RewriteResponse(REWRITE_DONE
, t
);
1190 TrustNode
ArithRewriter::expandDefinition(Node node
)
1192 // call eliminate operators, to eliminate partial operators only
1193 std::vector
<SkolemLemma
> lems
;
1194 TrustNode ret
= d_opElim
.eliminate(node
, lems
, true);
1195 Assert(lems
.empty());
1199 RewriteResponse
ArithRewriter::returnRewrite(TNode t
, Node ret
, Rewrite r
)
1201 Trace("arith-rewriter") << "ArithRewriter : " << t
<< " == " << ret
<< " by "
1203 return RewriteResponse(REWRITE_AGAIN_FULL
, ret
);
1206 } // namespace arith
1207 } // namespace theory