Slightly refactor arithmetic rewriting for extended operators (#8169)
[cvc5.git] / src / theory / arith / arith_rewriter.cpp
1 /******************************************************************************
2 * Top contributors (to current version):
3 * Andrew Reynolds, Tim King, Morgan Deters
4 *
5 * This file is part of the cvc5 project.
6 *
7 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 * in the top-level source directory and their institutional affiliations.
9 * All rights reserved. See the file COPYING in the top-level source
10 * directory for licensing information.
11 * ****************************************************************************
12 *
13 * [[ Add one-line brief description here ]]
14 *
15 * [[ Add lengthier description here ]]
16 * \todo document this file
17 */
18
19 #include "theory/arith/arith_rewriter.h"
20
21 #include <optional>
22 #include <set>
23 #include <sstream>
24 #include <stack>
25 #include <vector>
26
27 #include "expr/algorithm/flatten.h"
28 #include "smt/logic_exception.h"
29 #include "theory/arith/arith_msum.h"
30 #include "theory/arith/arith_utilities.h"
31 #include "theory/arith/normal_form.h"
32 #include "theory/arith/operator_elim.h"
33 #include "theory/arith/rewriter/addition.h"
34 #include "theory/arith/rewriter/node_utils.h"
35 #include "theory/arith/rewriter/ordering.h"
36 #include "theory/theory.h"
37 #include "util/bitvector.h"
38 #include "util/divisible.h"
39 #include "util/iand.h"
40 #include "util/real_algebraic_number.h"
41
42 using namespace cvc5::kind;
43
44 namespace cvc5 {
45 namespace theory {
46 namespace arith {
47
48 namespace {
49
50 template <typename L, typename R>
51 bool evaluateRelation(Kind rel, const L& l, const R& r)
52 {
53 switch (rel)
54 {
55 case Kind::LT: return l < r;
56 case Kind::LEQ: return l <= r;
57 case Kind::EQUAL: return l == r;
58 case Kind::GEQ: return l >= r;
59 case Kind::GT: return l > r;
60 default: Unreachable(); return false;
61 }
62 }
63
64 /**
65 * Check whether the parent has a child that is a constant zero.
66 * If so, return this child. Otherwise, return std::nullopt.
67 */
68 template <typename Iterable>
69 std::optional<TNode> getZeroChild(const Iterable& parent)
70 {
71 for (const auto& node : parent)
72 {
73 if (node.isConst() && node.template getConst<Rational>().isZero())
74 {
75 return node;
76 }
77 }
78 return std::nullopt;
79 }
80
81 } // namespace
82
83 ArithRewriter::ArithRewriter(OperatorElim& oe) : d_opElim(oe) {}
84
85 RewriteResponse ArithRewriter::preRewrite(TNode t)
86 {
87 Trace("arith-rewriter") << "preRewrite(" << t << ")" << std::endl;
88 if (isAtom(t))
89 {
90 auto res = preRewriteAtom(t);
91 Trace("arith-rewriter")
92 << res.d_status << " -> " << res.d_node << std::endl;
93 return res;
94 }
95 auto res = preRewriteTerm(t);
96 Trace("arith-rewriter") << res.d_status << " -> " << res.d_node << std::endl;
97 return res;
98 }
99
100 RewriteResponse ArithRewriter::postRewrite(TNode t)
101 {
102 Trace("arith-rewriter") << "postRewrite(" << t << ")" << std::endl;
103 if (isAtom(t))
104 {
105 auto res = postRewriteAtom(t);
106 Trace("arith-rewriter")
107 << res.d_status << " -> " << res.d_node << std::endl;
108 return res;
109 }
110 auto res = postRewriteTerm(t);
111 Trace("arith-rewriter") << res.d_status << " -> " << res.d_node << std::endl;
112 return res;
113 }
114
115 RewriteResponse ArithRewriter::preRewriteAtom(TNode atom)
116 {
117 Assert(isAtom(atom));
118
119 NodeManager* nm = NodeManager::currentNM();
120
121 if (isRelationOperator(atom.getKind()) && atom[0] == atom[1])
122 {
123 switch (atom.getKind())
124 {
125 case Kind::LT: return RewriteResponse(REWRITE_DONE, nm->mkConst(false));
126 case Kind::LEQ: return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
127 case Kind::EQUAL: return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
128 case Kind::GEQ: return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
129 case Kind::GT: return RewriteResponse(REWRITE_DONE, nm->mkConst(false));
130 default:;
131 }
132 }
133
134 switch (atom.getKind())
135 {
136 case Kind::GT:
137 return RewriteResponse(REWRITE_DONE,
138 nm->mkNode(kind::LEQ, atom[0], atom[1]).notNode());
139 case Kind::LT:
140 return RewriteResponse(REWRITE_DONE,
141 nm->mkNode(kind::GEQ, atom[0], atom[1]).notNode());
142 case Kind::IS_INTEGER:
143 if (atom[0].getType().isInteger())
144 {
145 return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
146 }
147 break;
148 case Kind::DIVISIBLE:
149 if (atom.getOperator().getConst<Divisible>().k.isOne())
150 {
151 return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
152 }
153 break;
154 default:;
155 }
156
157 return RewriteResponse(REWRITE_DONE, atom);
158 }
159
160 RewriteResponse ArithRewriter::postRewriteAtom(TNode atom)
161 {
162 Assert(isAtom(atom));
163 if (atom.getKind() == kind::IS_INTEGER)
164 {
165 return rewriteExtIntegerOp(atom);
166 }
167 else if (atom.getKind() == kind::DIVISIBLE)
168 {
169 if (atom[0].isConst())
170 {
171 return RewriteResponse(REWRITE_DONE,
172 NodeManager::currentNM()->mkConst(bool(
173 (atom[0].getConst<Rational>()
174 / atom.getOperator().getConst<Divisible>().k)
175 .isIntegral())));
176 }
177 if (atom.getOperator().getConst<Divisible>().k.isOne())
178 {
179 return RewriteResponse(REWRITE_DONE,
180 NodeManager::currentNM()->mkConst(true));
181 }
182 NodeManager* nm = NodeManager::currentNM();
183 return RewriteResponse(
184 REWRITE_AGAIN,
185 nm->mkNode(kind::EQUAL,
186 nm->mkNode(kind::INTS_MODULUS_TOTAL,
187 atom[0],
188 nm->mkConstInt(Rational(
189 atom.getOperator().getConst<Divisible>().k))),
190 nm->mkConstInt(Rational(0))));
191 }
192
193 // left |><| right
194 TNode left = atom[0];
195 TNode right = atom[1];
196
197 auto* nm = NodeManager::currentNM();
198 if (left.isConst())
199 {
200 const Rational& l = left.getConst<Rational>();
201 if (right.isConst())
202 {
203 const Rational& r = right.getConst<Rational>();
204 return RewriteResponse(
205 REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
206 }
207 else if (right.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
208 {
209 const RealAlgebraicNumber& r =
210 right.getOperator().getConst<RealAlgebraicNumber>();
211 return RewriteResponse(
212 REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
213 }
214 }
215 else if (left.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
216 {
217 const RealAlgebraicNumber& l =
218 left.getOperator().getConst<RealAlgebraicNumber>();
219 if (right.isConst())
220 {
221 const Rational& r = right.getConst<Rational>();
222 return RewriteResponse(
223 REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
224 }
225 else if (right.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
226 {
227 const RealAlgebraicNumber& r =
228 right.getOperator().getConst<RealAlgebraicNumber>();
229 return RewriteResponse(
230 REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
231 }
232 }
233
234 Polynomial pleft = Polynomial::parsePolynomial(left);
235 Polynomial pright = Polynomial::parsePolynomial(right);
236
237 Debug("arith::rewriter") << "pleft " << pleft.getNode() << std::endl;
238 Debug("arith::rewriter") << "pright " << pright.getNode() << std::endl;
239
240 Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
241 Assert(cmp.isNormalForm());
242 return RewriteResponse(REWRITE_DONE, cmp.getNode());
243 }
244
245 bool ArithRewriter::isAtom(TNode n) {
246 Kind k = n.getKind();
247 return arith::isRelationOperator(k) || k == kind::IS_INTEGER
248 || k == kind::DIVISIBLE;
249 }
250
251 RewriteResponse ArithRewriter::rewriteConstant(TNode t){
252 Assert(t.isConst());
253 Assert(t.getKind() == CONST_RATIONAL || t.getKind() == CONST_INTEGER);
254
255 return RewriteResponse(REWRITE_DONE, t);
256 }
257
258 RewriteResponse ArithRewriter::rewriteRAN(TNode t)
259 {
260 Assert(t.getKind() == REAL_ALGEBRAIC_NUMBER);
261
262 const RealAlgebraicNumber& r =
263 t.getOperator().getConst<RealAlgebraicNumber>();
264 if (r.isRational())
265 {
266 return RewriteResponse(
267 REWRITE_DONE, NodeManager::currentNM()->mkConstReal(r.toRational()));
268 }
269
270 return RewriteResponse(REWRITE_DONE, t);
271 }
272
273 RewriteResponse ArithRewriter::rewriteVariable(TNode t){
274 Assert(t.isVar());
275
276 return RewriteResponse(REWRITE_DONE, t);
277 }
278
279 RewriteResponse ArithRewriter::rewriteSub(TNode t)
280 {
281 Assert(t.getKind() == kind::SUB);
282 Assert(t.getNumChildren() == 2);
283
284 auto* nm = NodeManager::currentNM();
285
286 if (t[0] == t[1])
287 {
288 return RewriteResponse(REWRITE_DONE,
289 nm->mkConstRealOrInt(t.getType(), Rational(0)));
290 }
291 return RewriteResponse(REWRITE_AGAIN_FULL,
292 nm->mkNode(Kind::ADD, t[0], makeUnaryMinusNode(t[1])));
293 }
294
295 RewriteResponse ArithRewriter::rewriteNeg(TNode t, bool pre)
296 {
297 Assert(t.getKind() == kind::NEG);
298
299 if (t[0].isConst())
300 {
301 Rational neg = -(t[0].getConst<Rational>());
302 NodeManager* nm = NodeManager::currentNM();
303 return RewriteResponse(REWRITE_DONE,
304 nm->mkConstRealOrInt(t[0].getType(), neg));
305 }
306 if (t[0].getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
307 {
308 const RealAlgebraicNumber& r =
309 t[0].getOperator().getConst<RealAlgebraicNumber>();
310 NodeManager* nm = NodeManager::currentNM();
311 return RewriteResponse(REWRITE_DONE, nm->mkRealAlgebraicNumber(-r));
312 }
313
314 Node noUminus = makeUnaryMinusNode(t[0]);
315 if(pre)
316 return RewriteResponse(REWRITE_DONE, noUminus);
317 else
318 return RewriteResponse(REWRITE_AGAIN, noUminus);
319 }
320
321 RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
322 if(t.isConst()){
323 return rewriteConstant(t);
324 }else if(t.isVar()){
325 return rewriteVariable(t);
326 }else{
327 switch(Kind k = t.getKind()){
328 case kind::REAL_ALGEBRAIC_NUMBER: return rewriteRAN(t);
329 case kind::SUB: return rewriteSub(t);
330 case kind::NEG: return rewriteNeg(t, true);
331 case kind::DIVISION:
332 case kind::DIVISION_TOTAL: return rewriteDiv(t, true);
333 case kind::ADD: return preRewritePlus(t);
334 case kind::MULT:
335 case kind::NONLINEAR_MULT: return preRewriteMult(t);
336 case kind::IAND: return RewriteResponse(REWRITE_DONE, t);
337 case kind::POW2: return RewriteResponse(REWRITE_DONE, t);
338 case kind::EXPONENTIAL:
339 case kind::SINE:
340 case kind::COSINE:
341 case kind::TANGENT:
342 case kind::COSECANT:
343 case kind::SECANT:
344 case kind::COTANGENT:
345 case kind::ARCSINE:
346 case kind::ARCCOSINE:
347 case kind::ARCTANGENT:
348 case kind::ARCCOSECANT:
349 case kind::ARCSECANT:
350 case kind::ARCCOTANGENT:
351 case kind::SQRT: return preRewriteTranscendental(t);
352 case kind::INTS_DIVISION:
353 case kind::INTS_MODULUS: return rewriteIntsDivMod(t, true);
354 case kind::INTS_DIVISION_TOTAL:
355 case kind::INTS_MODULUS_TOTAL: return rewriteIntsDivModTotal(t, true);
356 case kind::ABS: return rewriteAbs(t);
357 case kind::IS_INTEGER:
358 case kind::TO_INTEGER: return RewriteResponse(REWRITE_DONE, t);
359 case kind::TO_REAL:
360 case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
361 case kind::POW: return RewriteResponse(REWRITE_DONE, t);
362 case kind::PI: return RewriteResponse(REWRITE_DONE, t);
363 default: Unhandled() << k;
364 }
365 }
366 }
367
368 RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
369 if(t.isConst()){
370 return rewriteConstant(t);
371 }else if(t.isVar()){
372 return rewriteVariable(t);
373 }else{
374 Trace("arith-rewriter") << "postRewriteTerm: " << t << std::endl;
375 switch(t.getKind()){
376 case kind::REAL_ALGEBRAIC_NUMBER: return rewriteRAN(t);
377 case kind::SUB: return rewriteSub(t);
378 case kind::NEG: return rewriteNeg(t, false);
379 case kind::DIVISION:
380 case kind::DIVISION_TOTAL: return rewriteDiv(t, false);
381 case kind::ADD: return postRewritePlus(t);
382 case kind::MULT:
383 case kind::NONLINEAR_MULT: return postRewriteMult(t);
384 case kind::IAND: return postRewriteIAnd(t);
385 case kind::POW2: return postRewritePow2(t);
386 case kind::EXPONENTIAL:
387 case kind::SINE:
388 case kind::COSINE:
389 case kind::TANGENT:
390 case kind::COSECANT:
391 case kind::SECANT:
392 case kind::COTANGENT:
393 case kind::ARCSINE:
394 case kind::ARCCOSINE:
395 case kind::ARCTANGENT:
396 case kind::ARCCOSECANT:
397 case kind::ARCSECANT:
398 case kind::ARCCOTANGENT:
399 case kind::SQRT: return postRewriteTranscendental(t);
400 case kind::INTS_DIVISION:
401 case kind::INTS_MODULUS: return rewriteIntsDivMod(t, false);
402 case kind::INTS_DIVISION_TOTAL:
403 case kind::INTS_MODULUS_TOTAL: return rewriteIntsDivModTotal(t, false);
404 case kind::ABS: return rewriteAbs(t);
405 case kind::TO_REAL:
406 case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
407 case kind::TO_INTEGER: return rewriteExtIntegerOp(t);
408 case kind::POW:
409 {
410 if (t[1].isConst())
411 {
412 const Rational& exp = t[1].getConst<Rational>();
413 TNode base = t[0];
414 if(exp.sgn() == 0){
415 return RewriteResponse(REWRITE_DONE,
416 NodeManager::currentNM()->mkConstRealOrInt(
417 t.getType(), Rational(1)));
418 }else if(exp.sgn() > 0 && exp.isIntegral()){
419 cvc5::Rational r(expr::NodeValue::MAX_CHILDREN);
420 if (exp <= r)
421 {
422 unsigned num = exp.getNumerator().toUnsignedInt();
423 if( num==1 ){
424 return RewriteResponse(REWRITE_AGAIN, base);
425 }else{
426 NodeBuilder nb(kind::MULT);
427 for(unsigned i=0; i < num; ++i){
428 nb << base;
429 }
430 Assert(nb.getNumChildren() > 0);
431 Node mult = nb;
432 return RewriteResponse(REWRITE_AGAIN, mult);
433 }
434 }
435 }
436 }
437 else if (t[0].isConst()
438 && t[0].getConst<Rational>().getNumerator().toUnsignedInt()
439 == 2)
440 {
441 return RewriteResponse(
442 REWRITE_DONE, NodeManager::currentNM()->mkNode(kind::POW2, t[1]));
443 }
444
445 // Todo improve the exception thrown
446 std::stringstream ss;
447 ss << "The exponent of the POW(^) operator can only be a positive "
448 "integral constant below "
449 << (expr::NodeValue::MAX_CHILDREN + 1) << ". ";
450 ss << "Exception occurred in:" << std::endl;
451 ss << " " << t;
452 throw LogicException(ss.str());
453 }
454 case kind::PI:
455 return RewriteResponse(REWRITE_DONE, t);
456 default:
457 Unreachable();
458 }
459 }
460 }
461
462
463 RewriteResponse ArithRewriter::preRewritePlus(TNode t){
464 Assert(t.getKind() == kind::ADD);
465 return RewriteResponse(REWRITE_DONE, expr::algorithm::flatten(t));
466 }
467
468 RewriteResponse ArithRewriter::postRewritePlus(TNode t){
469 Assert(t.getKind() == kind::ADD);
470 Assert(t.getNumChildren() > 1);
471
472 {
473 Node flat = expr::algorithm::flatten(t);
474 if (flat != t)
475 {
476 return RewriteResponse(REWRITE_AGAIN, flat);
477 }
478 }
479
480 Rational rational;
481 RealAlgebraicNumber ran;
482 std::vector<Monomial> monomials;
483 std::vector<Polynomial> polynomials;
484
485 for (const auto& child : t)
486 {
487 if (child.isConst())
488 {
489 if (child.getConst<Rational>().isZero())
490 {
491 continue;
492 }
493 rational += child.getConst<Rational>();
494 }
495 else if (child.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
496 {
497 ran += child.getOperator().getConst<RealAlgebraicNumber>();
498 }
499 else if (Monomial::isMember(child))
500 {
501 monomials.emplace_back(Monomial::parseMonomial(child));
502 }
503 else
504 {
505 polynomials.emplace_back(Polynomial::parsePolynomial(child));
506 }
507 }
508
509 if(!monomials.empty()){
510 Monomial::sort(monomials);
511 Monomial::combineAdjacentMonomials(monomials);
512 polynomials.emplace_back(Polynomial::mkPolynomial(monomials));
513 }
514 if (!rational.isZero())
515 {
516 polynomials.emplace_back(
517 Polynomial::mkPolynomial(Constant::mkConstant(rational)));
518 }
519
520 Polynomial poly = Polynomial::sumPolynomials(polynomials);
521
522 if (isZero(ran))
523 {
524 return RewriteResponse(REWRITE_DONE, poly.getNode());
525 }
526 if (poly.containsConstant())
527 {
528 ran += RealAlgebraicNumber(poly.getHead().getConstant().getValue());
529 if (!poly.isConstant())
530 {
531 poly = poly.getTail();
532 }
533 }
534
535 auto* nm = NodeManager::currentNM();
536 if (poly.isConstant())
537 {
538 return RewriteResponse(REWRITE_DONE, nm->mkRealAlgebraicNumber(ran));
539 }
540 return RewriteResponse(
541 REWRITE_DONE,
542 nm->mkNode(Kind::ADD, nm->mkRealAlgebraicNumber(ran), poly.getNode()));
543 }
544
545 RewriteResponse ArithRewriter::preRewriteMult(TNode node)
546 {
547 Assert(node.getKind() == kind::MULT
548 || node.getKind() == kind::NONLINEAR_MULT);
549
550 if (auto res = rewriter::getZeroChild(node); res)
551 {
552 return RewriteResponse(REWRITE_DONE, *res);
553 }
554 return RewriteResponse(REWRITE_DONE, node);
555 }
556
557 RewriteResponse ArithRewriter::postRewriteMult(TNode t){
558 Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT);
559 Assert(t.getNumChildren() >= 2);
560
561 std::vector<TNode> children;
562 expr::algorithm::flatten(t, children, Kind::MULT, Kind::NONLINEAR_MULT);
563
564 if (auto res = rewriter::getZeroChild(children); res)
565 {
566 return RewriteResponse(REWRITE_DONE, *res);
567 }
568
569 // Distribute over addition
570 if (std::any_of(children.begin(), children.end(), [](TNode child) {
571 return child.getKind() == Kind::ADD;
572 }))
573 {
574 return RewriteResponse(REWRITE_DONE,
575 rewriter::distributeMultiplication(children));
576 }
577
578 RealAlgebraicNumber ran = RealAlgebraicNumber(Integer(1));
579 std::vector<Node> leafs;
580
581 for (const auto& child : children)
582 {
583 if (child.isConst())
584 {
585 if (child.getConst<Rational>().isZero())
586 {
587 return RewriteResponse(REWRITE_DONE, child);
588 }
589 ran *= child.getConst<Rational>();
590 }
591 else if (rewriter::isRAN(child))
592 {
593 ran *= rewriter::getRAN(child);
594 }
595 else
596 {
597 leafs.emplace_back(child);
598 }
599 }
600
601 return RewriteResponse(REWRITE_DONE,
602 rewriter::mkMultTerm(ran, std::move(leafs)));
603 }
604
605 Node ArithRewriter::makeUnaryMinusNode(TNode n)
606 {
607 NodeManager* nm = NodeManager::currentNM();
608 Rational qNegOne(-1);
609 return nm->mkNode(kind::MULT, nm->mkConstRealOrInt(n.getType(), qNegOne), n);
610 }
611
612 RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre)
613 {
614 Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind() == kind::DIVISION);
615 Assert(t.getNumChildren() == 2);
616
617 Node left = t[0];
618 Node right = t[1];
619 if (right.isConst())
620 {
621 NodeManager* nm = NodeManager::currentNM();
622 const Rational& den = right.getConst<Rational>();
623
624 if (den.isZero())
625 {
626 if (t.getKind() == kind::DIVISION_TOTAL)
627 {
628 return RewriteResponse(REWRITE_DONE, nm->mkConstReal(0));
629 }
630 else
631 {
632 // This is unsupported, but this is not a good place to complain
633 return RewriteResponse(REWRITE_DONE, t);
634 }
635 }
636 Assert(den != Rational(0));
637
638 if (left.isConst())
639 {
640 const Rational& num = left.getConst<Rational>();
641 return RewriteResponse(REWRITE_DONE, nm->mkConstReal(num / den));
642 }
643 if (left.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
644 {
645 const RealAlgebraicNumber& num =
646 left.getOperator().getConst<RealAlgebraicNumber>();
647 return RewriteResponse(
648 REWRITE_DONE,
649 nm->mkRealAlgebraicNumber(num / RealAlgebraicNumber(den)));
650 }
651
652 Node result = nm->mkConstReal(den.inverse());
653 Node mult = NodeManager::currentNM()->mkNode(kind::MULT, left, result);
654 if (pre)
655 {
656 return RewriteResponse(REWRITE_DONE, mult);
657 }
658 else
659 {
660 return RewriteResponse(REWRITE_AGAIN, mult);
661 }
662 }
663 if (right.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
664 {
665 NodeManager* nm = NodeManager::currentNM();
666 const RealAlgebraicNumber& den =
667 right.getOperator().getConst<RealAlgebraicNumber>();
668 if (left.isConst())
669 {
670 const Rational& num = left.getConst<Rational>();
671 return RewriteResponse(
672 REWRITE_DONE,
673 nm->mkRealAlgebraicNumber(RealAlgebraicNumber(num) / den));
674 }
675 if (left.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
676 {
677 const RealAlgebraicNumber& num =
678 left.getOperator().getConst<RealAlgebraicNumber>();
679 return RewriteResponse(REWRITE_DONE,
680 nm->mkRealAlgebraicNumber(num / den));
681 }
682
683 Node result = nm->mkRealAlgebraicNumber(inverse(den));
684 Node mult = NodeManager::currentNM()->mkNode(kind::MULT, left, result);
685 if (pre)
686 {
687 return RewriteResponse(REWRITE_DONE, mult);
688 }
689 else
690 {
691 return RewriteResponse(REWRITE_AGAIN, mult);
692 }
693 }
694 return RewriteResponse(REWRITE_DONE, t);
695 }
696
697 RewriteResponse ArithRewriter::rewriteAbs(TNode t)
698 {
699 Assert(t.getKind() == Kind::ABS);
700 Assert(t.getNumChildren() == 1);
701
702 if (t[0].isConst())
703 {
704 const Rational& rat = t[0].getConst<Rational>();
705 if (rat >= 0)
706 {
707 return RewriteResponse(REWRITE_DONE, t[0]);
708 }
709 return RewriteResponse(
710 REWRITE_DONE,
711 NodeManager::currentNM()->mkConstRealOrInt(t[0].getType(), -rat));
712 }
713 if (t[0].getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
714 {
715 const RealAlgebraicNumber& ran =
716 t[0].getOperator().getConst<RealAlgebraicNumber>();
717 if (ran >= RealAlgebraicNumber())
718 {
719 return RewriteResponse(REWRITE_DONE, t[0]);
720 }
721 return RewriteResponse(
722 REWRITE_DONE, NodeManager::currentNM()->mkRealAlgebraicNumber(-ran));
723 }
724 return RewriteResponse(REWRITE_DONE, t);
725 }
726
727 RewriteResponse ArithRewriter::rewriteIntsDivMod(TNode t, bool pre)
728 {
729 NodeManager* nm = NodeManager::currentNM();
730 Kind k = t.getKind();
731 if (k == kind::INTS_MODULUS)
732 {
733 if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
734 {
735 // can immediately replace by INTS_MODULUS_TOTAL
736 Node ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, t[0], t[1]);
737 return returnRewrite(t, ret, Rewrite::MOD_TOTAL_BY_CONST);
738 }
739 }
740 if (k == kind::INTS_DIVISION)
741 {
742 if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
743 {
744 // can immediately replace by INTS_DIVISION_TOTAL
745 Node ret = nm->mkNode(kind::INTS_DIVISION_TOTAL, t[0], t[1]);
746 return returnRewrite(t, ret, Rewrite::DIV_TOTAL_BY_CONST);
747 }
748 }
749 return RewriteResponse(REWRITE_DONE, t);
750 }
751
752 RewriteResponse ArithRewriter::rewriteExtIntegerOp(TNode t)
753 {
754 Assert(t.getKind() == kind::TO_INTEGER || t.getKind() == kind::IS_INTEGER);
755 bool isPred = t.getKind() == kind::IS_INTEGER;
756 NodeManager* nm = NodeManager::currentNM();
757 if (t[0].isConst())
758 {
759 Node ret;
760 if (isPred)
761 {
762 ret = nm->mkConst(t[0].getConst<Rational>().isIntegral());
763 }
764 else
765 {
766 ret = nm->mkConstInt(Rational(t[0].getConst<Rational>().floor()));
767 }
768 return returnRewrite(t, ret, Rewrite::INT_EXT_CONST);
769 }
770 if (t[0].getType().isInteger())
771 {
772 Node ret = isPred ? nm->mkConst(true) : Node(t[0]);
773 return returnRewrite(t, ret, Rewrite::INT_EXT_INT);
774 }
775 if (t[0].getKind() == kind::PI)
776 {
777 Node ret = isPred ? nm->mkConst(false) : nm->mkConstReal(Rational(3));
778 return returnRewrite(t, ret, Rewrite::INT_EXT_PI);
779 }
780 return RewriteResponse(REWRITE_DONE, t);
781 }
782
783 RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre)
784 {
785 if (pre)
786 {
787 // do not rewrite at prewrite.
788 return RewriteResponse(REWRITE_DONE, t);
789 }
790 NodeManager* nm = NodeManager::currentNM();
791 Kind k = t.getKind();
792 Assert(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL);
793 TNode n = t[0];
794 TNode d = t[1];
795 bool dIsConstant = d.isConst();
796 if (dIsConstant && d.getConst<Rational>().isZero())
797 {
798 // (div x 0) ---> 0 or (mod x 0) ---> 0
799 return returnRewrite(t, nm->mkConstInt(0), Rewrite::DIV_MOD_BY_ZERO);
800 }
801 else if (dIsConstant && d.getConst<Rational>().isOne())
802 {
803 if (k == kind::INTS_MODULUS_TOTAL)
804 {
805 // (mod x 1) --> 0
806 return returnRewrite(t, nm->mkConstInt(0), Rewrite::MOD_BY_ONE);
807 }
808 Assert(k == kind::INTS_DIVISION_TOTAL);
809 // (div x 1) --> x
810 return returnRewrite(t, n, Rewrite::DIV_BY_ONE);
811 }
812 else if (dIsConstant && d.getConst<Rational>().sgn() < 0)
813 {
814 // pull negation
815 // (div x (- c)) ---> (- (div x c))
816 // (mod x (- c)) ---> (mod x c)
817 Node nn = nm->mkNode(k, t[0], nm->mkConstInt(-t[1].getConst<Rational>()));
818 Node ret = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL)
819 ? nm->mkNode(kind::NEG, nn)
820 : nn;
821 return returnRewrite(t, ret, Rewrite::DIV_MOD_PULL_NEG_DEN);
822 }
823 else if (dIsConstant && n.isConst())
824 {
825 Assert(d.getConst<Rational>().isIntegral());
826 Assert(n.getConst<Rational>().isIntegral());
827 Assert(!d.getConst<Rational>().isZero());
828 Integer di = d.getConst<Rational>().getNumerator();
829 Integer ni = n.getConst<Rational>().getNumerator();
830
831 bool isDiv = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
832
833 Integer result = isDiv ? ni.euclidianDivideQuotient(di)
834 : ni.euclidianDivideRemainder(di);
835
836 // constant evaluation
837 // (mod c1 c2) ---> c3 or (div c1 c2) ---> c3
838 Node resultNode = nm->mkConstInt(Rational(result));
839 return returnRewrite(t, resultNode, Rewrite::CONST_EVAL);
840 }
841 if (k == kind::INTS_MODULUS_TOTAL)
842 {
843 // Note these rewrites do not need to account for modulus by zero as being
844 // a UF, which is handled by the reduction of INTS_MODULUS.
845 Kind k0 = t[0].getKind();
846 if (k0 == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
847 {
848 // (mod (mod x c) c) --> (mod x c)
849 return returnRewrite(t, t[0], Rewrite::MOD_OVER_MOD);
850 }
851 else if (k0 == kind::NONLINEAR_MULT || k0 == kind::MULT || k0 == kind::ADD)
852 {
853 // can drop all
854 std::vector<Node> newChildren;
855 bool childChanged = false;
856 for (const Node& tc : t[0])
857 {
858 if (tc.getKind() == kind::INTS_MODULUS_TOTAL && tc[1] == t[1])
859 {
860 newChildren.push_back(tc[0]);
861 childChanged = true;
862 continue;
863 }
864 newChildren.push_back(tc);
865 }
866 if (childChanged)
867 {
868 // (mod (op ... (mod x c) ...) c) ---> (mod (op ... x ...) c) where
869 // op is one of { NONLINEAR_MULT, MULT, ADD }.
870 Node ret = nm->mkNode(k0, newChildren);
871 ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, ret, t[1]);
872 return returnRewrite(t, ret, Rewrite::MOD_CHILD_MOD);
873 }
874 }
875 }
876 else
877 {
878 Assert(k == kind::INTS_DIVISION_TOTAL);
879 // Note these rewrites do not need to account for division by zero as being
880 // a UF, which is handled by the reduction of INTS_DIVISION.
881 if (t[0].getKind() == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
882 {
883 // (div (mod x c) c) --> 0
884 Node ret = nm->mkConstInt(0);
885 return returnRewrite(t, ret, Rewrite::DIV_OVER_MOD);
886 }
887 }
888 return RewriteResponse(REWRITE_DONE, t);
889 }
890
891 RewriteResponse ArithRewriter::postRewriteIAnd(TNode t)
892 {
893 Assert(t.getKind() == kind::IAND);
894 size_t bsize = t.getOperator().getConst<IntAnd>().d_size;
895 NodeManager* nm = NodeManager::currentNM();
896 // if constant, we eliminate
897 if (t[0].isConst() && t[1].isConst())
898 {
899 Node iToBvop = nm->mkConst(IntToBitVector(bsize));
900 Node arg1 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[0]);
901 Node arg2 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[1]);
902 Node bvand = nm->mkNode(kind::BITVECTOR_AND, arg1, arg2);
903 Node ret = nm->mkNode(kind::BITVECTOR_TO_NAT, bvand);
904 return RewriteResponse(REWRITE_AGAIN_FULL, ret);
905 }
906 else if (t[0] > t[1])
907 {
908 // ((_ iand k) x y) ---> ((_ iand k) y x) if x > y by node ordering
909 Node ret = nm->mkNode(kind::IAND, t.getOperator(), t[1], t[0]);
910 return RewriteResponse(REWRITE_AGAIN, ret);
911 }
912 else if (t[0] == t[1])
913 {
914 // ((_ iand k) x x) ---> (mod x 2^k)
915 Node twok = nm->mkConstInt(Rational(Integer(2).pow(bsize)));
916 Node ret = nm->mkNode(kind::INTS_MODULUS, t[0], twok);
917 return RewriteResponse(REWRITE_AGAIN, ret);
918 }
919 // simplifications involving constants
920 for (unsigned i = 0; i < 2; i++)
921 {
922 if (!t[i].isConst())
923 {
924 continue;
925 }
926 if (t[i].getConst<Rational>().sgn() == 0)
927 {
928 // ((_ iand k) 0 y) ---> 0
929 return RewriteResponse(REWRITE_DONE, t[i]);
930 }
931 if (t[i].getConst<Rational>().getNumerator() == Integer(2).pow(bsize) - 1)
932 {
933 // ((_ iand k) 111...1 y) ---> (mod y 2^k)
934 Node twok = nm->mkConstInt(Rational(Integer(2).pow(bsize)));
935 Node ret = nm->mkNode(kind::INTS_MODULUS, t[1 - i], twok);
936 return RewriteResponse(REWRITE_AGAIN, ret);
937 }
938 }
939 return RewriteResponse(REWRITE_DONE, t);
940 }
941
942 RewriteResponse ArithRewriter::postRewritePow2(TNode t)
943 {
944 Assert(t.getKind() == kind::POW2);
945 NodeManager* nm = NodeManager::currentNM();
946 // if constant, we eliminate
947 if (t[0].isConst())
948 {
949 // pow2 is only supported for integers
950 Assert(t[0].getType().isInteger());
951 Integer i = t[0].getConst<Rational>().getNumerator();
952 if (i < 0)
953 {
954 return RewriteResponse(REWRITE_DONE, rewriter::mkConst(Integer(0)));
955 }
956 // (pow2 t) ---> (pow 2 t) and continue rewriting to eliminate pow
957 Node two = rewriter::mkConst(Integer(2));
958 Node ret = nm->mkNode(kind::POW, two, t[0]);
959 return RewriteResponse(REWRITE_AGAIN, ret);
960 }
961 return RewriteResponse(REWRITE_DONE, t);
962 }
963
964 RewriteResponse ArithRewriter::preRewriteTranscendental(TNode t)
965 {
966 return RewriteResponse(REWRITE_DONE, t);
967 }
968
969 RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t)
970 {
971 Trace("arith-tf-rewrite")
972 << "Rewrite transcendental function : " << t << std::endl;
973 NodeManager* nm = NodeManager::currentNM();
974 switch (t.getKind())
975 {
976 case kind::EXPONENTIAL:
977 {
978 if (t[0].isConst())
979 {
980 Node one = rewriter::mkConst(Integer(1));
981 if (t[0].getConst<Rational>().sgn() >= 0 && t[0].getType().isInteger()
982 && t[0] != one)
983 {
984 return RewriteResponse(
985 REWRITE_AGAIN,
986 nm->mkNode(kind::POW, nm->mkNode(kind::EXPONENTIAL, one), t[0]));
987 }
988 else
989 {
990 return RewriteResponse(REWRITE_DONE, t);
991 }
992 }
993 else if (t[0].getKind() == kind::ADD)
994 {
995 std::vector<Node> product;
996 for (const Node tc : t[0])
997 {
998 product.push_back(nm->mkNode(kind::EXPONENTIAL, tc));
999 }
1000 // We need to do a full rewrite here, since we can get exponentials of
1001 // constants, e.g. when we are rewriting exp(2 + x)
1002 return RewriteResponse(REWRITE_AGAIN_FULL,
1003 nm->mkNode(kind::MULT, product));
1004 }
1005 }
1006 break;
1007 case kind::SINE:
1008 if (t[0].isConst())
1009 {
1010 const Rational& rat = t[0].getConst<Rational>();
1011 if (rat.sgn() == 0)
1012 {
1013 return RewriteResponse(REWRITE_DONE, rewriter::mkConst(Integer(0)));
1014 }
1015 else if (rat.sgn() == -1)
1016 {
1017 Node ret = nm->mkNode(
1018 kind::NEG, nm->mkNode(kind::SINE, rewriter::mkConst(-rat)));
1019 return RewriteResponse(REWRITE_AGAIN_FULL, ret);
1020 }
1021 }
1022 else if ((t[0].getKind() == MULT || t[0].getKind() == NONLINEAR_MULT)
1023 && t[0][0].isConst() && t[0][0].getConst<Rational>().sgn() == -1)
1024 {
1025 // sin(-n*x) ---> -sin(n*x)
1026 std::vector<Node> mchildren(t[0].begin(), t[0].end());
1027 mchildren[0] = nm->mkConstReal(-t[0][0].getConst<Rational>());
1028 Node ret = nm->mkNode(
1029 kind::NEG,
1030 nm->mkNode(kind::SINE, nm->mkNode(t[0].getKind(), mchildren)));
1031 return RewriteResponse(REWRITE_AGAIN_FULL, ret);
1032 }
1033 else
1034 {
1035 // get the factor of PI in the argument
1036 Node pi_factor;
1037 Node pi;
1038 Node rem;
1039 std::map<Node, Node> msum;
1040 if (ArithMSum::getMonomialSum(t[0], msum))
1041 {
1042 pi = mkPi();
1043 std::map<Node, Node>::iterator itm = msum.find(pi);
1044 if (itm != msum.end())
1045 {
1046 if (itm->second.isNull())
1047 {
1048 pi_factor = rewriter::mkConst(Integer(1));
1049 }
1050 else
1051 {
1052 pi_factor = itm->second;
1053 }
1054 msum.erase(pi);
1055 if (!msum.empty())
1056 {
1057 rem = ArithMSum::mkNode(t[0].getType(), msum);
1058 }
1059 }
1060 }
1061 else
1062 {
1063 Assert(false);
1064 }
1065
1066 // if there is a factor of PI
1067 if (!pi_factor.isNull())
1068 {
1069 Trace("arith-tf-rewrite-debug")
1070 << "Process pi factor = " << pi_factor << std::endl;
1071 Rational r = pi_factor.getConst<Rational>();
1072 Rational r_abs = r.abs();
1073 Rational rone = Rational(1);
1074 Rational rtwo = Rational(2);
1075 if (r_abs > rone)
1076 {
1077 // add/substract 2*pi beyond scope
1078 Rational ra_div_two = (r_abs + rone) / rtwo;
1079 Node new_pi_factor;
1080 if (r.sgn() == 1)
1081 {
1082 new_pi_factor = nm->mkConstReal(r - rtwo * ra_div_two.floor());
1083 }
1084 else
1085 {
1086 Assert(r.sgn() == -1);
1087 new_pi_factor = nm->mkConstReal(r + rtwo * ra_div_two.floor());
1088 }
1089 Node new_arg = nm->mkNode(kind::MULT, new_pi_factor, pi);
1090 if (!rem.isNull())
1091 {
1092 new_arg = nm->mkNode(kind::ADD, new_arg, rem);
1093 }
1094 // sin( 2*n*PI + x ) = sin( x )
1095 return RewriteResponse(REWRITE_AGAIN_FULL,
1096 nm->mkNode(kind::SINE, new_arg));
1097 }
1098 else if (r_abs == rone)
1099 {
1100 // sin( PI + x ) = -sin( x )
1101 if (rem.isNull())
1102 {
1103 return RewriteResponse(REWRITE_DONE,
1104 nm->mkConstReal(Rational(0)));
1105 }
1106 else
1107 {
1108 return RewriteResponse(
1109 REWRITE_AGAIN_FULL,
1110 nm->mkNode(kind::NEG, nm->mkNode(kind::SINE, rem)));
1111 }
1112 }
1113 else if (rem.isNull())
1114 {
1115 // other rational cases based on Niven's theorem
1116 // (https://en.wikipedia.org/wiki/Niven%27s_theorem)
1117 Integer one = Integer(1);
1118 Integer two = Integer(2);
1119 Integer six = Integer(6);
1120 if (r_abs.getDenominator() == two)
1121 {
1122 Assert(r_abs.getNumerator() == one);
1123 return RewriteResponse(REWRITE_DONE,
1124 nm->mkConstReal(Rational(r.sgn())));
1125 }
1126 else if (r_abs.getDenominator() == six)
1127 {
1128 Integer five = Integer(5);
1129 if (r_abs.getNumerator() == one || r_abs.getNumerator() == five)
1130 {
1131 return RewriteResponse(
1132 REWRITE_DONE,
1133 nm->mkConstReal(Rational(r.sgn()) / Rational(2)));
1134 }
1135 }
1136 }
1137 }
1138 }
1139 break;
1140 case kind::COSINE:
1141 {
1142 return RewriteResponse(
1143 REWRITE_AGAIN_FULL,
1144 nm->mkNode(
1145 kind::SINE,
1146 nm->mkNode(kind::SUB,
1147 nm->mkNode(kind::MULT,
1148 nm->mkConstReal(Rational(1) / Rational(2)),
1149 mkPi()),
1150 t[0])));
1151 }
1152 break;
1153 case kind::TANGENT:
1154 {
1155 return RewriteResponse(REWRITE_AGAIN_FULL,
1156 nm->mkNode(kind::DIVISION,
1157 nm->mkNode(kind::SINE, t[0]),
1158 nm->mkNode(kind::COSINE, t[0])));
1159 }
1160 break;
1161 case kind::COSECANT:
1162 {
1163 return RewriteResponse(REWRITE_AGAIN_FULL,
1164 nm->mkNode(kind::DIVISION,
1165 nm->mkConstReal(Rational(1)),
1166 nm->mkNode(kind::SINE, t[0])));
1167 }
1168 break;
1169 case kind::SECANT:
1170 {
1171 return RewriteResponse(REWRITE_AGAIN_FULL,
1172 nm->mkNode(kind::DIVISION,
1173 nm->mkConstReal(Rational(1)),
1174 nm->mkNode(kind::COSINE, t[0])));
1175 }
1176 break;
1177 case kind::COTANGENT:
1178 {
1179 return RewriteResponse(REWRITE_AGAIN_FULL,
1180 nm->mkNode(kind::DIVISION,
1181 nm->mkNode(kind::COSINE, t[0]),
1182 nm->mkNode(kind::SINE, t[0])));
1183 }
1184 break;
1185 default: break;
1186 }
1187 return RewriteResponse(REWRITE_DONE, t);
1188 }
1189
1190 TrustNode ArithRewriter::expandDefinition(Node node)
1191 {
1192 // call eliminate operators, to eliminate partial operators only
1193 std::vector<SkolemLemma> lems;
1194 TrustNode ret = d_opElim.eliminate(node, lems, true);
1195 Assert(lems.empty());
1196 return ret;
1197 }
1198
1199 RewriteResponse ArithRewriter::returnRewrite(TNode t, Node ret, Rewrite r)
1200 {
1201 Trace("arith-rewriter") << "ArithRewriter : " << t << " == " << ret << " by "
1202 << r << std::endl;
1203 return RewriteResponse(REWRITE_AGAIN_FULL, ret);
1204 }
1205
1206 } // namespace arith
1207 } // namespace theory
1208 } // namespace cvc5