Fix rewrite for eliminating constant factors of PI from argument to sine (#8031)
[cvc5.git] / src / theory / arith / arith_rewriter.cpp
1 /******************************************************************************
2 * Top contributors (to current version):
3 * Andrew Reynolds, Tim King, Morgan Deters
4 *
5 * This file is part of the cvc5 project.
6 *
7 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 * in the top-level source directory and their institutional affiliations.
9 * All rights reserved. See the file COPYING in the top-level source
10 * directory for licensing information.
11 * ****************************************************************************
12 *
13 * [[ Add one-line brief description here ]]
14 *
15 * [[ Add lengthier description here ]]
16 * \todo document this file
17 */
18
19 #include "theory/arith/arith_rewriter.h"
20
21 #include <optional>
22 #include <set>
23 #include <sstream>
24 #include <stack>
25 #include <vector>
26
27 #include "expr/algorithm/flatten.h"
28 #include "smt/logic_exception.h"
29 #include "theory/arith/arith_msum.h"
30 #include "theory/arith/arith_utilities.h"
31 #include "theory/arith/normal_form.h"
32 #include "theory/arith/operator_elim.h"
33 #include "theory/theory.h"
34 #include "util/bitvector.h"
35 #include "util/divisible.h"
36 #include "util/iand.h"
37 #include "util/real_algebraic_number.h"
38
39 using namespace cvc5::kind;
40
41 namespace cvc5 {
42 namespace theory {
43 namespace arith {
44
45 namespace {
46
47 /**
48 * Implements an ordering on arithmetic leaf nodes, excluding rationals. As this
49 * comparator is meant to be used on children of Kind::NONLINEAR_MULT, we expect
50 * rationals to be handled separately. Furthermore, we expect there to be only a
51 * single real algebraic number.
52 * It broadly categorizes leaf nodes into real algebraic numbers, integers,
53 * variables, and the rest. The ordering is built as follows:
54 * - real algebraic numbers come first
55 * - real terms come before integer terms
56 * - variables come before non-variable terms
57 * - finally, fall back to node ordering
58 */
59 struct LeafNodeComparator
60 {
61 /** Implements operator<(a, b) as described above */
62 bool operator()(TNode a, TNode b)
63 {
64 if (a == b) return false;
65
66 bool aIsRAN = a.getKind() == Kind::REAL_ALGEBRAIC_NUMBER;
67 bool bIsRAN = b.getKind() == Kind::REAL_ALGEBRAIC_NUMBER;
68 if (aIsRAN != bIsRAN) return aIsRAN;
69 Assert(!aIsRAN && !bIsRAN) << "real algebraic numbers should be combined";
70
71 bool aIsInt = a.getType().isInteger();
72 bool bIsInt = b.getType().isInteger();
73 if (aIsInt != bIsInt) return !aIsInt;
74
75 bool aIsVar = a.isVar();
76 bool bIsVar = b.isVar();
77 if (aIsVar != bIsVar) return aIsVar;
78
79 return a < b;
80 }
81 };
82
83 /**
84 * Implements an ordering on arithmetic nonlinear multiplications. As we assume
85 * rationals to be handled separately, we only consider Kind::NONLINEAR_MULT as
86 * multiplication terms. For individual factors of the product, we rely on the
87 * ordering from LeafNodeComparator. Furthermore, we expect products to be
88 * sorted according to LeafNodeComparator. The ordering is built as follows:
89 * - single factors come first (everything that is not NONLINEAR_MULT)
90 * - multiplications with less factors come first
91 * - multiplications are compared lexicographically
92 */
93 struct ProductNodeComparator
94 {
95 /** Implements operator<(a, b) as described above */
96 bool operator()(TNode a, TNode b)
97 {
98 if (a == b) return false;
99
100 Assert(a.getKind() != Kind::MULT);
101 Assert(b.getKind() != Kind::MULT);
102
103 bool aIsMult = a.getKind() == Kind::NONLINEAR_MULT;
104 bool bIsMult = b.getKind() == Kind::NONLINEAR_MULT;
105 if (aIsMult != bIsMult) return !aIsMult;
106
107 if (!aIsMult)
108 {
109 return LeafNodeComparator()(a, b);
110 }
111
112 size_t aLen = a.getNumChildren();
113 size_t bLen = b.getNumChildren();
114 if (aLen != bLen) return aLen < bLen;
115
116 for (size_t i = 0; i < aLen; ++i)
117 {
118 if (a[i] != b[i])
119 {
120 return LeafNodeComparator()(a[i], b[i]);
121 }
122 }
123 Unreachable() << "Nodes are different, but have the same content";
124 return false;
125 }
126 };
127
128
129 template <typename L, typename R>
130 bool evaluateRelation(Kind rel, const L& l, const R& r)
131 {
132 switch (rel)
133 {
134 case Kind::LT: return l < r;
135 case Kind::LEQ: return l <= r;
136 case Kind::EQUAL: return l == r;
137 case Kind::GEQ: return l >= r;
138 case Kind::GT: return l > r;
139 default: Unreachable(); return false;
140 }
141 }
142
143 /**
144 * Check whether the parent has a child that is a constant zero.
145 * If so, return this child. Otherwise, return std::nullopt.
146 */
147 template <typename Iterable>
148 std::optional<TNode> getZeroChild(const Iterable& parent)
149 {
150 for (const auto& node : parent)
151 {
152 if (node.isConst() && node.template getConst<Rational>().isZero())
153 {
154 return node;
155 }
156 }
157 return std::nullopt;
158 }
159
160 } // namespace
161
162 ArithRewriter::ArithRewriter(OperatorElim& oe) : d_opElim(oe) {}
163
164 RewriteResponse ArithRewriter::preRewrite(TNode t)
165 {
166 Trace("arith-rewriter") << "preRewrite(" << t << ")" << std::endl;
167 if (isAtom(t))
168 {
169 auto res = preRewriteAtom(t);
170 Trace("arith-rewriter")
171 << res.d_status << " -> " << res.d_node << std::endl;
172 return res;
173 }
174 auto res = preRewriteTerm(t);
175 Trace("arith-rewriter") << res.d_status << " -> " << res.d_node << std::endl;
176 return res;
177 }
178
179 RewriteResponse ArithRewriter::postRewrite(TNode t)
180 {
181 Trace("arith-rewriter") << "postRewrite(" << t << ")" << std::endl;
182 if (isAtom(t))
183 {
184 auto res = postRewriteAtom(t);
185 Trace("arith-rewriter")
186 << res.d_status << " -> " << res.d_node << std::endl;
187 return res;
188 }
189 auto res = postRewriteTerm(t);
190 Trace("arith-rewriter") << res.d_status << " -> " << res.d_node << std::endl;
191 return res;
192 }
193
194 RewriteResponse ArithRewriter::preRewriteAtom(TNode atom)
195 {
196 Assert(isAtom(atom));
197
198 NodeManager* nm = NodeManager::currentNM();
199
200 if (isRelationOperator(atom.getKind()) && atom[0] == atom[1])
201 {
202 switch (atom.getKind())
203 {
204 case Kind::LT: return RewriteResponse(REWRITE_DONE, nm->mkConst(false));
205 case Kind::LEQ: return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
206 case Kind::EQUAL: return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
207 case Kind::GEQ: return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
208 case Kind::GT: return RewriteResponse(REWRITE_DONE, nm->mkConst(false));
209 default:;
210 }
211 }
212
213 switch (atom.getKind())
214 {
215 case Kind::GT:
216 return RewriteResponse(REWRITE_DONE,
217 nm->mkNode(kind::LEQ, atom[0], atom[1]).notNode());
218 case Kind::LT:
219 return RewriteResponse(REWRITE_DONE,
220 nm->mkNode(kind::GEQ, atom[0], atom[1]).notNode());
221 case Kind::IS_INTEGER:
222 if (atom[0].getType().isInteger())
223 {
224 return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
225 }
226 break;
227 case Kind::DIVISIBLE:
228 if (atom.getOperator().getConst<Divisible>().k.isOne())
229 {
230 return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
231 }
232 break;
233 default:;
234 }
235
236 return RewriteResponse(REWRITE_DONE, atom);
237 }
238
239 RewriteResponse ArithRewriter::postRewriteAtom(TNode atom)
240 {
241 Assert(isAtom(atom));
242 if (atom.getKind() == kind::IS_INTEGER)
243 {
244 return rewriteExtIntegerOp(atom);
245 }
246 else if (atom.getKind() == kind::DIVISIBLE)
247 {
248 if (atom[0].isConst())
249 {
250 return RewriteResponse(REWRITE_DONE,
251 NodeManager::currentNM()->mkConst(bool(
252 (atom[0].getConst<Rational>()
253 / atom.getOperator().getConst<Divisible>().k)
254 .isIntegral())));
255 }
256 if (atom.getOperator().getConst<Divisible>().k.isOne())
257 {
258 return RewriteResponse(REWRITE_DONE,
259 NodeManager::currentNM()->mkConst(true));
260 }
261 NodeManager* nm = NodeManager::currentNM();
262 return RewriteResponse(
263 REWRITE_AGAIN,
264 nm->mkNode(kind::EQUAL,
265 nm->mkNode(kind::INTS_MODULUS_TOTAL,
266 atom[0],
267 nm->mkConstInt(Rational(
268 atom.getOperator().getConst<Divisible>().k))),
269 nm->mkConstInt(Rational(0))));
270 }
271
272 // left |><| right
273 TNode left = atom[0];
274 TNode right = atom[1];
275
276 auto* nm = NodeManager::currentNM();
277 if (left.isConst())
278 {
279 const Rational& l = left.getConst<Rational>();
280 if (right.isConst())
281 {
282 const Rational& r = right.getConst<Rational>();
283 return RewriteResponse(
284 REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
285 }
286 else if (right.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
287 {
288 const RealAlgebraicNumber& r =
289 right.getOperator().getConst<RealAlgebraicNumber>();
290 return RewriteResponse(
291 REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
292 }
293 }
294 else if (left.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
295 {
296 const RealAlgebraicNumber& l =
297 left.getOperator().getConst<RealAlgebraicNumber>();
298 if (right.isConst())
299 {
300 const Rational& r = right.getConst<Rational>();
301 return RewriteResponse(
302 REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
303 }
304 else if (right.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
305 {
306 const RealAlgebraicNumber& r =
307 right.getOperator().getConst<RealAlgebraicNumber>();
308 return RewriteResponse(
309 REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
310 }
311 }
312
313 Polynomial pleft = Polynomial::parsePolynomial(left);
314 Polynomial pright = Polynomial::parsePolynomial(right);
315
316 Debug("arith::rewriter") << "pleft " << pleft.getNode() << std::endl;
317 Debug("arith::rewriter") << "pright " << pright.getNode() << std::endl;
318
319 Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
320 Assert(cmp.isNormalForm());
321 return RewriteResponse(REWRITE_DONE, cmp.getNode());
322 }
323
324 bool ArithRewriter::isAtom(TNode n) {
325 Kind k = n.getKind();
326 return arith::isRelationOperator(k) || k == kind::IS_INTEGER
327 || k == kind::DIVISIBLE;
328 }
329
330 RewriteResponse ArithRewriter::rewriteConstant(TNode t){
331 Assert(t.isConst());
332 Assert(t.getKind() == CONST_RATIONAL || t.getKind() == CONST_INTEGER);
333
334 return RewriteResponse(REWRITE_DONE, t);
335 }
336
337 RewriteResponse ArithRewriter::rewriteRAN(TNode t)
338 {
339 Assert(t.getKind() == REAL_ALGEBRAIC_NUMBER);
340
341 const RealAlgebraicNumber& r =
342 t.getOperator().getConst<RealAlgebraicNumber>();
343 if (r.isRational())
344 {
345 return RewriteResponse(
346 REWRITE_DONE, NodeManager::currentNM()->mkConstReal(r.toRational()));
347 }
348
349 return RewriteResponse(REWRITE_DONE, t);
350 }
351
352 RewriteResponse ArithRewriter::rewriteVariable(TNode t){
353 Assert(t.isVar());
354
355 return RewriteResponse(REWRITE_DONE, t);
356 }
357
358 RewriteResponse ArithRewriter::rewriteSub(TNode t)
359 {
360 Assert(t.getKind() == kind::SUB);
361 Assert(t.getNumChildren() == 2);
362
363 auto* nm = NodeManager::currentNM();
364
365 if (t[0] == t[1])
366 {
367 return RewriteResponse(REWRITE_DONE,
368 nm->mkConstRealOrInt(t.getType(), Rational(0)));
369 }
370 return RewriteResponse(
371 REWRITE_AGAIN_FULL,
372 nm->mkNode(Kind::PLUS, t[0], makeUnaryMinusNode(t[1])));
373 }
374
375 RewriteResponse ArithRewriter::rewriteNeg(TNode t, bool pre)
376 {
377 Assert(t.getKind() == kind::NEG);
378
379 if (t[0].isConst())
380 {
381 Rational neg = -(t[0].getConst<Rational>());
382 NodeManager* nm = NodeManager::currentNM();
383 return RewriteResponse(REWRITE_DONE,
384 nm->mkConstRealOrInt(t[0].getType(), neg));
385 }
386 if (t[0].getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
387 {
388 const RealAlgebraicNumber& r =
389 t[0].getOperator().getConst<RealAlgebraicNumber>();
390 NodeManager* nm = NodeManager::currentNM();
391 return RewriteResponse(REWRITE_DONE, nm->mkRealAlgebraicNumber(-r));
392 }
393
394 Node noUminus = makeUnaryMinusNode(t[0]);
395 if(pre)
396 return RewriteResponse(REWRITE_DONE, noUminus);
397 else
398 return RewriteResponse(REWRITE_AGAIN, noUminus);
399 }
400
401 RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
402 if(t.isConst()){
403 return rewriteConstant(t);
404 }else if(t.isVar()){
405 return rewriteVariable(t);
406 }else{
407 switch(Kind k = t.getKind()){
408 case kind::REAL_ALGEBRAIC_NUMBER: return rewriteRAN(t);
409 case kind::SUB: return rewriteSub(t);
410 case kind::NEG: return rewriteNeg(t, true);
411 case kind::DIVISION:
412 case kind::DIVISION_TOTAL: return rewriteDiv(t, true);
413 case kind::PLUS: return preRewritePlus(t);
414 case kind::MULT:
415 case kind::NONLINEAR_MULT: return preRewriteMult(t);
416 case kind::IAND: return RewriteResponse(REWRITE_DONE, t);
417 case kind::POW2: return RewriteResponse(REWRITE_DONE, t);
418 case kind::EXPONENTIAL:
419 case kind::SINE:
420 case kind::COSINE:
421 case kind::TANGENT:
422 case kind::COSECANT:
423 case kind::SECANT:
424 case kind::COTANGENT:
425 case kind::ARCSINE:
426 case kind::ARCCOSINE:
427 case kind::ARCTANGENT:
428 case kind::ARCCOSECANT:
429 case kind::ARCSECANT:
430 case kind::ARCCOTANGENT:
431 case kind::SQRT: return preRewriteTranscendental(t);
432 case kind::INTS_DIVISION:
433 case kind::INTS_MODULUS: return rewriteIntsDivMod(t, true);
434 case kind::INTS_DIVISION_TOTAL:
435 case kind::INTS_MODULUS_TOTAL: return rewriteIntsDivModTotal(t, true);
436 case kind::ABS: return rewriteAbs(t);
437 case kind::IS_INTEGER:
438 case kind::TO_INTEGER: return RewriteResponse(REWRITE_DONE, t);
439 case kind::TO_REAL:
440 case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
441 case kind::POW: return RewriteResponse(REWRITE_DONE, t);
442 case kind::PI: return RewriteResponse(REWRITE_DONE, t);
443 default: Unhandled() << k;
444 }
445 }
446 }
447
448 RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
449 if(t.isConst()){
450 return rewriteConstant(t);
451 }else if(t.isVar()){
452 return rewriteVariable(t);
453 }else{
454 Trace("arith-rewriter") << "postRewriteTerm: " << t << std::endl;
455 switch(t.getKind()){
456 case kind::REAL_ALGEBRAIC_NUMBER: return rewriteRAN(t);
457 case kind::SUB: return rewriteSub(t);
458 case kind::NEG: return rewriteNeg(t, false);
459 case kind::DIVISION:
460 case kind::DIVISION_TOTAL: return rewriteDiv(t, false);
461 case kind::PLUS: return postRewritePlus(t);
462 case kind::MULT:
463 case kind::NONLINEAR_MULT: return postRewriteMult(t);
464 case kind::IAND: return postRewriteIAnd(t);
465 case kind::POW2: return postRewritePow2(t);
466 case kind::EXPONENTIAL:
467 case kind::SINE:
468 case kind::COSINE:
469 case kind::TANGENT:
470 case kind::COSECANT:
471 case kind::SECANT:
472 case kind::COTANGENT:
473 case kind::ARCSINE:
474 case kind::ARCCOSINE:
475 case kind::ARCTANGENT:
476 case kind::ARCCOSECANT:
477 case kind::ARCSECANT:
478 case kind::ARCCOTANGENT:
479 case kind::SQRT: return postRewriteTranscendental(t);
480 case kind::INTS_DIVISION:
481 case kind::INTS_MODULUS: return rewriteIntsDivMod(t, false);
482 case kind::INTS_DIVISION_TOTAL:
483 case kind::INTS_MODULUS_TOTAL: return rewriteIntsDivModTotal(t, false);
484 case kind::ABS: return rewriteAbs(t);
485 case kind::TO_REAL:
486 case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
487 case kind::TO_INTEGER: return rewriteExtIntegerOp(t);
488 case kind::POW:
489 {
490 if (t[1].isConst())
491 {
492 const Rational& exp = t[1].getConst<Rational>();
493 TNode base = t[0];
494 if(exp.sgn() == 0){
495 return RewriteResponse(REWRITE_DONE,
496 NodeManager::currentNM()->mkConstRealOrInt(
497 t.getType(), Rational(1)));
498 }else if(exp.sgn() > 0 && exp.isIntegral()){
499 cvc5::Rational r(expr::NodeValue::MAX_CHILDREN);
500 if (exp <= r)
501 {
502 unsigned num = exp.getNumerator().toUnsignedInt();
503 if( num==1 ){
504 return RewriteResponse(REWRITE_AGAIN, base);
505 }else{
506 NodeBuilder nb(kind::MULT);
507 for(unsigned i=0; i < num; ++i){
508 nb << base;
509 }
510 Assert(nb.getNumChildren() > 0);
511 Node mult = nb;
512 return RewriteResponse(REWRITE_AGAIN, mult);
513 }
514 }
515 }
516 }
517 else if (t[0].isConst()
518 && t[0].getConst<Rational>().getNumerator().toUnsignedInt()
519 == 2)
520 {
521 return RewriteResponse(
522 REWRITE_DONE, NodeManager::currentNM()->mkNode(kind::POW2, t[1]));
523 }
524
525 // Todo improve the exception thrown
526 std::stringstream ss;
527 ss << "The exponent of the POW(^) operator can only be a positive "
528 "integral constant below "
529 << (expr::NodeValue::MAX_CHILDREN + 1) << ". ";
530 ss << "Exception occurred in:" << std::endl;
531 ss << " " << t;
532 throw LogicException(ss.str());
533 }
534 case kind::PI:
535 return RewriteResponse(REWRITE_DONE, t);
536 default:
537 Unreachable();
538 }
539 }
540 }
541
542
543 RewriteResponse ArithRewriter::preRewritePlus(TNode t){
544 Assert(t.getKind() == kind::PLUS);
545 return RewriteResponse(REWRITE_DONE, expr::algorithm::flatten(t));
546 }
547
548 RewriteResponse ArithRewriter::postRewritePlus(TNode t){
549 Assert(t.getKind() == kind::PLUS);
550 Assert(t.getNumChildren() > 1);
551
552 {
553 Node flat = expr::algorithm::flatten(t);
554 if (flat != t)
555 {
556 return RewriteResponse(REWRITE_AGAIN, flat);
557 }
558 }
559
560 Rational rational;
561 RealAlgebraicNumber ran;
562 std::vector<Monomial> monomials;
563 std::vector<Polynomial> polynomials;
564
565 for (const auto& child : t)
566 {
567 if (child.isConst())
568 {
569 if (child.getConst<Rational>().isZero())
570 {
571 continue;
572 }
573 rational += child.getConst<Rational>();
574 }
575 else if (child.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
576 {
577 ran += child.getOperator().getConst<RealAlgebraicNumber>();
578 }
579 else if (Monomial::isMember(child))
580 {
581 monomials.emplace_back(Monomial::parseMonomial(child));
582 }
583 else
584 {
585 polynomials.emplace_back(Polynomial::parsePolynomial(child));
586 }
587 }
588
589 if(!monomials.empty()){
590 Monomial::sort(monomials);
591 Monomial::combineAdjacentMonomials(monomials);
592 polynomials.emplace_back(Polynomial::mkPolynomial(monomials));
593 }
594 if (!rational.isZero())
595 {
596 polynomials.emplace_back(
597 Polynomial::mkPolynomial(Constant::mkConstant(rational)));
598 }
599
600 Polynomial poly = Polynomial::sumPolynomials(polynomials);
601
602 if (isZero(ran))
603 {
604 return RewriteResponse(REWRITE_DONE, poly.getNode());
605 }
606 if (poly.containsConstant())
607 {
608 ran += RealAlgebraicNumber(poly.getHead().getConstant().getValue());
609 if (!poly.isConstant())
610 {
611 poly = poly.getTail();
612 }
613 }
614
615 auto* nm = NodeManager::currentNM();
616 if (poly.isConstant())
617 {
618 return RewriteResponse(REWRITE_DONE, nm->mkRealAlgebraicNumber(ran));
619 }
620 return RewriteResponse(
621 REWRITE_DONE,
622 nm->mkNode(Kind::PLUS, nm->mkRealAlgebraicNumber(ran), poly.getNode()));
623 }
624
625 RewriteResponse ArithRewriter::preRewriteMult(TNode node)
626 {
627 Assert(node.getKind() == kind::MULT
628 || node.getKind() == kind::NONLINEAR_MULT);
629
630 auto res = getZeroChild(node);
631 if (res)
632 {
633 return RewriteResponse(REWRITE_DONE, *res);
634 }
635 return RewriteResponse(REWRITE_DONE, node);
636 }
637
638 RewriteResponse ArithRewriter::postRewriteMult(TNode t){
639 Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT);
640 Assert(t.getNumChildren() >= 2);
641
642 if (auto res = getZeroChild(t); res)
643 {
644 return RewriteResponse(REWRITE_DONE, *res);
645 }
646
647 Rational rational = Rational(1);
648 RealAlgebraicNumber ran = RealAlgebraicNumber(Integer(1));
649 Polynomial poly = Polynomial::mkOne();
650
651 for (const auto& child : t)
652 {
653 if (child.isConst())
654 {
655 if (child.getConst<Rational>().isZero())
656 {
657 return RewriteResponse(REWRITE_DONE, child);
658 }
659 rational *= child.getConst<Rational>();
660 }
661 else if (child.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
662 {
663 ran *= child.getOperator().getConst<RealAlgebraicNumber>();
664 }
665 else
666 {
667 poly = poly * Polynomial::parsePolynomial(child);
668 }
669 }
670
671 if (!rational.isOne())
672 {
673 poly = poly * rational;
674 }
675 if (isOne(ran))
676 {
677 return RewriteResponse(REWRITE_DONE, poly.getNode());
678 }
679 auto* nm = NodeManager::currentNM();
680 if (poly.isConstant())
681 {
682 ran *= RealAlgebraicNumber(poly.getHead().getConstant().getValue());
683 return RewriteResponse(REWRITE_DONE, nm->mkRealAlgebraicNumber(ran));
684 }
685 return RewriteResponse(
686 REWRITE_DONE,
687 nm->mkNode(
688 Kind::MULT, nm->mkRealAlgebraicNumber(ran), poly.getNode()));
689 }
690
691 RewriteResponse ArithRewriter::postRewritePow2(TNode t)
692 {
693 Assert(t.getKind() == kind::POW2);
694 NodeManager* nm = NodeManager::currentNM();
695 // if constant, we eliminate
696 if (t[0].isConst())
697 {
698 // pow2 is only supported for integers
699 Assert(t[0].getType().isInteger());
700 Integer i = t[0].getConst<Rational>().getNumerator();
701 if (i < 0)
702 {
703 return RewriteResponse(REWRITE_DONE, nm->mkConstInt(Rational(0)));
704 }
705 // (pow2 t) ---> (pow 2 t) and continue rewriting to eliminate pow
706 Node two = nm->mkConstInt(Rational(Integer(2)));
707 Node ret = nm->mkNode(kind::POW, two, t[0]);
708 return RewriteResponse(REWRITE_AGAIN, ret);
709 }
710 return RewriteResponse(REWRITE_DONE, t);
711 }
712
713 RewriteResponse ArithRewriter::postRewriteIAnd(TNode t)
714 {
715 Assert(t.getKind() == kind::IAND);
716 size_t bsize = t.getOperator().getConst<IntAnd>().d_size;
717 NodeManager* nm = NodeManager::currentNM();
718 // if constant, we eliminate
719 if (t[0].isConst() && t[1].isConst())
720 {
721 Node iToBvop = nm->mkConst(IntToBitVector(bsize));
722 Node arg1 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[0]);
723 Node arg2 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[1]);
724 Node bvand = nm->mkNode(kind::BITVECTOR_AND, arg1, arg2);
725 Node ret = nm->mkNode(kind::BITVECTOR_TO_NAT, bvand);
726 return RewriteResponse(REWRITE_AGAIN_FULL, ret);
727 }
728 else if (t[0] > t[1])
729 {
730 // ((_ iand k) x y) ---> ((_ iand k) y x) if x > y by node ordering
731 Node ret = nm->mkNode(kind::IAND, t.getOperator(), t[1], t[0]);
732 return RewriteResponse(REWRITE_AGAIN, ret);
733 }
734 else if (t[0] == t[1])
735 {
736 // ((_ iand k) x x) ---> x
737 return RewriteResponse(REWRITE_DONE, t[0]);
738 }
739 // simplifications involving constants
740 for (unsigned i = 0; i < 2; i++)
741 {
742 if (!t[i].isConst())
743 {
744 continue;
745 }
746 if (t[i].getConst<Rational>().sgn() == 0)
747 {
748 // ((_ iand k) 0 y) ---> 0
749 return RewriteResponse(REWRITE_DONE, t[i]);
750 }
751 if (t[i].getConst<Rational>().getNumerator() == Integer(2).pow(bsize) - 1)
752 {
753 // ((_ iand k) 111...1 y) ---> (mod y 2^k)
754 Node twok = nm->mkConstInt(Rational(Integer(2).pow(bsize)));
755 Node ret = nm->mkNode(kind::INTS_MODULUS, t[1-i], twok);
756 return RewriteResponse(REWRITE_AGAIN, ret);
757 }
758 }
759 return RewriteResponse(REWRITE_DONE, t);
760 }
761
762 RewriteResponse ArithRewriter::preRewriteTranscendental(TNode t) {
763 return RewriteResponse(REWRITE_DONE, t);
764 }
765
766 RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) {
767 Trace("arith-tf-rewrite") << "Rewrite transcendental function : " << t << std::endl;
768 NodeManager* nm = NodeManager::currentNM();
769 switch( t.getKind() ){
770 case kind::EXPONENTIAL: {
771 if (t[0].isConst())
772 {
773 Node one = nm->mkConstReal(Rational(1));
774 if(t[0].getConst<Rational>().sgn()>=0 && t[0].getType().isInteger() && t[0]!=one){
775 return RewriteResponse(
776 REWRITE_AGAIN,
777 nm->mkNode(kind::POW, nm->mkNode(kind::EXPONENTIAL, one), t[0]));
778 }else{
779 return RewriteResponse(REWRITE_DONE, t);
780 }
781 }
782 else if (t[0].getKind() == kind::PLUS)
783 {
784 std::vector<Node> product;
785 for (const Node tc : t[0])
786 {
787 product.push_back(nm->mkNode(kind::EXPONENTIAL, tc));
788 }
789 // We need to do a full rewrite here, since we can get exponentials of
790 // constants, e.g. when we are rewriting exp(2 + x)
791 return RewriteResponse(REWRITE_AGAIN_FULL,
792 nm->mkNode(kind::MULT, product));
793 }
794 }
795 break;
796 case kind::SINE:
797 if (t[0].isConst())
798 {
799 const Rational& rat = t[0].getConst<Rational>();
800 if(rat.sgn() == 0){
801 return RewriteResponse(REWRITE_DONE, nm->mkConstReal(Rational(0)));
802 }
803 else if (rat.sgn() == -1)
804 {
805 Node ret = nm->mkNode(kind::NEG,
806 nm->mkNode(kind::SINE, nm->mkConstReal(-rat)));
807 return RewriteResponse(REWRITE_AGAIN_FULL, ret);
808 }
809 }else{
810 // get the factor of PI in the argument
811 Node pi_factor;
812 Node pi;
813 Node rem;
814 std::map<Node, Node> msum;
815 if (ArithMSum::getMonomialSum(t[0], msum))
816 {
817 pi = mkPi();
818 std::map<Node, Node>::iterator itm = msum.find(pi);
819 if (itm != msum.end())
820 {
821 if (itm->second.isNull())
822 {
823 pi_factor = nm->mkConstReal(Rational(1));
824 }
825 else
826 {
827 pi_factor = itm->second;
828 }
829 msum.erase(pi);
830 if (!msum.empty())
831 {
832 rem = ArithMSum::mkNode(t[0].getType(), msum);
833 }
834 }
835 }
836 else
837 {
838 Assert(false);
839 }
840
841 // if there is a factor of PI
842 if( !pi_factor.isNull() ){
843 Trace("arith-tf-rewrite-debug") << "Process pi factor = " << pi_factor << std::endl;
844 Rational r = pi_factor.getConst<Rational>();
845 Rational r_abs = r.abs();
846 Rational rone = Rational(1);
847 Rational rtwo = Rational(2);
848 if (r_abs > rone)
849 {
850 //add/substract 2*pi beyond scope
851 Rational ra_div_two = (r_abs + rone) / rtwo;
852 Node new_pi_factor;
853 if( r.sgn()==1 ){
854 new_pi_factor = nm->mkConstReal(r - rtwo * ra_div_two.floor());
855 }else{
856 Assert(r.sgn() == -1);
857 new_pi_factor = nm->mkConstReal(r + rtwo * ra_div_two.floor());
858 }
859 Node new_arg = nm->mkNode(kind::MULT, new_pi_factor, pi);
860 if (!rem.isNull())
861 {
862 new_arg = nm->mkNode(kind::PLUS, new_arg, rem);
863 }
864 // sin( 2*n*PI + x ) = sin( x )
865 return RewriteResponse(REWRITE_AGAIN_FULL,
866 nm->mkNode(kind::SINE, new_arg));
867 }
868 else if (r_abs == rone)
869 {
870 // sin( PI + x ) = -sin( x )
871 if (rem.isNull())
872 {
873 return RewriteResponse(REWRITE_DONE, nm->mkConstReal(Rational(0)));
874 }
875 else
876 {
877 return RewriteResponse(
878 REWRITE_AGAIN_FULL,
879 nm->mkNode(kind::NEG, nm->mkNode(kind::SINE, rem)));
880 }
881 }
882 else if (rem.isNull())
883 {
884 // other rational cases based on Niven's theorem
885 // (https://en.wikipedia.org/wiki/Niven%27s_theorem)
886 Integer one = Integer(1);
887 Integer two = Integer(2);
888 Integer six = Integer(6);
889 if (r_abs.getDenominator() == two)
890 {
891 Assert(r_abs.getNumerator() == one);
892 return RewriteResponse(REWRITE_DONE,
893 nm->mkConstReal(Rational(r.sgn())));
894 }
895 else if (r_abs.getDenominator() == six)
896 {
897 Integer five = Integer(5);
898 if (r_abs.getNumerator() == one || r_abs.getNumerator() == five)
899 {
900 return RewriteResponse(
901 REWRITE_DONE,
902 nm->mkConstReal(Rational(r.sgn()) / Rational(2)));
903 }
904 }
905 }
906 }
907 }
908 break;
909 case kind::COSINE: {
910 return RewriteResponse(
911 REWRITE_AGAIN_FULL,
912 nm->mkNode(
913 kind::SINE,
914 nm->mkNode(kind::SUB,
915 nm->mkNode(kind::MULT,
916 nm->mkConstReal(Rational(1) / Rational(2)),
917 mkPi()),
918 t[0])));
919 }
920 break;
921 case kind::TANGENT:
922 {
923 return RewriteResponse(REWRITE_AGAIN_FULL,
924 nm->mkNode(kind::DIVISION,
925 nm->mkNode(kind::SINE, t[0]),
926 nm->mkNode(kind::COSINE, t[0])));
927 }
928 break;
929 case kind::COSECANT:
930 {
931 return RewriteResponse(REWRITE_AGAIN_FULL,
932 nm->mkNode(kind::DIVISION,
933 nm->mkConstReal(Rational(1)),
934 nm->mkNode(kind::SINE, t[0])));
935 }
936 break;
937 case kind::SECANT:
938 {
939 return RewriteResponse(REWRITE_AGAIN_FULL,
940 nm->mkNode(kind::DIVISION,
941 nm->mkConstReal(Rational(1)),
942 nm->mkNode(kind::COSINE, t[0])));
943 }
944 break;
945 case kind::COTANGENT:
946 {
947 return RewriteResponse(REWRITE_AGAIN_FULL,
948 nm->mkNode(kind::DIVISION,
949 nm->mkNode(kind::COSINE, t[0]),
950 nm->mkNode(kind::SINE, t[0])));
951 }
952 break;
953 default:
954 break;
955 }
956 return RewriteResponse(REWRITE_DONE, t);
957 }
958
959 Node ArithRewriter::makeUnaryMinusNode(TNode n){
960 NodeManager* nm = NodeManager::currentNM();
961 Rational qNegOne(-1);
962 return nm->mkNode(kind::MULT, nm->mkConstRealOrInt(n.getType(), qNegOne), n);
963 }
964
965 RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
966 Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind() == kind::DIVISION);
967 Assert(t.getNumChildren() == 2);
968
969 Node left = t[0];
970 Node right = t[1];
971 if (right.isConst())
972 {
973 NodeManager* nm = NodeManager::currentNM();
974 const Rational& den = right.getConst<Rational>();
975
976 if(den.isZero()){
977 if(t.getKind() == kind::DIVISION_TOTAL){
978 return RewriteResponse(REWRITE_DONE, nm->mkConstReal(0));
979 }else{
980 // This is unsupported, but this is not a good place to complain
981 return RewriteResponse(REWRITE_DONE, t);
982 }
983 }
984 Assert(den != Rational(0));
985
986 if (left.isConst())
987 {
988 const Rational& num = left.getConst<Rational>();
989 return RewriteResponse(REWRITE_DONE, nm->mkConstReal(num / den));
990 }
991 if (left.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
992 {
993 const RealAlgebraicNumber& num =
994 left.getOperator().getConst<RealAlgebraicNumber>();
995 return RewriteResponse(
996 REWRITE_DONE,
997 nm->mkRealAlgebraicNumber(num / RealAlgebraicNumber(den)));
998 }
999
1000 Node result = nm->mkConstReal(den.inverse());
1001 Node mult = NodeManager::currentNM()->mkNode(kind::MULT, left, result);
1002 if (pre)
1003 {
1004 return RewriteResponse(REWRITE_DONE, mult);
1005 }
1006 else
1007 {
1008 return RewriteResponse(REWRITE_AGAIN, mult);
1009 }
1010 }
1011 if (right.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
1012 {
1013 NodeManager* nm = NodeManager::currentNM();
1014 const RealAlgebraicNumber& den =
1015 right.getOperator().getConst<RealAlgebraicNumber>();
1016 if (left.isConst())
1017 {
1018 const Rational& num = left.getConst<Rational>();
1019 return RewriteResponse(
1020 REWRITE_DONE,
1021 nm->mkRealAlgebraicNumber(RealAlgebraicNumber(num) / den));
1022 }
1023 if (left.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
1024 {
1025 const RealAlgebraicNumber& num =
1026 left.getOperator().getConst<RealAlgebraicNumber>();
1027 return RewriteResponse(REWRITE_DONE,
1028 nm->mkRealAlgebraicNumber(num / den));
1029 }
1030
1031 Node result = nm->mkRealAlgebraicNumber(inverse(den));
1032 Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
1033 if(pre){
1034 return RewriteResponse(REWRITE_DONE, mult);
1035 }else{
1036 return RewriteResponse(REWRITE_AGAIN, mult);
1037 }
1038 }
1039 return RewriteResponse(REWRITE_DONE, t);
1040 }
1041
1042 RewriteResponse ArithRewriter::rewriteAbs(TNode t)
1043 {
1044 Assert(t.getKind() == Kind::ABS);
1045 Assert(t.getNumChildren() == 1);
1046
1047 if (t[0].isConst())
1048 {
1049 const Rational& rat = t[0].getConst<Rational>();
1050 if (rat >= 0)
1051 {
1052 return RewriteResponse(REWRITE_DONE, t[0]);
1053 }
1054 return RewriteResponse(
1055 REWRITE_DONE,
1056 NodeManager::currentNM()->mkConstRealOrInt(t[0].getType(), -rat));
1057 }
1058 if (t[0].getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
1059 {
1060 const RealAlgebraicNumber& ran =
1061 t[0].getOperator().getConst<RealAlgebraicNumber>();
1062 if (ran >= RealAlgebraicNumber())
1063 {
1064 return RewriteResponse(REWRITE_DONE, t[0]);
1065 }
1066 return RewriteResponse(
1067 REWRITE_DONE, NodeManager::currentNM()->mkRealAlgebraicNumber(-ran));
1068 }
1069 return RewriteResponse(REWRITE_DONE, t);
1070 }
1071
1072 RewriteResponse ArithRewriter::rewriteIntsDivMod(TNode t, bool pre)
1073 {
1074 NodeManager* nm = NodeManager::currentNM();
1075 Kind k = t.getKind();
1076 if (k == kind::INTS_MODULUS)
1077 {
1078 if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
1079 {
1080 // can immediately replace by INTS_MODULUS_TOTAL
1081 Node ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, t[0], t[1]);
1082 return returnRewrite(t, ret, Rewrite::MOD_TOTAL_BY_CONST);
1083 }
1084 }
1085 if (k == kind::INTS_DIVISION)
1086 {
1087 if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
1088 {
1089 // can immediately replace by INTS_DIVISION_TOTAL
1090 Node ret = nm->mkNode(kind::INTS_DIVISION_TOTAL, t[0], t[1]);
1091 return returnRewrite(t, ret, Rewrite::DIV_TOTAL_BY_CONST);
1092 }
1093 }
1094 return RewriteResponse(REWRITE_DONE, t);
1095 }
1096
1097 RewriteResponse ArithRewriter::rewriteExtIntegerOp(TNode t)
1098 {
1099 Assert(t.getKind() == kind::TO_INTEGER || t.getKind() == kind::IS_INTEGER);
1100 bool isPred = t.getKind() == kind::IS_INTEGER;
1101 NodeManager* nm = NodeManager::currentNM();
1102 if (t[0].isConst())
1103 {
1104 Node ret;
1105 if (isPred)
1106 {
1107 ret = nm->mkConst(t[0].getConst<Rational>().isIntegral());
1108 }
1109 else
1110 {
1111 ret = nm->mkConstInt(Rational(t[0].getConst<Rational>().floor()));
1112 }
1113 return returnRewrite(t, ret, Rewrite::INT_EXT_CONST);
1114 }
1115 if (t[0].getType().isInteger())
1116 {
1117 Node ret = isPred ? nm->mkConst(true) : Node(t[0]);
1118 return returnRewrite(t, ret, Rewrite::INT_EXT_INT);
1119 }
1120 if (t[0].getKind() == kind::PI)
1121 {
1122 Node ret = isPred ? nm->mkConst(false) : nm->mkConstReal(Rational(3));
1123 return returnRewrite(t, ret, Rewrite::INT_EXT_PI);
1124 }
1125 return RewriteResponse(REWRITE_DONE, t);
1126 }
1127
1128 RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre)
1129 {
1130 if (pre)
1131 {
1132 // do not rewrite at prewrite.
1133 return RewriteResponse(REWRITE_DONE, t);
1134 }
1135 NodeManager* nm = NodeManager::currentNM();
1136 Kind k = t.getKind();
1137 Assert(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL);
1138 TNode n = t[0];
1139 TNode d = t[1];
1140 bool dIsConstant = d.isConst();
1141 if(dIsConstant && d.getConst<Rational>().isZero()){
1142 // (div x 0) ---> 0 or (mod x 0) ---> 0
1143 return returnRewrite(t, nm->mkConstInt(0), Rewrite::DIV_MOD_BY_ZERO);
1144 }else if(dIsConstant && d.getConst<Rational>().isOne()){
1145 if (k == kind::INTS_MODULUS_TOTAL)
1146 {
1147 // (mod x 1) --> 0
1148 return returnRewrite(t, nm->mkConstInt(0), Rewrite::MOD_BY_ONE);
1149 }
1150 Assert(k == kind::INTS_DIVISION_TOTAL);
1151 // (div x 1) --> x
1152 return returnRewrite(t, n, Rewrite::DIV_BY_ONE);
1153 }
1154 else if (dIsConstant && d.getConst<Rational>().sgn() < 0)
1155 {
1156 // pull negation
1157 // (div x (- c)) ---> (- (div x c))
1158 // (mod x (- c)) ---> (mod x c)
1159 Node nn = nm->mkNode(k, t[0], nm->mkConstInt(-t[1].getConst<Rational>()));
1160 Node ret = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL)
1161 ? nm->mkNode(kind::NEG, nn)
1162 : nn;
1163 return returnRewrite(t, ret, Rewrite::DIV_MOD_PULL_NEG_DEN);
1164 }
1165 else if (dIsConstant && n.isConst())
1166 {
1167 Assert(d.getConst<Rational>().isIntegral());
1168 Assert(n.getConst<Rational>().isIntegral());
1169 Assert(!d.getConst<Rational>().isZero());
1170 Integer di = d.getConst<Rational>().getNumerator();
1171 Integer ni = n.getConst<Rational>().getNumerator();
1172
1173 bool isDiv = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
1174
1175 Integer result = isDiv ? ni.euclidianDivideQuotient(di) : ni.euclidianDivideRemainder(di);
1176
1177 // constant evaluation
1178 // (mod c1 c2) ---> c3 or (div c1 c2) ---> c3
1179 Node resultNode = nm->mkConstInt(Rational(result));
1180 return returnRewrite(t, resultNode, Rewrite::CONST_EVAL);
1181 }
1182 if (k == kind::INTS_MODULUS_TOTAL)
1183 {
1184 // Note these rewrites do not need to account for modulus by zero as being
1185 // a UF, which is handled by the reduction of INTS_MODULUS.
1186 Kind k0 = t[0].getKind();
1187 if (k0 == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
1188 {
1189 // (mod (mod x c) c) --> (mod x c)
1190 return returnRewrite(t, t[0], Rewrite::MOD_OVER_MOD);
1191 }
1192 else if (k0 == kind::NONLINEAR_MULT || k0 == kind::MULT || k0 == kind::PLUS)
1193 {
1194 // can drop all
1195 std::vector<Node> newChildren;
1196 bool childChanged = false;
1197 for (const Node& tc : t[0])
1198 {
1199 if (tc.getKind() == kind::INTS_MODULUS_TOTAL && tc[1] == t[1])
1200 {
1201 newChildren.push_back(tc[0]);
1202 childChanged = true;
1203 continue;
1204 }
1205 newChildren.push_back(tc);
1206 }
1207 if (childChanged)
1208 {
1209 // (mod (op ... (mod x c) ...) c) ---> (mod (op ... x ...) c) where
1210 // op is one of { NONLINEAR_MULT, MULT, PLUS }.
1211 Node ret = nm->mkNode(k0, newChildren);
1212 ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, ret, t[1]);
1213 return returnRewrite(t, ret, Rewrite::MOD_CHILD_MOD);
1214 }
1215 }
1216 }
1217 else
1218 {
1219 Assert(k == kind::INTS_DIVISION_TOTAL);
1220 // Note these rewrites do not need to account for division by zero as being
1221 // a UF, which is handled by the reduction of INTS_DIVISION.
1222 if (t[0].getKind() == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
1223 {
1224 // (div (mod x c) c) --> 0
1225 Node ret = nm->mkConstInt(0);
1226 return returnRewrite(t, ret, Rewrite::DIV_OVER_MOD);
1227 }
1228 }
1229 return RewriteResponse(REWRITE_DONE, t);
1230 }
1231
1232 TrustNode ArithRewriter::expandDefinition(Node node)
1233 {
1234 // call eliminate operators, to eliminate partial operators only
1235 std::vector<SkolemLemma> lems;
1236 TrustNode ret = d_opElim.eliminate(node, lems, true);
1237 Assert(lems.empty());
1238 return ret;
1239 }
1240
1241 RewriteResponse ArithRewriter::returnRewrite(TNode t, Node ret, Rewrite r)
1242 {
1243 Trace("arith-rewrite") << "ArithRewriter : " << t << " == " << ret << " by "
1244 << r << std::endl;
1245 return RewriteResponse(REWRITE_AGAIN_FULL, ret);
1246 }
1247
1248 } // namespace arith
1249 } // namespace theory
1250 } // namespace cvc5