Fix for bug 131. Added some additional debugging assertions for the arith rewriter.
[cvc5.git] / src / theory / arith / arith_rewriter.cpp
1
2 #include "theory/arith/arith_rewriter.h"
3 #include "theory/arith/arith_utilities.h"
4 #include "theory/arith/normal.h"
5
6 #include <vector>
7 #include <set>
8 #include <stack>
9
10
11 using namespace CVC4;
12 using namespace CVC4::theory;
13 using namespace CVC4::theory::arith;
14
15
16
17
18
19 Kind multKind(Kind k, int sgn);
20
21 /**
22 * Performs a quick check to see if it is easy to rewrite to
23 * this normal form
24 * v |><| b
25 * Also writes relations with constants on both sides to TRUE or FALSE.
26 * If it can, it returns true and sets res to this value.
27 *
28 * This is for optimizing rewriteAtom() to avoid the more compuationally
29 * expensive general rewriting procedure.
30 *
31 * If simplification is not done, it returns Node::null()
32 */
33 Node almostVarOrConstEqn(TNode atom, Kind k, TNode left, TNode right){
34 Assert(atom.getKind() == k);
35 Assert(isRelationOperator(k));
36 Assert(atom[0] == left);
37 Assert(atom[1] == right);
38 bool leftIsConst = left.getMetaKind() == kind::metakind::CONSTANT;
39 bool rightIsConst = right.getMetaKind() == kind::metakind::CONSTANT;
40
41 bool leftIsVar = left.getMetaKind() == kind::metakind::VARIABLE;
42 bool rightIsVar = right.getMetaKind() == kind::metakind::VARIABLE;
43
44 if(leftIsConst && rightIsConst){
45 Rational lc = coerceToRational(left);
46 Rational rc = coerceToRational(right);
47 bool res = evaluateConstantPredicate(k,lc, rc);
48 return mkBoolNode(res);
49 }else if(leftIsVar && rightIsConst){
50 if(right.getKind() == kind::CONST_RATIONAL){
51 return atom;
52 }else{
53 return NodeManager::currentNM()->mkNode(k,left,coerceToRationalNode(right));
54 }
55 }else if(leftIsConst && rightIsVar){
56 if(left.getKind() == kind::CONST_RATIONAL){
57 return NodeManager::currentNM()->mkNode(multKind(k,-1),right,left);
58 }else{
59 Node q_left = coerceToRationalNode(left);
60 return NodeManager::currentNM()->mkNode(multKind(k,-1),right,q_left);
61 }
62 }
63
64 return Node::null();
65 }
66
67 Node ArithRewriter::rewriteAtomCore(TNode atom){
68
69 Kind k = atom.getKind();
70 Assert(isRelationOperator(k));
71
72 // left |><| right
73 TNode left = atom[0];
74 TNode right = atom[1];
75
76 Node nf = almostVarOrConstEqn(atom, k,left,right);
77 if(nf != Node::null() ){
78 return nf;
79 }
80
81
82 //Transform this to: (left- right) |><| 0
83 Node diff = makeSubtractionNode(left, right);
84
85 Node rewritten = rewrite(diff);
86 // rewritten =_{Reals} left - right => rewritten |><| 0
87
88 if(rewritten.getMetaKind() == kind::metakind::CONSTANT){
89 // Case 1 rewritten : c
90 Rational c = rewritten.getConst<Rational>();
91 bool res = evaluateConstantPredicate(k, c, d_constants->d_ZERO);
92 nf = mkBoolNode(res);
93 }else if(rewritten.getMetaKind() == kind::metakind::VARIABLE){
94 // Case 2 rewritten : v
95 nf = NodeManager::currentNM()->mkNode(k, rewritten, d_constants->d_ZERO_NODE);
96 }else{
97 // Case 3 rewritten : (+ c p_1 p_2 ... p_N) | not(N=1 and c=0 and p_1.d=1)
98 Rational c = rewritten[0].getConst<Rational>();
99 c = -c;
100 TNode p_1 = rewritten[1];
101 Rational d = p_1[0].getConst<Rational>();
102 d = d.inverse();
103 c = c * d;
104 Node newRight = mkRationalNode(c);
105 Kind newKind = multKind(k, d.sgn());
106 int N = rewritten.getNumChildren() - 1;
107
108 if(N==1){
109 int M = p_1.getNumChildren()-1;
110 if(M == 1){ // v |><| b
111 TNode v = p_1[1];
112 nf = NodeManager::currentNM()->mkNode(newKind, v, newRight);
113 }else{ // p |><| b
114 Node newLeft = multPnfByNonZero(p_1, d);
115 nf = NodeManager::currentNM()->mkNode(newKind, newLeft, newRight);
116 }
117 }else{ //(+ p_1 .. p_N) |><| b
118 NodeBuilder<> plus(kind::PLUS);
119 for(int i=1; i<=N; ++i){
120 TNode p_i = rewritten[i];
121 plus << multPnfByNonZero(p_i, d);
122 }
123 Node newLeft = plus;
124 nf = NodeManager::currentNM()->mkNode(newKind, newLeft, newRight);
125 }
126 }
127
128 return nf;
129 }
130
131 Node ArithRewriter::rewriteAtom(TNode atom){
132 Node rewritten = rewriteAtomCore(atom);
133 if(rewritten.getKind() == kind::LT){
134 Node geq = NodeManager::currentNM()->mkNode(kind::GEQ, rewritten[0], rewritten[1]);
135 return NodeManager::currentNM()->mkNode(kind::NOT, geq);
136 }else if(rewritten.getKind() == kind::GT){
137 Node leq = NodeManager::currentNM()->mkNode(kind::LEQ, rewritten[0], rewritten[1]);
138 return NodeManager::currentNM()->mkNode(kind::NOT, leq);
139 }else{
140 return rewritten;
141 }
142 }
143
144
145 /* cmp( (* d v_1 v_2 ... v_M), (* d' v'_1 v'_2 ... v'_M'):
146 * if(M == M'):
147 * then tupleCompare(v_i, v'_i)
148 * else M -M'
149 */
150 struct pnfLessThan {
151 bool operator()(Node p0, Node p1) {
152 int p0_M = p0.getNumChildren() -1;
153 int p1_M = p1.getNumChildren() -1;
154 if(p0_M == p1_M){
155 for(int i=1; i<= p0_M; ++i){
156 if(p0[i] != p1[i]){
157 return p0[i] < p1[i];
158 }
159 }
160 return false; //p0 == p1 in this order
161 }else{
162 return p0_M < p1_M;
163 }
164 }
165 };
166
167 //Two pnfs are equal up to their coefficients
168 bool pnfsMatch(TNode p0, TNode p1){
169
170 unsigned M = p0.getNumChildren()-1;
171 if (M+1 != p1.getNumChildren()){
172 return false;
173 }
174
175 for(unsigned i=1; i <= M; ++i){
176 if(p0[i] != p1[i])
177 return false;
178 }
179 return true;
180 }
181
182 Node addMatchingPnfs(TNode p0, TNode p1){
183 Assert(pnfsMatch(p0,p1));
184
185 unsigned M = p0.getNumChildren()-1;
186
187 Rational c0 = p0[0].getConst<Rational>();
188 Rational c1 = p1[0].getConst<Rational>();
189
190 Rational addedC = c0 + c1;
191 Node newC = mkRationalNode(addedC);
192 NodeBuilder<> nb(kind::PLUS);
193 nb << newC;
194 for(unsigned i=1; i <= M; ++i){
195 nb << p0[i];
196 }
197 Node newPnf = nb;
198 return newPnf;
199 }
200
201 void sortAndCombineCoefficients(std::vector<Node>& pnfs){
202 using namespace std;
203
204 /* combined contains exactly 1 representative per for each pnf.
205 * This is maintained by combining the coefficients for pnfs.
206 * that is equal according to pnfLessThan.
207 */
208 typedef set<Node, pnfLessThan> PnfSet;
209 PnfSet combined;
210
211 for(vector<Node>::iterator i=pnfs.begin(); i != pnfs.end(); ++i){
212 Node pnf = *i;
213 PnfSet::iterator pos = combined.find(pnf);
214
215 if(pos == combined.end()){
216 combined.insert(pnf);
217 }else{
218 Node current = *pos;
219 Node sum = addMatchingPnfs(pnf, current);
220 combined.erase(pos);
221 combined.insert(sum);
222 }
223 }
224 pnfs.clear();
225 for(PnfSet::iterator i=combined.begin(); i != combined.end(); ++i){
226 Node pnf = *i;
227 pnfs.push_back(pnf);
228 }
229 }
230
231 Node ArithRewriter::var2pnf(TNode variable){
232 return NodeManager::currentNM()->mkNode(kind::MULT,d_constants->d_ONE_NODE,variable);
233 }
234
235 Node ArithRewriter::rewritePlus(TNode t){
236 using namespace std;
237
238 Rational accumulator;
239 vector<Node> pnfs;
240
241 for(TNode::iterator i = t.begin(); i!= t.end(); ++i){
242 TNode child = *i;
243 Node rewrittenChild = rewrite(child);
244
245 if(rewrittenChild.getMetaKind() == kind::metakind::CONSTANT){//c
246 Rational c = rewrittenChild.getConst<Rational>();
247 accumulator = accumulator + c;
248 }else if(rewrittenChild.getMetaKind() == kind::metakind::VARIABLE){ //v
249 Node pnf = var2pnf(rewrittenChild);
250 pnfs.push_back(pnf);
251 }else{ //(+ c p_1 p_2 ... p_N)
252 Rational c = rewrittenChild[0].getConst<Rational>();
253 accumulator = accumulator + c;
254 int N = rewrittenChild.getNumChildren() - 1;
255 for(int i=1; i<=N; ++i){
256 TNode pnf = rewrittenChild[i];
257 pnfs.push_back(pnf);
258 }
259 }
260 }
261 sortAndCombineCoefficients(pnfs);
262 if(pnfs.size() == 0){
263 return mkRationalNode(accumulator);
264 }
265
266 // pnfs.size() >= 1
267
268 //Enforce not(N=1 and c=0 and p_1.d=1)
269 if(pnfs.size() == 1){
270 Node p_1 = *(pnfs.begin());
271 if(p_1[0].getConst<Rational>() == d_constants->d_ONE){
272 if(accumulator == d_constants->d_ZERO){ // 0 + (* 1 var) |-> var
273 Node var = p_1[1];
274 return var;
275 }
276 }
277 }
278
279 //We must be in this case
280 //(+ c p_1 p_2 ... p_N) | not(N=1 and c=0 and p_1.d=1)
281
282 NodeBuilder<> nb(kind::PLUS);
283 nb << mkRationalNode(accumulator);
284 Debug("arithrewrite") << mkRationalNode(accumulator) << std::endl;
285 for(vector<Node>::iterator i = pnfs.begin(); i != pnfs.end(); ++i){
286 nb << *i;
287 Debug("arithrewrite") << (*i) << std::endl;
288
289 }
290
291 Node normalForm = nb;
292 return normalForm;
293 }
294
295 //Does not enforce
296 //5) v_i are of metakind VARIABLE,
297 //6) v_i are in increasing (not strict) nodeOrder,
298 Node toPnf(Rational& c, std::set<Node>& variables){
299 NodeBuilder<> nb(kind::MULT);
300 nb << mkRationalNode(c);
301
302 for(std::set<Node>::iterator i = variables.begin(); i != variables.end(); ++i){
303 nb << *i;
304 }
305 Node pnf = nb;
306 return pnf;
307 }
308
309 Node distribute(TNode n, TNode sum){
310 NodeBuilder<> nb(kind::PLUS);
311 for(TNode::iterator i=sum.begin(); i!=sum.end(); ++i){
312 Node prod = NodeManager::currentNM()->mkNode(kind::MULT, n, *i);
313 nb << prod;
314 }
315 return nb;
316 }
317 Node distributeSum(TNode sum, TNode distribSum){
318 NodeBuilder<> nb(kind::PLUS);
319 for(TNode::iterator i=sum.begin(); i!=sum.end(); ++i){
320 Node dist = distribute(*i, distribSum);
321 for(Node::iterator j=dist.begin(); j!=dist.end(); ++j){
322 nb << *j;
323 }
324 }
325 return nb;
326 }
327
328 Node ArithRewriter::rewriteMult(TNode t){
329
330 using namespace std;
331
332 Rational accumulator(1,1);
333 set<Node> variables;
334 vector<Node> sums;
335
336 //These stacks need to be kept in lock step
337 stack<TNode> mult_iterators_nodes;
338 stack<TNode::const_iterator> mult_iterators_iters;
339
340 mult_iterators_nodes.push(t);
341 mult_iterators_iters.push(t.begin());
342
343 while(!mult_iterators_nodes.empty()){
344 TNode mult = mult_iterators_nodes.top();
345 TNode::const_iterator i = mult_iterators_iters.top();
346
347 mult_iterators_nodes.pop();
348 mult_iterators_iters.pop();
349
350 for(; i != mult.end(); ++i){
351 TNode child = *i;
352 if(child.getKind() == kind::MULT){ //TODO add not rewritten already checks
353 ++i;
354 mult_iterators_nodes.push(mult);
355 mult_iterators_iters.push(i);
356
357 mult_iterators_nodes.push(child);
358 mult_iterators_iters.push(child.begin());
359 break;
360 }
361 Node rewrittenChild = rewrite(child);
362
363 if(rewrittenChild.getMetaKind() == kind::metakind::CONSTANT){//c
364 Rational c = rewrittenChild.getConst<Rational>();
365 accumulator = accumulator * c;
366 if(accumulator == d_constants->d_ZERO){
367 return d_constants->d_ZERO_NODE;
368 }
369 }else if(rewrittenChild.getMetaKind() == kind::metakind::VARIABLE){ //v
370 variables.insert(rewrittenChild);
371 }else{ //(+ c p_1 p_2 ... p_N)
372 sums.push_back(rewrittenChild);
373 }
374 }
375 }
376 // accumulator * (\prod var_i) *(\prod sum_j)
377
378 if(sums.size() == 0){ //accumulator * (\prod var_i)
379 if(variables.size() == 0){ //accumulator
380 return mkRationalNode(accumulator);
381 }else if(variables.size() == 1 && accumulator == d_constants->d_ONE){ // var_1
382 Node var = *(variables.begin());
383 return var;
384 }else{
385 //We need to return (+ c p_1 p_2 ... p_N)
386 //To accomplish this:
387 // let pnf = pnf(accumulator * (\prod var_i)) in (+ 0 pnf)
388 Node pnf = toPnf(accumulator, variables);
389 Node normalForm = NodeManager::currentNM()->mkNode(kind::PLUS, d_constants->d_ZERO_NODE, pnf);
390 return normalForm;
391 }
392 }else{
393 vector<Node>::iterator sum_iter = sums.begin();
394 // \sum t
395 // t \in Q \cup A
396 // where A = lfp {\prod s | s \in Q \cup Variables \cup A}
397 Node distributed = *sum_iter;
398 ++sum_iter;
399 while(sum_iter != sums.end()){
400 Node curr = *sum_iter;
401 distributed = distributeSum(curr, distributed);
402 ++sum_iter;
403 }
404 if(variables.size() >= 1){
405 Node pnf = toPnf(accumulator, variables);
406 distributed = distribute(pnf, distributed);
407 }else{
408 Node constant = mkRationalNode(accumulator);
409 distributed = distribute(constant, distributed);
410 }
411
412 Node nf_distributed = rewrite(distributed);
413 return nf_distributed;
414 }
415 }
416
417 Node ArithRewriter::rewriteConstantDiv(TNode t){
418 Assert(t.getKind()== kind::DIVISION);
419
420 Node reLeft = rewrite(t[0]);
421 Node reRight = rewrite(t[1]);
422 Assert(reLeft.getKind()== kind::CONST_RATIONAL);
423 Assert(reRight.getKind()== kind::CONST_RATIONAL);
424
425 Rational num = reLeft.getConst<Rational>();
426 Rational den = reRight.getConst<Rational>();
427
428 Assert(den != d_constants->d_ZERO);
429
430 Rational div = num / den;
431
432 Node result = mkRationalNode(div);
433
434 return result;
435 }
436
437 Node ArithRewriter::rewriteTerm(TNode t){
438 if(t.getMetaKind() == kind::metakind::CONSTANT){
439 return coerceToRationalNode(t);
440 }else if(t.getMetaKind() == kind::metakind::VARIABLE){
441 return t;
442 }else if(t.getKind() == kind::MULT){
443 return rewriteMult(t);
444 }else if(t.getKind() == kind::PLUS){
445 return rewritePlus(t);
446 }else if(t.getKind() == kind::DIVISION){
447 return rewriteConstantDiv(t);
448 }else if(t.getKind() == kind::MINUS){
449 Node sub = makeSubtractionNode(t[0],t[1]);
450 return rewrite(sub);
451 }else{
452 Unreachable();
453 return Node::null();
454 }
455 }
456
457
458 /**
459 * Given a node in PNF pnf = (* d p_1 p_2 .. p_M) and a rational q != 0
460 * constuct a node equal to q * pnf that is in pnf.
461 *
462 * The claim is that this is always okay:
463 * If d' = q*d, p' = (* d' p_1 p_2 .. p_M) =_{Reals} q * pnf.
464 */
465 Node ArithRewriter::multPnfByNonZero(TNode pnf, Rational& q){
466 Assert(q != d_constants->d_ZERO);
467 //TODO Assert(isPNF(pnf) );
468
469 int M = pnf.getNumChildren()-1;
470 Rational d = pnf[0].getConst<Rational>();
471 Rational new_d = d*q;
472
473
474 NodeBuilder<> mult(kind::MULT);
475 mult << mkRationalNode(new_d);
476 for(int i=1; i<=M; ++i){
477 mult << pnf[i];
478 }
479
480 Node result = mult;
481 return result;
482 }
483
484
485
486 Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
487 using namespace CVC4::kind;
488 NodeManager* currentNM = NodeManager::currentNM();
489 Node negR = currentNM->mkNode(MULT, d_constants->d_NEGATIVE_ONE_NODE, r);
490 Node diff = currentNM->mkNode(PLUS, l, negR);
491
492 return diff;
493 }
494
495
496 Kind multKind(Kind k, int sgn){
497 using namespace kind;
498
499 if(sgn < 0){
500
501 switch(k){
502 case LT: return GT;
503 case LEQ: return GEQ;
504 case EQUAL: return EQUAL;
505 case GEQ: return LEQ;
506 case GT: return LT;
507 default:
508 Unhandled();
509 }
510 return NULL_EXPR;
511 }else{
512 return k;
513 }
514 }
515
516 Node ArithRewriter::rewrite(TNode n){
517 Debug("arithrewriter") << "Trace rewrite:" << n << std::endl;
518
519 if(n.getAttribute(IsNormal())){
520 return n;
521 }
522
523 Node res;
524
525 if(isRelationOperator(n.getKind())){
526 res = rewriteAtom(n);
527 }else{
528 res = rewriteTerm(n);
529 }
530
531 if(n == res){
532 n.setAttribute(NormalForm(), Node::null());
533 }else{
534 n.setAttribute(NormalForm(), res);
535 }
536
537 return res;
538 }