Added Rational constructors that only take a numerator. The const char* Rational...
[cvc5.git] / src / theory / arith / arith_rewriter.cpp
1
2 #include "theory/arith/arith_rewriter.h"
3 #include "theory/arith/arith_utilities.h"
4 #include "theory/arith/normal.h"
5
6 #include <vector>
7 #include <set>
8 #include <stack>
9
10
11 using namespace CVC4;
12 using namespace CVC4::theory;
13 using namespace CVC4::theory::arith;
14
15
16
17
18
19 Kind multKind(Kind k, int sgn);
20
21 /**
22 * Performs a quick check to see if it is easy to rewrite to
23 * this normal form
24 * v |><| b
25 * Also writes relations with constants on both sides to TRUE or FALSE.
26 * If it can, it returns true and sets res to this value.
27 *
28 * This is for optimizing rewriteAtom() to avoid the more compuationally
29 * expensive general rewriting procedure.
30 *
31 * If simplification is not done, it returns Node::null()
32 */
33 Node almostVarOrConstEqn(TNode atom, Kind k, TNode left, TNode right){
34 Assert(atom.getKind() == k);
35 Assert(isRelationOperator(k));
36 Assert(atom[0] == left);
37 Assert(atom[1] == right);
38 bool leftIsConst = left.getMetaKind() == kind::metakind::CONSTANT;
39 bool rightIsConst = right.getMetaKind() == kind::metakind::CONSTANT;
40
41 bool leftIsVar = left.getMetaKind() == kind::metakind::VARIABLE;
42 bool rightIsVar = right.getMetaKind() == kind::metakind::VARIABLE;
43
44 if(leftIsConst && rightIsConst){
45 Rational lc = coerceToRational(left);
46 Rational rc = coerceToRational(right);
47 bool res = evaluateConstantPredicate(k,lc, rc);
48 return mkBoolNode(res);
49 }else if(leftIsVar && rightIsConst){
50 if(right.getKind() == kind::CONST_RATIONAL){
51 return atom;
52 }else{
53 return NodeManager::currentNM()->mkNode(k,left,coerceToRationalNode(right));
54 }
55 }else if(leftIsConst && rightIsVar){
56 if(left.getKind() == kind::CONST_RATIONAL){
57 return NodeManager::currentNM()->mkNode(multKind(k,-1),right,left);
58 }else{
59 Node q_left = coerceToRationalNode(left);
60 return NodeManager::currentNM()->mkNode(multKind(k,-1),right,q_left);
61 }
62 }
63
64 return Node::null();
65 }
66
67 Node ArithRewriter::rewriteAtomCore(TNode atom){
68
69 Kind k = atom.getKind();
70 Assert(isRelationOperator(k));
71
72 // left |><| right
73 TNode left = atom[0];
74 TNode right = atom[1];
75
76 Node nf = almostVarOrConstEqn(atom, k,left,right);
77 if(nf != Node::null() ){
78 return nf;
79 }
80
81
82 //Transform this to: (left- right) |><| 0
83 Node diff = makeSubtractionNode(left, right);
84
85 Node rewritten = rewrite(diff);
86 // rewritten =_{Reals} left - right => rewritten |><| 0
87
88 if(rewritten.getMetaKind() == kind::metakind::CONSTANT){
89 // Case 1 rewritten : c
90 Rational c = rewritten.getConst<Rational>();
91 bool res = evaluateConstantPredicate(k, c, d_constants->d_ZERO);
92 nf = mkBoolNode(res);
93 }else if(rewritten.getMetaKind() == kind::metakind::VARIABLE){
94 // Case 2 rewritten : v
95 nf = NodeManager::currentNM()->mkNode(k, rewritten, d_constants->d_ZERO_NODE);
96 }else{
97 // Case 3 rewritten : (+ c p_1 p_2 ... p_N) | not(N=1 and c=0 and p_1.d=1)
98 Rational c = rewritten[0].getConst<Rational>();
99 c = -c;
100 TNode p_1 = rewritten[1];
101 Rational d = p_1[0].getConst<Rational>();
102 d = d.inverse();
103 c = c * d;
104 Node newRight = mkRationalNode(c);
105 Kind newKind = multKind(k, d.sgn());
106 int N = rewritten.getNumChildren() - 1;
107
108 if(N==1){
109 int M = p_1.getNumChildren()-1;
110 if(M == 1){ // v |><| b
111 TNode v = p_1[1];
112 nf = NodeManager::currentNM()->mkNode(newKind, v, newRight);
113 }else{ // p |><| b
114 Node newLeft = multPnfByNonZero(p_1, d);
115 nf = NodeManager::currentNM()->mkNode(newKind, newLeft, newRight);
116 }
117 }else{ //(+ p_1 .. p_N) |><| b
118 NodeBuilder<> plus(kind::PLUS);
119 for(int i=1; i<=N; ++i){
120 TNode p_i = rewritten[i];
121 plus << multPnfByNonZero(p_i, d);
122 }
123 Node newLeft = plus;
124 nf = NodeManager::currentNM()->mkNode(newKind, newLeft, newRight);
125 }
126 }
127
128 return nf;
129 }
130
131 Node ArithRewriter::rewriteAtom(TNode atom){
132 Node rewritten = rewriteAtomCore(atom);
133 if(rewritten.getKind() == kind::LT){
134 Node geq = NodeManager::currentNM()->mkNode(kind::GEQ, rewritten[0], rewritten[1]);
135 return NodeManager::currentNM()->mkNode(kind::NOT, geq);
136 }else if(rewritten.getKind() == kind::GT){
137 Node leq = NodeManager::currentNM()->mkNode(kind::LEQ, rewritten[0], rewritten[1]);
138 return NodeManager::currentNM()->mkNode(kind::NOT, leq);
139 }else{
140 return rewritten;
141 }
142 }
143
144
145 /* cmp( (* d v_1 v_2 ... v_M), (* d' v'_1 v'_2 ... v'_M'):
146 * if(M == M'):
147 * then tupleCompare(v_i, v'_i)
148 * else M -M'
149 */
150 struct pnfLessThan {
151 bool operator()(Node p0, Node p1) {
152 int p0_M = p0.getNumChildren() -1;
153 int p1_M = p1.getNumChildren() -1;
154 if(p0_M == p1_M){
155 for(int i=1; i<= p0_M; ++i){
156 if(p0[i] != p1[i]){
157 return p0[i] < p1[i];
158 }
159 }
160 return false; //p0 == p1 in this order
161 }else{
162 return p0_M < p1_M;
163 }
164 }
165 };
166
167 Node addPnfs(TNode p0, TNode p1){
168 //TODO asserts
169 Rational c0 = p0[0].getConst<Rational>();
170 Rational c1 = p1[0].getConst<Rational>();
171
172 int M = p0.getNumChildren();
173
174 Rational addedC = c0 + c1;
175 Node newC = mkRationalNode(addedC);
176 NodeBuilder<> nb(kind::PLUS);
177 nb << newC;
178 for(int i=1; i <= M; ++i){
179 nb << p0[i];
180 }
181 Node newPnf = nb;
182 return newPnf;
183 }
184
185 void sortAndCombineCoefficients(std::vector<Node>& pnfs){
186 using namespace std;
187
188 /* combined contains exactly 1 representative per for each pnf.
189 * This is maintained by combining the coefficients for pnfs.
190 * that is equal according to pnfLessThan.
191 */
192 typedef set<Node, pnfLessThan> PnfSet;
193 PnfSet combined;
194
195 for(vector<Node>::iterator i=pnfs.begin(); i != pnfs.end(); ++i){
196 Node pnf = *i;
197 PnfSet::iterator pos = combined.find(pnf);
198
199 if(pos == combined.end()){
200 combined.insert(pnf);
201 }else{
202 Node current = *pos;
203 Node sum = addPnfs(pnf, current);
204 combined.erase(pos);
205 combined.insert(sum);
206 }
207 }
208 pnfs.clear();
209 for(PnfSet::iterator i=combined.begin(); i != combined.end(); ++i){
210 Node pnf = *i;
211 pnfs.push_back(pnf);
212 }
213 }
214
215 Node ArithRewriter::var2pnf(TNode variable){
216 return NodeManager::currentNM()->mkNode(kind::MULT,d_constants->d_ONE_NODE,variable);
217 }
218
219 Node ArithRewriter::rewritePlus(TNode t){
220 using namespace std;
221
222 Rational accumulator;
223 vector<Node> pnfs;
224
225 for(TNode::iterator i = t.begin(); i!= t.end(); ++i){
226 TNode child = *i;
227 Node rewrittenChild = rewrite(child);
228
229 if(rewrittenChild.getMetaKind() == kind::metakind::CONSTANT){//c
230 Rational c = rewrittenChild.getConst<Rational>();
231 accumulator = accumulator + c;
232 }else if(rewrittenChild.getMetaKind() == kind::metakind::VARIABLE){ //v
233 Node pnf = var2pnf(rewrittenChild);
234 pnfs.push_back(pnf);
235 }else{ //(+ c p_1 p_2 ... p_N)
236 Rational c = rewrittenChild[0].getConst<Rational>();
237 accumulator = accumulator + c;
238 int N = rewrittenChild.getNumChildren() - 1;
239 for(int i=1; i<=N; ++i){
240 TNode pnf = rewrittenChild[i];
241 pnfs.push_back(pnf);
242 }
243 }
244 }
245 sortAndCombineCoefficients(pnfs);
246 if(pnfs.size() == 0){
247 return mkRationalNode(accumulator);
248 }
249
250 // pnfs.size() >= 1
251
252 //Enforce not(N=1 and c=0 and p_1.d=1)
253 if(pnfs.size() == 1){
254 Node p_1 = *(pnfs.begin());
255 if(p_1[0].getConst<Rational>() == d_constants->d_ONE){
256 if(accumulator == d_constants->d_ZERO){ // 0 + (* 1 var) |-> var
257 Node var = p_1[1];
258 return var;
259 }
260 }
261 }
262
263 //We must be in this case
264 //(+ c p_1 p_2 ... p_N) | not(N=1 and c=0 and p_1.d=1)
265
266 NodeBuilder<> nb(kind::PLUS);
267 nb << mkRationalNode(accumulator);
268 Debug("arithrewrite") << mkRationalNode(accumulator) << std::endl;
269 for(vector<Node>::iterator i = pnfs.begin(); i != pnfs.end(); ++i){
270 nb << *i;
271 Debug("arithrewrite") << (*i) << std::endl;
272
273 }
274
275 Node normalForm = nb;
276 return normalForm;
277 }
278
279 //Does not enforce
280 //5) v_i are of metakind VARIABLE,
281 //6) v_i are in increasing (not strict) nodeOrder,
282 Node toPnf(Rational& c, std::set<Node>& variables){
283 NodeBuilder<> nb(kind::MULT);
284 nb << mkRationalNode(c);
285
286 for(std::set<Node>::iterator i = variables.begin(); i != variables.end(); ++i){
287 nb << *i;
288 }
289 Node pnf = nb;
290 return pnf;
291 }
292
293 Node distribute(TNode n, TNode sum){
294 NodeBuilder<> nb(kind::PLUS);
295 for(TNode::iterator i=sum.begin(); i!=sum.end(); ++i){
296 Node prod = NodeManager::currentNM()->mkNode(kind::MULT, n, *i);
297 nb << prod;
298 }
299 return nb;
300 }
301 Node distributeSum(TNode sum, TNode distribSum){
302 NodeBuilder<> nb(kind::PLUS);
303 for(TNode::iterator i=sum.begin(); i!=sum.end(); ++i){
304 Node dist = distribute(*i, distribSum);
305 for(Node::iterator j=dist.begin(); j!=dist.end(); ++j){
306 nb << *j;
307 }
308 }
309 return nb;
310 }
311
312 Node ArithRewriter::rewriteMult(TNode t){
313
314 using namespace std;
315
316 Rational accumulator(1,1);
317 set<Node> variables;
318 vector<Node> sums;
319
320 //These stacks need to be kept in lock step
321 stack<TNode> mult_iterators_nodes;
322 stack<unsigned> mult_iterators_iters;
323
324 mult_iterators_nodes.push(t);
325 mult_iterators_iters.push(0);
326
327 while(!mult_iterators_nodes.empty()){
328 TNode mult = mult_iterators_nodes.top();
329 unsigned i = mult_iterators_iters.top();
330
331 mult_iterators_nodes.pop();
332 mult_iterators_iters.pop();
333
334
335 for(; i < mult.getNumChildren(); ++i){
336 TNode child = mult[i];
337 if(child.getKind() == kind::MULT){ //TODO add not rewritten already checks
338 ++i;
339 mult_iterators_nodes.push(mult);
340 mult_iterators_iters.push(i);
341
342 mult_iterators_nodes.push(child);
343 mult_iterators_iters.push(0);
344 break;
345 }
346 Node rewrittenChild = rewrite(child);
347
348 if(rewrittenChild.getMetaKind() == kind::metakind::CONSTANT){//c
349 Rational c = rewrittenChild.getConst<Rational>();
350 accumulator = accumulator * c;
351 if(accumulator == d_constants->d_ZERO){
352 return d_constants->d_ZERO_NODE;
353 }
354 }else if(rewrittenChild.getMetaKind() == kind::metakind::VARIABLE){ //v
355 variables.insert(rewrittenChild);
356 }else{ //(+ c p_1 p_2 ... p_N)
357 sums.push_back(rewrittenChild);
358 }
359 }
360 }
361 // accumulator * (\prod var_i) *(\prod sum_j)
362
363 if(sums.size() == 0){ //accumulator * (\prod var_i)
364 if(variables.size() == 0){ //accumulator
365 return mkRationalNode(accumulator);
366 }else if(variables.size() == 1 && accumulator == d_constants->d_ONE){ // var_1
367 Node var = *(variables.begin());
368 return var;
369 }else{
370 //We need to return (+ c p_1 p_2 ... p_N)
371 //To accomplish this:
372 // let pnf = pnf(accumulator * (\prod var_i)) in (+ 0 pnf)
373 Node pnf = toPnf(accumulator, variables);
374 Node normalForm = NodeManager::currentNM()->mkNode(kind::PLUS, d_constants->d_ZERO_NODE, pnf);
375 return normalForm;
376 }
377 }else{
378 vector<Node>::iterator sum_iter = sums.begin();
379 // \sum t
380 // t \in Q \cup A
381 // where A = lfp {\prod s | s \in Q \cup Variables \cup A}
382 Node distributed = *sum_iter;
383 ++sum_iter;
384 while(sum_iter != sums.end()){
385 Node curr = *sum_iter;
386 distributed = distributeSum(curr, distributed);
387 ++sum_iter;
388 }
389 if(variables.size() >= 1){
390 Node pnf = toPnf(accumulator, variables);
391 distributed = distribute(pnf, distributed);
392 }else{
393 Node constant = mkRationalNode(accumulator);
394 distributed = distribute(constant, distributed);
395 }
396
397 Node nf_distributed = rewrite(distributed);
398 return nf_distributed;
399 }
400 }
401
402 Node ArithRewriter::rewriteConstantDiv(TNode t){
403 Assert(t.getKind()== kind::DIVISION);
404
405 Node reLeft = rewrite(t[0]);
406 Node reRight = rewrite(t[1]);
407 Assert(reLeft.getKind()== kind::CONST_RATIONAL);
408 Assert(reRight.getKind()== kind::CONST_RATIONAL);
409
410 Rational num = reLeft.getConst<Rational>();
411 Rational den = reRight.getConst<Rational>();
412
413 Assert(den != d_constants->d_ZERO);
414
415 Rational div = num / den;
416
417 Node result = mkRationalNode(div);
418
419 return result;
420 }
421
422 Node ArithRewriter::rewriteTerm(TNode t){
423 if(t.getMetaKind() == kind::metakind::CONSTANT){
424 return coerceToRationalNode(t);
425 }else if(t.getMetaKind() == kind::metakind::VARIABLE){
426 return t;
427 }else if(t.getKind() == kind::MULT){
428 return rewriteMult(t);
429 }else if(t.getKind() == kind::PLUS){
430 return rewritePlus(t);
431 }else if(t.getKind() == kind::DIVISION){
432 return rewriteConstantDiv(t);
433 }else if(t.getKind() == kind::MINUS){
434 Node sub = makeSubtractionNode(t[0],t[1]);
435 return rewrite(sub);
436 }else{
437 Unreachable();
438 return Node::null();
439 }
440 }
441
442
443 /**
444 * Given a node in PNF pnf = (* d p_1 p_2 .. p_M) and a rational q != 0
445 * constuct a node equal to q * pnf that is in pnf.
446 *
447 * The claim is that this is always okay:
448 * If d' = q*d, p' = (* d' p_1 p_2 .. p_M) =_{Reals} q * pnf.
449 */
450 Node ArithRewriter::multPnfByNonZero(TNode pnf, Rational& q){
451 Assert(q != d_constants->d_ZERO);
452 //TODO Assert(isPNF(pnf) );
453
454 int M = pnf.getNumChildren()-1;
455 Rational d = pnf[0].getConst<Rational>();
456 Rational new_d = d*q;
457
458
459 NodeBuilder<> mult(kind::MULT);
460 mult << mkRationalNode(new_d);
461 for(int i=1; i<=M; ++i){
462 mult << pnf[i];
463 }
464
465 Node result = mult;
466 return result;
467 }
468
469
470
471 Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
472 using namespace CVC4::kind;
473 NodeManager* currentNM = NodeManager::currentNM();
474 Node negR = currentNM->mkNode(MULT, d_constants->d_NEGATIVE_ONE_NODE, r);
475 Node diff = currentNM->mkNode(PLUS, l, negR);
476
477 return diff;
478 }
479
480
481 Kind multKind(Kind k, int sgn){
482 using namespace kind;
483
484 if(sgn < 0){
485
486 switch(k){
487 case LT: return GT;
488 case LEQ: return GEQ;
489 case EQUAL: return EQUAL;
490 case GEQ: return LEQ;
491 case GT: return LT;
492 default:
493 Unhandled();
494 }
495 return NULL_EXPR;
496 }else{
497 return k;
498 }
499 }
500
501 Node ArithRewriter::rewrite(TNode n){
502 Debug("arithrewriter") << "Trace rewrite:" << n << std::endl;
503
504 if(n.getAttribute(IsNormal())){
505 return n;
506 }
507
508 Node res;
509
510 if(isRelationOperator(n.getKind())){
511 res = rewriteAtom(n);
512 }else{
513 res = rewriteTerm(n);
514 }
515
516 if(n == res){
517 n.setAttribute(NormalForm(), Node::null());
518 }else{
519 n.setAttribute(NormalForm(), res);
520 }
521
522 return res;
523 }