Merge pull request #18 from timothy-king/master
[cvc5.git] / src / theory / arith / arith_rewriter.cpp
1 /********************* */
2 /*! \file arith_rewriter.cpp
3 ** \verbatim
4 ** Original author: Tim King
5 ** Major contributors: Morgan Deters
6 ** Minor contributors (to current version): Dejan Jovanovic
7 ** This file is part of the CVC4 project.
8 ** Copyright (c) 2009-2013 New York University and The University of Iowa
9 ** See the file COPYING in the top-level source directory for licensing
10 ** information.\endverbatim
11 **
12 ** \brief [[ Add one-line brief description here ]]
13 **
14 ** [[ Add lengthier description here ]]
15 ** \todo document this file
16 **/
17
18 #include "theory/theory.h"
19 #include "theory/arith/normal_form.h"
20 #include "theory/arith/arith_rewriter.h"
21 #include "theory/arith/arith_utilities.h"
22
23 #include <vector>
24 #include <set>
25 #include <stack>
26
27 namespace CVC4 {
28 namespace theory {
29 namespace arith {
30
31 bool ArithRewriter::isAtom(TNode n) {
32 Kind k = n.getKind();
33 return arith::isRelationOperator(k) || k == kind::IS_INTEGER || k == kind::DIVISIBLE;
34 }
35
36 RewriteResponse ArithRewriter::rewriteConstant(TNode t){
37 Assert(t.isConst());
38 Assert(t.getKind() == kind::CONST_RATIONAL);
39
40 return RewriteResponse(REWRITE_DONE, t);
41 }
42
43 RewriteResponse ArithRewriter::rewriteVariable(TNode t){
44 Assert(t.isVar());
45
46 return RewriteResponse(REWRITE_DONE, t);
47 }
48
49 RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){
50 Assert(t.getKind()== kind::MINUS);
51
52 if(pre){
53 if(t[0] == t[1]){
54 Rational zero(0);
55 Node zeroNode = mkRationalNode(zero);
56 return RewriteResponse(REWRITE_DONE, zeroNode);
57 }else{
58 Node noMinus = makeSubtractionNode(t[0],t[1]);
59 return RewriteResponse(REWRITE_DONE, noMinus);
60 }
61 }else{
62 Polynomial minuend = Polynomial::parsePolynomial(t[0]);
63 Polynomial subtrahend = Polynomial::parsePolynomial(t[1]);
64 Polynomial diff = minuend - subtrahend;
65 return RewriteResponse(REWRITE_DONE, diff.getNode());
66 }
67 }
68
69 RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){
70 Assert(t.getKind()== kind::UMINUS);
71
72 if(t[0].getKind() == kind::CONST_RATIONAL){
73 Rational neg = -(t[0].getConst<Rational>());
74 return RewriteResponse(REWRITE_DONE, mkRationalNode(neg));
75 }
76
77 Node noUminus = makeUnaryMinusNode(t[0]);
78 if(pre)
79 return RewriteResponse(REWRITE_DONE, noUminus);
80 else
81 return RewriteResponse(REWRITE_AGAIN, noUminus);
82 }
83
84 RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
85 if(t.isConst()){
86 return rewriteConstant(t);
87 }else if(t.isVar()){
88 return rewriteVariable(t);
89 }else{
90 switch(Kind k = t.getKind()){
91 case kind::MINUS:
92 return rewriteMinus(t, true);
93 case kind::UMINUS:
94 return rewriteUMinus(t, true);
95 case kind::DIVISION:
96 case kind::DIVISION_TOTAL:
97 return rewriteDiv(t,true);
98 case kind::PLUS:
99 return preRewritePlus(t);
100 case kind::MULT:
101 return preRewriteMult(t);
102 case kind::INTS_DIVISION:
103 case kind::INTS_MODULUS:
104 return RewriteResponse(REWRITE_DONE, t);
105 case kind::INTS_DIVISION_TOTAL:
106 case kind::INTS_MODULUS_TOTAL:
107 return rewriteIntsDivModTotal(t,true);
108 case kind::ABS:
109 if(t[0].isConst()) {
110 const Rational& rat = t[0].getConst<Rational>();
111 if(rat >= 0) {
112 return RewriteResponse(REWRITE_DONE, t[0]);
113 } else {
114 return RewriteResponse(REWRITE_DONE,
115 NodeManager::currentNM()->mkConst(-rat));
116 }
117 }
118 return RewriteResponse(REWRITE_DONE, t);
119 case kind::IS_INTEGER:
120 case kind::TO_INTEGER:
121 return RewriteResponse(REWRITE_DONE, t);
122 case kind::TO_REAL:
123 return RewriteResponse(REWRITE_DONE, t[0]);
124 case kind::POW:
125 return RewriteResponse(REWRITE_DONE, t);
126 default:
127 Unhandled(k);
128 }
129 }
130 }
131
132 RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
133 if(t.isConst()){
134 return rewriteConstant(t);
135 }else if(t.isVar()){
136 return rewriteVariable(t);
137 }else{
138 switch(t.getKind()){
139 case kind::MINUS:
140 return rewriteMinus(t, false);
141 case kind::UMINUS:
142 return rewriteUMinus(t, false);
143 case kind::DIVISION:
144 case kind::DIVISION_TOTAL:
145 return rewriteDiv(t, false);
146 case kind::PLUS:
147 return postRewritePlus(t);
148 case kind::MULT:
149 return postRewriteMult(t);
150 case kind::INTS_DIVISION:
151 case kind::INTS_MODULUS:
152 return RewriteResponse(REWRITE_DONE, t);
153 case kind::INTS_DIVISION_TOTAL:
154 case kind::INTS_MODULUS_TOTAL:
155 return rewriteIntsDivModTotal(t, false);
156 case kind::ABS:
157 if(t[0].isConst()) {
158 const Rational& rat = t[0].getConst<Rational>();
159 if(rat >= 0) {
160 return RewriteResponse(REWRITE_DONE, t[0]);
161 } else {
162 return RewriteResponse(REWRITE_DONE,
163 NodeManager::currentNM()->mkConst(-rat));
164 }
165 }
166 case kind::TO_REAL:
167 return RewriteResponse(REWRITE_DONE, t[0]);
168 case kind::TO_INTEGER:
169 if(t[0].isConst()) {
170 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(Rational(t[0].getConst<Rational>().floor())));
171 }
172 if(t[0].getType().isInteger()) {
173 return RewriteResponse(REWRITE_DONE, t[0]);
174 }
175 //Unimplemented("TO_INTEGER, nonconstant");
176 //return rewriteToInteger(t);
177 return RewriteResponse(REWRITE_DONE, t);
178 case kind::IS_INTEGER:
179 if(t[0].isConst()) {
180 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(t[0].getConst<Rational>().getDenominator() == 1));
181 }
182 if(t[0].getType().isInteger()) {
183 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
184 }
185 //Unimplemented("IS_INTEGER, nonconstant");
186 //return rewriteIsInteger(t);
187 return RewriteResponse(REWRITE_DONE, t);
188 case kind::POW:
189 {
190 if(t[1].getKind() == kind::CONST_RATIONAL){
191 const Rational& exp = t[1].getConst<Rational>();
192 TNode base = t[0];
193 if(exp.sgn() == 0){
194 return RewriteResponse(REWRITE_DONE, mkRationalNode(Rational(1)));
195 }else if(exp.sgn() > 0 && exp.isIntegral()){
196 Integer num = exp.getNumerator();
197 NodeBuilder<> nb(kind::MULT);
198 Integer one(1);
199 for(Integer i(0); i < num; i = i + one){
200 nb << base;
201 }
202 Assert(nb.getNumChildren() > 0);
203 Node mult = nb;
204 return RewriteResponse(REWRITE_AGAIN, mult);
205 }
206 }
207
208 // Todo improve the exception thrown
209 std::stringstream ss;
210 ss << "The POW(^) operator can only be used with a natural number ";
211 ss << "in the exponent. Exception occured in:" << std::endl;
212 ss << " " << t;
213 throw Exception(ss.str());
214 }
215 default:
216 Unreachable();
217 }
218 }
219 }
220
221
222 RewriteResponse ArithRewriter::preRewriteMult(TNode t){
223 Assert(t.getKind()== kind::MULT);
224
225 if(t.getNumChildren() == 2){
226 if(t[0].getKind() == kind::CONST_RATIONAL
227 && t[0].getConst<Rational>().isOne()){
228 return RewriteResponse(REWRITE_DONE, t[1]);
229 }
230 if(t[1].getKind() == kind::CONST_RATIONAL
231 && t[1].getConst<Rational>().isOne()){
232 return RewriteResponse(REWRITE_DONE, t[0]);
233 }
234 }
235
236 // Rewrite multiplications with a 0 argument and to 0
237 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
238 if((*i).getKind() == kind::CONST_RATIONAL) {
239 if((*i).getConst<Rational>().isZero()) {
240 TNode zero = (*i);
241 return RewriteResponse(REWRITE_DONE, zero);
242 }
243 }
244 }
245 return RewriteResponse(REWRITE_DONE, t);
246 }
247
248 static bool canFlatten(Kind k, TNode t){
249 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
250 TNode child = *i;
251 if(child.getKind() == k){
252 return true;
253 }
254 }
255 return false;
256 }
257
258 static void flatten(std::vector<TNode>& pb, Kind k, TNode t){
259 if(t.getKind() == k){
260 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
261 TNode child = *i;
262 if(child.getKind() == k){
263 flatten(pb, k, child);
264 }else{
265 pb.push_back(child);
266 }
267 }
268 }else{
269 pb.push_back(t);
270 }
271 }
272
273 static Node flatten(Kind k, TNode t){
274 std::vector<TNode> pb;
275 flatten(pb, k, t);
276 Assert(pb.size() >= 2);
277 return NodeManager::currentNM()->mkNode(k, pb);
278 }
279
280 RewriteResponse ArithRewriter::preRewritePlus(TNode t){
281 Assert(t.getKind()== kind::PLUS);
282
283 if(canFlatten(kind::PLUS, t)){
284 return RewriteResponse(REWRITE_DONE, flatten(kind::PLUS, t));
285 }else{
286 return RewriteResponse(REWRITE_DONE, t);
287 }
288 }
289
290 RewriteResponse ArithRewriter::postRewritePlus(TNode t){
291 Assert(t.getKind()== kind::PLUS);
292
293 std::vector<Monomial> monomials;
294 std::vector<Polynomial> polynomials;
295
296 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
297 TNode curr = *i;
298 if(Monomial::isMember(curr)){
299 monomials.push_back(Monomial::parseMonomial(curr));
300 }else{
301 polynomials.push_back(Polynomial::parsePolynomial(curr));
302 }
303 }
304
305 if(!monomials.empty()){
306 Monomial::sort(monomials);
307 Monomial::combineAdjacentMonomials(monomials);
308 polynomials.push_back(Polynomial::mkPolynomial(monomials));
309 }
310
311 Polynomial res = Polynomial::sumPolynomials(polynomials);
312
313 return RewriteResponse(REWRITE_DONE, res.getNode());
314 }
315
316 RewriteResponse ArithRewriter::postRewriteMult(TNode t){
317 Assert(t.getKind()== kind::MULT);
318
319 Polynomial res = Polynomial::mkOne();
320
321 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
322 Node curr = *i;
323 Polynomial currPoly = Polynomial::parsePolynomial(curr);
324
325 res = res * currPoly;
326 }
327
328 return RewriteResponse(REWRITE_DONE, res.getNode());
329 }
330
331 RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
332 if(atom.getKind() == kind::IS_INTEGER) {
333 if(atom[0].isConst()) {
334 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(atom[0].getConst<Rational>().isIntegral()));
335 }
336 if(atom[0].getType().isInteger()) {
337 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
338 }
339 // not supported, but this isn't the right place to complain
340 return RewriteResponse(REWRITE_DONE, atom);
341 } else if(atom.getKind() == kind::DIVISIBLE) {
342 if(atom[0].isConst()) {
343 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(bool((atom[0].getConst<Rational>() / atom.getOperator().getConst<Divisible>().k).isIntegral())));
344 }
345 if(atom.getOperator().getConst<Divisible>().k.isOne()) {
346 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
347 }
348 return RewriteResponse(REWRITE_AGAIN, NodeManager::currentNM()->mkNode(kind::EQUAL, NodeManager::currentNM()->mkNode(kind::INTS_MODULUS_TOTAL, atom[0], NodeManager::currentNM()->mkConst(Rational(atom.getOperator().getConst<Divisible>().k))), NodeManager::currentNM()->mkConst(Rational(0))));
349 }
350
351 // left |><| right
352 TNode left = atom[0];
353 TNode right = atom[1];
354
355 Polynomial pleft = Polynomial::parsePolynomial(left);
356 Polynomial pright = Polynomial::parsePolynomial(right);
357
358 Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
359 Assert(cmp.isNormalForm());
360 return RewriteResponse(REWRITE_DONE, cmp.getNode());
361 }
362
363 RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
364 Assert(isAtom(atom));
365
366 NodeManager* currNM = NodeManager::currentNM();
367
368 if(atom.getKind() == kind::EQUAL) {
369 if(atom[0] == atom[1]) {
370 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
371 }
372 }else if(atom.getKind() == kind::GT){
373 Node leq = currNM->mkNode(kind::LEQ, atom[0], atom[1]);
374 return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, leq));
375 }else if(atom.getKind() == kind::LT){
376 Node geq = currNM->mkNode(kind::GEQ, atom[0], atom[1]);
377 return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, geq));
378 }else if(atom.getKind() == kind::IS_INTEGER){
379 if(atom[0].getType().isInteger()){
380 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
381 }
382 }else if(atom.getKind() == kind::DIVISIBLE){
383 if(atom.getOperator().getConst<Divisible>().k.isOne()){
384 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
385 }
386 }
387
388 return RewriteResponse(REWRITE_DONE, atom);
389 }
390
391 RewriteResponse ArithRewriter::postRewrite(TNode t){
392 if(isTerm(t)){
393 RewriteResponse response = postRewriteTerm(t);
394 if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
395 Polynomial::parsePolynomial(response.node);
396 }
397 return response;
398 }else if(isAtom(t)){
399 RewriteResponse response = postRewriteAtom(t);
400 if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
401 Comparison::parseNormalForm(response.node);
402 }
403 return response;
404 }else{
405 Unreachable();
406 return RewriteResponse(REWRITE_DONE, Node::null());
407 }
408 }
409
410 RewriteResponse ArithRewriter::preRewrite(TNode t){
411 if(isTerm(t)){
412 return preRewriteTerm(t);
413 }else if(isAtom(t)){
414 return preRewriteAtom(t);
415 }else{
416 Unreachable();
417 return RewriteResponse(REWRITE_DONE, Node::null());
418 }
419 }
420
421 Node ArithRewriter::makeUnaryMinusNode(TNode n){
422 Rational qNegOne(-1);
423 return NodeManager::currentNM()->mkNode(kind::MULT, mkRationalNode(qNegOne),n);
424 }
425
426 Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
427 Node negR = makeUnaryMinusNode(r);
428 Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR);
429
430 return diff;
431 }
432
433 RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
434 Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind()== kind::DIVISION);
435
436
437 Node left = t[0];
438 Node right = t[1];
439 if(right.getKind() == kind::CONST_RATIONAL){
440 const Rational& den = right.getConst<Rational>();
441
442 if(den.isZero()){
443 if(t.getKind() == kind::DIVISION_TOTAL){
444 return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
445 }else{
446 // This is unsupported, but this is not a good place to complain
447 return RewriteResponse(REWRITE_DONE, t);
448 }
449 }
450 Assert(den != Rational(0));
451
452 if(left.getKind() == kind::CONST_RATIONAL){
453 const Rational& num = left.getConst<Rational>();
454 Rational div = num / den;
455 Node result = mkRationalNode(div);
456 return RewriteResponse(REWRITE_DONE, result);
457 }
458
459 Rational div = den.inverse();
460
461 Node result = mkRationalNode(div);
462
463 Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
464 if(pre){
465 return RewriteResponse(REWRITE_DONE, mult);
466 }else{
467 return RewriteResponse(REWRITE_AGAIN, mult);
468 }
469 }else{
470 return RewriteResponse(REWRITE_DONE, t);
471 }
472 }
473
474 RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre){
475 Kind k = t.getKind();
476 // Assert(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL ||
477 // k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
478
479 //Leaving the function as before (INTS_MODULUS can be handled),
480 // but restricting its use here
481 Assert(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL);
482 TNode n = t[0], d = t[1];
483 bool dIsConstant = d.getKind() == kind::CONST_RATIONAL;
484 if(dIsConstant && d.getConst<Rational>().isZero()){
485 if(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL){
486 return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
487 }else{
488 // Do nothing for k == INTS_MODULUS
489 return RewriteResponse(REWRITE_DONE, t);
490 }
491 }else if(dIsConstant && d.getConst<Rational>().isOne()){
492 if(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL){
493 return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
494 }else{
495 Assert(k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
496 return RewriteResponse(REWRITE_AGAIN, n);
497 }
498 }else if(dIsConstant && d.getConst<Rational>().isNegativeOne()){
499 if(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL){
500 return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
501 }else{
502 Assert(k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
503 return RewriteResponse(REWRITE_AGAIN, NodeManager::currentNM()->mkNode(kind::UMINUS, n));
504 }
505 }else if(dIsConstant && n.getKind() == kind::CONST_RATIONAL){
506 Assert(d.getConst<Rational>().isIntegral());
507 Assert(n.getConst<Rational>().isIntegral());
508 Assert(!d.getConst<Rational>().isZero());
509 Integer di = d.getConst<Rational>().getNumerator();
510 Integer ni = n.getConst<Rational>().getNumerator();
511
512 bool isDiv = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
513
514 Integer result = isDiv ? ni.euclidianDivideQuotient(di) : ni.euclidianDivideRemainder(di);
515
516 Node resultNode = mkRationalNode(Rational(result));
517 return RewriteResponse(REWRITE_DONE, resultNode);
518 }else{
519 return RewriteResponse(REWRITE_DONE, t);
520 }
521 }
522
523 }/* CVC4::theory::arith namespace */
524 }/* CVC4::theory namespace */
525 }/* CVC4 namespace */