Fixes for the arithmetic normal form and rewriter to handle arbitrary constants for...
[cvc5.git] / src / theory / arith / arith_rewriter.cpp
1 /********************* */
2 /*! \file arith_rewriter.cpp
3 ** \verbatim
4 ** Original author: taking
5 ** Major contributors: mdeters
6 ** Minor contributors (to current version): dejan
7 ** This file is part of the CVC4 prototype.
8 ** Copyright (c) 2009-2012 New York University and The University of Iowa
9 ** See the file COPYING in the top-level source directory for licensing
10 ** information.\endverbatim
11 **
12 ** \brief [[ Add one-line brief description here ]]
13 **
14 ** [[ Add lengthier description here ]]
15 ** \todo document this file
16 **/
17
18 #include "theory/theory.h"
19 #include "theory/arith/normal_form.h"
20 #include "theory/arith/arith_rewriter.h"
21 #include "theory/arith/arith_utilities.h"
22
23 #include <vector>
24 #include <set>
25 #include <stack>
26
27 namespace CVC4 {
28 namespace theory {
29 namespace arith {
30
31 bool ArithRewriter::isAtom(TNode n) {
32 return arith::isRelationOperator(n.getKind());
33 }
34
35 RewriteResponse ArithRewriter::rewriteConstant(TNode t){
36 Assert(t.isConst());
37 Assert(t.getKind() == kind::CONST_RATIONAL);
38
39 return RewriteResponse(REWRITE_DONE, t);
40 }
41
42 RewriteResponse ArithRewriter::rewriteVariable(TNode t){
43 Assert(t.isVar());
44
45 return RewriteResponse(REWRITE_DONE, t);
46 }
47
48 RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){
49 Assert(t.getKind()== kind::MINUS);
50
51 if(pre){
52 if(t[0] == t[1]){
53 Rational zero(0);
54 Node zeroNode = mkRationalNode(zero);
55 return RewriteResponse(REWRITE_DONE, zeroNode);
56 }else{
57 Node noMinus = makeSubtractionNode(t[0],t[1]);
58 return RewriteResponse(REWRITE_DONE, noMinus);
59 }
60 }else{
61 Polynomial minuend = Polynomial::parsePolynomial(t[0]);
62 Polynomial subtrahend = Polynomial::parsePolynomial(t[1]);
63 Polynomial diff = minuend - subtrahend;
64 return RewriteResponse(REWRITE_DONE, diff.getNode());
65 }
66 }
67
68 RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){
69 Assert(t.getKind()== kind::UMINUS);
70
71 Node noUminus = makeUnaryMinusNode(t[0]);
72 if(pre)
73 return RewriteResponse(REWRITE_DONE, noUminus);
74 else
75 return RewriteResponse(REWRITE_AGAIN, noUminus);
76 }
77
78 RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
79 if(t.isConst()){
80 return rewriteConstant(t);
81 }else if(t.isVar()){
82 return rewriteVariable(t);
83 }else{
84 switch(t.getKind()){
85 case kind::MINUS:
86 return rewriteMinus(t, true);
87 case kind::UMINUS:
88 return rewriteUMinus(t, true);
89 case kind::DIVISION:
90 return rewriteDiv(t,true);
91 case kind::DIVISION_TOTAL:
92 return rewriteDivTotal(t,true);
93 case kind::PLUS:
94 return preRewritePlus(t);
95 case kind::MULT:
96 return preRewriteMult(t);
97 //case kind::INTS_DIVISION:
98 //case kind::INTS_MODULUS:
99 case kind::INTS_DIVISION_TOTAL:
100 case kind::INTS_MODULUS_TOTAL:
101 return rewriteIntsDivModTotal(t,true);
102 default:
103 Unreachable();
104 }
105 }
106 }
107 RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
108 if(t.isConst()){
109 return rewriteConstant(t);
110 }else if(t.isVar()){
111 return rewriteVariable(t);
112 }else{
113 switch(t.getKind()){
114 case kind::MINUS:
115 return rewriteMinus(t, false);
116 case kind::UMINUS:
117 return rewriteUMinus(t, false);
118 case kind::DIVISION:
119 return rewriteDiv(t, false);
120 case kind::DIVISION_TOTAL:
121 return rewriteDivTotal(t, false);
122 case kind::PLUS:
123 return postRewritePlus(t);
124 case kind::MULT:
125 return postRewriteMult(t);
126 //case kind::INTS_DIVISION:
127 //case kind::INTS_MODULUS:
128 case kind::INTS_DIVISION_TOTAL:
129 case kind::INTS_MODULUS_TOTAL:
130 return rewriteIntsDivModTotal(t, false);
131 default:
132 Unreachable();
133 }
134 }
135 }
136
137
138 RewriteResponse ArithRewriter::preRewriteMult(TNode t){
139 Assert(t.getKind()== kind::MULT);
140
141 // Rewrite multiplications with a 0 argument and to 0
142 Rational qZero(0);
143
144 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
145 if((*i).getKind() == kind::CONST_RATIONAL) {
146 if((*i).getConst<Rational>() == qZero) {
147 return RewriteResponse(REWRITE_DONE, mkRationalNode(qZero));
148 }
149 }
150 }
151 return RewriteResponse(REWRITE_DONE, t);
152 }
153 RewriteResponse ArithRewriter::preRewritePlus(TNode t){
154 Assert(t.getKind()== kind::PLUS);
155
156 return RewriteResponse(REWRITE_DONE, t);
157 }
158
159 RewriteResponse ArithRewriter::postRewritePlus(TNode t){
160 Assert(t.getKind()== kind::PLUS);
161
162 Polynomial res = Polynomial::mkZero();
163
164 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
165 Node curr = *i;
166 Polynomial currPoly = Polynomial::parsePolynomial(curr);
167
168 res = res + currPoly;
169 }
170
171 return RewriteResponse(REWRITE_DONE, res.getNode());
172 }
173
174 RewriteResponse ArithRewriter::postRewriteMult(TNode t){
175 Assert(t.getKind()== kind::MULT);
176
177 Polynomial res = Polynomial::mkOne();
178
179 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
180 Node curr = *i;
181 Polynomial currPoly = Polynomial::parsePolynomial(curr);
182
183 res = res * currPoly;
184 }
185
186 return RewriteResponse(REWRITE_DONE, res.getNode());
187 }
188
189 RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
190 // left |><| right
191 TNode left = atom[0];
192 TNode right = atom[1];
193
194 Polynomial pleft = Polynomial::parsePolynomial(left);
195 Polynomial pright = Polynomial::parsePolynomial(right);
196
197 Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
198 Assert(cmp.isNormalForm());
199 return RewriteResponse(REWRITE_DONE, cmp.getNode());
200 }
201
202 RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
203 Assert(isAtom(atom));
204
205 NodeManager* currNM = NodeManager::currentNM();
206
207 if(atom.getKind() == kind::EQUAL) {
208 if(atom[0] == atom[1]) {
209 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
210 }
211 }else if(atom.getKind() == kind::GT){
212 Node leq = currNM->mkNode(kind::LEQ, atom[0], atom[1]);
213 return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, leq));
214 }else if(atom.getKind() == kind::LT){
215 Node geq = currNM->mkNode(kind::GEQ, atom[0], atom[1]);
216 return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, geq));
217 }
218
219 return RewriteResponse(REWRITE_DONE, atom);
220 }
221
222 RewriteResponse ArithRewriter::postRewrite(TNode t){
223 if(isTerm(t)){
224 RewriteResponse response = postRewriteTerm(t);
225 if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
226 Polynomial::parsePolynomial(response.node);
227 }
228 return response;
229 }else if(isAtom(t)){
230 RewriteResponse response = postRewriteAtom(t);
231 if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
232 Comparison::parseNormalForm(response.node);
233 }
234 return response;
235 }else{
236 Unreachable();
237 return RewriteResponse(REWRITE_DONE, Node::null());
238 }
239 }
240
241 RewriteResponse ArithRewriter::preRewrite(TNode t){
242 if(isTerm(t)){
243 return preRewriteTerm(t);
244 }else if(isAtom(t)){
245 return preRewriteAtom(t);
246 }else{
247 Unreachable();
248 return RewriteResponse(REWRITE_DONE, Node::null());
249 }
250 }
251
252 Node ArithRewriter::makeUnaryMinusNode(TNode n){
253 Rational qNegOne(-1);
254 return NodeManager::currentNM()->mkNode(kind::MULT, mkRationalNode(qNegOne),n);
255 }
256
257 Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
258 Node negR = makeUnaryMinusNode(r);
259 Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR);
260
261 return diff;
262 }
263 RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
264 Assert(t.getKind()== kind::DIVISION);
265
266 Node left = t[0];
267 Node right = t[1];
268
269 if(right.getKind() == kind::CONST_RATIONAL &&
270 left.getKind() != kind::CONST_RATIONAL){
271
272 const Rational& den = right.getConst<Rational>();
273
274 Assert(!den.isZero());
275
276 Rational div = den.inverse();
277 Node result = mkRationalNode(div);
278 Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
279 if(pre){
280 return RewriteResponse(REWRITE_DONE, mult);
281 }else{
282 return RewriteResponse(REWRITE_AGAIN, mult);
283 }
284 }
285
286 if(pre){
287 if(right.getKind() != kind::CONST_RATIONAL ||
288 left.getKind() != kind::CONST_RATIONAL){
289 return RewriteResponse(REWRITE_DONE, t);
290 }
291 }
292
293 Assert(right.getKind() == kind::CONST_RATIONAL);
294 Assert(left.getKind() == kind::CONST_RATIONAL);
295
296 const Rational& den = right.getConst<Rational>();
297
298 Assert(!den.isZero());
299
300 const Rational& num = left.getConst<Rational>();
301 Rational div = num / den;
302 Node result = mkRationalNode(div);
303 return RewriteResponse(REWRITE_DONE, result);
304 }
305
306 RewriteResponse ArithRewriter::rewriteDivTotal(TNode t, bool pre){
307 Assert(t.getKind() == kind::DIVISION_TOTAL);
308
309
310 Node left = t[0];
311 Node right = t[1];
312 if(right.getKind() == kind::CONST_RATIONAL){
313 const Rational& den = right.getConst<Rational>();
314
315 if(den.isZero()){
316 return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
317 }
318 Assert(den != Rational(0));
319
320 if(left.getKind() == kind::CONST_RATIONAL){
321 const Rational& num = left.getConst<Rational>();
322 Rational div = num / den;
323 Node result = mkRationalNode(div);
324 return RewriteResponse(REWRITE_DONE, result);
325 }
326
327 Rational div = den.inverse();
328
329 Node result = mkRationalNode(div);
330
331 Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
332 if(pre){
333 return RewriteResponse(REWRITE_DONE, mult);
334 }else{
335 return RewriteResponse(REWRITE_AGAIN, mult);
336 }
337 }else{
338 return RewriteResponse(REWRITE_DONE, t);
339 }
340 }
341
342 RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre){
343 Kind k = t.getKind();
344 // Assert(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL ||
345 // k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
346
347 //Leaving the function as before (INTS_MODULUS can be handled),
348 // but restricting its use here
349 Assert(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL);
350 TNode n = t[0], d = t[1];
351 bool dIsConstant = d.getKind() == kind::CONST_RATIONAL;
352 if(dIsConstant && d.getConst<Rational>().isZero()){
353 if(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL){
354 return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
355 }else{
356 // Do nothing for k == INTS_MODULUS
357 return RewriteResponse(REWRITE_DONE, t);
358 }
359 }else if(dIsConstant && d.getConst<Rational>().isOne()){
360 if(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL){
361 return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
362 }else{
363 Assert(k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
364 return RewriteResponse(REWRITE_AGAIN, n);
365 }
366 }else if(dIsConstant && n.getKind() == kind::CONST_RATIONAL){
367 Assert(d.getConst<Rational>().isIntegral());
368 Assert(n.getConst<Rational>().isIntegral());
369 Assert(!d.getConst<Rational>().isZero());
370 Integer di = d.getConst<Rational>().getNumerator();
371 Integer ni = n.getConst<Rational>().getNumerator();
372
373 bool isDiv = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
374
375 Integer result = isDiv ? ni.floorDivideQuotient(di) : ni.floorDivideRemainder(di);
376
377 Node resultNode = mkRationalNode(Rational(result));
378 return RewriteResponse(REWRITE_DONE, resultNode);
379 }else{
380 return RewriteResponse(REWRITE_DONE, t);
381 }
382 }
383
384 }/* CVC4::theory::arith namespace */
385 }/* CVC4::theory namespace */
386 }/* CVC4 namespace */