Changed how assignments are saved during check. These are now backed by an attribute...
[cvc5.git] / src / theory / arith / arith_rewriter.cpp
1 /********************* */
2 /*! \file arith_rewriter.cpp
3 ** \verbatim
4 ** Original author: taking
5 ** Major contributors: none
6 ** Minor contributors (to current version): none
7 ** This file is part of the CVC4 prototype.
8 ** Copyright (c) 2009, 2010 The Analysis of Computer Systems Group (ACSys)
9 ** Courant Institute of Mathematical Sciences
10 ** New York University
11 ** See the file COPYING in the top-level source directory for licensing
12 ** information.\endverbatim
13 **
14 ** \brief [[ Add one-line brief description here ]]
15 **
16 ** [[ Add lengthier description here ]]
17 ** \todo document this file
18 **/
19
20
21 #include "theory/arith/arith_rewriter.h"
22 #include "theory/arith/arith_utilities.h"
23
24 #include <vector>
25 #include <set>
26 #include <stack>
27
28
29 using namespace CVC4;
30 using namespace CVC4::theory;
31 using namespace CVC4::theory::arith;
32
33
34
35
36
37 Kind multKind(Kind k, int sgn);
38
39 /**
40 * Performs a quick check to see if it is easy to rewrite to
41 * this normal form
42 * v |><| b
43 * Also writes relations with constants on both sides to TRUE or FALSE.
44 * If it can, it returns true and sets res to this value.
45 *
46 * This is for optimizing rewriteAtom() to avoid the more compuationally
47 * expensive general rewriting procedure.
48 *
49 * If simplification is not done, it returns Node::null()
50 */
51 Node almostVarOrConstEqn(TNode atom, Kind k, TNode left, TNode right){
52 Assert(atom.getKind() == k);
53 Assert(isRelationOperator(k));
54 Assert(atom[0] == left);
55 Assert(atom[1] == right);
56 bool leftIsConst = left.getMetaKind() == kind::metakind::CONSTANT;
57 bool rightIsConst = right.getMetaKind() == kind::metakind::CONSTANT;
58
59 bool leftIsVar = left.getMetaKind() == kind::metakind::VARIABLE;
60 bool rightIsVar = right.getMetaKind() == kind::metakind::VARIABLE;
61
62 if(leftIsConst && rightIsConst){
63 Rational lc = coerceToRational(left);
64 Rational rc = coerceToRational(right);
65 bool res = evaluateConstantPredicate(k,lc, rc);
66 return mkBoolNode(res);
67 }else if(leftIsVar && rightIsConst){
68 if(right.getKind() == kind::CONST_RATIONAL){
69 return atom;
70 }else{
71 return NodeManager::currentNM()->mkNode(k,left,coerceToRationalNode(right));
72 }
73 }else if(leftIsConst && rightIsVar){
74 if(left.getKind() == kind::CONST_RATIONAL){
75 return NodeManager::currentNM()->mkNode(multKind(k,-1),right,left);
76 }else{
77 Node q_left = coerceToRationalNode(left);
78 return NodeManager::currentNM()->mkNode(multKind(k,-1),right,q_left);
79 }
80 }
81
82 return Node::null();
83 }
84
85 Node ArithRewriter::rewriteAtomCore(TNode atom){
86
87 Kind k = atom.getKind();
88 Assert(isRelationOperator(k));
89
90 // left |><| right
91 TNode left = atom[0];
92 TNode right = atom[1];
93
94 Node nf = almostVarOrConstEqn(atom, k,left,right);
95 if(nf != Node::null() ){
96 return nf;
97 }
98
99
100 //Transform this to: (left- right) |><| 0
101 Node diff = makeSubtractionNode(left, right);
102
103 Node rewritten = rewrite(diff);
104 // rewritten =_{Reals} left - right => rewritten |><| 0
105
106 if(rewritten.getMetaKind() == kind::metakind::CONSTANT){
107 // Case 1 rewritten : c
108 Rational c = rewritten.getConst<Rational>();
109 bool res = evaluateConstantPredicate(k, c, d_constants->d_ZERO);
110 nf = mkBoolNode(res);
111 }else if(rewritten.getMetaKind() == kind::metakind::VARIABLE){
112 // Case 2 rewritten : v
113 nf = NodeManager::currentNM()->mkNode(k, rewritten, d_constants->d_ZERO_NODE);
114 }else{
115 // Case 3 rewritten : (+ c p_1 p_2 ... p_N) | not(N=1 and c=0 and p_1.d=1)
116 Rational c = rewritten[0].getConst<Rational>();
117 c = -c;
118 TNode p_1 = rewritten[1];
119 Rational d = p_1[0].getConst<Rational>();
120 d = d.inverse();
121 c = c * d;
122 Node newRight = mkRationalNode(c);
123 Kind newKind = multKind(k, d.sgn());
124 int N = rewritten.getNumChildren() - 1;
125
126 if(N==1){
127 int M = p_1.getNumChildren()-1;
128 if(M == 1){ // v |><| b
129 TNode v = p_1[1];
130 nf = NodeManager::currentNM()->mkNode(newKind, v, newRight);
131 }else{ // p |><| b
132 Node newLeft = multPnfByNonZero(p_1, d);
133 nf = NodeManager::currentNM()->mkNode(newKind, newLeft, newRight);
134 }
135 }else{ //(+ p_1 .. p_N) |><| b
136 NodeBuilder<> plus(kind::PLUS);
137 for(int i=1; i<=N; ++i){
138 TNode p_i = rewritten[i];
139 plus << multPnfByNonZero(p_i, d);
140 }
141 Node newLeft = plus;
142 nf = NodeManager::currentNM()->mkNode(newKind, newLeft, newRight);
143 }
144 }
145
146 return nf;
147 }
148
149 Node ArithRewriter::rewriteAtom(TNode atom){
150 Node rewritten = rewriteAtomCore(atom);
151 if(rewritten.getKind() == kind::LT){
152 Node geq = NodeManager::currentNM()->mkNode(kind::GEQ, rewritten[0], rewritten[1]);
153 return NodeManager::currentNM()->mkNode(kind::NOT, geq);
154 }else if(rewritten.getKind() == kind::GT){
155 Node leq = NodeManager::currentNM()->mkNode(kind::LEQ, rewritten[0], rewritten[1]);
156 return NodeManager::currentNM()->mkNode(kind::NOT, leq);
157 }else{
158 return rewritten;
159 }
160 }
161
162
163 /* cmp( (* d v_1 v_2 ... v_M), (* d' v'_1 v'_2 ... v'_M'):
164 * if(M == M'):
165 * then tupleCompare(v_i, v'_i)
166 * else M -M'
167 */
168 struct pnfLessThan {
169 bool operator()(Node p0, Node p1) {
170 int p0_M = p0.getNumChildren() -1;
171 int p1_M = p1.getNumChildren() -1;
172 if(p0_M == p1_M){
173 for(int i=1; i<= p0_M; ++i){
174 if(p0[i] != p1[i]){
175 return p0[i] < p1[i];
176 }
177 }
178 return false; //p0 == p1 in this order
179 }else{
180 return p0_M < p1_M;
181 }
182 }
183 };
184
185 //Two pnfs are equal up to their coefficients
186 bool pnfsMatch(TNode p0, TNode p1){
187
188 unsigned M = p0.getNumChildren()-1;
189 if (M+1 != p1.getNumChildren()){
190 return false;
191 }
192
193 for(unsigned i=1; i <= M; ++i){
194 if(p0[i] != p1[i])
195 return false;
196 }
197 return true;
198 }
199
200 Node addMatchingPnfs(TNode p0, TNode p1){
201 Assert(pnfsMatch(p0,p1));
202
203 unsigned M = p0.getNumChildren()-1;
204
205 Rational c0 = p0[0].getConst<Rational>();
206 Rational c1 = p1[0].getConst<Rational>();
207
208 Rational addedC = c0 + c1;
209 Node newC = mkRationalNode(addedC);
210 NodeBuilder<> nb(kind::MULT);
211 nb << newC;
212 for(unsigned i=1; i <= M; ++i){
213 nb << p0[i];
214 }
215 Node newPnf = nb;
216 return newPnf;
217 }
218
219 void ArithRewriter::sortAndCombineCoefficients(std::vector<Node>& pnfs){
220 using namespace std;
221
222 /* combined contains exactly 1 representative per for each pnf.
223 * This is maintained by combining the coefficients for pnfs.
224 * that is equal according to pnfLessThan.
225 */
226 typedef set<Node, pnfLessThan> PnfSet;
227 PnfSet combined;
228
229 for(vector<Node>::iterator i=pnfs.begin(); i != pnfs.end(); ++i){
230 Node pnf = *i;
231 PnfSet::iterator pos = combined.find(pnf);
232
233 if(pos == combined.end()){
234 combined.insert(pnf);
235 }else{
236 Node current = *pos;
237 Node sum = addMatchingPnfs(pnf, current);
238 combined.erase(pos);
239 combined.insert(sum);
240 }
241 }
242 pnfs.clear();
243 for(PnfSet::iterator i=combined.begin(); i != combined.end(); ++i){
244 Node pnf = *i;
245 if(pnf[0].getConst<Rational>() != d_constants->d_ZERO){
246 //after combination the coefficient may be zero
247 pnfs.push_back(pnf);
248 }
249 }
250 }
251
252 Node ArithRewriter::var2pnf(TNode variable){
253 return NodeManager::currentNM()->mkNode(kind::MULT,d_constants->d_ONE_NODE,variable);
254 }
255
256 Node ArithRewriter::rewritePlus(TNode t){
257 using namespace std;
258
259 Rational accumulator;
260 vector<Node> pnfs;
261
262 for(TNode::iterator i = t.begin(); i!= t.end(); ++i){
263 TNode child = *i;
264 Node rewrittenChild = rewrite(child);
265
266 if(rewrittenChild.getMetaKind() == kind::metakind::CONSTANT){//c
267 Rational c = rewrittenChild.getConst<Rational>();
268 accumulator = accumulator + c;
269 }else if(rewrittenChild.getMetaKind() == kind::metakind::VARIABLE){ //v
270 Node pnf = var2pnf(rewrittenChild);
271 pnfs.push_back(pnf);
272 }else{ //(+ c p_1 p_2 ... p_N)
273 Rational c = rewrittenChild[0].getConst<Rational>();
274 accumulator = accumulator + c;
275 int N = rewrittenChild.getNumChildren() - 1;
276 for(int i=1; i<=N; ++i){
277 TNode pnf = rewrittenChild[i];
278 pnfs.push_back(pnf);
279 }
280 }
281 }
282 sortAndCombineCoefficients(pnfs);
283 if(pnfs.size() == 0){
284 return mkRationalNode(accumulator);
285 }
286
287 // pnfs.size() >= 1
288
289 //Enforce not(N=1 and c=0 and p_1.d=1)
290 if(pnfs.size() == 1){
291 Node p_1 = *(pnfs.begin());
292 if(p_1[0].getConst<Rational>() == d_constants->d_ONE){
293 if(accumulator == d_constants->d_ZERO){ // 0 + (* 1 var) |-> var
294 Node var = p_1[1];
295 return var;
296 }
297 }
298 }
299
300 //We must be in this case
301 //(+ c p_1 p_2 ... p_N) | not(N=1 and c=0 and p_1.d=1)
302
303 NodeBuilder<> nb(kind::PLUS);
304 nb << mkRationalNode(accumulator);
305 Debug("arithrewrite") << mkRationalNode(accumulator) << std::endl;
306 for(vector<Node>::iterator i = pnfs.begin(); i != pnfs.end(); ++i){
307 nb << *i;
308 Debug("arithrewrite") << (*i) << std::endl;
309
310 }
311
312 Node normalForm = nb;
313 return normalForm;
314 }
315
316 //Does not enforce
317 //5) v_i are of metakind VARIABLE,
318 //6) v_i are in increasing (not strict) nodeOrder,
319 Node toPnf(Rational& c, std::set<Node>& variables){
320 NodeBuilder<> nb(kind::MULT);
321 nb << mkRationalNode(c);
322
323 for(std::set<Node>::iterator i = variables.begin(); i != variables.end(); ++i){
324 nb << *i;
325 }
326 Node pnf = nb;
327 return pnf;
328 }
329
330 Node distribute(TNode n, TNode sum){
331 NodeBuilder<> nb(kind::PLUS);
332 for(TNode::iterator i=sum.begin(); i!=sum.end(); ++i){
333 Node prod = NodeManager::currentNM()->mkNode(kind::MULT, n, *i);
334 nb << prod;
335 }
336 return nb;
337 }
338 Node distributeSum(TNode sum, TNode distribSum){
339 NodeBuilder<> nb(kind::PLUS);
340 for(TNode::iterator i=sum.begin(); i!=sum.end(); ++i){
341 Node dist = distribute(*i, distribSum);
342 for(Node::iterator j=dist.begin(); j!=dist.end(); ++j){
343 nb << *j;
344 }
345 }
346 return nb;
347 }
348
349 Node ArithRewriter::rewriteMult(TNode t){
350
351 using namespace std;
352
353 Rational accumulator(1,1);
354 set<Node> variables;
355 vector<Node> sums;
356
357 //These stacks need to be kept in lock step
358 stack<TNode> mult_iterators_nodes;
359 stack<TNode::const_iterator> mult_iterators_iters;
360
361 mult_iterators_nodes.push(t);
362 mult_iterators_iters.push(t.begin());
363
364 while(!mult_iterators_nodes.empty()){
365 TNode mult = mult_iterators_nodes.top();
366 TNode::const_iterator i = mult_iterators_iters.top();
367
368 mult_iterators_nodes.pop();
369 mult_iterators_iters.pop();
370
371 for(; i != mult.end(); ++i){
372 TNode child = *i;
373 if(child.getKind() == kind::MULT){ //TODO add not rewritten already checks
374 ++i;
375 mult_iterators_nodes.push(mult);
376 mult_iterators_iters.push(i);
377
378 mult_iterators_nodes.push(child);
379 mult_iterators_iters.push(child.begin());
380 break;
381 }
382 Node rewrittenChild = rewrite(child);
383
384 if(rewrittenChild.getMetaKind() == kind::metakind::CONSTANT){//c
385 Rational c = rewrittenChild.getConst<Rational>();
386 accumulator = accumulator * c;
387 if(accumulator == d_constants->d_ZERO){
388 return d_constants->d_ZERO_NODE;
389 }
390 }else if(rewrittenChild.getMetaKind() == kind::metakind::VARIABLE){ //v
391 variables.insert(rewrittenChild);
392 }else{ //(+ c p_1 p_2 ... p_N)
393 sums.push_back(rewrittenChild);
394 }
395 }
396 }
397 // accumulator * (\prod var_i) *(\prod sum_j)
398
399 if(sums.size() == 0){ //accumulator * (\prod var_i)
400 if(variables.size() == 0){ //accumulator
401 return mkRationalNode(accumulator);
402 }else if(variables.size() == 1 && accumulator == d_constants->d_ONE){ // var_1
403 Node var = *(variables.begin());
404 return var;
405 }else{
406 //We need to return (+ c p_1 p_2 ... p_N)
407 //To accomplish this:
408 // let pnf = pnf(accumulator * (\prod var_i)) in (+ 0 pnf)
409 Node pnf = toPnf(accumulator, variables);
410 Node normalForm = NodeManager::currentNM()->mkNode(kind::PLUS, d_constants->d_ZERO_NODE, pnf);
411 return normalForm;
412 }
413 }else{
414 vector<Node>::iterator sum_iter = sums.begin();
415 // \sum t
416 // t \in Q \cup A
417 // where A = lfp {\prod s | s \in Q \cup Variables \cup A}
418 Node distributed = *sum_iter;
419 ++sum_iter;
420 while(sum_iter != sums.end()){
421 Node curr = *sum_iter;
422 distributed = distributeSum(curr, distributed);
423 ++sum_iter;
424 }
425 if(variables.size() >= 1){
426 Node pnf = toPnf(accumulator, variables);
427 distributed = distribute(pnf, distributed);
428 }else{
429 Node constant = mkRationalNode(accumulator);
430 distributed = distribute(constant, distributed);
431 }
432
433 Node nf_distributed = rewrite(distributed);
434 return nf_distributed;
435 }
436 }
437
438 Node ArithRewriter::rewriteConstantDiv(TNode t){
439 Assert(t.getKind()== kind::DIVISION);
440
441 Node reLeft = rewrite(t[0]);
442 Node reRight = rewrite(t[1]);
443 Assert(reLeft.getKind()== kind::CONST_RATIONAL);
444 Assert(reRight.getKind()== kind::CONST_RATIONAL);
445
446 Rational num = reLeft.getConst<Rational>();
447 Rational den = reRight.getConst<Rational>();
448
449 Assert(den != d_constants->d_ZERO);
450
451 Rational div = num / den;
452
453 Node result = mkRationalNode(div);
454
455 return result;
456 }
457
458 Node ArithRewriter::rewriteTerm(TNode t){
459 if(t.getMetaKind() == kind::metakind::CONSTANT){
460 return coerceToRationalNode(t);
461 }else if(t.getMetaKind() == kind::metakind::VARIABLE){
462 return t;
463 }else if(t.getKind() == kind::MULT){
464 return rewriteMult(t);
465 }else if(t.getKind() == kind::PLUS){
466 return rewritePlus(t);
467 }else if(t.getKind() == kind::DIVISION){
468 return rewriteConstantDiv(t);
469 }else if(t.getKind() == kind::MINUS){
470 Node sub = makeSubtractionNode(t[0],t[1]);
471 return rewrite(sub);
472 }else if(t.getKind() == kind::UMINUS){
473 Node sub = makeUnaryMinusNode(t[0]);
474 return rewrite(sub);
475 }else{
476 Unreachable();
477 return Node::null();
478 }
479 }
480
481
482 /**
483 * Given a node in PNF pnf = (* d p_1 p_2 .. p_M) and a rational q != 0
484 * constuct a node equal to q * pnf that is in pnf.
485 *
486 * The claim is that this is always okay:
487 * If d' = q*d, p' = (* d' p_1 p_2 .. p_M) =_{Reals} q * pnf.
488 */
489 Node ArithRewriter::multPnfByNonZero(TNode pnf, Rational& q){
490 Assert(q != d_constants->d_ZERO);
491 //TODO Assert(isPNF(pnf) );
492
493 int M = pnf.getNumChildren()-1;
494 Rational d = pnf[0].getConst<Rational>();
495 Rational new_d = d*q;
496
497
498 NodeBuilder<> mult(kind::MULT);
499 mult << mkRationalNode(new_d);
500 for(int i=1; i<=M; ++i){
501 mult << pnf[i];
502 }
503
504 Node result = mult;
505 return result;
506 }
507
508 Node ArithRewriter::makeUnaryMinusNode(TNode n){
509 Node tmp = NodeManager::currentNM()->mkNode(kind::MULT,d_constants->d_NEGATIVE_ONE_NODE,n);
510 return tmp;
511 }
512
513 Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
514 Node negR = makeUnaryMinusNode(r);
515 Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR);
516
517 return diff;
518 }
519
520
521 Kind multKind(Kind k, int sgn){
522 using namespace kind;
523
524 if(sgn < 0){
525
526 switch(k){
527 case LT: return GT;
528 case LEQ: return GEQ;
529 case EQUAL: return EQUAL;
530 case GEQ: return LEQ;
531 case GT: return LT;
532 default:
533 Unhandled();
534 }
535 return NULL_EXPR;
536 }else{
537 return k;
538 }
539 }
540
541 Node ArithRewriter::rewrite(TNode n){
542 Debug("arithrewriter") << "Trace rewrite:" << n << std::endl;
543
544 Node res;
545
546 if(isRelationOperator(n.getKind())){
547 res = rewriteAtom(n);
548 }else{
549 res = rewriteTerm(n);
550 }
551
552 Debug("arithrewriter") << "Trace rewrite:" << n << "|->"<< res << std::endl;
553
554 return res;
555 }