Refactor rewriteMinus (#7932)
[cvc5.git] / src / theory / arith / arith_rewriter.cpp
1 /******************************************************************************
2 * Top contributors (to current version):
3 * Andrew Reynolds, Tim King, Morgan Deters
4 *
5 * This file is part of the cvc5 project.
6 *
7 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 * in the top-level source directory and their institutional affiliations.
9 * All rights reserved. See the file COPYING in the top-level source
10 * directory for licensing information.
11 * ****************************************************************************
12 *
13 * [[ Add one-line brief description here ]]
14 *
15 * [[ Add lengthier description here ]]
16 * \todo document this file
17 */
18
19 #include "theory/arith/arith_rewriter.h"
20
21 #include <set>
22 #include <sstream>
23 #include <stack>
24 #include <vector>
25
26 #include "smt/logic_exception.h"
27 #include "theory/arith/arith_msum.h"
28 #include "theory/arith/arith_utilities.h"
29 #include "theory/arith/normal_form.h"
30 #include "theory/arith/operator_elim.h"
31 #include "theory/theory.h"
32 #include "util/bitvector.h"
33 #include "util/divisible.h"
34 #include "util/iand.h"
35
36 using namespace cvc5::kind;
37
38 namespace cvc5 {
39 namespace theory {
40 namespace arith {
41
42 ArithRewriter::ArithRewriter(OperatorElim& oe) : d_opElim(oe) {}
43
44 bool ArithRewriter::isAtom(TNode n) {
45 Kind k = n.getKind();
46 return arith::isRelationOperator(k) || k == kind::IS_INTEGER
47 || k == kind::DIVISIBLE;
48 }
49
50 RewriteResponse ArithRewriter::rewriteConstant(TNode t){
51 Assert(t.isConst());
52 Assert(t.getKind() == CONST_RATIONAL || t.getKind() == CONST_INTEGER);
53
54 return RewriteResponse(REWRITE_DONE, t);
55 }
56
57 RewriteResponse ArithRewriter::rewriteVariable(TNode t){
58 Assert(t.isVar());
59
60 return RewriteResponse(REWRITE_DONE, t);
61 }
62
63 RewriteResponse ArithRewriter::rewriteMinus(TNode t)
64 {
65 Assert(t.getKind() == kind::MINUS);
66 Assert(t.getNumChildren() == 2);
67
68 auto* nm = NodeManager::currentNM();
69
70 if (t[0] == t[1])
71 {
72 return RewriteResponse(REWRITE_DONE,
73 nm->mkConstRealOrInt(t.getType(), Rational(0)));
74 }
75 return RewriteResponse(
76 REWRITE_AGAIN_FULL,
77 nm->mkNode(Kind::PLUS, t[0], makeUnaryMinusNode(t[1])));
78 }
79
80 RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){
81 Assert(t.getKind() == kind::UMINUS);
82
83 if (t[0].isConst())
84 {
85 Rational neg = -(t[0].getConst<Rational>());
86 NodeManager* nm = NodeManager::currentNM();
87 return RewriteResponse(REWRITE_DONE,
88 nm->mkConstRealOrInt(t[0].getType(), neg));
89 }
90
91 Node noUminus = makeUnaryMinusNode(t[0]);
92 if(pre)
93 return RewriteResponse(REWRITE_DONE, noUminus);
94 else
95 return RewriteResponse(REWRITE_AGAIN, noUminus);
96 }
97
98 RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
99 if(t.isConst()){
100 return rewriteConstant(t);
101 }else if(t.isVar()){
102 return rewriteVariable(t);
103 }else{
104 switch(Kind k = t.getKind()){
105 case kind::MINUS: return rewriteMinus(t);
106 case kind::UMINUS: return rewriteUMinus(t, true);
107 case kind::DIVISION:
108 case kind::DIVISION_TOTAL: return rewriteDiv(t, true);
109 case kind::PLUS: return preRewritePlus(t);
110 case kind::MULT:
111 case kind::NONLINEAR_MULT: return preRewriteMult(t);
112 case kind::IAND: return RewriteResponse(REWRITE_DONE, t);
113 case kind::POW2: return RewriteResponse(REWRITE_DONE, t);
114 case kind::EXPONENTIAL:
115 case kind::SINE:
116 case kind::COSINE:
117 case kind::TANGENT:
118 case kind::COSECANT:
119 case kind::SECANT:
120 case kind::COTANGENT:
121 case kind::ARCSINE:
122 case kind::ARCCOSINE:
123 case kind::ARCTANGENT:
124 case kind::ARCCOSECANT:
125 case kind::ARCSECANT:
126 case kind::ARCCOTANGENT:
127 case kind::SQRT: return preRewriteTranscendental(t);
128 case kind::INTS_DIVISION:
129 case kind::INTS_MODULUS: return rewriteIntsDivMod(t, true);
130 case kind::INTS_DIVISION_TOTAL:
131 case kind::INTS_MODULUS_TOTAL: return rewriteIntsDivModTotal(t, true);
132 case kind::ABS:
133 if (t[0].isConst())
134 {
135 const Rational& rat = t[0].getConst<Rational>();
136 if (rat >= 0)
137 {
138 return RewriteResponse(REWRITE_DONE, t[0]);
139 }
140 else
141 {
142 return RewriteResponse(REWRITE_DONE,
143 NodeManager::currentNM()->mkConstRealOrInt(
144 t[0].getType(), -rat));
145 }
146 }
147 return RewriteResponse(REWRITE_DONE, t);
148 case kind::IS_INTEGER:
149 case kind::TO_INTEGER: return RewriteResponse(REWRITE_DONE, t);
150 case kind::TO_REAL:
151 case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
152 case kind::POW: return RewriteResponse(REWRITE_DONE, t);
153 case kind::PI: return RewriteResponse(REWRITE_DONE, t);
154 default: Unhandled() << k;
155 }
156 }
157 }
158
159 RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
160 if(t.isConst()){
161 return rewriteConstant(t);
162 }else if(t.isVar()){
163 return rewriteVariable(t);
164 }else{
165 Trace("arith-rewriter") << "postRewriteTerm: " << t << std::endl;
166 switch(t.getKind()){
167 case kind::MINUS: return rewriteMinus(t);
168 case kind::UMINUS: return rewriteUMinus(t, false);
169 case kind::DIVISION:
170 case kind::DIVISION_TOTAL: return rewriteDiv(t, false);
171 case kind::PLUS: return postRewritePlus(t);
172 case kind::MULT:
173 case kind::NONLINEAR_MULT: return postRewriteMult(t);
174 case kind::IAND: return postRewriteIAnd(t);
175 case kind::POW2: return postRewritePow2(t);
176 case kind::EXPONENTIAL:
177 case kind::SINE:
178 case kind::COSINE:
179 case kind::TANGENT:
180 case kind::COSECANT:
181 case kind::SECANT:
182 case kind::COTANGENT:
183 case kind::ARCSINE:
184 case kind::ARCCOSINE:
185 case kind::ARCTANGENT:
186 case kind::ARCCOSECANT:
187 case kind::ARCSECANT:
188 case kind::ARCCOTANGENT:
189 case kind::SQRT: return postRewriteTranscendental(t);
190 case kind::INTS_DIVISION:
191 case kind::INTS_MODULUS: return rewriteIntsDivMod(t, false);
192 case kind::INTS_DIVISION_TOTAL:
193 case kind::INTS_MODULUS_TOTAL: return rewriteIntsDivModTotal(t, false);
194 case kind::ABS:
195 if (t[0].isConst())
196 {
197 const Rational& rat = t[0].getConst<Rational>();
198 if (rat >= 0)
199 {
200 return RewriteResponse(REWRITE_DONE, t[0]);
201 }
202 else
203 {
204 return RewriteResponse(REWRITE_DONE,
205 NodeManager::currentNM()->mkConstRealOrInt(
206 t[0].getType(), -rat));
207 }
208 }
209 return RewriteResponse(REWRITE_DONE, t);
210 case kind::TO_REAL:
211 case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
212 case kind::TO_INTEGER: return rewriteExtIntegerOp(t);
213 case kind::POW:
214 {
215 if (t[1].isConst())
216 {
217 const Rational& exp = t[1].getConst<Rational>();
218 TNode base = t[0];
219 if(exp.sgn() == 0){
220 return RewriteResponse(REWRITE_DONE,
221 NodeManager::currentNM()->mkConstRealOrInt(
222 t.getType(), Rational(1)));
223 }else if(exp.sgn() > 0 && exp.isIntegral()){
224 cvc5::Rational r(expr::NodeValue::MAX_CHILDREN);
225 if (exp <= r)
226 {
227 unsigned num = exp.getNumerator().toUnsignedInt();
228 if( num==1 ){
229 return RewriteResponse(REWRITE_AGAIN, base);
230 }else{
231 NodeBuilder nb(kind::MULT);
232 for(unsigned i=0; i < num; ++i){
233 nb << base;
234 }
235 Assert(nb.getNumChildren() > 0);
236 Node mult = nb;
237 return RewriteResponse(REWRITE_AGAIN, mult);
238 }
239 }
240 }
241 }
242 else if (t[0].isConst()
243 && t[0].getConst<Rational>().getNumerator().toUnsignedInt()
244 == 2)
245 {
246 return RewriteResponse(
247 REWRITE_DONE, NodeManager::currentNM()->mkNode(kind::POW2, t[1]));
248 }
249
250 // Todo improve the exception thrown
251 std::stringstream ss;
252 ss << "The exponent of the POW(^) operator can only be a positive "
253 "integral constant below "
254 << (expr::NodeValue::MAX_CHILDREN + 1) << ". ";
255 ss << "Exception occurred in:" << std::endl;
256 ss << " " << t;
257 throw LogicException(ss.str());
258 }
259 case kind::PI:
260 return RewriteResponse(REWRITE_DONE, t);
261 default:
262 Unreachable();
263 }
264 }
265 }
266
267
268 RewriteResponse ArithRewriter::preRewriteMult(TNode t){
269 Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT);
270
271 if(t.getNumChildren() == 2){
272 if (t[0].isConst() && t[0].getConst<Rational>().isOne())
273 {
274 return RewriteResponse(REWRITE_DONE, t[1]);
275 }
276 if (t[1].isConst() && t[1].getConst<Rational>().isOne())
277 {
278 return RewriteResponse(REWRITE_DONE, t[0]);
279 }
280 }
281
282 // Rewrite multiplications with a 0 argument and to 0
283 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
284 if ((*i).isConst())
285 {
286 if((*i).getConst<Rational>().isZero()) {
287 TNode zero = (*i);
288 return RewriteResponse(REWRITE_DONE, zero);
289 }
290 }
291 }
292 return RewriteResponse(REWRITE_DONE, t);
293 }
294
295 static bool canFlatten(Kind k, TNode t){
296 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
297 TNode child = *i;
298 if(child.getKind() == k){
299 return true;
300 }
301 }
302 return false;
303 }
304
305 static void flatten(std::vector<TNode>& pb, Kind k, TNode t){
306 if(t.getKind() == k){
307 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
308 TNode child = *i;
309 if(child.getKind() == k){
310 flatten(pb, k, child);
311 }else{
312 pb.push_back(child);
313 }
314 }
315 }else{
316 pb.push_back(t);
317 }
318 }
319
320 static Node flatten(Kind k, TNode t){
321 std::vector<TNode> pb;
322 flatten(pb, k, t);
323 Assert(pb.size() >= 2);
324 return NodeManager::currentNM()->mkNode(k, pb);
325 }
326
327 RewriteResponse ArithRewriter::preRewritePlus(TNode t){
328 Assert(t.getKind() == kind::PLUS);
329
330 if(canFlatten(kind::PLUS, t)){
331 return RewriteResponse(REWRITE_DONE, flatten(kind::PLUS, t));
332 }else{
333 return RewriteResponse(REWRITE_DONE, t);
334 }
335 }
336
337 RewriteResponse ArithRewriter::postRewritePlus(TNode t){
338 Assert(t.getKind() == kind::PLUS);
339
340 std::vector<Monomial> monomials;
341 std::vector<Polynomial> polynomials;
342
343 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
344 TNode curr = *i;
345 if(Monomial::isMember(curr)){
346 monomials.push_back(Monomial::parseMonomial(curr));
347 }else{
348 polynomials.push_back(Polynomial::parsePolynomial(curr));
349 }
350 }
351
352 if(!monomials.empty()){
353 Monomial::sort(monomials);
354 Monomial::combineAdjacentMonomials(monomials);
355 polynomials.push_back(Polynomial::mkPolynomial(monomials));
356 }
357
358 Polynomial res = Polynomial::sumPolynomials(polynomials);
359
360 return RewriteResponse(REWRITE_DONE, res.getNode());
361 }
362
363 RewriteResponse ArithRewriter::postRewriteMult(TNode t){
364 Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT);
365
366 Polynomial res = Polynomial::mkOne();
367
368 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
369 Node curr = *i;
370 Polynomial currPoly = Polynomial::parsePolynomial(curr);
371
372 res = res * currPoly;
373 }
374
375 return RewriteResponse(REWRITE_DONE, res.getNode());
376 }
377
378 RewriteResponse ArithRewriter::postRewritePow2(TNode t)
379 {
380 Assert(t.getKind() == kind::POW2);
381 NodeManager* nm = NodeManager::currentNM();
382 // if constant, we eliminate
383 if (t[0].isConst())
384 {
385 // pow2 is only supported for integers
386 Assert(t[0].getType().isInteger());
387 Integer i = t[0].getConst<Rational>().getNumerator();
388 if (i < 0)
389 {
390 return RewriteResponse(REWRITE_DONE, nm->mkConstInt(Rational(0)));
391 }
392 // (pow2 t) ---> (pow 2 t) and continue rewriting to eliminate pow
393 Node two = nm->mkConstInt(Rational(Integer(2)));
394 Node ret = nm->mkNode(kind::POW, two, t[0]);
395 return RewriteResponse(REWRITE_AGAIN, ret);
396 }
397 return RewriteResponse(REWRITE_DONE, t);
398 }
399
400 RewriteResponse ArithRewriter::postRewriteIAnd(TNode t)
401 {
402 Assert(t.getKind() == kind::IAND);
403 size_t bsize = t.getOperator().getConst<IntAnd>().d_size;
404 NodeManager* nm = NodeManager::currentNM();
405 // if constant, we eliminate
406 if (t[0].isConst() && t[1].isConst())
407 {
408 Node iToBvop = nm->mkConst(IntToBitVector(bsize));
409 Node arg1 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[0]);
410 Node arg2 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[1]);
411 Node bvand = nm->mkNode(kind::BITVECTOR_AND, arg1, arg2);
412 Node ret = nm->mkNode(kind::BITVECTOR_TO_NAT, bvand);
413 return RewriteResponse(REWRITE_AGAIN_FULL, ret);
414 }
415 else if (t[0] > t[1])
416 {
417 // ((_ iand k) x y) ---> ((_ iand k) y x) if x > y by node ordering
418 Node ret = nm->mkNode(kind::IAND, t.getOperator(), t[1], t[0]);
419 return RewriteResponse(REWRITE_AGAIN, ret);
420 }
421 else if (t[0] == t[1])
422 {
423 // ((_ iand k) x x) ---> x
424 return RewriteResponse(REWRITE_DONE, t[0]);
425 }
426 // simplifications involving constants
427 for (unsigned i = 0; i < 2; i++)
428 {
429 if (!t[i].isConst())
430 {
431 continue;
432 }
433 if (t[i].getConst<Rational>().sgn() == 0)
434 {
435 // ((_ iand k) 0 y) ---> 0
436 return RewriteResponse(REWRITE_DONE, t[i]);
437 }
438 if (t[i].getConst<Rational>().getNumerator() == Integer(2).pow(bsize) - 1)
439 {
440 // ((_ iand k) 111...1 y) ---> y
441 return RewriteResponse(REWRITE_DONE, t[i == 0 ? 1 : 0]);
442 }
443 }
444 return RewriteResponse(REWRITE_DONE, t);
445 }
446
447 RewriteResponse ArithRewriter::preRewriteTranscendental(TNode t) {
448 return RewriteResponse(REWRITE_DONE, t);
449 }
450
451 RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) {
452 Trace("arith-tf-rewrite") << "Rewrite transcendental function : " << t << std::endl;
453 NodeManager* nm = NodeManager::currentNM();
454 switch( t.getKind() ){
455 case kind::EXPONENTIAL: {
456 if (t[0].isConst())
457 {
458 Node one = nm->mkConstReal(Rational(1));
459 if(t[0].getConst<Rational>().sgn()>=0 && t[0].getType().isInteger() && t[0]!=one){
460 return RewriteResponse(
461 REWRITE_AGAIN,
462 nm->mkNode(kind::POW, nm->mkNode(kind::EXPONENTIAL, one), t[0]));
463 }else{
464 return RewriteResponse(REWRITE_DONE, t);
465 }
466 }
467 else if (t[0].getKind() == kind::PLUS)
468 {
469 std::vector<Node> product;
470 for (const Node tc : t[0])
471 {
472 product.push_back(nm->mkNode(kind::EXPONENTIAL, tc));
473 }
474 // We need to do a full rewrite here, since we can get exponentials of
475 // constants, e.g. when we are rewriting exp(2 + x)
476 return RewriteResponse(REWRITE_AGAIN_FULL,
477 nm->mkNode(kind::MULT, product));
478 }
479 }
480 break;
481 case kind::SINE:
482 if (t[0].isConst())
483 {
484 const Rational& rat = t[0].getConst<Rational>();
485 if(rat.sgn() == 0){
486 return RewriteResponse(REWRITE_DONE, nm->mkConstReal(Rational(0)));
487 }
488 else if (rat.sgn() == -1)
489 {
490 Node ret = nm->mkNode(kind::UMINUS,
491 nm->mkNode(kind::SINE, nm->mkConstReal(-rat)));
492 return RewriteResponse(REWRITE_AGAIN_FULL, ret);
493 }
494 }else{
495 // get the factor of PI in the argument
496 Node pi_factor;
497 Node pi;
498 Node rem;
499 std::map<Node, Node> msum;
500 if (ArithMSum::getMonomialSum(t[0], msum))
501 {
502 pi = mkPi();
503 std::map<Node, Node>::iterator itm = msum.find(pi);
504 if (itm != msum.end())
505 {
506 if (itm->second.isNull())
507 {
508 pi_factor = nm->mkConstReal(Rational(1));
509 }
510 else
511 {
512 pi_factor = itm->second;
513 }
514 msum.erase(pi);
515 if (!msum.empty())
516 {
517 rem = ArithMSum::mkNode(t[0].getType(), msum);
518 }
519 }
520 }
521 else
522 {
523 Assert(false);
524 }
525
526 // if there is a factor of PI
527 if( !pi_factor.isNull() ){
528 Trace("arith-tf-rewrite-debug") << "Process pi factor = " << pi_factor << std::endl;
529 Rational r = pi_factor.getConst<Rational>();
530 Rational r_abs = r.abs();
531 Rational rone = Rational(1);
532 Node ntwo = nm->mkConstInt(Rational(2));
533 if (r_abs > rone)
534 {
535 //add/substract 2*pi beyond scope
536 Node ra_div_two = nm->mkNode(
537 kind::INTS_DIVISION, mkRationalNode(r_abs + rone), ntwo);
538 Node new_pi_factor;
539 if( r.sgn()==1 ){
540 new_pi_factor =
541 nm->mkNode(kind::MINUS,
542 pi_factor,
543 nm->mkNode(kind::MULT, ntwo, ra_div_two));
544 }else{
545 Assert(r.sgn() == -1);
546 new_pi_factor =
547 nm->mkNode(kind::PLUS,
548 pi_factor,
549 nm->mkNode(kind::MULT, ntwo, ra_div_two));
550 }
551 Node new_arg = nm->mkNode(kind::MULT, new_pi_factor, pi);
552 if (!rem.isNull())
553 {
554 new_arg = nm->mkNode(kind::PLUS, new_arg, rem);
555 }
556 // sin( 2*n*PI + x ) = sin( x )
557 return RewriteResponse(REWRITE_AGAIN_FULL,
558 nm->mkNode(kind::SINE, new_arg));
559 }
560 else if (r_abs == rone)
561 {
562 // sin( PI + x ) = -sin( x )
563 if (rem.isNull())
564 {
565 return RewriteResponse(REWRITE_DONE, nm->mkConstReal(Rational(0)));
566 }
567 else
568 {
569 return RewriteResponse(
570 REWRITE_AGAIN_FULL,
571 nm->mkNode(kind::UMINUS, nm->mkNode(kind::SINE, rem)));
572 }
573 }
574 else if (rem.isNull())
575 {
576 // other rational cases based on Niven's theorem
577 // (https://en.wikipedia.org/wiki/Niven%27s_theorem)
578 Integer one = Integer(1);
579 Integer two = Integer(2);
580 Integer six = Integer(6);
581 if (r_abs.getDenominator() == two)
582 {
583 Assert(r_abs.getNumerator() == one);
584 return RewriteResponse(REWRITE_DONE,
585 nm->mkConstReal(Rational(r.sgn())));
586 }
587 else if (r_abs.getDenominator() == six)
588 {
589 Integer five = Integer(5);
590 if (r_abs.getNumerator() == one || r_abs.getNumerator() == five)
591 {
592 return RewriteResponse(
593 REWRITE_DONE,
594 nm->mkConstReal(Rational(r.sgn()) / Rational(2)));
595 }
596 }
597 }
598 }
599 }
600 break;
601 case kind::COSINE: {
602 return RewriteResponse(
603 REWRITE_AGAIN_FULL,
604 nm->mkNode(
605 kind::SINE,
606 nm->mkNode(kind::MINUS,
607 nm->mkNode(kind::MULT,
608 nm->mkConstReal(Rational(1) / Rational(2)),
609 mkPi()),
610 t[0])));
611 }
612 break;
613 case kind::TANGENT:
614 {
615 return RewriteResponse(REWRITE_AGAIN_FULL,
616 nm->mkNode(kind::DIVISION,
617 nm->mkNode(kind::SINE, t[0]),
618 nm->mkNode(kind::COSINE, t[0])));
619 }
620 break;
621 case kind::COSECANT:
622 {
623 return RewriteResponse(REWRITE_AGAIN_FULL,
624 nm->mkNode(kind::DIVISION,
625 nm->mkConstReal(Rational(1)),
626 nm->mkNode(kind::SINE, t[0])));
627 }
628 break;
629 case kind::SECANT:
630 {
631 return RewriteResponse(REWRITE_AGAIN_FULL,
632 nm->mkNode(kind::DIVISION,
633 nm->mkConstReal(Rational(1)),
634 nm->mkNode(kind::COSINE, t[0])));
635 }
636 break;
637 case kind::COTANGENT:
638 {
639 return RewriteResponse(REWRITE_AGAIN_FULL,
640 nm->mkNode(kind::DIVISION,
641 nm->mkNode(kind::COSINE, t[0]),
642 nm->mkNode(kind::SINE, t[0])));
643 }
644 break;
645 default:
646 break;
647 }
648 return RewriteResponse(REWRITE_DONE, t);
649 }
650
651 RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
652 if(atom.getKind() == kind::IS_INTEGER) {
653 return rewriteExtIntegerOp(atom);
654 } else if(atom.getKind() == kind::DIVISIBLE) {
655 if(atom[0].isConst()) {
656 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(bool((atom[0].getConst<Rational>() / atom.getOperator().getConst<Divisible>().k).isIntegral())));
657 }
658 if(atom.getOperator().getConst<Divisible>().k.isOne()) {
659 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
660 }
661 NodeManager* nm = NodeManager::currentNM();
662 return RewriteResponse(
663 REWRITE_AGAIN,
664 nm->mkNode(kind::EQUAL,
665 nm->mkNode(kind::INTS_MODULUS_TOTAL,
666 atom[0],
667 nm->mkConstInt(Rational(
668 atom.getOperator().getConst<Divisible>().k))),
669 nm->mkConstInt(Rational(0))));
670 }
671
672 // left |><| right
673 TNode left = atom[0];
674 TNode right = atom[1];
675
676 Polynomial pleft = Polynomial::parsePolynomial(left);
677 Polynomial pright = Polynomial::parsePolynomial(right);
678
679 Debug("arith::rewriter") << "pleft " << pleft.getNode() << std::endl;
680 Debug("arith::rewriter") << "pright " << pright.getNode() << std::endl;
681
682 Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
683 Assert(cmp.isNormalForm());
684 return RewriteResponse(REWRITE_DONE, cmp.getNode());
685 }
686
687 RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
688 Assert(isAtom(atom));
689
690 NodeManager* currNM = NodeManager::currentNM();
691
692 if(atom.getKind() == kind::EQUAL) {
693 if(atom[0] == atom[1]) {
694 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
695 }
696 }else if(atom.getKind() == kind::GT){
697 Node leq = currNM->mkNode(kind::LEQ, atom[0], atom[1]);
698 return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, leq));
699 }else if(atom.getKind() == kind::LT){
700 Node geq = currNM->mkNode(kind::GEQ, atom[0], atom[1]);
701 return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, geq));
702 }else if(atom.getKind() == kind::IS_INTEGER){
703 if(atom[0].getType().isInteger()){
704 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
705 }
706 }else if(atom.getKind() == kind::DIVISIBLE){
707 if(atom.getOperator().getConst<Divisible>().k.isOne()){
708 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
709 }
710 }
711
712 return RewriteResponse(REWRITE_DONE, atom);
713 }
714
715 RewriteResponse ArithRewriter::postRewrite(TNode t){
716 if(isTerm(t)){
717 RewriteResponse response = postRewriteTerm(t);
718 if (Debug.isOn("arith::rewriter") && response.d_status == REWRITE_DONE)
719 {
720 Polynomial::parsePolynomial(response.d_node);
721 }
722 return response;
723 }else if(isAtom(t)){
724 RewriteResponse response = postRewriteAtom(t);
725 if (Debug.isOn("arith::rewriter") && response.d_status == REWRITE_DONE)
726 {
727 Comparison::parseNormalForm(response.d_node);
728 }
729 return response;
730 }else{
731 Unreachable();
732 }
733 }
734
735 RewriteResponse ArithRewriter::preRewrite(TNode t){
736 if(isTerm(t)){
737 return preRewriteTerm(t);
738 }else if(isAtom(t)){
739 return preRewriteAtom(t);
740 }else{
741 Unreachable();
742 }
743 }
744
745 Node ArithRewriter::makeUnaryMinusNode(TNode n){
746 NodeManager* nm = NodeManager::currentNM();
747 Rational qNegOne(-1);
748 return nm->mkNode(kind::MULT, nm->mkConstRealOrInt(n.getType(), qNegOne), n);
749 }
750
751 RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
752 Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind() == kind::DIVISION);
753
754 Node left = t[0];
755 Node right = t[1];
756 if (right.isConst())
757 {
758 NodeManager* nm = NodeManager::currentNM();
759 const Rational& den = right.getConst<Rational>();
760
761 if(den.isZero()){
762 if(t.getKind() == kind::DIVISION_TOTAL){
763 return RewriteResponse(REWRITE_DONE, nm->mkConstReal(0));
764 }else{
765 // This is unsupported, but this is not a good place to complain
766 return RewriteResponse(REWRITE_DONE, t);
767 }
768 }
769 Assert(den != Rational(0));
770
771 if (left.isConst())
772 {
773 const Rational& num = left.getConst<Rational>();
774 Rational div = num / den;
775 Node result = nm->mkConstReal(div);
776 return RewriteResponse(REWRITE_DONE, result);
777 }
778
779 Rational div = den.inverse();
780
781 Node result = nm->mkConstReal(div);
782
783 Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
784 if(pre){
785 return RewriteResponse(REWRITE_DONE, mult);
786 }else{
787 return RewriteResponse(REWRITE_AGAIN, mult);
788 }
789 }
790 return RewriteResponse(REWRITE_DONE, t);
791 }
792
793 RewriteResponse ArithRewriter::rewriteIntsDivMod(TNode t, bool pre)
794 {
795 NodeManager* nm = NodeManager::currentNM();
796 Kind k = t.getKind();
797 if (k == kind::INTS_MODULUS)
798 {
799 if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
800 {
801 // can immediately replace by INTS_MODULUS_TOTAL
802 Node ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, t[0], t[1]);
803 return returnRewrite(t, ret, Rewrite::MOD_TOTAL_BY_CONST);
804 }
805 }
806 if (k == kind::INTS_DIVISION)
807 {
808 if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
809 {
810 // can immediately replace by INTS_DIVISION_TOTAL
811 Node ret = nm->mkNode(kind::INTS_DIVISION_TOTAL, t[0], t[1]);
812 return returnRewrite(t, ret, Rewrite::DIV_TOTAL_BY_CONST);
813 }
814 }
815 return RewriteResponse(REWRITE_DONE, t);
816 }
817
818 RewriteResponse ArithRewriter::rewriteExtIntegerOp(TNode t)
819 {
820 Assert(t.getKind() == kind::TO_INTEGER || t.getKind() == kind::IS_INTEGER);
821 bool isPred = t.getKind() == kind::IS_INTEGER;
822 NodeManager* nm = NodeManager::currentNM();
823 if (t[0].isConst())
824 {
825 Node ret;
826 if (isPred)
827 {
828 ret = nm->mkConst(t[0].getConst<Rational>().isIntegral());
829 }
830 else
831 {
832 ret = nm->mkConstInt(Rational(t[0].getConst<Rational>().floor()));
833 }
834 return returnRewrite(t, ret, Rewrite::INT_EXT_CONST);
835 }
836 if (t[0].getType().isInteger())
837 {
838 Node ret = isPred ? nm->mkConst(true) : Node(t[0]);
839 return returnRewrite(t, ret, Rewrite::INT_EXT_INT);
840 }
841 if (t[0].getKind() == kind::PI)
842 {
843 Node ret = isPred ? nm->mkConst(false) : nm->mkConstReal(Rational(3));
844 return returnRewrite(t, ret, Rewrite::INT_EXT_PI);
845 }
846 return RewriteResponse(REWRITE_DONE, t);
847 }
848
849 RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre)
850 {
851 if (pre)
852 {
853 // do not rewrite at prewrite.
854 return RewriteResponse(REWRITE_DONE, t);
855 }
856 NodeManager* nm = NodeManager::currentNM();
857 Kind k = t.getKind();
858 Assert(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL);
859 TNode n = t[0];
860 TNode d = t[1];
861 bool dIsConstant = d.isConst();
862 if(dIsConstant && d.getConst<Rational>().isZero()){
863 // (div x 0) ---> 0 or (mod x 0) ---> 0
864 return returnRewrite(t, nm->mkConstInt(0), Rewrite::DIV_MOD_BY_ZERO);
865 }else if(dIsConstant && d.getConst<Rational>().isOne()){
866 if (k == kind::INTS_MODULUS_TOTAL)
867 {
868 // (mod x 1) --> 0
869 return returnRewrite(t, nm->mkConstInt(0), Rewrite::MOD_BY_ONE);
870 }
871 Assert(k == kind::INTS_DIVISION_TOTAL);
872 // (div x 1) --> x
873 return returnRewrite(t, n, Rewrite::DIV_BY_ONE);
874 }
875 else if (dIsConstant && d.getConst<Rational>().sgn() < 0)
876 {
877 // pull negation
878 // (div x (- c)) ---> (- (div x c))
879 // (mod x (- c)) ---> (mod x c)
880 Node nn = nm->mkNode(k, t[0], nm->mkConstInt(-t[1].getConst<Rational>()));
881 Node ret = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL)
882 ? nm->mkNode(kind::UMINUS, nn)
883 : nn;
884 return returnRewrite(t, ret, Rewrite::DIV_MOD_PULL_NEG_DEN);
885 }
886 else if (dIsConstant && n.isConst())
887 {
888 Assert(d.getConst<Rational>().isIntegral());
889 Assert(n.getConst<Rational>().isIntegral());
890 Assert(!d.getConst<Rational>().isZero());
891 Integer di = d.getConst<Rational>().getNumerator();
892 Integer ni = n.getConst<Rational>().getNumerator();
893
894 bool isDiv = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
895
896 Integer result = isDiv ? ni.euclidianDivideQuotient(di) : ni.euclidianDivideRemainder(di);
897
898 // constant evaluation
899 // (mod c1 c2) ---> c3 or (div c1 c2) ---> c3
900 Node resultNode = nm->mkConstInt(Rational(result));
901 return returnRewrite(t, resultNode, Rewrite::CONST_EVAL);
902 }
903 if (k == kind::INTS_MODULUS_TOTAL)
904 {
905 // Note these rewrites do not need to account for modulus by zero as being
906 // a UF, which is handled by the reduction of INTS_MODULUS.
907 Kind k0 = t[0].getKind();
908 if (k0 == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
909 {
910 // (mod (mod x c) c) --> (mod x c)
911 return returnRewrite(t, t[0], Rewrite::MOD_OVER_MOD);
912 }
913 else if (k0 == kind::NONLINEAR_MULT || k0 == kind::MULT || k0 == kind::PLUS)
914 {
915 // can drop all
916 std::vector<Node> newChildren;
917 bool childChanged = false;
918 for (const Node& tc : t[0])
919 {
920 if (tc.getKind() == kind::INTS_MODULUS_TOTAL && tc[1] == t[1])
921 {
922 newChildren.push_back(tc[0]);
923 childChanged = true;
924 continue;
925 }
926 newChildren.push_back(tc);
927 }
928 if (childChanged)
929 {
930 // (mod (op ... (mod x c) ...) c) ---> (mod (op ... x ...) c) where
931 // op is one of { NONLINEAR_MULT, MULT, PLUS }.
932 Node ret = nm->mkNode(k0, newChildren);
933 ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, ret, t[1]);
934 return returnRewrite(t, ret, Rewrite::MOD_CHILD_MOD);
935 }
936 }
937 }
938 else
939 {
940 Assert(k == kind::INTS_DIVISION_TOTAL);
941 // Note these rewrites do not need to account for division by zero as being
942 // a UF, which is handled by the reduction of INTS_DIVISION.
943 if (t[0].getKind() == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
944 {
945 // (div (mod x c) c) --> 0
946 Node ret = nm->mkConstInt(0);
947 return returnRewrite(t, ret, Rewrite::DIV_OVER_MOD);
948 }
949 }
950 return RewriteResponse(REWRITE_DONE, t);
951 }
952
953 TrustNode ArithRewriter::expandDefinition(Node node)
954 {
955 // call eliminate operators, to eliminate partial operators only
956 std::vector<SkolemLemma> lems;
957 TrustNode ret = d_opElim.eliminate(node, lems, true);
958 Assert(lems.empty());
959 return ret;
960 }
961
962 RewriteResponse ArithRewriter::returnRewrite(TNode t, Node ret, Rewrite r)
963 {
964 Trace("arith-rewrite") << "ArithRewriter : " << t << " == " << ret << " by "
965 << r << std::endl;
966 return RewriteResponse(REWRITE_AGAIN_FULL, ret);
967 }
968
969 } // namespace arith
970 } // namespace theory
971 } // namespace cvc5