Update to the ArithRewriter to remove REWRITE_AGAIN_FULL and limit REWRITE_AGAIN...
[cvc5.git] / src / theory / arith / arith_rewriter.cpp
1 /********************* */
2 /*! \file arith_rewriter.cpp
3 ** \verbatim
4 ** Original author: taking
5 ** Major contributors: none
6 ** Minor contributors (to current version): mdeters, dejan
7 ** This file is part of the CVC4 prototype.
8 ** Copyright (c) 2009, 2010, 2011 The Analysis of Computer Systems Group (ACSys)
9 ** Courant Institute of Mathematical Sciences
10 ** New York University
11 ** See the file COPYING in the top-level source directory for licensing
12 ** information.\endverbatim
13 **
14 ** \brief [[ Add one-line brief description here ]]
15 **
16 ** [[ Add lengthier description here ]]
17 ** \todo document this file
18 **/
19
20 #include "theory/theory.h"
21 #include "theory/arith/normal_form.h"
22 #include "theory/arith/arith_rewriter.h"
23 #include "theory/arith/arith_utilities.h"
24
25 #include <vector>
26 #include <set>
27 #include <stack>
28
29 namespace CVC4 {
30 namespace theory {
31 namespace arith {
32
33 bool isVariable(TNode t){
34 return t.getMetaKind() == kind::metakind::VARIABLE;
35 }
36
37 bool ArithRewriter::isAtom(TNode n) {
38 return arith::isRelationOperator(n.getKind());
39 }
40
41 RewriteResponse ArithRewriter::rewriteConstant(TNode t){
42 Assert(t.getMetaKind() == kind::metakind::CONSTANT);
43 Node val = coerceToRationalNode(t);
44
45 return RewriteResponse(REWRITE_DONE, val);
46 }
47
48 RewriteResponse ArithRewriter::rewriteVariable(TNode t){
49 Assert(isVariable(t));
50
51 return RewriteResponse(REWRITE_DONE, t);
52 }
53
54 RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){
55 Assert(t.getKind()== kind::MINUS);
56
57 if(pre){
58 if(t[0] == t[1]){
59 Rational zero(0);
60 Node zeroNode = mkRationalNode(zero);
61 return RewriteResponse(REWRITE_DONE, zeroNode);
62 }else{
63 Node noMinus = makeSubtractionNode(t[0],t[1]);
64 return RewriteResponse(REWRITE_DONE, noMinus);
65 }
66 }else{
67 Polynomial minuend = Polynomial::parsePolynomial(t[0]);
68 Polynomial subtrahend = Polynomial::parsePolynomial(t[0]);
69 Polynomial diff = minuend - subtrahend;
70 return RewriteResponse(REWRITE_DONE, diff.getNode());
71 }
72 }
73
74 RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){
75 Assert(t.getKind()== kind::UMINUS);
76
77 Node noUminus = makeUnaryMinusNode(t[0]);
78 if(pre)
79 return RewriteResponse(REWRITE_DONE, noUminus);
80 else
81 return RewriteResponse(REWRITE_AGAIN, noUminus);
82 }
83
84 RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
85 if(t.getMetaKind() == kind::metakind::CONSTANT){
86 return rewriteConstant(t);
87 }else if(isVariable(t)){
88 return rewriteVariable(t);
89 }else if(t.getKind() == kind::MINUS){
90 return rewriteMinus(t, true);
91 }else if(t.getKind() == kind::UMINUS){
92 return rewriteUMinus(t, true);
93 }else if(t.getKind() == kind::DIVISION){
94 if(t[0].getKind()== kind::CONST_RATIONAL){
95 return rewriteDivByConstant(t, true);
96 }else{
97 return RewriteResponse(REWRITE_DONE, t);
98 }
99 }else if(t.getKind() == kind::PLUS){
100 return preRewritePlus(t);
101 }else if(t.getKind() == kind::MULT){
102 return preRewriteMult(t);
103 }else if(t.getKind() == kind::INTS_DIVISION){
104 Integer intOne(1);
105 if(t[1].getKind()== kind::CONST_INTEGER && t[1].getConst<Integer>() == intOne){
106 return RewriteResponse(REWRITE_AGAIN, t[0]);
107 }else{
108 return RewriteResponse(REWRITE_DONE, t);
109 }
110 }else if(t.getKind() == kind::INTS_MODULUS){
111 Integer intOne(1);
112 if(t[1].getKind()== kind::CONST_INTEGER && t[1].getConst<Integer>() == intOne){
113 Integer intZero(0);
114 return RewriteResponse(REWRITE_AGAIN, mkIntegerNode(intZero));
115 }else{
116 return RewriteResponse(REWRITE_DONE, t);
117 }
118 }else{
119 Unreachable();
120 }
121 }
122 RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
123 if(t.getMetaKind() == kind::metakind::CONSTANT){
124 return rewriteConstant(t);
125 }else if(isVariable(t)){
126 return rewriteVariable(t);
127 }else if(t.getKind() == kind::MINUS){
128 return rewriteMinus(t, false);
129 }else if(t.getKind() == kind::UMINUS){
130 return rewriteUMinus(t, false);
131 }else if(t.getKind() == kind::DIVISION){
132 return rewriteDivByConstant(t, false);
133 }else if(t.getKind() == kind::PLUS){
134 return postRewritePlus(t);
135 }else if(t.getKind() == kind::MULT){
136 return postRewriteMult(t);
137 }else if(t.getKind() == kind::INTS_DIVISION){
138 return RewriteResponse(REWRITE_DONE, t);
139 }else if(t.getKind() == kind::INTS_MODULUS){
140 return RewriteResponse(REWRITE_DONE, t);
141 }else{
142 Unreachable();
143 }
144 }
145
146 RewriteResponse ArithRewriter::preRewriteMult(TNode t){
147 Assert(t.getKind()== kind::MULT);
148
149 // Rewrite multiplications with a 0 argument and to 0
150 Integer intZero;
151
152 Rational qZero(0);
153
154 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
155 if((*i).getKind() == kind::CONST_RATIONAL) {
156 if((*i).getConst<Rational>() == qZero) {
157 return RewriteResponse(REWRITE_DONE, mkRationalNode(qZero));
158 }
159 } else if((*i).getKind() == kind::CONST_INTEGER) {
160 if((*i).getConst<Integer>() == intZero) {
161 if(t.getType().isInteger()) {
162 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(intZero));
163 } else {
164 return RewriteResponse(REWRITE_DONE, mkRationalNode(qZero));
165 }
166 }
167 }
168 }
169 return RewriteResponse(REWRITE_DONE, t);
170 }
171 RewriteResponse ArithRewriter::preRewritePlus(TNode t){
172 Assert(t.getKind()== kind::PLUS);
173
174 return RewriteResponse(REWRITE_DONE, t);
175 }
176
177 RewriteResponse ArithRewriter::postRewritePlus(TNode t){
178 Assert(t.getKind()== kind::PLUS);
179
180 Polynomial res = Polynomial::mkZero();
181
182 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
183 Node curr = *i;
184 Polynomial currPoly = Polynomial::parsePolynomial(curr);
185
186 res = res + currPoly;
187 }
188
189 return RewriteResponse(REWRITE_DONE, res.getNode());
190 }
191
192 RewriteResponse ArithRewriter::postRewriteMult(TNode t){
193 Assert(t.getKind()== kind::MULT);
194
195 Polynomial res = Polynomial::mkOne();
196
197 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
198 Node curr = *i;
199 Polynomial currPoly = Polynomial::parsePolynomial(curr);
200
201 res = res * currPoly;
202 }
203
204 return RewriteResponse(REWRITE_DONE, res.getNode());
205 }
206
207 RewriteResponse ArithRewriter::postRewriteAtomConstantRHS(TNode t){
208 TNode left = t[0];
209 TNode right = t[1];
210
211 Comparison cmp = Comparison::mkNormalComparison(t.getKind(), Polynomial::parsePolynomial(left), Constant(right));
212
213 Assert(cmp.isNormalForm());
214 return RewriteResponse(REWRITE_DONE, cmp.getNode());
215 }
216
217 RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
218 // left |><| right
219 TNode left = atom[0];
220 TNode right = atom[1];
221
222 if(right.getMetaKind() == kind::metakind::CONSTANT){
223 return postRewriteAtomConstantRHS(atom);
224 }else{
225 Polynomial pleft = Polynomial::parsePolynomial(left);
226 Polynomial pright = Polynomial::parsePolynomial(right);
227
228 Polynomial diff = pleft - pright;
229
230 Constant cZero = Constant::mkConstant(Rational(0));
231 Node reduction = NodeManager::currentNM()->mkNode(atom.getKind(), diff.getNode(), cZero.getNode());
232
233 return postRewriteAtomConstantRHS(reduction);
234 }
235 }
236
237 RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
238 Assert(isAtom(atom));
239
240 NodeManager* currNM = NodeManager::currentNM();
241
242 if(atom.getKind() == kind::EQUAL) {
243 if(atom[0] == atom[1]) {
244 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
245 }
246 }else if(atom.getKind() == kind::GT){
247 Node leq = currNM->mkNode(kind::LEQ, atom[0], atom[1]);
248 return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, leq));
249 }else if(atom.getKind() == kind::LT){
250 Node geq = currNM->mkNode(kind::GEQ, atom[0], atom[1]);
251 return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, geq));
252 }
253
254 return RewriteResponse(REWRITE_DONE, atom);
255 }
256
257 RewriteResponse ArithRewriter::postRewrite(TNode t){
258 if(isTerm(t)){
259 RewriteResponse response = postRewriteTerm(t);
260 if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
261 Polynomial::parsePolynomial(response.node);
262 }
263 return response;
264 }else if(isAtom(t)){
265 RewriteResponse response = postRewriteAtom(t);
266 if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
267 Comparison::parseNormalForm(response.node);
268 }
269 return response;
270 }else{
271 Unreachable();
272 return RewriteResponse(REWRITE_DONE, Node::null());
273 }
274 }
275
276 RewriteResponse ArithRewriter::preRewrite(TNode t){
277 if(isTerm(t)){
278 return preRewriteTerm(t);
279 }else if(isAtom(t)){
280 return preRewriteAtom(t);
281 }else{
282 Unreachable();
283 return RewriteResponse(REWRITE_DONE, Node::null());
284 }
285 }
286
287 Node ArithRewriter::makeUnaryMinusNode(TNode n){
288 Rational qNegOne(-1);
289 return NodeManager::currentNM()->mkNode(kind::MULT, mkRationalNode(qNegOne),n);
290 }
291
292 Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
293 Node negR = makeUnaryMinusNode(r);
294 Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR);
295
296 return diff;
297 }
298
299 RewriteResponse ArithRewriter::rewriteDivByConstant(TNode t, bool pre){
300 Assert(t.getKind()== kind::DIVISION);
301
302 Node left = t[0];
303 Node right = t[1];
304 Assert(right.getKind()== kind::CONST_RATIONAL);
305
306
307 const Rational& den = right.getConst<Rational>();
308
309 Assert(den != Rational(0));
310
311 Rational div = den.inverse();
312
313 Node result = mkRationalNode(div);
314
315 Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
316 if(pre){
317 return RewriteResponse(REWRITE_DONE, mult);
318 }else{
319 return RewriteResponse(REWRITE_AGAIN, mult);
320 }
321 }
322
323 }/* CVC4::theory::arith namespace */
324 }/* CVC4::theory namespace */
325 }/* CVC4 namespace */