Move expand definition from Theory to TheoryRewriter (#6408)
[cvc5.git] / src / theory / arith / arith_rewriter.cpp
1 /******************************************************************************
2 * Top contributors (to current version):
3 * Andrew Reynolds, Tim King, Morgan Deters
4 *
5 * This file is part of the cvc5 project.
6 *
7 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 * in the top-level source directory and their institutional affiliations.
9 * All rights reserved. See the file COPYING in the top-level source
10 * directory for licensing information.
11 * ****************************************************************************
12 *
13 * [[ Add one-line brief description here ]]
14 *
15 * [[ Add lengthier description here ]]
16 * \todo document this file
17 */
18
19 #include <set>
20 #include <sstream>
21 #include <stack>
22 #include <vector>
23
24 #include "smt/logic_exception.h"
25 #include "theory/arith/arith_msum.h"
26 #include "theory/arith/arith_rewriter.h"
27 #include "theory/arith/arith_utilities.h"
28 #include "theory/arith/normal_form.h"
29 #include "theory/arith/operator_elim.h"
30 #include "theory/theory.h"
31 #include "util/iand.h"
32
33 namespace cvc5 {
34 namespace theory {
35 namespace arith {
36
37 ArithRewriter::ArithRewriter(OperatorElim& oe) : d_opElim(oe) {}
38
39 bool ArithRewriter::isAtom(TNode n) {
40 Kind k = n.getKind();
41 return arith::isRelationOperator(k) || k == kind::IS_INTEGER
42 || k == kind::DIVISIBLE;
43 }
44
45 RewriteResponse ArithRewriter::rewriteConstant(TNode t){
46 Assert(t.isConst());
47 Assert(t.getKind() == kind::CONST_RATIONAL);
48
49 return RewriteResponse(REWRITE_DONE, t);
50 }
51
52 RewriteResponse ArithRewriter::rewriteVariable(TNode t){
53 Assert(t.isVar());
54
55 return RewriteResponse(REWRITE_DONE, t);
56 }
57
58 RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){
59 Assert(t.getKind() == kind::MINUS);
60
61 if(pre){
62 if(t[0] == t[1]){
63 Rational zero(0);
64 Node zeroNode = mkRationalNode(zero);
65 return RewriteResponse(REWRITE_DONE, zeroNode);
66 }else{
67 Node noMinus = makeSubtractionNode(t[0],t[1]);
68 return RewriteResponse(REWRITE_DONE, noMinus);
69 }
70 }else{
71 Polynomial minuend = Polynomial::parsePolynomial(t[0]);
72 Polynomial subtrahend = Polynomial::parsePolynomial(t[1]);
73 Polynomial diff = minuend - subtrahend;
74 return RewriteResponse(REWRITE_DONE, diff.getNode());
75 }
76 }
77
78 RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){
79 Assert(t.getKind() == kind::UMINUS);
80
81 if(t[0].getKind() == kind::CONST_RATIONAL){
82 Rational neg = -(t[0].getConst<Rational>());
83 return RewriteResponse(REWRITE_DONE, mkRationalNode(neg));
84 }
85
86 Node noUminus = makeUnaryMinusNode(t[0]);
87 if(pre)
88 return RewriteResponse(REWRITE_DONE, noUminus);
89 else
90 return RewriteResponse(REWRITE_AGAIN, noUminus);
91 }
92
93 RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
94 if(t.isConst()){
95 return rewriteConstant(t);
96 }else if(t.isVar()){
97 return rewriteVariable(t);
98 }else{
99 switch(Kind k = t.getKind()){
100 case kind::MINUS:
101 return rewriteMinus(t, true);
102 case kind::UMINUS:
103 return rewriteUMinus(t, true);
104 case kind::DIVISION:
105 case kind::DIVISION_TOTAL:
106 return rewriteDiv(t,true);
107 case kind::PLUS:
108 return preRewritePlus(t);
109 case kind::MULT:
110 case kind::NONLINEAR_MULT: return preRewriteMult(t);
111 case kind::IAND: return RewriteResponse(REWRITE_DONE, t);
112 case kind::EXPONENTIAL:
113 case kind::SINE:
114 case kind::COSINE:
115 case kind::TANGENT:
116 case kind::COSECANT:
117 case kind::SECANT:
118 case kind::COTANGENT:
119 case kind::ARCSINE:
120 case kind::ARCCOSINE:
121 case kind::ARCTANGENT:
122 case kind::ARCCOSECANT:
123 case kind::ARCSECANT:
124 case kind::ARCCOTANGENT:
125 case kind::SQRT: return preRewriteTranscendental(t);
126 case kind::INTS_DIVISION:
127 case kind::INTS_MODULUS: return rewriteIntsDivMod(t, true);
128 case kind::INTS_DIVISION_TOTAL:
129 case kind::INTS_MODULUS_TOTAL:
130 return rewriteIntsDivModTotal(t,true);
131 case kind::ABS:
132 if(t[0].isConst()) {
133 const Rational& rat = t[0].getConst<Rational>();
134 if(rat >= 0) {
135 return RewriteResponse(REWRITE_DONE, t[0]);
136 } else {
137 return RewriteResponse(REWRITE_DONE,
138 NodeManager::currentNM()->mkConst(-rat));
139 }
140 }
141 return RewriteResponse(REWRITE_DONE, t);
142 case kind::IS_INTEGER:
143 case kind::TO_INTEGER:
144 return RewriteResponse(REWRITE_DONE, t);
145 case kind::TO_REAL:
146 case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
147 case kind::POW:
148 return RewriteResponse(REWRITE_DONE, t);
149 case kind::PI:
150 return RewriteResponse(REWRITE_DONE, t);
151 default: Unhandled() << k;
152 }
153 }
154 }
155
156 RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
157 if(t.isConst()){
158 return rewriteConstant(t);
159 }else if(t.isVar()){
160 return rewriteVariable(t);
161 }else{
162 switch(t.getKind()){
163 case kind::MINUS:
164 return rewriteMinus(t, false);
165 case kind::UMINUS:
166 return rewriteUMinus(t, false);
167 case kind::DIVISION:
168 case kind::DIVISION_TOTAL:
169 return rewriteDiv(t, false);
170 case kind::PLUS:
171 return postRewritePlus(t);
172 case kind::MULT:
173 case kind::NONLINEAR_MULT: return postRewriteMult(t);
174 case kind::IAND: return postRewriteIAnd(t);
175 case kind::EXPONENTIAL:
176 case kind::SINE:
177 case kind::COSINE:
178 case kind::TANGENT:
179 case kind::COSECANT:
180 case kind::SECANT:
181 case kind::COTANGENT:
182 case kind::ARCSINE:
183 case kind::ARCCOSINE:
184 case kind::ARCTANGENT:
185 case kind::ARCCOSECANT:
186 case kind::ARCSECANT:
187 case kind::ARCCOTANGENT:
188 case kind::SQRT: return postRewriteTranscendental(t);
189 case kind::INTS_DIVISION:
190 case kind::INTS_MODULUS: return rewriteIntsDivMod(t, false);
191 case kind::INTS_DIVISION_TOTAL:
192 case kind::INTS_MODULUS_TOTAL:
193 return rewriteIntsDivModTotal(t, false);
194 case kind::ABS:
195 if(t[0].isConst()) {
196 const Rational& rat = t[0].getConst<Rational>();
197 if(rat >= 0) {
198 return RewriteResponse(REWRITE_DONE, t[0]);
199 } else {
200 return RewriteResponse(REWRITE_DONE,
201 NodeManager::currentNM()->mkConst(-rat));
202 }
203 }
204 return RewriteResponse(REWRITE_DONE, t);
205 case kind::TO_REAL:
206 case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
207 case kind::TO_INTEGER:
208 if(t[0].isConst()) {
209 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(Rational(t[0].getConst<Rational>().floor())));
210 }
211 if(t[0].getType().isInteger()) {
212 return RewriteResponse(REWRITE_DONE, t[0]);
213 }
214 //Unimplemented() << "TO_INTEGER, nonconstant";
215 //return rewriteToInteger(t);
216 return RewriteResponse(REWRITE_DONE, t);
217 case kind::IS_INTEGER:
218 if(t[0].isConst()) {
219 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(t[0].getConst<Rational>().getDenominator() == 1));
220 }
221 if(t[0].getType().isInteger()) {
222 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
223 }
224 //Unimplemented() << "IS_INTEGER, nonconstant";
225 //return rewriteIsInteger(t);
226 return RewriteResponse(REWRITE_DONE, t);
227 case kind::POW:
228 {
229 if(t[1].getKind() == kind::CONST_RATIONAL){
230 const Rational& exp = t[1].getConst<Rational>();
231 TNode base = t[0];
232 if(exp.sgn() == 0){
233 return RewriteResponse(REWRITE_DONE, mkRationalNode(Rational(1)));
234 }else if(exp.sgn() > 0 && exp.isIntegral()){
235 cvc5::Rational r(expr::NodeValue::MAX_CHILDREN);
236 if (exp <= r)
237 {
238 unsigned num = exp.getNumerator().toUnsignedInt();
239 if( num==1 ){
240 return RewriteResponse(REWRITE_AGAIN, base);
241 }else{
242 NodeBuilder nb(kind::MULT);
243 for(unsigned i=0; i < num; ++i){
244 nb << base;
245 }
246 Assert(nb.getNumChildren() > 0);
247 Node mult = nb;
248 return RewriteResponse(REWRITE_AGAIN, mult);
249 }
250 }
251 }
252 }
253
254 // Todo improve the exception thrown
255 std::stringstream ss;
256 ss << "The exponent of the POW(^) operator can only be a positive "
257 "integral constant below "
258 << (expr::NodeValue::MAX_CHILDREN + 1) << ". ";
259 ss << "Exception occurred in:" << std::endl;
260 ss << " " << t;
261 throw LogicException(ss.str());
262 }
263 case kind::PI:
264 return RewriteResponse(REWRITE_DONE, t);
265 default:
266 Unreachable();
267 }
268 }
269 }
270
271
272 RewriteResponse ArithRewriter::preRewriteMult(TNode t){
273 Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT);
274
275 if(t.getNumChildren() == 2){
276 if(t[0].getKind() == kind::CONST_RATIONAL
277 && t[0].getConst<Rational>().isOne()){
278 return RewriteResponse(REWRITE_DONE, t[1]);
279 }
280 if(t[1].getKind() == kind::CONST_RATIONAL
281 && t[1].getConst<Rational>().isOne()){
282 return RewriteResponse(REWRITE_DONE, t[0]);
283 }
284 }
285
286 // Rewrite multiplications with a 0 argument and to 0
287 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
288 if((*i).getKind() == kind::CONST_RATIONAL) {
289 if((*i).getConst<Rational>().isZero()) {
290 TNode zero = (*i);
291 return RewriteResponse(REWRITE_DONE, zero);
292 }
293 }
294 }
295 return RewriteResponse(REWRITE_DONE, t);
296 }
297
298 static bool canFlatten(Kind k, TNode t){
299 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
300 TNode child = *i;
301 if(child.getKind() == k){
302 return true;
303 }
304 }
305 return false;
306 }
307
308 static void flatten(std::vector<TNode>& pb, Kind k, TNode t){
309 if(t.getKind() == k){
310 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
311 TNode child = *i;
312 if(child.getKind() == k){
313 flatten(pb, k, child);
314 }else{
315 pb.push_back(child);
316 }
317 }
318 }else{
319 pb.push_back(t);
320 }
321 }
322
323 static Node flatten(Kind k, TNode t){
324 std::vector<TNode> pb;
325 flatten(pb, k, t);
326 Assert(pb.size() >= 2);
327 return NodeManager::currentNM()->mkNode(k, pb);
328 }
329
330 RewriteResponse ArithRewriter::preRewritePlus(TNode t){
331 Assert(t.getKind() == kind::PLUS);
332
333 if(canFlatten(kind::PLUS, t)){
334 return RewriteResponse(REWRITE_DONE, flatten(kind::PLUS, t));
335 }else{
336 return RewriteResponse(REWRITE_DONE, t);
337 }
338 }
339
340 RewriteResponse ArithRewriter::postRewritePlus(TNode t){
341 Assert(t.getKind() == kind::PLUS);
342
343 std::vector<Monomial> monomials;
344 std::vector<Polynomial> polynomials;
345
346 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
347 TNode curr = *i;
348 if(Monomial::isMember(curr)){
349 monomials.push_back(Monomial::parseMonomial(curr));
350 }else{
351 polynomials.push_back(Polynomial::parsePolynomial(curr));
352 }
353 }
354
355 if(!monomials.empty()){
356 Monomial::sort(monomials);
357 Monomial::combineAdjacentMonomials(monomials);
358 polynomials.push_back(Polynomial::mkPolynomial(monomials));
359 }
360
361 Polynomial res = Polynomial::sumPolynomials(polynomials);
362
363 return RewriteResponse(REWRITE_DONE, res.getNode());
364 }
365
366 RewriteResponse ArithRewriter::postRewriteMult(TNode t){
367 Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT);
368
369 Polynomial res = Polynomial::mkOne();
370
371 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
372 Node curr = *i;
373 Polynomial currPoly = Polynomial::parsePolynomial(curr);
374
375 res = res * currPoly;
376 }
377
378 return RewriteResponse(REWRITE_DONE, res.getNode());
379 }
380
381 RewriteResponse ArithRewriter::postRewriteIAnd(TNode t)
382 {
383 Assert(t.getKind() == kind::IAND);
384 NodeManager* nm = NodeManager::currentNM();
385 // if constant, we eliminate
386 if (t[0].isConst() && t[1].isConst())
387 {
388 size_t bsize = t.getOperator().getConst<IntAnd>().d_size;
389 Node iToBvop = nm->mkConst(IntToBitVector(bsize));
390 Node arg1 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[0]);
391 Node arg2 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[1]);
392 Node bvand = nm->mkNode(kind::BITVECTOR_AND, arg1, arg2);
393 Node ret = nm->mkNode(kind::BITVECTOR_TO_NAT, bvand);
394 return RewriteResponse(REWRITE_AGAIN_FULL, ret);
395 }
396 else if (t[0] > t[1])
397 {
398 // ((_ iand k) x y) ---> ((_ iand k) y x) if x > y by node ordering
399 Node ret = nm->mkNode(kind::IAND, t.getOperator(), t[1], t[0]);
400 return RewriteResponse(REWRITE_AGAIN, ret);
401 }
402 else if (t[0] == t[1])
403 {
404 // ((_ iand k) x x) ---> x
405 return RewriteResponse(REWRITE_DONE, t[0]);
406 }
407 // simplifications involving constants
408 for (unsigned i = 0; i < 2; i++)
409 {
410 if (!t[i].isConst())
411 {
412 continue;
413 }
414 if (t[i].getConst<Rational>().sgn() == 0)
415 {
416 // ((_ iand k) 0 y) ---> 0
417 return RewriteResponse(REWRITE_DONE, t[i]);
418 }
419 }
420 return RewriteResponse(REWRITE_DONE, t);
421 }
422
423 RewriteResponse ArithRewriter::preRewriteTranscendental(TNode t) {
424 return RewriteResponse(REWRITE_DONE, t);
425 }
426
427 RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) {
428 Trace("arith-tf-rewrite") << "Rewrite transcendental function : " << t << std::endl;
429 NodeManager* nm = NodeManager::currentNM();
430 switch( t.getKind() ){
431 case kind::EXPONENTIAL: {
432 if(t[0].getKind() == kind::CONST_RATIONAL){
433 Node one = nm->mkConst(Rational(1));
434 if(t[0].getConst<Rational>().sgn()>=0 && t[0].getType().isInteger() && t[0]!=one){
435 return RewriteResponse(
436 REWRITE_AGAIN,
437 nm->mkNode(kind::POW, nm->mkNode(kind::EXPONENTIAL, one), t[0]));
438 }else{
439 return RewriteResponse(REWRITE_DONE, t);
440 }
441 }
442 else if (t[0].getKind() == kind::PLUS)
443 {
444 std::vector<Node> product;
445 for (const Node tc : t[0])
446 {
447 product.push_back(nm->mkNode(kind::EXPONENTIAL, tc));
448 }
449 // We need to do a full rewrite here, since we can get exponentials of
450 // constants, e.g. when we are rewriting exp(2 + x)
451 return RewriteResponse(REWRITE_AGAIN_FULL,
452 nm->mkNode(kind::MULT, product));
453 }
454 }
455 break;
456 case kind::SINE:
457 if(t[0].getKind() == kind::CONST_RATIONAL){
458 const Rational& rat = t[0].getConst<Rational>();
459 if(rat.sgn() == 0){
460 return RewriteResponse(REWRITE_DONE, nm->mkConst(Rational(0)));
461 }
462 else if (rat.sgn() == -1)
463 {
464 Node ret =
465 nm->mkNode(kind::UMINUS, nm->mkNode(kind::SINE, nm->mkConst(-rat)));
466 return RewriteResponse(REWRITE_AGAIN_FULL, ret);
467 }
468 }else{
469 // get the factor of PI in the argument
470 Node pi_factor;
471 Node pi;
472 Node rem;
473 std::map<Node, Node> msum;
474 if (ArithMSum::getMonomialSum(t[0], msum))
475 {
476 pi = mkPi();
477 std::map<Node, Node>::iterator itm = msum.find(pi);
478 if (itm != msum.end())
479 {
480 if (itm->second.isNull())
481 {
482 pi_factor = mkRationalNode(Rational(1));
483 }
484 else
485 {
486 pi_factor = itm->second;
487 }
488 msum.erase(pi);
489 if (!msum.empty())
490 {
491 rem = ArithMSum::mkNode(msum);
492 }
493 }
494 }
495 else
496 {
497 Assert(false);
498 }
499
500 // if there is a factor of PI
501 if( !pi_factor.isNull() ){
502 Trace("arith-tf-rewrite-debug") << "Process pi factor = " << pi_factor << std::endl;
503 Rational r = pi_factor.getConst<Rational>();
504 Rational r_abs = r.abs();
505 Rational rone = Rational(1);
506 Node ntwo = mkRationalNode(Rational(2));
507 if (r_abs > rone)
508 {
509 //add/substract 2*pi beyond scope
510 Node ra_div_two = nm->mkNode(
511 kind::INTS_DIVISION, mkRationalNode(r_abs + rone), ntwo);
512 Node new_pi_factor;
513 if( r.sgn()==1 ){
514 new_pi_factor =
515 nm->mkNode(kind::MINUS,
516 pi_factor,
517 nm->mkNode(kind::MULT, ntwo, ra_div_two));
518 }else{
519 Assert(r.sgn() == -1);
520 new_pi_factor =
521 nm->mkNode(kind::PLUS,
522 pi_factor,
523 nm->mkNode(kind::MULT, ntwo, ra_div_two));
524 }
525 Node new_arg = nm->mkNode(kind::MULT, new_pi_factor, pi);
526 if (!rem.isNull())
527 {
528 new_arg = nm->mkNode(kind::PLUS, new_arg, rem);
529 }
530 // sin( 2*n*PI + x ) = sin( x )
531 return RewriteResponse(REWRITE_AGAIN_FULL,
532 nm->mkNode(kind::SINE, new_arg));
533 }
534 else if (r_abs == rone)
535 {
536 // sin( PI + x ) = -sin( x )
537 if (rem.isNull())
538 {
539 return RewriteResponse(REWRITE_DONE, mkRationalNode(Rational(0)));
540 }
541 else
542 {
543 return RewriteResponse(
544 REWRITE_AGAIN_FULL,
545 nm->mkNode(kind::UMINUS, nm->mkNode(kind::SINE, rem)));
546 }
547 }
548 else if (rem.isNull())
549 {
550 // other rational cases based on Niven's theorem
551 // (https://en.wikipedia.org/wiki/Niven%27s_theorem)
552 Integer one = Integer(1);
553 Integer two = Integer(2);
554 Integer six = Integer(6);
555 if (r_abs.getDenominator() == two)
556 {
557 Assert(r_abs.getNumerator() == one);
558 return RewriteResponse(REWRITE_DONE,
559 mkRationalNode(Rational(r.sgn())));
560 }
561 else if (r_abs.getDenominator() == six)
562 {
563 Integer five = Integer(5);
564 if (r_abs.getNumerator() == one || r_abs.getNumerator() == five)
565 {
566 return RewriteResponse(
567 REWRITE_DONE,
568 mkRationalNode(Rational(r.sgn()) / Rational(2)));
569 }
570 }
571 }
572 }
573 }
574 break;
575 case kind::COSINE: {
576 return RewriteResponse(
577 REWRITE_AGAIN_FULL,
578 nm->mkNode(kind::SINE,
579 nm->mkNode(kind::MINUS,
580 nm->mkNode(kind::MULT,
581 nm->mkConst(Rational(1) / Rational(2)),
582 mkPi()),
583 t[0])));
584 }
585 break;
586 case kind::TANGENT:
587 {
588 return RewriteResponse(REWRITE_AGAIN_FULL,
589 nm->mkNode(kind::DIVISION,
590 nm->mkNode(kind::SINE, t[0]),
591 nm->mkNode(kind::COSINE, t[0])));
592 }
593 break;
594 case kind::COSECANT:
595 {
596 return RewriteResponse(REWRITE_AGAIN_FULL,
597 nm->mkNode(kind::DIVISION,
598 mkRationalNode(Rational(1)),
599 nm->mkNode(kind::SINE, t[0])));
600 }
601 break;
602 case kind::SECANT:
603 {
604 return RewriteResponse(REWRITE_AGAIN_FULL,
605 nm->mkNode(kind::DIVISION,
606 mkRationalNode(Rational(1)),
607 nm->mkNode(kind::COSINE, t[0])));
608 }
609 break;
610 case kind::COTANGENT:
611 {
612 return RewriteResponse(REWRITE_AGAIN_FULL,
613 nm->mkNode(kind::DIVISION,
614 nm->mkNode(kind::COSINE, t[0]),
615 nm->mkNode(kind::SINE, t[0])));
616 }
617 break;
618 default:
619 break;
620 }
621 return RewriteResponse(REWRITE_DONE, t);
622 }
623
624 RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
625 if(atom.getKind() == kind::IS_INTEGER) {
626 if(atom[0].isConst()) {
627 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(atom[0].getConst<Rational>().isIntegral()));
628 }
629 if(atom[0].getType().isInteger()) {
630 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
631 }
632 // not supported, but this isn't the right place to complain
633 return RewriteResponse(REWRITE_DONE, atom);
634 } else if(atom.getKind() == kind::DIVISIBLE) {
635 if(atom[0].isConst()) {
636 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(bool((atom[0].getConst<Rational>() / atom.getOperator().getConst<Divisible>().k).isIntegral())));
637 }
638 if(atom.getOperator().getConst<Divisible>().k.isOne()) {
639 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
640 }
641 return RewriteResponse(REWRITE_AGAIN, NodeManager::currentNM()->mkNode(kind::EQUAL, NodeManager::currentNM()->mkNode(kind::INTS_MODULUS_TOTAL, atom[0], NodeManager::currentNM()->mkConst(Rational(atom.getOperator().getConst<Divisible>().k))), NodeManager::currentNM()->mkConst(Rational(0))));
642 }
643
644 // left |><| right
645 TNode left = atom[0];
646 TNode right = atom[1];
647
648 Polynomial pleft = Polynomial::parsePolynomial(left);
649 Polynomial pright = Polynomial::parsePolynomial(right);
650
651 Debug("arith::rewriter") << "pleft " << pleft.getNode() << std::endl;
652 Debug("arith::rewriter") << "pright " << pright.getNode() << std::endl;
653
654 Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
655 Assert(cmp.isNormalForm());
656 return RewriteResponse(REWRITE_DONE, cmp.getNode());
657 }
658
659 RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
660 Assert(isAtom(atom));
661
662 NodeManager* currNM = NodeManager::currentNM();
663
664 if(atom.getKind() == kind::EQUAL) {
665 if(atom[0] == atom[1]) {
666 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
667 }
668 }else if(atom.getKind() == kind::GT){
669 Node leq = currNM->mkNode(kind::LEQ, atom[0], atom[1]);
670 return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, leq));
671 }else if(atom.getKind() == kind::LT){
672 Node geq = currNM->mkNode(kind::GEQ, atom[0], atom[1]);
673 return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, geq));
674 }else if(atom.getKind() == kind::IS_INTEGER){
675 if(atom[0].getType().isInteger()){
676 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
677 }
678 }else if(atom.getKind() == kind::DIVISIBLE){
679 if(atom.getOperator().getConst<Divisible>().k.isOne()){
680 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
681 }
682 }
683
684 return RewriteResponse(REWRITE_DONE, atom);
685 }
686
687 RewriteResponse ArithRewriter::postRewrite(TNode t){
688 if(isTerm(t)){
689 RewriteResponse response = postRewriteTerm(t);
690 if (Debug.isOn("arith::rewriter") && response.d_status == REWRITE_DONE)
691 {
692 Polynomial::parsePolynomial(response.d_node);
693 }
694 return response;
695 }else if(isAtom(t)){
696 RewriteResponse response = postRewriteAtom(t);
697 if (Debug.isOn("arith::rewriter") && response.d_status == REWRITE_DONE)
698 {
699 Comparison::parseNormalForm(response.d_node);
700 }
701 return response;
702 }else{
703 Unreachable();
704 }
705 }
706
707 RewriteResponse ArithRewriter::preRewrite(TNode t){
708 if(isTerm(t)){
709 return preRewriteTerm(t);
710 }else if(isAtom(t)){
711 return preRewriteAtom(t);
712 }else{
713 Unreachable();
714 }
715 }
716
717 Node ArithRewriter::makeUnaryMinusNode(TNode n){
718 Rational qNegOne(-1);
719 return NodeManager::currentNM()->mkNode(kind::MULT, mkRationalNode(qNegOne),n);
720 }
721
722 Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
723 Node negR = makeUnaryMinusNode(r);
724 Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR);
725
726 return diff;
727 }
728
729 RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
730 Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind() == kind::DIVISION);
731
732 Node left = t[0];
733 Node right = t[1];
734 if(right.getKind() == kind::CONST_RATIONAL){
735 const Rational& den = right.getConst<Rational>();
736
737 if(den.isZero()){
738 if(t.getKind() == kind::DIVISION_TOTAL){
739 return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
740 }else{
741 // This is unsupported, but this is not a good place to complain
742 return RewriteResponse(REWRITE_DONE, t);
743 }
744 }
745 Assert(den != Rational(0));
746
747 if(left.getKind() == kind::CONST_RATIONAL){
748 const Rational& num = left.getConst<Rational>();
749 Rational div = num / den;
750 Node result = mkRationalNode(div);
751 return RewriteResponse(REWRITE_DONE, result);
752 }
753
754 Rational div = den.inverse();
755
756 Node result = mkRationalNode(div);
757
758 Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
759 if(pre){
760 return RewriteResponse(REWRITE_DONE, mult);
761 }else{
762 return RewriteResponse(REWRITE_AGAIN, mult);
763 }
764 }else{
765 return RewriteResponse(REWRITE_DONE, t);
766 }
767 }
768
769 RewriteResponse ArithRewriter::rewriteIntsDivMod(TNode t, bool pre)
770 {
771 NodeManager* nm = NodeManager::currentNM();
772 Kind k = t.getKind();
773 Node zero = nm->mkConst(Rational(0));
774 if (k == kind::INTS_MODULUS)
775 {
776 if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
777 {
778 // can immediately replace by INTS_MODULUS_TOTAL
779 Node ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, t[0], t[1]);
780 return returnRewrite(t, ret, Rewrite::MOD_TOTAL_BY_CONST);
781 }
782 }
783 if (k == kind::INTS_DIVISION)
784 {
785 if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
786 {
787 // can immediately replace by INTS_DIVISION_TOTAL
788 Node ret = nm->mkNode(kind::INTS_DIVISION_TOTAL, t[0], t[1]);
789 return returnRewrite(t, ret, Rewrite::DIV_TOTAL_BY_CONST);
790 }
791 }
792 return RewriteResponse(REWRITE_DONE, t);
793 }
794
795 RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre)
796 {
797 if (pre)
798 {
799 // do not rewrite at prewrite.
800 return RewriteResponse(REWRITE_DONE, t);
801 }
802 NodeManager* nm = NodeManager::currentNM();
803 Kind k = t.getKind();
804 Assert(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL);
805 TNode n = t[0];
806 TNode d = t[1];
807 bool dIsConstant = d.getKind() == kind::CONST_RATIONAL;
808 if(dIsConstant && d.getConst<Rational>().isZero()){
809 // (div x 0) ---> 0 or (mod x 0) ---> 0
810 return returnRewrite(t, mkRationalNode(0), Rewrite::DIV_MOD_BY_ZERO);
811 }else if(dIsConstant && d.getConst<Rational>().isOne()){
812 if (k == kind::INTS_MODULUS_TOTAL)
813 {
814 // (mod x 1) --> 0
815 return returnRewrite(t, mkRationalNode(0), Rewrite::MOD_BY_ONE);
816 }
817 Assert(k == kind::INTS_DIVISION_TOTAL);
818 // (div x 1) --> x
819 return returnRewrite(t, n, Rewrite::DIV_BY_ONE);
820 }
821 else if (dIsConstant && d.getConst<Rational>().sgn() < 0)
822 {
823 // pull negation
824 // (div x (- c)) ---> (- (div x c))
825 // (mod x (- c)) ---> (mod x c)
826 Node nn = nm->mkNode(k, t[0], nm->mkConst(-t[1].getConst<Rational>()));
827 Node ret = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL)
828 ? nm->mkNode(kind::UMINUS, nn)
829 : nn;
830 return returnRewrite(t, ret, Rewrite::DIV_MOD_PULL_NEG_DEN);
831 }
832 else if (dIsConstant && n.getKind() == kind::CONST_RATIONAL)
833 {
834 Assert(d.getConst<Rational>().isIntegral());
835 Assert(n.getConst<Rational>().isIntegral());
836 Assert(!d.getConst<Rational>().isZero());
837 Integer di = d.getConst<Rational>().getNumerator();
838 Integer ni = n.getConst<Rational>().getNumerator();
839
840 bool isDiv = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
841
842 Integer result = isDiv ? ni.euclidianDivideQuotient(di) : ni.euclidianDivideRemainder(di);
843
844 // constant evaluation
845 // (mod c1 c2) ---> c3 or (div c1 c2) ---> c3
846 Node resultNode = mkRationalNode(Rational(result));
847 return returnRewrite(t, resultNode, Rewrite::CONST_EVAL);
848 }
849 if (k == kind::INTS_MODULUS_TOTAL)
850 {
851 // Note these rewrites do not need to account for modulus by zero as being
852 // a UF, which is handled by the reduction of INTS_MODULUS.
853 Kind k0 = t[0].getKind();
854 if (k0 == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
855 {
856 // (mod (mod x c) c) --> (mod x c)
857 return returnRewrite(t, t[0], Rewrite::MOD_OVER_MOD);
858 }
859 else if (k0 == kind::NONLINEAR_MULT || k0 == kind::MULT || k0 == kind::PLUS)
860 {
861 // can drop all
862 std::vector<Node> newChildren;
863 bool childChanged = false;
864 for (const Node& tc : t[0])
865 {
866 if (tc.getKind() == kind::INTS_MODULUS_TOTAL && tc[1] == t[1])
867 {
868 newChildren.push_back(tc[0]);
869 childChanged = true;
870 continue;
871 }
872 newChildren.push_back(tc);
873 }
874 if (childChanged)
875 {
876 // (mod (op ... (mod x c) ...) c) ---> (mod (op ... x ...) c) where
877 // op is one of { NONLINEAR_MULT, MULT, PLUS }.
878 Node ret = nm->mkNode(k0, newChildren);
879 ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, ret, t[1]);
880 return returnRewrite(t, ret, Rewrite::MOD_CHILD_MOD);
881 }
882 }
883 }
884 else
885 {
886 Assert(k == kind::INTS_DIVISION_TOTAL);
887 // Note these rewrites do not need to account for division by zero as being
888 // a UF, which is handled by the reduction of INTS_DIVISION.
889 if (t[0].getKind() == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
890 {
891 // (div (mod x c) c) --> 0
892 Node ret = mkRationalNode(0);
893 return returnRewrite(t, ret, Rewrite::DIV_OVER_MOD);
894 }
895 }
896 return RewriteResponse(REWRITE_DONE, t);
897 }
898
899 TrustNode ArithRewriter::expandDefinition(Node node)
900 {
901 // call eliminate operators, to eliminate partial operators only
902 std::vector<SkolemLemma> lems;
903 TrustNode ret = d_opElim.eliminate(node, lems, true);
904 Assert(lems.empty());
905 return ret;
906 }
907
908 RewriteResponse ArithRewriter::returnRewrite(TNode t, Node ret, Rewrite r)
909 {
910 Trace("arith-rewrite") << "ArithRewriter : " << t << " == " << ret << " by "
911 << r << std::endl;
912 return RewriteResponse(REWRITE_AGAIN_FULL, ret);
913 }
914
915 } // namespace arith
916 } // namespace theory
917 } // namespace cvc5