Merge pull request #28 from kbansal/sets
[cvc5.git] / src / theory / arith / arith_rewriter.cpp
1 /********************* */
2 /*! \file arith_rewriter.cpp
3 ** \verbatim
4 ** Original author: Tim King
5 ** Major contributors: Morgan Deters
6 ** Minor contributors (to current version): Dejan Jovanovic
7 ** This file is part of the CVC4 project.
8 ** Copyright (c) 2009-2013 New York University and The University of Iowa
9 ** See the file COPYING in the top-level source directory for licensing
10 ** information.\endverbatim
11 **
12 ** \brief [[ Add one-line brief description here ]]
13 **
14 ** [[ Add lengthier description here ]]
15 ** \todo document this file
16 **/
17
18 #include "theory/theory.h"
19 #include "theory/arith/normal_form.h"
20 #include "theory/arith/arith_rewriter.h"
21 #include "theory/arith/arith_utilities.h"
22
23 #include <vector>
24 #include <set>
25 #include <stack>
26
27 namespace CVC4 {
28 namespace theory {
29 namespace arith {
30
31 bool ArithRewriter::isAtom(TNode n) {
32 Kind k = n.getKind();
33 return arith::isRelationOperator(k) || k == kind::IS_INTEGER || k == kind::DIVISIBLE;
34 }
35
36 RewriteResponse ArithRewriter::rewriteConstant(TNode t){
37 Assert(t.isConst());
38 Assert(t.getKind() == kind::CONST_RATIONAL);
39
40 return RewriteResponse(REWRITE_DONE, t);
41 }
42
43 RewriteResponse ArithRewriter::rewriteVariable(TNode t){
44 Assert(t.isVar());
45
46 return RewriteResponse(REWRITE_DONE, t);
47 }
48
49 RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){
50 Assert(t.getKind()== kind::MINUS);
51
52 if(pre){
53 if(t[0] == t[1]){
54 Rational zero(0);
55 Node zeroNode = mkRationalNode(zero);
56 return RewriteResponse(REWRITE_DONE, zeroNode);
57 }else{
58 Node noMinus = makeSubtractionNode(t[0],t[1]);
59 return RewriteResponse(REWRITE_DONE, noMinus);
60 }
61 }else{
62 Polynomial minuend = Polynomial::parsePolynomial(t[0]);
63 Polynomial subtrahend = Polynomial::parsePolynomial(t[1]);
64 Polynomial diff = minuend - subtrahend;
65 return RewriteResponse(REWRITE_DONE, diff.getNode());
66 }
67 }
68
69 RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){
70 Assert(t.getKind()== kind::UMINUS);
71
72 if(t[0].getKind() == kind::CONST_RATIONAL){
73 Rational neg = -(t[0].getConst<Rational>());
74 return RewriteResponse(REWRITE_DONE, mkRationalNode(neg));
75 }
76
77 Node noUminus = makeUnaryMinusNode(t[0]);
78 if(pre)
79 return RewriteResponse(REWRITE_DONE, noUminus);
80 else
81 return RewriteResponse(REWRITE_AGAIN, noUminus);
82 }
83
84 RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
85 if(t.isConst()){
86 return rewriteConstant(t);
87 }else if(t.isVar()){
88 return rewriteVariable(t);
89 }else{
90 switch(Kind k = t.getKind()){
91 case kind::MINUS:
92 return rewriteMinus(t, true);
93 case kind::UMINUS:
94 return rewriteUMinus(t, true);
95 case kind::DIVISION:
96 case kind::DIVISION_TOTAL:
97 return rewriteDiv(t,true);
98 case kind::PLUS:
99 return preRewritePlus(t);
100 case kind::MULT:
101 return preRewriteMult(t);
102 case kind::INTS_DIVISION:
103 case kind::INTS_MODULUS:
104 return RewriteResponse(REWRITE_DONE, t);
105 case kind::INTS_DIVISION_TOTAL:
106 case kind::INTS_MODULUS_TOTAL:
107 return rewriteIntsDivModTotal(t,true);
108 case kind::ABS:
109 if(t[0].isConst()) {
110 const Rational& rat = t[0].getConst<Rational>();
111 if(rat >= 0) {
112 return RewriteResponse(REWRITE_DONE, t[0]);
113 } else {
114 return RewriteResponse(REWRITE_DONE,
115 NodeManager::currentNM()->mkConst(-rat));
116 }
117 }
118 return RewriteResponse(REWRITE_DONE, t);
119 case kind::IS_INTEGER:
120 case kind::TO_INTEGER:
121 return RewriteResponse(REWRITE_DONE, t);
122 case kind::TO_REAL:
123 return RewriteResponse(REWRITE_DONE, t[0]);
124 case kind::POW:
125 return RewriteResponse(REWRITE_DONE, t);
126 default:
127 Unhandled(k);
128 }
129 }
130 }
131
132 RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
133 if(t.isConst()){
134 return rewriteConstant(t);
135 }else if(t.isVar()){
136 return rewriteVariable(t);
137 }else{
138 switch(t.getKind()){
139 case kind::MINUS:
140 return rewriteMinus(t, false);
141 case kind::UMINUS:
142 return rewriteUMinus(t, false);
143 case kind::DIVISION:
144 case kind::DIVISION_TOTAL:
145 return rewriteDiv(t, false);
146 case kind::PLUS:
147 return postRewritePlus(t);
148 case kind::MULT:
149 return postRewriteMult(t);
150 case kind::INTS_DIVISION:
151 case kind::INTS_MODULUS:
152 return RewriteResponse(REWRITE_DONE, t);
153 case kind::INTS_DIVISION_TOTAL:
154 case kind::INTS_MODULUS_TOTAL:
155 return rewriteIntsDivModTotal(t, false);
156 case kind::ABS:
157 if(t[0].isConst()) {
158 const Rational& rat = t[0].getConst<Rational>();
159 if(rat >= 0) {
160 return RewriteResponse(REWRITE_DONE, t[0]);
161 } else {
162 return RewriteResponse(REWRITE_DONE,
163 NodeManager::currentNM()->mkConst(-rat));
164 }
165 }
166 case kind::TO_REAL:
167 return RewriteResponse(REWRITE_DONE, t[0]);
168 case kind::TO_INTEGER:
169 if(t[0].isConst()) {
170 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(Rational(t[0].getConst<Rational>().floor())));
171 }
172 if(t[0].getType().isInteger()) {
173 return RewriteResponse(REWRITE_DONE, t[0]);
174 }
175 //Unimplemented("TO_INTEGER, nonconstant");
176 //return rewriteToInteger(t);
177 return RewriteResponse(REWRITE_DONE, t);
178 case kind::IS_INTEGER:
179 if(t[0].isConst()) {
180 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(t[0].getConst<Rational>().getDenominator() == 1));
181 }
182 if(t[0].getType().isInteger()) {
183 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
184 }
185 //Unimplemented("IS_INTEGER, nonconstant");
186 //return rewriteIsInteger(t);
187 return RewriteResponse(REWRITE_DONE, t);
188 case kind::POW:
189 {
190 if(t[1].getKind() == kind::CONST_RATIONAL){
191 const Rational& exp = t[1].getConst<Rational>();
192 TNode base = t[0];
193 if(exp.sgn() == 0){
194 return RewriteResponse(REWRITE_DONE, mkRationalNode(Rational(1)));
195 }else if(exp.sgn() > 0 && exp.isIntegral()){
196 Integer num = exp.getNumerator();
197 NodeBuilder<> nb(kind::MULT);
198 Integer one(1);
199 for(Integer i(0); i < num; i = i + one){
200 nb << base;
201 }
202 Assert(nb.getNumChildren() > 0);
203 Node mult = nb;
204 return RewriteResponse(REWRITE_AGAIN, mult);
205 }
206 }
207
208 // Todo improve the exception thrown
209 std::stringstream ss;
210 ss << "The POW(^) operator can only be used with a natural number ";
211 ss << "in the exponent. Exception occured in:" << std::endl;
212 ss << " " << t;
213 throw Exception(ss.str());
214 }
215 default:
216 Unreachable();
217 }
218 }
219 }
220
221
222 RewriteResponse ArithRewriter::preRewriteMult(TNode t){
223 Assert(t.getKind()== kind::MULT);
224
225 if(t.getNumChildren() == 2){
226 if(t[0].getKind() == kind::CONST_RATIONAL
227 && t[0].getConst<Rational>().isOne()){
228 return RewriteResponse(REWRITE_DONE, t[1]);
229 }
230 if(t[1].getKind() == kind::CONST_RATIONAL
231 && t[1].getConst<Rational>().isOne()){
232 return RewriteResponse(REWRITE_DONE, t[0]);
233 }
234 }
235
236 // Rewrite multiplications with a 0 argument and to 0
237 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
238 if((*i).getKind() == kind::CONST_RATIONAL) {
239 if((*i).getConst<Rational>().isZero()) {
240 TNode zero = (*i);
241 return RewriteResponse(REWRITE_DONE, zero);
242 }
243 }
244 }
245 return RewriteResponse(REWRITE_DONE, t);
246 }
247
248 static bool canFlatten(Kind k, TNode t){
249 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
250 TNode child = *i;
251 if(child.getKind() == k){
252 return true;
253 }
254 }
255 return false;
256 }
257
258 static void flatten(std::vector<TNode>& pb, Kind k, TNode t){
259 if(t.getKind() == k){
260 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
261 TNode child = *i;
262 if(child.getKind() == k){
263 flatten(pb, k, child);
264 }else{
265 pb.push_back(child);
266 }
267 }
268 }else{
269 pb.push_back(t);
270 }
271 }
272
273 static Node flatten(Kind k, TNode t){
274 std::vector<TNode> pb;
275 flatten(pb, k, t);
276 Assert(pb.size() >= 2);
277 return NodeManager::currentNM()->mkNode(k, pb);
278 }
279
280 RewriteResponse ArithRewriter::preRewritePlus(TNode t){
281 Assert(t.getKind()== kind::PLUS);
282
283 if(canFlatten(kind::PLUS, t)){
284 return RewriteResponse(REWRITE_DONE, flatten(kind::PLUS, t));
285 }else{
286 return RewriteResponse(REWRITE_DONE, t);
287 }
288 }
289
290 RewriteResponse ArithRewriter::postRewritePlus(TNode t){
291 Assert(t.getKind()== kind::PLUS);
292
293 std::vector<Monomial> monomials;
294 std::vector<Polynomial> polynomials;
295
296 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
297 TNode curr = *i;
298 if(Monomial::isMember(curr)){
299 monomials.push_back(Monomial::parseMonomial(curr));
300 }else{
301 polynomials.push_back(Polynomial::parsePolynomial(curr));
302 }
303 }
304
305 if(!monomials.empty()){
306 Monomial::sort(monomials);
307 Monomial::combineAdjacentMonomials(monomials);
308 polynomials.push_back(Polynomial::mkPolynomial(monomials));
309 }
310
311 Polynomial res = Polynomial::sumPolynomials(polynomials);
312
313 return RewriteResponse(REWRITE_DONE, res.getNode());
314 }
315
316 RewriteResponse ArithRewriter::postRewriteMult(TNode t){
317 Assert(t.getKind()== kind::MULT);
318
319 Polynomial res = Polynomial::mkOne();
320
321 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
322 Node curr = *i;
323 Polynomial currPoly = Polynomial::parsePolynomial(curr);
324
325 res = res * currPoly;
326 }
327
328 return RewriteResponse(REWRITE_DONE, res.getNode());
329 }
330
331 RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
332 if(atom.getKind() == kind::IS_INTEGER) {
333 if(atom[0].isConst()) {
334 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(atom[0].getConst<Rational>().isIntegral()));
335 }
336 if(atom[0].getType().isInteger()) {
337 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
338 }
339 // not supported, but this isn't the right place to complain
340 return RewriteResponse(REWRITE_DONE, atom);
341 } else if(atom.getKind() == kind::DIVISIBLE) {
342 if(atom[0].isConst()) {
343 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(bool((atom[0].getConst<Rational>() / atom.getOperator().getConst<Divisible>().k).isIntegral())));
344 }
345 if(atom.getOperator().getConst<Divisible>().k.isOne()) {
346 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
347 }
348 return RewriteResponse(REWRITE_AGAIN, NodeManager::currentNM()->mkNode(kind::EQUAL, NodeManager::currentNM()->mkNode(kind::INTS_MODULUS_TOTAL, atom[0], NodeManager::currentNM()->mkConst(Rational(atom.getOperator().getConst<Divisible>().k))), NodeManager::currentNM()->mkConst(Rational(0))));
349 }
350
351 // left |><| right
352 TNode left = atom[0];
353 TNode right = atom[1];
354
355 Polynomial pleft = Polynomial::parsePolynomial(left);
356 Polynomial pright = Polynomial::parsePolynomial(right);
357
358 Debug("arith::rewriter") << "pleft " << pleft.getNode() << std::endl;
359 Debug("arith::rewriter") << "pright " << pright.getNode() << std::endl;
360
361 Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
362 Assert(cmp.isNormalForm());
363 return RewriteResponse(REWRITE_DONE, cmp.getNode());
364 }
365
366 RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
367 Assert(isAtom(atom));
368
369 NodeManager* currNM = NodeManager::currentNM();
370
371 if(atom.getKind() == kind::EQUAL) {
372 if(atom[0] == atom[1]) {
373 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
374 }
375 }else if(atom.getKind() == kind::GT){
376 Node leq = currNM->mkNode(kind::LEQ, atom[0], atom[1]);
377 return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, leq));
378 }else if(atom.getKind() == kind::LT){
379 Node geq = currNM->mkNode(kind::GEQ, atom[0], atom[1]);
380 return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, geq));
381 }else if(atom.getKind() == kind::IS_INTEGER){
382 if(atom[0].getType().isInteger()){
383 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
384 }
385 }else if(atom.getKind() == kind::DIVISIBLE){
386 if(atom.getOperator().getConst<Divisible>().k.isOne()){
387 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
388 }
389 }
390
391 return RewriteResponse(REWRITE_DONE, atom);
392 }
393
394 RewriteResponse ArithRewriter::postRewrite(TNode t){
395 if(isTerm(t)){
396 RewriteResponse response = postRewriteTerm(t);
397 if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
398 Polynomial::parsePolynomial(response.node);
399 }
400 return response;
401 }else if(isAtom(t)){
402 RewriteResponse response = postRewriteAtom(t);
403 if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
404 Comparison::parseNormalForm(response.node);
405 }
406 return response;
407 }else{
408 Unreachable();
409 return RewriteResponse(REWRITE_DONE, Node::null());
410 }
411 }
412
413 RewriteResponse ArithRewriter::preRewrite(TNode t){
414 if(isTerm(t)){
415 return preRewriteTerm(t);
416 }else if(isAtom(t)){
417 return preRewriteAtom(t);
418 }else{
419 Unreachable();
420 return RewriteResponse(REWRITE_DONE, Node::null());
421 }
422 }
423
424 Node ArithRewriter::makeUnaryMinusNode(TNode n){
425 Rational qNegOne(-1);
426 return NodeManager::currentNM()->mkNode(kind::MULT, mkRationalNode(qNegOne),n);
427 }
428
429 Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
430 Node negR = makeUnaryMinusNode(r);
431 Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR);
432
433 return diff;
434 }
435
436 RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
437 Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind()== kind::DIVISION);
438
439
440 Node left = t[0];
441 Node right = t[1];
442 if(right.getKind() == kind::CONST_RATIONAL){
443 const Rational& den = right.getConst<Rational>();
444
445 if(den.isZero()){
446 if(t.getKind() == kind::DIVISION_TOTAL){
447 return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
448 }else{
449 // This is unsupported, but this is not a good place to complain
450 return RewriteResponse(REWRITE_DONE, t);
451 }
452 }
453 Assert(den != Rational(0));
454
455 if(left.getKind() == kind::CONST_RATIONAL){
456 const Rational& num = left.getConst<Rational>();
457 Rational div = num / den;
458 Node result = mkRationalNode(div);
459 return RewriteResponse(REWRITE_DONE, result);
460 }
461
462 Rational div = den.inverse();
463
464 Node result = mkRationalNode(div);
465
466 Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
467 if(pre){
468 return RewriteResponse(REWRITE_DONE, mult);
469 }else{
470 return RewriteResponse(REWRITE_AGAIN, mult);
471 }
472 }else{
473 return RewriteResponse(REWRITE_DONE, t);
474 }
475 }
476
477 RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre){
478 Kind k = t.getKind();
479 // Assert(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL ||
480 // k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
481
482 //Leaving the function as before (INTS_MODULUS can be handled),
483 // but restricting its use here
484 Assert(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL);
485 TNode n = t[0], d = t[1];
486 bool dIsConstant = d.getKind() == kind::CONST_RATIONAL;
487 if(dIsConstant && d.getConst<Rational>().isZero()){
488 if(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL){
489 return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
490 }else{
491 // Do nothing for k == INTS_MODULUS
492 return RewriteResponse(REWRITE_DONE, t);
493 }
494 }else if(dIsConstant && d.getConst<Rational>().isOne()){
495 if(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL){
496 return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
497 }else{
498 Assert(k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
499 return RewriteResponse(REWRITE_AGAIN, n);
500 }
501 }else if(dIsConstant && d.getConst<Rational>().isNegativeOne()){
502 if(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL){
503 return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
504 }else{
505 Assert(k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
506 return RewriteResponse(REWRITE_AGAIN, NodeManager::currentNM()->mkNode(kind::UMINUS, n));
507 }
508 }else if(dIsConstant && n.getKind() == kind::CONST_RATIONAL){
509 Assert(d.getConst<Rational>().isIntegral());
510 Assert(n.getConst<Rational>().isIntegral());
511 Assert(!d.getConst<Rational>().isZero());
512 Integer di = d.getConst<Rational>().getNumerator();
513 Integer ni = n.getConst<Rational>().getNumerator();
514
515 bool isDiv = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
516
517 Integer result = isDiv ? ni.euclidianDivideQuotient(di) : ni.euclidianDivideRemainder(di);
518
519 Node resultNode = mkRationalNode(Rational(result));
520 return RewriteResponse(REWRITE_DONE, resultNode);
521 }else{
522 return RewriteResponse(REWRITE_DONE, t);
523 }
524 }
525
526 }/* CVC4::theory::arith namespace */
527 }/* CVC4::theory namespace */
528 }/* CVC4 namespace */