Adding a model based axiom instantiation scheme for multiplication. Merge commit...
[cvc5.git] / src / theory / arith / arith_rewriter.cpp
1 /********************* */
2 /*! \file arith_rewriter.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Tim King, Morgan Deters, Dejan Jovanovic
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2016 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief [[ Add one-line brief description here ]]
13 **
14 ** [[ Add lengthier description here ]]
15 ** \todo document this file
16 **/
17
18 #include <set>
19 #include <stack>
20 #include <vector>
21
22 #include "smt/logic_exception.h"
23 #include "theory/arith/arith_rewriter.h"
24 #include "theory/arith/arith_utilities.h"
25 #include "theory/arith/normal_form.h"
26 #include "theory/theory.h"
27
28 namespace CVC4 {
29 namespace theory {
30 namespace arith {
31
32 bool ArithRewriter::isAtom(TNode n) {
33 Kind k = n.getKind();
34 return arith::isRelationOperator(k) || k == kind::IS_INTEGER
35 || k == kind::DIVISIBLE;
36 }
37
38 RewriteResponse ArithRewriter::rewriteConstant(TNode t){
39 Assert(t.isConst());
40 Assert(t.getKind() == kind::CONST_RATIONAL);
41
42 return RewriteResponse(REWRITE_DONE, t);
43 }
44
45 RewriteResponse ArithRewriter::rewriteVariable(TNode t){
46 Assert(t.isVar());
47
48 return RewriteResponse(REWRITE_DONE, t);
49 }
50
51 RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){
52 Assert(t.getKind()== kind::MINUS);
53
54 if(pre){
55 if(t[0] == t[1]){
56 Rational zero(0);
57 Node zeroNode = mkRationalNode(zero);
58 return RewriteResponse(REWRITE_DONE, zeroNode);
59 }else{
60 Node noMinus = makeSubtractionNode(t[0],t[1]);
61 return RewriteResponse(REWRITE_DONE, noMinus);
62 }
63 }else{
64 Polynomial minuend = Polynomial::parsePolynomial(t[0]);
65 Polynomial subtrahend = Polynomial::parsePolynomial(t[1]);
66 Polynomial diff = minuend - subtrahend;
67 return RewriteResponse(REWRITE_DONE, diff.getNode());
68 }
69 }
70
71 RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){
72 Assert(t.getKind()== kind::UMINUS);
73
74 if(t[0].getKind() == kind::CONST_RATIONAL){
75 Rational neg = -(t[0].getConst<Rational>());
76 return RewriteResponse(REWRITE_DONE, mkRationalNode(neg));
77 }
78
79 Node noUminus = makeUnaryMinusNode(t[0]);
80 if(pre)
81 return RewriteResponse(REWRITE_DONE, noUminus);
82 else
83 return RewriteResponse(REWRITE_AGAIN, noUminus);
84 }
85
86 RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
87 if(t.isConst()){
88 return rewriteConstant(t);
89 }else if(t.isVar()){
90 return rewriteVariable(t);
91 }else{
92 switch(Kind k = t.getKind()){
93 case kind::MINUS:
94 return rewriteMinus(t, true);
95 case kind::UMINUS:
96 return rewriteUMinus(t, true);
97 case kind::DIVISION:
98 case kind::DIVISION_TOTAL:
99 return rewriteDiv(t,true);
100 case kind::PLUS:
101 return preRewritePlus(t);
102 case kind::MULT:
103 case kind::NONLINEAR_MULT:
104 return preRewriteMult(t);
105 case kind::INTS_DIVISION:
106 case kind::INTS_MODULUS:
107 return RewriteResponse(REWRITE_DONE, t);
108 case kind::INTS_DIVISION_TOTAL:
109 case kind::INTS_MODULUS_TOTAL:
110 return rewriteIntsDivModTotal(t,true);
111 case kind::ABS:
112 if(t[0].isConst()) {
113 const Rational& rat = t[0].getConst<Rational>();
114 if(rat >= 0) {
115 return RewriteResponse(REWRITE_DONE, t[0]);
116 } else {
117 return RewriteResponse(REWRITE_DONE,
118 NodeManager::currentNM()->mkConst(-rat));
119 }
120 }
121 return RewriteResponse(REWRITE_DONE, t);
122 case kind::IS_INTEGER:
123 case kind::TO_INTEGER:
124 return RewriteResponse(REWRITE_DONE, t);
125 case kind::TO_REAL:
126 return RewriteResponse(REWRITE_DONE, t[0]);
127 case kind::POW:
128 return RewriteResponse(REWRITE_DONE, t);
129 default:
130 Unhandled(k);
131 }
132 }
133 }
134
135 RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
136 if(t.isConst()){
137 return rewriteConstant(t);
138 }else if(t.isVar()){
139 return rewriteVariable(t);
140 }else{
141 switch(t.getKind()){
142 case kind::MINUS:
143 return rewriteMinus(t, false);
144 case kind::UMINUS:
145 return rewriteUMinus(t, false);
146 case kind::DIVISION:
147 case kind::DIVISION_TOTAL:
148 return rewriteDiv(t, false);
149 case kind::PLUS:
150 return postRewritePlus(t);
151 case kind::MULT:
152 case kind::NONLINEAR_MULT:
153 return postRewriteMult(t);
154 case kind::INTS_DIVISION:
155 case kind::INTS_MODULUS:
156 return RewriteResponse(REWRITE_DONE, t);
157 case kind::INTS_DIVISION_TOTAL:
158 case kind::INTS_MODULUS_TOTAL:
159 return rewriteIntsDivModTotal(t, false);
160 case kind::ABS:
161 if(t[0].isConst()) {
162 const Rational& rat = t[0].getConst<Rational>();
163 if(rat >= 0) {
164 return RewriteResponse(REWRITE_DONE, t[0]);
165 } else {
166 return RewriteResponse(REWRITE_DONE,
167 NodeManager::currentNM()->mkConst(-rat));
168 }
169 }
170 case kind::TO_REAL:
171 return RewriteResponse(REWRITE_DONE, t[0]);
172 case kind::TO_INTEGER:
173 if(t[0].isConst()) {
174 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(Rational(t[0].getConst<Rational>().floor())));
175 }
176 if(t[0].getType().isInteger()) {
177 return RewriteResponse(REWRITE_DONE, t[0]);
178 }
179 //Unimplemented("TO_INTEGER, nonconstant");
180 //return rewriteToInteger(t);
181 return RewriteResponse(REWRITE_DONE, t);
182 case kind::IS_INTEGER:
183 if(t[0].isConst()) {
184 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(t[0].getConst<Rational>().getDenominator() == 1));
185 }
186 if(t[0].getType().isInteger()) {
187 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
188 }
189 //Unimplemented("IS_INTEGER, nonconstant");
190 //return rewriteIsInteger(t);
191 return RewriteResponse(REWRITE_DONE, t);
192 case kind::POW:
193 {
194 if(t[1].getKind() == kind::CONST_RATIONAL){
195 const Rational& exp = t[1].getConst<Rational>();
196 TNode base = t[0];
197 if(exp.sgn() == 0){
198 return RewriteResponse(REWRITE_DONE, mkRationalNode(Rational(1)));
199 }else if(exp.sgn() > 0 && exp.isIntegral()){
200 Integer num = exp.getNumerator();
201 NodeBuilder<> nb(kind::MULT);
202 Integer one(1);
203 for(Integer i(0); i < num; i = i + one){
204 nb << base;
205 }
206 Assert(nb.getNumChildren() > 0);
207 Node mult = nb;
208 return RewriteResponse(REWRITE_AGAIN, mult);
209 }
210 }
211
212 // Todo improve the exception thrown
213 std::stringstream ss;
214 ss << "The POW(^) operator can only be used with a natural number ";
215 ss << "in the exponent. Exception occured in:" << std::endl;
216 ss << " " << t;
217 throw LogicException(ss.str());
218 }
219 default:
220 Unreachable();
221 }
222 }
223 }
224
225
226 RewriteResponse ArithRewriter::preRewriteMult(TNode t){
227 Assert(t.getKind()== kind::MULT || t.getKind()== kind::NONLINEAR_MULT);
228
229 if(t.getNumChildren() == 2){
230 if(t[0].getKind() == kind::CONST_RATIONAL
231 && t[0].getConst<Rational>().isOne()){
232 return RewriteResponse(REWRITE_DONE, t[1]);
233 }
234 if(t[1].getKind() == kind::CONST_RATIONAL
235 && t[1].getConst<Rational>().isOne()){
236 return RewriteResponse(REWRITE_DONE, t[0]);
237 }
238 }
239
240 // Rewrite multiplications with a 0 argument and to 0
241 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
242 if((*i).getKind() == kind::CONST_RATIONAL) {
243 if((*i).getConst<Rational>().isZero()) {
244 TNode zero = (*i);
245 return RewriteResponse(REWRITE_DONE, zero);
246 }
247 }
248 }
249 return RewriteResponse(REWRITE_DONE, t);
250 }
251
252 static bool canFlatten(Kind k, TNode t){
253 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
254 TNode child = *i;
255 if(child.getKind() == k){
256 return true;
257 }
258 }
259 return false;
260 }
261
262 static void flatten(std::vector<TNode>& pb, Kind k, TNode t){
263 if(t.getKind() == k){
264 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
265 TNode child = *i;
266 if(child.getKind() == k){
267 flatten(pb, k, child);
268 }else{
269 pb.push_back(child);
270 }
271 }
272 }else{
273 pb.push_back(t);
274 }
275 }
276
277 static Node flatten(Kind k, TNode t){
278 std::vector<TNode> pb;
279 flatten(pb, k, t);
280 Assert(pb.size() >= 2);
281 return NodeManager::currentNM()->mkNode(k, pb);
282 }
283
284 RewriteResponse ArithRewriter::preRewritePlus(TNode t){
285 Assert(t.getKind()== kind::PLUS);
286
287 if(canFlatten(kind::PLUS, t)){
288 return RewriteResponse(REWRITE_DONE, flatten(kind::PLUS, t));
289 }else{
290 return RewriteResponse(REWRITE_DONE, t);
291 }
292 }
293
294 RewriteResponse ArithRewriter::postRewritePlus(TNode t){
295 Assert(t.getKind()== kind::PLUS);
296
297 std::vector<Monomial> monomials;
298 std::vector<Polynomial> polynomials;
299
300 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
301 TNode curr = *i;
302 if(Monomial::isMember(curr)){
303 monomials.push_back(Monomial::parseMonomial(curr));
304 }else{
305 polynomials.push_back(Polynomial::parsePolynomial(curr));
306 }
307 }
308
309 if(!monomials.empty()){
310 Monomial::sort(monomials);
311 Monomial::combineAdjacentMonomials(monomials);
312 polynomials.push_back(Polynomial::mkPolynomial(monomials));
313 }
314
315 Polynomial res = Polynomial::sumPolynomials(polynomials);
316
317 return RewriteResponse(REWRITE_DONE, res.getNode());
318 }
319
320 RewriteResponse ArithRewriter::postRewriteMult(TNode t){
321 Assert(t.getKind()== kind::MULT || t.getKind()==kind::NONLINEAR_MULT);
322
323 Polynomial res = Polynomial::mkOne();
324
325 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
326 Node curr = *i;
327 Polynomial currPoly = Polynomial::parsePolynomial(curr);
328
329 res = res * currPoly;
330 }
331
332 return RewriteResponse(REWRITE_DONE, res.getNode());
333 }
334
335 RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
336 if(atom.getKind() == kind::IS_INTEGER) {
337 if(atom[0].isConst()) {
338 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(atom[0].getConst<Rational>().isIntegral()));
339 }
340 if(atom[0].getType().isInteger()) {
341 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
342 }
343 // not supported, but this isn't the right place to complain
344 return RewriteResponse(REWRITE_DONE, atom);
345 } else if(atom.getKind() == kind::DIVISIBLE) {
346 if(atom[0].isConst()) {
347 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(bool((atom[0].getConst<Rational>() / atom.getOperator().getConst<Divisible>().k).isIntegral())));
348 }
349 if(atom.getOperator().getConst<Divisible>().k.isOne()) {
350 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
351 }
352 return RewriteResponse(REWRITE_AGAIN, NodeManager::currentNM()->mkNode(kind::EQUAL, NodeManager::currentNM()->mkNode(kind::INTS_MODULUS_TOTAL, atom[0], NodeManager::currentNM()->mkConst(Rational(atom.getOperator().getConst<Divisible>().k))), NodeManager::currentNM()->mkConst(Rational(0))));
353 }
354
355 // left |><| right
356 TNode left = atom[0];
357 TNode right = atom[1];
358
359 Polynomial pleft = Polynomial::parsePolynomial(left);
360 Polynomial pright = Polynomial::parsePolynomial(right);
361
362 Debug("arith::rewriter") << "pleft " << pleft.getNode() << std::endl;
363 Debug("arith::rewriter") << "pright " << pright.getNode() << std::endl;
364
365 Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
366 Assert(cmp.isNormalForm());
367 return RewriteResponse(REWRITE_DONE, cmp.getNode());
368 }
369
370 RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
371 Assert(isAtom(atom));
372
373 NodeManager* currNM = NodeManager::currentNM();
374
375 if(atom.getKind() == kind::EQUAL) {
376 if(atom[0] == atom[1]) {
377 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
378 }
379 }else if(atom.getKind() == kind::GT){
380 Node leq = currNM->mkNode(kind::LEQ, atom[0], atom[1]);
381 return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, leq));
382 }else if(atom.getKind() == kind::LT){
383 Node geq = currNM->mkNode(kind::GEQ, atom[0], atom[1]);
384 return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, geq));
385 }else if(atom.getKind() == kind::IS_INTEGER){
386 if(atom[0].getType().isInteger()){
387 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
388 }
389 }else if(atom.getKind() == kind::DIVISIBLE){
390 if(atom.getOperator().getConst<Divisible>().k.isOne()){
391 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
392 }
393 }
394
395 return RewriteResponse(REWRITE_DONE, atom);
396 }
397
398 RewriteResponse ArithRewriter::postRewrite(TNode t){
399 if(isTerm(t)){
400 RewriteResponse response = postRewriteTerm(t);
401 if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
402 Polynomial::parsePolynomial(response.node);
403 }
404 return response;
405 }else if(isAtom(t)){
406 RewriteResponse response = postRewriteAtom(t);
407 if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
408 Comparison::parseNormalForm(response.node);
409 }
410 return response;
411 }else{
412 Unreachable();
413 return RewriteResponse(REWRITE_DONE, Node::null());
414 }
415 }
416
417 RewriteResponse ArithRewriter::preRewrite(TNode t){
418 if(isTerm(t)){
419 return preRewriteTerm(t);
420 }else if(isAtom(t)){
421 return preRewriteAtom(t);
422 }else{
423 Unreachable();
424 return RewriteResponse(REWRITE_DONE, Node::null());
425 }
426 }
427
428 Node ArithRewriter::makeUnaryMinusNode(TNode n){
429 Rational qNegOne(-1);
430 return NodeManager::currentNM()->mkNode(kind::MULT, mkRationalNode(qNegOne),n);
431 }
432
433 Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
434 Node negR = makeUnaryMinusNode(r);
435 Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR);
436
437 return diff;
438 }
439
440 RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
441 Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind()== kind::DIVISION);
442
443
444 Node left = t[0];
445 Node right = t[1];
446 if(right.getKind() == kind::CONST_RATIONAL){
447 const Rational& den = right.getConst<Rational>();
448
449 if(den.isZero()){
450 if(t.getKind() == kind::DIVISION_TOTAL){
451 return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
452 }else{
453 // This is unsupported, but this is not a good place to complain
454 return RewriteResponse(REWRITE_DONE, t);
455 }
456 }
457 Assert(den != Rational(0));
458
459 if(left.getKind() == kind::CONST_RATIONAL){
460 const Rational& num = left.getConst<Rational>();
461 Rational div = num / den;
462 Node result = mkRationalNode(div);
463 return RewriteResponse(REWRITE_DONE, result);
464 }
465
466 Rational div = den.inverse();
467
468 Node result = mkRationalNode(div);
469
470 Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
471 if(pre){
472 return RewriteResponse(REWRITE_DONE, mult);
473 }else{
474 return RewriteResponse(REWRITE_AGAIN, mult);
475 }
476 }else{
477 return RewriteResponse(REWRITE_DONE, t);
478 }
479 }
480
481 RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre){
482 Kind k = t.getKind();
483 // Assert(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL ||
484 // k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
485
486 //Leaving the function as before (INTS_MODULUS can be handled),
487 // but restricting its use here
488 Assert(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL);
489 TNode n = t[0], d = t[1];
490 bool dIsConstant = d.getKind() == kind::CONST_RATIONAL;
491 if(dIsConstant && d.getConst<Rational>().isZero()){
492 if(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL){
493 return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
494 }else{
495 // Do nothing for k == INTS_MODULUS
496 return RewriteResponse(REWRITE_DONE, t);
497 }
498 }else if(dIsConstant && d.getConst<Rational>().isOne()){
499 if(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL){
500 return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
501 }else{
502 Assert(k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
503 return RewriteResponse(REWRITE_AGAIN, n);
504 }
505 }else if(dIsConstant && d.getConst<Rational>().isNegativeOne()){
506 if(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL){
507 return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
508 }else{
509 Assert(k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
510 return RewriteResponse(REWRITE_AGAIN, NodeManager::currentNM()->mkNode(kind::UMINUS, n));
511 }
512 }else if(dIsConstant && n.getKind() == kind::CONST_RATIONAL){
513 Assert(d.getConst<Rational>().isIntegral());
514 Assert(n.getConst<Rational>().isIntegral());
515 Assert(!d.getConst<Rational>().isZero());
516 Integer di = d.getConst<Rational>().getNumerator();
517 Integer ni = n.getConst<Rational>().getNumerator();
518
519 bool isDiv = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
520
521 Integer result = isDiv ? ni.euclidianDivideQuotient(di) : ni.euclidianDivideRemainder(di);
522
523 Node resultNode = mkRationalNode(Rational(result));
524 return RewriteResponse(REWRITE_DONE, resultNode);
525 }else{
526 return RewriteResponse(REWRITE_DONE, t);
527 }
528 }
529
530 }/* CVC4::theory::arith namespace */
531 }/* CVC4::theory namespace */
532 }/* CVC4 namespace */