1 /********************* */
2 /*! \file arith_rewriter.cpp
4 ** Original author: Tim King
5 ** Major contributors: none
6 ** Minor contributors (to current version): Morgan Deters, Dejan Jovanovic
7 ** This file is part of the CVC4 project.
8 ** Copyright (c) 2009-2013 New York University and The University of Iowa
9 ** See the file COPYING in the top-level source directory for licensing
10 ** information.\endverbatim
12 ** \brief [[ Add one-line brief description here ]]
14 ** [[ Add lengthier description here ]]
15 ** \todo document this file
18 #include "theory/theory.h"
19 #include "theory/arith/normal_form.h"
20 #include "theory/arith/arith_rewriter.h"
21 #include "theory/arith/arith_utilities.h"
31 bool ArithRewriter::isAtom(TNode n
) {
33 return arith::isRelationOperator(k
) || k
== kind::IS_INTEGER
|| k
== kind::DIVISIBLE
;
36 RewriteResponse
ArithRewriter::rewriteConstant(TNode t
){
38 Assert(t
.getKind() == kind::CONST_RATIONAL
);
40 return RewriteResponse(REWRITE_DONE
, t
);
43 RewriteResponse
ArithRewriter::rewriteVariable(TNode t
){
46 return RewriteResponse(REWRITE_DONE
, t
);
49 RewriteResponse
ArithRewriter::rewriteMinus(TNode t
, bool pre
){
50 Assert(t
.getKind()== kind::MINUS
);
55 Node zeroNode
= mkRationalNode(zero
);
56 return RewriteResponse(REWRITE_DONE
, zeroNode
);
58 Node noMinus
= makeSubtractionNode(t
[0],t
[1]);
59 return RewriteResponse(REWRITE_DONE
, noMinus
);
62 Polynomial minuend
= Polynomial::parsePolynomial(t
[0]);
63 Polynomial subtrahend
= Polynomial::parsePolynomial(t
[1]);
64 Polynomial diff
= minuend
- subtrahend
;
65 return RewriteResponse(REWRITE_DONE
, diff
.getNode());
69 RewriteResponse
ArithRewriter::rewriteUMinus(TNode t
, bool pre
){
70 Assert(t
.getKind()== kind::UMINUS
);
72 if(t
[0].getKind() == kind::CONST_RATIONAL
){
73 Rational neg
= -(t
[0].getConst
<Rational
>());
74 return RewriteResponse(REWRITE_DONE
, mkRationalNode(neg
));
77 Node noUminus
= makeUnaryMinusNode(t
[0]);
79 return RewriteResponse(REWRITE_DONE
, noUminus
);
81 return RewriteResponse(REWRITE_AGAIN
, noUminus
);
84 RewriteResponse
ArithRewriter::preRewriteTerm(TNode t
){
86 return rewriteConstant(t
);
88 return rewriteVariable(t
);
90 switch(Kind k
= t
.getKind()){
92 return rewriteMinus(t
, true);
94 return rewriteUMinus(t
, true);
96 case kind::DIVISION_TOTAL
:
97 return rewriteDiv(t
,true);
99 return preRewritePlus(t
);
101 return preRewriteMult(t
);
102 case kind::INTS_DIVISION
:
103 case kind::INTS_MODULUS
:
104 return RewriteResponse(REWRITE_DONE
, t
);
105 case kind::INTS_DIVISION_TOTAL
:
106 case kind::INTS_MODULUS_TOTAL
:
107 return rewriteIntsDivModTotal(t
,true);
110 const Rational
& rat
= t
[0].getConst
<Rational
>();
112 return RewriteResponse(REWRITE_DONE
, t
[0]);
114 return RewriteResponse(REWRITE_DONE
,
115 NodeManager::currentNM()->mkConst(-rat
));
118 return RewriteResponse(REWRITE_DONE
, t
);
119 case kind::IS_INTEGER
:
120 case kind::TO_INTEGER
:
121 return RewriteResponse(REWRITE_DONE
, t
);
123 return RewriteResponse(REWRITE_DONE
, t
[0]);
129 RewriteResponse
ArithRewriter::postRewriteTerm(TNode t
){
131 return rewriteConstant(t
);
133 return rewriteVariable(t
);
137 return rewriteMinus(t
, false);
139 return rewriteUMinus(t
, false);
141 case kind::DIVISION_TOTAL
:
142 return rewriteDiv(t
, false);
144 return postRewritePlus(t
);
146 return postRewriteMult(t
);
147 case kind::INTS_DIVISION
:
148 case kind::INTS_MODULUS
:
149 return RewriteResponse(REWRITE_DONE
, t
);
150 case kind::INTS_DIVISION_TOTAL
:
151 case kind::INTS_MODULUS_TOTAL
:
152 return rewriteIntsDivModTotal(t
, false);
155 const Rational
& rat
= t
[0].getConst
<Rational
>();
157 return RewriteResponse(REWRITE_DONE
, t
[0]);
159 return RewriteResponse(REWRITE_DONE
,
160 NodeManager::currentNM()->mkConst(-rat
));
164 return RewriteResponse(REWRITE_DONE
, t
[0]);
165 case kind::TO_INTEGER
:
167 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(Rational(t
[0].getConst
<Rational
>().floor())));
169 if(t
[0].getType().isInteger()) {
170 return RewriteResponse(REWRITE_DONE
, t
[0]);
172 //Unimplemented("TO_INTEGER, nonconstant");
173 //return rewriteToInteger(t);
174 return RewriteResponse(REWRITE_DONE
, t
);
175 case kind::IS_INTEGER
:
177 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(t
[0].getConst
<Rational
>().getDenominator() == 1));
179 if(t
[0].getType().isInteger()) {
180 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(true));
182 //Unimplemented("IS_INTEGER, nonconstant");
183 //return rewriteIsInteger(t);
184 return RewriteResponse(REWRITE_DONE
, t
);
192 RewriteResponse
ArithRewriter::preRewriteMult(TNode t
){
193 Assert(t
.getKind()== kind::MULT
);
195 // Rewrite multiplications with a 0 argument and to 0
198 for(TNode::iterator i
= t
.begin(); i
!= t
.end(); ++i
) {
199 if((*i
).getKind() == kind::CONST_RATIONAL
) {
200 if((*i
).getConst
<Rational
>() == qZero
) {
201 return RewriteResponse(REWRITE_DONE
, mkRationalNode(qZero
));
205 return RewriteResponse(REWRITE_DONE
, t
);
207 RewriteResponse
ArithRewriter::preRewritePlus(TNode t
){
208 Assert(t
.getKind()== kind::PLUS
);
210 return RewriteResponse(REWRITE_DONE
, t
);
213 RewriteResponse
ArithRewriter::postRewritePlus(TNode t
){
214 Assert(t
.getKind()== kind::PLUS
);
216 Polynomial res
= Polynomial::mkZero();
218 for(TNode::iterator i
= t
.begin(), end
= t
.end(); i
!= end
; ++i
){
220 Polynomial currPoly
= Polynomial::parsePolynomial(curr
);
222 res
= res
+ currPoly
;
225 return RewriteResponse(REWRITE_DONE
, res
.getNode());
228 RewriteResponse
ArithRewriter::postRewriteMult(TNode t
){
229 Assert(t
.getKind()== kind::MULT
);
231 Polynomial res
= Polynomial::mkOne();
233 for(TNode::iterator i
= t
.begin(), end
= t
.end(); i
!= end
; ++i
){
235 Polynomial currPoly
= Polynomial::parsePolynomial(curr
);
237 res
= res
* currPoly
;
240 return RewriteResponse(REWRITE_DONE
, res
.getNode());
243 RewriteResponse
ArithRewriter::postRewriteAtom(TNode atom
){
244 if(atom
.getKind() == kind::IS_INTEGER
) {
245 if(atom
[0].isConst()) {
246 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(atom
[0].getConst
<Rational
>().isIntegral()));
248 if(atom
[0].getType().isInteger()) {
249 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(true));
251 // not supported, but this isn't the right place to complain
252 return RewriteResponse(REWRITE_DONE
, atom
);
253 } else if(atom
.getKind() == kind::DIVISIBLE
) {
254 if(atom
[0].isConst()) {
255 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(bool((atom
[0].getConst
<Rational
>() / atom
.getOperator().getConst
<Divisible
>().k
).isIntegral())));
257 if(atom
.getOperator().getConst
<Divisible
>().k
.isOne()) {
258 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(true));
260 return RewriteResponse(REWRITE_AGAIN
, NodeManager::currentNM()->mkNode(kind::EQUAL
, NodeManager::currentNM()->mkNode(kind::INTS_MODULUS_TOTAL
, atom
[0], NodeManager::currentNM()->mkConst(Rational(atom
.getOperator().getConst
<Divisible
>().k
))), NodeManager::currentNM()->mkConst(Rational(0))));
264 TNode left
= atom
[0];
265 TNode right
= atom
[1];
267 Polynomial pleft
= Polynomial::parsePolynomial(left
);
268 Polynomial pright
= Polynomial::parsePolynomial(right
);
270 Comparison cmp
= Comparison::mkComparison(atom
.getKind(), pleft
, pright
);
271 Assert(cmp
.isNormalForm());
272 return RewriteResponse(REWRITE_DONE
, cmp
.getNode());
275 RewriteResponse
ArithRewriter::preRewriteAtom(TNode atom
){
276 Assert(isAtom(atom
));
278 NodeManager
* currNM
= NodeManager::currentNM();
280 if(atom
.getKind() == kind::EQUAL
) {
281 if(atom
[0] == atom
[1]) {
282 return RewriteResponse(REWRITE_DONE
, currNM
->mkConst(true));
284 }else if(atom
.getKind() == kind::GT
){
285 Node leq
= currNM
->mkNode(kind::LEQ
, atom
[0], atom
[1]);
286 return RewriteResponse(REWRITE_DONE
, currNM
->mkNode(kind::NOT
, leq
));
287 }else if(atom
.getKind() == kind::LT
){
288 Node geq
= currNM
->mkNode(kind::GEQ
, atom
[0], atom
[1]);
289 return RewriteResponse(REWRITE_DONE
, currNM
->mkNode(kind::NOT
, geq
));
290 }else if(atom
.getKind() == kind::IS_INTEGER
){
291 if(atom
[0].getType().isInteger()){
292 return RewriteResponse(REWRITE_DONE
, currNM
->mkConst(true));
294 }else if(atom
.getKind() == kind::DIVISIBLE
){
295 if(atom
.getOperator().getConst
<Divisible
>().k
.isOne()){
296 return RewriteResponse(REWRITE_DONE
, currNM
->mkConst(true));
300 return RewriteResponse(REWRITE_DONE
, atom
);
303 RewriteResponse
ArithRewriter::postRewrite(TNode t
){
305 RewriteResponse response
= postRewriteTerm(t
);
306 if(Debug
.isOn("arith::rewriter") && response
.status
== REWRITE_DONE
) {
307 Polynomial::parsePolynomial(response
.node
);
311 RewriteResponse response
= postRewriteAtom(t
);
312 if(Debug
.isOn("arith::rewriter") && response
.status
== REWRITE_DONE
) {
313 Comparison::parseNormalForm(response
.node
);
318 return RewriteResponse(REWRITE_DONE
, Node::null());
322 RewriteResponse
ArithRewriter::preRewrite(TNode t
){
324 return preRewriteTerm(t
);
326 return preRewriteAtom(t
);
329 return RewriteResponse(REWRITE_DONE
, Node::null());
333 Node
ArithRewriter::makeUnaryMinusNode(TNode n
){
334 Rational
qNegOne(-1);
335 return NodeManager::currentNM()->mkNode(kind::MULT
, mkRationalNode(qNegOne
),n
);
338 Node
ArithRewriter::makeSubtractionNode(TNode l
, TNode r
){
339 Node negR
= makeUnaryMinusNode(r
);
340 Node diff
= NodeManager::currentNM()->mkNode(kind::PLUS
, l
, negR
);
345 RewriteResponse
ArithRewriter::rewriteDiv(TNode t
, bool pre
){
346 Assert(t
.getKind() == kind::DIVISION_TOTAL
|| t
.getKind()== kind::DIVISION
);
351 if(right
.getKind() == kind::CONST_RATIONAL
){
352 const Rational
& den
= right
.getConst
<Rational
>();
355 if(t
.getKind() == kind::DIVISION_TOTAL
){
356 return RewriteResponse(REWRITE_DONE
, mkRationalNode(0));
358 // This is unsupported, but this is not a good place to complain
359 return RewriteResponse(REWRITE_DONE
, t
);
362 Assert(den
!= Rational(0));
364 if(left
.getKind() == kind::CONST_RATIONAL
){
365 const Rational
& num
= left
.getConst
<Rational
>();
366 Rational div
= num
/ den
;
367 Node result
= mkRationalNode(div
);
368 return RewriteResponse(REWRITE_DONE
, result
);
371 Rational div
= den
.inverse();
373 Node result
= mkRationalNode(div
);
375 Node mult
= NodeManager::currentNM()->mkNode(kind::MULT
,left
,result
);
377 return RewriteResponse(REWRITE_DONE
, mult
);
379 return RewriteResponse(REWRITE_AGAIN
, mult
);
382 return RewriteResponse(REWRITE_DONE
, t
);
386 RewriteResponse
ArithRewriter::rewriteIntsDivModTotal(TNode t
, bool pre
){
387 Kind k
= t
.getKind();
388 // Assert(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL ||
389 // k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
391 //Leaving the function as before (INTS_MODULUS can be handled),
392 // but restricting its use here
393 Assert(k
== kind::INTS_MODULUS_TOTAL
|| k
== kind::INTS_DIVISION_TOTAL
);
394 TNode n
= t
[0], d
= t
[1];
395 bool dIsConstant
= d
.getKind() == kind::CONST_RATIONAL
;
396 if(dIsConstant
&& d
.getConst
<Rational
>().isZero()){
397 if(k
== kind::INTS_MODULUS_TOTAL
|| k
== kind::INTS_DIVISION_TOTAL
){
398 return RewriteResponse(REWRITE_DONE
, mkRationalNode(0));
400 // Do nothing for k == INTS_MODULUS
401 return RewriteResponse(REWRITE_DONE
, t
);
403 }else if(dIsConstant
&& d
.getConst
<Rational
>().isOne()){
404 if(k
== kind::INTS_MODULUS
|| k
== kind::INTS_MODULUS_TOTAL
){
405 return RewriteResponse(REWRITE_DONE
, mkRationalNode(0));
407 Assert(k
== kind::INTS_DIVISION
|| k
== kind::INTS_DIVISION_TOTAL
);
408 return RewriteResponse(REWRITE_AGAIN
, n
);
410 }else if(dIsConstant
&& d
.getConst
<Rational
>().isNegativeOne()){
411 if(k
== kind::INTS_MODULUS
|| k
== kind::INTS_MODULUS_TOTAL
){
412 return RewriteResponse(REWRITE_DONE
, mkRationalNode(0));
414 Assert(k
== kind::INTS_DIVISION
|| k
== kind::INTS_DIVISION_TOTAL
);
415 return RewriteResponse(REWRITE_AGAIN
, NodeManager::currentNM()->mkNode(kind::UMINUS
, n
));
417 }else if(dIsConstant
&& n
.getKind() == kind::CONST_RATIONAL
){
418 Assert(d
.getConst
<Rational
>().isIntegral());
419 Assert(n
.getConst
<Rational
>().isIntegral());
420 Assert(!d
.getConst
<Rational
>().isZero());
421 Integer di
= d
.getConst
<Rational
>().getNumerator();
422 Integer ni
= n
.getConst
<Rational
>().getNumerator();
424 bool isDiv
= (k
== kind::INTS_DIVISION
|| k
== kind::INTS_DIVISION_TOTAL
);
426 Integer result
= isDiv
? ni
.euclidianDivideQuotient(di
) : ni
.euclidianDivideRemainder(di
);
428 Node resultNode
= mkRationalNode(Rational(result
));
429 return RewriteResponse(REWRITE_DONE
, resultNode
);
431 return RewriteResponse(REWRITE_DONE
, t
);
435 }/* CVC4::theory::arith namespace */
436 }/* CVC4::theory namespace */
437 }/* CVC4 namespace */