1 /********************* */
2 /*! \file arith_rewriter.cpp
4 ** Original author: Tim King
5 ** Major contributors: Morgan Deters
6 ** Minor contributors (to current version): Dejan Jovanovic
7 ** This file is part of the CVC4 project.
8 ** Copyright (c) 2009-2013 New York University and The University of Iowa
9 ** See the file COPYING in the top-level source directory for licensing
10 ** information.\endverbatim
12 ** \brief [[ Add one-line brief description here ]]
14 ** [[ Add lengthier description here ]]
15 ** \todo document this file
18 #include "theory/theory.h"
19 #include "theory/arith/normal_form.h"
20 #include "theory/arith/arith_rewriter.h"
21 #include "theory/arith/arith_utilities.h"
31 bool ArithRewriter::isAtom(TNode n
) {
33 return arith::isRelationOperator(k
) || k
== kind::IS_INTEGER
|| k
== kind::DIVISIBLE
;
36 RewriteResponse
ArithRewriter::rewriteConstant(TNode t
){
38 Assert(t
.getKind() == kind::CONST_RATIONAL
);
40 return RewriteResponse(REWRITE_DONE
, t
);
43 RewriteResponse
ArithRewriter::rewriteVariable(TNode t
){
46 return RewriteResponse(REWRITE_DONE
, t
);
49 RewriteResponse
ArithRewriter::rewriteMinus(TNode t
, bool pre
){
50 Assert(t
.getKind()== kind::MINUS
);
55 Node zeroNode
= mkRationalNode(zero
);
56 return RewriteResponse(REWRITE_DONE
, zeroNode
);
58 Node noMinus
= makeSubtractionNode(t
[0],t
[1]);
59 return RewriteResponse(REWRITE_DONE
, noMinus
);
62 Polynomial minuend
= Polynomial::parsePolynomial(t
[0]);
63 Polynomial subtrahend
= Polynomial::parsePolynomial(t
[1]);
64 Polynomial diff
= minuend
- subtrahend
;
65 return RewriteResponse(REWRITE_DONE
, diff
.getNode());
69 RewriteResponse
ArithRewriter::rewriteUMinus(TNode t
, bool pre
){
70 Assert(t
.getKind()== kind::UMINUS
);
72 if(t
[0].getKind() == kind::CONST_RATIONAL
){
73 Rational neg
= -(t
[0].getConst
<Rational
>());
74 return RewriteResponse(REWRITE_DONE
, mkRationalNode(neg
));
77 Node noUminus
= makeUnaryMinusNode(t
[0]);
79 return RewriteResponse(REWRITE_DONE
, noUminus
);
81 return RewriteResponse(REWRITE_AGAIN
, noUminus
);
84 RewriteResponse
ArithRewriter::preRewriteTerm(TNode t
){
86 return rewriteConstant(t
);
88 return rewriteVariable(t
);
90 switch(Kind k
= t
.getKind()){
92 return rewriteMinus(t
, true);
94 return rewriteUMinus(t
, true);
96 case kind::DIVISION_TOTAL
:
97 return rewriteDiv(t
,true);
99 return preRewritePlus(t
);
101 return preRewriteMult(t
);
102 case kind::INTS_DIVISION
:
103 case kind::INTS_MODULUS
:
104 return RewriteResponse(REWRITE_DONE
, t
);
105 case kind::INTS_DIVISION_TOTAL
:
106 case kind::INTS_MODULUS_TOTAL
:
107 return rewriteIntsDivModTotal(t
,true);
110 const Rational
& rat
= t
[0].getConst
<Rational
>();
112 return RewriteResponse(REWRITE_DONE
, t
[0]);
114 return RewriteResponse(REWRITE_DONE
,
115 NodeManager::currentNM()->mkConst(-rat
));
118 return RewriteResponse(REWRITE_DONE
, t
);
119 case kind::IS_INTEGER
:
120 case kind::TO_INTEGER
:
121 return RewriteResponse(REWRITE_DONE
, t
);
123 return RewriteResponse(REWRITE_DONE
, t
[0]);
125 return RewriteResponse(REWRITE_DONE
, t
);
132 RewriteResponse
ArithRewriter::postRewriteTerm(TNode t
){
134 return rewriteConstant(t
);
136 return rewriteVariable(t
);
140 return rewriteMinus(t
, false);
142 return rewriteUMinus(t
, false);
144 case kind::DIVISION_TOTAL
:
145 return rewriteDiv(t
, false);
147 return postRewritePlus(t
);
149 return postRewriteMult(t
);
150 case kind::INTS_DIVISION
:
151 case kind::INTS_MODULUS
:
152 return RewriteResponse(REWRITE_DONE
, t
);
153 case kind::INTS_DIVISION_TOTAL
:
154 case kind::INTS_MODULUS_TOTAL
:
155 return rewriteIntsDivModTotal(t
, false);
158 const Rational
& rat
= t
[0].getConst
<Rational
>();
160 return RewriteResponse(REWRITE_DONE
, t
[0]);
162 return RewriteResponse(REWRITE_DONE
,
163 NodeManager::currentNM()->mkConst(-rat
));
167 return RewriteResponse(REWRITE_DONE
, t
[0]);
168 case kind::TO_INTEGER
:
170 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(Rational(t
[0].getConst
<Rational
>().floor())));
172 if(t
[0].getType().isInteger()) {
173 return RewriteResponse(REWRITE_DONE
, t
[0]);
175 //Unimplemented("TO_INTEGER, nonconstant");
176 //return rewriteToInteger(t);
177 return RewriteResponse(REWRITE_DONE
, t
);
178 case kind::IS_INTEGER
:
180 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(t
[0].getConst
<Rational
>().getDenominator() == 1));
182 if(t
[0].getType().isInteger()) {
183 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(true));
185 //Unimplemented("IS_INTEGER, nonconstant");
186 //return rewriteIsInteger(t);
187 return RewriteResponse(REWRITE_DONE
, t
);
190 if(t
[1].getKind() == kind::CONST_RATIONAL
){
191 const Rational
& exp
= t
[1].getConst
<Rational
>();
194 return RewriteResponse(REWRITE_DONE
, mkRationalNode(Rational(1)));
195 }else if(exp
.sgn() > 0 && exp
.isIntegral()){
196 Integer num
= exp
.getNumerator();
197 NodeBuilder
<> nb(kind::MULT
);
199 for(Integer
i(0); i
< num
; i
= i
+ one
){
202 Assert(nb
.getNumChildren() > 0);
204 return RewriteResponse(REWRITE_AGAIN
, mult
);
208 // Todo improve the exception thrown
209 std::stringstream ss
;
210 ss
<< "The POW(^) operator can only be used with a natural number ";
211 ss
<< "in the exponent. Exception occured in:" << std::endl
;
213 throw Exception(ss
.str());
222 RewriteResponse
ArithRewriter::preRewriteMult(TNode t
){
223 Assert(t
.getKind()== kind::MULT
);
225 if(t
.getNumChildren() == 2){
226 if(t
[0].getKind() == kind::CONST_RATIONAL
227 && t
[0].getConst
<Rational
>().isOne()){
228 return RewriteResponse(REWRITE_DONE
, t
[1]);
230 if(t
[1].getKind() == kind::CONST_RATIONAL
231 && t
[1].getConst
<Rational
>().isOne()){
232 return RewriteResponse(REWRITE_DONE
, t
[0]);
236 // Rewrite multiplications with a 0 argument and to 0
237 for(TNode::iterator i
= t
.begin(); i
!= t
.end(); ++i
) {
238 if((*i
).getKind() == kind::CONST_RATIONAL
) {
239 if((*i
).getConst
<Rational
>().isZero()) {
241 return RewriteResponse(REWRITE_DONE
, zero
);
245 return RewriteResponse(REWRITE_DONE
, t
);
248 static bool canFlatten(Kind k
, TNode t
){
249 for(TNode::iterator i
= t
.begin(); i
!= t
.end(); ++i
) {
251 if(child
.getKind() == k
){
258 static void flatten(std::vector
<TNode
>& pb
, Kind k
, TNode t
){
259 if(t
.getKind() == k
){
260 for(TNode::iterator i
= t
.begin(); i
!= t
.end(); ++i
) {
262 if(child
.getKind() == k
){
263 flatten(pb
, k
, child
);
273 static Node
flatten(Kind k
, TNode t
){
274 std::vector
<TNode
> pb
;
276 Assert(pb
.size() >= 2);
277 return NodeManager::currentNM()->mkNode(k
, pb
);
280 RewriteResponse
ArithRewriter::preRewritePlus(TNode t
){
281 Assert(t
.getKind()== kind::PLUS
);
283 if(canFlatten(kind::PLUS
, t
)){
284 return RewriteResponse(REWRITE_DONE
, flatten(kind::PLUS
, t
));
286 return RewriteResponse(REWRITE_DONE
, t
);
290 RewriteResponse
ArithRewriter::postRewritePlus(TNode t
){
291 Assert(t
.getKind()== kind::PLUS
);
293 std::vector
<Monomial
> monomials
;
294 std::vector
<Polynomial
> polynomials
;
296 for(TNode::iterator i
= t
.begin(), end
= t
.end(); i
!= end
; ++i
){
298 if(Monomial::isMember(curr
)){
299 monomials
.push_back(Monomial::parseMonomial(curr
));
301 polynomials
.push_back(Polynomial::parsePolynomial(curr
));
305 if(!monomials
.empty()){
306 Monomial::sort(monomials
);
307 Monomial::combineAdjacentMonomials(monomials
);
308 polynomials
.push_back(Polynomial::mkPolynomial(monomials
));
311 Polynomial res
= Polynomial::sumPolynomials(polynomials
);
313 return RewriteResponse(REWRITE_DONE
, res
.getNode());
316 RewriteResponse
ArithRewriter::postRewriteMult(TNode t
){
317 Assert(t
.getKind()== kind::MULT
);
319 Polynomial res
= Polynomial::mkOne();
321 for(TNode::iterator i
= t
.begin(), end
= t
.end(); i
!= end
; ++i
){
323 Polynomial currPoly
= Polynomial::parsePolynomial(curr
);
325 res
= res
* currPoly
;
328 return RewriteResponse(REWRITE_DONE
, res
.getNode());
331 RewriteResponse
ArithRewriter::postRewriteAtom(TNode atom
){
332 if(atom
.getKind() == kind::IS_INTEGER
) {
333 if(atom
[0].isConst()) {
334 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(atom
[0].getConst
<Rational
>().isIntegral()));
336 if(atom
[0].getType().isInteger()) {
337 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(true));
339 // not supported, but this isn't the right place to complain
340 return RewriteResponse(REWRITE_DONE
, atom
);
341 } else if(atom
.getKind() == kind::DIVISIBLE
) {
342 if(atom
[0].isConst()) {
343 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(bool((atom
[0].getConst
<Rational
>() / atom
.getOperator().getConst
<Divisible
>().k
).isIntegral())));
345 if(atom
.getOperator().getConst
<Divisible
>().k
.isOne()) {
346 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(true));
348 return RewriteResponse(REWRITE_AGAIN
, NodeManager::currentNM()->mkNode(kind::EQUAL
, NodeManager::currentNM()->mkNode(kind::INTS_MODULUS_TOTAL
, atom
[0], NodeManager::currentNM()->mkConst(Rational(atom
.getOperator().getConst
<Divisible
>().k
))), NodeManager::currentNM()->mkConst(Rational(0))));
352 TNode left
= atom
[0];
353 TNode right
= atom
[1];
355 Polynomial pleft
= Polynomial::parsePolynomial(left
);
356 Polynomial pright
= Polynomial::parsePolynomial(right
);
358 Comparison cmp
= Comparison::mkComparison(atom
.getKind(), pleft
, pright
);
359 Assert(cmp
.isNormalForm());
360 return RewriteResponse(REWRITE_DONE
, cmp
.getNode());
363 RewriteResponse
ArithRewriter::preRewriteAtom(TNode atom
){
364 Assert(isAtom(atom
));
366 NodeManager
* currNM
= NodeManager::currentNM();
368 if(atom
.getKind() == kind::EQUAL
) {
369 if(atom
[0] == atom
[1]) {
370 return RewriteResponse(REWRITE_DONE
, currNM
->mkConst(true));
372 }else if(atom
.getKind() == kind::GT
){
373 Node leq
= currNM
->mkNode(kind::LEQ
, atom
[0], atom
[1]);
374 return RewriteResponse(REWRITE_DONE
, currNM
->mkNode(kind::NOT
, leq
));
375 }else if(atom
.getKind() == kind::LT
){
376 Node geq
= currNM
->mkNode(kind::GEQ
, atom
[0], atom
[1]);
377 return RewriteResponse(REWRITE_DONE
, currNM
->mkNode(kind::NOT
, geq
));
378 }else if(atom
.getKind() == kind::IS_INTEGER
){
379 if(atom
[0].getType().isInteger()){
380 return RewriteResponse(REWRITE_DONE
, currNM
->mkConst(true));
382 }else if(atom
.getKind() == kind::DIVISIBLE
){
383 if(atom
.getOperator().getConst
<Divisible
>().k
.isOne()){
384 return RewriteResponse(REWRITE_DONE
, currNM
->mkConst(true));
388 return RewriteResponse(REWRITE_DONE
, atom
);
391 RewriteResponse
ArithRewriter::postRewrite(TNode t
){
393 RewriteResponse response
= postRewriteTerm(t
);
394 if(Debug
.isOn("arith::rewriter") && response
.status
== REWRITE_DONE
) {
395 Polynomial::parsePolynomial(response
.node
);
399 RewriteResponse response
= postRewriteAtom(t
);
400 if(Debug
.isOn("arith::rewriter") && response
.status
== REWRITE_DONE
) {
401 Comparison::parseNormalForm(response
.node
);
406 return RewriteResponse(REWRITE_DONE
, Node::null());
410 RewriteResponse
ArithRewriter::preRewrite(TNode t
){
412 return preRewriteTerm(t
);
414 return preRewriteAtom(t
);
417 return RewriteResponse(REWRITE_DONE
, Node::null());
421 Node
ArithRewriter::makeUnaryMinusNode(TNode n
){
422 Rational
qNegOne(-1);
423 return NodeManager::currentNM()->mkNode(kind::MULT
, mkRationalNode(qNegOne
),n
);
426 Node
ArithRewriter::makeSubtractionNode(TNode l
, TNode r
){
427 Node negR
= makeUnaryMinusNode(r
);
428 Node diff
= NodeManager::currentNM()->mkNode(kind::PLUS
, l
, negR
);
433 RewriteResponse
ArithRewriter::rewriteDiv(TNode t
, bool pre
){
434 Assert(t
.getKind() == kind::DIVISION_TOTAL
|| t
.getKind()== kind::DIVISION
);
439 if(right
.getKind() == kind::CONST_RATIONAL
){
440 const Rational
& den
= right
.getConst
<Rational
>();
443 if(t
.getKind() == kind::DIVISION_TOTAL
){
444 return RewriteResponse(REWRITE_DONE
, mkRationalNode(0));
446 // This is unsupported, but this is not a good place to complain
447 return RewriteResponse(REWRITE_DONE
, t
);
450 Assert(den
!= Rational(0));
452 if(left
.getKind() == kind::CONST_RATIONAL
){
453 const Rational
& num
= left
.getConst
<Rational
>();
454 Rational div
= num
/ den
;
455 Node result
= mkRationalNode(div
);
456 return RewriteResponse(REWRITE_DONE
, result
);
459 Rational div
= den
.inverse();
461 Node result
= mkRationalNode(div
);
463 Node mult
= NodeManager::currentNM()->mkNode(kind::MULT
,left
,result
);
465 return RewriteResponse(REWRITE_DONE
, mult
);
467 return RewriteResponse(REWRITE_AGAIN
, mult
);
470 return RewriteResponse(REWRITE_DONE
, t
);
474 RewriteResponse
ArithRewriter::rewriteIntsDivModTotal(TNode t
, bool pre
){
475 Kind k
= t
.getKind();
476 // Assert(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL ||
477 // k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
479 //Leaving the function as before (INTS_MODULUS can be handled),
480 // but restricting its use here
481 Assert(k
== kind::INTS_MODULUS_TOTAL
|| k
== kind::INTS_DIVISION_TOTAL
);
482 TNode n
= t
[0], d
= t
[1];
483 bool dIsConstant
= d
.getKind() == kind::CONST_RATIONAL
;
484 if(dIsConstant
&& d
.getConst
<Rational
>().isZero()){
485 if(k
== kind::INTS_MODULUS_TOTAL
|| k
== kind::INTS_DIVISION_TOTAL
){
486 return RewriteResponse(REWRITE_DONE
, mkRationalNode(0));
488 // Do nothing for k == INTS_MODULUS
489 return RewriteResponse(REWRITE_DONE
, t
);
491 }else if(dIsConstant
&& d
.getConst
<Rational
>().isOne()){
492 if(k
== kind::INTS_MODULUS
|| k
== kind::INTS_MODULUS_TOTAL
){
493 return RewriteResponse(REWRITE_DONE
, mkRationalNode(0));
495 Assert(k
== kind::INTS_DIVISION
|| k
== kind::INTS_DIVISION_TOTAL
);
496 return RewriteResponse(REWRITE_AGAIN
, n
);
498 }else if(dIsConstant
&& d
.getConst
<Rational
>().isNegativeOne()){
499 if(k
== kind::INTS_MODULUS
|| k
== kind::INTS_MODULUS_TOTAL
){
500 return RewriteResponse(REWRITE_DONE
, mkRationalNode(0));
502 Assert(k
== kind::INTS_DIVISION
|| k
== kind::INTS_DIVISION_TOTAL
);
503 return RewriteResponse(REWRITE_AGAIN
, NodeManager::currentNM()->mkNode(kind::UMINUS
, n
));
505 }else if(dIsConstant
&& n
.getKind() == kind::CONST_RATIONAL
){
506 Assert(d
.getConst
<Rational
>().isIntegral());
507 Assert(n
.getConst
<Rational
>().isIntegral());
508 Assert(!d
.getConst
<Rational
>().isZero());
509 Integer di
= d
.getConst
<Rational
>().getNumerator();
510 Integer ni
= n
.getConst
<Rational
>().getNumerator();
512 bool isDiv
= (k
== kind::INTS_DIVISION
|| k
== kind::INTS_DIVISION_TOTAL
);
514 Integer result
= isDiv
? ni
.euclidianDivideQuotient(di
) : ni
.euclidianDivideRemainder(di
);
516 Node resultNode
= mkRationalNode(Rational(result
));
517 return RewriteResponse(REWRITE_DONE
, resultNode
);
519 return RewriteResponse(REWRITE_DONE
, t
);
523 }/* CVC4::theory::arith namespace */
524 }/* CVC4::theory namespace */
525 }/* CVC4 namespace */