More precise includes of `Node` constants (#6617)
[cvc5.git] / src / theory / arith / arith_rewriter.cpp
1 /******************************************************************************
2 * Top contributors (to current version):
3 * Andrew Reynolds, Tim King, Morgan Deters
4 *
5 * This file is part of the cvc5 project.
6 *
7 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 * in the top-level source directory and their institutional affiliations.
9 * All rights reserved. See the file COPYING in the top-level source
10 * directory for licensing information.
11 * ****************************************************************************
12 *
13 * [[ Add one-line brief description here ]]
14 *
15 * [[ Add lengthier description here ]]
16 * \todo document this file
17 */
18
19 #include "theory/arith/arith_rewriter.h"
20
21 #include <set>
22 #include <sstream>
23 #include <stack>
24 #include <vector>
25
26 #include "smt/logic_exception.h"
27 #include "theory/arith/arith_msum.h"
28 #include "theory/arith/arith_utilities.h"
29 #include "theory/arith/normal_form.h"
30 #include "theory/arith/operator_elim.h"
31 #include "theory/theory.h"
32 #include "util/bitvector.h"
33 #include "util/divisible.h"
34 #include "util/iand.h"
35
36 namespace cvc5 {
37 namespace theory {
38 namespace arith {
39
40 ArithRewriter::ArithRewriter(OperatorElim& oe) : d_opElim(oe) {}
41
42 bool ArithRewriter::isAtom(TNode n) {
43 Kind k = n.getKind();
44 return arith::isRelationOperator(k) || k == kind::IS_INTEGER
45 || k == kind::DIVISIBLE;
46 }
47
48 RewriteResponse ArithRewriter::rewriteConstant(TNode t){
49 Assert(t.isConst());
50 Assert(t.getKind() == kind::CONST_RATIONAL);
51
52 return RewriteResponse(REWRITE_DONE, t);
53 }
54
55 RewriteResponse ArithRewriter::rewriteVariable(TNode t){
56 Assert(t.isVar());
57
58 return RewriteResponse(REWRITE_DONE, t);
59 }
60
61 RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){
62 Assert(t.getKind() == kind::MINUS);
63
64 if(pre){
65 if(t[0] == t[1]){
66 Rational zero(0);
67 Node zeroNode = mkRationalNode(zero);
68 return RewriteResponse(REWRITE_DONE, zeroNode);
69 }else{
70 Node noMinus = makeSubtractionNode(t[0],t[1]);
71 return RewriteResponse(REWRITE_DONE, noMinus);
72 }
73 }else{
74 Polynomial minuend = Polynomial::parsePolynomial(t[0]);
75 Polynomial subtrahend = Polynomial::parsePolynomial(t[1]);
76 Polynomial diff = minuend - subtrahend;
77 return RewriteResponse(REWRITE_DONE, diff.getNode());
78 }
79 }
80
81 RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){
82 Assert(t.getKind() == kind::UMINUS);
83
84 if(t[0].getKind() == kind::CONST_RATIONAL){
85 Rational neg = -(t[0].getConst<Rational>());
86 return RewriteResponse(REWRITE_DONE, mkRationalNode(neg));
87 }
88
89 Node noUminus = makeUnaryMinusNode(t[0]);
90 if(pre)
91 return RewriteResponse(REWRITE_DONE, noUminus);
92 else
93 return RewriteResponse(REWRITE_AGAIN, noUminus);
94 }
95
96 RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
97 if(t.isConst()){
98 return rewriteConstant(t);
99 }else if(t.isVar()){
100 return rewriteVariable(t);
101 }else{
102 switch(Kind k = t.getKind()){
103 case kind::MINUS:
104 return rewriteMinus(t, true);
105 case kind::UMINUS:
106 return rewriteUMinus(t, true);
107 case kind::DIVISION:
108 case kind::DIVISION_TOTAL:
109 return rewriteDiv(t,true);
110 case kind::PLUS:
111 return preRewritePlus(t);
112 case kind::MULT:
113 case kind::NONLINEAR_MULT: return preRewriteMult(t);
114 case kind::IAND: return RewriteResponse(REWRITE_DONE, t);
115 case kind::EXPONENTIAL:
116 case kind::SINE:
117 case kind::COSINE:
118 case kind::TANGENT:
119 case kind::COSECANT:
120 case kind::SECANT:
121 case kind::COTANGENT:
122 case kind::ARCSINE:
123 case kind::ARCCOSINE:
124 case kind::ARCTANGENT:
125 case kind::ARCCOSECANT:
126 case kind::ARCSECANT:
127 case kind::ARCCOTANGENT:
128 case kind::SQRT: return preRewriteTranscendental(t);
129 case kind::INTS_DIVISION:
130 case kind::INTS_MODULUS: return rewriteIntsDivMod(t, true);
131 case kind::INTS_DIVISION_TOTAL:
132 case kind::INTS_MODULUS_TOTAL:
133 return rewriteIntsDivModTotal(t,true);
134 case kind::ABS:
135 if(t[0].isConst()) {
136 const Rational& rat = t[0].getConst<Rational>();
137 if(rat >= 0) {
138 return RewriteResponse(REWRITE_DONE, t[0]);
139 } else {
140 return RewriteResponse(REWRITE_DONE,
141 NodeManager::currentNM()->mkConst(-rat));
142 }
143 }
144 return RewriteResponse(REWRITE_DONE, t);
145 case kind::IS_INTEGER:
146 case kind::TO_INTEGER:
147 return RewriteResponse(REWRITE_DONE, t);
148 case kind::TO_REAL:
149 case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
150 case kind::POW:
151 return RewriteResponse(REWRITE_DONE, t);
152 case kind::PI:
153 return RewriteResponse(REWRITE_DONE, t);
154 default: Unhandled() << k;
155 }
156 }
157 }
158
159 RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
160 if(t.isConst()){
161 return rewriteConstant(t);
162 }else if(t.isVar()){
163 return rewriteVariable(t);
164 }else{
165 switch(t.getKind()){
166 case kind::MINUS:
167 return rewriteMinus(t, false);
168 case kind::UMINUS:
169 return rewriteUMinus(t, false);
170 case kind::DIVISION:
171 case kind::DIVISION_TOTAL:
172 return rewriteDiv(t, false);
173 case kind::PLUS:
174 return postRewritePlus(t);
175 case kind::MULT:
176 case kind::NONLINEAR_MULT: return postRewriteMult(t);
177 case kind::IAND: return postRewriteIAnd(t);
178 case kind::EXPONENTIAL:
179 case kind::SINE:
180 case kind::COSINE:
181 case kind::TANGENT:
182 case kind::COSECANT:
183 case kind::SECANT:
184 case kind::COTANGENT:
185 case kind::ARCSINE:
186 case kind::ARCCOSINE:
187 case kind::ARCTANGENT:
188 case kind::ARCCOSECANT:
189 case kind::ARCSECANT:
190 case kind::ARCCOTANGENT:
191 case kind::SQRT: return postRewriteTranscendental(t);
192 case kind::INTS_DIVISION:
193 case kind::INTS_MODULUS: return rewriteIntsDivMod(t, false);
194 case kind::INTS_DIVISION_TOTAL:
195 case kind::INTS_MODULUS_TOTAL:
196 return rewriteIntsDivModTotal(t, false);
197 case kind::ABS:
198 if(t[0].isConst()) {
199 const Rational& rat = t[0].getConst<Rational>();
200 if(rat >= 0) {
201 return RewriteResponse(REWRITE_DONE, t[0]);
202 } else {
203 return RewriteResponse(REWRITE_DONE,
204 NodeManager::currentNM()->mkConst(-rat));
205 }
206 }
207 return RewriteResponse(REWRITE_DONE, t);
208 case kind::TO_REAL:
209 case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
210 case kind::TO_INTEGER:
211 if(t[0].isConst()) {
212 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(Rational(t[0].getConst<Rational>().floor())));
213 }
214 if(t[0].getType().isInteger()) {
215 return RewriteResponse(REWRITE_DONE, t[0]);
216 }
217 //Unimplemented() << "TO_INTEGER, nonconstant";
218 //return rewriteToInteger(t);
219 return RewriteResponse(REWRITE_DONE, t);
220 case kind::IS_INTEGER:
221 if(t[0].isConst()) {
222 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(t[0].getConst<Rational>().getDenominator() == 1));
223 }
224 if(t[0].getType().isInteger()) {
225 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
226 }
227 //Unimplemented() << "IS_INTEGER, nonconstant";
228 //return rewriteIsInteger(t);
229 return RewriteResponse(REWRITE_DONE, t);
230 case kind::POW:
231 {
232 if(t[1].getKind() == kind::CONST_RATIONAL){
233 const Rational& exp = t[1].getConst<Rational>();
234 TNode base = t[0];
235 if(exp.sgn() == 0){
236 return RewriteResponse(REWRITE_DONE, mkRationalNode(Rational(1)));
237 }else if(exp.sgn() > 0 && exp.isIntegral()){
238 cvc5::Rational r(expr::NodeValue::MAX_CHILDREN);
239 if (exp <= r)
240 {
241 unsigned num = exp.getNumerator().toUnsignedInt();
242 if( num==1 ){
243 return RewriteResponse(REWRITE_AGAIN, base);
244 }else{
245 NodeBuilder nb(kind::MULT);
246 for(unsigned i=0; i < num; ++i){
247 nb << base;
248 }
249 Assert(nb.getNumChildren() > 0);
250 Node mult = nb;
251 return RewriteResponse(REWRITE_AGAIN, mult);
252 }
253 }
254 }
255 }
256
257 // Todo improve the exception thrown
258 std::stringstream ss;
259 ss << "The exponent of the POW(^) operator can only be a positive "
260 "integral constant below "
261 << (expr::NodeValue::MAX_CHILDREN + 1) << ". ";
262 ss << "Exception occurred in:" << std::endl;
263 ss << " " << t;
264 throw LogicException(ss.str());
265 }
266 case kind::PI:
267 return RewriteResponse(REWRITE_DONE, t);
268 default:
269 Unreachable();
270 }
271 }
272 }
273
274
275 RewriteResponse ArithRewriter::preRewriteMult(TNode t){
276 Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT);
277
278 if(t.getNumChildren() == 2){
279 if(t[0].getKind() == kind::CONST_RATIONAL
280 && t[0].getConst<Rational>().isOne()){
281 return RewriteResponse(REWRITE_DONE, t[1]);
282 }
283 if(t[1].getKind() == kind::CONST_RATIONAL
284 && t[1].getConst<Rational>().isOne()){
285 return RewriteResponse(REWRITE_DONE, t[0]);
286 }
287 }
288
289 // Rewrite multiplications with a 0 argument and to 0
290 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
291 if((*i).getKind() == kind::CONST_RATIONAL) {
292 if((*i).getConst<Rational>().isZero()) {
293 TNode zero = (*i);
294 return RewriteResponse(REWRITE_DONE, zero);
295 }
296 }
297 }
298 return RewriteResponse(REWRITE_DONE, t);
299 }
300
301 static bool canFlatten(Kind k, TNode t){
302 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
303 TNode child = *i;
304 if(child.getKind() == k){
305 return true;
306 }
307 }
308 return false;
309 }
310
311 static void flatten(std::vector<TNode>& pb, Kind k, TNode t){
312 if(t.getKind() == k){
313 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
314 TNode child = *i;
315 if(child.getKind() == k){
316 flatten(pb, k, child);
317 }else{
318 pb.push_back(child);
319 }
320 }
321 }else{
322 pb.push_back(t);
323 }
324 }
325
326 static Node flatten(Kind k, TNode t){
327 std::vector<TNode> pb;
328 flatten(pb, k, t);
329 Assert(pb.size() >= 2);
330 return NodeManager::currentNM()->mkNode(k, pb);
331 }
332
333 RewriteResponse ArithRewriter::preRewritePlus(TNode t){
334 Assert(t.getKind() == kind::PLUS);
335
336 if(canFlatten(kind::PLUS, t)){
337 return RewriteResponse(REWRITE_DONE, flatten(kind::PLUS, t));
338 }else{
339 return RewriteResponse(REWRITE_DONE, t);
340 }
341 }
342
343 RewriteResponse ArithRewriter::postRewritePlus(TNode t){
344 Assert(t.getKind() == kind::PLUS);
345
346 std::vector<Monomial> monomials;
347 std::vector<Polynomial> polynomials;
348
349 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
350 TNode curr = *i;
351 if(Monomial::isMember(curr)){
352 monomials.push_back(Monomial::parseMonomial(curr));
353 }else{
354 polynomials.push_back(Polynomial::parsePolynomial(curr));
355 }
356 }
357
358 if(!monomials.empty()){
359 Monomial::sort(monomials);
360 Monomial::combineAdjacentMonomials(monomials);
361 polynomials.push_back(Polynomial::mkPolynomial(monomials));
362 }
363
364 Polynomial res = Polynomial::sumPolynomials(polynomials);
365
366 return RewriteResponse(REWRITE_DONE, res.getNode());
367 }
368
369 RewriteResponse ArithRewriter::postRewriteMult(TNode t){
370 Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT);
371
372 Polynomial res = Polynomial::mkOne();
373
374 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
375 Node curr = *i;
376 Polynomial currPoly = Polynomial::parsePolynomial(curr);
377
378 res = res * currPoly;
379 }
380
381 return RewriteResponse(REWRITE_DONE, res.getNode());
382 }
383
384 RewriteResponse ArithRewriter::postRewriteIAnd(TNode t)
385 {
386 Assert(t.getKind() == kind::IAND);
387 NodeManager* nm = NodeManager::currentNM();
388 // if constant, we eliminate
389 if (t[0].isConst() && t[1].isConst())
390 {
391 size_t bsize = t.getOperator().getConst<IntAnd>().d_size;
392 Node iToBvop = nm->mkConst(IntToBitVector(bsize));
393 Node arg1 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[0]);
394 Node arg2 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[1]);
395 Node bvand = nm->mkNode(kind::BITVECTOR_AND, arg1, arg2);
396 Node ret = nm->mkNode(kind::BITVECTOR_TO_NAT, bvand);
397 return RewriteResponse(REWRITE_AGAIN_FULL, ret);
398 }
399 else if (t[0] > t[1])
400 {
401 // ((_ iand k) x y) ---> ((_ iand k) y x) if x > y by node ordering
402 Node ret = nm->mkNode(kind::IAND, t.getOperator(), t[1], t[0]);
403 return RewriteResponse(REWRITE_AGAIN, ret);
404 }
405 else if (t[0] == t[1])
406 {
407 // ((_ iand k) x x) ---> x
408 return RewriteResponse(REWRITE_DONE, t[0]);
409 }
410 // simplifications involving constants
411 for (unsigned i = 0; i < 2; i++)
412 {
413 if (!t[i].isConst())
414 {
415 continue;
416 }
417 if (t[i].getConst<Rational>().sgn() == 0)
418 {
419 // ((_ iand k) 0 y) ---> 0
420 return RewriteResponse(REWRITE_DONE, t[i]);
421 }
422 }
423 return RewriteResponse(REWRITE_DONE, t);
424 }
425
426 RewriteResponse ArithRewriter::preRewriteTranscendental(TNode t) {
427 return RewriteResponse(REWRITE_DONE, t);
428 }
429
430 RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) {
431 Trace("arith-tf-rewrite") << "Rewrite transcendental function : " << t << std::endl;
432 NodeManager* nm = NodeManager::currentNM();
433 switch( t.getKind() ){
434 case kind::EXPONENTIAL: {
435 if(t[0].getKind() == kind::CONST_RATIONAL){
436 Node one = nm->mkConst(Rational(1));
437 if(t[0].getConst<Rational>().sgn()>=0 && t[0].getType().isInteger() && t[0]!=one){
438 return RewriteResponse(
439 REWRITE_AGAIN,
440 nm->mkNode(kind::POW, nm->mkNode(kind::EXPONENTIAL, one), t[0]));
441 }else{
442 return RewriteResponse(REWRITE_DONE, t);
443 }
444 }
445 else if (t[0].getKind() == kind::PLUS)
446 {
447 std::vector<Node> product;
448 for (const Node tc : t[0])
449 {
450 product.push_back(nm->mkNode(kind::EXPONENTIAL, tc));
451 }
452 // We need to do a full rewrite here, since we can get exponentials of
453 // constants, e.g. when we are rewriting exp(2 + x)
454 return RewriteResponse(REWRITE_AGAIN_FULL,
455 nm->mkNode(kind::MULT, product));
456 }
457 }
458 break;
459 case kind::SINE:
460 if(t[0].getKind() == kind::CONST_RATIONAL){
461 const Rational& rat = t[0].getConst<Rational>();
462 if(rat.sgn() == 0){
463 return RewriteResponse(REWRITE_DONE, nm->mkConst(Rational(0)));
464 }
465 else if (rat.sgn() == -1)
466 {
467 Node ret =
468 nm->mkNode(kind::UMINUS, nm->mkNode(kind::SINE, nm->mkConst(-rat)));
469 return RewriteResponse(REWRITE_AGAIN_FULL, ret);
470 }
471 }else{
472 // get the factor of PI in the argument
473 Node pi_factor;
474 Node pi;
475 Node rem;
476 std::map<Node, Node> msum;
477 if (ArithMSum::getMonomialSum(t[0], msum))
478 {
479 pi = mkPi();
480 std::map<Node, Node>::iterator itm = msum.find(pi);
481 if (itm != msum.end())
482 {
483 if (itm->second.isNull())
484 {
485 pi_factor = mkRationalNode(Rational(1));
486 }
487 else
488 {
489 pi_factor = itm->second;
490 }
491 msum.erase(pi);
492 if (!msum.empty())
493 {
494 rem = ArithMSum::mkNode(msum);
495 }
496 }
497 }
498 else
499 {
500 Assert(false);
501 }
502
503 // if there is a factor of PI
504 if( !pi_factor.isNull() ){
505 Trace("arith-tf-rewrite-debug") << "Process pi factor = " << pi_factor << std::endl;
506 Rational r = pi_factor.getConst<Rational>();
507 Rational r_abs = r.abs();
508 Rational rone = Rational(1);
509 Node ntwo = mkRationalNode(Rational(2));
510 if (r_abs > rone)
511 {
512 //add/substract 2*pi beyond scope
513 Node ra_div_two = nm->mkNode(
514 kind::INTS_DIVISION, mkRationalNode(r_abs + rone), ntwo);
515 Node new_pi_factor;
516 if( r.sgn()==1 ){
517 new_pi_factor =
518 nm->mkNode(kind::MINUS,
519 pi_factor,
520 nm->mkNode(kind::MULT, ntwo, ra_div_two));
521 }else{
522 Assert(r.sgn() == -1);
523 new_pi_factor =
524 nm->mkNode(kind::PLUS,
525 pi_factor,
526 nm->mkNode(kind::MULT, ntwo, ra_div_two));
527 }
528 Node new_arg = nm->mkNode(kind::MULT, new_pi_factor, pi);
529 if (!rem.isNull())
530 {
531 new_arg = nm->mkNode(kind::PLUS, new_arg, rem);
532 }
533 // sin( 2*n*PI + x ) = sin( x )
534 return RewriteResponse(REWRITE_AGAIN_FULL,
535 nm->mkNode(kind::SINE, new_arg));
536 }
537 else if (r_abs == rone)
538 {
539 // sin( PI + x ) = -sin( x )
540 if (rem.isNull())
541 {
542 return RewriteResponse(REWRITE_DONE, mkRationalNode(Rational(0)));
543 }
544 else
545 {
546 return RewriteResponse(
547 REWRITE_AGAIN_FULL,
548 nm->mkNode(kind::UMINUS, nm->mkNode(kind::SINE, rem)));
549 }
550 }
551 else if (rem.isNull())
552 {
553 // other rational cases based on Niven's theorem
554 // (https://en.wikipedia.org/wiki/Niven%27s_theorem)
555 Integer one = Integer(1);
556 Integer two = Integer(2);
557 Integer six = Integer(6);
558 if (r_abs.getDenominator() == two)
559 {
560 Assert(r_abs.getNumerator() == one);
561 return RewriteResponse(REWRITE_DONE,
562 mkRationalNode(Rational(r.sgn())));
563 }
564 else if (r_abs.getDenominator() == six)
565 {
566 Integer five = Integer(5);
567 if (r_abs.getNumerator() == one || r_abs.getNumerator() == five)
568 {
569 return RewriteResponse(
570 REWRITE_DONE,
571 mkRationalNode(Rational(r.sgn()) / Rational(2)));
572 }
573 }
574 }
575 }
576 }
577 break;
578 case kind::COSINE: {
579 return RewriteResponse(
580 REWRITE_AGAIN_FULL,
581 nm->mkNode(kind::SINE,
582 nm->mkNode(kind::MINUS,
583 nm->mkNode(kind::MULT,
584 nm->mkConst(Rational(1) / Rational(2)),
585 mkPi()),
586 t[0])));
587 }
588 break;
589 case kind::TANGENT:
590 {
591 return RewriteResponse(REWRITE_AGAIN_FULL,
592 nm->mkNode(kind::DIVISION,
593 nm->mkNode(kind::SINE, t[0]),
594 nm->mkNode(kind::COSINE, t[0])));
595 }
596 break;
597 case kind::COSECANT:
598 {
599 return RewriteResponse(REWRITE_AGAIN_FULL,
600 nm->mkNode(kind::DIVISION,
601 mkRationalNode(Rational(1)),
602 nm->mkNode(kind::SINE, t[0])));
603 }
604 break;
605 case kind::SECANT:
606 {
607 return RewriteResponse(REWRITE_AGAIN_FULL,
608 nm->mkNode(kind::DIVISION,
609 mkRationalNode(Rational(1)),
610 nm->mkNode(kind::COSINE, t[0])));
611 }
612 break;
613 case kind::COTANGENT:
614 {
615 return RewriteResponse(REWRITE_AGAIN_FULL,
616 nm->mkNode(kind::DIVISION,
617 nm->mkNode(kind::COSINE, t[0]),
618 nm->mkNode(kind::SINE, t[0])));
619 }
620 break;
621 default:
622 break;
623 }
624 return RewriteResponse(REWRITE_DONE, t);
625 }
626
627 RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
628 if(atom.getKind() == kind::IS_INTEGER) {
629 if(atom[0].isConst()) {
630 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(atom[0].getConst<Rational>().isIntegral()));
631 }
632 if(atom[0].getType().isInteger()) {
633 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
634 }
635 // not supported, but this isn't the right place to complain
636 return RewriteResponse(REWRITE_DONE, atom);
637 } else if(atom.getKind() == kind::DIVISIBLE) {
638 if(atom[0].isConst()) {
639 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(bool((atom[0].getConst<Rational>() / atom.getOperator().getConst<Divisible>().k).isIntegral())));
640 }
641 if(atom.getOperator().getConst<Divisible>().k.isOne()) {
642 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
643 }
644 return RewriteResponse(REWRITE_AGAIN, NodeManager::currentNM()->mkNode(kind::EQUAL, NodeManager::currentNM()->mkNode(kind::INTS_MODULUS_TOTAL, atom[0], NodeManager::currentNM()->mkConst(Rational(atom.getOperator().getConst<Divisible>().k))), NodeManager::currentNM()->mkConst(Rational(0))));
645 }
646
647 // left |><| right
648 TNode left = atom[0];
649 TNode right = atom[1];
650
651 Polynomial pleft = Polynomial::parsePolynomial(left);
652 Polynomial pright = Polynomial::parsePolynomial(right);
653
654 Debug("arith::rewriter") << "pleft " << pleft.getNode() << std::endl;
655 Debug("arith::rewriter") << "pright " << pright.getNode() << std::endl;
656
657 Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
658 Assert(cmp.isNormalForm());
659 return RewriteResponse(REWRITE_DONE, cmp.getNode());
660 }
661
662 RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
663 Assert(isAtom(atom));
664
665 NodeManager* currNM = NodeManager::currentNM();
666
667 if(atom.getKind() == kind::EQUAL) {
668 if(atom[0] == atom[1]) {
669 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
670 }
671 }else if(atom.getKind() == kind::GT){
672 Node leq = currNM->mkNode(kind::LEQ, atom[0], atom[1]);
673 return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, leq));
674 }else if(atom.getKind() == kind::LT){
675 Node geq = currNM->mkNode(kind::GEQ, atom[0], atom[1]);
676 return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, geq));
677 }else if(atom.getKind() == kind::IS_INTEGER){
678 if(atom[0].getType().isInteger()){
679 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
680 }
681 }else if(atom.getKind() == kind::DIVISIBLE){
682 if(atom.getOperator().getConst<Divisible>().k.isOne()){
683 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
684 }
685 }
686
687 return RewriteResponse(REWRITE_DONE, atom);
688 }
689
690 RewriteResponse ArithRewriter::postRewrite(TNode t){
691 if(isTerm(t)){
692 RewriteResponse response = postRewriteTerm(t);
693 if (Debug.isOn("arith::rewriter") && response.d_status == REWRITE_DONE)
694 {
695 Polynomial::parsePolynomial(response.d_node);
696 }
697 return response;
698 }else if(isAtom(t)){
699 RewriteResponse response = postRewriteAtom(t);
700 if (Debug.isOn("arith::rewriter") && response.d_status == REWRITE_DONE)
701 {
702 Comparison::parseNormalForm(response.d_node);
703 }
704 return response;
705 }else{
706 Unreachable();
707 }
708 }
709
710 RewriteResponse ArithRewriter::preRewrite(TNode t){
711 if(isTerm(t)){
712 return preRewriteTerm(t);
713 }else if(isAtom(t)){
714 return preRewriteAtom(t);
715 }else{
716 Unreachable();
717 }
718 }
719
720 Node ArithRewriter::makeUnaryMinusNode(TNode n){
721 Rational qNegOne(-1);
722 return NodeManager::currentNM()->mkNode(kind::MULT, mkRationalNode(qNegOne),n);
723 }
724
725 Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
726 Node negR = makeUnaryMinusNode(r);
727 Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR);
728
729 return diff;
730 }
731
732 RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
733 Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind() == kind::DIVISION);
734
735 Node left = t[0];
736 Node right = t[1];
737 if(right.getKind() == kind::CONST_RATIONAL){
738 const Rational& den = right.getConst<Rational>();
739
740 if(den.isZero()){
741 if(t.getKind() == kind::DIVISION_TOTAL){
742 return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
743 }else{
744 // This is unsupported, but this is not a good place to complain
745 return RewriteResponse(REWRITE_DONE, t);
746 }
747 }
748 Assert(den != Rational(0));
749
750 if(left.getKind() == kind::CONST_RATIONAL){
751 const Rational& num = left.getConst<Rational>();
752 Rational div = num / den;
753 Node result = mkRationalNode(div);
754 return RewriteResponse(REWRITE_DONE, result);
755 }
756
757 Rational div = den.inverse();
758
759 Node result = mkRationalNode(div);
760
761 Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
762 if(pre){
763 return RewriteResponse(REWRITE_DONE, mult);
764 }else{
765 return RewriteResponse(REWRITE_AGAIN, mult);
766 }
767 }else{
768 return RewriteResponse(REWRITE_DONE, t);
769 }
770 }
771
772 RewriteResponse ArithRewriter::rewriteIntsDivMod(TNode t, bool pre)
773 {
774 NodeManager* nm = NodeManager::currentNM();
775 Kind k = t.getKind();
776 Node zero = nm->mkConst(Rational(0));
777 if (k == kind::INTS_MODULUS)
778 {
779 if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
780 {
781 // can immediately replace by INTS_MODULUS_TOTAL
782 Node ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, t[0], t[1]);
783 return returnRewrite(t, ret, Rewrite::MOD_TOTAL_BY_CONST);
784 }
785 }
786 if (k == kind::INTS_DIVISION)
787 {
788 if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
789 {
790 // can immediately replace by INTS_DIVISION_TOTAL
791 Node ret = nm->mkNode(kind::INTS_DIVISION_TOTAL, t[0], t[1]);
792 return returnRewrite(t, ret, Rewrite::DIV_TOTAL_BY_CONST);
793 }
794 }
795 return RewriteResponse(REWRITE_DONE, t);
796 }
797
798 RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre)
799 {
800 if (pre)
801 {
802 // do not rewrite at prewrite.
803 return RewriteResponse(REWRITE_DONE, t);
804 }
805 NodeManager* nm = NodeManager::currentNM();
806 Kind k = t.getKind();
807 Assert(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL);
808 TNode n = t[0];
809 TNode d = t[1];
810 bool dIsConstant = d.getKind() == kind::CONST_RATIONAL;
811 if(dIsConstant && d.getConst<Rational>().isZero()){
812 // (div x 0) ---> 0 or (mod x 0) ---> 0
813 return returnRewrite(t, mkRationalNode(0), Rewrite::DIV_MOD_BY_ZERO);
814 }else if(dIsConstant && d.getConst<Rational>().isOne()){
815 if (k == kind::INTS_MODULUS_TOTAL)
816 {
817 // (mod x 1) --> 0
818 return returnRewrite(t, mkRationalNode(0), Rewrite::MOD_BY_ONE);
819 }
820 Assert(k == kind::INTS_DIVISION_TOTAL);
821 // (div x 1) --> x
822 return returnRewrite(t, n, Rewrite::DIV_BY_ONE);
823 }
824 else if (dIsConstant && d.getConst<Rational>().sgn() < 0)
825 {
826 // pull negation
827 // (div x (- c)) ---> (- (div x c))
828 // (mod x (- c)) ---> (mod x c)
829 Node nn = nm->mkNode(k, t[0], nm->mkConst(-t[1].getConst<Rational>()));
830 Node ret = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL)
831 ? nm->mkNode(kind::UMINUS, nn)
832 : nn;
833 return returnRewrite(t, ret, Rewrite::DIV_MOD_PULL_NEG_DEN);
834 }
835 else if (dIsConstant && n.getKind() == kind::CONST_RATIONAL)
836 {
837 Assert(d.getConst<Rational>().isIntegral());
838 Assert(n.getConst<Rational>().isIntegral());
839 Assert(!d.getConst<Rational>().isZero());
840 Integer di = d.getConst<Rational>().getNumerator();
841 Integer ni = n.getConst<Rational>().getNumerator();
842
843 bool isDiv = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
844
845 Integer result = isDiv ? ni.euclidianDivideQuotient(di) : ni.euclidianDivideRemainder(di);
846
847 // constant evaluation
848 // (mod c1 c2) ---> c3 or (div c1 c2) ---> c3
849 Node resultNode = mkRationalNode(Rational(result));
850 return returnRewrite(t, resultNode, Rewrite::CONST_EVAL);
851 }
852 if (k == kind::INTS_MODULUS_TOTAL)
853 {
854 // Note these rewrites do not need to account for modulus by zero as being
855 // a UF, which is handled by the reduction of INTS_MODULUS.
856 Kind k0 = t[0].getKind();
857 if (k0 == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
858 {
859 // (mod (mod x c) c) --> (mod x c)
860 return returnRewrite(t, t[0], Rewrite::MOD_OVER_MOD);
861 }
862 else if (k0 == kind::NONLINEAR_MULT || k0 == kind::MULT || k0 == kind::PLUS)
863 {
864 // can drop all
865 std::vector<Node> newChildren;
866 bool childChanged = false;
867 for (const Node& tc : t[0])
868 {
869 if (tc.getKind() == kind::INTS_MODULUS_TOTAL && tc[1] == t[1])
870 {
871 newChildren.push_back(tc[0]);
872 childChanged = true;
873 continue;
874 }
875 newChildren.push_back(tc);
876 }
877 if (childChanged)
878 {
879 // (mod (op ... (mod x c) ...) c) ---> (mod (op ... x ...) c) where
880 // op is one of { NONLINEAR_MULT, MULT, PLUS }.
881 Node ret = nm->mkNode(k0, newChildren);
882 ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, ret, t[1]);
883 return returnRewrite(t, ret, Rewrite::MOD_CHILD_MOD);
884 }
885 }
886 }
887 else
888 {
889 Assert(k == kind::INTS_DIVISION_TOTAL);
890 // Note these rewrites do not need to account for division by zero as being
891 // a UF, which is handled by the reduction of INTS_DIVISION.
892 if (t[0].getKind() == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
893 {
894 // (div (mod x c) c) --> 0
895 Node ret = mkRationalNode(0);
896 return returnRewrite(t, ret, Rewrite::DIV_OVER_MOD);
897 }
898 }
899 return RewriteResponse(REWRITE_DONE, t);
900 }
901
902 TrustNode ArithRewriter::expandDefinition(Node node)
903 {
904 // call eliminate operators, to eliminate partial operators only
905 std::vector<SkolemLemma> lems;
906 TrustNode ret = d_opElim.eliminate(node, lems, true);
907 Assert(lems.empty());
908 return ret;
909 }
910
911 RewriteResponse ArithRewriter::returnRewrite(TNode t, Node ret, Rewrite r)
912 {
913 Trace("arith-rewrite") << "ArithRewriter : " << t << " == " << ret << " by "
914 << r << std::endl;
915 return RewriteResponse(REWRITE_AGAIN_FULL, ret);
916 }
917
918 } // namespace arith
919 } // namespace theory
920 } // namespace cvc5