1 /******************************************************************************
2 * Top contributors (to current version):
3 * Andrew Reynolds, Tim King, Morgan Deters
5 * This file is part of the cvc5 project.
7 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 * in the top-level source directory and their institutional affiliations.
9 * All rights reserved. See the file COPYING in the top-level source
10 * directory for licensing information.
11 * ****************************************************************************
13 * [[ Add one-line brief description here ]]
15 * [[ Add lengthier description here ]]
16 * \todo document this file
19 #include "theory/arith/arith_rewriter.h"
27 #include "expr/algorithm/flatten.h"
28 #include "smt/logic_exception.h"
29 #include "theory/arith/arith_msum.h"
30 #include "theory/arith/arith_utilities.h"
31 #include "theory/arith/normal_form.h"
32 #include "theory/arith/operator_elim.h"
33 #include "theory/theory.h"
34 #include "util/bitvector.h"
35 #include "util/divisible.h"
36 #include "util/iand.h"
37 #include "util/real_algebraic_number.h"
39 using namespace cvc5::kind
;
48 * Implements an ordering on arithmetic leaf nodes, excluding rationals. As this
49 * comparator is meant to be used on children of Kind::NONLINEAR_MULT, we expect
50 * rationals to be handled separately. Furthermore, we expect there to be only a
51 * single real algebraic number.
52 * It broadly categorizes leaf nodes into real algebraic numbers, integers,
53 * variables, and the rest. The ordering is built as follows:
54 * - real algebraic numbers come first
55 * - real terms come before integer terms
56 * - variables come before non-variable terms
57 * - finally, fall back to node ordering
59 struct LeafNodeComparator
61 /** Implements operator<(a, b) as described above */
62 bool operator()(TNode a
, TNode b
)
64 if (a
== b
) return false;
66 bool aIsRAN
= a
.getKind() == Kind::REAL_ALGEBRAIC_NUMBER
;
67 bool bIsRAN
= b
.getKind() == Kind::REAL_ALGEBRAIC_NUMBER
;
68 if (aIsRAN
!= bIsRAN
) return aIsRAN
;
69 Assert(!aIsRAN
&& !bIsRAN
) << "real algebraic numbers should be combined";
71 bool aIsInt
= a
.getType().isInteger();
72 bool bIsInt
= b
.getType().isInteger();
73 if (aIsInt
!= bIsInt
) return !aIsInt
;
75 bool aIsVar
= a
.isVar();
76 bool bIsVar
= b
.isVar();
77 if (aIsVar
!= bIsVar
) return aIsVar
;
84 * Implements an ordering on arithmetic nonlinear multiplications. As we assume
85 * rationals to be handled separately, we only consider Kind::NONLINEAR_MULT as
86 * multiplication terms. For individual factors of the product, we rely on the
87 * ordering from LeafNodeComparator. Furthermore, we expect products to be
88 * sorted according to LeafNodeComparator. The ordering is built as follows:
89 * - single factors come first (everything that is not NONLINEAR_MULT)
90 * - multiplications with less factors come first
91 * - multiplications are compared lexicographically
93 struct ProductNodeComparator
95 /** Implements operator<(a, b) as described above */
96 bool operator()(TNode a
, TNode b
)
98 if (a
== b
) return false;
100 Assert(a
.getKind() != Kind::MULT
);
101 Assert(b
.getKind() != Kind::MULT
);
103 bool aIsMult
= a
.getKind() == Kind::NONLINEAR_MULT
;
104 bool bIsMult
= b
.getKind() == Kind::NONLINEAR_MULT
;
105 if (aIsMult
!= bIsMult
) return !aIsMult
;
109 return LeafNodeComparator()(a
, b
);
112 size_t aLen
= a
.getNumChildren();
113 size_t bLen
= b
.getNumChildren();
114 if (aLen
!= bLen
) return aLen
< bLen
;
116 for (size_t i
= 0; i
< aLen
; ++i
)
120 return LeafNodeComparator()(a
[i
], b
[i
]);
123 Unreachable() << "Nodes are different, but have the same content";
129 template <typename L
, typename R
>
130 bool evaluateRelation(Kind rel
, const L
& l
, const R
& r
)
134 case Kind::LT
: return l
< r
;
135 case Kind::LEQ
: return l
<= r
;
136 case Kind::EQUAL
: return l
== r
;
137 case Kind::GEQ
: return l
>= r
;
138 case Kind::GT
: return l
> r
;
139 default: Unreachable(); return false;
144 * Check whether the parent has a child that is a constant zero.
145 * If so, return this child. Otherwise, return std::nullopt.
147 template <typename Iterable
>
148 std::optional
<TNode
> getZeroChild(const Iterable
& parent
)
150 for (const auto& node
: parent
)
152 if (node
.isConst() && node
.template getConst
<Rational
>().isZero())
162 ArithRewriter::ArithRewriter(OperatorElim
& oe
) : d_opElim(oe
) {}
164 RewriteResponse
ArithRewriter::preRewrite(TNode t
)
166 Trace("arith-rewriter") << "preRewrite(" << t
<< ")" << std::endl
;
169 auto res
= preRewriteAtom(t
);
170 Trace("arith-rewriter")
171 << res
.d_status
<< " -> " << res
.d_node
<< std::endl
;
174 auto res
= preRewriteTerm(t
);
175 Trace("arith-rewriter") << res
.d_status
<< " -> " << res
.d_node
<< std::endl
;
179 RewriteResponse
ArithRewriter::postRewrite(TNode t
)
181 Trace("arith-rewriter") << "postRewrite(" << t
<< ")" << std::endl
;
184 auto res
= postRewriteAtom(t
);
185 Trace("arith-rewriter")
186 << res
.d_status
<< " -> " << res
.d_node
<< std::endl
;
189 auto res
= postRewriteTerm(t
);
190 Trace("arith-rewriter") << res
.d_status
<< " -> " << res
.d_node
<< std::endl
;
194 RewriteResponse
ArithRewriter::preRewriteAtom(TNode atom
)
196 Assert(isAtom(atom
));
198 NodeManager
* nm
= NodeManager::currentNM();
200 if (isRelationOperator(atom
.getKind()) && atom
[0] == atom
[1])
202 switch (atom
.getKind())
204 case Kind::LT
: return RewriteResponse(REWRITE_DONE
, nm
->mkConst(false));
205 case Kind::LEQ
: return RewriteResponse(REWRITE_DONE
, nm
->mkConst(true));
206 case Kind::EQUAL
: return RewriteResponse(REWRITE_DONE
, nm
->mkConst(true));
207 case Kind::GEQ
: return RewriteResponse(REWRITE_DONE
, nm
->mkConst(true));
208 case Kind::GT
: return RewriteResponse(REWRITE_DONE
, nm
->mkConst(false));
213 switch (atom
.getKind())
216 return RewriteResponse(REWRITE_DONE
,
217 nm
->mkNode(kind::LEQ
, atom
[0], atom
[1]).notNode());
219 return RewriteResponse(REWRITE_DONE
,
220 nm
->mkNode(kind::GEQ
, atom
[0], atom
[1]).notNode());
221 case Kind::IS_INTEGER
:
222 if (atom
[0].getType().isInteger())
224 return RewriteResponse(REWRITE_DONE
, nm
->mkConst(true));
227 case Kind::DIVISIBLE
:
228 if (atom
.getOperator().getConst
<Divisible
>().k
.isOne())
230 return RewriteResponse(REWRITE_DONE
, nm
->mkConst(true));
236 return RewriteResponse(REWRITE_DONE
, atom
);
239 RewriteResponse
ArithRewriter::postRewriteAtom(TNode atom
)
241 Assert(isAtom(atom
));
242 if (atom
.getKind() == kind::IS_INTEGER
)
244 return rewriteExtIntegerOp(atom
);
246 else if (atom
.getKind() == kind::DIVISIBLE
)
248 if (atom
[0].isConst())
250 return RewriteResponse(REWRITE_DONE
,
251 NodeManager::currentNM()->mkConst(bool(
252 (atom
[0].getConst
<Rational
>()
253 / atom
.getOperator().getConst
<Divisible
>().k
)
256 if (atom
.getOperator().getConst
<Divisible
>().k
.isOne())
258 return RewriteResponse(REWRITE_DONE
,
259 NodeManager::currentNM()->mkConst(true));
261 NodeManager
* nm
= NodeManager::currentNM();
262 return RewriteResponse(
264 nm
->mkNode(kind::EQUAL
,
265 nm
->mkNode(kind::INTS_MODULUS_TOTAL
,
267 nm
->mkConstInt(Rational(
268 atom
.getOperator().getConst
<Divisible
>().k
))),
269 nm
->mkConstInt(Rational(0))));
273 TNode left
= atom
[0];
274 TNode right
= atom
[1];
276 auto* nm
= NodeManager::currentNM();
279 const Rational
& l
= left
.getConst
<Rational
>();
282 const Rational
& r
= right
.getConst
<Rational
>();
283 return RewriteResponse(
284 REWRITE_DONE
, nm
->mkConst(evaluateRelation(atom
.getKind(), l
, r
)));
286 else if (right
.getKind() == Kind::REAL_ALGEBRAIC_NUMBER
)
288 const RealAlgebraicNumber
& r
=
289 right
.getOperator().getConst
<RealAlgebraicNumber
>();
290 return RewriteResponse(
291 REWRITE_DONE
, nm
->mkConst(evaluateRelation(atom
.getKind(), l
, r
)));
294 else if (left
.getKind() == Kind::REAL_ALGEBRAIC_NUMBER
)
296 const RealAlgebraicNumber
& l
=
297 left
.getOperator().getConst
<RealAlgebraicNumber
>();
300 const Rational
& r
= right
.getConst
<Rational
>();
301 return RewriteResponse(
302 REWRITE_DONE
, nm
->mkConst(evaluateRelation(atom
.getKind(), l
, r
)));
304 else if (right
.getKind() == Kind::REAL_ALGEBRAIC_NUMBER
)
306 const RealAlgebraicNumber
& r
=
307 right
.getOperator().getConst
<RealAlgebraicNumber
>();
308 return RewriteResponse(
309 REWRITE_DONE
, nm
->mkConst(evaluateRelation(atom
.getKind(), l
, r
)));
313 Polynomial pleft
= Polynomial::parsePolynomial(left
);
314 Polynomial pright
= Polynomial::parsePolynomial(right
);
316 Debug("arith::rewriter") << "pleft " << pleft
.getNode() << std::endl
;
317 Debug("arith::rewriter") << "pright " << pright
.getNode() << std::endl
;
319 Comparison cmp
= Comparison::mkComparison(atom
.getKind(), pleft
, pright
);
320 Assert(cmp
.isNormalForm());
321 return RewriteResponse(REWRITE_DONE
, cmp
.getNode());
324 bool ArithRewriter::isAtom(TNode n
) {
325 Kind k
= n
.getKind();
326 return arith::isRelationOperator(k
) || k
== kind::IS_INTEGER
327 || k
== kind::DIVISIBLE
;
330 RewriteResponse
ArithRewriter::rewriteConstant(TNode t
){
332 Assert(t
.getKind() == CONST_RATIONAL
|| t
.getKind() == CONST_INTEGER
);
334 return RewriteResponse(REWRITE_DONE
, t
);
337 RewriteResponse
ArithRewriter::rewriteRAN(TNode t
)
339 Assert(t
.getKind() == REAL_ALGEBRAIC_NUMBER
);
341 const RealAlgebraicNumber
& r
=
342 t
.getOperator().getConst
<RealAlgebraicNumber
>();
345 return RewriteResponse(
346 REWRITE_DONE
, NodeManager::currentNM()->mkConstReal(r
.toRational()));
349 return RewriteResponse(REWRITE_DONE
, t
);
352 RewriteResponse
ArithRewriter::rewriteVariable(TNode t
){
355 return RewriteResponse(REWRITE_DONE
, t
);
358 RewriteResponse
ArithRewriter::rewriteSub(TNode t
)
360 Assert(t
.getKind() == kind::SUB
);
361 Assert(t
.getNumChildren() == 2);
363 auto* nm
= NodeManager::currentNM();
367 return RewriteResponse(REWRITE_DONE
,
368 nm
->mkConstRealOrInt(t
.getType(), Rational(0)));
370 return RewriteResponse(REWRITE_AGAIN_FULL
,
371 nm
->mkNode(Kind::ADD
, t
[0], makeUnaryMinusNode(t
[1])));
374 RewriteResponse
ArithRewriter::rewriteNeg(TNode t
, bool pre
)
376 Assert(t
.getKind() == kind::NEG
);
380 Rational neg
= -(t
[0].getConst
<Rational
>());
381 NodeManager
* nm
= NodeManager::currentNM();
382 return RewriteResponse(REWRITE_DONE
,
383 nm
->mkConstRealOrInt(t
[0].getType(), neg
));
385 if (t
[0].getKind() == Kind::REAL_ALGEBRAIC_NUMBER
)
387 const RealAlgebraicNumber
& r
=
388 t
[0].getOperator().getConst
<RealAlgebraicNumber
>();
389 NodeManager
* nm
= NodeManager::currentNM();
390 return RewriteResponse(REWRITE_DONE
, nm
->mkRealAlgebraicNumber(-r
));
393 Node noUminus
= makeUnaryMinusNode(t
[0]);
395 return RewriteResponse(REWRITE_DONE
, noUminus
);
397 return RewriteResponse(REWRITE_AGAIN
, noUminus
);
400 RewriteResponse
ArithRewriter::preRewriteTerm(TNode t
){
402 return rewriteConstant(t
);
404 return rewriteVariable(t
);
406 switch(Kind k
= t
.getKind()){
407 case kind::REAL_ALGEBRAIC_NUMBER
: return rewriteRAN(t
);
408 case kind::SUB
: return rewriteSub(t
);
409 case kind::NEG
: return rewriteNeg(t
, true);
411 case kind::DIVISION_TOTAL
: return rewriteDiv(t
, true);
412 case kind::ADD
: return preRewritePlus(t
);
414 case kind::NONLINEAR_MULT
: return preRewriteMult(t
);
415 case kind::IAND
: return RewriteResponse(REWRITE_DONE
, t
);
416 case kind::POW2
: return RewriteResponse(REWRITE_DONE
, t
);
417 case kind::EXPONENTIAL
:
423 case kind::COTANGENT
:
425 case kind::ARCCOSINE
:
426 case kind::ARCTANGENT
:
427 case kind::ARCCOSECANT
:
428 case kind::ARCSECANT
:
429 case kind::ARCCOTANGENT
:
430 case kind::SQRT
: return preRewriteTranscendental(t
);
431 case kind::INTS_DIVISION
:
432 case kind::INTS_MODULUS
: return rewriteIntsDivMod(t
, true);
433 case kind::INTS_DIVISION_TOTAL
:
434 case kind::INTS_MODULUS_TOTAL
: return rewriteIntsDivModTotal(t
, true);
435 case kind::ABS
: return rewriteAbs(t
);
436 case kind::IS_INTEGER
:
437 case kind::TO_INTEGER
: return RewriteResponse(REWRITE_DONE
, t
);
439 case kind::CAST_TO_REAL
: return RewriteResponse(REWRITE_DONE
, t
[0]);
440 case kind::POW
: return RewriteResponse(REWRITE_DONE
, t
);
441 case kind::PI
: return RewriteResponse(REWRITE_DONE
, t
);
442 default: Unhandled() << k
;
447 RewriteResponse
ArithRewriter::postRewriteTerm(TNode t
){
449 return rewriteConstant(t
);
451 return rewriteVariable(t
);
453 Trace("arith-rewriter") << "postRewriteTerm: " << t
<< std::endl
;
455 case kind::REAL_ALGEBRAIC_NUMBER
: return rewriteRAN(t
);
456 case kind::SUB
: return rewriteSub(t
);
457 case kind::NEG
: return rewriteNeg(t
, false);
459 case kind::DIVISION_TOTAL
: return rewriteDiv(t
, false);
460 case kind::ADD
: return postRewritePlus(t
);
462 case kind::NONLINEAR_MULT
: return postRewriteMult(t
);
463 case kind::IAND
: return postRewriteIAnd(t
);
464 case kind::POW2
: return postRewritePow2(t
);
465 case kind::EXPONENTIAL
:
471 case kind::COTANGENT
:
473 case kind::ARCCOSINE
:
474 case kind::ARCTANGENT
:
475 case kind::ARCCOSECANT
:
476 case kind::ARCSECANT
:
477 case kind::ARCCOTANGENT
:
478 case kind::SQRT
: return postRewriteTranscendental(t
);
479 case kind::INTS_DIVISION
:
480 case kind::INTS_MODULUS
: return rewriteIntsDivMod(t
, false);
481 case kind::INTS_DIVISION_TOTAL
:
482 case kind::INTS_MODULUS_TOTAL
: return rewriteIntsDivModTotal(t
, false);
483 case kind::ABS
: return rewriteAbs(t
);
485 case kind::CAST_TO_REAL
: return RewriteResponse(REWRITE_DONE
, t
[0]);
486 case kind::TO_INTEGER
: return rewriteExtIntegerOp(t
);
491 const Rational
& exp
= t
[1].getConst
<Rational
>();
494 return RewriteResponse(REWRITE_DONE
,
495 NodeManager::currentNM()->mkConstRealOrInt(
496 t
.getType(), Rational(1)));
497 }else if(exp
.sgn() > 0 && exp
.isIntegral()){
498 cvc5::Rational
r(expr::NodeValue::MAX_CHILDREN
);
501 unsigned num
= exp
.getNumerator().toUnsignedInt();
503 return RewriteResponse(REWRITE_AGAIN
, base
);
505 NodeBuilder
nb(kind::MULT
);
506 for(unsigned i
=0; i
< num
; ++i
){
509 Assert(nb
.getNumChildren() > 0);
511 return RewriteResponse(REWRITE_AGAIN
, mult
);
516 else if (t
[0].isConst()
517 && t
[0].getConst
<Rational
>().getNumerator().toUnsignedInt()
520 return RewriteResponse(
521 REWRITE_DONE
, NodeManager::currentNM()->mkNode(kind::POW2
, t
[1]));
524 // Todo improve the exception thrown
525 std::stringstream ss
;
526 ss
<< "The exponent of the POW(^) operator can only be a positive "
527 "integral constant below "
528 << (expr::NodeValue::MAX_CHILDREN
+ 1) << ". ";
529 ss
<< "Exception occurred in:" << std::endl
;
531 throw LogicException(ss
.str());
534 return RewriteResponse(REWRITE_DONE
, t
);
542 RewriteResponse
ArithRewriter::preRewritePlus(TNode t
){
543 Assert(t
.getKind() == kind::ADD
);
544 return RewriteResponse(REWRITE_DONE
, expr::algorithm::flatten(t
));
547 RewriteResponse
ArithRewriter::postRewritePlus(TNode t
){
548 Assert(t
.getKind() == kind::ADD
);
549 Assert(t
.getNumChildren() > 1);
552 Node flat
= expr::algorithm::flatten(t
);
555 return RewriteResponse(REWRITE_AGAIN
, flat
);
560 RealAlgebraicNumber ran
;
561 std::vector
<Monomial
> monomials
;
562 std::vector
<Polynomial
> polynomials
;
564 for (const auto& child
: t
)
568 if (child
.getConst
<Rational
>().isZero())
572 rational
+= child
.getConst
<Rational
>();
574 else if (child
.getKind() == Kind::REAL_ALGEBRAIC_NUMBER
)
576 ran
+= child
.getOperator().getConst
<RealAlgebraicNumber
>();
578 else if (Monomial::isMember(child
))
580 monomials
.emplace_back(Monomial::parseMonomial(child
));
584 polynomials
.emplace_back(Polynomial::parsePolynomial(child
));
588 if(!monomials
.empty()){
589 Monomial::sort(monomials
);
590 Monomial::combineAdjacentMonomials(monomials
);
591 polynomials
.emplace_back(Polynomial::mkPolynomial(monomials
));
593 if (!rational
.isZero())
595 polynomials
.emplace_back(
596 Polynomial::mkPolynomial(Constant::mkConstant(rational
)));
599 Polynomial poly
= Polynomial::sumPolynomials(polynomials
);
603 return RewriteResponse(REWRITE_DONE
, poly
.getNode());
605 if (poly
.containsConstant())
607 ran
+= RealAlgebraicNumber(poly
.getHead().getConstant().getValue());
608 if (!poly
.isConstant())
610 poly
= poly
.getTail();
614 auto* nm
= NodeManager::currentNM();
615 if (poly
.isConstant())
617 return RewriteResponse(REWRITE_DONE
, nm
->mkRealAlgebraicNumber(ran
));
619 return RewriteResponse(
621 nm
->mkNode(Kind::ADD
, nm
->mkRealAlgebraicNumber(ran
), poly
.getNode()));
624 RewriteResponse
ArithRewriter::preRewriteMult(TNode node
)
626 Assert(node
.getKind() == kind::MULT
627 || node
.getKind() == kind::NONLINEAR_MULT
);
629 auto res
= getZeroChild(node
);
632 return RewriteResponse(REWRITE_DONE
, *res
);
634 return RewriteResponse(REWRITE_DONE
, node
);
637 RewriteResponse
ArithRewriter::postRewriteMult(TNode t
){
638 Assert(t
.getKind() == kind::MULT
|| t
.getKind() == kind::NONLINEAR_MULT
);
639 Assert(t
.getNumChildren() >= 2);
641 if (auto res
= getZeroChild(t
); res
)
643 return RewriteResponse(REWRITE_DONE
, *res
);
646 Rational rational
= Rational(1);
647 RealAlgebraicNumber ran
= RealAlgebraicNumber(Integer(1));
648 Polynomial poly
= Polynomial::mkOne();
650 for (const auto& child
: t
)
654 if (child
.getConst
<Rational
>().isZero())
656 return RewriteResponse(REWRITE_DONE
, child
);
658 rational
*= child
.getConst
<Rational
>();
660 else if (child
.getKind() == Kind::REAL_ALGEBRAIC_NUMBER
)
662 ran
*= child
.getOperator().getConst
<RealAlgebraicNumber
>();
666 poly
= poly
* Polynomial::parsePolynomial(child
);
670 if (!rational
.isOne())
672 poly
= poly
* rational
;
676 return RewriteResponse(REWRITE_DONE
, poly
.getNode());
678 auto* nm
= NodeManager::currentNM();
679 if (poly
.isConstant())
681 ran
*= RealAlgebraicNumber(poly
.getHead().getConstant().getValue());
682 return RewriteResponse(REWRITE_DONE
, nm
->mkRealAlgebraicNumber(ran
));
684 return RewriteResponse(
687 Kind::MULT
, nm
->mkRealAlgebraicNumber(ran
), poly
.getNode()));
690 RewriteResponse
ArithRewriter::postRewritePow2(TNode t
)
692 Assert(t
.getKind() == kind::POW2
);
693 NodeManager
* nm
= NodeManager::currentNM();
694 // if constant, we eliminate
697 // pow2 is only supported for integers
698 Assert(t
[0].getType().isInteger());
699 Integer i
= t
[0].getConst
<Rational
>().getNumerator();
702 return RewriteResponse(REWRITE_DONE
, nm
->mkConstInt(Rational(0)));
704 // (pow2 t) ---> (pow 2 t) and continue rewriting to eliminate pow
705 Node two
= nm
->mkConstInt(Rational(Integer(2)));
706 Node ret
= nm
->mkNode(kind::POW
, two
, t
[0]);
707 return RewriteResponse(REWRITE_AGAIN
, ret
);
709 return RewriteResponse(REWRITE_DONE
, t
);
712 RewriteResponse
ArithRewriter::postRewriteIAnd(TNode t
)
714 Assert(t
.getKind() == kind::IAND
);
715 size_t bsize
= t
.getOperator().getConst
<IntAnd
>().d_size
;
716 NodeManager
* nm
= NodeManager::currentNM();
717 // if constant, we eliminate
718 if (t
[0].isConst() && t
[1].isConst())
720 Node iToBvop
= nm
->mkConst(IntToBitVector(bsize
));
721 Node arg1
= nm
->mkNode(kind::INT_TO_BITVECTOR
, iToBvop
, t
[0]);
722 Node arg2
= nm
->mkNode(kind::INT_TO_BITVECTOR
, iToBvop
, t
[1]);
723 Node bvand
= nm
->mkNode(kind::BITVECTOR_AND
, arg1
, arg2
);
724 Node ret
= nm
->mkNode(kind::BITVECTOR_TO_NAT
, bvand
);
725 return RewriteResponse(REWRITE_AGAIN_FULL
, ret
);
727 else if (t
[0] > t
[1])
729 // ((_ iand k) x y) ---> ((_ iand k) y x) if x > y by node ordering
730 Node ret
= nm
->mkNode(kind::IAND
, t
.getOperator(), t
[1], t
[0]);
731 return RewriteResponse(REWRITE_AGAIN
, ret
);
733 else if (t
[0] == t
[1])
735 // ((_ iand k) x x) ---> (mod x 2^k)
736 Node twok
= nm
->mkConstInt(Rational(Integer(2).pow(bsize
)));
737 Node ret
= nm
->mkNode(kind::INTS_MODULUS
, t
[0], twok
);
738 return RewriteResponse(REWRITE_AGAIN
, ret
);
740 // simplifications involving constants
741 for (unsigned i
= 0; i
< 2; i
++)
747 if (t
[i
].getConst
<Rational
>().sgn() == 0)
749 // ((_ iand k) 0 y) ---> 0
750 return RewriteResponse(REWRITE_DONE
, t
[i
]);
752 if (t
[i
].getConst
<Rational
>().getNumerator() == Integer(2).pow(bsize
) - 1)
754 // ((_ iand k) 111...1 y) ---> (mod y 2^k)
755 Node twok
= nm
->mkConstInt(Rational(Integer(2).pow(bsize
)));
756 Node ret
= nm
->mkNode(kind::INTS_MODULUS
, t
[1-i
], twok
);
757 return RewriteResponse(REWRITE_AGAIN
, ret
);
760 return RewriteResponse(REWRITE_DONE
, t
);
763 RewriteResponse
ArithRewriter::preRewriteTranscendental(TNode t
) {
764 return RewriteResponse(REWRITE_DONE
, t
);
767 RewriteResponse
ArithRewriter::postRewriteTranscendental(TNode t
) {
768 Trace("arith-tf-rewrite") << "Rewrite transcendental function : " << t
<< std::endl
;
769 NodeManager
* nm
= NodeManager::currentNM();
770 switch( t
.getKind() ){
771 case kind::EXPONENTIAL
: {
774 Node one
= nm
->mkConstReal(Rational(1));
775 if(t
[0].getConst
<Rational
>().sgn()>=0 && t
[0].getType().isInteger() && t
[0]!=one
){
776 return RewriteResponse(
778 nm
->mkNode(kind::POW
, nm
->mkNode(kind::EXPONENTIAL
, one
), t
[0]));
780 return RewriteResponse(REWRITE_DONE
, t
);
783 else if (t
[0].getKind() == kind::ADD
)
785 std::vector
<Node
> product
;
786 for (const Node tc
: t
[0])
788 product
.push_back(nm
->mkNode(kind::EXPONENTIAL
, tc
));
790 // We need to do a full rewrite here, since we can get exponentials of
791 // constants, e.g. when we are rewriting exp(2 + x)
792 return RewriteResponse(REWRITE_AGAIN_FULL
,
793 nm
->mkNode(kind::MULT
, product
));
800 const Rational
& rat
= t
[0].getConst
<Rational
>();
802 return RewriteResponse(REWRITE_DONE
, nm
->mkConstReal(Rational(0)));
804 else if (rat
.sgn() == -1)
806 Node ret
= nm
->mkNode(kind::NEG
,
807 nm
->mkNode(kind::SINE
, nm
->mkConstReal(-rat
)));
808 return RewriteResponse(REWRITE_AGAIN_FULL
, ret
);
811 else if ((t
[0].getKind() == MULT
|| t
[0].getKind() == NONLINEAR_MULT
)
812 && t
[0][0].isConst() && t
[0][0].getConst
<Rational
>().sgn() == -1)
814 // sin(-n*x) ---> -sin(n*x)
815 std::vector
<Node
> mchildren(t
[0].begin(), t
[0].end());
816 mchildren
[0] = nm
->mkConstReal(-t
[0][0].getConst
<Rational
>());
817 Node ret
= nm
->mkNode(
819 nm
->mkNode(kind::SINE
, nm
->mkNode(t
[0].getKind(), mchildren
)));
820 return RewriteResponse(REWRITE_AGAIN_FULL
, ret
);
824 // get the factor of PI in the argument
828 std::map
<Node
, Node
> msum
;
829 if (ArithMSum::getMonomialSum(t
[0], msum
))
832 std::map
<Node
, Node
>::iterator itm
= msum
.find(pi
);
833 if (itm
!= msum
.end())
835 if (itm
->second
.isNull())
837 pi_factor
= nm
->mkConstReal(Rational(1));
841 pi_factor
= itm
->second
;
846 rem
= ArithMSum::mkNode(t
[0].getType(), msum
);
855 // if there is a factor of PI
856 if( !pi_factor
.isNull() ){
857 Trace("arith-tf-rewrite-debug") << "Process pi factor = " << pi_factor
<< std::endl
;
858 Rational r
= pi_factor
.getConst
<Rational
>();
859 Rational r_abs
= r
.abs();
860 Rational rone
= Rational(1);
861 Rational rtwo
= Rational(2);
864 //add/substract 2*pi beyond scope
865 Rational ra_div_two
= (r_abs
+ rone
) / rtwo
;
869 new_pi_factor
= nm
->mkConstReal(r
- rtwo
* ra_div_two
.floor());
873 Assert(r
.sgn() == -1);
874 new_pi_factor
= nm
->mkConstReal(r
+ rtwo
* ra_div_two
.floor());
876 Node new_arg
= nm
->mkNode(kind::MULT
, new_pi_factor
, pi
);
879 new_arg
= nm
->mkNode(kind::ADD
, new_arg
, rem
);
881 // sin( 2*n*PI + x ) = sin( x )
882 return RewriteResponse(REWRITE_AGAIN_FULL
,
883 nm
->mkNode(kind::SINE
, new_arg
));
885 else if (r_abs
== rone
)
887 // sin( PI + x ) = -sin( x )
890 return RewriteResponse(REWRITE_DONE
, nm
->mkConstReal(Rational(0)));
894 return RewriteResponse(
896 nm
->mkNode(kind::NEG
, nm
->mkNode(kind::SINE
, rem
)));
899 else if (rem
.isNull())
901 // other rational cases based on Niven's theorem
902 // (https://en.wikipedia.org/wiki/Niven%27s_theorem)
903 Integer one
= Integer(1);
904 Integer two
= Integer(2);
905 Integer six
= Integer(6);
906 if (r_abs
.getDenominator() == two
)
908 Assert(r_abs
.getNumerator() == one
);
909 return RewriteResponse(REWRITE_DONE
,
910 nm
->mkConstReal(Rational(r
.sgn())));
912 else if (r_abs
.getDenominator() == six
)
914 Integer five
= Integer(5);
915 if (r_abs
.getNumerator() == one
|| r_abs
.getNumerator() == five
)
917 return RewriteResponse(
919 nm
->mkConstReal(Rational(r
.sgn()) / Rational(2)));
927 return RewriteResponse(
931 nm
->mkNode(kind::SUB
,
932 nm
->mkNode(kind::MULT
,
933 nm
->mkConstReal(Rational(1) / Rational(2)),
940 return RewriteResponse(REWRITE_AGAIN_FULL
,
941 nm
->mkNode(kind::DIVISION
,
942 nm
->mkNode(kind::SINE
, t
[0]),
943 nm
->mkNode(kind::COSINE
, t
[0])));
948 return RewriteResponse(REWRITE_AGAIN_FULL
,
949 nm
->mkNode(kind::DIVISION
,
950 nm
->mkConstReal(Rational(1)),
951 nm
->mkNode(kind::SINE
, t
[0])));
956 return RewriteResponse(REWRITE_AGAIN_FULL
,
957 nm
->mkNode(kind::DIVISION
,
958 nm
->mkConstReal(Rational(1)),
959 nm
->mkNode(kind::COSINE
, t
[0])));
962 case kind::COTANGENT
:
964 return RewriteResponse(REWRITE_AGAIN_FULL
,
965 nm
->mkNode(kind::DIVISION
,
966 nm
->mkNode(kind::COSINE
, t
[0]),
967 nm
->mkNode(kind::SINE
, t
[0])));
973 return RewriteResponse(REWRITE_DONE
, t
);
976 Node
ArithRewriter::makeUnaryMinusNode(TNode n
){
977 NodeManager
* nm
= NodeManager::currentNM();
978 Rational
qNegOne(-1);
979 return nm
->mkNode(kind::MULT
, nm
->mkConstRealOrInt(n
.getType(), qNegOne
), n
);
982 RewriteResponse
ArithRewriter::rewriteDiv(TNode t
, bool pre
){
983 Assert(t
.getKind() == kind::DIVISION_TOTAL
|| t
.getKind() == kind::DIVISION
);
984 Assert(t
.getNumChildren() == 2);
990 NodeManager
* nm
= NodeManager::currentNM();
991 const Rational
& den
= right
.getConst
<Rational
>();
994 if(t
.getKind() == kind::DIVISION_TOTAL
){
995 return RewriteResponse(REWRITE_DONE
, nm
->mkConstReal(0));
997 // This is unsupported, but this is not a good place to complain
998 return RewriteResponse(REWRITE_DONE
, t
);
1001 Assert(den
!= Rational(0));
1005 const Rational
& num
= left
.getConst
<Rational
>();
1006 return RewriteResponse(REWRITE_DONE
, nm
->mkConstReal(num
/ den
));
1008 if (left
.getKind() == Kind::REAL_ALGEBRAIC_NUMBER
)
1010 const RealAlgebraicNumber
& num
=
1011 left
.getOperator().getConst
<RealAlgebraicNumber
>();
1012 return RewriteResponse(
1014 nm
->mkRealAlgebraicNumber(num
/ RealAlgebraicNumber(den
)));
1017 Node result
= nm
->mkConstReal(den
.inverse());
1018 Node mult
= NodeManager::currentNM()->mkNode(kind::MULT
, left
, result
);
1021 return RewriteResponse(REWRITE_DONE
, mult
);
1025 return RewriteResponse(REWRITE_AGAIN
, mult
);
1028 if (right
.getKind() == Kind::REAL_ALGEBRAIC_NUMBER
)
1030 NodeManager
* nm
= NodeManager::currentNM();
1031 const RealAlgebraicNumber
& den
=
1032 right
.getOperator().getConst
<RealAlgebraicNumber
>();
1035 const Rational
& num
= left
.getConst
<Rational
>();
1036 return RewriteResponse(
1038 nm
->mkRealAlgebraicNumber(RealAlgebraicNumber(num
) / den
));
1040 if (left
.getKind() == Kind::REAL_ALGEBRAIC_NUMBER
)
1042 const RealAlgebraicNumber
& num
=
1043 left
.getOperator().getConst
<RealAlgebraicNumber
>();
1044 return RewriteResponse(REWRITE_DONE
,
1045 nm
->mkRealAlgebraicNumber(num
/ den
));
1048 Node result
= nm
->mkRealAlgebraicNumber(inverse(den
));
1049 Node mult
= NodeManager::currentNM()->mkNode(kind::MULT
,left
,result
);
1051 return RewriteResponse(REWRITE_DONE
, mult
);
1053 return RewriteResponse(REWRITE_AGAIN
, mult
);
1056 return RewriteResponse(REWRITE_DONE
, t
);
1059 RewriteResponse
ArithRewriter::rewriteAbs(TNode t
)
1061 Assert(t
.getKind() == Kind::ABS
);
1062 Assert(t
.getNumChildren() == 1);
1066 const Rational
& rat
= t
[0].getConst
<Rational
>();
1069 return RewriteResponse(REWRITE_DONE
, t
[0]);
1071 return RewriteResponse(
1073 NodeManager::currentNM()->mkConstRealOrInt(t
[0].getType(), -rat
));
1075 if (t
[0].getKind() == Kind::REAL_ALGEBRAIC_NUMBER
)
1077 const RealAlgebraicNumber
& ran
=
1078 t
[0].getOperator().getConst
<RealAlgebraicNumber
>();
1079 if (ran
>= RealAlgebraicNumber())
1081 return RewriteResponse(REWRITE_DONE
, t
[0]);
1083 return RewriteResponse(
1084 REWRITE_DONE
, NodeManager::currentNM()->mkRealAlgebraicNumber(-ran
));
1086 return RewriteResponse(REWRITE_DONE
, t
);
1089 RewriteResponse
ArithRewriter::rewriteIntsDivMod(TNode t
, bool pre
)
1091 NodeManager
* nm
= NodeManager::currentNM();
1092 Kind k
= t
.getKind();
1093 if (k
== kind::INTS_MODULUS
)
1095 if (t
[1].isConst() && !t
[1].getConst
<Rational
>().isZero())
1097 // can immediately replace by INTS_MODULUS_TOTAL
1098 Node ret
= nm
->mkNode(kind::INTS_MODULUS_TOTAL
, t
[0], t
[1]);
1099 return returnRewrite(t
, ret
, Rewrite::MOD_TOTAL_BY_CONST
);
1102 if (k
== kind::INTS_DIVISION
)
1104 if (t
[1].isConst() && !t
[1].getConst
<Rational
>().isZero())
1106 // can immediately replace by INTS_DIVISION_TOTAL
1107 Node ret
= nm
->mkNode(kind::INTS_DIVISION_TOTAL
, t
[0], t
[1]);
1108 return returnRewrite(t
, ret
, Rewrite::DIV_TOTAL_BY_CONST
);
1111 return RewriteResponse(REWRITE_DONE
, t
);
1114 RewriteResponse
ArithRewriter::rewriteExtIntegerOp(TNode t
)
1116 Assert(t
.getKind() == kind::TO_INTEGER
|| t
.getKind() == kind::IS_INTEGER
);
1117 bool isPred
= t
.getKind() == kind::IS_INTEGER
;
1118 NodeManager
* nm
= NodeManager::currentNM();
1124 ret
= nm
->mkConst(t
[0].getConst
<Rational
>().isIntegral());
1128 ret
= nm
->mkConstInt(Rational(t
[0].getConst
<Rational
>().floor()));
1130 return returnRewrite(t
, ret
, Rewrite::INT_EXT_CONST
);
1132 if (t
[0].getType().isInteger())
1134 Node ret
= isPred
? nm
->mkConst(true) : Node(t
[0]);
1135 return returnRewrite(t
, ret
, Rewrite::INT_EXT_INT
);
1137 if (t
[0].getKind() == kind::PI
)
1139 Node ret
= isPred
? nm
->mkConst(false) : nm
->mkConstReal(Rational(3));
1140 return returnRewrite(t
, ret
, Rewrite::INT_EXT_PI
);
1142 return RewriteResponse(REWRITE_DONE
, t
);
1145 RewriteResponse
ArithRewriter::rewriteIntsDivModTotal(TNode t
, bool pre
)
1149 // do not rewrite at prewrite.
1150 return RewriteResponse(REWRITE_DONE
, t
);
1152 NodeManager
* nm
= NodeManager::currentNM();
1153 Kind k
= t
.getKind();
1154 Assert(k
== kind::INTS_MODULUS_TOTAL
|| k
== kind::INTS_DIVISION_TOTAL
);
1157 bool dIsConstant
= d
.isConst();
1158 if(dIsConstant
&& d
.getConst
<Rational
>().isZero()){
1159 // (div x 0) ---> 0 or (mod x 0) ---> 0
1160 return returnRewrite(t
, nm
->mkConstInt(0), Rewrite::DIV_MOD_BY_ZERO
);
1161 }else if(dIsConstant
&& d
.getConst
<Rational
>().isOne()){
1162 if (k
== kind::INTS_MODULUS_TOTAL
)
1165 return returnRewrite(t
, nm
->mkConstInt(0), Rewrite::MOD_BY_ONE
);
1167 Assert(k
== kind::INTS_DIVISION_TOTAL
);
1169 return returnRewrite(t
, n
, Rewrite::DIV_BY_ONE
);
1171 else if (dIsConstant
&& d
.getConst
<Rational
>().sgn() < 0)
1174 // (div x (- c)) ---> (- (div x c))
1175 // (mod x (- c)) ---> (mod x c)
1176 Node nn
= nm
->mkNode(k
, t
[0], nm
->mkConstInt(-t
[1].getConst
<Rational
>()));
1177 Node ret
= (k
== kind::INTS_DIVISION
|| k
== kind::INTS_DIVISION_TOTAL
)
1178 ? nm
->mkNode(kind::NEG
, nn
)
1180 return returnRewrite(t
, ret
, Rewrite::DIV_MOD_PULL_NEG_DEN
);
1182 else if (dIsConstant
&& n
.isConst())
1184 Assert(d
.getConst
<Rational
>().isIntegral());
1185 Assert(n
.getConst
<Rational
>().isIntegral());
1186 Assert(!d
.getConst
<Rational
>().isZero());
1187 Integer di
= d
.getConst
<Rational
>().getNumerator();
1188 Integer ni
= n
.getConst
<Rational
>().getNumerator();
1190 bool isDiv
= (k
== kind::INTS_DIVISION
|| k
== kind::INTS_DIVISION_TOTAL
);
1192 Integer result
= isDiv
? ni
.euclidianDivideQuotient(di
) : ni
.euclidianDivideRemainder(di
);
1194 // constant evaluation
1195 // (mod c1 c2) ---> c3 or (div c1 c2) ---> c3
1196 Node resultNode
= nm
->mkConstInt(Rational(result
));
1197 return returnRewrite(t
, resultNode
, Rewrite::CONST_EVAL
);
1199 if (k
== kind::INTS_MODULUS_TOTAL
)
1201 // Note these rewrites do not need to account for modulus by zero as being
1202 // a UF, which is handled by the reduction of INTS_MODULUS.
1203 Kind k0
= t
[0].getKind();
1204 if (k0
== kind::INTS_MODULUS_TOTAL
&& t
[0][1] == t
[1])
1206 // (mod (mod x c) c) --> (mod x c)
1207 return returnRewrite(t
, t
[0], Rewrite::MOD_OVER_MOD
);
1209 else if (k0
== kind::NONLINEAR_MULT
|| k0
== kind::MULT
|| k0
== kind::ADD
)
1212 std::vector
<Node
> newChildren
;
1213 bool childChanged
= false;
1214 for (const Node
& tc
: t
[0])
1216 if (tc
.getKind() == kind::INTS_MODULUS_TOTAL
&& tc
[1] == t
[1])
1218 newChildren
.push_back(tc
[0]);
1219 childChanged
= true;
1222 newChildren
.push_back(tc
);
1226 // (mod (op ... (mod x c) ...) c) ---> (mod (op ... x ...) c) where
1227 // op is one of { NONLINEAR_MULT, MULT, ADD }.
1228 Node ret
= nm
->mkNode(k0
, newChildren
);
1229 ret
= nm
->mkNode(kind::INTS_MODULUS_TOTAL
, ret
, t
[1]);
1230 return returnRewrite(t
, ret
, Rewrite::MOD_CHILD_MOD
);
1236 Assert(k
== kind::INTS_DIVISION_TOTAL
);
1237 // Note these rewrites do not need to account for division by zero as being
1238 // a UF, which is handled by the reduction of INTS_DIVISION.
1239 if (t
[0].getKind() == kind::INTS_MODULUS_TOTAL
&& t
[0][1] == t
[1])
1241 // (div (mod x c) c) --> 0
1242 Node ret
= nm
->mkConstInt(0);
1243 return returnRewrite(t
, ret
, Rewrite::DIV_OVER_MOD
);
1246 return RewriteResponse(REWRITE_DONE
, t
);
1249 TrustNode
ArithRewriter::expandDefinition(Node node
)
1251 // call eliminate operators, to eliminate partial operators only
1252 std::vector
<SkolemLemma
> lems
;
1253 TrustNode ret
= d_opElim
.eliminate(node
, lems
, true);
1254 Assert(lems
.empty());
1258 RewriteResponse
ArithRewriter::returnRewrite(TNode t
, Node ret
, Rewrite r
)
1260 Trace("arith-rewrite") << "ArithRewriter : " << t
<< " == " << ret
<< " by "
1262 return RewriteResponse(REWRITE_AGAIN_FULL
, ret
);
1265 } // namespace arith
1266 } // namespace theory