Merge branch '1.2.x'
[cvc5.git] / src / theory / arith / arith_rewriter.cpp
1 /********************* */
2 /*! \file arith_rewriter.cpp
3 ** \verbatim
4 ** Original author: Tim King
5 ** Major contributors: none
6 ** Minor contributors (to current version): Morgan Deters, Dejan Jovanovic
7 ** This file is part of the CVC4 project.
8 ** Copyright (c) 2009-2013 New York University and The University of Iowa
9 ** See the file COPYING in the top-level source directory for licensing
10 ** information.\endverbatim
11 **
12 ** \brief [[ Add one-line brief description here ]]
13 **
14 ** [[ Add lengthier description here ]]
15 ** \todo document this file
16 **/
17
18 #include "theory/theory.h"
19 #include "theory/arith/normal_form.h"
20 #include "theory/arith/arith_rewriter.h"
21 #include "theory/arith/arith_utilities.h"
22
23 #include <vector>
24 #include <set>
25 #include <stack>
26
27 namespace CVC4 {
28 namespace theory {
29 namespace arith {
30
31 bool ArithRewriter::isAtom(TNode n) {
32 Kind k = n.getKind();
33 return arith::isRelationOperator(k) || k == kind::IS_INTEGER || k == kind::DIVISIBLE;
34 }
35
36 RewriteResponse ArithRewriter::rewriteConstant(TNode t){
37 Assert(t.isConst());
38 Assert(t.getKind() == kind::CONST_RATIONAL);
39
40 return RewriteResponse(REWRITE_DONE, t);
41 }
42
43 RewriteResponse ArithRewriter::rewriteVariable(TNode t){
44 Assert(t.isVar());
45
46 return RewriteResponse(REWRITE_DONE, t);
47 }
48
49 RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){
50 Assert(t.getKind()== kind::MINUS);
51
52 if(pre){
53 if(t[0] == t[1]){
54 Rational zero(0);
55 Node zeroNode = mkRationalNode(zero);
56 return RewriteResponse(REWRITE_DONE, zeroNode);
57 }else{
58 Node noMinus = makeSubtractionNode(t[0],t[1]);
59 return RewriteResponse(REWRITE_DONE, noMinus);
60 }
61 }else{
62 Polynomial minuend = Polynomial::parsePolynomial(t[0]);
63 Polynomial subtrahend = Polynomial::parsePolynomial(t[1]);
64 Polynomial diff = minuend - subtrahend;
65 return RewriteResponse(REWRITE_DONE, diff.getNode());
66 }
67 }
68
69 RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){
70 Assert(t.getKind()== kind::UMINUS);
71
72 if(t[0].getKind() == kind::CONST_RATIONAL){
73 Rational neg = -(t[0].getConst<Rational>());
74 return RewriteResponse(REWRITE_DONE, mkRationalNode(neg));
75 }
76
77 Node noUminus = makeUnaryMinusNode(t[0]);
78 if(pre)
79 return RewriteResponse(REWRITE_DONE, noUminus);
80 else
81 return RewriteResponse(REWRITE_AGAIN, noUminus);
82 }
83
84 RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
85 if(t.isConst()){
86 return rewriteConstant(t);
87 }else if(t.isVar()){
88 return rewriteVariable(t);
89 }else{
90 switch(Kind k = t.getKind()){
91 case kind::MINUS:
92 return rewriteMinus(t, true);
93 case kind::UMINUS:
94 return rewriteUMinus(t, true);
95 case kind::DIVISION:
96 case kind::DIVISION_TOTAL:
97 return rewriteDiv(t,true);
98 case kind::PLUS:
99 return preRewritePlus(t);
100 case kind::MULT:
101 return preRewriteMult(t);
102 case kind::INTS_DIVISION:
103 case kind::INTS_MODULUS:
104 return RewriteResponse(REWRITE_DONE, t);
105 case kind::INTS_DIVISION_TOTAL:
106 case kind::INTS_MODULUS_TOTAL:
107 return rewriteIntsDivModTotal(t,true);
108 case kind::ABS:
109 if(t[0].isConst()) {
110 const Rational& rat = t[0].getConst<Rational>();
111 if(rat >= 0) {
112 return RewriteResponse(REWRITE_DONE, t[0]);
113 } else {
114 return RewriteResponse(REWRITE_DONE,
115 NodeManager::currentNM()->mkConst(-rat));
116 }
117 }
118 return RewriteResponse(REWRITE_DONE, t);
119 case kind::IS_INTEGER:
120 case kind::TO_INTEGER:
121 return RewriteResponse(REWRITE_DONE, t);
122 case kind::TO_REAL:
123 return RewriteResponse(REWRITE_DONE, t[0]);
124 default:
125 Unhandled(k);
126 }
127 }
128 }
129 RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
130 if(t.isConst()){
131 return rewriteConstant(t);
132 }else if(t.isVar()){
133 return rewriteVariable(t);
134 }else{
135 switch(t.getKind()){
136 case kind::MINUS:
137 return rewriteMinus(t, false);
138 case kind::UMINUS:
139 return rewriteUMinus(t, false);
140 case kind::DIVISION:
141 case kind::DIVISION_TOTAL:
142 return rewriteDiv(t, false);
143 case kind::PLUS:
144 return postRewritePlus(t);
145 case kind::MULT:
146 return postRewriteMult(t);
147 case kind::INTS_DIVISION:
148 case kind::INTS_MODULUS:
149 return RewriteResponse(REWRITE_DONE, t);
150 case kind::INTS_DIVISION_TOTAL:
151 case kind::INTS_MODULUS_TOTAL:
152 return rewriteIntsDivModTotal(t, false);
153 case kind::ABS:
154 if(t[0].isConst()) {
155 const Rational& rat = t[0].getConst<Rational>();
156 if(rat >= 0) {
157 return RewriteResponse(REWRITE_DONE, t[0]);
158 } else {
159 return RewriteResponse(REWRITE_DONE,
160 NodeManager::currentNM()->mkConst(-rat));
161 }
162 }
163 case kind::TO_REAL:
164 return RewriteResponse(REWRITE_DONE, t[0]);
165 case kind::TO_INTEGER:
166 if(t[0].isConst()) {
167 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(Rational(t[0].getConst<Rational>().floor())));
168 }
169 if(t[0].getType().isInteger()) {
170 return RewriteResponse(REWRITE_DONE, t[0]);
171 }
172 //Unimplemented("TO_INTEGER, nonconstant");
173 //return rewriteToInteger(t);
174 return RewriteResponse(REWRITE_DONE, t);
175 case kind::IS_INTEGER:
176 if(t[0].isConst()) {
177 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(t[0].getConst<Rational>().getDenominator() == 1));
178 }
179 if(t[0].getType().isInteger()) {
180 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
181 }
182 //Unimplemented("IS_INTEGER, nonconstant");
183 //return rewriteIsInteger(t);
184 return RewriteResponse(REWRITE_DONE, t);
185 default:
186 Unreachable();
187 }
188 }
189 }
190
191
192 RewriteResponse ArithRewriter::preRewriteMult(TNode t){
193 Assert(t.getKind()== kind::MULT);
194
195 // Rewrite multiplications with a 0 argument and to 0
196 Rational qZero(0);
197
198 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
199 if((*i).getKind() == kind::CONST_RATIONAL) {
200 if((*i).getConst<Rational>() == qZero) {
201 return RewriteResponse(REWRITE_DONE, mkRationalNode(qZero));
202 }
203 }
204 }
205 return RewriteResponse(REWRITE_DONE, t);
206 }
207 RewriteResponse ArithRewriter::preRewritePlus(TNode t){
208 Assert(t.getKind()== kind::PLUS);
209
210 return RewriteResponse(REWRITE_DONE, t);
211 }
212
213 RewriteResponse ArithRewriter::postRewritePlus(TNode t){
214 Assert(t.getKind()== kind::PLUS);
215
216 Polynomial res = Polynomial::mkZero();
217
218 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
219 Node curr = *i;
220 Polynomial currPoly = Polynomial::parsePolynomial(curr);
221
222 res = res + currPoly;
223 }
224
225 return RewriteResponse(REWRITE_DONE, res.getNode());
226 }
227
228 RewriteResponse ArithRewriter::postRewriteMult(TNode t){
229 Assert(t.getKind()== kind::MULT);
230
231 Polynomial res = Polynomial::mkOne();
232
233 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
234 Node curr = *i;
235 Polynomial currPoly = Polynomial::parsePolynomial(curr);
236
237 res = res * currPoly;
238 }
239
240 return RewriteResponse(REWRITE_DONE, res.getNode());
241 }
242
243 RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
244 if(atom.getKind() == kind::IS_INTEGER) {
245 if(atom[0].isConst()) {
246 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(atom[0].getConst<Rational>().isIntegral()));
247 }
248 if(atom[0].getType().isInteger()) {
249 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
250 }
251 // not supported, but this isn't the right place to complain
252 return RewriteResponse(REWRITE_DONE, atom);
253 } else if(atom.getKind() == kind::DIVISIBLE) {
254 if(atom[0].isConst()) {
255 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(bool((atom[0].getConst<Rational>() / atom.getOperator().getConst<Divisible>().k).isIntegral())));
256 }
257 if(atom.getOperator().getConst<Divisible>().k.isOne()) {
258 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
259 }
260 return RewriteResponse(REWRITE_AGAIN, NodeManager::currentNM()->mkNode(kind::EQUAL, NodeManager::currentNM()->mkNode(kind::INTS_MODULUS_TOTAL, atom[0], NodeManager::currentNM()->mkConst(Rational(atom.getOperator().getConst<Divisible>().k))), NodeManager::currentNM()->mkConst(Rational(0))));
261 }
262
263 // left |><| right
264 TNode left = atom[0];
265 TNode right = atom[1];
266
267 Polynomial pleft = Polynomial::parsePolynomial(left);
268 Polynomial pright = Polynomial::parsePolynomial(right);
269
270 Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
271 Assert(cmp.isNormalForm());
272 return RewriteResponse(REWRITE_DONE, cmp.getNode());
273 }
274
275 RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
276 Assert(isAtom(atom));
277
278 NodeManager* currNM = NodeManager::currentNM();
279
280 if(atom.getKind() == kind::EQUAL) {
281 if(atom[0] == atom[1]) {
282 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
283 }
284 }else if(atom.getKind() == kind::GT){
285 Node leq = currNM->mkNode(kind::LEQ, atom[0], atom[1]);
286 return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, leq));
287 }else if(atom.getKind() == kind::LT){
288 Node geq = currNM->mkNode(kind::GEQ, atom[0], atom[1]);
289 return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, geq));
290 }else if(atom.getKind() == kind::IS_INTEGER){
291 if(atom[0].getType().isInteger()){
292 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
293 }
294 }else if(atom.getKind() == kind::DIVISIBLE){
295 if(atom.getOperator().getConst<Divisible>().k.isOne()){
296 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
297 }
298 }
299
300 return RewriteResponse(REWRITE_DONE, atom);
301 }
302
303 RewriteResponse ArithRewriter::postRewrite(TNode t){
304 if(isTerm(t)){
305 RewriteResponse response = postRewriteTerm(t);
306 if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
307 Polynomial::parsePolynomial(response.node);
308 }
309 return response;
310 }else if(isAtom(t)){
311 RewriteResponse response = postRewriteAtom(t);
312 if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
313 Comparison::parseNormalForm(response.node);
314 }
315 return response;
316 }else{
317 Unreachable();
318 return RewriteResponse(REWRITE_DONE, Node::null());
319 }
320 }
321
322 RewriteResponse ArithRewriter::preRewrite(TNode t){
323 if(isTerm(t)){
324 return preRewriteTerm(t);
325 }else if(isAtom(t)){
326 return preRewriteAtom(t);
327 }else{
328 Unreachable();
329 return RewriteResponse(REWRITE_DONE, Node::null());
330 }
331 }
332
333 Node ArithRewriter::makeUnaryMinusNode(TNode n){
334 Rational qNegOne(-1);
335 return NodeManager::currentNM()->mkNode(kind::MULT, mkRationalNode(qNegOne),n);
336 }
337
338 Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
339 Node negR = makeUnaryMinusNode(r);
340 Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR);
341
342 return diff;
343 }
344
345 RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
346 Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind()== kind::DIVISION);
347
348
349 Node left = t[0];
350 Node right = t[1];
351 if(right.getKind() == kind::CONST_RATIONAL){
352 const Rational& den = right.getConst<Rational>();
353
354 if(den.isZero()){
355 if(t.getKind() == kind::DIVISION_TOTAL){
356 return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
357 }else{
358 // This is unsupported, but this is not a good place to complain
359 return RewriteResponse(REWRITE_DONE, t);
360 }
361 }
362 Assert(den != Rational(0));
363
364 if(left.getKind() == kind::CONST_RATIONAL){
365 const Rational& num = left.getConst<Rational>();
366 Rational div = num / den;
367 Node result = mkRationalNode(div);
368 return RewriteResponse(REWRITE_DONE, result);
369 }
370
371 Rational div = den.inverse();
372
373 Node result = mkRationalNode(div);
374
375 Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
376 if(pre){
377 return RewriteResponse(REWRITE_DONE, mult);
378 }else{
379 return RewriteResponse(REWRITE_AGAIN, mult);
380 }
381 }else{
382 return RewriteResponse(REWRITE_DONE, t);
383 }
384 }
385
386 RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre){
387 Kind k = t.getKind();
388 // Assert(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL ||
389 // k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
390
391 //Leaving the function as before (INTS_MODULUS can be handled),
392 // but restricting its use here
393 Assert(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL);
394 TNode n = t[0], d = t[1];
395 bool dIsConstant = d.getKind() == kind::CONST_RATIONAL;
396 if(dIsConstant && d.getConst<Rational>().isZero()){
397 if(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL){
398 return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
399 }else{
400 // Do nothing for k == INTS_MODULUS
401 return RewriteResponse(REWRITE_DONE, t);
402 }
403 }else if(dIsConstant && d.getConst<Rational>().isOne()){
404 if(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL){
405 return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
406 }else{
407 Assert(k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
408 return RewriteResponse(REWRITE_AGAIN, n);
409 }
410 }else if(dIsConstant && d.getConst<Rational>().isNegativeOne()){
411 if(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL){
412 return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
413 }else{
414 Assert(k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
415 return RewriteResponse(REWRITE_AGAIN, NodeManager::currentNM()->mkNode(kind::UMINUS, n));
416 }
417 }else if(dIsConstant && n.getKind() == kind::CONST_RATIONAL){
418 Assert(d.getConst<Rational>().isIntegral());
419 Assert(n.getConst<Rational>().isIntegral());
420 Assert(!d.getConst<Rational>().isZero());
421 Integer di = d.getConst<Rational>().getNumerator();
422 Integer ni = n.getConst<Rational>().getNumerator();
423
424 bool isDiv = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
425
426 Integer result = isDiv ? ni.euclidianDivideQuotient(di) : ni.euclidianDivideRemainder(di);
427
428 Node resultNode = mkRationalNode(Rational(result));
429 return RewriteResponse(REWRITE_DONE, resultNode);
430 }else{
431 return RewriteResponse(REWRITE_DONE, t);
432 }
433 }
434
435 }/* CVC4::theory::arith namespace */
436 }/* CVC4::theory namespace */
437 }/* CVC4 namespace */