Merges branches/arithmetic/atom-database r2979 through 3247 into trunk. Below is...
[cvc5.git] / src / theory / arith / arith_rewriter.cpp
1 /********************* */
2 /*! \file arith_rewriter.cpp
3 ** \verbatim
4 ** Original author: taking
5 ** Major contributors: none
6 ** Minor contributors (to current version): mdeters, dejan
7 ** This file is part of the CVC4 prototype.
8 ** Copyright (c) 2009, 2010, 2011 The Analysis of Computer Systems Group (ACSys)
9 ** Courant Institute of Mathematical Sciences
10 ** New York University
11 ** See the file COPYING in the top-level source directory for licensing
12 ** information.\endverbatim
13 **
14 ** \brief [[ Add one-line brief description here ]]
15 **
16 ** [[ Add lengthier description here ]]
17 ** \todo document this file
18 **/
19
20 #include "theory/theory.h"
21 #include "theory/arith/normal_form.h"
22 #include "theory/arith/arith_rewriter.h"
23 #include "theory/arith/arith_utilities.h"
24
25 #include <vector>
26 #include <set>
27 #include <stack>
28
29 namespace CVC4 {
30 namespace theory {
31 namespace arith {
32
33 bool isVariable(TNode t){
34 return t.getMetaKind() == kind::metakind::VARIABLE;
35 }
36
37 bool ArithRewriter::isAtom(TNode n) {
38 return arith::isRelationOperator(n.getKind());
39 }
40
41 RewriteResponse ArithRewriter::rewriteConstant(TNode t){
42 Assert(t.getMetaKind() == kind::metakind::CONSTANT);
43 Node val = coerceToRationalNode(t);
44
45 return RewriteResponse(REWRITE_DONE, val);
46 }
47
48 RewriteResponse ArithRewriter::rewriteVariable(TNode t){
49 Assert(isVariable(t));
50
51 return RewriteResponse(REWRITE_DONE, t);
52 }
53
54 RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){
55 Assert(t.getKind()== kind::MINUS);
56
57 if(pre){
58 if(t[0] == t[1]){
59 Rational zero(0);
60 Node zeroNode = mkRationalNode(zero);
61 return RewriteResponse(REWRITE_DONE, zeroNode);
62 }else{
63 Node noMinus = makeSubtractionNode(t[0],t[1]);
64 return RewriteResponse(REWRITE_DONE, noMinus);
65 }
66 }else{
67 Polynomial minuend = Polynomial::parsePolynomial(t[0]);
68 Polynomial subtrahend = Polynomial::parsePolynomial(t[0]);
69 Polynomial diff = minuend - subtrahend;
70 return RewriteResponse(REWRITE_DONE, diff.getNode());
71 }
72 }
73
74 RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){
75 Assert(t.getKind()== kind::UMINUS);
76
77 Node noUminus = makeUnaryMinusNode(t[0]);
78 if(pre)
79 return RewriteResponse(REWRITE_DONE, noUminus);
80 else
81 return RewriteResponse(REWRITE_AGAIN, noUminus);
82 }
83
84 RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
85 if(t.getMetaKind() == kind::metakind::CONSTANT){
86 return rewriteConstant(t);
87 }else if(isVariable(t)){
88 return rewriteVariable(t);
89 }else if(t.getKind() == kind::MINUS){
90 return rewriteMinus(t, true);
91 }else if(t.getKind() == kind::UMINUS){
92 return rewriteUMinus(t, true);
93 }else if(t.getKind() == kind::DIVISION){
94 if(t[0].getKind()== kind::CONST_RATIONAL){
95 return rewriteDivByConstant(t, true);
96 }else{
97 return RewriteResponse(REWRITE_DONE, t);
98 }
99 }else if(t.getKind() == kind::PLUS){
100 return preRewritePlus(t);
101 }else if(t.getKind() == kind::MULT){
102 return preRewriteMult(t);
103 }else if(t.getKind() == kind::INTS_DIVISION){
104 Integer intOne(1);
105 if(t[1].getKind()== kind::CONST_INTEGER && t[1].getConst<Integer>() == intOne){
106 return RewriteResponse(REWRITE_AGAIN, t[0]);
107 }else{
108 return RewriteResponse(REWRITE_DONE, t);
109 }
110 }else if(t.getKind() == kind::INTS_MODULUS){
111 Integer intOne(1);
112 if(t[1].getKind()== kind::CONST_INTEGER && t[1].getConst<Integer>() == intOne){
113 Integer intZero(0);
114 return RewriteResponse(REWRITE_AGAIN, mkIntegerNode(intZero));
115 }else{
116 return RewriteResponse(REWRITE_DONE, t);
117 }
118 }else{
119 Unreachable();
120 }
121 }
122 RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
123 if(t.getMetaKind() == kind::metakind::CONSTANT){
124 return rewriteConstant(t);
125 }else if(isVariable(t)){
126 return rewriteVariable(t);
127 }else if(t.getKind() == kind::MINUS){
128 return rewriteMinus(t, false);
129 }else if(t.getKind() == kind::UMINUS){
130 return rewriteUMinus(t, false);
131 }else if(t.getKind() == kind::DIVISION){
132 return rewriteDivByConstant(t, false);
133 }else if(t.getKind() == kind::PLUS){
134 return postRewritePlus(t);
135 }else if(t.getKind() == kind::MULT){
136 return postRewriteMult(t);
137 }else if(t.getKind() == kind::INTS_DIVISION){
138 return RewriteResponse(REWRITE_DONE, t);
139 }else if(t.getKind() == kind::INTS_MODULUS){
140 return RewriteResponse(REWRITE_DONE, t);
141 }else{
142 Unreachable();
143 }
144 }
145
146 RewriteResponse ArithRewriter::preRewriteMult(TNode t){
147 Assert(t.getKind()== kind::MULT);
148
149 // Rewrite multiplications with a 0 argument and to 0
150 Integer intZero;
151
152 Rational qZero(0);
153
154 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
155 if((*i).getKind() == kind::CONST_RATIONAL) {
156 if((*i).getConst<Rational>() == qZero) {
157 return RewriteResponse(REWRITE_DONE, mkRationalNode(qZero));
158 }
159 } else if((*i).getKind() == kind::CONST_INTEGER) {
160 if((*i).getConst<Integer>() == intZero) {
161 if(t.getType().isInteger()) {
162 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(intZero));
163 } else {
164 return RewriteResponse(REWRITE_DONE, mkRationalNode(qZero));
165 }
166 }
167 }
168 }
169 return RewriteResponse(REWRITE_DONE, t);
170 }
171 RewriteResponse ArithRewriter::preRewritePlus(TNode t){
172 Assert(t.getKind()== kind::PLUS);
173
174 return RewriteResponse(REWRITE_DONE, t);
175 }
176
177 RewriteResponse ArithRewriter::postRewritePlus(TNode t){
178 Assert(t.getKind()== kind::PLUS);
179
180 Polynomial res = Polynomial::mkZero();
181
182 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
183 Node curr = *i;
184 Polynomial currPoly = Polynomial::parsePolynomial(curr);
185
186 res = res + currPoly;
187 }
188
189 return RewriteResponse(REWRITE_DONE, res.getNode());
190 }
191
192 RewriteResponse ArithRewriter::postRewriteMult(TNode t){
193 Assert(t.getKind()== kind::MULT);
194
195 Polynomial res = Polynomial::mkOne();
196
197 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
198 Node curr = *i;
199 Polynomial currPoly = Polynomial::parsePolynomial(curr);
200
201 res = res * currPoly;
202 }
203
204 return RewriteResponse(REWRITE_DONE, res.getNode());
205 }
206
207 // RewriteResponse ArithRewriter::postRewriteAtomConstantRHS(TNode t){
208 // TNode left = t[0];
209 // TNode right = t[1];
210
211 // Polynomial pLeft = Polynomial::parsePolynomial(left);
212
213
214 // Comparison cmp = Comparison::mkComparison(t.getKind(), Polynomial::parsePolynomial(left), Constant(right));
215
216 // Assert(cmp.isNormalForm());
217 // return RewriteResponse(REWRITE_DONE, cmp.getNode());
218 // }
219
220 RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
221 // left |><| right
222 TNode left = atom[0];
223 TNode right = atom[1];
224
225 Polynomial pleft = Polynomial::parsePolynomial(left);
226 Polynomial pright = Polynomial::parsePolynomial(right);
227
228 Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
229 Assert(cmp.isNormalForm());
230 return RewriteResponse(REWRITE_DONE, cmp.getNode());
231 }
232
233 RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
234 Assert(isAtom(atom));
235
236 NodeManager* currNM = NodeManager::currentNM();
237
238 if(atom.getKind() == kind::EQUAL) {
239 if(atom[0] == atom[1]) {
240 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
241 }
242 }else if(atom.getKind() == kind::GT){
243 Node leq = currNM->mkNode(kind::LEQ, atom[0], atom[1]);
244 return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, leq));
245 }else if(atom.getKind() == kind::LT){
246 Node geq = currNM->mkNode(kind::GEQ, atom[0], atom[1]);
247 return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, geq));
248 }
249
250 return RewriteResponse(REWRITE_DONE, atom);
251 }
252
253 RewriteResponse ArithRewriter::postRewrite(TNode t){
254 if(isTerm(t)){
255 RewriteResponse response = postRewriteTerm(t);
256 if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
257 Polynomial::parsePolynomial(response.node);
258 }
259 return response;
260 }else if(isAtom(t)){
261 RewriteResponse response = postRewriteAtom(t);
262 if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
263 Comparison::parseNormalForm(response.node);
264 }
265 return response;
266 }else{
267 Unreachable();
268 return RewriteResponse(REWRITE_DONE, Node::null());
269 }
270 }
271
272 RewriteResponse ArithRewriter::preRewrite(TNode t){
273 if(isTerm(t)){
274 return preRewriteTerm(t);
275 }else if(isAtom(t)){
276 return preRewriteAtom(t);
277 }else{
278 Unreachable();
279 return RewriteResponse(REWRITE_DONE, Node::null());
280 }
281 }
282
283 Node ArithRewriter::makeUnaryMinusNode(TNode n){
284 Rational qNegOne(-1);
285 return NodeManager::currentNM()->mkNode(kind::MULT, mkRationalNode(qNegOne),n);
286 }
287
288 Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
289 Node negR = makeUnaryMinusNode(r);
290 Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR);
291
292 return diff;
293 }
294
295 RewriteResponse ArithRewriter::rewriteDivByConstant(TNode t, bool pre){
296 Assert(t.getKind()== kind::DIVISION);
297
298 Node left = t[0];
299 Node right = t[1];
300 Assert(right.getKind()== kind::CONST_RATIONAL);
301
302
303 const Rational& den = right.getConst<Rational>();
304
305 Assert(den != Rational(0));
306
307 Rational div = den.inverse();
308
309 Node result = mkRationalNode(div);
310
311 Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
312 if(pre){
313 return RewriteResponse(REWRITE_DONE, mult);
314 }else{
315 return RewriteResponse(REWRITE_AGAIN, mult);
316 }
317 }
318
319 }/* CVC4::theory::arith namespace */
320 }/* CVC4::theory namespace */
321 }/* CVC4 namespace */