1 /******************************************************************************
2 * Top contributors (to current version):
3 * Andrew Reynolds, Tim King, Morgan Deters
5 * This file is part of the cvc5 project.
7 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 * in the top-level source directory and their institutional affiliations.
9 * All rights reserved. See the file COPYING in the top-level source
10 * directory for licensing information.
11 * ****************************************************************************
13 * [[ Add one-line brief description here ]]
15 * [[ Add lengthier description here ]]
16 * \todo document this file
19 #include "theory/arith/arith_rewriter.h"
27 #include "expr/algorithm/flatten.h"
28 #include "smt/logic_exception.h"
29 #include "theory/arith/arith_msum.h"
30 #include "theory/arith/arith_utilities.h"
31 #include "theory/arith/normal_form.h"
32 #include "theory/arith/operator_elim.h"
33 #include "theory/theory.h"
34 #include "util/bitvector.h"
35 #include "util/divisible.h"
36 #include "util/iand.h"
37 #include "util/real_algebraic_number.h"
39 using namespace cvc5::kind
;
48 * Implements an ordering on arithmetic leaf nodes, excluding rationals. As this
49 * comparator is meant to be used on children of Kind::NONLINEAR_MULT, we expect
50 * rationals to be handled separately. Furthermore, we expect there to be only a
51 * single real algebraic number.
52 * It broadly categorizes leaf nodes into real algebraic numbers, integers,
53 * variables, and the rest. The ordering is built as follows:
54 * - real algebraic numbers come first
55 * - real terms come before integer terms
56 * - variables come before non-variable terms
57 * - finally, fall back to node ordering
59 struct LeafNodeComparator
61 /** Implements operator<(a, b) as described above */
62 bool operator()(TNode a
, TNode b
)
64 if (a
== b
) return false;
66 bool aIsRAN
= a
.getKind() == Kind::REAL_ALGEBRAIC_NUMBER
;
67 bool bIsRAN
= b
.getKind() == Kind::REAL_ALGEBRAIC_NUMBER
;
68 if (aIsRAN
!= bIsRAN
) return aIsRAN
;
69 Assert(!aIsRAN
&& !bIsRAN
) << "real algebraic numbers should be combined";
71 bool aIsInt
= a
.getType().isInteger();
72 bool bIsInt
= b
.getType().isInteger();
73 if (aIsInt
!= bIsInt
) return !aIsInt
;
75 bool aIsVar
= a
.isVar();
76 bool bIsVar
= b
.isVar();
77 if (aIsVar
!= bIsVar
) return aIsVar
;
84 * Implements an ordering on arithmetic nonlinear multiplications. As we assume
85 * rationals to be handled separately, we only consider Kind::NONLINEAR_MULT as
86 * multiplication terms. For individual factors of the product, we rely on the
87 * ordering from LeafNodeComparator. Furthermore, we expect products to be
88 * sorted according to LeafNodeComparator. The ordering is built as follows:
89 * - single factors come first (everything that is not NONLINEAR_MULT)
90 * - multiplications with less factors come first
91 * - multiplications are compared lexicographically
93 struct ProductNodeComparator
95 /** Implements operator<(a, b) as described above */
96 bool operator()(TNode a
, TNode b
)
98 if (a
== b
) return false;
100 Assert(a
.getKind() != Kind::MULT
);
101 Assert(b
.getKind() != Kind::MULT
);
103 bool aIsMult
= a
.getKind() == Kind::NONLINEAR_MULT
;
104 bool bIsMult
= b
.getKind() == Kind::NONLINEAR_MULT
;
105 if (aIsMult
!= bIsMult
) return !aIsMult
;
109 return LeafNodeComparator()(a
, b
);
112 size_t aLen
= a
.getNumChildren();
113 size_t bLen
= b
.getNumChildren();
114 if (aLen
!= bLen
) return aLen
< bLen
;
116 for (size_t i
= 0; i
< aLen
; ++i
)
120 return LeafNodeComparator()(a
[i
], b
[i
]);
123 Unreachable() << "Nodes are different, but have the same content";
129 template <typename L
, typename R
>
130 bool evaluateRelation(Kind rel
, const L
& l
, const R
& r
)
134 case Kind::LT
: return l
< r
;
135 case Kind::LEQ
: return l
<= r
;
136 case Kind::EQUAL
: return l
== r
;
137 case Kind::GEQ
: return l
>= r
;
138 case Kind::GT
: return l
> r
;
139 default: Unreachable(); return false;
144 * Check whether the parent has a child that is a constant zero.
145 * If so, return this child. Otherwise, return std::nullopt.
147 template <typename Iterable
>
148 std::optional
<TNode
> getZeroChild(const Iterable
& parent
)
150 for (const auto& node
: parent
)
152 if (node
.isConst() && node
.template getConst
<Rational
>().isZero())
162 ArithRewriter::ArithRewriter(OperatorElim
& oe
) : d_opElim(oe
) {}
164 RewriteResponse
ArithRewriter::preRewrite(TNode t
)
166 Trace("arith-rewriter") << "preRewrite(" << t
<< ")" << std::endl
;
169 auto res
= preRewriteAtom(t
);
170 Trace("arith-rewriter")
171 << res
.d_status
<< " -> " << res
.d_node
<< std::endl
;
174 auto res
= preRewriteTerm(t
);
175 Trace("arith-rewriter") << res
.d_status
<< " -> " << res
.d_node
<< std::endl
;
179 RewriteResponse
ArithRewriter::postRewrite(TNode t
)
181 Trace("arith-rewriter") << "postRewrite(" << t
<< ")" << std::endl
;
184 auto res
= postRewriteAtom(t
);
185 Trace("arith-rewriter")
186 << res
.d_status
<< " -> " << res
.d_node
<< std::endl
;
189 auto res
= postRewriteTerm(t
);
190 Trace("arith-rewriter") << res
.d_status
<< " -> " << res
.d_node
<< std::endl
;
194 RewriteResponse
ArithRewriter::preRewriteAtom(TNode atom
)
196 Assert(isAtom(atom
));
198 NodeManager
* nm
= NodeManager::currentNM();
200 if (isRelationOperator(atom
.getKind()) && atom
[0] == atom
[1])
202 switch (atom
.getKind())
204 case Kind::LT
: return RewriteResponse(REWRITE_DONE
, nm
->mkConst(false));
205 case Kind::LEQ
: return RewriteResponse(REWRITE_DONE
, nm
->mkConst(true));
206 case Kind::EQUAL
: return RewriteResponse(REWRITE_DONE
, nm
->mkConst(true));
207 case Kind::GEQ
: return RewriteResponse(REWRITE_DONE
, nm
->mkConst(true));
208 case Kind::GT
: return RewriteResponse(REWRITE_DONE
, nm
->mkConst(false));
213 switch (atom
.getKind())
216 return RewriteResponse(REWRITE_DONE
,
217 nm
->mkNode(kind::LEQ
, atom
[0], atom
[1]).notNode());
219 return RewriteResponse(REWRITE_DONE
,
220 nm
->mkNode(kind::GEQ
, atom
[0], atom
[1]).notNode());
221 case Kind::IS_INTEGER
:
222 if (atom
[0].getType().isInteger())
224 return RewriteResponse(REWRITE_DONE
, nm
->mkConst(true));
227 case Kind::DIVISIBLE
:
228 if (atom
.getOperator().getConst
<Divisible
>().k
.isOne())
230 return RewriteResponse(REWRITE_DONE
, nm
->mkConst(true));
236 return RewriteResponse(REWRITE_DONE
, atom
);
239 RewriteResponse
ArithRewriter::postRewriteAtom(TNode atom
)
241 Assert(isAtom(atom
));
242 if (atom
.getKind() == kind::IS_INTEGER
)
244 return rewriteExtIntegerOp(atom
);
246 else if (atom
.getKind() == kind::DIVISIBLE
)
248 if (atom
[0].isConst())
250 return RewriteResponse(REWRITE_DONE
,
251 NodeManager::currentNM()->mkConst(bool(
252 (atom
[0].getConst
<Rational
>()
253 / atom
.getOperator().getConst
<Divisible
>().k
)
256 if (atom
.getOperator().getConst
<Divisible
>().k
.isOne())
258 return RewriteResponse(REWRITE_DONE
,
259 NodeManager::currentNM()->mkConst(true));
261 NodeManager
* nm
= NodeManager::currentNM();
262 return RewriteResponse(
264 nm
->mkNode(kind::EQUAL
,
265 nm
->mkNode(kind::INTS_MODULUS_TOTAL
,
267 nm
->mkConstInt(Rational(
268 atom
.getOperator().getConst
<Divisible
>().k
))),
269 nm
->mkConstInt(Rational(0))));
273 TNode left
= atom
[0];
274 TNode right
= atom
[1];
276 auto* nm
= NodeManager::currentNM();
279 const Rational
& l
= left
.getConst
<Rational
>();
282 const Rational
& r
= right
.getConst
<Rational
>();
283 return RewriteResponse(
284 REWRITE_DONE
, nm
->mkConst(evaluateRelation(atom
.getKind(), l
, r
)));
286 else if (right
.getKind() == Kind::REAL_ALGEBRAIC_NUMBER
)
288 const RealAlgebraicNumber
& r
=
289 right
.getOperator().getConst
<RealAlgebraicNumber
>();
290 return RewriteResponse(
291 REWRITE_DONE
, nm
->mkConst(evaluateRelation(atom
.getKind(), l
, r
)));
294 else if (left
.getKind() == Kind::REAL_ALGEBRAIC_NUMBER
)
296 const RealAlgebraicNumber
& l
=
297 left
.getOperator().getConst
<RealAlgebraicNumber
>();
300 const Rational
& r
= right
.getConst
<Rational
>();
301 return RewriteResponse(
302 REWRITE_DONE
, nm
->mkConst(evaluateRelation(atom
.getKind(), l
, r
)));
304 else if (right
.getKind() == Kind::REAL_ALGEBRAIC_NUMBER
)
306 const RealAlgebraicNumber
& r
=
307 right
.getOperator().getConst
<RealAlgebraicNumber
>();
308 return RewriteResponse(
309 REWRITE_DONE
, nm
->mkConst(evaluateRelation(atom
.getKind(), l
, r
)));
313 Polynomial pleft
= Polynomial::parsePolynomial(left
);
314 Polynomial pright
= Polynomial::parsePolynomial(right
);
316 Debug("arith::rewriter") << "pleft " << pleft
.getNode() << std::endl
;
317 Debug("arith::rewriter") << "pright " << pright
.getNode() << std::endl
;
319 Comparison cmp
= Comparison::mkComparison(atom
.getKind(), pleft
, pright
);
320 Assert(cmp
.isNormalForm());
321 return RewriteResponse(REWRITE_DONE
, cmp
.getNode());
324 bool ArithRewriter::isAtom(TNode n
) {
325 Kind k
= n
.getKind();
326 return arith::isRelationOperator(k
) || k
== kind::IS_INTEGER
327 || k
== kind::DIVISIBLE
;
330 RewriteResponse
ArithRewriter::rewriteConstant(TNode t
){
332 Assert(t
.getKind() == CONST_RATIONAL
|| t
.getKind() == CONST_INTEGER
);
334 return RewriteResponse(REWRITE_DONE
, t
);
337 RewriteResponse
ArithRewriter::rewriteRAN(TNode t
)
339 Assert(t
.getKind() == REAL_ALGEBRAIC_NUMBER
);
341 const RealAlgebraicNumber
& r
=
342 t
.getOperator().getConst
<RealAlgebraicNumber
>();
345 return RewriteResponse(
346 REWRITE_DONE
, NodeManager::currentNM()->mkConstReal(r
.toRational()));
349 return RewriteResponse(REWRITE_DONE
, t
);
352 RewriteResponse
ArithRewriter::rewriteVariable(TNode t
){
355 return RewriteResponse(REWRITE_DONE
, t
);
358 RewriteResponse
ArithRewriter::rewriteSub(TNode t
)
360 Assert(t
.getKind() == kind::SUB
);
361 Assert(t
.getNumChildren() == 2);
363 auto* nm
= NodeManager::currentNM();
367 return RewriteResponse(REWRITE_DONE
,
368 nm
->mkConstRealOrInt(t
.getType(), Rational(0)));
370 return RewriteResponse(
372 nm
->mkNode(Kind::PLUS
, t
[0], makeUnaryMinusNode(t
[1])));
375 RewriteResponse
ArithRewriter::rewriteNeg(TNode t
, bool pre
)
377 Assert(t
.getKind() == kind::NEG
);
381 Rational neg
= -(t
[0].getConst
<Rational
>());
382 NodeManager
* nm
= NodeManager::currentNM();
383 return RewriteResponse(REWRITE_DONE
,
384 nm
->mkConstRealOrInt(t
[0].getType(), neg
));
386 if (t
[0].getKind() == Kind::REAL_ALGEBRAIC_NUMBER
)
388 const RealAlgebraicNumber
& r
=
389 t
[0].getOperator().getConst
<RealAlgebraicNumber
>();
390 NodeManager
* nm
= NodeManager::currentNM();
391 return RewriteResponse(REWRITE_DONE
, nm
->mkRealAlgebraicNumber(-r
));
394 Node noUminus
= makeUnaryMinusNode(t
[0]);
396 return RewriteResponse(REWRITE_DONE
, noUminus
);
398 return RewriteResponse(REWRITE_AGAIN
, noUminus
);
401 RewriteResponse
ArithRewriter::preRewriteTerm(TNode t
){
403 return rewriteConstant(t
);
405 return rewriteVariable(t
);
407 switch(Kind k
= t
.getKind()){
408 case kind::REAL_ALGEBRAIC_NUMBER
: return rewriteRAN(t
);
409 case kind::SUB
: return rewriteSub(t
);
410 case kind::NEG
: return rewriteNeg(t
, true);
412 case kind::DIVISION_TOTAL
: return rewriteDiv(t
, true);
413 case kind::PLUS
: return preRewritePlus(t
);
415 case kind::NONLINEAR_MULT
: return preRewriteMult(t
);
416 case kind::IAND
: return RewriteResponse(REWRITE_DONE
, t
);
417 case kind::POW2
: return RewriteResponse(REWRITE_DONE
, t
);
418 case kind::EXPONENTIAL
:
424 case kind::COTANGENT
:
426 case kind::ARCCOSINE
:
427 case kind::ARCTANGENT
:
428 case kind::ARCCOSECANT
:
429 case kind::ARCSECANT
:
430 case kind::ARCCOTANGENT
:
431 case kind::SQRT
: return preRewriteTranscendental(t
);
432 case kind::INTS_DIVISION
:
433 case kind::INTS_MODULUS
: return rewriteIntsDivMod(t
, true);
434 case kind::INTS_DIVISION_TOTAL
:
435 case kind::INTS_MODULUS_TOTAL
: return rewriteIntsDivModTotal(t
, true);
436 case kind::ABS
: return rewriteAbs(t
);
437 case kind::IS_INTEGER
:
438 case kind::TO_INTEGER
: return RewriteResponse(REWRITE_DONE
, t
);
440 case kind::CAST_TO_REAL
: return RewriteResponse(REWRITE_DONE
, t
[0]);
441 case kind::POW
: return RewriteResponse(REWRITE_DONE
, t
);
442 case kind::PI
: return RewriteResponse(REWRITE_DONE
, t
);
443 default: Unhandled() << k
;
448 RewriteResponse
ArithRewriter::postRewriteTerm(TNode t
){
450 return rewriteConstant(t
);
452 return rewriteVariable(t
);
454 Trace("arith-rewriter") << "postRewriteTerm: " << t
<< std::endl
;
456 case kind::REAL_ALGEBRAIC_NUMBER
: return rewriteRAN(t
);
457 case kind::SUB
: return rewriteSub(t
);
458 case kind::NEG
: return rewriteNeg(t
, false);
460 case kind::DIVISION_TOTAL
: return rewriteDiv(t
, false);
461 case kind::PLUS
: return postRewritePlus(t
);
463 case kind::NONLINEAR_MULT
: return postRewriteMult(t
);
464 case kind::IAND
: return postRewriteIAnd(t
);
465 case kind::POW2
: return postRewritePow2(t
);
466 case kind::EXPONENTIAL
:
472 case kind::COTANGENT
:
474 case kind::ARCCOSINE
:
475 case kind::ARCTANGENT
:
476 case kind::ARCCOSECANT
:
477 case kind::ARCSECANT
:
478 case kind::ARCCOTANGENT
:
479 case kind::SQRT
: return postRewriteTranscendental(t
);
480 case kind::INTS_DIVISION
:
481 case kind::INTS_MODULUS
: return rewriteIntsDivMod(t
, false);
482 case kind::INTS_DIVISION_TOTAL
:
483 case kind::INTS_MODULUS_TOTAL
: return rewriteIntsDivModTotal(t
, false);
484 case kind::ABS
: return rewriteAbs(t
);
486 case kind::CAST_TO_REAL
: return RewriteResponse(REWRITE_DONE
, t
[0]);
487 case kind::TO_INTEGER
: return rewriteExtIntegerOp(t
);
492 const Rational
& exp
= t
[1].getConst
<Rational
>();
495 return RewriteResponse(REWRITE_DONE
,
496 NodeManager::currentNM()->mkConstRealOrInt(
497 t
.getType(), Rational(1)));
498 }else if(exp
.sgn() > 0 && exp
.isIntegral()){
499 cvc5::Rational
r(expr::NodeValue::MAX_CHILDREN
);
502 unsigned num
= exp
.getNumerator().toUnsignedInt();
504 return RewriteResponse(REWRITE_AGAIN
, base
);
506 NodeBuilder
nb(kind::MULT
);
507 for(unsigned i
=0; i
< num
; ++i
){
510 Assert(nb
.getNumChildren() > 0);
512 return RewriteResponse(REWRITE_AGAIN
, mult
);
517 else if (t
[0].isConst()
518 && t
[0].getConst
<Rational
>().getNumerator().toUnsignedInt()
521 return RewriteResponse(
522 REWRITE_DONE
, NodeManager::currentNM()->mkNode(kind::POW2
, t
[1]));
525 // Todo improve the exception thrown
526 std::stringstream ss
;
527 ss
<< "The exponent of the POW(^) operator can only be a positive "
528 "integral constant below "
529 << (expr::NodeValue::MAX_CHILDREN
+ 1) << ". ";
530 ss
<< "Exception occurred in:" << std::endl
;
532 throw LogicException(ss
.str());
535 return RewriteResponse(REWRITE_DONE
, t
);
543 RewriteResponse
ArithRewriter::preRewritePlus(TNode t
){
544 Assert(t
.getKind() == kind::PLUS
);
545 return RewriteResponse(REWRITE_DONE
, expr::algorithm::flatten(t
));
548 RewriteResponse
ArithRewriter::postRewritePlus(TNode t
){
549 Assert(t
.getKind() == kind::PLUS
);
550 Assert(t
.getNumChildren() > 1);
553 Node flat
= expr::algorithm::flatten(t
);
556 return RewriteResponse(REWRITE_AGAIN
, flat
);
561 RealAlgebraicNumber ran
;
562 std::vector
<Monomial
> monomials
;
563 std::vector
<Polynomial
> polynomials
;
565 for (const auto& child
: t
)
569 if (child
.getConst
<Rational
>().isZero())
573 rational
+= child
.getConst
<Rational
>();
575 else if (child
.getKind() == Kind::REAL_ALGEBRAIC_NUMBER
)
577 ran
+= child
.getOperator().getConst
<RealAlgebraicNumber
>();
579 else if (Monomial::isMember(child
))
581 monomials
.emplace_back(Monomial::parseMonomial(child
));
585 polynomials
.emplace_back(Polynomial::parsePolynomial(child
));
589 if(!monomials
.empty()){
590 Monomial::sort(monomials
);
591 Monomial::combineAdjacentMonomials(monomials
);
592 polynomials
.emplace_back(Polynomial::mkPolynomial(monomials
));
594 if (!rational
.isZero())
596 polynomials
.emplace_back(
597 Polynomial::mkPolynomial(Constant::mkConstant(rational
)));
600 Polynomial poly
= Polynomial::sumPolynomials(polynomials
);
604 return RewriteResponse(REWRITE_DONE
, poly
.getNode());
606 if (poly
.containsConstant())
608 ran
+= RealAlgebraicNumber(poly
.getHead().getConstant().getValue());
609 if (!poly
.isConstant())
611 poly
= poly
.getTail();
615 auto* nm
= NodeManager::currentNM();
616 if (poly
.isConstant())
618 return RewriteResponse(REWRITE_DONE
, nm
->mkRealAlgebraicNumber(ran
));
620 return RewriteResponse(
622 nm
->mkNode(Kind::PLUS
, nm
->mkRealAlgebraicNumber(ran
), poly
.getNode()));
625 RewriteResponse
ArithRewriter::preRewriteMult(TNode node
)
627 Assert(node
.getKind() == kind::MULT
628 || node
.getKind() == kind::NONLINEAR_MULT
);
630 auto res
= getZeroChild(node
);
633 return RewriteResponse(REWRITE_DONE
, *res
);
635 return RewriteResponse(REWRITE_DONE
, node
);
638 RewriteResponse
ArithRewriter::postRewriteMult(TNode t
){
639 Assert(t
.getKind() == kind::MULT
|| t
.getKind() == kind::NONLINEAR_MULT
);
640 Assert(t
.getNumChildren() >= 2);
642 if (auto res
= getZeroChild(t
); res
)
644 return RewriteResponse(REWRITE_DONE
, *res
);
647 Rational rational
= Rational(1);
648 RealAlgebraicNumber ran
= RealAlgebraicNumber(Integer(1));
649 Polynomial poly
= Polynomial::mkOne();
651 for (const auto& child
: t
)
655 if (child
.getConst
<Rational
>().isZero())
657 return RewriteResponse(REWRITE_DONE
, child
);
659 rational
*= child
.getConst
<Rational
>();
661 else if (child
.getKind() == Kind::REAL_ALGEBRAIC_NUMBER
)
663 ran
*= child
.getOperator().getConst
<RealAlgebraicNumber
>();
667 poly
= poly
* Polynomial::parsePolynomial(child
);
671 if (!rational
.isOne())
673 poly
= poly
* rational
;
677 return RewriteResponse(REWRITE_DONE
, poly
.getNode());
679 auto* nm
= NodeManager::currentNM();
680 if (poly
.isConstant())
682 ran
*= RealAlgebraicNumber(poly
.getHead().getConstant().getValue());
683 return RewriteResponse(REWRITE_DONE
, nm
->mkRealAlgebraicNumber(ran
));
685 return RewriteResponse(
688 Kind::MULT
, nm
->mkRealAlgebraicNumber(ran
), poly
.getNode()));
691 RewriteResponse
ArithRewriter::postRewritePow2(TNode t
)
693 Assert(t
.getKind() == kind::POW2
);
694 NodeManager
* nm
= NodeManager::currentNM();
695 // if constant, we eliminate
698 // pow2 is only supported for integers
699 Assert(t
[0].getType().isInteger());
700 Integer i
= t
[0].getConst
<Rational
>().getNumerator();
703 return RewriteResponse(REWRITE_DONE
, nm
->mkConstInt(Rational(0)));
705 // (pow2 t) ---> (pow 2 t) and continue rewriting to eliminate pow
706 Node two
= nm
->mkConstInt(Rational(Integer(2)));
707 Node ret
= nm
->mkNode(kind::POW
, two
, t
[0]);
708 return RewriteResponse(REWRITE_AGAIN
, ret
);
710 return RewriteResponse(REWRITE_DONE
, t
);
713 RewriteResponse
ArithRewriter::postRewriteIAnd(TNode t
)
715 Assert(t
.getKind() == kind::IAND
);
716 size_t bsize
= t
.getOperator().getConst
<IntAnd
>().d_size
;
717 NodeManager
* nm
= NodeManager::currentNM();
718 // if constant, we eliminate
719 if (t
[0].isConst() && t
[1].isConst())
721 Node iToBvop
= nm
->mkConst(IntToBitVector(bsize
));
722 Node arg1
= nm
->mkNode(kind::INT_TO_BITVECTOR
, iToBvop
, t
[0]);
723 Node arg2
= nm
->mkNode(kind::INT_TO_BITVECTOR
, iToBvop
, t
[1]);
724 Node bvand
= nm
->mkNode(kind::BITVECTOR_AND
, arg1
, arg2
);
725 Node ret
= nm
->mkNode(kind::BITVECTOR_TO_NAT
, bvand
);
726 return RewriteResponse(REWRITE_AGAIN_FULL
, ret
);
728 else if (t
[0] > t
[1])
730 // ((_ iand k) x y) ---> ((_ iand k) y x) if x > y by node ordering
731 Node ret
= nm
->mkNode(kind::IAND
, t
.getOperator(), t
[1], t
[0]);
732 return RewriteResponse(REWRITE_AGAIN
, ret
);
734 else if (t
[0] == t
[1])
736 // ((_ iand k) x x) ---> x
737 return RewriteResponse(REWRITE_DONE
, t
[0]);
739 // simplifications involving constants
740 for (unsigned i
= 0; i
< 2; i
++)
746 if (t
[i
].getConst
<Rational
>().sgn() == 0)
748 // ((_ iand k) 0 y) ---> 0
749 return RewriteResponse(REWRITE_DONE
, t
[i
]);
751 if (t
[i
].getConst
<Rational
>().getNumerator() == Integer(2).pow(bsize
) - 1)
753 // ((_ iand k) 111...1 y) ---> (mod y 2^k)
754 Node twok
= nm
->mkConstInt(Rational(Integer(2).pow(bsize
)));
755 Node ret
= nm
->mkNode(kind::INTS_MODULUS
, t
[1-i
], twok
);
756 return RewriteResponse(REWRITE_AGAIN
, ret
);
759 return RewriteResponse(REWRITE_DONE
, t
);
762 RewriteResponse
ArithRewriter::preRewriteTranscendental(TNode t
) {
763 return RewriteResponse(REWRITE_DONE
, t
);
766 RewriteResponse
ArithRewriter::postRewriteTranscendental(TNode t
) {
767 Trace("arith-tf-rewrite") << "Rewrite transcendental function : " << t
<< std::endl
;
768 NodeManager
* nm
= NodeManager::currentNM();
769 switch( t
.getKind() ){
770 case kind::EXPONENTIAL
: {
773 Node one
= nm
->mkConstReal(Rational(1));
774 if(t
[0].getConst
<Rational
>().sgn()>=0 && t
[0].getType().isInteger() && t
[0]!=one
){
775 return RewriteResponse(
777 nm
->mkNode(kind::POW
, nm
->mkNode(kind::EXPONENTIAL
, one
), t
[0]));
779 return RewriteResponse(REWRITE_DONE
, t
);
782 else if (t
[0].getKind() == kind::PLUS
)
784 std::vector
<Node
> product
;
785 for (const Node tc
: t
[0])
787 product
.push_back(nm
->mkNode(kind::EXPONENTIAL
, tc
));
789 // We need to do a full rewrite here, since we can get exponentials of
790 // constants, e.g. when we are rewriting exp(2 + x)
791 return RewriteResponse(REWRITE_AGAIN_FULL
,
792 nm
->mkNode(kind::MULT
, product
));
799 const Rational
& rat
= t
[0].getConst
<Rational
>();
801 return RewriteResponse(REWRITE_DONE
, nm
->mkConstReal(Rational(0)));
803 else if (rat
.sgn() == -1)
805 Node ret
= nm
->mkNode(kind::NEG
,
806 nm
->mkNode(kind::SINE
, nm
->mkConstReal(-rat
)));
807 return RewriteResponse(REWRITE_AGAIN_FULL
, ret
);
810 // get the factor of PI in the argument
814 std::map
<Node
, Node
> msum
;
815 if (ArithMSum::getMonomialSum(t
[0], msum
))
818 std::map
<Node
, Node
>::iterator itm
= msum
.find(pi
);
819 if (itm
!= msum
.end())
821 if (itm
->second
.isNull())
823 pi_factor
= nm
->mkConstReal(Rational(1));
827 pi_factor
= itm
->second
;
832 rem
= ArithMSum::mkNode(t
[0].getType(), msum
);
841 // if there is a factor of PI
842 if( !pi_factor
.isNull() ){
843 Trace("arith-tf-rewrite-debug") << "Process pi factor = " << pi_factor
<< std::endl
;
844 Rational r
= pi_factor
.getConst
<Rational
>();
845 Rational r_abs
= r
.abs();
846 Rational rone
= Rational(1);
847 Rational rtwo
= Rational(2);
850 //add/substract 2*pi beyond scope
851 Rational ra_div_two
= (r_abs
+ rone
) / rtwo
;
854 new_pi_factor
= nm
->mkConstReal(r
- rtwo
* ra_div_two
.floor());
856 Assert(r
.sgn() == -1);
857 new_pi_factor
= nm
->mkConstReal(r
+ rtwo
* ra_div_two
.floor());
859 Node new_arg
= nm
->mkNode(kind::MULT
, new_pi_factor
, pi
);
862 new_arg
= nm
->mkNode(kind::PLUS
, new_arg
, rem
);
864 // sin( 2*n*PI + x ) = sin( x )
865 return RewriteResponse(REWRITE_AGAIN_FULL
,
866 nm
->mkNode(kind::SINE
, new_arg
));
868 else if (r_abs
== rone
)
870 // sin( PI + x ) = -sin( x )
873 return RewriteResponse(REWRITE_DONE
, nm
->mkConstReal(Rational(0)));
877 return RewriteResponse(
879 nm
->mkNode(kind::NEG
, nm
->mkNode(kind::SINE
, rem
)));
882 else if (rem
.isNull())
884 // other rational cases based on Niven's theorem
885 // (https://en.wikipedia.org/wiki/Niven%27s_theorem)
886 Integer one
= Integer(1);
887 Integer two
= Integer(2);
888 Integer six
= Integer(6);
889 if (r_abs
.getDenominator() == two
)
891 Assert(r_abs
.getNumerator() == one
);
892 return RewriteResponse(REWRITE_DONE
,
893 nm
->mkConstReal(Rational(r
.sgn())));
895 else if (r_abs
.getDenominator() == six
)
897 Integer five
= Integer(5);
898 if (r_abs
.getNumerator() == one
|| r_abs
.getNumerator() == five
)
900 return RewriteResponse(
902 nm
->mkConstReal(Rational(r
.sgn()) / Rational(2)));
910 return RewriteResponse(
914 nm
->mkNode(kind::SUB
,
915 nm
->mkNode(kind::MULT
,
916 nm
->mkConstReal(Rational(1) / Rational(2)),
923 return RewriteResponse(REWRITE_AGAIN_FULL
,
924 nm
->mkNode(kind::DIVISION
,
925 nm
->mkNode(kind::SINE
, t
[0]),
926 nm
->mkNode(kind::COSINE
, t
[0])));
931 return RewriteResponse(REWRITE_AGAIN_FULL
,
932 nm
->mkNode(kind::DIVISION
,
933 nm
->mkConstReal(Rational(1)),
934 nm
->mkNode(kind::SINE
, t
[0])));
939 return RewriteResponse(REWRITE_AGAIN_FULL
,
940 nm
->mkNode(kind::DIVISION
,
941 nm
->mkConstReal(Rational(1)),
942 nm
->mkNode(kind::COSINE
, t
[0])));
945 case kind::COTANGENT
:
947 return RewriteResponse(REWRITE_AGAIN_FULL
,
948 nm
->mkNode(kind::DIVISION
,
949 nm
->mkNode(kind::COSINE
, t
[0]),
950 nm
->mkNode(kind::SINE
, t
[0])));
956 return RewriteResponse(REWRITE_DONE
, t
);
959 Node
ArithRewriter::makeUnaryMinusNode(TNode n
){
960 NodeManager
* nm
= NodeManager::currentNM();
961 Rational
qNegOne(-1);
962 return nm
->mkNode(kind::MULT
, nm
->mkConstRealOrInt(n
.getType(), qNegOne
), n
);
965 RewriteResponse
ArithRewriter::rewriteDiv(TNode t
, bool pre
){
966 Assert(t
.getKind() == kind::DIVISION_TOTAL
|| t
.getKind() == kind::DIVISION
);
967 Assert(t
.getNumChildren() == 2);
973 NodeManager
* nm
= NodeManager::currentNM();
974 const Rational
& den
= right
.getConst
<Rational
>();
977 if(t
.getKind() == kind::DIVISION_TOTAL
){
978 return RewriteResponse(REWRITE_DONE
, nm
->mkConstReal(0));
980 // This is unsupported, but this is not a good place to complain
981 return RewriteResponse(REWRITE_DONE
, t
);
984 Assert(den
!= Rational(0));
988 const Rational
& num
= left
.getConst
<Rational
>();
989 return RewriteResponse(REWRITE_DONE
, nm
->mkConstReal(num
/ den
));
991 if (left
.getKind() == Kind::REAL_ALGEBRAIC_NUMBER
)
993 const RealAlgebraicNumber
& num
=
994 left
.getOperator().getConst
<RealAlgebraicNumber
>();
995 return RewriteResponse(
997 nm
->mkRealAlgebraicNumber(num
/ RealAlgebraicNumber(den
)));
1000 Node result
= nm
->mkConstReal(den
.inverse());
1001 Node mult
= NodeManager::currentNM()->mkNode(kind::MULT
, left
, result
);
1004 return RewriteResponse(REWRITE_DONE
, mult
);
1008 return RewriteResponse(REWRITE_AGAIN
, mult
);
1011 if (right
.getKind() == Kind::REAL_ALGEBRAIC_NUMBER
)
1013 NodeManager
* nm
= NodeManager::currentNM();
1014 const RealAlgebraicNumber
& den
=
1015 right
.getOperator().getConst
<RealAlgebraicNumber
>();
1018 const Rational
& num
= left
.getConst
<Rational
>();
1019 return RewriteResponse(
1021 nm
->mkRealAlgebraicNumber(RealAlgebraicNumber(num
) / den
));
1023 if (left
.getKind() == Kind::REAL_ALGEBRAIC_NUMBER
)
1025 const RealAlgebraicNumber
& num
=
1026 left
.getOperator().getConst
<RealAlgebraicNumber
>();
1027 return RewriteResponse(REWRITE_DONE
,
1028 nm
->mkRealAlgebraicNumber(num
/ den
));
1031 Node result
= nm
->mkRealAlgebraicNumber(inverse(den
));
1032 Node mult
= NodeManager::currentNM()->mkNode(kind::MULT
,left
,result
);
1034 return RewriteResponse(REWRITE_DONE
, mult
);
1036 return RewriteResponse(REWRITE_AGAIN
, mult
);
1039 return RewriteResponse(REWRITE_DONE
, t
);
1042 RewriteResponse
ArithRewriter::rewriteAbs(TNode t
)
1044 Assert(t
.getKind() == Kind::ABS
);
1045 Assert(t
.getNumChildren() == 1);
1049 const Rational
& rat
= t
[0].getConst
<Rational
>();
1052 return RewriteResponse(REWRITE_DONE
, t
[0]);
1054 return RewriteResponse(
1056 NodeManager::currentNM()->mkConstRealOrInt(t
[0].getType(), -rat
));
1058 if (t
[0].getKind() == Kind::REAL_ALGEBRAIC_NUMBER
)
1060 const RealAlgebraicNumber
& ran
=
1061 t
[0].getOperator().getConst
<RealAlgebraicNumber
>();
1062 if (ran
>= RealAlgebraicNumber())
1064 return RewriteResponse(REWRITE_DONE
, t
[0]);
1066 return RewriteResponse(
1067 REWRITE_DONE
, NodeManager::currentNM()->mkRealAlgebraicNumber(-ran
));
1069 return RewriteResponse(REWRITE_DONE
, t
);
1072 RewriteResponse
ArithRewriter::rewriteIntsDivMod(TNode t
, bool pre
)
1074 NodeManager
* nm
= NodeManager::currentNM();
1075 Kind k
= t
.getKind();
1076 if (k
== kind::INTS_MODULUS
)
1078 if (t
[1].isConst() && !t
[1].getConst
<Rational
>().isZero())
1080 // can immediately replace by INTS_MODULUS_TOTAL
1081 Node ret
= nm
->mkNode(kind::INTS_MODULUS_TOTAL
, t
[0], t
[1]);
1082 return returnRewrite(t
, ret
, Rewrite::MOD_TOTAL_BY_CONST
);
1085 if (k
== kind::INTS_DIVISION
)
1087 if (t
[1].isConst() && !t
[1].getConst
<Rational
>().isZero())
1089 // can immediately replace by INTS_DIVISION_TOTAL
1090 Node ret
= nm
->mkNode(kind::INTS_DIVISION_TOTAL
, t
[0], t
[1]);
1091 return returnRewrite(t
, ret
, Rewrite::DIV_TOTAL_BY_CONST
);
1094 return RewriteResponse(REWRITE_DONE
, t
);
1097 RewriteResponse
ArithRewriter::rewriteExtIntegerOp(TNode t
)
1099 Assert(t
.getKind() == kind::TO_INTEGER
|| t
.getKind() == kind::IS_INTEGER
);
1100 bool isPred
= t
.getKind() == kind::IS_INTEGER
;
1101 NodeManager
* nm
= NodeManager::currentNM();
1107 ret
= nm
->mkConst(t
[0].getConst
<Rational
>().isIntegral());
1111 ret
= nm
->mkConstInt(Rational(t
[0].getConst
<Rational
>().floor()));
1113 return returnRewrite(t
, ret
, Rewrite::INT_EXT_CONST
);
1115 if (t
[0].getType().isInteger())
1117 Node ret
= isPred
? nm
->mkConst(true) : Node(t
[0]);
1118 return returnRewrite(t
, ret
, Rewrite::INT_EXT_INT
);
1120 if (t
[0].getKind() == kind::PI
)
1122 Node ret
= isPred
? nm
->mkConst(false) : nm
->mkConstReal(Rational(3));
1123 return returnRewrite(t
, ret
, Rewrite::INT_EXT_PI
);
1125 return RewriteResponse(REWRITE_DONE
, t
);
1128 RewriteResponse
ArithRewriter::rewriteIntsDivModTotal(TNode t
, bool pre
)
1132 // do not rewrite at prewrite.
1133 return RewriteResponse(REWRITE_DONE
, t
);
1135 NodeManager
* nm
= NodeManager::currentNM();
1136 Kind k
= t
.getKind();
1137 Assert(k
== kind::INTS_MODULUS_TOTAL
|| k
== kind::INTS_DIVISION_TOTAL
);
1140 bool dIsConstant
= d
.isConst();
1141 if(dIsConstant
&& d
.getConst
<Rational
>().isZero()){
1142 // (div x 0) ---> 0 or (mod x 0) ---> 0
1143 return returnRewrite(t
, nm
->mkConstInt(0), Rewrite::DIV_MOD_BY_ZERO
);
1144 }else if(dIsConstant
&& d
.getConst
<Rational
>().isOne()){
1145 if (k
== kind::INTS_MODULUS_TOTAL
)
1148 return returnRewrite(t
, nm
->mkConstInt(0), Rewrite::MOD_BY_ONE
);
1150 Assert(k
== kind::INTS_DIVISION_TOTAL
);
1152 return returnRewrite(t
, n
, Rewrite::DIV_BY_ONE
);
1154 else if (dIsConstant
&& d
.getConst
<Rational
>().sgn() < 0)
1157 // (div x (- c)) ---> (- (div x c))
1158 // (mod x (- c)) ---> (mod x c)
1159 Node nn
= nm
->mkNode(k
, t
[0], nm
->mkConstInt(-t
[1].getConst
<Rational
>()));
1160 Node ret
= (k
== kind::INTS_DIVISION
|| k
== kind::INTS_DIVISION_TOTAL
)
1161 ? nm
->mkNode(kind::NEG
, nn
)
1163 return returnRewrite(t
, ret
, Rewrite::DIV_MOD_PULL_NEG_DEN
);
1165 else if (dIsConstant
&& n
.isConst())
1167 Assert(d
.getConst
<Rational
>().isIntegral());
1168 Assert(n
.getConst
<Rational
>().isIntegral());
1169 Assert(!d
.getConst
<Rational
>().isZero());
1170 Integer di
= d
.getConst
<Rational
>().getNumerator();
1171 Integer ni
= n
.getConst
<Rational
>().getNumerator();
1173 bool isDiv
= (k
== kind::INTS_DIVISION
|| k
== kind::INTS_DIVISION_TOTAL
);
1175 Integer result
= isDiv
? ni
.euclidianDivideQuotient(di
) : ni
.euclidianDivideRemainder(di
);
1177 // constant evaluation
1178 // (mod c1 c2) ---> c3 or (div c1 c2) ---> c3
1179 Node resultNode
= nm
->mkConstInt(Rational(result
));
1180 return returnRewrite(t
, resultNode
, Rewrite::CONST_EVAL
);
1182 if (k
== kind::INTS_MODULUS_TOTAL
)
1184 // Note these rewrites do not need to account for modulus by zero as being
1185 // a UF, which is handled by the reduction of INTS_MODULUS.
1186 Kind k0
= t
[0].getKind();
1187 if (k0
== kind::INTS_MODULUS_TOTAL
&& t
[0][1] == t
[1])
1189 // (mod (mod x c) c) --> (mod x c)
1190 return returnRewrite(t
, t
[0], Rewrite::MOD_OVER_MOD
);
1192 else if (k0
== kind::NONLINEAR_MULT
|| k0
== kind::MULT
|| k0
== kind::PLUS
)
1195 std::vector
<Node
> newChildren
;
1196 bool childChanged
= false;
1197 for (const Node
& tc
: t
[0])
1199 if (tc
.getKind() == kind::INTS_MODULUS_TOTAL
&& tc
[1] == t
[1])
1201 newChildren
.push_back(tc
[0]);
1202 childChanged
= true;
1205 newChildren
.push_back(tc
);
1209 // (mod (op ... (mod x c) ...) c) ---> (mod (op ... x ...) c) where
1210 // op is one of { NONLINEAR_MULT, MULT, PLUS }.
1211 Node ret
= nm
->mkNode(k0
, newChildren
);
1212 ret
= nm
->mkNode(kind::INTS_MODULUS_TOTAL
, ret
, t
[1]);
1213 return returnRewrite(t
, ret
, Rewrite::MOD_CHILD_MOD
);
1219 Assert(k
== kind::INTS_DIVISION_TOTAL
);
1220 // Note these rewrites do not need to account for division by zero as being
1221 // a UF, which is handled by the reduction of INTS_DIVISION.
1222 if (t
[0].getKind() == kind::INTS_MODULUS_TOTAL
&& t
[0][1] == t
[1])
1224 // (div (mod x c) c) --> 0
1225 Node ret
= nm
->mkConstInt(0);
1226 return returnRewrite(t
, ret
, Rewrite::DIV_OVER_MOD
);
1229 return RewriteResponse(REWRITE_DONE
, t
);
1232 TrustNode
ArithRewriter::expandDefinition(Node node
)
1234 // call eliminate operators, to eliminate partial operators only
1235 std::vector
<SkolemLemma
> lems
;
1236 TrustNode ret
= d_opElim
.eliminate(node
, lems
, true);
1237 Assert(lems
.empty());
1241 RewriteResponse
ArithRewriter::returnRewrite(TNode t
, Node ret
, Rewrite r
)
1243 Trace("arith-rewrite") << "ArithRewriter : " << t
<< " == " << ret
<< " by "
1245 return RewriteResponse(REWRITE_AGAIN_FULL
, ret
);
1248 } // namespace arith
1249 } // namespace theory