Commit to fix bug 241 (improper "using namespace std" in a header). This caused...
[cvc5.git] / src / theory / arith / arith_rewriter.cpp
1 /********************* */
2 /*! \file arith_rewriter.cpp
3 ** \verbatim
4 ** Original author: taking
5 ** Major contributors: dejan
6 ** Minor contributors (to current version): mdeters
7 ** This file is part of the CVC4 prototype.
8 ** Copyright (c) 2009, 2010, 2011 The Analysis of Computer Systems Group (ACSys)
9 ** Courant Institute of Mathematical Sciences
10 ** New York University
11 ** See the file COPYING in the top-level source directory for licensing
12 ** information.\endverbatim
13 **
14 ** \brief [[ Add one-line brief description here ]]
15 **
16 ** [[ Add lengthier description here ]]
17 ** \todo document this file
18 **/
19
20 #include "theory/theory.h"
21 #include "theory/arith/normal_form.h"
22 #include "theory/arith/arith_rewriter.h"
23 #include "theory/arith/arith_utilities.h"
24
25 #include <vector>
26 #include <set>
27 #include <stack>
28
29 using namespace CVC4;
30 using namespace CVC4::theory;
31 using namespace CVC4::theory::arith;
32
33 arith::ArithConstants* ArithRewriter::s_constants = NULL;
34
35 bool isVariable(TNode t){
36 return t.getMetaKind() == kind::metakind::VARIABLE;
37 }
38
39 RewriteResponse ArithRewriter::rewriteConstant(TNode t){
40 Assert(t.getMetaKind() == kind::metakind::CONSTANT);
41 Node val = coerceToRationalNode(t);
42
43 return RewriteResponse(REWRITE_DONE, val);
44 }
45
46 RewriteResponse ArithRewriter::rewriteVariable(TNode t){
47 Assert(isVariable(t));
48
49 return RewriteResponse(REWRITE_DONE, t);
50 }
51
52 RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){
53 Assert(t.getKind()== kind::MINUS);
54
55 if(t[0] == t[1]) return RewriteResponse(REWRITE_DONE, s_constants->d_ZERO_NODE);
56
57 Node noMinus = makeSubtractionNode(t[0],t[1]);
58 if(pre){
59 return RewriteResponse(REWRITE_DONE, noMinus);
60 }else{
61 return RewriteResponse(REWRITE_AGAIN_FULL, noMinus);
62 }
63 }
64
65 RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){
66 Assert(t.getKind()== kind::UMINUS);
67
68 Node noUminus = makeUnaryMinusNode(t[0]);
69 if(pre)
70 return RewriteResponse(REWRITE_DONE, noUminus);
71 else
72 return RewriteResponse(REWRITE_AGAIN, noUminus);
73 }
74
75 RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
76 if(t.getMetaKind() == kind::metakind::CONSTANT){
77 return rewriteConstant(t);
78 }else if(isVariable(t)){
79 return rewriteVariable(t);
80 }else if(t.getKind() == kind::MINUS){
81 return rewriteMinus(t, true);
82 }else if(t.getKind() == kind::UMINUS){
83 return rewriteUMinus(t, true);
84 }else if(t.getKind() == kind::DIVISION){
85 if(t[0].getKind()== kind::CONST_RATIONAL){
86 return rewriteDivByConstant(t, true);
87 }else{
88 return RewriteResponse(REWRITE_DONE, t);
89 }
90 }else if(t.getKind() == kind::PLUS){
91 return preRewritePlus(t);
92 }else if(t.getKind() == kind::MULT){
93 return preRewriteMult(t);
94 }else{
95 Unreachable();
96 }
97 }
98 RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
99 if(t.getMetaKind() == kind::metakind::CONSTANT){
100 return rewriteConstant(t);
101 }else if(isVariable(t)){
102 return rewriteVariable(t);
103 }else if(t.getKind() == kind::MINUS){
104 return rewriteMinus(t, false);
105 }else if(t.getKind() == kind::UMINUS){
106 return rewriteUMinus(t, false);
107 }else if(t.getKind() == kind::DIVISION){
108 return rewriteDivByConstant(t, false);
109 }else if(t.getKind() == kind::PLUS){
110 return postRewritePlus(t);
111 }else if(t.getKind() == kind::MULT){
112 return postRewriteMult(t);
113 }else{
114 Unreachable();
115 }
116 }
117
118 RewriteResponse ArithRewriter::preRewriteMult(TNode t){
119 Assert(t.getKind()== kind::MULT);
120
121 // Rewrite multiplications with a 0 argument and to 0
122 Integer intZero;
123
124 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
125 if((*i).getKind() == kind::CONST_RATIONAL) {
126 if((*i).getConst<Rational>() == s_constants->d_ZERO) {
127 return RewriteResponse(REWRITE_DONE, s_constants->d_ZERO_NODE);
128 }
129 } else if((*i).getKind() == kind::CONST_INTEGER) {
130 if((*i).getConst<Integer>() == intZero) {
131 if(t.getType().isInteger()) {
132 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(intZero));
133 } else {
134 return RewriteResponse(REWRITE_DONE, s_constants->d_ZERO_NODE);
135 }
136 }
137 }
138 }
139 return RewriteResponse(REWRITE_DONE, t);
140 }
141 RewriteResponse ArithRewriter::preRewritePlus(TNode t){
142 Assert(t.getKind()== kind::PLUS);
143
144 return RewriteResponse(REWRITE_DONE, t);
145 }
146
147 RewriteResponse ArithRewriter::postRewritePlus(TNode t){
148 Assert(t.getKind()== kind::PLUS);
149
150 Polynomial res = Polynomial::mkZero();
151
152 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
153 Node curr = *i;
154 Polynomial currPoly = Polynomial::parsePolynomial(curr);
155
156 res = res + currPoly;
157 }
158
159 return RewriteResponse(REWRITE_DONE, res.getNode());
160 }
161
162 RewriteResponse ArithRewriter::postRewriteMult(TNode t){
163 Assert(t.getKind()== kind::MULT);
164
165 Polynomial res = Polynomial::mkOne();
166
167 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
168 Node curr = *i;
169 Polynomial currPoly = Polynomial::parsePolynomial(curr);
170
171 res = res * currPoly;
172 }
173
174 return RewriteResponse(REWRITE_DONE, res.getNode());
175 }
176
177 RewriteResponse ArithRewriter::postRewriteAtomConstantRHS(TNode t){
178 TNode left = t[0];
179 TNode right = t[1];
180
181
182 Comparison cmp = Comparison::mkComparison(t.getKind(), Polynomial::parsePolynomial(left), Constant(right));
183
184 if(cmp.isBoolean()){
185 return RewriteResponse(REWRITE_DONE, cmp.getNode());
186 }
187
188 if(cmp.getLeft().containsConstant()){
189 Monomial constantHead = cmp.getLeft().getHead();
190 Assert(constantHead.isConstant());
191
192 Constant constant = constantHead.getConstant();
193
194 Constant negativeConstantHead = -constant;
195
196 cmp = cmp.addConstant(negativeConstantHead);
197 }
198 Assert(!cmp.getLeft().containsConstant());
199
200 if(!cmp.getLeft().getHead().coefficientIsOne()){
201 Monomial constantHead = cmp.getLeft().getHead();
202 Assert(!constantHead.isConstant());
203 Constant constant = constantHead.getConstant();
204
205 Constant inverse = Constant::mkConstant(constant.getValue().inverse());
206
207 cmp = cmp.multiplyConstant(inverse);
208 }
209 Assert(cmp.getLeft().getHead().coefficientIsOne());
210
211 Assert(cmp.isBoolean() || cmp.isNormalForm());
212 return RewriteResponse(REWRITE_DONE, cmp.getNode());
213 }
214
215 RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
216 // left |><| right
217 TNode left = atom[0];
218 TNode right = atom[1];
219
220 if(right.getMetaKind() == kind::metakind::CONSTANT){
221 return postRewriteAtomConstantRHS(atom);
222 }else{
223 //Transform this to: (left - right) |><| 0
224 Node diff = makeSubtractionNode(left, right);
225 Node reduction = NodeManager::currentNM()->mkNode(atom.getKind(), diff, s_constants->d_ZERO_NODE);
226 return RewriteResponse(REWRITE_AGAIN_FULL, reduction);
227 }
228 }
229
230 RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
231 Assert(isAtom(atom));
232 NodeManager* currNM = NodeManager::currentNM();
233
234 if(atom.getKind() == kind::EQUAL) {
235 if(atom[0] == atom[1]) {
236 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
237 }
238 }
239
240 Node reduction = atom;
241
242 if(atom[1].getMetaKind() != kind::metakind::CONSTANT){
243 // left |><| right
244 TNode left = atom[0];
245 TNode right = atom[1];
246
247 //Transform this to: (left - right) |><| 0
248 Node diff = makeSubtractionNode(left, right);
249 reduction = currNM->mkNode(atom.getKind(), diff, s_constants->d_ZERO_NODE);
250 }
251
252 if(reduction.getKind() == kind::GT){
253 Node leq = currNM->mkNode(kind::LEQ, reduction[0], reduction[1]);
254 reduction = currNM->mkNode(kind::NOT, leq);
255 }else if(reduction.getKind() == kind::LT){
256 Node geq = currNM->mkNode(kind::GEQ, reduction[0], reduction[1]);
257 reduction = currNM->mkNode(kind::NOT, geq);
258 }
259
260 return RewriteResponse(REWRITE_DONE, reduction);
261 }
262
263 RewriteResponse ArithRewriter::postRewrite(TNode t){
264 if(isTerm(t)){
265 RewriteResponse response = postRewriteTerm(t);
266 if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
267 Polynomial::parsePolynomial(response.node);
268 }
269 return response;
270 }else if(isAtom(t)){
271 RewriteResponse response = postRewriteAtom(t);
272 if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
273 Comparison::parseNormalForm(response.node);
274 }
275 return response;
276 }else{
277 Unreachable();
278 return RewriteResponse(REWRITE_DONE, Node::null());
279 }
280 }
281
282 RewriteResponse ArithRewriter::preRewrite(TNode t){
283 if(isTerm(t)){
284 return preRewriteTerm(t);
285 }else if(isAtom(t)){
286 return preRewriteAtom(t);
287 }else{
288 Unreachable();
289 return RewriteResponse(REWRITE_DONE, Node::null());
290 }
291 }
292
293 Node ArithRewriter::makeUnaryMinusNode(TNode n){
294 return NodeManager::currentNM()->mkNode(kind::MULT,s_constants->d_NEGATIVE_ONE_NODE,n);
295 }
296
297 Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
298 Node negR = makeUnaryMinusNode(r);
299 Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR);
300
301 return diff;
302 }
303
304 RewriteResponse ArithRewriter::rewriteDivByConstant(TNode t, bool pre){
305 Assert(t.getKind()== kind::DIVISION);
306
307 Node left = t[0];
308 Node right = t[1];
309 Assert(right.getKind()== kind::CONST_RATIONAL);
310
311
312 const Rational& den = right.getConst<Rational>();
313
314 Assert(den != s_constants->d_ZERO);
315
316 Rational div = den.inverse();
317
318 Node result = mkRationalNode(div);
319
320 Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
321 if(pre){
322 return RewriteResponse(REWRITE_DONE, mult);
323 }else{
324 return RewriteResponse(REWRITE_AGAIN, mult);
325 }
326 }