Improved support for division by zero. This adds the *_TOTAL kinds and uninterpreted...
[cvc5.git] / src / theory / arith / arith_rewriter.cpp
1 /********************* */
2 /*! \file arith_rewriter.cpp
3 ** \verbatim
4 ** Original author: taking
5 ** Major contributors: mdeters
6 ** Minor contributors (to current version): dejan
7 ** This file is part of the CVC4 prototype.
8 ** Copyright (c) 2009-2012 New York University and The University of Iowa
9 ** See the file COPYING in the top-level source directory for licensing
10 ** information.\endverbatim
11 **
12 ** \brief [[ Add one-line brief description here ]]
13 **
14 ** [[ Add lengthier description here ]]
15 ** \todo document this file
16 **/
17
18 #include "theory/theory.h"
19 #include "theory/arith/normal_form.h"
20 #include "theory/arith/arith_rewriter.h"
21 #include "theory/arith/arith_utilities.h"
22
23 #include <vector>
24 #include <set>
25 #include <stack>
26
27 namespace CVC4 {
28 namespace theory {
29 namespace arith {
30
31 bool ArithRewriter::isAtom(TNode n) {
32 return arith::isRelationOperator(n.getKind());
33 }
34
35 RewriteResponse ArithRewriter::rewriteConstant(TNode t){
36 Assert(t.isConst());
37 Assert(t.getKind() == kind::CONST_RATIONAL);
38
39 return RewriteResponse(REWRITE_DONE, t);
40 }
41
42 RewriteResponse ArithRewriter::rewriteVariable(TNode t){
43 Assert(t.isVar());
44
45 return RewriteResponse(REWRITE_DONE, t);
46 }
47
48 RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){
49 Assert(t.getKind()== kind::MINUS);
50
51 if(pre){
52 if(t[0] == t[1]){
53 Rational zero(0);
54 Node zeroNode = mkRationalNode(zero);
55 return RewriteResponse(REWRITE_DONE, zeroNode);
56 }else{
57 Node noMinus = makeSubtractionNode(t[0],t[1]);
58 return RewriteResponse(REWRITE_DONE, noMinus);
59 }
60 }else{
61 Polynomial minuend = Polynomial::parsePolynomial(t[0]);
62 Polynomial subtrahend = Polynomial::parsePolynomial(t[1]);
63 Polynomial diff = minuend - subtrahend;
64 return RewriteResponse(REWRITE_DONE, diff.getNode());
65 }
66 }
67
68 RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){
69 Assert(t.getKind()== kind::UMINUS);
70
71 Node noUminus = makeUnaryMinusNode(t[0]);
72 if(pre)
73 return RewriteResponse(REWRITE_DONE, noUminus);
74 else
75 return RewriteResponse(REWRITE_AGAIN, noUminus);
76 }
77
78 RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
79 if(t.isConst()){
80 return rewriteConstant(t);
81 }else if(t.isVar()){
82 return rewriteVariable(t);
83 }else if(t.getKind() == kind::MINUS){
84 return rewriteMinus(t, true);
85 }else if(t.getKind() == kind::UMINUS){
86 return rewriteUMinus(t, true);
87 }else if(t.getKind() == kind::DIVISION){
88 return RewriteResponse(REWRITE_DONE, t); // wait until t[1] is rewritten
89 }else if(t.getKind() == kind::DIVISION_TOTAL){
90 if(t[1].getKind()== kind::CONST_RATIONAL &&
91 t[1].getConst<Rational>().isZero()){
92 return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
93 }else{
94 return RewriteResponse(REWRITE_DONE, t); // wait until t[1] is rewritten
95 }
96 }else if(t.getKind() == kind::PLUS){
97 return preRewritePlus(t);
98 }else if(t.getKind() == kind::MULT){
99 return preRewriteMult(t);
100 }else if(t.getKind() == kind::INTS_DIVISION){
101 Rational intOne(1);
102 if(t[1].getKind()== kind::CONST_RATIONAL &&
103 t[1].getConst<Rational>().isOne()){
104 return RewriteResponse(REWRITE_AGAIN, t[0]);
105 }else{
106 return RewriteResponse(REWRITE_DONE, t);
107 }
108 }else if(t.getKind() == kind::INTS_DIVISION_TOTAL){
109 if(t[1].getKind()== kind::CONST_RATIONAL){
110 Rational intOne(1), intZero(0);
111 if(t[1].getConst<Rational>().isOne()){
112 return RewriteResponse(REWRITE_AGAIN, t[0]);
113 } else if(t[1].getConst<Rational>().isZero()){
114 return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
115 }
116 }
117 return RewriteResponse(REWRITE_DONE, t);
118 }else if(t.getKind() == kind::INTS_MODULUS){
119 Rational intOne(1);
120 if(t[1].getKind()== kind::CONST_RATIONAL &&
121 t[1].getConst<Rational>().isOne()){
122 return RewriteResponse(REWRITE_AGAIN, mkRationalNode(0));
123 }else{
124 return RewriteResponse(REWRITE_DONE, t);
125 }
126 }else if(t.getKind() == kind::INTS_MODULUS_TOTAL){
127 if(t[1].getKind()== kind::CONST_RATIONAL){
128 if(t[1].getConst<Rational>().isOne() || t[1].getConst<Rational>().isZero()){
129 return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
130 }
131 }
132 return RewriteResponse(REWRITE_DONE, t);
133 }else{
134 Unreachable();
135 }
136 }
137 RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
138 if(t.isConst()){
139 return rewriteConstant(t);
140 }else if(t.isVar()){
141 return rewriteVariable(t);
142 }else if(t.getKind() == kind::MINUS){
143 return rewriteMinus(t, false);
144 }else if(t.getKind() == kind::UMINUS){
145 return rewriteUMinus(t, false);
146 }else if(t.getKind() == kind::DIVISION ||
147 t.getKind() == kind::DIVISION_TOTAL){
148 return rewriteDiv(t, false);
149 }else if(t.getKind() == kind::PLUS){
150 return postRewritePlus(t);
151 }else if(t.getKind() == kind::MULT){
152 return postRewriteMult(t);
153 }else if(t.getKind() == kind::INTS_DIVISION ||
154 t.getKind() == kind::INTS_MODULUS){
155 return RewriteResponse(REWRITE_DONE, t);
156 }else if(t.getKind() == kind::INTS_DIVISION_TOTAL ||
157 t.getKind() == kind::INTS_MODULUS_TOTAL){
158 if(t[1].getKind() == kind::CONST_RATIONAL &&
159 t[1].getConst<Rational>().isZero()){
160 return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
161 }else{
162 return RewriteResponse(REWRITE_DONE, t);
163 }
164 }else{
165 Unreachable();
166 }
167 }
168
169 RewriteResponse ArithRewriter::preRewriteMult(TNode t){
170 Assert(t.getKind()== kind::MULT);
171
172 // Rewrite multiplications with a 0 argument and to 0
173 Rational qZero(0);
174
175 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
176 if((*i).getKind() == kind::CONST_RATIONAL) {
177 if((*i).getConst<Rational>() == qZero) {
178 return RewriteResponse(REWRITE_DONE, mkRationalNode(qZero));
179 }
180 }
181 }
182 return RewriteResponse(REWRITE_DONE, t);
183 }
184 RewriteResponse ArithRewriter::preRewritePlus(TNode t){
185 Assert(t.getKind()== kind::PLUS);
186
187 return RewriteResponse(REWRITE_DONE, t);
188 }
189
190 RewriteResponse ArithRewriter::postRewritePlus(TNode t){
191 Assert(t.getKind()== kind::PLUS);
192
193 Polynomial res = Polynomial::mkZero();
194
195 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
196 Node curr = *i;
197 Polynomial currPoly = Polynomial::parsePolynomial(curr);
198
199 res = res + currPoly;
200 }
201
202 return RewriteResponse(REWRITE_DONE, res.getNode());
203 }
204
205 RewriteResponse ArithRewriter::postRewriteMult(TNode t){
206 Assert(t.getKind()== kind::MULT);
207
208 Polynomial res = Polynomial::mkOne();
209
210 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
211 Node curr = *i;
212 Polynomial currPoly = Polynomial::parsePolynomial(curr);
213
214 res = res * currPoly;
215 }
216
217 return RewriteResponse(REWRITE_DONE, res.getNode());
218 }
219
220 // RewriteResponse ArithRewriter::postRewriteAtomConstantRHS(TNode t){
221 // TNode left = t[0];
222 // TNode right = t[1];
223
224 // Polynomial pLeft = Polynomial::parsePolynomial(left);
225
226
227 // Comparison cmp = Comparison::mkComparison(t.getKind(), Polynomial::parsePolynomial(left), Constant(right));
228
229 // Assert(cmp.isNormalForm());
230 // return RewriteResponse(REWRITE_DONE, cmp.getNode());
231 // }
232
233 RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
234 // left |><| right
235 TNode left = atom[0];
236 TNode right = atom[1];
237
238 Polynomial pleft = Polynomial::parsePolynomial(left);
239 Polynomial pright = Polynomial::parsePolynomial(right);
240
241 Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
242 Assert(cmp.isNormalForm());
243 return RewriteResponse(REWRITE_DONE, cmp.getNode());
244 }
245
246 RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
247 Assert(isAtom(atom));
248
249 NodeManager* currNM = NodeManager::currentNM();
250
251 if(atom.getKind() == kind::EQUAL) {
252 if(atom[0] == atom[1]) {
253 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
254 }
255 }else if(atom.getKind() == kind::GT){
256 Node leq = currNM->mkNode(kind::LEQ, atom[0], atom[1]);
257 return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, leq));
258 }else if(atom.getKind() == kind::LT){
259 Node geq = currNM->mkNode(kind::GEQ, atom[0], atom[1]);
260 return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, geq));
261 }
262
263 return RewriteResponse(REWRITE_DONE, atom);
264 }
265
266 RewriteResponse ArithRewriter::postRewrite(TNode t){
267 if(isTerm(t)){
268 RewriteResponse response = postRewriteTerm(t);
269 if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
270 Polynomial::parsePolynomial(response.node);
271 }
272 return response;
273 }else if(isAtom(t)){
274 RewriteResponse response = postRewriteAtom(t);
275 if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
276 Comparison::parseNormalForm(response.node);
277 }
278 return response;
279 }else{
280 Unreachable();
281 return RewriteResponse(REWRITE_DONE, Node::null());
282 }
283 }
284
285 RewriteResponse ArithRewriter::preRewrite(TNode t){
286 if(isTerm(t)){
287 return preRewriteTerm(t);
288 }else if(isAtom(t)){
289 return preRewriteAtom(t);
290 }else{
291 Unreachable();
292 return RewriteResponse(REWRITE_DONE, Node::null());
293 }
294 }
295
296 Node ArithRewriter::makeUnaryMinusNode(TNode n){
297 Rational qNegOne(-1);
298 return NodeManager::currentNM()->mkNode(kind::MULT, mkRationalNode(qNegOne),n);
299 }
300
301 Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
302 Node negR = makeUnaryMinusNode(r);
303 Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR);
304
305 return diff;
306 }
307
308 RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
309 Assert(t.getKind()== kind::DIVISION || t.getKind() == kind::DIVISION_TOTAL);
310
311 Node left = t[0];
312 Node right = t[1];
313 if(right.getKind() == kind::CONST_RATIONAL){
314 const Rational& den = right.getConst<Rational>();
315
316 if(den.isZero()){
317 if(t.getKind() == kind::DIVISION_TOTAL){
318 return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
319 }else{
320 return RewriteResponse(REWRITE_DONE, t);
321 }
322 }
323 Assert(den != Rational(0));
324
325 Rational div = den.inverse();
326
327 Node result = mkRationalNode(div);
328
329 Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
330 if(pre){
331 return RewriteResponse(REWRITE_DONE, mult);
332 }else{
333 return RewriteResponse(REWRITE_AGAIN, mult);
334 }
335 }else{
336 return RewriteResponse(REWRITE_DONE, t);
337 }
338 }
339
340 }/* CVC4::theory::arith namespace */
341 }/* CVC4::theory namespace */
342 }/* CVC4 namespace */