Fix another rewrite involving iand (#8054)
[cvc5.git] / src / theory / arith / arith_rewriter.cpp
1 /******************************************************************************
2 * Top contributors (to current version):
3 * Andrew Reynolds, Tim King, Morgan Deters
4 *
5 * This file is part of the cvc5 project.
6 *
7 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 * in the top-level source directory and their institutional affiliations.
9 * All rights reserved. See the file COPYING in the top-level source
10 * directory for licensing information.
11 * ****************************************************************************
12 *
13 * [[ Add one-line brief description here ]]
14 *
15 * [[ Add lengthier description here ]]
16 * \todo document this file
17 */
18
19 #include "theory/arith/arith_rewriter.h"
20
21 #include <optional>
22 #include <set>
23 #include <sstream>
24 #include <stack>
25 #include <vector>
26
27 #include "expr/algorithm/flatten.h"
28 #include "smt/logic_exception.h"
29 #include "theory/arith/arith_msum.h"
30 #include "theory/arith/arith_utilities.h"
31 #include "theory/arith/normal_form.h"
32 #include "theory/arith/operator_elim.h"
33 #include "theory/theory.h"
34 #include "util/bitvector.h"
35 #include "util/divisible.h"
36 #include "util/iand.h"
37 #include "util/real_algebraic_number.h"
38
39 using namespace cvc5::kind;
40
41 namespace cvc5 {
42 namespace theory {
43 namespace arith {
44
45 namespace {
46
47 /**
48 * Implements an ordering on arithmetic leaf nodes, excluding rationals. As this
49 * comparator is meant to be used on children of Kind::NONLINEAR_MULT, we expect
50 * rationals to be handled separately. Furthermore, we expect there to be only a
51 * single real algebraic number.
52 * It broadly categorizes leaf nodes into real algebraic numbers, integers,
53 * variables, and the rest. The ordering is built as follows:
54 * - real algebraic numbers come first
55 * - real terms come before integer terms
56 * - variables come before non-variable terms
57 * - finally, fall back to node ordering
58 */
59 struct LeafNodeComparator
60 {
61 /** Implements operator<(a, b) as described above */
62 bool operator()(TNode a, TNode b)
63 {
64 if (a == b) return false;
65
66 bool aIsRAN = a.getKind() == Kind::REAL_ALGEBRAIC_NUMBER;
67 bool bIsRAN = b.getKind() == Kind::REAL_ALGEBRAIC_NUMBER;
68 if (aIsRAN != bIsRAN) return aIsRAN;
69 Assert(!aIsRAN && !bIsRAN) << "real algebraic numbers should be combined";
70
71 bool aIsInt = a.getType().isInteger();
72 bool bIsInt = b.getType().isInteger();
73 if (aIsInt != bIsInt) return !aIsInt;
74
75 bool aIsVar = a.isVar();
76 bool bIsVar = b.isVar();
77 if (aIsVar != bIsVar) return aIsVar;
78
79 return a < b;
80 }
81 };
82
83 /**
84 * Implements an ordering on arithmetic nonlinear multiplications. As we assume
85 * rationals to be handled separately, we only consider Kind::NONLINEAR_MULT as
86 * multiplication terms. For individual factors of the product, we rely on the
87 * ordering from LeafNodeComparator. Furthermore, we expect products to be
88 * sorted according to LeafNodeComparator. The ordering is built as follows:
89 * - single factors come first (everything that is not NONLINEAR_MULT)
90 * - multiplications with less factors come first
91 * - multiplications are compared lexicographically
92 */
93 struct ProductNodeComparator
94 {
95 /** Implements operator<(a, b) as described above */
96 bool operator()(TNode a, TNode b)
97 {
98 if (a == b) return false;
99
100 Assert(a.getKind() != Kind::MULT);
101 Assert(b.getKind() != Kind::MULT);
102
103 bool aIsMult = a.getKind() == Kind::NONLINEAR_MULT;
104 bool bIsMult = b.getKind() == Kind::NONLINEAR_MULT;
105 if (aIsMult != bIsMult) return !aIsMult;
106
107 if (!aIsMult)
108 {
109 return LeafNodeComparator()(a, b);
110 }
111
112 size_t aLen = a.getNumChildren();
113 size_t bLen = b.getNumChildren();
114 if (aLen != bLen) return aLen < bLen;
115
116 for (size_t i = 0; i < aLen; ++i)
117 {
118 if (a[i] != b[i])
119 {
120 return LeafNodeComparator()(a[i], b[i]);
121 }
122 }
123 Unreachable() << "Nodes are different, but have the same content";
124 return false;
125 }
126 };
127
128
129 template <typename L, typename R>
130 bool evaluateRelation(Kind rel, const L& l, const R& r)
131 {
132 switch (rel)
133 {
134 case Kind::LT: return l < r;
135 case Kind::LEQ: return l <= r;
136 case Kind::EQUAL: return l == r;
137 case Kind::GEQ: return l >= r;
138 case Kind::GT: return l > r;
139 default: Unreachable(); return false;
140 }
141 }
142
143 /**
144 * Check whether the parent has a child that is a constant zero.
145 * If so, return this child. Otherwise, return std::nullopt.
146 */
147 template <typename Iterable>
148 std::optional<TNode> getZeroChild(const Iterable& parent)
149 {
150 for (const auto& node : parent)
151 {
152 if (node.isConst() && node.template getConst<Rational>().isZero())
153 {
154 return node;
155 }
156 }
157 return std::nullopt;
158 }
159
160 } // namespace
161
162 ArithRewriter::ArithRewriter(OperatorElim& oe) : d_opElim(oe) {}
163
164 RewriteResponse ArithRewriter::preRewrite(TNode t)
165 {
166 Trace("arith-rewriter") << "preRewrite(" << t << ")" << std::endl;
167 if (isAtom(t))
168 {
169 auto res = preRewriteAtom(t);
170 Trace("arith-rewriter")
171 << res.d_status << " -> " << res.d_node << std::endl;
172 return res;
173 }
174 auto res = preRewriteTerm(t);
175 Trace("arith-rewriter") << res.d_status << " -> " << res.d_node << std::endl;
176 return res;
177 }
178
179 RewriteResponse ArithRewriter::postRewrite(TNode t)
180 {
181 Trace("arith-rewriter") << "postRewrite(" << t << ")" << std::endl;
182 if (isAtom(t))
183 {
184 auto res = postRewriteAtom(t);
185 Trace("arith-rewriter")
186 << res.d_status << " -> " << res.d_node << std::endl;
187 return res;
188 }
189 auto res = postRewriteTerm(t);
190 Trace("arith-rewriter") << res.d_status << " -> " << res.d_node << std::endl;
191 return res;
192 }
193
194 RewriteResponse ArithRewriter::preRewriteAtom(TNode atom)
195 {
196 Assert(isAtom(atom));
197
198 NodeManager* nm = NodeManager::currentNM();
199
200 if (isRelationOperator(atom.getKind()) && atom[0] == atom[1])
201 {
202 switch (atom.getKind())
203 {
204 case Kind::LT: return RewriteResponse(REWRITE_DONE, nm->mkConst(false));
205 case Kind::LEQ: return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
206 case Kind::EQUAL: return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
207 case Kind::GEQ: return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
208 case Kind::GT: return RewriteResponse(REWRITE_DONE, nm->mkConst(false));
209 default:;
210 }
211 }
212
213 switch (atom.getKind())
214 {
215 case Kind::GT:
216 return RewriteResponse(REWRITE_DONE,
217 nm->mkNode(kind::LEQ, atom[0], atom[1]).notNode());
218 case Kind::LT:
219 return RewriteResponse(REWRITE_DONE,
220 nm->mkNode(kind::GEQ, atom[0], atom[1]).notNode());
221 case Kind::IS_INTEGER:
222 if (atom[0].getType().isInteger())
223 {
224 return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
225 }
226 break;
227 case Kind::DIVISIBLE:
228 if (atom.getOperator().getConst<Divisible>().k.isOne())
229 {
230 return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
231 }
232 break;
233 default:;
234 }
235
236 return RewriteResponse(REWRITE_DONE, atom);
237 }
238
239 RewriteResponse ArithRewriter::postRewriteAtom(TNode atom)
240 {
241 Assert(isAtom(atom));
242 if (atom.getKind() == kind::IS_INTEGER)
243 {
244 return rewriteExtIntegerOp(atom);
245 }
246 else if (atom.getKind() == kind::DIVISIBLE)
247 {
248 if (atom[0].isConst())
249 {
250 return RewriteResponse(REWRITE_DONE,
251 NodeManager::currentNM()->mkConst(bool(
252 (atom[0].getConst<Rational>()
253 / atom.getOperator().getConst<Divisible>().k)
254 .isIntegral())));
255 }
256 if (atom.getOperator().getConst<Divisible>().k.isOne())
257 {
258 return RewriteResponse(REWRITE_DONE,
259 NodeManager::currentNM()->mkConst(true));
260 }
261 NodeManager* nm = NodeManager::currentNM();
262 return RewriteResponse(
263 REWRITE_AGAIN,
264 nm->mkNode(kind::EQUAL,
265 nm->mkNode(kind::INTS_MODULUS_TOTAL,
266 atom[0],
267 nm->mkConstInt(Rational(
268 atom.getOperator().getConst<Divisible>().k))),
269 nm->mkConstInt(Rational(0))));
270 }
271
272 // left |><| right
273 TNode left = atom[0];
274 TNode right = atom[1];
275
276 auto* nm = NodeManager::currentNM();
277 if (left.isConst())
278 {
279 const Rational& l = left.getConst<Rational>();
280 if (right.isConst())
281 {
282 const Rational& r = right.getConst<Rational>();
283 return RewriteResponse(
284 REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
285 }
286 else if (right.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
287 {
288 const RealAlgebraicNumber& r =
289 right.getOperator().getConst<RealAlgebraicNumber>();
290 return RewriteResponse(
291 REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
292 }
293 }
294 else if (left.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
295 {
296 const RealAlgebraicNumber& l =
297 left.getOperator().getConst<RealAlgebraicNumber>();
298 if (right.isConst())
299 {
300 const Rational& r = right.getConst<Rational>();
301 return RewriteResponse(
302 REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
303 }
304 else if (right.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
305 {
306 const RealAlgebraicNumber& r =
307 right.getOperator().getConst<RealAlgebraicNumber>();
308 return RewriteResponse(
309 REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
310 }
311 }
312
313 Polynomial pleft = Polynomial::parsePolynomial(left);
314 Polynomial pright = Polynomial::parsePolynomial(right);
315
316 Debug("arith::rewriter") << "pleft " << pleft.getNode() << std::endl;
317 Debug("arith::rewriter") << "pright " << pright.getNode() << std::endl;
318
319 Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
320 Assert(cmp.isNormalForm());
321 return RewriteResponse(REWRITE_DONE, cmp.getNode());
322 }
323
324 bool ArithRewriter::isAtom(TNode n) {
325 Kind k = n.getKind();
326 return arith::isRelationOperator(k) || k == kind::IS_INTEGER
327 || k == kind::DIVISIBLE;
328 }
329
330 RewriteResponse ArithRewriter::rewriteConstant(TNode t){
331 Assert(t.isConst());
332 Assert(t.getKind() == CONST_RATIONAL || t.getKind() == CONST_INTEGER);
333
334 return RewriteResponse(REWRITE_DONE, t);
335 }
336
337 RewriteResponse ArithRewriter::rewriteRAN(TNode t)
338 {
339 Assert(t.getKind() == REAL_ALGEBRAIC_NUMBER);
340
341 const RealAlgebraicNumber& r =
342 t.getOperator().getConst<RealAlgebraicNumber>();
343 if (r.isRational())
344 {
345 return RewriteResponse(
346 REWRITE_DONE, NodeManager::currentNM()->mkConstReal(r.toRational()));
347 }
348
349 return RewriteResponse(REWRITE_DONE, t);
350 }
351
352 RewriteResponse ArithRewriter::rewriteVariable(TNode t){
353 Assert(t.isVar());
354
355 return RewriteResponse(REWRITE_DONE, t);
356 }
357
358 RewriteResponse ArithRewriter::rewriteSub(TNode t)
359 {
360 Assert(t.getKind() == kind::SUB);
361 Assert(t.getNumChildren() == 2);
362
363 auto* nm = NodeManager::currentNM();
364
365 if (t[0] == t[1])
366 {
367 return RewriteResponse(REWRITE_DONE,
368 nm->mkConstRealOrInt(t.getType(), Rational(0)));
369 }
370 return RewriteResponse(REWRITE_AGAIN_FULL,
371 nm->mkNode(Kind::ADD, t[0], makeUnaryMinusNode(t[1])));
372 }
373
374 RewriteResponse ArithRewriter::rewriteNeg(TNode t, bool pre)
375 {
376 Assert(t.getKind() == kind::NEG);
377
378 if (t[0].isConst())
379 {
380 Rational neg = -(t[0].getConst<Rational>());
381 NodeManager* nm = NodeManager::currentNM();
382 return RewriteResponse(REWRITE_DONE,
383 nm->mkConstRealOrInt(t[0].getType(), neg));
384 }
385 if (t[0].getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
386 {
387 const RealAlgebraicNumber& r =
388 t[0].getOperator().getConst<RealAlgebraicNumber>();
389 NodeManager* nm = NodeManager::currentNM();
390 return RewriteResponse(REWRITE_DONE, nm->mkRealAlgebraicNumber(-r));
391 }
392
393 Node noUminus = makeUnaryMinusNode(t[0]);
394 if(pre)
395 return RewriteResponse(REWRITE_DONE, noUminus);
396 else
397 return RewriteResponse(REWRITE_AGAIN, noUminus);
398 }
399
400 RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
401 if(t.isConst()){
402 return rewriteConstant(t);
403 }else if(t.isVar()){
404 return rewriteVariable(t);
405 }else{
406 switch(Kind k = t.getKind()){
407 case kind::REAL_ALGEBRAIC_NUMBER: return rewriteRAN(t);
408 case kind::SUB: return rewriteSub(t);
409 case kind::NEG: return rewriteNeg(t, true);
410 case kind::DIVISION:
411 case kind::DIVISION_TOTAL: return rewriteDiv(t, true);
412 case kind::ADD: return preRewritePlus(t);
413 case kind::MULT:
414 case kind::NONLINEAR_MULT: return preRewriteMult(t);
415 case kind::IAND: return RewriteResponse(REWRITE_DONE, t);
416 case kind::POW2: return RewriteResponse(REWRITE_DONE, t);
417 case kind::EXPONENTIAL:
418 case kind::SINE:
419 case kind::COSINE:
420 case kind::TANGENT:
421 case kind::COSECANT:
422 case kind::SECANT:
423 case kind::COTANGENT:
424 case kind::ARCSINE:
425 case kind::ARCCOSINE:
426 case kind::ARCTANGENT:
427 case kind::ARCCOSECANT:
428 case kind::ARCSECANT:
429 case kind::ARCCOTANGENT:
430 case kind::SQRT: return preRewriteTranscendental(t);
431 case kind::INTS_DIVISION:
432 case kind::INTS_MODULUS: return rewriteIntsDivMod(t, true);
433 case kind::INTS_DIVISION_TOTAL:
434 case kind::INTS_MODULUS_TOTAL: return rewriteIntsDivModTotal(t, true);
435 case kind::ABS: return rewriteAbs(t);
436 case kind::IS_INTEGER:
437 case kind::TO_INTEGER: return RewriteResponse(REWRITE_DONE, t);
438 case kind::TO_REAL:
439 case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
440 case kind::POW: return RewriteResponse(REWRITE_DONE, t);
441 case kind::PI: return RewriteResponse(REWRITE_DONE, t);
442 default: Unhandled() << k;
443 }
444 }
445 }
446
447 RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
448 if(t.isConst()){
449 return rewriteConstant(t);
450 }else if(t.isVar()){
451 return rewriteVariable(t);
452 }else{
453 Trace("arith-rewriter") << "postRewriteTerm: " << t << std::endl;
454 switch(t.getKind()){
455 case kind::REAL_ALGEBRAIC_NUMBER: return rewriteRAN(t);
456 case kind::SUB: return rewriteSub(t);
457 case kind::NEG: return rewriteNeg(t, false);
458 case kind::DIVISION:
459 case kind::DIVISION_TOTAL: return rewriteDiv(t, false);
460 case kind::ADD: return postRewritePlus(t);
461 case kind::MULT:
462 case kind::NONLINEAR_MULT: return postRewriteMult(t);
463 case kind::IAND: return postRewriteIAnd(t);
464 case kind::POW2: return postRewritePow2(t);
465 case kind::EXPONENTIAL:
466 case kind::SINE:
467 case kind::COSINE:
468 case kind::TANGENT:
469 case kind::COSECANT:
470 case kind::SECANT:
471 case kind::COTANGENT:
472 case kind::ARCSINE:
473 case kind::ARCCOSINE:
474 case kind::ARCTANGENT:
475 case kind::ARCCOSECANT:
476 case kind::ARCSECANT:
477 case kind::ARCCOTANGENT:
478 case kind::SQRT: return postRewriteTranscendental(t);
479 case kind::INTS_DIVISION:
480 case kind::INTS_MODULUS: return rewriteIntsDivMod(t, false);
481 case kind::INTS_DIVISION_TOTAL:
482 case kind::INTS_MODULUS_TOTAL: return rewriteIntsDivModTotal(t, false);
483 case kind::ABS: return rewriteAbs(t);
484 case kind::TO_REAL:
485 case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
486 case kind::TO_INTEGER: return rewriteExtIntegerOp(t);
487 case kind::POW:
488 {
489 if (t[1].isConst())
490 {
491 const Rational& exp = t[1].getConst<Rational>();
492 TNode base = t[0];
493 if(exp.sgn() == 0){
494 return RewriteResponse(REWRITE_DONE,
495 NodeManager::currentNM()->mkConstRealOrInt(
496 t.getType(), Rational(1)));
497 }else if(exp.sgn() > 0 && exp.isIntegral()){
498 cvc5::Rational r(expr::NodeValue::MAX_CHILDREN);
499 if (exp <= r)
500 {
501 unsigned num = exp.getNumerator().toUnsignedInt();
502 if( num==1 ){
503 return RewriteResponse(REWRITE_AGAIN, base);
504 }else{
505 NodeBuilder nb(kind::MULT);
506 for(unsigned i=0; i < num; ++i){
507 nb << base;
508 }
509 Assert(nb.getNumChildren() > 0);
510 Node mult = nb;
511 return RewriteResponse(REWRITE_AGAIN, mult);
512 }
513 }
514 }
515 }
516 else if (t[0].isConst()
517 && t[0].getConst<Rational>().getNumerator().toUnsignedInt()
518 == 2)
519 {
520 return RewriteResponse(
521 REWRITE_DONE, NodeManager::currentNM()->mkNode(kind::POW2, t[1]));
522 }
523
524 // Todo improve the exception thrown
525 std::stringstream ss;
526 ss << "The exponent of the POW(^) operator can only be a positive "
527 "integral constant below "
528 << (expr::NodeValue::MAX_CHILDREN + 1) << ". ";
529 ss << "Exception occurred in:" << std::endl;
530 ss << " " << t;
531 throw LogicException(ss.str());
532 }
533 case kind::PI:
534 return RewriteResponse(REWRITE_DONE, t);
535 default:
536 Unreachable();
537 }
538 }
539 }
540
541
542 RewriteResponse ArithRewriter::preRewritePlus(TNode t){
543 Assert(t.getKind() == kind::ADD);
544 return RewriteResponse(REWRITE_DONE, expr::algorithm::flatten(t));
545 }
546
547 RewriteResponse ArithRewriter::postRewritePlus(TNode t){
548 Assert(t.getKind() == kind::ADD);
549 Assert(t.getNumChildren() > 1);
550
551 {
552 Node flat = expr::algorithm::flatten(t);
553 if (flat != t)
554 {
555 return RewriteResponse(REWRITE_AGAIN, flat);
556 }
557 }
558
559 Rational rational;
560 RealAlgebraicNumber ran;
561 std::vector<Monomial> monomials;
562 std::vector<Polynomial> polynomials;
563
564 for (const auto& child : t)
565 {
566 if (child.isConst())
567 {
568 if (child.getConst<Rational>().isZero())
569 {
570 continue;
571 }
572 rational += child.getConst<Rational>();
573 }
574 else if (child.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
575 {
576 ran += child.getOperator().getConst<RealAlgebraicNumber>();
577 }
578 else if (Monomial::isMember(child))
579 {
580 monomials.emplace_back(Monomial::parseMonomial(child));
581 }
582 else
583 {
584 polynomials.emplace_back(Polynomial::parsePolynomial(child));
585 }
586 }
587
588 if(!monomials.empty()){
589 Monomial::sort(monomials);
590 Monomial::combineAdjacentMonomials(monomials);
591 polynomials.emplace_back(Polynomial::mkPolynomial(monomials));
592 }
593 if (!rational.isZero())
594 {
595 polynomials.emplace_back(
596 Polynomial::mkPolynomial(Constant::mkConstant(rational)));
597 }
598
599 Polynomial poly = Polynomial::sumPolynomials(polynomials);
600
601 if (isZero(ran))
602 {
603 return RewriteResponse(REWRITE_DONE, poly.getNode());
604 }
605 if (poly.containsConstant())
606 {
607 ran += RealAlgebraicNumber(poly.getHead().getConstant().getValue());
608 if (!poly.isConstant())
609 {
610 poly = poly.getTail();
611 }
612 }
613
614 auto* nm = NodeManager::currentNM();
615 if (poly.isConstant())
616 {
617 return RewriteResponse(REWRITE_DONE, nm->mkRealAlgebraicNumber(ran));
618 }
619 return RewriteResponse(
620 REWRITE_DONE,
621 nm->mkNode(Kind::ADD, nm->mkRealAlgebraicNumber(ran), poly.getNode()));
622 }
623
624 RewriteResponse ArithRewriter::preRewriteMult(TNode node)
625 {
626 Assert(node.getKind() == kind::MULT
627 || node.getKind() == kind::NONLINEAR_MULT);
628
629 auto res = getZeroChild(node);
630 if (res)
631 {
632 return RewriteResponse(REWRITE_DONE, *res);
633 }
634 return RewriteResponse(REWRITE_DONE, node);
635 }
636
637 RewriteResponse ArithRewriter::postRewriteMult(TNode t){
638 Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT);
639 Assert(t.getNumChildren() >= 2);
640
641 if (auto res = getZeroChild(t); res)
642 {
643 return RewriteResponse(REWRITE_DONE, *res);
644 }
645
646 Rational rational = Rational(1);
647 RealAlgebraicNumber ran = RealAlgebraicNumber(Integer(1));
648 Polynomial poly = Polynomial::mkOne();
649
650 for (const auto& child : t)
651 {
652 if (child.isConst())
653 {
654 if (child.getConst<Rational>().isZero())
655 {
656 return RewriteResponse(REWRITE_DONE, child);
657 }
658 rational *= child.getConst<Rational>();
659 }
660 else if (child.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
661 {
662 ran *= child.getOperator().getConst<RealAlgebraicNumber>();
663 }
664 else
665 {
666 poly = poly * Polynomial::parsePolynomial(child);
667 }
668 }
669
670 if (!rational.isOne())
671 {
672 poly = poly * rational;
673 }
674 if (isOne(ran))
675 {
676 return RewriteResponse(REWRITE_DONE, poly.getNode());
677 }
678 auto* nm = NodeManager::currentNM();
679 if (poly.isConstant())
680 {
681 ran *= RealAlgebraicNumber(poly.getHead().getConstant().getValue());
682 return RewriteResponse(REWRITE_DONE, nm->mkRealAlgebraicNumber(ran));
683 }
684 return RewriteResponse(
685 REWRITE_DONE,
686 nm->mkNode(
687 Kind::MULT, nm->mkRealAlgebraicNumber(ran), poly.getNode()));
688 }
689
690 RewriteResponse ArithRewriter::postRewritePow2(TNode t)
691 {
692 Assert(t.getKind() == kind::POW2);
693 NodeManager* nm = NodeManager::currentNM();
694 // if constant, we eliminate
695 if (t[0].isConst())
696 {
697 // pow2 is only supported for integers
698 Assert(t[0].getType().isInteger());
699 Integer i = t[0].getConst<Rational>().getNumerator();
700 if (i < 0)
701 {
702 return RewriteResponse(REWRITE_DONE, nm->mkConstInt(Rational(0)));
703 }
704 // (pow2 t) ---> (pow 2 t) and continue rewriting to eliminate pow
705 Node two = nm->mkConstInt(Rational(Integer(2)));
706 Node ret = nm->mkNode(kind::POW, two, t[0]);
707 return RewriteResponse(REWRITE_AGAIN, ret);
708 }
709 return RewriteResponse(REWRITE_DONE, t);
710 }
711
712 RewriteResponse ArithRewriter::postRewriteIAnd(TNode t)
713 {
714 Assert(t.getKind() == kind::IAND);
715 size_t bsize = t.getOperator().getConst<IntAnd>().d_size;
716 NodeManager* nm = NodeManager::currentNM();
717 // if constant, we eliminate
718 if (t[0].isConst() && t[1].isConst())
719 {
720 Node iToBvop = nm->mkConst(IntToBitVector(bsize));
721 Node arg1 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[0]);
722 Node arg2 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[1]);
723 Node bvand = nm->mkNode(kind::BITVECTOR_AND, arg1, arg2);
724 Node ret = nm->mkNode(kind::BITVECTOR_TO_NAT, bvand);
725 return RewriteResponse(REWRITE_AGAIN_FULL, ret);
726 }
727 else if (t[0] > t[1])
728 {
729 // ((_ iand k) x y) ---> ((_ iand k) y x) if x > y by node ordering
730 Node ret = nm->mkNode(kind::IAND, t.getOperator(), t[1], t[0]);
731 return RewriteResponse(REWRITE_AGAIN, ret);
732 }
733 else if (t[0] == t[1])
734 {
735 // ((_ iand k) x x) ---> (mod x 2^k)
736 Node twok = nm->mkConstInt(Rational(Integer(2).pow(bsize)));
737 Node ret = nm->mkNode(kind::INTS_MODULUS, t[0], twok);
738 return RewriteResponse(REWRITE_AGAIN, ret);
739 }
740 // simplifications involving constants
741 for (unsigned i = 0; i < 2; i++)
742 {
743 if (!t[i].isConst())
744 {
745 continue;
746 }
747 if (t[i].getConst<Rational>().sgn() == 0)
748 {
749 // ((_ iand k) 0 y) ---> 0
750 return RewriteResponse(REWRITE_DONE, t[i]);
751 }
752 if (t[i].getConst<Rational>().getNumerator() == Integer(2).pow(bsize) - 1)
753 {
754 // ((_ iand k) 111...1 y) ---> (mod y 2^k)
755 Node twok = nm->mkConstInt(Rational(Integer(2).pow(bsize)));
756 Node ret = nm->mkNode(kind::INTS_MODULUS, t[1-i], twok);
757 return RewriteResponse(REWRITE_AGAIN, ret);
758 }
759 }
760 return RewriteResponse(REWRITE_DONE, t);
761 }
762
763 RewriteResponse ArithRewriter::preRewriteTranscendental(TNode t) {
764 return RewriteResponse(REWRITE_DONE, t);
765 }
766
767 RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) {
768 Trace("arith-tf-rewrite") << "Rewrite transcendental function : " << t << std::endl;
769 NodeManager* nm = NodeManager::currentNM();
770 switch( t.getKind() ){
771 case kind::EXPONENTIAL: {
772 if (t[0].isConst())
773 {
774 Node one = nm->mkConstReal(Rational(1));
775 if(t[0].getConst<Rational>().sgn()>=0 && t[0].getType().isInteger() && t[0]!=one){
776 return RewriteResponse(
777 REWRITE_AGAIN,
778 nm->mkNode(kind::POW, nm->mkNode(kind::EXPONENTIAL, one), t[0]));
779 }else{
780 return RewriteResponse(REWRITE_DONE, t);
781 }
782 }
783 else if (t[0].getKind() == kind::ADD)
784 {
785 std::vector<Node> product;
786 for (const Node tc : t[0])
787 {
788 product.push_back(nm->mkNode(kind::EXPONENTIAL, tc));
789 }
790 // We need to do a full rewrite here, since we can get exponentials of
791 // constants, e.g. when we are rewriting exp(2 + x)
792 return RewriteResponse(REWRITE_AGAIN_FULL,
793 nm->mkNode(kind::MULT, product));
794 }
795 }
796 break;
797 case kind::SINE:
798 if (t[0].isConst())
799 {
800 const Rational& rat = t[0].getConst<Rational>();
801 if(rat.sgn() == 0){
802 return RewriteResponse(REWRITE_DONE, nm->mkConstReal(Rational(0)));
803 }
804 else if (rat.sgn() == -1)
805 {
806 Node ret = nm->mkNode(kind::NEG,
807 nm->mkNode(kind::SINE, nm->mkConstReal(-rat)));
808 return RewriteResponse(REWRITE_AGAIN_FULL, ret);
809 }
810 }else{
811 // get the factor of PI in the argument
812 Node pi_factor;
813 Node pi;
814 Node rem;
815 std::map<Node, Node> msum;
816 if (ArithMSum::getMonomialSum(t[0], msum))
817 {
818 pi = mkPi();
819 std::map<Node, Node>::iterator itm = msum.find(pi);
820 if (itm != msum.end())
821 {
822 if (itm->second.isNull())
823 {
824 pi_factor = nm->mkConstReal(Rational(1));
825 }
826 else
827 {
828 pi_factor = itm->second;
829 }
830 msum.erase(pi);
831 if (!msum.empty())
832 {
833 rem = ArithMSum::mkNode(t[0].getType(), msum);
834 }
835 }
836 }
837 else
838 {
839 Assert(false);
840 }
841
842 // if there is a factor of PI
843 if( !pi_factor.isNull() ){
844 Trace("arith-tf-rewrite-debug") << "Process pi factor = " << pi_factor << std::endl;
845 Rational r = pi_factor.getConst<Rational>();
846 Rational r_abs = r.abs();
847 Rational rone = Rational(1);
848 Rational rtwo = Rational(2);
849 if (r_abs > rone)
850 {
851 //add/substract 2*pi beyond scope
852 Rational ra_div_two = (r_abs + rone) / rtwo;
853 Node new_pi_factor;
854 if (r.sgn() == 1)
855 {
856 new_pi_factor = nm->mkConstReal(r - rtwo * ra_div_two.floor());
857 }
858 else
859 {
860 Assert(r.sgn() == -1);
861 new_pi_factor = nm->mkConstReal(r + rtwo * ra_div_two.floor());
862 }
863 Node new_arg = nm->mkNode(kind::MULT, new_pi_factor, pi);
864 if (!rem.isNull())
865 {
866 new_arg = nm->mkNode(kind::ADD, new_arg, rem);
867 }
868 // sin( 2*n*PI + x ) = sin( x )
869 return RewriteResponse(REWRITE_AGAIN_FULL,
870 nm->mkNode(kind::SINE, new_arg));
871 }
872 else if (r_abs == rone)
873 {
874 // sin( PI + x ) = -sin( x )
875 if (rem.isNull())
876 {
877 return RewriteResponse(REWRITE_DONE, nm->mkConstReal(Rational(0)));
878 }
879 else
880 {
881 return RewriteResponse(
882 REWRITE_AGAIN_FULL,
883 nm->mkNode(kind::NEG, nm->mkNode(kind::SINE, rem)));
884 }
885 }
886 else if (rem.isNull())
887 {
888 // other rational cases based on Niven's theorem
889 // (https://en.wikipedia.org/wiki/Niven%27s_theorem)
890 Integer one = Integer(1);
891 Integer two = Integer(2);
892 Integer six = Integer(6);
893 if (r_abs.getDenominator() == two)
894 {
895 Assert(r_abs.getNumerator() == one);
896 return RewriteResponse(REWRITE_DONE,
897 nm->mkConstReal(Rational(r.sgn())));
898 }
899 else if (r_abs.getDenominator() == six)
900 {
901 Integer five = Integer(5);
902 if (r_abs.getNumerator() == one || r_abs.getNumerator() == five)
903 {
904 return RewriteResponse(
905 REWRITE_DONE,
906 nm->mkConstReal(Rational(r.sgn()) / Rational(2)));
907 }
908 }
909 }
910 }
911 }
912 break;
913 case kind::COSINE: {
914 return RewriteResponse(
915 REWRITE_AGAIN_FULL,
916 nm->mkNode(
917 kind::SINE,
918 nm->mkNode(kind::SUB,
919 nm->mkNode(kind::MULT,
920 nm->mkConstReal(Rational(1) / Rational(2)),
921 mkPi()),
922 t[0])));
923 }
924 break;
925 case kind::TANGENT:
926 {
927 return RewriteResponse(REWRITE_AGAIN_FULL,
928 nm->mkNode(kind::DIVISION,
929 nm->mkNode(kind::SINE, t[0]),
930 nm->mkNode(kind::COSINE, t[0])));
931 }
932 break;
933 case kind::COSECANT:
934 {
935 return RewriteResponse(REWRITE_AGAIN_FULL,
936 nm->mkNode(kind::DIVISION,
937 nm->mkConstReal(Rational(1)),
938 nm->mkNode(kind::SINE, t[0])));
939 }
940 break;
941 case kind::SECANT:
942 {
943 return RewriteResponse(REWRITE_AGAIN_FULL,
944 nm->mkNode(kind::DIVISION,
945 nm->mkConstReal(Rational(1)),
946 nm->mkNode(kind::COSINE, t[0])));
947 }
948 break;
949 case kind::COTANGENT:
950 {
951 return RewriteResponse(REWRITE_AGAIN_FULL,
952 nm->mkNode(kind::DIVISION,
953 nm->mkNode(kind::COSINE, t[0]),
954 nm->mkNode(kind::SINE, t[0])));
955 }
956 break;
957 default:
958 break;
959 }
960 return RewriteResponse(REWRITE_DONE, t);
961 }
962
963 Node ArithRewriter::makeUnaryMinusNode(TNode n){
964 NodeManager* nm = NodeManager::currentNM();
965 Rational qNegOne(-1);
966 return nm->mkNode(kind::MULT, nm->mkConstRealOrInt(n.getType(), qNegOne), n);
967 }
968
969 RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
970 Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind() == kind::DIVISION);
971 Assert(t.getNumChildren() == 2);
972
973 Node left = t[0];
974 Node right = t[1];
975 if (right.isConst())
976 {
977 NodeManager* nm = NodeManager::currentNM();
978 const Rational& den = right.getConst<Rational>();
979
980 if(den.isZero()){
981 if(t.getKind() == kind::DIVISION_TOTAL){
982 return RewriteResponse(REWRITE_DONE, nm->mkConstReal(0));
983 }else{
984 // This is unsupported, but this is not a good place to complain
985 return RewriteResponse(REWRITE_DONE, t);
986 }
987 }
988 Assert(den != Rational(0));
989
990 if (left.isConst())
991 {
992 const Rational& num = left.getConst<Rational>();
993 return RewriteResponse(REWRITE_DONE, nm->mkConstReal(num / den));
994 }
995 if (left.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
996 {
997 const RealAlgebraicNumber& num =
998 left.getOperator().getConst<RealAlgebraicNumber>();
999 return RewriteResponse(
1000 REWRITE_DONE,
1001 nm->mkRealAlgebraicNumber(num / RealAlgebraicNumber(den)));
1002 }
1003
1004 Node result = nm->mkConstReal(den.inverse());
1005 Node mult = NodeManager::currentNM()->mkNode(kind::MULT, left, result);
1006 if (pre)
1007 {
1008 return RewriteResponse(REWRITE_DONE, mult);
1009 }
1010 else
1011 {
1012 return RewriteResponse(REWRITE_AGAIN, mult);
1013 }
1014 }
1015 if (right.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
1016 {
1017 NodeManager* nm = NodeManager::currentNM();
1018 const RealAlgebraicNumber& den =
1019 right.getOperator().getConst<RealAlgebraicNumber>();
1020 if (left.isConst())
1021 {
1022 const Rational& num = left.getConst<Rational>();
1023 return RewriteResponse(
1024 REWRITE_DONE,
1025 nm->mkRealAlgebraicNumber(RealAlgebraicNumber(num) / den));
1026 }
1027 if (left.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
1028 {
1029 const RealAlgebraicNumber& num =
1030 left.getOperator().getConst<RealAlgebraicNumber>();
1031 return RewriteResponse(REWRITE_DONE,
1032 nm->mkRealAlgebraicNumber(num / den));
1033 }
1034
1035 Node result = nm->mkRealAlgebraicNumber(inverse(den));
1036 Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
1037 if(pre){
1038 return RewriteResponse(REWRITE_DONE, mult);
1039 }else{
1040 return RewriteResponse(REWRITE_AGAIN, mult);
1041 }
1042 }
1043 return RewriteResponse(REWRITE_DONE, t);
1044 }
1045
1046 RewriteResponse ArithRewriter::rewriteAbs(TNode t)
1047 {
1048 Assert(t.getKind() == Kind::ABS);
1049 Assert(t.getNumChildren() == 1);
1050
1051 if (t[0].isConst())
1052 {
1053 const Rational& rat = t[0].getConst<Rational>();
1054 if (rat >= 0)
1055 {
1056 return RewriteResponse(REWRITE_DONE, t[0]);
1057 }
1058 return RewriteResponse(
1059 REWRITE_DONE,
1060 NodeManager::currentNM()->mkConstRealOrInt(t[0].getType(), -rat));
1061 }
1062 if (t[0].getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
1063 {
1064 const RealAlgebraicNumber& ran =
1065 t[0].getOperator().getConst<RealAlgebraicNumber>();
1066 if (ran >= RealAlgebraicNumber())
1067 {
1068 return RewriteResponse(REWRITE_DONE, t[0]);
1069 }
1070 return RewriteResponse(
1071 REWRITE_DONE, NodeManager::currentNM()->mkRealAlgebraicNumber(-ran));
1072 }
1073 return RewriteResponse(REWRITE_DONE, t);
1074 }
1075
1076 RewriteResponse ArithRewriter::rewriteIntsDivMod(TNode t, bool pre)
1077 {
1078 NodeManager* nm = NodeManager::currentNM();
1079 Kind k = t.getKind();
1080 if (k == kind::INTS_MODULUS)
1081 {
1082 if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
1083 {
1084 // can immediately replace by INTS_MODULUS_TOTAL
1085 Node ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, t[0], t[1]);
1086 return returnRewrite(t, ret, Rewrite::MOD_TOTAL_BY_CONST);
1087 }
1088 }
1089 if (k == kind::INTS_DIVISION)
1090 {
1091 if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
1092 {
1093 // can immediately replace by INTS_DIVISION_TOTAL
1094 Node ret = nm->mkNode(kind::INTS_DIVISION_TOTAL, t[0], t[1]);
1095 return returnRewrite(t, ret, Rewrite::DIV_TOTAL_BY_CONST);
1096 }
1097 }
1098 return RewriteResponse(REWRITE_DONE, t);
1099 }
1100
1101 RewriteResponse ArithRewriter::rewriteExtIntegerOp(TNode t)
1102 {
1103 Assert(t.getKind() == kind::TO_INTEGER || t.getKind() == kind::IS_INTEGER);
1104 bool isPred = t.getKind() == kind::IS_INTEGER;
1105 NodeManager* nm = NodeManager::currentNM();
1106 if (t[0].isConst())
1107 {
1108 Node ret;
1109 if (isPred)
1110 {
1111 ret = nm->mkConst(t[0].getConst<Rational>().isIntegral());
1112 }
1113 else
1114 {
1115 ret = nm->mkConstInt(Rational(t[0].getConst<Rational>().floor()));
1116 }
1117 return returnRewrite(t, ret, Rewrite::INT_EXT_CONST);
1118 }
1119 if (t[0].getType().isInteger())
1120 {
1121 Node ret = isPred ? nm->mkConst(true) : Node(t[0]);
1122 return returnRewrite(t, ret, Rewrite::INT_EXT_INT);
1123 }
1124 if (t[0].getKind() == kind::PI)
1125 {
1126 Node ret = isPred ? nm->mkConst(false) : nm->mkConstReal(Rational(3));
1127 return returnRewrite(t, ret, Rewrite::INT_EXT_PI);
1128 }
1129 return RewriteResponse(REWRITE_DONE, t);
1130 }
1131
1132 RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre)
1133 {
1134 if (pre)
1135 {
1136 // do not rewrite at prewrite.
1137 return RewriteResponse(REWRITE_DONE, t);
1138 }
1139 NodeManager* nm = NodeManager::currentNM();
1140 Kind k = t.getKind();
1141 Assert(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL);
1142 TNode n = t[0];
1143 TNode d = t[1];
1144 bool dIsConstant = d.isConst();
1145 if(dIsConstant && d.getConst<Rational>().isZero()){
1146 // (div x 0) ---> 0 or (mod x 0) ---> 0
1147 return returnRewrite(t, nm->mkConstInt(0), Rewrite::DIV_MOD_BY_ZERO);
1148 }else if(dIsConstant && d.getConst<Rational>().isOne()){
1149 if (k == kind::INTS_MODULUS_TOTAL)
1150 {
1151 // (mod x 1) --> 0
1152 return returnRewrite(t, nm->mkConstInt(0), Rewrite::MOD_BY_ONE);
1153 }
1154 Assert(k == kind::INTS_DIVISION_TOTAL);
1155 // (div x 1) --> x
1156 return returnRewrite(t, n, Rewrite::DIV_BY_ONE);
1157 }
1158 else if (dIsConstant && d.getConst<Rational>().sgn() < 0)
1159 {
1160 // pull negation
1161 // (div x (- c)) ---> (- (div x c))
1162 // (mod x (- c)) ---> (mod x c)
1163 Node nn = nm->mkNode(k, t[0], nm->mkConstInt(-t[1].getConst<Rational>()));
1164 Node ret = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL)
1165 ? nm->mkNode(kind::NEG, nn)
1166 : nn;
1167 return returnRewrite(t, ret, Rewrite::DIV_MOD_PULL_NEG_DEN);
1168 }
1169 else if (dIsConstant && n.isConst())
1170 {
1171 Assert(d.getConst<Rational>().isIntegral());
1172 Assert(n.getConst<Rational>().isIntegral());
1173 Assert(!d.getConst<Rational>().isZero());
1174 Integer di = d.getConst<Rational>().getNumerator();
1175 Integer ni = n.getConst<Rational>().getNumerator();
1176
1177 bool isDiv = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
1178
1179 Integer result = isDiv ? ni.euclidianDivideQuotient(di) : ni.euclidianDivideRemainder(di);
1180
1181 // constant evaluation
1182 // (mod c1 c2) ---> c3 or (div c1 c2) ---> c3
1183 Node resultNode = nm->mkConstInt(Rational(result));
1184 return returnRewrite(t, resultNode, Rewrite::CONST_EVAL);
1185 }
1186 if (k == kind::INTS_MODULUS_TOTAL)
1187 {
1188 // Note these rewrites do not need to account for modulus by zero as being
1189 // a UF, which is handled by the reduction of INTS_MODULUS.
1190 Kind k0 = t[0].getKind();
1191 if (k0 == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
1192 {
1193 // (mod (mod x c) c) --> (mod x c)
1194 return returnRewrite(t, t[0], Rewrite::MOD_OVER_MOD);
1195 }
1196 else if (k0 == kind::NONLINEAR_MULT || k0 == kind::MULT || k0 == kind::ADD)
1197 {
1198 // can drop all
1199 std::vector<Node> newChildren;
1200 bool childChanged = false;
1201 for (const Node& tc : t[0])
1202 {
1203 if (tc.getKind() == kind::INTS_MODULUS_TOTAL && tc[1] == t[1])
1204 {
1205 newChildren.push_back(tc[0]);
1206 childChanged = true;
1207 continue;
1208 }
1209 newChildren.push_back(tc);
1210 }
1211 if (childChanged)
1212 {
1213 // (mod (op ... (mod x c) ...) c) ---> (mod (op ... x ...) c) where
1214 // op is one of { NONLINEAR_MULT, MULT, ADD }.
1215 Node ret = nm->mkNode(k0, newChildren);
1216 ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, ret, t[1]);
1217 return returnRewrite(t, ret, Rewrite::MOD_CHILD_MOD);
1218 }
1219 }
1220 }
1221 else
1222 {
1223 Assert(k == kind::INTS_DIVISION_TOTAL);
1224 // Note these rewrites do not need to account for division by zero as being
1225 // a UF, which is handled by the reduction of INTS_DIVISION.
1226 if (t[0].getKind() == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
1227 {
1228 // (div (mod x c) c) --> 0
1229 Node ret = nm->mkConstInt(0);
1230 return returnRewrite(t, ret, Rewrite::DIV_OVER_MOD);
1231 }
1232 }
1233 return RewriteResponse(REWRITE_DONE, t);
1234 }
1235
1236 TrustNode ArithRewriter::expandDefinition(Node node)
1237 {
1238 // call eliminate operators, to eliminate partial operators only
1239 std::vector<SkolemLemma> lems;
1240 TrustNode ret = d_opElim.eliminate(node, lems, true);
1241 Assert(lems.empty());
1242 return ret;
1243 }
1244
1245 RewriteResponse ArithRewriter::returnRewrite(TNode t, Node ret, Rewrite r)
1246 {
1247 Trace("arith-rewrite") << "ArithRewriter : " << t << " == " << ret << " by "
1248 << r << std::endl;
1249 return RewriteResponse(REWRITE_AGAIN_FULL, ret);
1250 }
1251
1252 } // namespace arith
1253 } // namespace theory
1254 } // namespace cvc5