Partial merge from kind-backend branch, including Minisat and CNF work to
[cvc5.git] / src / theory / arith / arith_rewriter.cpp
1 /********************* */
2 /*! \file arith_rewriter.cpp
3 ** \verbatim
4 ** Original author: taking
5 ** Major contributors: none
6 ** Minor contributors (to current version): mdeters, dejan
7 ** This file is part of the CVC4 prototype.
8 ** Copyright (c) 2009, 2010, 2011 The Analysis of Computer Systems Group (ACSys)
9 ** Courant Institute of Mathematical Sciences
10 ** New York University
11 ** See the file COPYING in the top-level source directory for licensing
12 ** information.\endverbatim
13 **
14 ** \brief [[ Add one-line brief description here ]]
15 **
16 ** [[ Add lengthier description here ]]
17 ** \todo document this file
18 **/
19
20 #include "theory/theory.h"
21 #include "theory/arith/normal_form.h"
22 #include "theory/arith/arith_rewriter.h"
23 #include "theory/arith/arith_utilities.h"
24
25 #include <vector>
26 #include <set>
27 #include <stack>
28
29 using namespace CVC4;
30 using namespace CVC4::theory;
31 using namespace CVC4::theory::arith;
32
33 bool isVariable(TNode t){
34 return t.getMetaKind() == kind::metakind::VARIABLE;
35 }
36
37 RewriteResponse ArithRewriter::rewriteConstant(TNode t){
38 Assert(t.getMetaKind() == kind::metakind::CONSTANT);
39 Node val = coerceToRationalNode(t);
40
41 return RewriteResponse(REWRITE_DONE, val);
42 }
43
44 RewriteResponse ArithRewriter::rewriteVariable(TNode t){
45 Assert(isVariable(t));
46
47 return RewriteResponse(REWRITE_DONE, t);
48 }
49
50 RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){
51 Assert(t.getKind()== kind::MINUS);
52
53 if(t[0] == t[1]){
54 Rational zero(0);
55 Node zeroNode = mkRationalNode(zero);
56 return RewriteResponse(REWRITE_DONE, zeroNode);
57 }
58
59 Node noMinus = makeSubtractionNode(t[0],t[1]);
60 if(pre){
61 return RewriteResponse(REWRITE_DONE, noMinus);
62 }else{
63 return RewriteResponse(REWRITE_AGAIN_FULL, noMinus);
64 }
65 }
66
67 RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){
68 Assert(t.getKind()== kind::UMINUS);
69
70 Node noUminus = makeUnaryMinusNode(t[0]);
71 if(pre)
72 return RewriteResponse(REWRITE_DONE, noUminus);
73 else
74 return RewriteResponse(REWRITE_AGAIN, noUminus);
75 }
76
77 RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
78 if(t.getMetaKind() == kind::metakind::CONSTANT){
79 return rewriteConstant(t);
80 }else if(isVariable(t)){
81 return rewriteVariable(t);
82 }else if(t.getKind() == kind::MINUS){
83 return rewriteMinus(t, true);
84 }else if(t.getKind() == kind::UMINUS){
85 return rewriteUMinus(t, true);
86 }else if(t.getKind() == kind::DIVISION){
87 if(t[0].getKind()== kind::CONST_RATIONAL){
88 return rewriteDivByConstant(t, true);
89 }else{
90 return RewriteResponse(REWRITE_DONE, t);
91 }
92 }else if(t.getKind() == kind::PLUS){
93 return preRewritePlus(t);
94 }else if(t.getKind() == kind::MULT){
95 return preRewriteMult(t);
96 }else if(t.getKind() == kind::INTS_DIVISION){
97 Integer intOne(1);
98 if(t[1].getKind()== kind::CONST_INTEGER && t[1].getConst<Integer>() == intOne){
99 return RewriteResponse(REWRITE_AGAIN, t[0]);
100 }else{
101 return RewriteResponse(REWRITE_DONE, t);
102 }
103 }else if(t.getKind() == kind::INTS_MODULUS){
104 Integer intOne(1);
105 if(t[1].getKind()== kind::CONST_INTEGER && t[1].getConst<Integer>() == intOne){
106 Integer intZero(0);
107 return RewriteResponse(REWRITE_AGAIN, mkIntegerNode(intZero));
108 }else{
109 return RewriteResponse(REWRITE_DONE, t);
110 }
111 }else{
112 Unreachable();
113 }
114 }
115 RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
116 if(t.getMetaKind() == kind::metakind::CONSTANT){
117 return rewriteConstant(t);
118 }else if(isVariable(t)){
119 return rewriteVariable(t);
120 }else if(t.getKind() == kind::MINUS){
121 return rewriteMinus(t, false);
122 }else if(t.getKind() == kind::UMINUS){
123 return rewriteUMinus(t, false);
124 }else if(t.getKind() == kind::DIVISION){
125 return rewriteDivByConstant(t, false);
126 }else if(t.getKind() == kind::PLUS){
127 return postRewritePlus(t);
128 }else if(t.getKind() == kind::MULT){
129 return postRewriteMult(t);
130 }else if(t.getKind() == kind::INTS_DIVISION){
131 return RewriteResponse(REWRITE_DONE, t);
132 }else if(t.getKind() == kind::INTS_MODULUS){
133 return RewriteResponse(REWRITE_DONE, t);
134 }else{
135 Unreachable();
136 }
137 }
138
139 RewriteResponse ArithRewriter::preRewriteMult(TNode t){
140 Assert(t.getKind()== kind::MULT);
141
142 // Rewrite multiplications with a 0 argument and to 0
143 Integer intZero;
144
145 Rational qZero(0);
146
147 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
148 if((*i).getKind() == kind::CONST_RATIONAL) {
149 if((*i).getConst<Rational>() == qZero) {
150 return RewriteResponse(REWRITE_DONE, mkRationalNode(qZero));
151 }
152 } else if((*i).getKind() == kind::CONST_INTEGER) {
153 if((*i).getConst<Integer>() == intZero) {
154 if(t.getType().isInteger()) {
155 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(intZero));
156 } else {
157 return RewriteResponse(REWRITE_DONE, mkRationalNode(qZero));
158 }
159 }
160 }
161 }
162 return RewriteResponse(REWRITE_DONE, t);
163 }
164 RewriteResponse ArithRewriter::preRewritePlus(TNode t){
165 Assert(t.getKind()== kind::PLUS);
166
167 return RewriteResponse(REWRITE_DONE, t);
168 }
169
170 RewriteResponse ArithRewriter::postRewritePlus(TNode t){
171 Assert(t.getKind()== kind::PLUS);
172
173 Polynomial res = Polynomial::mkZero();
174
175 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
176 Node curr = *i;
177 Polynomial currPoly = Polynomial::parsePolynomial(curr);
178
179 res = res + currPoly;
180 }
181
182 return RewriteResponse(REWRITE_DONE, res.getNode());
183 }
184
185 RewriteResponse ArithRewriter::postRewriteMult(TNode t){
186 Assert(t.getKind()== kind::MULT);
187
188 Polynomial res = Polynomial::mkOne();
189
190 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
191 Node curr = *i;
192 Polynomial currPoly = Polynomial::parsePolynomial(curr);
193
194 res = res * currPoly;
195 }
196
197 return RewriteResponse(REWRITE_DONE, res.getNode());
198 }
199
200 RewriteResponse ArithRewriter::postRewriteAtomConstantRHS(TNode t){
201 TNode left = t[0];
202 TNode right = t[1];
203
204 Comparison cmp = Comparison::mkNormalComparison(t.getKind(), Polynomial::parsePolynomial(left), Constant(right));
205
206 Assert(cmp.isNormalForm());
207 return RewriteResponse(REWRITE_DONE, cmp.getNode());
208
209
210 // Comparison cmp = Comparison::mkComparison(t.getKind(), Polynomial::parsePolynomial(left), Constant(right));
211
212 // if(cmp.isBoolean()){
213 // return RewriteResponse(REWRITE_DONE, cmp.getNode());
214 // }
215
216 // if(cmp.getLeft().containsConstant()){
217 // Monomial constantHead = cmp.getLeft().getHead();
218 // Assert(constantHead.isConstant());
219
220 // Constant constant = constantHead.getConstant();
221
222 // Constant negativeConstantHead = -constant;
223
224 // cmp = cmp.addConstant(negativeConstantHead);
225 // }
226 // Assert(!cmp.getLeft().containsConstant());
227
228 // if(!cmp.getLeft().getHead().coefficientIsOne()){
229 // Monomial constantHead = cmp.getLeft().getHead();
230 // Assert(!constantHead.isConstant());
231 // Constant constant = constantHead.getConstant();
232
233 // Constant inverse = Constant::mkConstant(constant.getValue().inverse());
234
235 // cmp = cmp.multiplyConstant(inverse);
236 // }
237 // Assert(cmp.getLeft().getHead().coefficientIsOne());
238
239 // Assert(cmp.isBoolean() || cmp.isNormalForm());
240 // return RewriteResponse(REWRITE_DONE, cmp.getNode());
241 }
242
243 RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
244 // left |><| right
245 TNode left = atom[0];
246 TNode right = atom[1];
247
248 if(right.getMetaKind() == kind::metakind::CONSTANT){
249 return postRewriteAtomConstantRHS(atom);
250 }else{
251 //Transform this to: (left - right) |><| 0
252 Node diff = makeSubtractionNode(left, right);
253 Rational qZero(0);
254 Node reduction = NodeManager::currentNM()->mkNode(atom.getKind(), diff, mkRationalNode(qZero));
255 return RewriteResponse(REWRITE_AGAIN_FULL, reduction);
256 }
257 }
258
259 RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
260 Assert(isAtom(atom));
261
262 NodeManager* currNM = NodeManager::currentNM();
263
264 if(atom.getKind() == kind::EQUAL) {
265 if(atom[0] == atom[1]) {
266 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
267 }
268 }
269
270 Node reduction = atom;
271
272 if(atom[1].getMetaKind() != kind::metakind::CONSTANT) {
273 // left |><| right
274 TNode left = atom[0];
275 TNode right = atom[1];
276
277 //Transform this to: (left - right) |><| 0
278 Node diff = makeSubtractionNode(left, right);
279 Rational qZero(0);
280 reduction = currNM->mkNode(atom.getKind(), diff, mkRationalNode(qZero));
281 }
282
283 if(reduction.getKind() == kind::GT){
284 Node leq = currNM->mkNode(kind::LEQ, reduction[0], reduction[1]);
285 reduction = currNM->mkNode(kind::NOT, leq);
286 }else if(reduction.getKind() == kind::LT){
287 Node geq = currNM->mkNode(kind::GEQ, reduction[0], reduction[1]);
288 reduction = currNM->mkNode(kind::NOT, geq);
289 }
290 /* BREADCRUMB : Move this rewrite into preprocessing
291 else if( Options::current()->rewriteArithEqualities && reduction.getKind() == kind::EQUAL){
292 Node geq = currNM->mkNode(kind::GEQ, reduction[0], reduction[1]);
293 Node leq = currNM->mkNode(kind::LEQ, reduction[0], reduction[1]);
294 reduction = currNM->mkNode(kind::AND, geq, leq);
295 }
296 */
297
298
299 return RewriteResponse(REWRITE_DONE, reduction);
300 }
301
302 RewriteResponse ArithRewriter::postRewrite(TNode t){
303 if(isTerm(t)){
304 RewriteResponse response = postRewriteTerm(t);
305 if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
306 Polynomial::parsePolynomial(response.node);
307 }
308 return response;
309 }else if(isAtom(t)){
310 RewriteResponse response = postRewriteAtom(t);
311 if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
312 Comparison::parseNormalForm(response.node);
313 }
314 return response;
315 }else{
316 Unreachable();
317 return RewriteResponse(REWRITE_DONE, Node::null());
318 }
319 }
320
321 RewriteResponse ArithRewriter::preRewrite(TNode t){
322 if(isTerm(t)){
323 return preRewriteTerm(t);
324 }else if(isAtom(t)){
325 return preRewriteAtom(t);
326 }else{
327 Unreachable();
328 return RewriteResponse(REWRITE_DONE, Node::null());
329 }
330 }
331
332 Node ArithRewriter::makeUnaryMinusNode(TNode n){
333 Rational qNegOne(-1);
334 return NodeManager::currentNM()->mkNode(kind::MULT, mkRationalNode(qNegOne),n);
335 }
336
337 Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
338 Node negR = makeUnaryMinusNode(r);
339 Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR);
340
341 return diff;
342 }
343
344 RewriteResponse ArithRewriter::rewriteDivByConstant(TNode t, bool pre){
345 Assert(t.getKind()== kind::DIVISION);
346
347 Node left = t[0];
348 Node right = t[1];
349 Assert(right.getKind()== kind::CONST_RATIONAL);
350
351
352 const Rational& den = right.getConst<Rational>();
353
354 Assert(den != Rational(0));
355
356 Rational div = den.inverse();
357
358 Node result = mkRationalNode(div);
359
360 Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
361 if(pre){
362 return RewriteResponse(REWRITE_DONE, mult);
363 }else{
364 return RewriteResponse(REWRITE_AGAIN, mult);
365 }
366 }