Partial merge of integers work; this is simple B&B and some pseudoboolean
[cvc5.git] / src / theory / arith / arith_rewriter.cpp
1 /********************* */
2 /*! \file arith_rewriter.cpp
3 ** \verbatim
4 ** Original author: taking
5 ** Major contributors: dejan
6 ** Minor contributors (to current version): mdeters
7 ** This file is part of the CVC4 prototype.
8 ** Copyright (c) 2009, 2010, 2011 The Analysis of Computer Systems Group (ACSys)
9 ** Courant Institute of Mathematical Sciences
10 ** New York University
11 ** See the file COPYING in the top-level source directory for licensing
12 ** information.\endverbatim
13 **
14 ** \brief [[ Add one-line brief description here ]]
15 **
16 ** [[ Add lengthier description here ]]
17 ** \todo document this file
18 **/
19
20 #include "theory/theory.h"
21 #include "theory/arith/normal_form.h"
22 #include "theory/arith/arith_rewriter.h"
23 #include "theory/arith/arith_utilities.h"
24
25 #include <vector>
26 #include <set>
27 #include <stack>
28
29 using namespace CVC4;
30 using namespace CVC4::theory;
31 using namespace CVC4::theory::arith;
32
33 bool isVariable(TNode t){
34 return t.getMetaKind() == kind::metakind::VARIABLE;
35 }
36
37 RewriteResponse ArithRewriter::rewriteConstant(TNode t){
38 Assert(t.getMetaKind() == kind::metakind::CONSTANT);
39 Node val = coerceToRationalNode(t);
40
41 return RewriteResponse(REWRITE_DONE, val);
42 }
43
44 RewriteResponse ArithRewriter::rewriteVariable(TNode t){
45 Assert(isVariable(t));
46
47 return RewriteResponse(REWRITE_DONE, t);
48 }
49
50 RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){
51 Assert(t.getKind()== kind::MINUS);
52
53 if(t[0] == t[1]){
54 Rational zero(0);
55 Node zeroNode = mkRationalNode(zero);
56 return RewriteResponse(REWRITE_DONE, zeroNode);
57 }
58
59 Node noMinus = makeSubtractionNode(t[0],t[1]);
60 if(pre){
61 return RewriteResponse(REWRITE_DONE, noMinus);
62 }else{
63 return RewriteResponse(REWRITE_AGAIN_FULL, noMinus);
64 }
65 }
66
67 RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){
68 Assert(t.getKind()== kind::UMINUS);
69
70 Node noUminus = makeUnaryMinusNode(t[0]);
71 if(pre)
72 return RewriteResponse(REWRITE_DONE, noUminus);
73 else
74 return RewriteResponse(REWRITE_AGAIN, noUminus);
75 }
76
77 RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
78 if(t.getMetaKind() == kind::metakind::CONSTANT){
79 return rewriteConstant(t);
80 }else if(isVariable(t)){
81 return rewriteVariable(t);
82 }else if(t.getKind() == kind::MINUS){
83 return rewriteMinus(t, true);
84 }else if(t.getKind() == kind::UMINUS){
85 return rewriteUMinus(t, true);
86 }else if(t.getKind() == kind::DIVISION){
87 if(t[0].getKind()== kind::CONST_RATIONAL){
88 return rewriteDivByConstant(t, true);
89 }else{
90 return RewriteResponse(REWRITE_DONE, t);
91 }
92 }else if(t.getKind() == kind::PLUS){
93 return preRewritePlus(t);
94 }else if(t.getKind() == kind::MULT){
95 return preRewriteMult(t);
96 }else{
97 Unreachable();
98 }
99 }
100 RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
101 if(t.getMetaKind() == kind::metakind::CONSTANT){
102 return rewriteConstant(t);
103 }else if(isVariable(t)){
104 return rewriteVariable(t);
105 }else if(t.getKind() == kind::MINUS){
106 return rewriteMinus(t, false);
107 }else if(t.getKind() == kind::UMINUS){
108 return rewriteUMinus(t, false);
109 }else if(t.getKind() == kind::DIVISION){
110 return rewriteDivByConstant(t, false);
111 }else if(t.getKind() == kind::PLUS){
112 return postRewritePlus(t);
113 }else if(t.getKind() == kind::MULT){
114 return postRewriteMult(t);
115 }else{
116 Unreachable();
117 }
118 }
119
120 RewriteResponse ArithRewriter::preRewriteMult(TNode t){
121 Assert(t.getKind()== kind::MULT);
122
123 // Rewrite multiplications with a 0 argument and to 0
124 Integer intZero;
125
126 Rational qZero(0);
127
128 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
129 if((*i).getKind() == kind::CONST_RATIONAL) {
130 if((*i).getConst<Rational>() == qZero) {
131 return RewriteResponse(REWRITE_DONE, mkRationalNode(qZero));
132 }
133 } else if((*i).getKind() == kind::CONST_INTEGER) {
134 if((*i).getConst<Integer>() == intZero) {
135 if(t.getType().isInteger()) {
136 return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(intZero));
137 } else {
138 return RewriteResponse(REWRITE_DONE, mkRationalNode(qZero));
139 }
140 }
141 }
142 }
143 return RewriteResponse(REWRITE_DONE, t);
144 }
145 RewriteResponse ArithRewriter::preRewritePlus(TNode t){
146 Assert(t.getKind()== kind::PLUS);
147
148 return RewriteResponse(REWRITE_DONE, t);
149 }
150
151 RewriteResponse ArithRewriter::postRewritePlus(TNode t){
152 Assert(t.getKind()== kind::PLUS);
153
154 Polynomial res = Polynomial::mkZero();
155
156 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
157 Node curr = *i;
158 Polynomial currPoly = Polynomial::parsePolynomial(curr);
159
160 res = res + currPoly;
161 }
162
163 return RewriteResponse(REWRITE_DONE, res.getNode());
164 }
165
166 RewriteResponse ArithRewriter::postRewriteMult(TNode t){
167 Assert(t.getKind()== kind::MULT);
168
169 Polynomial res = Polynomial::mkOne();
170
171 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
172 Node curr = *i;
173 Polynomial currPoly = Polynomial::parsePolynomial(curr);
174
175 res = res * currPoly;
176 }
177
178 return RewriteResponse(REWRITE_DONE, res.getNode());
179 }
180
181 RewriteResponse ArithRewriter::postRewriteAtomConstantRHS(TNode t){
182 TNode left = t[0];
183 TNode right = t[1];
184
185 Comparison cmp = Comparison::mkComparison(t.getKind(), Polynomial::parsePolynomial(left), Constant(right));
186
187 if(cmp.isBoolean()){
188 return RewriteResponse(REWRITE_DONE, cmp.getNode());
189 }
190
191 if(cmp.getLeft().containsConstant()){
192 Monomial constantHead = cmp.getLeft().getHead();
193 Assert(constantHead.isConstant());
194
195 Constant constant = constantHead.getConstant();
196
197 Constant negativeConstantHead = -constant;
198
199 cmp = cmp.addConstant(negativeConstantHead);
200 }
201 Assert(!cmp.getLeft().containsConstant());
202
203 if(!cmp.getLeft().getHead().coefficientIsOne()){
204 Monomial constantHead = cmp.getLeft().getHead();
205 Assert(!constantHead.isConstant());
206 Constant constant = constantHead.getConstant();
207
208 Constant inverse = Constant::mkConstant(constant.getValue().inverse());
209
210 cmp = cmp.multiplyConstant(inverse);
211 }
212 Assert(cmp.getLeft().getHead().coefficientIsOne());
213
214 Assert(cmp.isBoolean() || cmp.isNormalForm());
215 return RewriteResponse(REWRITE_DONE, cmp.getNode());
216 }
217
218 RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
219 // left |><| right
220 TNode left = atom[0];
221 TNode right = atom[1];
222
223 if(right.getMetaKind() == kind::metakind::CONSTANT){
224 return postRewriteAtomConstantRHS(atom);
225 }else{
226 //Transform this to: (left - right) |><| 0
227 Node diff = makeSubtractionNode(left, right);
228 Rational qZero(0);
229 Node reduction = NodeManager::currentNM()->mkNode(atom.getKind(), diff, mkRationalNode(qZero));
230 return RewriteResponse(REWRITE_AGAIN_FULL, reduction);
231 }
232 }
233
234 RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
235 Assert(isAtom(atom));
236
237 NodeManager* currNM = NodeManager::currentNM();
238
239 if(atom.getKind() == kind::EQUAL) {
240 if(atom[0] == atom[1]) {
241 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
242 }
243 }
244
245 Node reduction = atom;
246
247 if(atom[1].getMetaKind() != kind::metakind::CONSTANT) {
248 // left |><| right
249 TNode left = atom[0];
250 TNode right = atom[1];
251
252 //Transform this to: (left - right) |><| 0
253 Node diff = makeSubtractionNode(left, right);
254 Rational qZero(0);
255 reduction = currNM->mkNode(atom.getKind(), diff, mkRationalNode(qZero));
256 }
257
258 if(reduction.getKind() == kind::GT){
259 Node leq = currNM->mkNode(kind::LEQ, reduction[0], reduction[1]);
260 reduction = currNM->mkNode(kind::NOT, leq);
261 }else if(reduction.getKind() == kind::LT){
262 Node geq = currNM->mkNode(kind::GEQ, reduction[0], reduction[1]);
263 reduction = currNM->mkNode(kind::NOT, geq);
264 }
265 /* BREADCRUMB : Move this rewrite into preprocessing
266 else if( Options::current()->rewriteArithEqualities && reduction.getKind() == kind::EQUAL){
267 Node geq = currNM->mkNode(kind::GEQ, reduction[0], reduction[1]);
268 Node leq = currNM->mkNode(kind::LEQ, reduction[0], reduction[1]);
269 reduction = currNM->mkNode(kind::AND, geq, leq);
270 }
271 */
272
273
274 return RewriteResponse(REWRITE_DONE, reduction);
275 }
276
277 RewriteResponse ArithRewriter::postRewrite(TNode t){
278 if(isTerm(t)){
279 RewriteResponse response = postRewriteTerm(t);
280 if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
281 Polynomial::parsePolynomial(response.node);
282 }
283 return response;
284 }else if(isAtom(t)){
285 RewriteResponse response = postRewriteAtom(t);
286 if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
287 Comparison::parseNormalForm(response.node);
288 }
289 return response;
290 }else{
291 Unreachable();
292 return RewriteResponse(REWRITE_DONE, Node::null());
293 }
294 }
295
296 RewriteResponse ArithRewriter::preRewrite(TNode t){
297 if(isTerm(t)){
298 return preRewriteTerm(t);
299 }else if(isAtom(t)){
300 return preRewriteAtom(t);
301 }else{
302 Unreachable();
303 return RewriteResponse(REWRITE_DONE, Node::null());
304 }
305 }
306
307 Node ArithRewriter::makeUnaryMinusNode(TNode n){
308 Rational qNegOne(-1);
309 return NodeManager::currentNM()->mkNode(kind::MULT, mkRationalNode(qNegOne),n);
310 }
311
312 Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
313 Node negR = makeUnaryMinusNode(r);
314 Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR);
315
316 return diff;
317 }
318
319 RewriteResponse ArithRewriter::rewriteDivByConstant(TNode t, bool pre){
320 Assert(t.getKind()== kind::DIVISION);
321
322 Node left = t[0];
323 Node right = t[1];
324 Assert(right.getKind()== kind::CONST_RATIONAL);
325
326
327 const Rational& den = right.getConst<Rational>();
328
329 Assert(den != Rational(0));
330
331 Rational div = den.inverse();
332
333 Node result = mkRationalNode(div);
334
335 Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
336 if(pre){
337 return RewriteResponse(REWRITE_DONE, mult);
338 }else{
339 return RewriteResponse(REWRITE_AGAIN, mult);
340 }
341 }