pow2: more implementations (#6756)
[cvc5.git] / src / theory / arith / arith_rewriter.cpp
1 /******************************************************************************
2 * Top contributors (to current version):
3 * Andrew Reynolds, Tim King, Morgan Deters
4 *
5 * This file is part of the cvc5 project.
6 *
7 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 * in the top-level source directory and their institutional affiliations.
9 * All rights reserved. See the file COPYING in the top-level source
10 * directory for licensing information.
11 * ****************************************************************************
12 *
13 * [[ Add one-line brief description here ]]
14 *
15 * [[ Add lengthier description here ]]
16 * \todo document this file
17 */
18
19 #include "theory/arith/arith_rewriter.h"
20
21 #include <set>
22 #include <sstream>
23 #include <stack>
24 #include <vector>
25
26 #include "smt/logic_exception.h"
27 #include "theory/arith/arith_msum.h"
28 #include "theory/arith/arith_utilities.h"
29 #include "theory/arith/normal_form.h"
30 #include "theory/arith/operator_elim.h"
31 #include "theory/theory.h"
32 #include "util/bitvector.h"
33 #include "util/divisible.h"
34 #include "util/iand.h"
35
36 namespace cvc5 {
37 namespace theory {
38 namespace arith {
39
40 ArithRewriter::ArithRewriter(OperatorElim& oe) : d_opElim(oe) {}
41
42 bool ArithRewriter::isAtom(TNode n) {
43 Kind k = n.getKind();
44 return arith::isRelationOperator(k) || k == kind::IS_INTEGER
45 || k == kind::DIVISIBLE;
46 }
47
48 RewriteResponse ArithRewriter::rewriteConstant(TNode t){
49 Assert(t.isConst());
50 Assert(t.getKind() == kind::CONST_RATIONAL);
51
52 return RewriteResponse(REWRITE_DONE, t);
53 }
54
55 RewriteResponse ArithRewriter::rewriteVariable(TNode t){
56 Assert(t.isVar());
57
58 return RewriteResponse(REWRITE_DONE, t);
59 }
60
61 RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){
62 Assert(t.getKind() == kind::MINUS);
63
64 if(pre){
65 if(t[0] == t[1]){
66 Rational zero(0);
67 Node zeroNode = mkRationalNode(zero);
68 return RewriteResponse(REWRITE_DONE, zeroNode);
69 }else{
70 Node noMinus = makeSubtractionNode(t[0],t[1]);
71 return RewriteResponse(REWRITE_DONE, noMinus);
72 }
73 }else{
74 Polynomial minuend = Polynomial::parsePolynomial(t[0]);
75 Polynomial subtrahend = Polynomial::parsePolynomial(t[1]);
76 Polynomial diff = minuend - subtrahend;
77 return RewriteResponse(REWRITE_DONE, diff.getNode());
78 }
79 }
80
81 RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){
82 Assert(t.getKind() == kind::UMINUS);
83
84 if(t[0].getKind() == kind::CONST_RATIONAL){
85 Rational neg = -(t[0].getConst<Rational>());
86 return RewriteResponse(REWRITE_DONE, mkRationalNode(neg));
87 }
88
89 Node noUminus = makeUnaryMinusNode(t[0]);
90 if(pre)
91 return RewriteResponse(REWRITE_DONE, noUminus);
92 else
93 return RewriteResponse(REWRITE_AGAIN, noUminus);
94 }
95
96 RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
97 if(t.isConst()){
98 return rewriteConstant(t);
99 }else if(t.isVar()){
100 return rewriteVariable(t);
101 }else{
102 switch(Kind k = t.getKind()){
103 case kind::MINUS:
104 return rewriteMinus(t, true);
105 case kind::UMINUS:
106 return rewriteUMinus(t, true);
107 case kind::DIVISION:
108 case kind::DIVISION_TOTAL:
109 return rewriteDiv(t,true);
110 case kind::PLUS:
111 return preRewritePlus(t);
112 case kind::MULT:
113 case kind::NONLINEAR_MULT: return preRewriteMult(t);
114 case kind::IAND: return RewriteResponse(REWRITE_DONE, t);
115 case kind::POW2: return RewriteResponse(REWRITE_DONE, t);
116 case kind::EXPONENTIAL:
117 case kind::SINE:
118 case kind::COSINE:
119 case kind::TANGENT:
120 case kind::COSECANT:
121 case kind::SECANT:
122 case kind::COTANGENT:
123 case kind::ARCSINE:
124 case kind::ARCCOSINE:
125 case kind::ARCTANGENT:
126 case kind::ARCCOSECANT:
127 case kind::ARCSECANT:
128 case kind::ARCCOTANGENT:
129 case kind::SQRT: return preRewriteTranscendental(t);
130 case kind::INTS_DIVISION:
131 case kind::INTS_MODULUS: return rewriteIntsDivMod(t, true);
132 case kind::INTS_DIVISION_TOTAL:
133 case kind::INTS_MODULUS_TOTAL:
134 return rewriteIntsDivModTotal(t,true);
135 case kind::ABS:
136 if(t[0].isConst()) {
137 const Rational& rat = t[0].getConst<Rational>();
138 if(rat >= 0) {
139 return RewriteResponse(REWRITE_DONE, t[0]);
140 } else {
141 return RewriteResponse(REWRITE_DONE,
142 NodeManager::currentNM()->mkConst(-rat));
143 }
144 }
145 return RewriteResponse(REWRITE_DONE, t);
146 case kind::IS_INTEGER:
147 case kind::TO_INTEGER:
148 return RewriteResponse(REWRITE_DONE, t);
149 case kind::TO_REAL:
150 case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
151 case kind::POW:
152 return RewriteResponse(REWRITE_DONE, t);
153 case kind::PI:
154 return RewriteResponse(REWRITE_DONE, t);
155 default: Unhandled() << k;
156 }
157 }
158 }
159
160 RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
161 if(t.isConst()){
162 return rewriteConstant(t);
163 }else if(t.isVar()){
164 return rewriteVariable(t);
165 }else{
166 switch(t.getKind()){
167 case kind::MINUS:
168 return rewriteMinus(t, false);
169 case kind::UMINUS:
170 return rewriteUMinus(t, false);
171 case kind::DIVISION:
172 case kind::DIVISION_TOTAL:
173 return rewriteDiv(t, false);
174 case kind::PLUS:
175 return postRewritePlus(t);
176 case kind::MULT:
177 case kind::NONLINEAR_MULT: return postRewriteMult(t);
178 case kind::IAND: return postRewriteIAnd(t);
179 case kind::POW2: return postRewritePow2(t);
180 case kind::EXPONENTIAL:
181 case kind::SINE:
182 case kind::COSINE:
183 case kind::TANGENT:
184 case kind::COSECANT:
185 case kind::SECANT:
186 case kind::COTANGENT:
187 case kind::ARCSINE:
188 case kind::ARCCOSINE:
189 case kind::ARCTANGENT:
190 case kind::ARCCOSECANT:
191 case kind::ARCSECANT:
192 case kind::ARCCOTANGENT:
193 case kind::SQRT: return postRewriteTranscendental(t);
194 case kind::INTS_DIVISION:
195 case kind::INTS_MODULUS: return rewriteIntsDivMod(t, false);
196 case kind::INTS_DIVISION_TOTAL:
197 case kind::INTS_MODULUS_TOTAL:
198 return rewriteIntsDivModTotal(t, false);
199 case kind::ABS:
200 if(t[0].isConst()) {
201 const Rational& rat = t[0].getConst<Rational>();
202 if(rat >= 0) {
203 return RewriteResponse(REWRITE_DONE, t[0]);
204 } else {
205 return RewriteResponse(REWRITE_DONE,
206 NodeManager::currentNM()->mkConst(-rat));
207 }
208 }
209 return RewriteResponse(REWRITE_DONE, t);
210 case kind::TO_REAL:
211 case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
212 case kind::TO_INTEGER:
213 if(t[0].isConst()) {
214 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(Rational(t[0].getConst<Rational>().floor())));
215 }
216 if(t[0].getType().isInteger()) {
217 return RewriteResponse(REWRITE_DONE, t[0]);
218 }
219 //Unimplemented() << "TO_INTEGER, nonconstant";
220 //return rewriteToInteger(t);
221 return RewriteResponse(REWRITE_DONE, t);
222 case kind::IS_INTEGER:
223 if(t[0].isConst()) {
224 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(t[0].getConst<Rational>().getDenominator() == 1));
225 }
226 if(t[0].getType().isInteger()) {
227 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
228 }
229 //Unimplemented() << "IS_INTEGER, nonconstant";
230 //return rewriteIsInteger(t);
231 return RewriteResponse(REWRITE_DONE, t);
232 case kind::POW:
233 {
234 if(t[1].getKind() == kind::CONST_RATIONAL){
235 const Rational& exp = t[1].getConst<Rational>();
236 TNode base = t[0];
237 if(exp.sgn() == 0){
238 return RewriteResponse(REWRITE_DONE, mkRationalNode(Rational(1)));
239 }else if(exp.sgn() > 0 && exp.isIntegral()){
240 cvc5::Rational r(expr::NodeValue::MAX_CHILDREN);
241 if (exp <= r)
242 {
243 unsigned num = exp.getNumerator().toUnsignedInt();
244 if( num==1 ){
245 return RewriteResponse(REWRITE_AGAIN, base);
246 }else{
247 NodeBuilder nb(kind::MULT);
248 for(unsigned i=0; i < num; ++i){
249 nb << base;
250 }
251 Assert(nb.getNumChildren() > 0);
252 Node mult = nb;
253 return RewriteResponse(REWRITE_AGAIN, mult);
254 }
255 }
256 }
257 }
258
259 // Todo improve the exception thrown
260 std::stringstream ss;
261 ss << "The exponent of the POW(^) operator can only be a positive "
262 "integral constant below "
263 << (expr::NodeValue::MAX_CHILDREN + 1) << ". ";
264 ss << "Exception occurred in:" << std::endl;
265 ss << " " << t;
266 throw LogicException(ss.str());
267 }
268 case kind::PI:
269 return RewriteResponse(REWRITE_DONE, t);
270 default:
271 Unreachable();
272 }
273 }
274 }
275
276
277 RewriteResponse ArithRewriter::preRewriteMult(TNode t){
278 Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT);
279
280 if(t.getNumChildren() == 2){
281 if(t[0].getKind() == kind::CONST_RATIONAL
282 && t[0].getConst<Rational>().isOne()){
283 return RewriteResponse(REWRITE_DONE, t[1]);
284 }
285 if(t[1].getKind() == kind::CONST_RATIONAL
286 && t[1].getConst<Rational>().isOne()){
287 return RewriteResponse(REWRITE_DONE, t[0]);
288 }
289 }
290
291 // Rewrite multiplications with a 0 argument and to 0
292 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
293 if((*i).getKind() == kind::CONST_RATIONAL) {
294 if((*i).getConst<Rational>().isZero()) {
295 TNode zero = (*i);
296 return RewriteResponse(REWRITE_DONE, zero);
297 }
298 }
299 }
300 return RewriteResponse(REWRITE_DONE, t);
301 }
302
303 static bool canFlatten(Kind k, TNode t){
304 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
305 TNode child = *i;
306 if(child.getKind() == k){
307 return true;
308 }
309 }
310 return false;
311 }
312
313 static void flatten(std::vector<TNode>& pb, Kind k, TNode t){
314 if(t.getKind() == k){
315 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
316 TNode child = *i;
317 if(child.getKind() == k){
318 flatten(pb, k, child);
319 }else{
320 pb.push_back(child);
321 }
322 }
323 }else{
324 pb.push_back(t);
325 }
326 }
327
328 static Node flatten(Kind k, TNode t){
329 std::vector<TNode> pb;
330 flatten(pb, k, t);
331 Assert(pb.size() >= 2);
332 return NodeManager::currentNM()->mkNode(k, pb);
333 }
334
335 RewriteResponse ArithRewriter::preRewritePlus(TNode t){
336 Assert(t.getKind() == kind::PLUS);
337
338 if(canFlatten(kind::PLUS, t)){
339 return RewriteResponse(REWRITE_DONE, flatten(kind::PLUS, t));
340 }else{
341 return RewriteResponse(REWRITE_DONE, t);
342 }
343 }
344
345 RewriteResponse ArithRewriter::postRewritePlus(TNode t){
346 Assert(t.getKind() == kind::PLUS);
347
348 std::vector<Monomial> monomials;
349 std::vector<Polynomial> polynomials;
350
351 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
352 TNode curr = *i;
353 if(Monomial::isMember(curr)){
354 monomials.push_back(Monomial::parseMonomial(curr));
355 }else{
356 polynomials.push_back(Polynomial::parsePolynomial(curr));
357 }
358 }
359
360 if(!monomials.empty()){
361 Monomial::sort(monomials);
362 Monomial::combineAdjacentMonomials(monomials);
363 polynomials.push_back(Polynomial::mkPolynomial(monomials));
364 }
365
366 Polynomial res = Polynomial::sumPolynomials(polynomials);
367
368 return RewriteResponse(REWRITE_DONE, res.getNode());
369 }
370
371 RewriteResponse ArithRewriter::postRewriteMult(TNode t){
372 Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT);
373
374 Polynomial res = Polynomial::mkOne();
375
376 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
377 Node curr = *i;
378 Polynomial currPoly = Polynomial::parsePolynomial(curr);
379
380 res = res * currPoly;
381 }
382
383 return RewriteResponse(REWRITE_DONE, res.getNode());
384 }
385
386 RewriteResponse ArithRewriter::postRewritePow2(TNode t)
387 {
388 Assert(t.getKind() == kind::POW2);
389 NodeManager* nm = NodeManager::currentNM();
390 // if constant, we eliminate
391 if (t[0].isConst())
392 {
393 // pow2 is only supported for integers
394 Assert(t[0].getType().isInteger());
395 Integer i = t[0].getConst<Rational>().getNumerator();
396 unsigned long k = i.getUnsignedLong();
397 Node ret = nm->mkConst<Rational>(Rational(Integer(2).pow(k), Integer(1)));
398 return RewriteResponse(REWRITE_DONE, ret);
399 }
400 return RewriteResponse(REWRITE_DONE, t);
401 }
402
403 RewriteResponse ArithRewriter::postRewriteIAnd(TNode t)
404 {
405 Assert(t.getKind() == kind::IAND);
406 NodeManager* nm = NodeManager::currentNM();
407 // if constant, we eliminate
408 if (t[0].isConst() && t[1].isConst())
409 {
410 size_t bsize = t.getOperator().getConst<IntAnd>().d_size;
411 Node iToBvop = nm->mkConst(IntToBitVector(bsize));
412 Node arg1 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[0]);
413 Node arg2 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[1]);
414 Node bvand = nm->mkNode(kind::BITVECTOR_AND, arg1, arg2);
415 Node ret = nm->mkNode(kind::BITVECTOR_TO_NAT, bvand);
416 return RewriteResponse(REWRITE_AGAIN_FULL, ret);
417 }
418 else if (t[0] > t[1])
419 {
420 // ((_ iand k) x y) ---> ((_ iand k) y x) if x > y by node ordering
421 Node ret = nm->mkNode(kind::IAND, t.getOperator(), t[1], t[0]);
422 return RewriteResponse(REWRITE_AGAIN, ret);
423 }
424 else if (t[0] == t[1])
425 {
426 // ((_ iand k) x x) ---> x
427 return RewriteResponse(REWRITE_DONE, t[0]);
428 }
429 // simplifications involving constants
430 for (unsigned i = 0; i < 2; i++)
431 {
432 if (!t[i].isConst())
433 {
434 continue;
435 }
436 if (t[i].getConst<Rational>().sgn() == 0)
437 {
438 // ((_ iand k) 0 y) ---> 0
439 return RewriteResponse(REWRITE_DONE, t[i]);
440 }
441 }
442 return RewriteResponse(REWRITE_DONE, t);
443 }
444
445 RewriteResponse ArithRewriter::preRewriteTranscendental(TNode t) {
446 return RewriteResponse(REWRITE_DONE, t);
447 }
448
449 RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) {
450 Trace("arith-tf-rewrite") << "Rewrite transcendental function : " << t << std::endl;
451 NodeManager* nm = NodeManager::currentNM();
452 switch( t.getKind() ){
453 case kind::EXPONENTIAL: {
454 if(t[0].getKind() == kind::CONST_RATIONAL){
455 Node one = nm->mkConst(Rational(1));
456 if(t[0].getConst<Rational>().sgn()>=0 && t[0].getType().isInteger() && t[0]!=one){
457 return RewriteResponse(
458 REWRITE_AGAIN,
459 nm->mkNode(kind::POW, nm->mkNode(kind::EXPONENTIAL, one), t[0]));
460 }else{
461 return RewriteResponse(REWRITE_DONE, t);
462 }
463 }
464 else if (t[0].getKind() == kind::PLUS)
465 {
466 std::vector<Node> product;
467 for (const Node tc : t[0])
468 {
469 product.push_back(nm->mkNode(kind::EXPONENTIAL, tc));
470 }
471 // We need to do a full rewrite here, since we can get exponentials of
472 // constants, e.g. when we are rewriting exp(2 + x)
473 return RewriteResponse(REWRITE_AGAIN_FULL,
474 nm->mkNode(kind::MULT, product));
475 }
476 }
477 break;
478 case kind::SINE:
479 if(t[0].getKind() == kind::CONST_RATIONAL){
480 const Rational& rat = t[0].getConst<Rational>();
481 if(rat.sgn() == 0){
482 return RewriteResponse(REWRITE_DONE, nm->mkConst(Rational(0)));
483 }
484 else if (rat.sgn() == -1)
485 {
486 Node ret =
487 nm->mkNode(kind::UMINUS, nm->mkNode(kind::SINE, nm->mkConst(-rat)));
488 return RewriteResponse(REWRITE_AGAIN_FULL, ret);
489 }
490 }else{
491 // get the factor of PI in the argument
492 Node pi_factor;
493 Node pi;
494 Node rem;
495 std::map<Node, Node> msum;
496 if (ArithMSum::getMonomialSum(t[0], msum))
497 {
498 pi = mkPi();
499 std::map<Node, Node>::iterator itm = msum.find(pi);
500 if (itm != msum.end())
501 {
502 if (itm->second.isNull())
503 {
504 pi_factor = mkRationalNode(Rational(1));
505 }
506 else
507 {
508 pi_factor = itm->second;
509 }
510 msum.erase(pi);
511 if (!msum.empty())
512 {
513 rem = ArithMSum::mkNode(msum);
514 }
515 }
516 }
517 else
518 {
519 Assert(false);
520 }
521
522 // if there is a factor of PI
523 if( !pi_factor.isNull() ){
524 Trace("arith-tf-rewrite-debug") << "Process pi factor = " << pi_factor << std::endl;
525 Rational r = pi_factor.getConst<Rational>();
526 Rational r_abs = r.abs();
527 Rational rone = Rational(1);
528 Node ntwo = mkRationalNode(Rational(2));
529 if (r_abs > rone)
530 {
531 //add/substract 2*pi beyond scope
532 Node ra_div_two = nm->mkNode(
533 kind::INTS_DIVISION, mkRationalNode(r_abs + rone), ntwo);
534 Node new_pi_factor;
535 if( r.sgn()==1 ){
536 new_pi_factor =
537 nm->mkNode(kind::MINUS,
538 pi_factor,
539 nm->mkNode(kind::MULT, ntwo, ra_div_two));
540 }else{
541 Assert(r.sgn() == -1);
542 new_pi_factor =
543 nm->mkNode(kind::PLUS,
544 pi_factor,
545 nm->mkNode(kind::MULT, ntwo, ra_div_two));
546 }
547 Node new_arg = nm->mkNode(kind::MULT, new_pi_factor, pi);
548 if (!rem.isNull())
549 {
550 new_arg = nm->mkNode(kind::PLUS, new_arg, rem);
551 }
552 // sin( 2*n*PI + x ) = sin( x )
553 return RewriteResponse(REWRITE_AGAIN_FULL,
554 nm->mkNode(kind::SINE, new_arg));
555 }
556 else if (r_abs == rone)
557 {
558 // sin( PI + x ) = -sin( x )
559 if (rem.isNull())
560 {
561 return RewriteResponse(REWRITE_DONE, mkRationalNode(Rational(0)));
562 }
563 else
564 {
565 return RewriteResponse(
566 REWRITE_AGAIN_FULL,
567 nm->mkNode(kind::UMINUS, nm->mkNode(kind::SINE, rem)));
568 }
569 }
570 else if (rem.isNull())
571 {
572 // other rational cases based on Niven's theorem
573 // (https://en.wikipedia.org/wiki/Niven%27s_theorem)
574 Integer one = Integer(1);
575 Integer two = Integer(2);
576 Integer six = Integer(6);
577 if (r_abs.getDenominator() == two)
578 {
579 Assert(r_abs.getNumerator() == one);
580 return RewriteResponse(REWRITE_DONE,
581 mkRationalNode(Rational(r.sgn())));
582 }
583 else if (r_abs.getDenominator() == six)
584 {
585 Integer five = Integer(5);
586 if (r_abs.getNumerator() == one || r_abs.getNumerator() == five)
587 {
588 return RewriteResponse(
589 REWRITE_DONE,
590 mkRationalNode(Rational(r.sgn()) / Rational(2)));
591 }
592 }
593 }
594 }
595 }
596 break;
597 case kind::COSINE: {
598 return RewriteResponse(
599 REWRITE_AGAIN_FULL,
600 nm->mkNode(kind::SINE,
601 nm->mkNode(kind::MINUS,
602 nm->mkNode(kind::MULT,
603 nm->mkConst(Rational(1) / Rational(2)),
604 mkPi()),
605 t[0])));
606 }
607 break;
608 case kind::TANGENT:
609 {
610 return RewriteResponse(REWRITE_AGAIN_FULL,
611 nm->mkNode(kind::DIVISION,
612 nm->mkNode(kind::SINE, t[0]),
613 nm->mkNode(kind::COSINE, t[0])));
614 }
615 break;
616 case kind::COSECANT:
617 {
618 return RewriteResponse(REWRITE_AGAIN_FULL,
619 nm->mkNode(kind::DIVISION,
620 mkRationalNode(Rational(1)),
621 nm->mkNode(kind::SINE, t[0])));
622 }
623 break;
624 case kind::SECANT:
625 {
626 return RewriteResponse(REWRITE_AGAIN_FULL,
627 nm->mkNode(kind::DIVISION,
628 mkRationalNode(Rational(1)),
629 nm->mkNode(kind::COSINE, t[0])));
630 }
631 break;
632 case kind::COTANGENT:
633 {
634 return RewriteResponse(REWRITE_AGAIN_FULL,
635 nm->mkNode(kind::DIVISION,
636 nm->mkNode(kind::COSINE, t[0]),
637 nm->mkNode(kind::SINE, t[0])));
638 }
639 break;
640 default:
641 break;
642 }
643 return RewriteResponse(REWRITE_DONE, t);
644 }
645
646 RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
647 if(atom.getKind() == kind::IS_INTEGER) {
648 if(atom[0].isConst()) {
649 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(atom[0].getConst<Rational>().isIntegral()));
650 }
651 if(atom[0].getType().isInteger()) {
652 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
653 }
654 // not supported, but this isn't the right place to complain
655 return RewriteResponse(REWRITE_DONE, atom);
656 } else if(atom.getKind() == kind::DIVISIBLE) {
657 if(atom[0].isConst()) {
658 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(bool((atom[0].getConst<Rational>() / atom.getOperator().getConst<Divisible>().k).isIntegral())));
659 }
660 if(atom.getOperator().getConst<Divisible>().k.isOne()) {
661 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
662 }
663 return RewriteResponse(REWRITE_AGAIN, NodeManager::currentNM()->mkNode(kind::EQUAL, NodeManager::currentNM()->mkNode(kind::INTS_MODULUS_TOTAL, atom[0], NodeManager::currentNM()->mkConst(Rational(atom.getOperator().getConst<Divisible>().k))), NodeManager::currentNM()->mkConst(Rational(0))));
664 }
665
666 // left |><| right
667 TNode left = atom[0];
668 TNode right = atom[1];
669
670 Polynomial pleft = Polynomial::parsePolynomial(left);
671 Polynomial pright = Polynomial::parsePolynomial(right);
672
673 Debug("arith::rewriter") << "pleft " << pleft.getNode() << std::endl;
674 Debug("arith::rewriter") << "pright " << pright.getNode() << std::endl;
675
676 Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
677 Assert(cmp.isNormalForm());
678 return RewriteResponse(REWRITE_DONE, cmp.getNode());
679 }
680
681 RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
682 Assert(isAtom(atom));
683
684 NodeManager* currNM = NodeManager::currentNM();
685
686 if(atom.getKind() == kind::EQUAL) {
687 if(atom[0] == atom[1]) {
688 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
689 }
690 }else if(atom.getKind() == kind::GT){
691 Node leq = currNM->mkNode(kind::LEQ, atom[0], atom[1]);
692 return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, leq));
693 }else if(atom.getKind() == kind::LT){
694 Node geq = currNM->mkNode(kind::GEQ, atom[0], atom[1]);
695 return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, geq));
696 }else if(atom.getKind() == kind::IS_INTEGER){
697 if(atom[0].getType().isInteger()){
698 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
699 }
700 }else if(atom.getKind() == kind::DIVISIBLE){
701 if(atom.getOperator().getConst<Divisible>().k.isOne()){
702 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
703 }
704 }
705
706 return RewriteResponse(REWRITE_DONE, atom);
707 }
708
709 RewriteResponse ArithRewriter::postRewrite(TNode t){
710 if(isTerm(t)){
711 RewriteResponse response = postRewriteTerm(t);
712 if (Debug.isOn("arith::rewriter") && response.d_status == REWRITE_DONE)
713 {
714 Polynomial::parsePolynomial(response.d_node);
715 }
716 return response;
717 }else if(isAtom(t)){
718 RewriteResponse response = postRewriteAtom(t);
719 if (Debug.isOn("arith::rewriter") && response.d_status == REWRITE_DONE)
720 {
721 Comparison::parseNormalForm(response.d_node);
722 }
723 return response;
724 }else{
725 Unreachable();
726 }
727 }
728
729 RewriteResponse ArithRewriter::preRewrite(TNode t){
730 if(isTerm(t)){
731 return preRewriteTerm(t);
732 }else if(isAtom(t)){
733 return preRewriteAtom(t);
734 }else{
735 Unreachable();
736 }
737 }
738
739 Node ArithRewriter::makeUnaryMinusNode(TNode n){
740 Rational qNegOne(-1);
741 return NodeManager::currentNM()->mkNode(kind::MULT, mkRationalNode(qNegOne),n);
742 }
743
744 Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
745 Node negR = makeUnaryMinusNode(r);
746 Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR);
747
748 return diff;
749 }
750
751 RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
752 Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind() == kind::DIVISION);
753
754 Node left = t[0];
755 Node right = t[1];
756 if(right.getKind() == kind::CONST_RATIONAL){
757 const Rational& den = right.getConst<Rational>();
758
759 if(den.isZero()){
760 if(t.getKind() == kind::DIVISION_TOTAL){
761 return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
762 }else{
763 // This is unsupported, but this is not a good place to complain
764 return RewriteResponse(REWRITE_DONE, t);
765 }
766 }
767 Assert(den != Rational(0));
768
769 if(left.getKind() == kind::CONST_RATIONAL){
770 const Rational& num = left.getConst<Rational>();
771 Rational div = num / den;
772 Node result = mkRationalNode(div);
773 return RewriteResponse(REWRITE_DONE, result);
774 }
775
776 Rational div = den.inverse();
777
778 Node result = mkRationalNode(div);
779
780 Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
781 if(pre){
782 return RewriteResponse(REWRITE_DONE, mult);
783 }else{
784 return RewriteResponse(REWRITE_AGAIN, mult);
785 }
786 }else{
787 return RewriteResponse(REWRITE_DONE, t);
788 }
789 }
790
791 RewriteResponse ArithRewriter::rewriteIntsDivMod(TNode t, bool pre)
792 {
793 NodeManager* nm = NodeManager::currentNM();
794 Kind k = t.getKind();
795 Node zero = nm->mkConst(Rational(0));
796 if (k == kind::INTS_MODULUS)
797 {
798 if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
799 {
800 // can immediately replace by INTS_MODULUS_TOTAL
801 Node ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, t[0], t[1]);
802 return returnRewrite(t, ret, Rewrite::MOD_TOTAL_BY_CONST);
803 }
804 }
805 if (k == kind::INTS_DIVISION)
806 {
807 if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
808 {
809 // can immediately replace by INTS_DIVISION_TOTAL
810 Node ret = nm->mkNode(kind::INTS_DIVISION_TOTAL, t[0], t[1]);
811 return returnRewrite(t, ret, Rewrite::DIV_TOTAL_BY_CONST);
812 }
813 }
814 return RewriteResponse(REWRITE_DONE, t);
815 }
816
817 RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre)
818 {
819 if (pre)
820 {
821 // do not rewrite at prewrite.
822 return RewriteResponse(REWRITE_DONE, t);
823 }
824 NodeManager* nm = NodeManager::currentNM();
825 Kind k = t.getKind();
826 Assert(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL);
827 TNode n = t[0];
828 TNode d = t[1];
829 bool dIsConstant = d.getKind() == kind::CONST_RATIONAL;
830 if(dIsConstant && d.getConst<Rational>().isZero()){
831 // (div x 0) ---> 0 or (mod x 0) ---> 0
832 return returnRewrite(t, mkRationalNode(0), Rewrite::DIV_MOD_BY_ZERO);
833 }else if(dIsConstant && d.getConst<Rational>().isOne()){
834 if (k == kind::INTS_MODULUS_TOTAL)
835 {
836 // (mod x 1) --> 0
837 return returnRewrite(t, mkRationalNode(0), Rewrite::MOD_BY_ONE);
838 }
839 Assert(k == kind::INTS_DIVISION_TOTAL);
840 // (div x 1) --> x
841 return returnRewrite(t, n, Rewrite::DIV_BY_ONE);
842 }
843 else if (dIsConstant && d.getConst<Rational>().sgn() < 0)
844 {
845 // pull negation
846 // (div x (- c)) ---> (- (div x c))
847 // (mod x (- c)) ---> (mod x c)
848 Node nn = nm->mkNode(k, t[0], nm->mkConst(-t[1].getConst<Rational>()));
849 Node ret = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL)
850 ? nm->mkNode(kind::UMINUS, nn)
851 : nn;
852 return returnRewrite(t, ret, Rewrite::DIV_MOD_PULL_NEG_DEN);
853 }
854 else if (dIsConstant && n.getKind() == kind::CONST_RATIONAL)
855 {
856 Assert(d.getConst<Rational>().isIntegral());
857 Assert(n.getConst<Rational>().isIntegral());
858 Assert(!d.getConst<Rational>().isZero());
859 Integer di = d.getConst<Rational>().getNumerator();
860 Integer ni = n.getConst<Rational>().getNumerator();
861
862 bool isDiv = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
863
864 Integer result = isDiv ? ni.euclidianDivideQuotient(di) : ni.euclidianDivideRemainder(di);
865
866 // constant evaluation
867 // (mod c1 c2) ---> c3 or (div c1 c2) ---> c3
868 Node resultNode = mkRationalNode(Rational(result));
869 return returnRewrite(t, resultNode, Rewrite::CONST_EVAL);
870 }
871 if (k == kind::INTS_MODULUS_TOTAL)
872 {
873 // Note these rewrites do not need to account for modulus by zero as being
874 // a UF, which is handled by the reduction of INTS_MODULUS.
875 Kind k0 = t[0].getKind();
876 if (k0 == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
877 {
878 // (mod (mod x c) c) --> (mod x c)
879 return returnRewrite(t, t[0], Rewrite::MOD_OVER_MOD);
880 }
881 else if (k0 == kind::NONLINEAR_MULT || k0 == kind::MULT || k0 == kind::PLUS)
882 {
883 // can drop all
884 std::vector<Node> newChildren;
885 bool childChanged = false;
886 for (const Node& tc : t[0])
887 {
888 if (tc.getKind() == kind::INTS_MODULUS_TOTAL && tc[1] == t[1])
889 {
890 newChildren.push_back(tc[0]);
891 childChanged = true;
892 continue;
893 }
894 newChildren.push_back(tc);
895 }
896 if (childChanged)
897 {
898 // (mod (op ... (mod x c) ...) c) ---> (mod (op ... x ...) c) where
899 // op is one of { NONLINEAR_MULT, MULT, PLUS }.
900 Node ret = nm->mkNode(k0, newChildren);
901 ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, ret, t[1]);
902 return returnRewrite(t, ret, Rewrite::MOD_CHILD_MOD);
903 }
904 }
905 }
906 else
907 {
908 Assert(k == kind::INTS_DIVISION_TOTAL);
909 // Note these rewrites do not need to account for division by zero as being
910 // a UF, which is handled by the reduction of INTS_DIVISION.
911 if (t[0].getKind() == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
912 {
913 // (div (mod x c) c) --> 0
914 Node ret = mkRationalNode(0);
915 return returnRewrite(t, ret, Rewrite::DIV_OVER_MOD);
916 }
917 }
918 return RewriteResponse(REWRITE_DONE, t);
919 }
920
921 TrustNode ArithRewriter::expandDefinition(Node node)
922 {
923 // call eliminate operators, to eliminate partial operators only
924 std::vector<SkolemLemma> lems;
925 TrustNode ret = d_opElim.eliminate(node, lems, true);
926 Assert(lems.empty());
927 return ret;
928 }
929
930 RewriteResponse ArithRewriter::returnRewrite(TNode t, Node ret, Rewrite r)
931 {
932 Trace("arith-rewrite") << "ArithRewriter : " << t << " == " << ret << " by "
933 << r << std::endl;
934 return RewriteResponse(REWRITE_AGAIN_FULL, ret);
935 }
936
937 } // namespace arith
938 } // namespace theory
939 } // namespace cvc5