Change inference scheme in transcendentals to rewrite rule (#8115)
[cvc5.git] / src / theory / arith / arith_rewriter.cpp
1 /******************************************************************************
2 * Top contributors (to current version):
3 * Andrew Reynolds, Tim King, Morgan Deters
4 *
5 * This file is part of the cvc5 project.
6 *
7 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 * in the top-level source directory and their institutional affiliations.
9 * All rights reserved. See the file COPYING in the top-level source
10 * directory for licensing information.
11 * ****************************************************************************
12 *
13 * [[ Add one-line brief description here ]]
14 *
15 * [[ Add lengthier description here ]]
16 * \todo document this file
17 */
18
19 #include "theory/arith/arith_rewriter.h"
20
21 #include <optional>
22 #include <set>
23 #include <sstream>
24 #include <stack>
25 #include <vector>
26
27 #include "expr/algorithm/flatten.h"
28 #include "smt/logic_exception.h"
29 #include "theory/arith/arith_msum.h"
30 #include "theory/arith/arith_utilities.h"
31 #include "theory/arith/normal_form.h"
32 #include "theory/arith/operator_elim.h"
33 #include "theory/theory.h"
34 #include "util/bitvector.h"
35 #include "util/divisible.h"
36 #include "util/iand.h"
37 #include "util/real_algebraic_number.h"
38
39 using namespace cvc5::kind;
40
41 namespace cvc5 {
42 namespace theory {
43 namespace arith {
44
45 namespace {
46
47 /**
48 * Implements an ordering on arithmetic leaf nodes, excluding rationals. As this
49 * comparator is meant to be used on children of Kind::NONLINEAR_MULT, we expect
50 * rationals to be handled separately. Furthermore, we expect there to be only a
51 * single real algebraic number.
52 * It broadly categorizes leaf nodes into real algebraic numbers, integers,
53 * variables, and the rest. The ordering is built as follows:
54 * - real algebraic numbers come first
55 * - real terms come before integer terms
56 * - variables come before non-variable terms
57 * - finally, fall back to node ordering
58 */
59 struct LeafNodeComparator
60 {
61 /** Implements operator<(a, b) as described above */
62 bool operator()(TNode a, TNode b)
63 {
64 if (a == b) return false;
65
66 bool aIsRAN = a.getKind() == Kind::REAL_ALGEBRAIC_NUMBER;
67 bool bIsRAN = b.getKind() == Kind::REAL_ALGEBRAIC_NUMBER;
68 if (aIsRAN != bIsRAN) return aIsRAN;
69 Assert(!aIsRAN && !bIsRAN) << "real algebraic numbers should be combined";
70
71 bool aIsInt = a.getType().isInteger();
72 bool bIsInt = b.getType().isInteger();
73 if (aIsInt != bIsInt) return !aIsInt;
74
75 bool aIsVar = a.isVar();
76 bool bIsVar = b.isVar();
77 if (aIsVar != bIsVar) return aIsVar;
78
79 return a < b;
80 }
81 };
82
83 /**
84 * Implements an ordering on arithmetic nonlinear multiplications. As we assume
85 * rationals to be handled separately, we only consider Kind::NONLINEAR_MULT as
86 * multiplication terms. For individual factors of the product, we rely on the
87 * ordering from LeafNodeComparator. Furthermore, we expect products to be
88 * sorted according to LeafNodeComparator. The ordering is built as follows:
89 * - single factors come first (everything that is not NONLINEAR_MULT)
90 * - multiplications with less factors come first
91 * - multiplications are compared lexicographically
92 */
93 struct ProductNodeComparator
94 {
95 /** Implements operator<(a, b) as described above */
96 bool operator()(TNode a, TNode b)
97 {
98 if (a == b) return false;
99
100 Assert(a.getKind() != Kind::MULT);
101 Assert(b.getKind() != Kind::MULT);
102
103 bool aIsMult = a.getKind() == Kind::NONLINEAR_MULT;
104 bool bIsMult = b.getKind() == Kind::NONLINEAR_MULT;
105 if (aIsMult != bIsMult) return !aIsMult;
106
107 if (!aIsMult)
108 {
109 return LeafNodeComparator()(a, b);
110 }
111
112 size_t aLen = a.getNumChildren();
113 size_t bLen = b.getNumChildren();
114 if (aLen != bLen) return aLen < bLen;
115
116 for (size_t i = 0; i < aLen; ++i)
117 {
118 if (a[i] != b[i])
119 {
120 return LeafNodeComparator()(a[i], b[i]);
121 }
122 }
123 Unreachable() << "Nodes are different, but have the same content";
124 return false;
125 }
126 };
127
128
129 template <typename L, typename R>
130 bool evaluateRelation(Kind rel, const L& l, const R& r)
131 {
132 switch (rel)
133 {
134 case Kind::LT: return l < r;
135 case Kind::LEQ: return l <= r;
136 case Kind::EQUAL: return l == r;
137 case Kind::GEQ: return l >= r;
138 case Kind::GT: return l > r;
139 default: Unreachable(); return false;
140 }
141 }
142
143 /**
144 * Check whether the parent has a child that is a constant zero.
145 * If so, return this child. Otherwise, return std::nullopt.
146 */
147 template <typename Iterable>
148 std::optional<TNode> getZeroChild(const Iterable& parent)
149 {
150 for (const auto& node : parent)
151 {
152 if (node.isConst() && node.template getConst<Rational>().isZero())
153 {
154 return node;
155 }
156 }
157 return std::nullopt;
158 }
159
160 } // namespace
161
162 ArithRewriter::ArithRewriter(OperatorElim& oe) : d_opElim(oe) {}
163
164 RewriteResponse ArithRewriter::preRewrite(TNode t)
165 {
166 Trace("arith-rewriter") << "preRewrite(" << t << ")" << std::endl;
167 if (isAtom(t))
168 {
169 auto res = preRewriteAtom(t);
170 Trace("arith-rewriter")
171 << res.d_status << " -> " << res.d_node << std::endl;
172 return res;
173 }
174 auto res = preRewriteTerm(t);
175 Trace("arith-rewriter") << res.d_status << " -> " << res.d_node << std::endl;
176 return res;
177 }
178
179 RewriteResponse ArithRewriter::postRewrite(TNode t)
180 {
181 Trace("arith-rewriter") << "postRewrite(" << t << ")" << std::endl;
182 if (isAtom(t))
183 {
184 auto res = postRewriteAtom(t);
185 Trace("arith-rewriter")
186 << res.d_status << " -> " << res.d_node << std::endl;
187 return res;
188 }
189 auto res = postRewriteTerm(t);
190 Trace("arith-rewriter") << res.d_status << " -> " << res.d_node << std::endl;
191 return res;
192 }
193
194 RewriteResponse ArithRewriter::preRewriteAtom(TNode atom)
195 {
196 Assert(isAtom(atom));
197
198 NodeManager* nm = NodeManager::currentNM();
199
200 if (isRelationOperator(atom.getKind()) && atom[0] == atom[1])
201 {
202 switch (atom.getKind())
203 {
204 case Kind::LT: return RewriteResponse(REWRITE_DONE, nm->mkConst(false));
205 case Kind::LEQ: return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
206 case Kind::EQUAL: return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
207 case Kind::GEQ: return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
208 case Kind::GT: return RewriteResponse(REWRITE_DONE, nm->mkConst(false));
209 default:;
210 }
211 }
212
213 switch (atom.getKind())
214 {
215 case Kind::GT:
216 return RewriteResponse(REWRITE_DONE,
217 nm->mkNode(kind::LEQ, atom[0], atom[1]).notNode());
218 case Kind::LT:
219 return RewriteResponse(REWRITE_DONE,
220 nm->mkNode(kind::GEQ, atom[0], atom[1]).notNode());
221 case Kind::IS_INTEGER:
222 if (atom[0].getType().isInteger())
223 {
224 return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
225 }
226 break;
227 case Kind::DIVISIBLE:
228 if (atom.getOperator().getConst<Divisible>().k.isOne())
229 {
230 return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
231 }
232 break;
233 default:;
234 }
235
236 return RewriteResponse(REWRITE_DONE, atom);
237 }
238
239 RewriteResponse ArithRewriter::postRewriteAtom(TNode atom)
240 {
241 Assert(isAtom(atom));
242 if (atom.getKind() == kind::IS_INTEGER)
243 {
244 return rewriteExtIntegerOp(atom);
245 }
246 else if (atom.getKind() == kind::DIVISIBLE)
247 {
248 if (atom[0].isConst())
249 {
250 return RewriteResponse(REWRITE_DONE,
251 NodeManager::currentNM()->mkConst(bool(
252 (atom[0].getConst<Rational>()
253 / atom.getOperator().getConst<Divisible>().k)
254 .isIntegral())));
255 }
256 if (atom.getOperator().getConst<Divisible>().k.isOne())
257 {
258 return RewriteResponse(REWRITE_DONE,
259 NodeManager::currentNM()->mkConst(true));
260 }
261 NodeManager* nm = NodeManager::currentNM();
262 return RewriteResponse(
263 REWRITE_AGAIN,
264 nm->mkNode(kind::EQUAL,
265 nm->mkNode(kind::INTS_MODULUS_TOTAL,
266 atom[0],
267 nm->mkConstInt(Rational(
268 atom.getOperator().getConst<Divisible>().k))),
269 nm->mkConstInt(Rational(0))));
270 }
271
272 // left |><| right
273 TNode left = atom[0];
274 TNode right = atom[1];
275
276 auto* nm = NodeManager::currentNM();
277 if (left.isConst())
278 {
279 const Rational& l = left.getConst<Rational>();
280 if (right.isConst())
281 {
282 const Rational& r = right.getConst<Rational>();
283 return RewriteResponse(
284 REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
285 }
286 else if (right.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
287 {
288 const RealAlgebraicNumber& r =
289 right.getOperator().getConst<RealAlgebraicNumber>();
290 return RewriteResponse(
291 REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
292 }
293 }
294 else if (left.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
295 {
296 const RealAlgebraicNumber& l =
297 left.getOperator().getConst<RealAlgebraicNumber>();
298 if (right.isConst())
299 {
300 const Rational& r = right.getConst<Rational>();
301 return RewriteResponse(
302 REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
303 }
304 else if (right.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
305 {
306 const RealAlgebraicNumber& r =
307 right.getOperator().getConst<RealAlgebraicNumber>();
308 return RewriteResponse(
309 REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
310 }
311 }
312
313 Polynomial pleft = Polynomial::parsePolynomial(left);
314 Polynomial pright = Polynomial::parsePolynomial(right);
315
316 Debug("arith::rewriter") << "pleft " << pleft.getNode() << std::endl;
317 Debug("arith::rewriter") << "pright " << pright.getNode() << std::endl;
318
319 Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
320 Assert(cmp.isNormalForm());
321 return RewriteResponse(REWRITE_DONE, cmp.getNode());
322 }
323
324 bool ArithRewriter::isAtom(TNode n) {
325 Kind k = n.getKind();
326 return arith::isRelationOperator(k) || k == kind::IS_INTEGER
327 || k == kind::DIVISIBLE;
328 }
329
330 RewriteResponse ArithRewriter::rewriteConstant(TNode t){
331 Assert(t.isConst());
332 Assert(t.getKind() == CONST_RATIONAL || t.getKind() == CONST_INTEGER);
333
334 return RewriteResponse(REWRITE_DONE, t);
335 }
336
337 RewriteResponse ArithRewriter::rewriteRAN(TNode t)
338 {
339 Assert(t.getKind() == REAL_ALGEBRAIC_NUMBER);
340
341 const RealAlgebraicNumber& r =
342 t.getOperator().getConst<RealAlgebraicNumber>();
343 if (r.isRational())
344 {
345 return RewriteResponse(
346 REWRITE_DONE, NodeManager::currentNM()->mkConstReal(r.toRational()));
347 }
348
349 return RewriteResponse(REWRITE_DONE, t);
350 }
351
352 RewriteResponse ArithRewriter::rewriteVariable(TNode t){
353 Assert(t.isVar());
354
355 return RewriteResponse(REWRITE_DONE, t);
356 }
357
358 RewriteResponse ArithRewriter::rewriteSub(TNode t)
359 {
360 Assert(t.getKind() == kind::SUB);
361 Assert(t.getNumChildren() == 2);
362
363 auto* nm = NodeManager::currentNM();
364
365 if (t[0] == t[1])
366 {
367 return RewriteResponse(REWRITE_DONE,
368 nm->mkConstRealOrInt(t.getType(), Rational(0)));
369 }
370 return RewriteResponse(REWRITE_AGAIN_FULL,
371 nm->mkNode(Kind::ADD, t[0], makeUnaryMinusNode(t[1])));
372 }
373
374 RewriteResponse ArithRewriter::rewriteNeg(TNode t, bool pre)
375 {
376 Assert(t.getKind() == kind::NEG);
377
378 if (t[0].isConst())
379 {
380 Rational neg = -(t[0].getConst<Rational>());
381 NodeManager* nm = NodeManager::currentNM();
382 return RewriteResponse(REWRITE_DONE,
383 nm->mkConstRealOrInt(t[0].getType(), neg));
384 }
385 if (t[0].getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
386 {
387 const RealAlgebraicNumber& r =
388 t[0].getOperator().getConst<RealAlgebraicNumber>();
389 NodeManager* nm = NodeManager::currentNM();
390 return RewriteResponse(REWRITE_DONE, nm->mkRealAlgebraicNumber(-r));
391 }
392
393 Node noUminus = makeUnaryMinusNode(t[0]);
394 if(pre)
395 return RewriteResponse(REWRITE_DONE, noUminus);
396 else
397 return RewriteResponse(REWRITE_AGAIN, noUminus);
398 }
399
400 RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
401 if(t.isConst()){
402 return rewriteConstant(t);
403 }else if(t.isVar()){
404 return rewriteVariable(t);
405 }else{
406 switch(Kind k = t.getKind()){
407 case kind::REAL_ALGEBRAIC_NUMBER: return rewriteRAN(t);
408 case kind::SUB: return rewriteSub(t);
409 case kind::NEG: return rewriteNeg(t, true);
410 case kind::DIVISION:
411 case kind::DIVISION_TOTAL: return rewriteDiv(t, true);
412 case kind::ADD: return preRewritePlus(t);
413 case kind::MULT:
414 case kind::NONLINEAR_MULT: return preRewriteMult(t);
415 case kind::IAND: return RewriteResponse(REWRITE_DONE, t);
416 case kind::POW2: return RewriteResponse(REWRITE_DONE, t);
417 case kind::EXPONENTIAL:
418 case kind::SINE:
419 case kind::COSINE:
420 case kind::TANGENT:
421 case kind::COSECANT:
422 case kind::SECANT:
423 case kind::COTANGENT:
424 case kind::ARCSINE:
425 case kind::ARCCOSINE:
426 case kind::ARCTANGENT:
427 case kind::ARCCOSECANT:
428 case kind::ARCSECANT:
429 case kind::ARCCOTANGENT:
430 case kind::SQRT: return preRewriteTranscendental(t);
431 case kind::INTS_DIVISION:
432 case kind::INTS_MODULUS: return rewriteIntsDivMod(t, true);
433 case kind::INTS_DIVISION_TOTAL:
434 case kind::INTS_MODULUS_TOTAL: return rewriteIntsDivModTotal(t, true);
435 case kind::ABS: return rewriteAbs(t);
436 case kind::IS_INTEGER:
437 case kind::TO_INTEGER: return RewriteResponse(REWRITE_DONE, t);
438 case kind::TO_REAL:
439 case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
440 case kind::POW: return RewriteResponse(REWRITE_DONE, t);
441 case kind::PI: return RewriteResponse(REWRITE_DONE, t);
442 default: Unhandled() << k;
443 }
444 }
445 }
446
447 RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
448 if(t.isConst()){
449 return rewriteConstant(t);
450 }else if(t.isVar()){
451 return rewriteVariable(t);
452 }else{
453 Trace("arith-rewriter") << "postRewriteTerm: " << t << std::endl;
454 switch(t.getKind()){
455 case kind::REAL_ALGEBRAIC_NUMBER: return rewriteRAN(t);
456 case kind::SUB: return rewriteSub(t);
457 case kind::NEG: return rewriteNeg(t, false);
458 case kind::DIVISION:
459 case kind::DIVISION_TOTAL: return rewriteDiv(t, false);
460 case kind::ADD: return postRewritePlus(t);
461 case kind::MULT:
462 case kind::NONLINEAR_MULT: return postRewriteMult(t);
463 case kind::IAND: return postRewriteIAnd(t);
464 case kind::POW2: return postRewritePow2(t);
465 case kind::EXPONENTIAL:
466 case kind::SINE:
467 case kind::COSINE:
468 case kind::TANGENT:
469 case kind::COSECANT:
470 case kind::SECANT:
471 case kind::COTANGENT:
472 case kind::ARCSINE:
473 case kind::ARCCOSINE:
474 case kind::ARCTANGENT:
475 case kind::ARCCOSECANT:
476 case kind::ARCSECANT:
477 case kind::ARCCOTANGENT:
478 case kind::SQRT: return postRewriteTranscendental(t);
479 case kind::INTS_DIVISION:
480 case kind::INTS_MODULUS: return rewriteIntsDivMod(t, false);
481 case kind::INTS_DIVISION_TOTAL:
482 case kind::INTS_MODULUS_TOTAL: return rewriteIntsDivModTotal(t, false);
483 case kind::ABS: return rewriteAbs(t);
484 case kind::TO_REAL:
485 case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
486 case kind::TO_INTEGER: return rewriteExtIntegerOp(t);
487 case kind::POW:
488 {
489 if (t[1].isConst())
490 {
491 const Rational& exp = t[1].getConst<Rational>();
492 TNode base = t[0];
493 if(exp.sgn() == 0){
494 return RewriteResponse(REWRITE_DONE,
495 NodeManager::currentNM()->mkConstRealOrInt(
496 t.getType(), Rational(1)));
497 }else if(exp.sgn() > 0 && exp.isIntegral()){
498 cvc5::Rational r(expr::NodeValue::MAX_CHILDREN);
499 if (exp <= r)
500 {
501 unsigned num = exp.getNumerator().toUnsignedInt();
502 if( num==1 ){
503 return RewriteResponse(REWRITE_AGAIN, base);
504 }else{
505 NodeBuilder nb(kind::MULT);
506 for(unsigned i=0; i < num; ++i){
507 nb << base;
508 }
509 Assert(nb.getNumChildren() > 0);
510 Node mult = nb;
511 return RewriteResponse(REWRITE_AGAIN, mult);
512 }
513 }
514 }
515 }
516 else if (t[0].isConst()
517 && t[0].getConst<Rational>().getNumerator().toUnsignedInt()
518 == 2)
519 {
520 return RewriteResponse(
521 REWRITE_DONE, NodeManager::currentNM()->mkNode(kind::POW2, t[1]));
522 }
523
524 // Todo improve the exception thrown
525 std::stringstream ss;
526 ss << "The exponent of the POW(^) operator can only be a positive "
527 "integral constant below "
528 << (expr::NodeValue::MAX_CHILDREN + 1) << ". ";
529 ss << "Exception occurred in:" << std::endl;
530 ss << " " << t;
531 throw LogicException(ss.str());
532 }
533 case kind::PI:
534 return RewriteResponse(REWRITE_DONE, t);
535 default:
536 Unreachable();
537 }
538 }
539 }
540
541
542 RewriteResponse ArithRewriter::preRewritePlus(TNode t){
543 Assert(t.getKind() == kind::ADD);
544 return RewriteResponse(REWRITE_DONE, expr::algorithm::flatten(t));
545 }
546
547 RewriteResponse ArithRewriter::postRewritePlus(TNode t){
548 Assert(t.getKind() == kind::ADD);
549 Assert(t.getNumChildren() > 1);
550
551 {
552 Node flat = expr::algorithm::flatten(t);
553 if (flat != t)
554 {
555 return RewriteResponse(REWRITE_AGAIN, flat);
556 }
557 }
558
559 Rational rational;
560 RealAlgebraicNumber ran;
561 std::vector<Monomial> monomials;
562 std::vector<Polynomial> polynomials;
563
564 for (const auto& child : t)
565 {
566 if (child.isConst())
567 {
568 if (child.getConst<Rational>().isZero())
569 {
570 continue;
571 }
572 rational += child.getConst<Rational>();
573 }
574 else if (child.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
575 {
576 ran += child.getOperator().getConst<RealAlgebraicNumber>();
577 }
578 else if (Monomial::isMember(child))
579 {
580 monomials.emplace_back(Monomial::parseMonomial(child));
581 }
582 else
583 {
584 polynomials.emplace_back(Polynomial::parsePolynomial(child));
585 }
586 }
587
588 if(!monomials.empty()){
589 Monomial::sort(monomials);
590 Monomial::combineAdjacentMonomials(monomials);
591 polynomials.emplace_back(Polynomial::mkPolynomial(monomials));
592 }
593 if (!rational.isZero())
594 {
595 polynomials.emplace_back(
596 Polynomial::mkPolynomial(Constant::mkConstant(rational)));
597 }
598
599 Polynomial poly = Polynomial::sumPolynomials(polynomials);
600
601 if (isZero(ran))
602 {
603 return RewriteResponse(REWRITE_DONE, poly.getNode());
604 }
605 if (poly.containsConstant())
606 {
607 ran += RealAlgebraicNumber(poly.getHead().getConstant().getValue());
608 if (!poly.isConstant())
609 {
610 poly = poly.getTail();
611 }
612 }
613
614 auto* nm = NodeManager::currentNM();
615 if (poly.isConstant())
616 {
617 return RewriteResponse(REWRITE_DONE, nm->mkRealAlgebraicNumber(ran));
618 }
619 return RewriteResponse(
620 REWRITE_DONE,
621 nm->mkNode(Kind::ADD, nm->mkRealAlgebraicNumber(ran), poly.getNode()));
622 }
623
624 RewriteResponse ArithRewriter::preRewriteMult(TNode node)
625 {
626 Assert(node.getKind() == kind::MULT
627 || node.getKind() == kind::NONLINEAR_MULT);
628
629 auto res = getZeroChild(node);
630 if (res)
631 {
632 return RewriteResponse(REWRITE_DONE, *res);
633 }
634 return RewriteResponse(REWRITE_DONE, node);
635 }
636
637 RewriteResponse ArithRewriter::postRewriteMult(TNode t){
638 Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT);
639 Assert(t.getNumChildren() >= 2);
640
641 if (auto res = getZeroChild(t); res)
642 {
643 return RewriteResponse(REWRITE_DONE, *res);
644 }
645
646 Rational rational = Rational(1);
647 RealAlgebraicNumber ran = RealAlgebraicNumber(Integer(1));
648 Polynomial poly = Polynomial::mkOne();
649
650 for (const auto& child : t)
651 {
652 if (child.isConst())
653 {
654 if (child.getConst<Rational>().isZero())
655 {
656 return RewriteResponse(REWRITE_DONE, child);
657 }
658 rational *= child.getConst<Rational>();
659 }
660 else if (child.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
661 {
662 ran *= child.getOperator().getConst<RealAlgebraicNumber>();
663 }
664 else
665 {
666 poly = poly * Polynomial::parsePolynomial(child);
667 }
668 }
669
670 if (!rational.isOne())
671 {
672 poly = poly * rational;
673 }
674 if (isOne(ran))
675 {
676 return RewriteResponse(REWRITE_DONE, poly.getNode());
677 }
678 auto* nm = NodeManager::currentNM();
679 if (poly.isConstant())
680 {
681 ran *= RealAlgebraicNumber(poly.getHead().getConstant().getValue());
682 return RewriteResponse(REWRITE_DONE, nm->mkRealAlgebraicNumber(ran));
683 }
684 return RewriteResponse(
685 REWRITE_DONE,
686 nm->mkNode(
687 Kind::MULT, nm->mkRealAlgebraicNumber(ran), poly.getNode()));
688 }
689
690 RewriteResponse ArithRewriter::postRewritePow2(TNode t)
691 {
692 Assert(t.getKind() == kind::POW2);
693 NodeManager* nm = NodeManager::currentNM();
694 // if constant, we eliminate
695 if (t[0].isConst())
696 {
697 // pow2 is only supported for integers
698 Assert(t[0].getType().isInteger());
699 Integer i = t[0].getConst<Rational>().getNumerator();
700 if (i < 0)
701 {
702 return RewriteResponse(REWRITE_DONE, nm->mkConstInt(Rational(0)));
703 }
704 // (pow2 t) ---> (pow 2 t) and continue rewriting to eliminate pow
705 Node two = nm->mkConstInt(Rational(Integer(2)));
706 Node ret = nm->mkNode(kind::POW, two, t[0]);
707 return RewriteResponse(REWRITE_AGAIN, ret);
708 }
709 return RewriteResponse(REWRITE_DONE, t);
710 }
711
712 RewriteResponse ArithRewriter::postRewriteIAnd(TNode t)
713 {
714 Assert(t.getKind() == kind::IAND);
715 size_t bsize = t.getOperator().getConst<IntAnd>().d_size;
716 NodeManager* nm = NodeManager::currentNM();
717 // if constant, we eliminate
718 if (t[0].isConst() && t[1].isConst())
719 {
720 Node iToBvop = nm->mkConst(IntToBitVector(bsize));
721 Node arg1 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[0]);
722 Node arg2 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[1]);
723 Node bvand = nm->mkNode(kind::BITVECTOR_AND, arg1, arg2);
724 Node ret = nm->mkNode(kind::BITVECTOR_TO_NAT, bvand);
725 return RewriteResponse(REWRITE_AGAIN_FULL, ret);
726 }
727 else if (t[0] > t[1])
728 {
729 // ((_ iand k) x y) ---> ((_ iand k) y x) if x > y by node ordering
730 Node ret = nm->mkNode(kind::IAND, t.getOperator(), t[1], t[0]);
731 return RewriteResponse(REWRITE_AGAIN, ret);
732 }
733 else if (t[0] == t[1])
734 {
735 // ((_ iand k) x x) ---> (mod x 2^k)
736 Node twok = nm->mkConstInt(Rational(Integer(2).pow(bsize)));
737 Node ret = nm->mkNode(kind::INTS_MODULUS, t[0], twok);
738 return RewriteResponse(REWRITE_AGAIN, ret);
739 }
740 // simplifications involving constants
741 for (unsigned i = 0; i < 2; i++)
742 {
743 if (!t[i].isConst())
744 {
745 continue;
746 }
747 if (t[i].getConst<Rational>().sgn() == 0)
748 {
749 // ((_ iand k) 0 y) ---> 0
750 return RewriteResponse(REWRITE_DONE, t[i]);
751 }
752 if (t[i].getConst<Rational>().getNumerator() == Integer(2).pow(bsize) - 1)
753 {
754 // ((_ iand k) 111...1 y) ---> (mod y 2^k)
755 Node twok = nm->mkConstInt(Rational(Integer(2).pow(bsize)));
756 Node ret = nm->mkNode(kind::INTS_MODULUS, t[1-i], twok);
757 return RewriteResponse(REWRITE_AGAIN, ret);
758 }
759 }
760 return RewriteResponse(REWRITE_DONE, t);
761 }
762
763 RewriteResponse ArithRewriter::preRewriteTranscendental(TNode t) {
764 return RewriteResponse(REWRITE_DONE, t);
765 }
766
767 RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) {
768 Trace("arith-tf-rewrite") << "Rewrite transcendental function : " << t << std::endl;
769 NodeManager* nm = NodeManager::currentNM();
770 switch( t.getKind() ){
771 case kind::EXPONENTIAL: {
772 if (t[0].isConst())
773 {
774 Node one = nm->mkConstReal(Rational(1));
775 if(t[0].getConst<Rational>().sgn()>=0 && t[0].getType().isInteger() && t[0]!=one){
776 return RewriteResponse(
777 REWRITE_AGAIN,
778 nm->mkNode(kind::POW, nm->mkNode(kind::EXPONENTIAL, one), t[0]));
779 }else{
780 return RewriteResponse(REWRITE_DONE, t);
781 }
782 }
783 else if (t[0].getKind() == kind::ADD)
784 {
785 std::vector<Node> product;
786 for (const Node tc : t[0])
787 {
788 product.push_back(nm->mkNode(kind::EXPONENTIAL, tc));
789 }
790 // We need to do a full rewrite here, since we can get exponentials of
791 // constants, e.g. when we are rewriting exp(2 + x)
792 return RewriteResponse(REWRITE_AGAIN_FULL,
793 nm->mkNode(kind::MULT, product));
794 }
795 }
796 break;
797 case kind::SINE:
798 if (t[0].isConst())
799 {
800 const Rational& rat = t[0].getConst<Rational>();
801 if(rat.sgn() == 0){
802 return RewriteResponse(REWRITE_DONE, nm->mkConstReal(Rational(0)));
803 }
804 else if (rat.sgn() == -1)
805 {
806 Node ret = nm->mkNode(kind::NEG,
807 nm->mkNode(kind::SINE, nm->mkConstReal(-rat)));
808 return RewriteResponse(REWRITE_AGAIN_FULL, ret);
809 }
810 }
811 else if ((t[0].getKind() == MULT || t[0].getKind() == NONLINEAR_MULT)
812 && t[0][0].isConst() && t[0][0].getConst<Rational>().sgn() == -1)
813 {
814 // sin(-n*x) ---> -sin(n*x)
815 std::vector<Node> mchildren(t[0].begin(), t[0].end());
816 mchildren[0] = nm->mkConstReal(-t[0][0].getConst<Rational>());
817 Node ret = nm->mkNode(
818 kind::NEG,
819 nm->mkNode(kind::SINE, nm->mkNode(t[0].getKind(), mchildren)));
820 return RewriteResponse(REWRITE_AGAIN_FULL, ret);
821 }
822 else
823 {
824 // get the factor of PI in the argument
825 Node pi_factor;
826 Node pi;
827 Node rem;
828 std::map<Node, Node> msum;
829 if (ArithMSum::getMonomialSum(t[0], msum))
830 {
831 pi = mkPi();
832 std::map<Node, Node>::iterator itm = msum.find(pi);
833 if (itm != msum.end())
834 {
835 if (itm->second.isNull())
836 {
837 pi_factor = nm->mkConstReal(Rational(1));
838 }
839 else
840 {
841 pi_factor = itm->second;
842 }
843 msum.erase(pi);
844 if (!msum.empty())
845 {
846 rem = ArithMSum::mkNode(t[0].getType(), msum);
847 }
848 }
849 }
850 else
851 {
852 Assert(false);
853 }
854
855 // if there is a factor of PI
856 if( !pi_factor.isNull() ){
857 Trace("arith-tf-rewrite-debug") << "Process pi factor = " << pi_factor << std::endl;
858 Rational r = pi_factor.getConst<Rational>();
859 Rational r_abs = r.abs();
860 Rational rone = Rational(1);
861 Rational rtwo = Rational(2);
862 if (r_abs > rone)
863 {
864 //add/substract 2*pi beyond scope
865 Rational ra_div_two = (r_abs + rone) / rtwo;
866 Node new_pi_factor;
867 if (r.sgn() == 1)
868 {
869 new_pi_factor = nm->mkConstReal(r - rtwo * ra_div_two.floor());
870 }
871 else
872 {
873 Assert(r.sgn() == -1);
874 new_pi_factor = nm->mkConstReal(r + rtwo * ra_div_two.floor());
875 }
876 Node new_arg = nm->mkNode(kind::MULT, new_pi_factor, pi);
877 if (!rem.isNull())
878 {
879 new_arg = nm->mkNode(kind::ADD, new_arg, rem);
880 }
881 // sin( 2*n*PI + x ) = sin( x )
882 return RewriteResponse(REWRITE_AGAIN_FULL,
883 nm->mkNode(kind::SINE, new_arg));
884 }
885 else if (r_abs == rone)
886 {
887 // sin( PI + x ) = -sin( x )
888 if (rem.isNull())
889 {
890 return RewriteResponse(REWRITE_DONE, nm->mkConstReal(Rational(0)));
891 }
892 else
893 {
894 return RewriteResponse(
895 REWRITE_AGAIN_FULL,
896 nm->mkNode(kind::NEG, nm->mkNode(kind::SINE, rem)));
897 }
898 }
899 else if (rem.isNull())
900 {
901 // other rational cases based on Niven's theorem
902 // (https://en.wikipedia.org/wiki/Niven%27s_theorem)
903 Integer one = Integer(1);
904 Integer two = Integer(2);
905 Integer six = Integer(6);
906 if (r_abs.getDenominator() == two)
907 {
908 Assert(r_abs.getNumerator() == one);
909 return RewriteResponse(REWRITE_DONE,
910 nm->mkConstReal(Rational(r.sgn())));
911 }
912 else if (r_abs.getDenominator() == six)
913 {
914 Integer five = Integer(5);
915 if (r_abs.getNumerator() == one || r_abs.getNumerator() == five)
916 {
917 return RewriteResponse(
918 REWRITE_DONE,
919 nm->mkConstReal(Rational(r.sgn()) / Rational(2)));
920 }
921 }
922 }
923 }
924 }
925 break;
926 case kind::COSINE: {
927 return RewriteResponse(
928 REWRITE_AGAIN_FULL,
929 nm->mkNode(
930 kind::SINE,
931 nm->mkNode(kind::SUB,
932 nm->mkNode(kind::MULT,
933 nm->mkConstReal(Rational(1) / Rational(2)),
934 mkPi()),
935 t[0])));
936 }
937 break;
938 case kind::TANGENT:
939 {
940 return RewriteResponse(REWRITE_AGAIN_FULL,
941 nm->mkNode(kind::DIVISION,
942 nm->mkNode(kind::SINE, t[0]),
943 nm->mkNode(kind::COSINE, t[0])));
944 }
945 break;
946 case kind::COSECANT:
947 {
948 return RewriteResponse(REWRITE_AGAIN_FULL,
949 nm->mkNode(kind::DIVISION,
950 nm->mkConstReal(Rational(1)),
951 nm->mkNode(kind::SINE, t[0])));
952 }
953 break;
954 case kind::SECANT:
955 {
956 return RewriteResponse(REWRITE_AGAIN_FULL,
957 nm->mkNode(kind::DIVISION,
958 nm->mkConstReal(Rational(1)),
959 nm->mkNode(kind::COSINE, t[0])));
960 }
961 break;
962 case kind::COTANGENT:
963 {
964 return RewriteResponse(REWRITE_AGAIN_FULL,
965 nm->mkNode(kind::DIVISION,
966 nm->mkNode(kind::COSINE, t[0]),
967 nm->mkNode(kind::SINE, t[0])));
968 }
969 break;
970 default:
971 break;
972 }
973 return RewriteResponse(REWRITE_DONE, t);
974 }
975
976 Node ArithRewriter::makeUnaryMinusNode(TNode n){
977 NodeManager* nm = NodeManager::currentNM();
978 Rational qNegOne(-1);
979 return nm->mkNode(kind::MULT, nm->mkConstRealOrInt(n.getType(), qNegOne), n);
980 }
981
982 RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
983 Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind() == kind::DIVISION);
984 Assert(t.getNumChildren() == 2);
985
986 Node left = t[0];
987 Node right = t[1];
988 if (right.isConst())
989 {
990 NodeManager* nm = NodeManager::currentNM();
991 const Rational& den = right.getConst<Rational>();
992
993 if(den.isZero()){
994 if(t.getKind() == kind::DIVISION_TOTAL){
995 return RewriteResponse(REWRITE_DONE, nm->mkConstReal(0));
996 }else{
997 // This is unsupported, but this is not a good place to complain
998 return RewriteResponse(REWRITE_DONE, t);
999 }
1000 }
1001 Assert(den != Rational(0));
1002
1003 if (left.isConst())
1004 {
1005 const Rational& num = left.getConst<Rational>();
1006 return RewriteResponse(REWRITE_DONE, nm->mkConstReal(num / den));
1007 }
1008 if (left.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
1009 {
1010 const RealAlgebraicNumber& num =
1011 left.getOperator().getConst<RealAlgebraicNumber>();
1012 return RewriteResponse(
1013 REWRITE_DONE,
1014 nm->mkRealAlgebraicNumber(num / RealAlgebraicNumber(den)));
1015 }
1016
1017 Node result = nm->mkConstReal(den.inverse());
1018 Node mult = NodeManager::currentNM()->mkNode(kind::MULT, left, result);
1019 if (pre)
1020 {
1021 return RewriteResponse(REWRITE_DONE, mult);
1022 }
1023 else
1024 {
1025 return RewriteResponse(REWRITE_AGAIN, mult);
1026 }
1027 }
1028 if (right.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
1029 {
1030 NodeManager* nm = NodeManager::currentNM();
1031 const RealAlgebraicNumber& den =
1032 right.getOperator().getConst<RealAlgebraicNumber>();
1033 if (left.isConst())
1034 {
1035 const Rational& num = left.getConst<Rational>();
1036 return RewriteResponse(
1037 REWRITE_DONE,
1038 nm->mkRealAlgebraicNumber(RealAlgebraicNumber(num) / den));
1039 }
1040 if (left.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
1041 {
1042 const RealAlgebraicNumber& num =
1043 left.getOperator().getConst<RealAlgebraicNumber>();
1044 return RewriteResponse(REWRITE_DONE,
1045 nm->mkRealAlgebraicNumber(num / den));
1046 }
1047
1048 Node result = nm->mkRealAlgebraicNumber(inverse(den));
1049 Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
1050 if(pre){
1051 return RewriteResponse(REWRITE_DONE, mult);
1052 }else{
1053 return RewriteResponse(REWRITE_AGAIN, mult);
1054 }
1055 }
1056 return RewriteResponse(REWRITE_DONE, t);
1057 }
1058
1059 RewriteResponse ArithRewriter::rewriteAbs(TNode t)
1060 {
1061 Assert(t.getKind() == Kind::ABS);
1062 Assert(t.getNumChildren() == 1);
1063
1064 if (t[0].isConst())
1065 {
1066 const Rational& rat = t[0].getConst<Rational>();
1067 if (rat >= 0)
1068 {
1069 return RewriteResponse(REWRITE_DONE, t[0]);
1070 }
1071 return RewriteResponse(
1072 REWRITE_DONE,
1073 NodeManager::currentNM()->mkConstRealOrInt(t[0].getType(), -rat));
1074 }
1075 if (t[0].getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
1076 {
1077 const RealAlgebraicNumber& ran =
1078 t[0].getOperator().getConst<RealAlgebraicNumber>();
1079 if (ran >= RealAlgebraicNumber())
1080 {
1081 return RewriteResponse(REWRITE_DONE, t[0]);
1082 }
1083 return RewriteResponse(
1084 REWRITE_DONE, NodeManager::currentNM()->mkRealAlgebraicNumber(-ran));
1085 }
1086 return RewriteResponse(REWRITE_DONE, t);
1087 }
1088
1089 RewriteResponse ArithRewriter::rewriteIntsDivMod(TNode t, bool pre)
1090 {
1091 NodeManager* nm = NodeManager::currentNM();
1092 Kind k = t.getKind();
1093 if (k == kind::INTS_MODULUS)
1094 {
1095 if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
1096 {
1097 // can immediately replace by INTS_MODULUS_TOTAL
1098 Node ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, t[0], t[1]);
1099 return returnRewrite(t, ret, Rewrite::MOD_TOTAL_BY_CONST);
1100 }
1101 }
1102 if (k == kind::INTS_DIVISION)
1103 {
1104 if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
1105 {
1106 // can immediately replace by INTS_DIVISION_TOTAL
1107 Node ret = nm->mkNode(kind::INTS_DIVISION_TOTAL, t[0], t[1]);
1108 return returnRewrite(t, ret, Rewrite::DIV_TOTAL_BY_CONST);
1109 }
1110 }
1111 return RewriteResponse(REWRITE_DONE, t);
1112 }
1113
1114 RewriteResponse ArithRewriter::rewriteExtIntegerOp(TNode t)
1115 {
1116 Assert(t.getKind() == kind::TO_INTEGER || t.getKind() == kind::IS_INTEGER);
1117 bool isPred = t.getKind() == kind::IS_INTEGER;
1118 NodeManager* nm = NodeManager::currentNM();
1119 if (t[0].isConst())
1120 {
1121 Node ret;
1122 if (isPred)
1123 {
1124 ret = nm->mkConst(t[0].getConst<Rational>().isIntegral());
1125 }
1126 else
1127 {
1128 ret = nm->mkConstInt(Rational(t[0].getConst<Rational>().floor()));
1129 }
1130 return returnRewrite(t, ret, Rewrite::INT_EXT_CONST);
1131 }
1132 if (t[0].getType().isInteger())
1133 {
1134 Node ret = isPred ? nm->mkConst(true) : Node(t[0]);
1135 return returnRewrite(t, ret, Rewrite::INT_EXT_INT);
1136 }
1137 if (t[0].getKind() == kind::PI)
1138 {
1139 Node ret = isPred ? nm->mkConst(false) : nm->mkConstReal(Rational(3));
1140 return returnRewrite(t, ret, Rewrite::INT_EXT_PI);
1141 }
1142 return RewriteResponse(REWRITE_DONE, t);
1143 }
1144
1145 RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre)
1146 {
1147 if (pre)
1148 {
1149 // do not rewrite at prewrite.
1150 return RewriteResponse(REWRITE_DONE, t);
1151 }
1152 NodeManager* nm = NodeManager::currentNM();
1153 Kind k = t.getKind();
1154 Assert(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL);
1155 TNode n = t[0];
1156 TNode d = t[1];
1157 bool dIsConstant = d.isConst();
1158 if(dIsConstant && d.getConst<Rational>().isZero()){
1159 // (div x 0) ---> 0 or (mod x 0) ---> 0
1160 return returnRewrite(t, nm->mkConstInt(0), Rewrite::DIV_MOD_BY_ZERO);
1161 }else if(dIsConstant && d.getConst<Rational>().isOne()){
1162 if (k == kind::INTS_MODULUS_TOTAL)
1163 {
1164 // (mod x 1) --> 0
1165 return returnRewrite(t, nm->mkConstInt(0), Rewrite::MOD_BY_ONE);
1166 }
1167 Assert(k == kind::INTS_DIVISION_TOTAL);
1168 // (div x 1) --> x
1169 return returnRewrite(t, n, Rewrite::DIV_BY_ONE);
1170 }
1171 else if (dIsConstant && d.getConst<Rational>().sgn() < 0)
1172 {
1173 // pull negation
1174 // (div x (- c)) ---> (- (div x c))
1175 // (mod x (- c)) ---> (mod x c)
1176 Node nn = nm->mkNode(k, t[0], nm->mkConstInt(-t[1].getConst<Rational>()));
1177 Node ret = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL)
1178 ? nm->mkNode(kind::NEG, nn)
1179 : nn;
1180 return returnRewrite(t, ret, Rewrite::DIV_MOD_PULL_NEG_DEN);
1181 }
1182 else if (dIsConstant && n.isConst())
1183 {
1184 Assert(d.getConst<Rational>().isIntegral());
1185 Assert(n.getConst<Rational>().isIntegral());
1186 Assert(!d.getConst<Rational>().isZero());
1187 Integer di = d.getConst<Rational>().getNumerator();
1188 Integer ni = n.getConst<Rational>().getNumerator();
1189
1190 bool isDiv = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
1191
1192 Integer result = isDiv ? ni.euclidianDivideQuotient(di) : ni.euclidianDivideRemainder(di);
1193
1194 // constant evaluation
1195 // (mod c1 c2) ---> c3 or (div c1 c2) ---> c3
1196 Node resultNode = nm->mkConstInt(Rational(result));
1197 return returnRewrite(t, resultNode, Rewrite::CONST_EVAL);
1198 }
1199 if (k == kind::INTS_MODULUS_TOTAL)
1200 {
1201 // Note these rewrites do not need to account for modulus by zero as being
1202 // a UF, which is handled by the reduction of INTS_MODULUS.
1203 Kind k0 = t[0].getKind();
1204 if (k0 == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
1205 {
1206 // (mod (mod x c) c) --> (mod x c)
1207 return returnRewrite(t, t[0], Rewrite::MOD_OVER_MOD);
1208 }
1209 else if (k0 == kind::NONLINEAR_MULT || k0 == kind::MULT || k0 == kind::ADD)
1210 {
1211 // can drop all
1212 std::vector<Node> newChildren;
1213 bool childChanged = false;
1214 for (const Node& tc : t[0])
1215 {
1216 if (tc.getKind() == kind::INTS_MODULUS_TOTAL && tc[1] == t[1])
1217 {
1218 newChildren.push_back(tc[0]);
1219 childChanged = true;
1220 continue;
1221 }
1222 newChildren.push_back(tc);
1223 }
1224 if (childChanged)
1225 {
1226 // (mod (op ... (mod x c) ...) c) ---> (mod (op ... x ...) c) where
1227 // op is one of { NONLINEAR_MULT, MULT, ADD }.
1228 Node ret = nm->mkNode(k0, newChildren);
1229 ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, ret, t[1]);
1230 return returnRewrite(t, ret, Rewrite::MOD_CHILD_MOD);
1231 }
1232 }
1233 }
1234 else
1235 {
1236 Assert(k == kind::INTS_DIVISION_TOTAL);
1237 // Note these rewrites do not need to account for division by zero as being
1238 // a UF, which is handled by the reduction of INTS_DIVISION.
1239 if (t[0].getKind() == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
1240 {
1241 // (div (mod x c) c) --> 0
1242 Node ret = nm->mkConstInt(0);
1243 return returnRewrite(t, ret, Rewrite::DIV_OVER_MOD);
1244 }
1245 }
1246 return RewriteResponse(REWRITE_DONE, t);
1247 }
1248
1249 TrustNode ArithRewriter::expandDefinition(Node node)
1250 {
1251 // call eliminate operators, to eliminate partial operators only
1252 std::vector<SkolemLemma> lems;
1253 TrustNode ret = d_opElim.eliminate(node, lems, true);
1254 Assert(lems.empty());
1255 return ret;
1256 }
1257
1258 RewriteResponse ArithRewriter::returnRewrite(TNode t, Node ret, Rewrite r)
1259 {
1260 Trace("arith-rewrite") << "ArithRewriter : " << t << " == " << ret << " by "
1261 << r << std::endl;
1262 return RewriteResponse(REWRITE_AGAIN_FULL, ret);
1263 }
1264
1265 } // namespace arith
1266 } // namespace theory
1267 } // namespace cvc5