Regenerated copyrights: canonicalized names, no emails
[cvc5.git] / src / theory / arith / arith_rewriter.cpp
1 /********************* */
2 /*! \file arith_rewriter.cpp
3 ** \verbatim
4 ** Original author: Tim King
5 ** Major contributors: none
6 ** Minor contributors (to current version): Morgan Deters, Dejan Jovanovic
7 ** This file is part of the CVC4 project.
8 ** Copyright (c) 2009-2013 New York University and The University of Iowa
9 ** See the file COPYING in the top-level source directory for licensing
10 ** information.\endverbatim
11 **
12 ** \brief [[ Add one-line brief description here ]]
13 **
14 ** [[ Add lengthier description here ]]
15 ** \todo document this file
16 **/
17
18 #include "theory/theory.h"
19 #include "theory/arith/normal_form.h"
20 #include "theory/arith/arith_rewriter.h"
21 #include "theory/arith/arith_utilities.h"
22
23 #include <vector>
24 #include <set>
25 #include <stack>
26
27 namespace CVC4 {
28 namespace theory {
29 namespace arith {
30
31 bool ArithRewriter::isAtom(TNode n) {
32 return arith::isRelationOperator(n.getKind());
33 }
34
35 RewriteResponse ArithRewriter::rewriteConstant(TNode t){
36 Assert(t.isConst());
37 Assert(t.getKind() == kind::CONST_RATIONAL);
38
39 return RewriteResponse(REWRITE_DONE, t);
40 }
41
42 RewriteResponse ArithRewriter::rewriteVariable(TNode t){
43 Assert(t.isVar());
44
45 return RewriteResponse(REWRITE_DONE, t);
46 }
47
48 RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){
49 Assert(t.getKind()== kind::MINUS);
50
51 if(pre){
52 if(t[0] == t[1]){
53 Rational zero(0);
54 Node zeroNode = mkRationalNode(zero);
55 return RewriteResponse(REWRITE_DONE, zeroNode);
56 }else{
57 Node noMinus = makeSubtractionNode(t[0],t[1]);
58 return RewriteResponse(REWRITE_DONE, noMinus);
59 }
60 }else{
61 Polynomial minuend = Polynomial::parsePolynomial(t[0]);
62 Polynomial subtrahend = Polynomial::parsePolynomial(t[1]);
63 Polynomial diff = minuend - subtrahend;
64 return RewriteResponse(REWRITE_DONE, diff.getNode());
65 }
66 }
67
68 RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){
69 Assert(t.getKind()== kind::UMINUS);
70
71 if(t[0].getKind() == kind::CONST_RATIONAL){
72 Rational neg = -(t[0].getConst<Rational>());
73 return RewriteResponse(REWRITE_DONE, mkRationalNode(neg));
74 }
75
76 Node noUminus = makeUnaryMinusNode(t[0]);
77 if(pre)
78 return RewriteResponse(REWRITE_DONE, noUminus);
79 else
80 return RewriteResponse(REWRITE_AGAIN, noUminus);
81 }
82
83 RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
84 if(t.isConst()){
85 return rewriteConstant(t);
86 }else if(t.isVar()){
87 return rewriteVariable(t);
88 }else{
89 switch(Kind k = t.getKind()){
90 case kind::MINUS:
91 return rewriteMinus(t, true);
92 case kind::UMINUS:
93 return rewriteUMinus(t, true);
94 case kind::DIVISION:
95 case kind::DIVISION_TOTAL:
96 return rewriteDiv(t,true);
97 case kind::PLUS:
98 return preRewritePlus(t);
99 case kind::MULT:
100 return preRewriteMult(t);
101 //case kind::INTS_DIVISION:
102 //case kind::INTS_MODULUS:
103 case kind::INTS_DIVISION_TOTAL:
104 case kind::INTS_MODULUS_TOTAL:
105 return rewriteIntsDivModTotal(t,true);
106 default:
107 Unhandled(k);
108 }
109 }
110 }
111 RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
112 if(t.isConst()){
113 return rewriteConstant(t);
114 }else if(t.isVar()){
115 return rewriteVariable(t);
116 }else{
117 switch(t.getKind()){
118 case kind::MINUS:
119 return rewriteMinus(t, false);
120 case kind::UMINUS:
121 return rewriteUMinus(t, false);
122 case kind::DIVISION:
123 case kind::DIVISION_TOTAL:
124 return rewriteDiv(t, false);
125 case kind::PLUS:
126 return postRewritePlus(t);
127 case kind::MULT:
128 return postRewriteMult(t);
129 //case kind::INTS_DIVISION:
130 //case kind::INTS_MODULUS:
131 case kind::INTS_DIVISION_TOTAL:
132 case kind::INTS_MODULUS_TOTAL:
133 return rewriteIntsDivModTotal(t, false);
134 default:
135 Unreachable();
136 }
137 }
138 }
139
140
141 RewriteResponse ArithRewriter::preRewriteMult(TNode t){
142 Assert(t.getKind()== kind::MULT);
143
144 // Rewrite multiplications with a 0 argument and to 0
145 Rational qZero(0);
146
147 for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
148 if((*i).getKind() == kind::CONST_RATIONAL) {
149 if((*i).getConst<Rational>() == qZero) {
150 return RewriteResponse(REWRITE_DONE, mkRationalNode(qZero));
151 }
152 }
153 }
154 return RewriteResponse(REWRITE_DONE, t);
155 }
156 RewriteResponse ArithRewriter::preRewritePlus(TNode t){
157 Assert(t.getKind()== kind::PLUS);
158
159 return RewriteResponse(REWRITE_DONE, t);
160 }
161
162 RewriteResponse ArithRewriter::postRewritePlus(TNode t){
163 Assert(t.getKind()== kind::PLUS);
164
165 Polynomial res = Polynomial::mkZero();
166
167 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
168 Node curr = *i;
169 Polynomial currPoly = Polynomial::parsePolynomial(curr);
170
171 res = res + currPoly;
172 }
173
174 return RewriteResponse(REWRITE_DONE, res.getNode());
175 }
176
177 RewriteResponse ArithRewriter::postRewriteMult(TNode t){
178 Assert(t.getKind()== kind::MULT);
179
180 Polynomial res = Polynomial::mkOne();
181
182 for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
183 Node curr = *i;
184 Polynomial currPoly = Polynomial::parsePolynomial(curr);
185
186 res = res * currPoly;
187 }
188
189 return RewriteResponse(REWRITE_DONE, res.getNode());
190 }
191
192 RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
193 // left |><| right
194 TNode left = atom[0];
195 TNode right = atom[1];
196
197 Polynomial pleft = Polynomial::parsePolynomial(left);
198 Polynomial pright = Polynomial::parsePolynomial(right);
199
200 Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
201 Assert(cmp.isNormalForm());
202 return RewriteResponse(REWRITE_DONE, cmp.getNode());
203 }
204
205 RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
206 Assert(isAtom(atom));
207
208 NodeManager* currNM = NodeManager::currentNM();
209
210 if(atom.getKind() == kind::EQUAL) {
211 if(atom[0] == atom[1]) {
212 return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
213 }
214 }else if(atom.getKind() == kind::GT){
215 Node leq = currNM->mkNode(kind::LEQ, atom[0], atom[1]);
216 return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, leq));
217 }else if(atom.getKind() == kind::LT){
218 Node geq = currNM->mkNode(kind::GEQ, atom[0], atom[1]);
219 return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, geq));
220 }
221
222 return RewriteResponse(REWRITE_DONE, atom);
223 }
224
225 RewriteResponse ArithRewriter::postRewrite(TNode t){
226 if(isTerm(t)){
227 RewriteResponse response = postRewriteTerm(t);
228 if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
229 Polynomial::parsePolynomial(response.node);
230 }
231 return response;
232 }else if(isAtom(t)){
233 RewriteResponse response = postRewriteAtom(t);
234 if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
235 Comparison::parseNormalForm(response.node);
236 }
237 return response;
238 }else{
239 Unreachable();
240 return RewriteResponse(REWRITE_DONE, Node::null());
241 }
242 }
243
244 RewriteResponse ArithRewriter::preRewrite(TNode t){
245 if(isTerm(t)){
246 return preRewriteTerm(t);
247 }else if(isAtom(t)){
248 return preRewriteAtom(t);
249 }else{
250 Unreachable();
251 return RewriteResponse(REWRITE_DONE, Node::null());
252 }
253 }
254
255 Node ArithRewriter::makeUnaryMinusNode(TNode n){
256 Rational qNegOne(-1);
257 return NodeManager::currentNM()->mkNode(kind::MULT, mkRationalNode(qNegOne),n);
258 }
259
260 Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
261 Node negR = makeUnaryMinusNode(r);
262 Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR);
263
264 return diff;
265 }
266
267 RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
268 Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind()== kind::DIVISION);
269
270
271 Node left = t[0];
272 Node right = t[1];
273 if(right.getKind() == kind::CONST_RATIONAL){
274 const Rational& den = right.getConst<Rational>();
275
276 if(den.isZero()){
277 if(t.getKind() == kind::DIVISION_TOTAL){
278 return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
279 }else{
280 // This is unsupported, but this is not a good place to complain
281 return RewriteResponse(REWRITE_DONE, t);
282 }
283 }
284 Assert(den != Rational(0));
285
286 if(left.getKind() == kind::CONST_RATIONAL){
287 const Rational& num = left.getConst<Rational>();
288 Rational div = num / den;
289 Node result = mkRationalNode(div);
290 return RewriteResponse(REWRITE_DONE, result);
291 }
292
293 Rational div = den.inverse();
294
295 Node result = mkRationalNode(div);
296
297 Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
298 if(pre){
299 return RewriteResponse(REWRITE_DONE, mult);
300 }else{
301 return RewriteResponse(REWRITE_AGAIN, mult);
302 }
303 }else{
304 return RewriteResponse(REWRITE_DONE, t);
305 }
306 }
307
308 RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre){
309 Kind k = t.getKind();
310 // Assert(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL ||
311 // k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
312
313 //Leaving the function as before (INTS_MODULUS can be handled),
314 // but restricting its use here
315 Assert(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL);
316 TNode n = t[0], d = t[1];
317 bool dIsConstant = d.getKind() == kind::CONST_RATIONAL;
318 if(dIsConstant && d.getConst<Rational>().isZero()){
319 if(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL){
320 return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
321 }else{
322 // Do nothing for k == INTS_MODULUS
323 return RewriteResponse(REWRITE_DONE, t);
324 }
325 }else if(dIsConstant && d.getConst<Rational>().isOne()){
326 if(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL){
327 return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
328 }else{
329 Assert(k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
330 return RewriteResponse(REWRITE_AGAIN, n);
331 }
332 }else if(dIsConstant && n.getKind() == kind::CONST_RATIONAL){
333 Assert(d.getConst<Rational>().isIntegral());
334 Assert(n.getConst<Rational>().isIntegral());
335 Assert(!d.getConst<Rational>().isZero());
336 Integer di = d.getConst<Rational>().getNumerator();
337 Integer ni = n.getConst<Rational>().getNumerator();
338
339 bool isDiv = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
340
341 Integer result = isDiv ? ni.euclidianDivideQuotient(di) : ni.euclidianDivideRemainder(di);
342
343 Node resultNode = mkRationalNode(Rational(result));
344 return RewriteResponse(REWRITE_DONE, resultNode);
345 }else{
346 return RewriteResponse(REWRITE_DONE, t);
347 }
348 }
349
350 }/* CVC4::theory::arith namespace */
351 }/* CVC4::theory namespace */
352 }/* CVC4 namespace */