fix uses of getMetaKind() from outside the expr package. (they now use isConst(...
[cvc5.git] / src / theory / arith / arith_rewriter.cpp
1 /********************* */
2 /*! \file arith_rewriter.cpp
3 ** \verbatim
4 ** Original author: taking
5 ** Major contributors: none
6 ** Minor contributors (to current version): mdeters, dejan
7 ** This file is part of the CVC4 prototype.
8 ** Copyright (c) 2009, 2010, 2011 The Analysis of Computer Systems Group (ACSys)
9 ** Courant Institute of Mathematical Sciences
10 ** New York University
11 ** See the file COPYING in the top-level source directory for licensing
12 ** information.\endverbatim
13 **
14 ** \brief [[ Add one-line brief description here ]]
15 **
16 ** [[ Add lengthier description here ]]
17 ** \todo document this file
18 **/
19
20 #include "theory/theory.h"
21 #include "theory/arith/normal_form.h"
22 #include "theory/arith/arith_rewriter.h"
23 #include "theory/arith/arith_utilities.h"
24
25 #include <vector>
26 #include <set>
27 #include <stack>
28
29 namespace CVC4 {
30 namespace theory {
31 namespace arith {
32
33 bool ArithRewriter::isAtom(TNode n) {
34 return arith::isRelationOperator(n.getKind());
35 }
36
37 RewriteResponse ArithRewriter::rewriteConstant(TNode t){
38 Assert(t.isConst());
39 Assert(t.getKind() == kind::CONST_RATIONAL);
40
41 return RewriteResponse(REWRITE_DONE, t);
42 }
43
44 RewriteResponse ArithRewriter::rewriteVariable(TNode t){
45 Assert(t.isVar());
46
47 return RewriteResponse(REWRITE_DONE, t);
48 }
49
50 RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){
51 Assert(t.getKind()== kind::MINUS);
52
53 if(pre){
54 if(t[0] == t[1]){
55 Rational zero(0);
56 Node zeroNode = mkRationalNode(zero);
57 return RewriteResponse(REWRITE_DONE, zeroNode);
58 }else{
59 Node noMinus = makeSubtractionNode(t[0],t[1]);
60 return RewriteResponse(REWRITE_DONE, noMinus);
61 }
62 }else{
63 Polynomial minuend = Polynomial::parsePolynomial(t[0]);
64 Polynomial subtrahend = Polynomial::parsePolynomial(t[0]);
65 Polynomial diff = minuend - subtrahend;
66 return RewriteResponse(REWRITE_DONE, diff.getNode());
67 }
68 }
69
70 RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){
71 Assert(t.getKind()== kind::UMINUS);
72
73 Node noUminus = makeUnaryMinusNode(t[0]);
74 if(pre)
75 return RewriteResponse(REWRITE_DONE, noUminus);
76 else
77 return RewriteResponse(REWRITE_AGAIN, noUminus);
78 }
79
80 RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
81 if(t.isConst()){
82 return rewriteConstant(t);
83 }else if(t.isVar()){
84 return rewriteVariable(t);
85 }else if(t.getKind() == kind::MINUS){
86 return rewriteMinus(t, true);
87 }else if(t.getKind() == kind::UMINUS){
88 return rewriteUMinus(t, true);
89 }else if(t.getKind() == kind::DIVISION){
90 return RewriteResponse(REWRITE_DONE, t); // wait until t[1] is rewritten
91 }else if(t.getKind() == kind::PLUS){
92 return preRewritePlus(t);
93 }else if(t.getKind() == kind::MULT){
94 return preRewriteMult(t);
95 }else if(t.getKind() == kind::INTS_DIVISION){
96 Rational intOne(1);
97 if(t[1].getKind()== kind::CONST_RATIONAL && t[1].getConst<Rational>() == intOne){
98 return RewriteResponse(REWRITE_AGAIN, t[0]);
99 }else{
100 return RewriteResponse(REWRITE_DONE, t);
101 }
102 }else if(t.getKind() == kind::INTS_MODULUS){
103 Rational intOne(1);
104 if(t[1].getKind()== kind::CONST_RATIONAL && t[1].getConst<Rational>() == intOne){
105 Rational intZero(0);
106 return RewriteResponse(REWRITE_AGAIN, mkRationalNode(intZero));
107 }else{
108 return RewriteResponse(REWRITE_DONE, t);
109 }
110 }else{
111 Unreachable();
112 }
113 }
114 RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
115 if(t.isConst()){
116 return rewriteConstant(t);
117 }else if(t.isVar()){
118 return rewriteVariable(t);
119 }else if(t.getKind() == kind::MINUS){
120 return rewriteMinus(t, false);
121 }else if(t.getKind() == kind::UMINUS){
122 return rewriteUMinus(t, false);
123 }else if(t.getKind() == kind::DIVISION){
124 return rewriteDivByConstant(t, false);
125 }else if(t.getKind() == kind::PLUS){
126 return postRewritePlus(t);
127 }else if(t.getKind() == kind::MULT){
128 return postRewriteMult(t);
129 }else if(t.getKind() == kind::INTS_DIVISION){
130 return RewriteResponse(REWRITE_DONE, t);
131 }else if(t.getKind() == kind::INTS_MODULUS){
132 return RewriteResponse(REWRITE_DONE, t);
133 }else{
134 Unreachable();
135 }
136 }
137
138 RewriteResponse ArithRewriter::preRewriteMult(TNode t){
139 Assert(t.getKind()== kind::MULT);
140
141 // Rewrite multiplications with a 0 argument and to 0
142 Rational qZero(0);
143
144 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
145 if((*i).getKind() == kind::CONST_RATIONAL) {
146 if((*i).getConst<Rational>() == qZero) {
147 return RewriteResponse(REWRITE_DONE, mkRationalNode(qZero));
148 }
149 }
150 }
151 return RewriteResponse(REWRITE_DONE, t);
152 }
153 RewriteResponse ArithRewriter::preRewritePlus(TNode t){
154 Assert(t.getKind()== kind::PLUS);
155
156 return RewriteResponse(REWRITE_DONE, t);
157 }
158
159 RewriteResponse ArithRewriter::postRewritePlus(TNode t){
160 Assert(t.getKind()== kind::PLUS);
161
162 Polynomial res = Polynomial::mkZero();
163
164 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
165 Node curr = *i;
166 Polynomial currPoly = Polynomial::parsePolynomial(curr);
167
168 res = res + currPoly;
169 }
170
171 return RewriteResponse(REWRITE_DONE, res.getNode());
172 }
173
174 RewriteResponse ArithRewriter::postRewriteMult(TNode t){
175 Assert(t.getKind()== kind::MULT);
176
177 Polynomial res = Polynomial::mkOne();
178
179 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
180 Node curr = *i;
181 Polynomial currPoly = Polynomial::parsePolynomial(curr);
182
183 res = res * currPoly;
184 }
185
186 return RewriteResponse(REWRITE_DONE, res.getNode());
187 }
188
189 // RewriteResponse ArithRewriter::postRewriteAtomConstantRHS(TNode t){
190 // TNode left = t[0];
191 // TNode right = t[1];
192
193 // Polynomial pLeft = Polynomial::parsePolynomial(left);
194
195
196 // Comparison cmp = Comparison::mkComparison(t.getKind(), Polynomial::parsePolynomial(left), Constant(right));
197
198 // Assert(cmp.isNormalForm());
199 // return RewriteResponse(REWRITE_DONE, cmp.getNode());
200 // }
201
202 RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
203 // left |><| right
204 TNode left = atom[0];
205 TNode right = atom[1];
206
207 Polynomial pleft = Polynomial::parsePolynomial(left);
208 Polynomial pright = Polynomial::parsePolynomial(right);
209
210 Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
211 Assert(cmp.isNormalForm());
212 return RewriteResponse(REWRITE_DONE, cmp.getNode());
213 }
214
215 RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
216 Assert(isAtom(atom));
217
218 NodeManager* currNM = NodeManager::currentNM();
219
220 if(atom.getKind() == kind::EQUAL) {
221 if(atom[0] == atom[1]) {
222 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
223 }
224 }else if(atom.getKind() == kind::GT){
225 Node leq = currNM->mkNode(kind::LEQ, atom[0], atom[1]);
226 return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, leq));
227 }else if(atom.getKind() == kind::LT){
228 Node geq = currNM->mkNode(kind::GEQ, atom[0], atom[1]);
229 return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, geq));
230 }
231
232 return RewriteResponse(REWRITE_DONE, atom);
233 }
234
235 RewriteResponse ArithRewriter::postRewrite(TNode t){
236 if(isTerm(t)){
237 RewriteResponse response = postRewriteTerm(t);
238 if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
239 Polynomial::parsePolynomial(response.node);
240 }
241 return response;
242 }else if(isAtom(t)){
243 RewriteResponse response = postRewriteAtom(t);
244 if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
245 Comparison::parseNormalForm(response.node);
246 }
247 return response;
248 }else{
249 Unreachable();
250 return RewriteResponse(REWRITE_DONE, Node::null());
251 }
252 }
253
254 RewriteResponse ArithRewriter::preRewrite(TNode t){
255 if(isTerm(t)){
256 return preRewriteTerm(t);
257 }else if(isAtom(t)){
258 return preRewriteAtom(t);
259 }else{
260 Unreachable();
261 return RewriteResponse(REWRITE_DONE, Node::null());
262 }
263 }
264
265 Node ArithRewriter::makeUnaryMinusNode(TNode n){
266 Rational qNegOne(-1);
267 return NodeManager::currentNM()->mkNode(kind::MULT, mkRationalNode(qNegOne),n);
268 }
269
270 Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
271 Node negR = makeUnaryMinusNode(r);
272 Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR);
273
274 return diff;
275 }
276
277 RewriteResponse ArithRewriter::rewriteDivByConstant(TNode t, bool pre){
278 Assert(t.getKind()== kind::DIVISION);
279
280 Node left = t[0];
281 Node right = t[1];
282 Assert(right.getKind()== kind::CONST_RATIONAL);
283
284
285 const Rational& den = right.getConst<Rational>();
286
287 Assert(den != Rational(0));
288
289 Rational div = den.inverse();
290
291 Node result = mkRationalNode(div);
292
293 Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
294 if(pre){
295 return RewriteResponse(REWRITE_DONE, mult);
296 }else{
297 return RewriteResponse(REWRITE_AGAIN, mult);
298 }
299 }
300
301 }/* CVC4::theory::arith namespace */
302 }/* CVC4::theory namespace */
303 }/* CVC4 namespace */