This commit removes the CONST_INTEGER kind from nodes. This code comes from the branc...
[cvc5.git] / src / theory / arith / arith_rewriter.cpp
1 /********************* */
2 /*! \file arith_rewriter.cpp
3 ** \verbatim
4 ** Original author: taking
5 ** Major contributors: none
6 ** Minor contributors (to current version): mdeters, dejan
7 ** This file is part of the CVC4 prototype.
8 ** Copyright (c) 2009, 2010, 2011 The Analysis of Computer Systems Group (ACSys)
9 ** Courant Institute of Mathematical Sciences
10 ** New York University
11 ** See the file COPYING in the top-level source directory for licensing
12 ** information.\endverbatim
13 **
14 ** \brief [[ Add one-line brief description here ]]
15 **
16 ** [[ Add lengthier description here ]]
17 ** \todo document this file
18 **/
19
20 #include "theory/theory.h"
21 #include "theory/arith/normal_form.h"
22 #include "theory/arith/arith_rewriter.h"
23 #include "theory/arith/arith_utilities.h"
24
25 #include <vector>
26 #include <set>
27 #include <stack>
28
29 namespace CVC4 {
30 namespace theory {
31 namespace arith {
32
33 bool isVariable(TNode t){
34 return t.getMetaKind() == kind::metakind::VARIABLE;
35 }
36
37 bool ArithRewriter::isAtom(TNode n) {
38 return arith::isRelationOperator(n.getKind());
39 }
40
41 RewriteResponse ArithRewriter::rewriteConstant(TNode t){
42 Assert(t.getMetaKind() == kind::metakind::CONSTANT);
43 Assert(t.getKind() == kind::CONST_RATIONAL);
44
45 return RewriteResponse(REWRITE_DONE, t);
46 }
47
48 RewriteResponse ArithRewriter::rewriteVariable(TNode t){
49 Assert(isVariable(t));
50
51 return RewriteResponse(REWRITE_DONE, t);
52 }
53
54 RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){
55 Assert(t.getKind()== kind::MINUS);
56
57 if(pre){
58 if(t[0] == t[1]){
59 Rational zero(0);
60 Node zeroNode = mkRationalNode(zero);
61 return RewriteResponse(REWRITE_DONE, zeroNode);
62 }else{
63 Node noMinus = makeSubtractionNode(t[0],t[1]);
64 return RewriteResponse(REWRITE_DONE, noMinus);
65 }
66 }else{
67 Polynomial minuend = Polynomial::parsePolynomial(t[0]);
68 Polynomial subtrahend = Polynomial::parsePolynomial(t[0]);
69 Polynomial diff = minuend - subtrahend;
70 return RewriteResponse(REWRITE_DONE, diff.getNode());
71 }
72 }
73
74 RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){
75 Assert(t.getKind()== kind::UMINUS);
76
77 Node noUminus = makeUnaryMinusNode(t[0]);
78 if(pre)
79 return RewriteResponse(REWRITE_DONE, noUminus);
80 else
81 return RewriteResponse(REWRITE_AGAIN, noUminus);
82 }
83
84 RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
85 if(t.getMetaKind() == kind::metakind::CONSTANT){
86 return rewriteConstant(t);
87 }else if(isVariable(t)){
88 return rewriteVariable(t);
89 }else if(t.getKind() == kind::MINUS){
90 return rewriteMinus(t, true);
91 }else if(t.getKind() == kind::UMINUS){
92 return rewriteUMinus(t, true);
93 }else if(t.getKind() == kind::DIVISION){
94 return RewriteResponse(REWRITE_DONE, t); // wait until t[1] is rewritten
95 }else if(t.getKind() == kind::PLUS){
96 return preRewritePlus(t);
97 }else if(t.getKind() == kind::MULT){
98 return preRewriteMult(t);
99 }else if(t.getKind() == kind::INTS_DIVISION){
100 Rational intOne(1);
101 if(t[1].getKind()== kind::CONST_RATIONAL && t[1].getConst<Rational>() == intOne){
102 return RewriteResponse(REWRITE_AGAIN, t[0]);
103 }else{
104 return RewriteResponse(REWRITE_DONE, t);
105 }
106 }else if(t.getKind() == kind::INTS_MODULUS){
107 Rational intOne(1);
108 if(t[1].getKind()== kind::CONST_RATIONAL && t[1].getConst<Rational>() == intOne){
109 Rational intZero(0);
110 return RewriteResponse(REWRITE_AGAIN, mkRationalNode(intZero));
111 }else{
112 return RewriteResponse(REWRITE_DONE, t);
113 }
114 }else{
115 Unreachable();
116 }
117 }
118 RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
119 if(t.getMetaKind() == kind::metakind::CONSTANT){
120 return rewriteConstant(t);
121 }else if(isVariable(t)){
122 return rewriteVariable(t);
123 }else if(t.getKind() == kind::MINUS){
124 return rewriteMinus(t, false);
125 }else if(t.getKind() == kind::UMINUS){
126 return rewriteUMinus(t, false);
127 }else if(t.getKind() == kind::DIVISION){
128 return rewriteDivByConstant(t, false);
129 }else if(t.getKind() == kind::PLUS){
130 return postRewritePlus(t);
131 }else if(t.getKind() == kind::MULT){
132 return postRewriteMult(t);
133 }else if(t.getKind() == kind::INTS_DIVISION){
134 return RewriteResponse(REWRITE_DONE, t);
135 }else if(t.getKind() == kind::INTS_MODULUS){
136 return RewriteResponse(REWRITE_DONE, t);
137 }else{
138 Unreachable();
139 }
140 }
141
142 RewriteResponse ArithRewriter::preRewriteMult(TNode t){
143 Assert(t.getKind()== kind::MULT);
144
145 // Rewrite multiplications with a 0 argument and to 0
146 Rational qZero(0);
147
148 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
149 if((*i).getKind() == kind::CONST_RATIONAL) {
150 if((*i).getConst<Rational>() == qZero) {
151 return RewriteResponse(REWRITE_DONE, mkRationalNode(qZero));
152 }
153 }
154 }
155 return RewriteResponse(REWRITE_DONE, t);
156 }
157 RewriteResponse ArithRewriter::preRewritePlus(TNode t){
158 Assert(t.getKind()== kind::PLUS);
159
160 return RewriteResponse(REWRITE_DONE, t);
161 }
162
163 RewriteResponse ArithRewriter::postRewritePlus(TNode t){
164 Assert(t.getKind()== kind::PLUS);
165
166 Polynomial res = Polynomial::mkZero();
167
168 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
169 Node curr = *i;
170 Polynomial currPoly = Polynomial::parsePolynomial(curr);
171
172 res = res + currPoly;
173 }
174
175 return RewriteResponse(REWRITE_DONE, res.getNode());
176 }
177
178 RewriteResponse ArithRewriter::postRewriteMult(TNode t){
179 Assert(t.getKind()== kind::MULT);
180
181 Polynomial res = Polynomial::mkOne();
182
183 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
184 Node curr = *i;
185 Polynomial currPoly = Polynomial::parsePolynomial(curr);
186
187 res = res * currPoly;
188 }
189
190 return RewriteResponse(REWRITE_DONE, res.getNode());
191 }
192
193 // RewriteResponse ArithRewriter::postRewriteAtomConstantRHS(TNode t){
194 // TNode left = t[0];
195 // TNode right = t[1];
196
197 // Polynomial pLeft = Polynomial::parsePolynomial(left);
198
199
200 // Comparison cmp = Comparison::mkComparison(t.getKind(), Polynomial::parsePolynomial(left), Constant(right));
201
202 // Assert(cmp.isNormalForm());
203 // return RewriteResponse(REWRITE_DONE, cmp.getNode());
204 // }
205
206 RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
207 // left |><| right
208 TNode left = atom[0];
209 TNode right = atom[1];
210
211 Polynomial pleft = Polynomial::parsePolynomial(left);
212 Polynomial pright = Polynomial::parsePolynomial(right);
213
214 Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
215 Assert(cmp.isNormalForm());
216 return RewriteResponse(REWRITE_DONE, cmp.getNode());
217 }
218
219 RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
220 Assert(isAtom(atom));
221
222 NodeManager* currNM = NodeManager::currentNM();
223
224 if(atom.getKind() == kind::EQUAL) {
225 if(atom[0] == atom[1]) {
226 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
227 }
228 }else if(atom.getKind() == kind::GT){
229 Node leq = currNM->mkNode(kind::LEQ, atom[0], atom[1]);
230 return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, leq));
231 }else if(atom.getKind() == kind::LT){
232 Node geq = currNM->mkNode(kind::GEQ, atom[0], atom[1]);
233 return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, geq));
234 }
235
236 return RewriteResponse(REWRITE_DONE, atom);
237 }
238
239 RewriteResponse ArithRewriter::postRewrite(TNode t){
240 if(isTerm(t)){
241 RewriteResponse response = postRewriteTerm(t);
242 if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
243 Polynomial::parsePolynomial(response.node);
244 }
245 return response;
246 }else if(isAtom(t)){
247 RewriteResponse response = postRewriteAtom(t);
248 if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
249 Comparison::parseNormalForm(response.node);
250 }
251 return response;
252 }else{
253 Unreachable();
254 return RewriteResponse(REWRITE_DONE, Node::null());
255 }
256 }
257
258 RewriteResponse ArithRewriter::preRewrite(TNode t){
259 if(isTerm(t)){
260 return preRewriteTerm(t);
261 }else if(isAtom(t)){
262 return preRewriteAtom(t);
263 }else{
264 Unreachable();
265 return RewriteResponse(REWRITE_DONE, Node::null());
266 }
267 }
268
269 Node ArithRewriter::makeUnaryMinusNode(TNode n){
270 Rational qNegOne(-1);
271 return NodeManager::currentNM()->mkNode(kind::MULT, mkRationalNode(qNegOne),n);
272 }
273
274 Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
275 Node negR = makeUnaryMinusNode(r);
276 Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR);
277
278 return diff;
279 }
280
281 RewriteResponse ArithRewriter::rewriteDivByConstant(TNode t, bool pre){
282 Assert(t.getKind()== kind::DIVISION);
283
284 Node left = t[0];
285 Node right = t[1];
286 Assert(right.getKind()== kind::CONST_RATIONAL);
287
288
289 const Rational& den = right.getConst<Rational>();
290
291 Assert(den != Rational(0));
292
293 Rational div = den.inverse();
294
295 Node result = mkRationalNode(div);
296
297 Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
298 if(pre){
299 return RewriteResponse(REWRITE_DONE, mult);
300 }else{
301 return RewriteResponse(REWRITE_AGAIN, mult);
302 }
303 }
304
305 }/* CVC4::theory::arith namespace */
306 }/* CVC4::theory namespace */
307 }/* CVC4 namespace */