Refactor and update copyright headers. (#6316)
[cvc5.git] / src / theory / arith / arith_rewriter.cpp
1 /******************************************************************************
2 * Top contributors (to current version):
3 * Andrew Reynolds, Tim King, Morgan Deters
4 *
5 * This file is part of the cvc5 project.
6 *
7 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 * in the top-level source directory and their institutional affiliations.
9 * All rights reserved. See the file COPYING in the top-level source
10 * directory for licensing information.
11 * ****************************************************************************
12 *
13 * [[ Add one-line brief description here ]]
14 *
15 * [[ Add lengthier description here ]]
16 * \todo document this file
17 */
18
19 #include <set>
20 #include <sstream>
21 #include <stack>
22 #include <vector>
23
24 #include "smt/logic_exception.h"
25 #include "theory/arith/arith_msum.h"
26 #include "theory/arith/arith_rewriter.h"
27 #include "theory/arith/arith_utilities.h"
28 #include "theory/arith/normal_form.h"
29 #include "theory/theory.h"
30 #include "util/iand.h"
31
32 namespace cvc5 {
33 namespace theory {
34 namespace arith {
35
36 bool ArithRewriter::isAtom(TNode n) {
37 Kind k = n.getKind();
38 return arith::isRelationOperator(k) || k == kind::IS_INTEGER
39 || k == kind::DIVISIBLE;
40 }
41
42 RewriteResponse ArithRewriter::rewriteConstant(TNode t){
43 Assert(t.isConst());
44 Assert(t.getKind() == kind::CONST_RATIONAL);
45
46 return RewriteResponse(REWRITE_DONE, t);
47 }
48
49 RewriteResponse ArithRewriter::rewriteVariable(TNode t){
50 Assert(t.isVar());
51
52 return RewriteResponse(REWRITE_DONE, t);
53 }
54
55 RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){
56 Assert(t.getKind() == kind::MINUS);
57
58 if(pre){
59 if(t[0] == t[1]){
60 Rational zero(0);
61 Node zeroNode = mkRationalNode(zero);
62 return RewriteResponse(REWRITE_DONE, zeroNode);
63 }else{
64 Node noMinus = makeSubtractionNode(t[0],t[1]);
65 return RewriteResponse(REWRITE_DONE, noMinus);
66 }
67 }else{
68 Polynomial minuend = Polynomial::parsePolynomial(t[0]);
69 Polynomial subtrahend = Polynomial::parsePolynomial(t[1]);
70 Polynomial diff = minuend - subtrahend;
71 return RewriteResponse(REWRITE_DONE, diff.getNode());
72 }
73 }
74
75 RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){
76 Assert(t.getKind() == kind::UMINUS);
77
78 if(t[0].getKind() == kind::CONST_RATIONAL){
79 Rational neg = -(t[0].getConst<Rational>());
80 return RewriteResponse(REWRITE_DONE, mkRationalNode(neg));
81 }
82
83 Node noUminus = makeUnaryMinusNode(t[0]);
84 if(pre)
85 return RewriteResponse(REWRITE_DONE, noUminus);
86 else
87 return RewriteResponse(REWRITE_AGAIN, noUminus);
88 }
89
90 RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
91 if(t.isConst()){
92 return rewriteConstant(t);
93 }else if(t.isVar()){
94 return rewriteVariable(t);
95 }else{
96 switch(Kind k = t.getKind()){
97 case kind::MINUS:
98 return rewriteMinus(t, true);
99 case kind::UMINUS:
100 return rewriteUMinus(t, true);
101 case kind::DIVISION:
102 case kind::DIVISION_TOTAL:
103 return rewriteDiv(t,true);
104 case kind::PLUS:
105 return preRewritePlus(t);
106 case kind::MULT:
107 case kind::NONLINEAR_MULT: return preRewriteMult(t);
108 case kind::IAND: return RewriteResponse(REWRITE_DONE, t);
109 case kind::EXPONENTIAL:
110 case kind::SINE:
111 case kind::COSINE:
112 case kind::TANGENT:
113 case kind::COSECANT:
114 case kind::SECANT:
115 case kind::COTANGENT:
116 case kind::ARCSINE:
117 case kind::ARCCOSINE:
118 case kind::ARCTANGENT:
119 case kind::ARCCOSECANT:
120 case kind::ARCSECANT:
121 case kind::ARCCOTANGENT:
122 case kind::SQRT: return preRewriteTranscendental(t);
123 case kind::INTS_DIVISION:
124 case kind::INTS_MODULUS: return rewriteIntsDivMod(t, true);
125 case kind::INTS_DIVISION_TOTAL:
126 case kind::INTS_MODULUS_TOTAL:
127 return rewriteIntsDivModTotal(t,true);
128 case kind::ABS:
129 if(t[0].isConst()) {
130 const Rational& rat = t[0].getConst<Rational>();
131 if(rat >= 0) {
132 return RewriteResponse(REWRITE_DONE, t[0]);
133 } else {
134 return RewriteResponse(REWRITE_DONE,
135 NodeManager::currentNM()->mkConst(-rat));
136 }
137 }
138 return RewriteResponse(REWRITE_DONE, t);
139 case kind::IS_INTEGER:
140 case kind::TO_INTEGER:
141 return RewriteResponse(REWRITE_DONE, t);
142 case kind::TO_REAL:
143 case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
144 case kind::POW:
145 return RewriteResponse(REWRITE_DONE, t);
146 case kind::PI:
147 return RewriteResponse(REWRITE_DONE, t);
148 default: Unhandled() << k;
149 }
150 }
151 }
152
153 RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
154 if(t.isConst()){
155 return rewriteConstant(t);
156 }else if(t.isVar()){
157 return rewriteVariable(t);
158 }else{
159 switch(t.getKind()){
160 case kind::MINUS:
161 return rewriteMinus(t, false);
162 case kind::UMINUS:
163 return rewriteUMinus(t, false);
164 case kind::DIVISION:
165 case kind::DIVISION_TOTAL:
166 return rewriteDiv(t, false);
167 case kind::PLUS:
168 return postRewritePlus(t);
169 case kind::MULT:
170 case kind::NONLINEAR_MULT: return postRewriteMult(t);
171 case kind::IAND: return postRewriteIAnd(t);
172 case kind::EXPONENTIAL:
173 case kind::SINE:
174 case kind::COSINE:
175 case kind::TANGENT:
176 case kind::COSECANT:
177 case kind::SECANT:
178 case kind::COTANGENT:
179 case kind::ARCSINE:
180 case kind::ARCCOSINE:
181 case kind::ARCTANGENT:
182 case kind::ARCCOSECANT:
183 case kind::ARCSECANT:
184 case kind::ARCCOTANGENT:
185 case kind::SQRT: return postRewriteTranscendental(t);
186 case kind::INTS_DIVISION:
187 case kind::INTS_MODULUS: return rewriteIntsDivMod(t, false);
188 case kind::INTS_DIVISION_TOTAL:
189 case kind::INTS_MODULUS_TOTAL:
190 return rewriteIntsDivModTotal(t, false);
191 case kind::ABS:
192 if(t[0].isConst()) {
193 const Rational& rat = t[0].getConst<Rational>();
194 if(rat >= 0) {
195 return RewriteResponse(REWRITE_DONE, t[0]);
196 } else {
197 return RewriteResponse(REWRITE_DONE,
198 NodeManager::currentNM()->mkConst(-rat));
199 }
200 }
201 return RewriteResponse(REWRITE_DONE, t);
202 case kind::TO_REAL:
203 case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
204 case kind::TO_INTEGER:
205 if(t[0].isConst()) {
206 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(Rational(t[0].getConst<Rational>().floor())));
207 }
208 if(t[0].getType().isInteger()) {
209 return RewriteResponse(REWRITE_DONE, t[0]);
210 }
211 //Unimplemented() << "TO_INTEGER, nonconstant";
212 //return rewriteToInteger(t);
213 return RewriteResponse(REWRITE_DONE, t);
214 case kind::IS_INTEGER:
215 if(t[0].isConst()) {
216 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(t[0].getConst<Rational>().getDenominator() == 1));
217 }
218 if(t[0].getType().isInteger()) {
219 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
220 }
221 //Unimplemented() << "IS_INTEGER, nonconstant";
222 //return rewriteIsInteger(t);
223 return RewriteResponse(REWRITE_DONE, t);
224 case kind::POW:
225 {
226 if(t[1].getKind() == kind::CONST_RATIONAL){
227 const Rational& exp = t[1].getConst<Rational>();
228 TNode base = t[0];
229 if(exp.sgn() == 0){
230 return RewriteResponse(REWRITE_DONE, mkRationalNode(Rational(1)));
231 }else if(exp.sgn() > 0 && exp.isIntegral()){
232 cvc5::Rational r(expr::NodeValue::MAX_CHILDREN);
233 if (exp <= r)
234 {
235 unsigned num = exp.getNumerator().toUnsignedInt();
236 if( num==1 ){
237 return RewriteResponse(REWRITE_AGAIN, base);
238 }else{
239 NodeBuilder nb(kind::MULT);
240 for(unsigned i=0; i < num; ++i){
241 nb << base;
242 }
243 Assert(nb.getNumChildren() > 0);
244 Node mult = nb;
245 return RewriteResponse(REWRITE_AGAIN, mult);
246 }
247 }
248 }
249 }
250
251 // Todo improve the exception thrown
252 std::stringstream ss;
253 ss << "The exponent of the POW(^) operator can only be a positive "
254 "integral constant below "
255 << (expr::NodeValue::MAX_CHILDREN + 1) << ". ";
256 ss << "Exception occurred in:" << std::endl;
257 ss << " " << t;
258 throw LogicException(ss.str());
259 }
260 case kind::PI:
261 return RewriteResponse(REWRITE_DONE, t);
262 default:
263 Unreachable();
264 }
265 }
266 }
267
268
269 RewriteResponse ArithRewriter::preRewriteMult(TNode t){
270 Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT);
271
272 if(t.getNumChildren() == 2){
273 if(t[0].getKind() == kind::CONST_RATIONAL
274 && t[0].getConst<Rational>().isOne()){
275 return RewriteResponse(REWRITE_DONE, t[1]);
276 }
277 if(t[1].getKind() == kind::CONST_RATIONAL
278 && t[1].getConst<Rational>().isOne()){
279 return RewriteResponse(REWRITE_DONE, t[0]);
280 }
281 }
282
283 // Rewrite multiplications with a 0 argument and to 0
284 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
285 if((*i).getKind() == kind::CONST_RATIONAL) {
286 if((*i).getConst<Rational>().isZero()) {
287 TNode zero = (*i);
288 return RewriteResponse(REWRITE_DONE, zero);
289 }
290 }
291 }
292 return RewriteResponse(REWRITE_DONE, t);
293 }
294
295 static bool canFlatten(Kind k, TNode t){
296 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
297 TNode child = *i;
298 if(child.getKind() == k){
299 return true;
300 }
301 }
302 return false;
303 }
304
305 static void flatten(std::vector<TNode>& pb, Kind k, TNode t){
306 if(t.getKind() == k){
307 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
308 TNode child = *i;
309 if(child.getKind() == k){
310 flatten(pb, k, child);
311 }else{
312 pb.push_back(child);
313 }
314 }
315 }else{
316 pb.push_back(t);
317 }
318 }
319
320 static Node flatten(Kind k, TNode t){
321 std::vector<TNode> pb;
322 flatten(pb, k, t);
323 Assert(pb.size() >= 2);
324 return NodeManager::currentNM()->mkNode(k, pb);
325 }
326
327 RewriteResponse ArithRewriter::preRewritePlus(TNode t){
328 Assert(t.getKind() == kind::PLUS);
329
330 if(canFlatten(kind::PLUS, t)){
331 return RewriteResponse(REWRITE_DONE, flatten(kind::PLUS, t));
332 }else{
333 return RewriteResponse(REWRITE_DONE, t);
334 }
335 }
336
337 RewriteResponse ArithRewriter::postRewritePlus(TNode t){
338 Assert(t.getKind() == kind::PLUS);
339
340 std::vector<Monomial> monomials;
341 std::vector<Polynomial> polynomials;
342
343 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
344 TNode curr = *i;
345 if(Monomial::isMember(curr)){
346 monomials.push_back(Monomial::parseMonomial(curr));
347 }else{
348 polynomials.push_back(Polynomial::parsePolynomial(curr));
349 }
350 }
351
352 if(!monomials.empty()){
353 Monomial::sort(monomials);
354 Monomial::combineAdjacentMonomials(monomials);
355 polynomials.push_back(Polynomial::mkPolynomial(monomials));
356 }
357
358 Polynomial res = Polynomial::sumPolynomials(polynomials);
359
360 return RewriteResponse(REWRITE_DONE, res.getNode());
361 }
362
363 RewriteResponse ArithRewriter::postRewriteMult(TNode t){
364 Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT);
365
366 Polynomial res = Polynomial::mkOne();
367
368 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
369 Node curr = *i;
370 Polynomial currPoly = Polynomial::parsePolynomial(curr);
371
372 res = res * currPoly;
373 }
374
375 return RewriteResponse(REWRITE_DONE, res.getNode());
376 }
377
378 RewriteResponse ArithRewriter::postRewriteIAnd(TNode t)
379 {
380 Assert(t.getKind() == kind::IAND);
381 NodeManager* nm = NodeManager::currentNM();
382 // if constant, we eliminate
383 if (t[0].isConst() && t[1].isConst())
384 {
385 size_t bsize = t.getOperator().getConst<IntAnd>().d_size;
386 Node iToBvop = nm->mkConst(IntToBitVector(bsize));
387 Node arg1 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[0]);
388 Node arg2 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[1]);
389 Node bvand = nm->mkNode(kind::BITVECTOR_AND, arg1, arg2);
390 Node ret = nm->mkNode(kind::BITVECTOR_TO_NAT, bvand);
391 return RewriteResponse(REWRITE_AGAIN_FULL, ret);
392 }
393 else if (t[0] > t[1])
394 {
395 // ((_ iand k) x y) ---> ((_ iand k) y x) if x > y by node ordering
396 Node ret = nm->mkNode(kind::IAND, t.getOperator(), t[1], t[0]);
397 return RewriteResponse(REWRITE_AGAIN, ret);
398 }
399 else if (t[0] == t[1])
400 {
401 // ((_ iand k) x x) ---> x
402 return RewriteResponse(REWRITE_DONE, t[0]);
403 }
404 // simplifications involving constants
405 for (unsigned i = 0; i < 2; i++)
406 {
407 if (!t[i].isConst())
408 {
409 continue;
410 }
411 if (t[i].getConst<Rational>().sgn() == 0)
412 {
413 // ((_ iand k) 0 y) ---> 0
414 return RewriteResponse(REWRITE_DONE, t[i]);
415 }
416 }
417 return RewriteResponse(REWRITE_DONE, t);
418 }
419
420 RewriteResponse ArithRewriter::preRewriteTranscendental(TNode t) {
421 return RewriteResponse(REWRITE_DONE, t);
422 }
423
424 RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) {
425 Trace("arith-tf-rewrite") << "Rewrite transcendental function : " << t << std::endl;
426 NodeManager* nm = NodeManager::currentNM();
427 switch( t.getKind() ){
428 case kind::EXPONENTIAL: {
429 if(t[0].getKind() == kind::CONST_RATIONAL){
430 Node one = nm->mkConst(Rational(1));
431 if(t[0].getConst<Rational>().sgn()>=0 && t[0].getType().isInteger() && t[0]!=one){
432 return RewriteResponse(
433 REWRITE_AGAIN,
434 nm->mkNode(kind::POW, nm->mkNode(kind::EXPONENTIAL, one), t[0]));
435 }else{
436 return RewriteResponse(REWRITE_DONE, t);
437 }
438 }
439 else if (t[0].getKind() == kind::PLUS)
440 {
441 std::vector<Node> product;
442 for (const Node tc : t[0])
443 {
444 product.push_back(nm->mkNode(kind::EXPONENTIAL, tc));
445 }
446 // We need to do a full rewrite here, since we can get exponentials of
447 // constants, e.g. when we are rewriting exp(2 + x)
448 return RewriteResponse(REWRITE_AGAIN_FULL,
449 nm->mkNode(kind::MULT, product));
450 }
451 }
452 break;
453 case kind::SINE:
454 if(t[0].getKind() == kind::CONST_RATIONAL){
455 const Rational& rat = t[0].getConst<Rational>();
456 if(rat.sgn() == 0){
457 return RewriteResponse(REWRITE_DONE, nm->mkConst(Rational(0)));
458 }
459 else if (rat.sgn() == -1)
460 {
461 Node ret =
462 nm->mkNode(kind::UMINUS, nm->mkNode(kind::SINE, nm->mkConst(-rat)));
463 return RewriteResponse(REWRITE_AGAIN_FULL, ret);
464 }
465 }else{
466 // get the factor of PI in the argument
467 Node pi_factor;
468 Node pi;
469 Node rem;
470 std::map<Node, Node> msum;
471 if (ArithMSum::getMonomialSum(t[0], msum))
472 {
473 pi = mkPi();
474 std::map<Node, Node>::iterator itm = msum.find(pi);
475 if (itm != msum.end())
476 {
477 if (itm->second.isNull())
478 {
479 pi_factor = mkRationalNode(Rational(1));
480 }
481 else
482 {
483 pi_factor = itm->second;
484 }
485 msum.erase(pi);
486 if (!msum.empty())
487 {
488 rem = ArithMSum::mkNode(msum);
489 }
490 }
491 }
492 else
493 {
494 Assert(false);
495 }
496
497 // if there is a factor of PI
498 if( !pi_factor.isNull() ){
499 Trace("arith-tf-rewrite-debug") << "Process pi factor = " << pi_factor << std::endl;
500 Rational r = pi_factor.getConst<Rational>();
501 Rational r_abs = r.abs();
502 Rational rone = Rational(1);
503 Node ntwo = mkRationalNode(Rational(2));
504 if (r_abs > rone)
505 {
506 //add/substract 2*pi beyond scope
507 Node ra_div_two = nm->mkNode(
508 kind::INTS_DIVISION, mkRationalNode(r_abs + rone), ntwo);
509 Node new_pi_factor;
510 if( r.sgn()==1 ){
511 new_pi_factor =
512 nm->mkNode(kind::MINUS,
513 pi_factor,
514 nm->mkNode(kind::MULT, ntwo, ra_div_two));
515 }else{
516 Assert(r.sgn() == -1);
517 new_pi_factor =
518 nm->mkNode(kind::PLUS,
519 pi_factor,
520 nm->mkNode(kind::MULT, ntwo, ra_div_two));
521 }
522 Node new_arg = nm->mkNode(kind::MULT, new_pi_factor, pi);
523 if (!rem.isNull())
524 {
525 new_arg = nm->mkNode(kind::PLUS, new_arg, rem);
526 }
527 // sin( 2*n*PI + x ) = sin( x )
528 return RewriteResponse(REWRITE_AGAIN_FULL,
529 nm->mkNode(kind::SINE, new_arg));
530 }
531 else if (r_abs == rone)
532 {
533 // sin( PI + x ) = -sin( x )
534 if (rem.isNull())
535 {
536 return RewriteResponse(REWRITE_DONE, mkRationalNode(Rational(0)));
537 }
538 else
539 {
540 return RewriteResponse(
541 REWRITE_AGAIN_FULL,
542 nm->mkNode(kind::UMINUS, nm->mkNode(kind::SINE, rem)));
543 }
544 }
545 else if (rem.isNull())
546 {
547 // other rational cases based on Niven's theorem
548 // (https://en.wikipedia.org/wiki/Niven%27s_theorem)
549 Integer one = Integer(1);
550 Integer two = Integer(2);
551 Integer six = Integer(6);
552 if (r_abs.getDenominator() == two)
553 {
554 Assert(r_abs.getNumerator() == one);
555 return RewriteResponse(REWRITE_DONE,
556 mkRationalNode(Rational(r.sgn())));
557 }
558 else if (r_abs.getDenominator() == six)
559 {
560 Integer five = Integer(5);
561 if (r_abs.getNumerator() == one || r_abs.getNumerator() == five)
562 {
563 return RewriteResponse(
564 REWRITE_DONE,
565 mkRationalNode(Rational(r.sgn()) / Rational(2)));
566 }
567 }
568 }
569 }
570 }
571 break;
572 case kind::COSINE: {
573 return RewriteResponse(
574 REWRITE_AGAIN_FULL,
575 nm->mkNode(kind::SINE,
576 nm->mkNode(kind::MINUS,
577 nm->mkNode(kind::MULT,
578 nm->mkConst(Rational(1) / Rational(2)),
579 mkPi()),
580 t[0])));
581 }
582 break;
583 case kind::TANGENT:
584 {
585 return RewriteResponse(REWRITE_AGAIN_FULL,
586 nm->mkNode(kind::DIVISION,
587 nm->mkNode(kind::SINE, t[0]),
588 nm->mkNode(kind::COSINE, t[0])));
589 }
590 break;
591 case kind::COSECANT:
592 {
593 return RewriteResponse(REWRITE_AGAIN_FULL,
594 nm->mkNode(kind::DIVISION,
595 mkRationalNode(Rational(1)),
596 nm->mkNode(kind::SINE, t[0])));
597 }
598 break;
599 case kind::SECANT:
600 {
601 return RewriteResponse(REWRITE_AGAIN_FULL,
602 nm->mkNode(kind::DIVISION,
603 mkRationalNode(Rational(1)),
604 nm->mkNode(kind::COSINE, t[0])));
605 }
606 break;
607 case kind::COTANGENT:
608 {
609 return RewriteResponse(REWRITE_AGAIN_FULL,
610 nm->mkNode(kind::DIVISION,
611 nm->mkNode(kind::COSINE, t[0]),
612 nm->mkNode(kind::SINE, t[0])));
613 }
614 break;
615 default:
616 break;
617 }
618 return RewriteResponse(REWRITE_DONE, t);
619 }
620
621 RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
622 if(atom.getKind() == kind::IS_INTEGER) {
623 if(atom[0].isConst()) {
624 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(atom[0].getConst<Rational>().isIntegral()));
625 }
626 if(atom[0].getType().isInteger()) {
627 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
628 }
629 // not supported, but this isn't the right place to complain
630 return RewriteResponse(REWRITE_DONE, atom);
631 } else if(atom.getKind() == kind::DIVISIBLE) {
632 if(atom[0].isConst()) {
633 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(bool((atom[0].getConst<Rational>() / atom.getOperator().getConst<Divisible>().k).isIntegral())));
634 }
635 if(atom.getOperator().getConst<Divisible>().k.isOne()) {
636 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
637 }
638 return RewriteResponse(REWRITE_AGAIN, NodeManager::currentNM()->mkNode(kind::EQUAL, NodeManager::currentNM()->mkNode(kind::INTS_MODULUS_TOTAL, atom[0], NodeManager::currentNM()->mkConst(Rational(atom.getOperator().getConst<Divisible>().k))), NodeManager::currentNM()->mkConst(Rational(0))));
639 }
640
641 // left |><| right
642 TNode left = atom[0];
643 TNode right = atom[1];
644
645 Polynomial pleft = Polynomial::parsePolynomial(left);
646 Polynomial pright = Polynomial::parsePolynomial(right);
647
648 Debug("arith::rewriter") << "pleft " << pleft.getNode() << std::endl;
649 Debug("arith::rewriter") << "pright " << pright.getNode() << std::endl;
650
651 Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
652 Assert(cmp.isNormalForm());
653 return RewriteResponse(REWRITE_DONE, cmp.getNode());
654 }
655
656 RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
657 Assert(isAtom(atom));
658
659 NodeManager* currNM = NodeManager::currentNM();
660
661 if(atom.getKind() == kind::EQUAL) {
662 if(atom[0] == atom[1]) {
663 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
664 }
665 }else if(atom.getKind() == kind::GT){
666 Node leq = currNM->mkNode(kind::LEQ, atom[0], atom[1]);
667 return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, leq));
668 }else if(atom.getKind() == kind::LT){
669 Node geq = currNM->mkNode(kind::GEQ, atom[0], atom[1]);
670 return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, geq));
671 }else if(atom.getKind() == kind::IS_INTEGER){
672 if(atom[0].getType().isInteger()){
673 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
674 }
675 }else if(atom.getKind() == kind::DIVISIBLE){
676 if(atom.getOperator().getConst<Divisible>().k.isOne()){
677 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
678 }
679 }
680
681 return RewriteResponse(REWRITE_DONE, atom);
682 }
683
684 RewriteResponse ArithRewriter::postRewrite(TNode t){
685 if(isTerm(t)){
686 RewriteResponse response = postRewriteTerm(t);
687 if (Debug.isOn("arith::rewriter") && response.d_status == REWRITE_DONE)
688 {
689 Polynomial::parsePolynomial(response.d_node);
690 }
691 return response;
692 }else if(isAtom(t)){
693 RewriteResponse response = postRewriteAtom(t);
694 if (Debug.isOn("arith::rewriter") && response.d_status == REWRITE_DONE)
695 {
696 Comparison::parseNormalForm(response.d_node);
697 }
698 return response;
699 }else{
700 Unreachable();
701 }
702 }
703
704 RewriteResponse ArithRewriter::preRewrite(TNode t){
705 if(isTerm(t)){
706 return preRewriteTerm(t);
707 }else if(isAtom(t)){
708 return preRewriteAtom(t);
709 }else{
710 Unreachable();
711 }
712 }
713
714 Node ArithRewriter::makeUnaryMinusNode(TNode n){
715 Rational qNegOne(-1);
716 return NodeManager::currentNM()->mkNode(kind::MULT, mkRationalNode(qNegOne),n);
717 }
718
719 Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
720 Node negR = makeUnaryMinusNode(r);
721 Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR);
722
723 return diff;
724 }
725
726 RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
727 Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind() == kind::DIVISION);
728
729 Node left = t[0];
730 Node right = t[1];
731 if(right.getKind() == kind::CONST_RATIONAL){
732 const Rational& den = right.getConst<Rational>();
733
734 if(den.isZero()){
735 if(t.getKind() == kind::DIVISION_TOTAL){
736 return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
737 }else{
738 // This is unsupported, but this is not a good place to complain
739 return RewriteResponse(REWRITE_DONE, t);
740 }
741 }
742 Assert(den != Rational(0));
743
744 if(left.getKind() == kind::CONST_RATIONAL){
745 const Rational& num = left.getConst<Rational>();
746 Rational div = num / den;
747 Node result = mkRationalNode(div);
748 return RewriteResponse(REWRITE_DONE, result);
749 }
750
751 Rational div = den.inverse();
752
753 Node result = mkRationalNode(div);
754
755 Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
756 if(pre){
757 return RewriteResponse(REWRITE_DONE, mult);
758 }else{
759 return RewriteResponse(REWRITE_AGAIN, mult);
760 }
761 }else{
762 return RewriteResponse(REWRITE_DONE, t);
763 }
764 }
765
766 RewriteResponse ArithRewriter::rewriteIntsDivMod(TNode t, bool pre)
767 {
768 NodeManager* nm = NodeManager::currentNM();
769 Kind k = t.getKind();
770 Node zero = nm->mkConst(Rational(0));
771 if (k == kind::INTS_MODULUS)
772 {
773 if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
774 {
775 // can immediately replace by INTS_MODULUS_TOTAL
776 Node ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, t[0], t[1]);
777 return returnRewrite(t, ret, Rewrite::MOD_TOTAL_BY_CONST);
778 }
779 }
780 if (k == kind::INTS_DIVISION)
781 {
782 if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
783 {
784 // can immediately replace by INTS_DIVISION_TOTAL
785 Node ret = nm->mkNode(kind::INTS_DIVISION_TOTAL, t[0], t[1]);
786 return returnRewrite(t, ret, Rewrite::DIV_TOTAL_BY_CONST);
787 }
788 }
789 return RewriteResponse(REWRITE_DONE, t);
790 }
791
792 RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre)
793 {
794 if (pre)
795 {
796 // do not rewrite at prewrite.
797 return RewriteResponse(REWRITE_DONE, t);
798 }
799 NodeManager* nm = NodeManager::currentNM();
800 Kind k = t.getKind();
801 Assert(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL);
802 TNode n = t[0];
803 TNode d = t[1];
804 bool dIsConstant = d.getKind() == kind::CONST_RATIONAL;
805 if(dIsConstant && d.getConst<Rational>().isZero()){
806 // (div x 0) ---> 0 or (mod x 0) ---> 0
807 return returnRewrite(t, mkRationalNode(0), Rewrite::DIV_MOD_BY_ZERO);
808 }else if(dIsConstant && d.getConst<Rational>().isOne()){
809 if (k == kind::INTS_MODULUS_TOTAL)
810 {
811 // (mod x 1) --> 0
812 return returnRewrite(t, mkRationalNode(0), Rewrite::MOD_BY_ONE);
813 }
814 Assert(k == kind::INTS_DIVISION_TOTAL);
815 // (div x 1) --> x
816 return returnRewrite(t, n, Rewrite::DIV_BY_ONE);
817 }
818 else if (dIsConstant && d.getConst<Rational>().sgn() < 0)
819 {
820 // pull negation
821 // (div x (- c)) ---> (- (div x c))
822 // (mod x (- c)) ---> (mod x c)
823 Node nn = nm->mkNode(k, t[0], nm->mkConst(-t[1].getConst<Rational>()));
824 Node ret = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL)
825 ? nm->mkNode(kind::UMINUS, nn)
826 : nn;
827 return returnRewrite(t, ret, Rewrite::DIV_MOD_PULL_NEG_DEN);
828 }
829 else if (dIsConstant && n.getKind() == kind::CONST_RATIONAL)
830 {
831 Assert(d.getConst<Rational>().isIntegral());
832 Assert(n.getConst<Rational>().isIntegral());
833 Assert(!d.getConst<Rational>().isZero());
834 Integer di = d.getConst<Rational>().getNumerator();
835 Integer ni = n.getConst<Rational>().getNumerator();
836
837 bool isDiv = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
838
839 Integer result = isDiv ? ni.euclidianDivideQuotient(di) : ni.euclidianDivideRemainder(di);
840
841 // constant evaluation
842 // (mod c1 c2) ---> c3 or (div c1 c2) ---> c3
843 Node resultNode = mkRationalNode(Rational(result));
844 return returnRewrite(t, resultNode, Rewrite::CONST_EVAL);
845 }
846 if (k == kind::INTS_MODULUS_TOTAL)
847 {
848 // Note these rewrites do not need to account for modulus by zero as being
849 // a UF, which is handled by the reduction of INTS_MODULUS.
850 Kind k0 = t[0].getKind();
851 if (k0 == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
852 {
853 // (mod (mod x c) c) --> (mod x c)
854 return returnRewrite(t, t[0], Rewrite::MOD_OVER_MOD);
855 }
856 else if (k0 == kind::NONLINEAR_MULT || k0 == kind::MULT || k0 == kind::PLUS)
857 {
858 // can drop all
859 std::vector<Node> newChildren;
860 bool childChanged = false;
861 for (const Node& tc : t[0])
862 {
863 if (tc.getKind() == kind::INTS_MODULUS_TOTAL && tc[1] == t[1])
864 {
865 newChildren.push_back(tc[0]);
866 childChanged = true;
867 continue;
868 }
869 newChildren.push_back(tc);
870 }
871 if (childChanged)
872 {
873 // (mod (op ... (mod x c) ...) c) ---> (mod (op ... x ...) c) where
874 // op is one of { NONLINEAR_MULT, MULT, PLUS }.
875 Node ret = nm->mkNode(k0, newChildren);
876 ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, ret, t[1]);
877 return returnRewrite(t, ret, Rewrite::MOD_CHILD_MOD);
878 }
879 }
880 }
881 else
882 {
883 Assert(k == kind::INTS_DIVISION_TOTAL);
884 // Note these rewrites do not need to account for division by zero as being
885 // a UF, which is handled by the reduction of INTS_DIVISION.
886 if (t[0].getKind() == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
887 {
888 // (div (mod x c) c) --> 0
889 Node ret = mkRationalNode(0);
890 return returnRewrite(t, ret, Rewrite::DIV_OVER_MOD);
891 }
892 }
893 return RewriteResponse(REWRITE_DONE, t);
894 }
895
896 RewriteResponse ArithRewriter::returnRewrite(TNode t, Node ret, Rewrite r)
897 {
898 Trace("arith-rewrite") << "ArithRewriter : " << t << " == " << ret << " by "
899 << r << std::endl;
900 return RewriteResponse(REWRITE_AGAIN_FULL, ret);
901 }
902
903 } // namespace arith
904 } // namespace theory
905 } // namespace cvc5