1 /********************* */
2 /*! \file arith_rewriter.cpp
4 ** Original author: Tim King
5 ** Major contributors: Morgan Deters
6 ** Minor contributors (to current version): Dejan Jovanovic
7 ** This file is part of the CVC4 project.
8 ** Copyright (c) 2009-2013 New York University and The University of Iowa
9 ** See the file COPYING in the top-level source directory for licensing
10 ** information.\endverbatim
12 ** \brief [[ Add one-line brief description here ]]
14 ** [[ Add lengthier description here ]]
15 ** \todo document this file
18 #include "theory/theory.h"
19 #include "theory/arith/normal_form.h"
20 #include "theory/arith/arith_rewriter.h"
21 #include "theory/arith/arith_utilities.h"
31 bool ArithRewriter::isAtom(TNode n
) {
33 return arith::isRelationOperator(k
) || k
== kind::IS_INTEGER
|| k
== kind::DIVISIBLE
;
36 RewriteResponse
ArithRewriter::rewriteConstant(TNode t
){
38 Assert(t
.getKind() == kind::CONST_RATIONAL
);
40 return RewriteResponse(REWRITE_DONE
, t
);
43 RewriteResponse
ArithRewriter::rewriteVariable(TNode t
){
46 return RewriteResponse(REWRITE_DONE
, t
);
49 RewriteResponse
ArithRewriter::rewriteMinus(TNode t
, bool pre
){
50 Assert(t
.getKind()== kind::MINUS
);
55 Node zeroNode
= mkRationalNode(zero
);
56 return RewriteResponse(REWRITE_DONE
, zeroNode
);
58 Node noMinus
= makeSubtractionNode(t
[0],t
[1]);
59 return RewriteResponse(REWRITE_DONE
, noMinus
);
62 Polynomial minuend
= Polynomial::parsePolynomial(t
[0]);
63 Polynomial subtrahend
= Polynomial::parsePolynomial(t
[1]);
64 Polynomial diff
= minuend
- subtrahend
;
65 return RewriteResponse(REWRITE_DONE
, diff
.getNode());
69 RewriteResponse
ArithRewriter::rewriteUMinus(TNode t
, bool pre
){
70 Assert(t
.getKind()== kind::UMINUS
);
72 if(t
[0].getKind() == kind::CONST_RATIONAL
){
73 Rational neg
= -(t
[0].getConst
<Rational
>());
74 return RewriteResponse(REWRITE_DONE
, mkRationalNode(neg
));
77 Node noUminus
= makeUnaryMinusNode(t
[0]);
79 return RewriteResponse(REWRITE_DONE
, noUminus
);
81 return RewriteResponse(REWRITE_AGAIN
, noUminus
);
84 RewriteResponse
ArithRewriter::preRewriteTerm(TNode t
){
86 return rewriteConstant(t
);
88 return rewriteVariable(t
);
90 switch(Kind k
= t
.getKind()){
92 return rewriteMinus(t
, true);
94 return rewriteUMinus(t
, true);
96 case kind::DIVISION_TOTAL
:
97 return rewriteDiv(t
,true);
99 return preRewritePlus(t
);
101 return preRewriteMult(t
);
102 case kind::INTS_DIVISION
:
103 case kind::INTS_MODULUS
:
104 return RewriteResponse(REWRITE_DONE
, t
);
105 case kind::INTS_DIVISION_TOTAL
:
106 case kind::INTS_MODULUS_TOTAL
:
107 return rewriteIntsDivModTotal(t
,true);
110 const Rational
& rat
= t
[0].getConst
<Rational
>();
112 return RewriteResponse(REWRITE_DONE
, t
[0]);
114 return RewriteResponse(REWRITE_DONE
,
115 NodeManager::currentNM()->mkConst(-rat
));
118 return RewriteResponse(REWRITE_DONE
, t
);
119 case kind::IS_INTEGER
:
120 case kind::TO_INTEGER
:
121 return RewriteResponse(REWRITE_DONE
, t
);
123 return RewriteResponse(REWRITE_DONE
, t
[0]);
125 return RewriteResponse(REWRITE_DONE
, t
);
132 RewriteResponse
ArithRewriter::postRewriteTerm(TNode t
){
134 return rewriteConstant(t
);
136 return rewriteVariable(t
);
140 return rewriteMinus(t
, false);
142 return rewriteUMinus(t
, false);
144 case kind::DIVISION_TOTAL
:
145 return rewriteDiv(t
, false);
147 return postRewritePlus(t
);
149 return postRewriteMult(t
);
150 case kind::INTS_DIVISION
:
151 case kind::INTS_MODULUS
:
152 return RewriteResponse(REWRITE_DONE
, t
);
153 case kind::INTS_DIVISION_TOTAL
:
154 case kind::INTS_MODULUS_TOTAL
:
155 return rewriteIntsDivModTotal(t
, false);
158 const Rational
& rat
= t
[0].getConst
<Rational
>();
160 return RewriteResponse(REWRITE_DONE
, t
[0]);
162 return RewriteResponse(REWRITE_DONE
,
163 NodeManager::currentNM()->mkConst(-rat
));
167 return RewriteResponse(REWRITE_DONE
, t
[0]);
168 case kind::TO_INTEGER
:
170 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(Rational(t
[0].getConst
<Rational
>().floor())));
172 if(t
[0].getType().isInteger()) {
173 return RewriteResponse(REWRITE_DONE
, t
[0]);
175 //Unimplemented("TO_INTEGER, nonconstant");
176 //return rewriteToInteger(t);
177 return RewriteResponse(REWRITE_DONE
, t
);
178 case kind::IS_INTEGER
:
180 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(t
[0].getConst
<Rational
>().getDenominator() == 1));
182 if(t
[0].getType().isInteger()) {
183 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(true));
185 //Unimplemented("IS_INTEGER, nonconstant");
186 //return rewriteIsInteger(t);
187 return RewriteResponse(REWRITE_DONE
, t
);
190 if(t
[1].getKind() == kind::CONST_RATIONAL
){
191 const Rational
& exp
= t
[1].getConst
<Rational
>();
194 return RewriteResponse(REWRITE_DONE
, mkRationalNode(Rational(1)));
195 }else if(exp
.sgn() > 0 && exp
.isIntegral()){
196 Integer num
= exp
.getNumerator();
197 NodeBuilder
<> nb(kind::MULT
);
199 for(Integer
i(0); i
< num
; i
= i
+ one
){
202 Assert(nb
.getNumChildren() > 0);
204 return RewriteResponse(REWRITE_AGAIN
, mult
);
208 // Todo improve the exception thrown
209 std::stringstream ss
;
210 ss
<< "The POW(^) operator can only be used with a natural number ";
211 ss
<< "in the exponent. Exception occured in:" << std::endl
;
213 throw Exception(ss
.str());
222 RewriteResponse
ArithRewriter::preRewriteMult(TNode t
){
223 Assert(t
.getKind()== kind::MULT
);
225 if(t
.getNumChildren() == 2){
226 if(t
[0].getKind() == kind::CONST_RATIONAL
227 && t
[0].getConst
<Rational
>().isOne()){
228 return RewriteResponse(REWRITE_DONE
, t
[1]);
230 if(t
[1].getKind() == kind::CONST_RATIONAL
231 && t
[1].getConst
<Rational
>().isOne()){
232 return RewriteResponse(REWRITE_DONE
, t
[0]);
236 // Rewrite multiplications with a 0 argument and to 0
237 for(TNode::iterator i
= t
.begin(); i
!= t
.end(); ++i
) {
238 if((*i
).getKind() == kind::CONST_RATIONAL
) {
239 if((*i
).getConst
<Rational
>().isZero()) {
241 return RewriteResponse(REWRITE_DONE
, zero
);
245 return RewriteResponse(REWRITE_DONE
, t
);
248 static bool canFlatten(Kind k
, TNode t
){
249 for(TNode::iterator i
= t
.begin(); i
!= t
.end(); ++i
) {
251 if(child
.getKind() == k
){
258 static void flatten(std::vector
<TNode
>& pb
, Kind k
, TNode t
){
259 if(t
.getKind() == k
){
260 for(TNode::iterator i
= t
.begin(); i
!= t
.end(); ++i
) {
262 if(child
.getKind() == k
){
263 flatten(pb
, k
, child
);
273 static Node
flatten(Kind k
, TNode t
){
274 std::vector
<TNode
> pb
;
276 Assert(pb
.size() >= 2);
277 return NodeManager::currentNM()->mkNode(k
, pb
);
280 RewriteResponse
ArithRewriter::preRewritePlus(TNode t
){
281 Assert(t
.getKind()== kind::PLUS
);
283 if(canFlatten(kind::PLUS
, t
)){
284 return RewriteResponse(REWRITE_DONE
, flatten(kind::PLUS
, t
));
286 return RewriteResponse(REWRITE_DONE
, t
);
290 RewriteResponse
ArithRewriter::postRewritePlus(TNode t
){
291 Assert(t
.getKind()== kind::PLUS
);
293 std::vector
<Monomial
> monomials
;
294 std::vector
<Polynomial
> polynomials
;
296 for(TNode::iterator i
= t
.begin(), end
= t
.end(); i
!= end
; ++i
){
298 if(Monomial::isMember(curr
)){
299 monomials
.push_back(Monomial::parseMonomial(curr
));
301 polynomials
.push_back(Polynomial::parsePolynomial(curr
));
305 if(!monomials
.empty()){
306 Monomial::sort(monomials
);
307 Monomial::combineAdjacentMonomials(monomials
);
308 polynomials
.push_back(Polynomial::mkPolynomial(monomials
));
311 Polynomial res
= Polynomial::sumPolynomials(polynomials
);
313 return RewriteResponse(REWRITE_DONE
, res
.getNode());
316 RewriteResponse
ArithRewriter::postRewriteMult(TNode t
){
317 Assert(t
.getKind()== kind::MULT
);
319 Polynomial res
= Polynomial::mkOne();
321 for(TNode::iterator i
= t
.begin(), end
= t
.end(); i
!= end
; ++i
){
323 Polynomial currPoly
= Polynomial::parsePolynomial(curr
);
325 res
= res
* currPoly
;
328 return RewriteResponse(REWRITE_DONE
, res
.getNode());
331 RewriteResponse
ArithRewriter::postRewriteAtom(TNode atom
){
332 if(atom
.getKind() == kind::IS_INTEGER
) {
333 if(atom
[0].isConst()) {
334 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(atom
[0].getConst
<Rational
>().isIntegral()));
336 if(atom
[0].getType().isInteger()) {
337 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(true));
339 // not supported, but this isn't the right place to complain
340 return RewriteResponse(REWRITE_DONE
, atom
);
341 } else if(atom
.getKind() == kind::DIVISIBLE
) {
342 if(atom
[0].isConst()) {
343 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(bool((atom
[0].getConst
<Rational
>() / atom
.getOperator().getConst
<Divisible
>().k
).isIntegral())));
345 if(atom
.getOperator().getConst
<Divisible
>().k
.isOne()) {
346 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(true));
348 return RewriteResponse(REWRITE_AGAIN
, NodeManager::currentNM()->mkNode(kind::EQUAL
, NodeManager::currentNM()->mkNode(kind::INTS_MODULUS_TOTAL
, atom
[0], NodeManager::currentNM()->mkConst(Rational(atom
.getOperator().getConst
<Divisible
>().k
))), NodeManager::currentNM()->mkConst(Rational(0))));
352 TNode left
= atom
[0];
353 TNode right
= atom
[1];
355 Polynomial pleft
= Polynomial::parsePolynomial(left
);
356 Polynomial pright
= Polynomial::parsePolynomial(right
);
358 Debug("arith::rewriter") << "pleft " << pleft
.getNode() << std::endl
;
359 Debug("arith::rewriter") << "pright " << pright
.getNode() << std::endl
;
361 Comparison cmp
= Comparison::mkComparison(atom
.getKind(), pleft
, pright
);
362 Assert(cmp
.isNormalForm());
363 return RewriteResponse(REWRITE_DONE
, cmp
.getNode());
366 RewriteResponse
ArithRewriter::preRewriteAtom(TNode atom
){
367 Assert(isAtom(atom
));
369 NodeManager
* currNM
= NodeManager::currentNM();
371 if(atom
.getKind() == kind::EQUAL
) {
372 if(atom
[0] == atom
[1]) {
373 return RewriteResponse(REWRITE_DONE
, currNM
->mkConst(true));
375 }else if(atom
.getKind() == kind::GT
){
376 Node leq
= currNM
->mkNode(kind::LEQ
, atom
[0], atom
[1]);
377 return RewriteResponse(REWRITE_DONE
, currNM
->mkNode(kind::NOT
, leq
));
378 }else if(atom
.getKind() == kind::LT
){
379 Node geq
= currNM
->mkNode(kind::GEQ
, atom
[0], atom
[1]);
380 return RewriteResponse(REWRITE_DONE
, currNM
->mkNode(kind::NOT
, geq
));
381 }else if(atom
.getKind() == kind::IS_INTEGER
){
382 if(atom
[0].getType().isInteger()){
383 return RewriteResponse(REWRITE_DONE
, currNM
->mkConst(true));
385 }else if(atom
.getKind() == kind::DIVISIBLE
){
386 if(atom
.getOperator().getConst
<Divisible
>().k
.isOne()){
387 return RewriteResponse(REWRITE_DONE
, currNM
->mkConst(true));
391 return RewriteResponse(REWRITE_DONE
, atom
);
394 RewriteResponse
ArithRewriter::postRewrite(TNode t
){
396 RewriteResponse response
= postRewriteTerm(t
);
397 if(Debug
.isOn("arith::rewriter") && response
.status
== REWRITE_DONE
) {
398 Polynomial::parsePolynomial(response
.node
);
402 RewriteResponse response
= postRewriteAtom(t
);
403 if(Debug
.isOn("arith::rewriter") && response
.status
== REWRITE_DONE
) {
404 Comparison::parseNormalForm(response
.node
);
409 return RewriteResponse(REWRITE_DONE
, Node::null());
413 RewriteResponse
ArithRewriter::preRewrite(TNode t
){
415 return preRewriteTerm(t
);
417 return preRewriteAtom(t
);
420 return RewriteResponse(REWRITE_DONE
, Node::null());
424 Node
ArithRewriter::makeUnaryMinusNode(TNode n
){
425 Rational
qNegOne(-1);
426 return NodeManager::currentNM()->mkNode(kind::MULT
, mkRationalNode(qNegOne
),n
);
429 Node
ArithRewriter::makeSubtractionNode(TNode l
, TNode r
){
430 Node negR
= makeUnaryMinusNode(r
);
431 Node diff
= NodeManager::currentNM()->mkNode(kind::PLUS
, l
, negR
);
436 RewriteResponse
ArithRewriter::rewriteDiv(TNode t
, bool pre
){
437 Assert(t
.getKind() == kind::DIVISION_TOTAL
|| t
.getKind()== kind::DIVISION
);
442 if(right
.getKind() == kind::CONST_RATIONAL
){
443 const Rational
& den
= right
.getConst
<Rational
>();
446 if(t
.getKind() == kind::DIVISION_TOTAL
){
447 return RewriteResponse(REWRITE_DONE
, mkRationalNode(0));
449 // This is unsupported, but this is not a good place to complain
450 return RewriteResponse(REWRITE_DONE
, t
);
453 Assert(den
!= Rational(0));
455 if(left
.getKind() == kind::CONST_RATIONAL
){
456 const Rational
& num
= left
.getConst
<Rational
>();
457 Rational div
= num
/ den
;
458 Node result
= mkRationalNode(div
);
459 return RewriteResponse(REWRITE_DONE
, result
);
462 Rational div
= den
.inverse();
464 Node result
= mkRationalNode(div
);
466 Node mult
= NodeManager::currentNM()->mkNode(kind::MULT
,left
,result
);
468 return RewriteResponse(REWRITE_DONE
, mult
);
470 return RewriteResponse(REWRITE_AGAIN
, mult
);
473 return RewriteResponse(REWRITE_DONE
, t
);
477 RewriteResponse
ArithRewriter::rewriteIntsDivModTotal(TNode t
, bool pre
){
478 Kind k
= t
.getKind();
479 // Assert(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL ||
480 // k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
482 //Leaving the function as before (INTS_MODULUS can be handled),
483 // but restricting its use here
484 Assert(k
== kind::INTS_MODULUS_TOTAL
|| k
== kind::INTS_DIVISION_TOTAL
);
485 TNode n
= t
[0], d
= t
[1];
486 bool dIsConstant
= d
.getKind() == kind::CONST_RATIONAL
;
487 if(dIsConstant
&& d
.getConst
<Rational
>().isZero()){
488 if(k
== kind::INTS_MODULUS_TOTAL
|| k
== kind::INTS_DIVISION_TOTAL
){
489 return RewriteResponse(REWRITE_DONE
, mkRationalNode(0));
491 // Do nothing for k == INTS_MODULUS
492 return RewriteResponse(REWRITE_DONE
, t
);
494 }else if(dIsConstant
&& d
.getConst
<Rational
>().isOne()){
495 if(k
== kind::INTS_MODULUS
|| k
== kind::INTS_MODULUS_TOTAL
){
496 return RewriteResponse(REWRITE_DONE
, mkRationalNode(0));
498 Assert(k
== kind::INTS_DIVISION
|| k
== kind::INTS_DIVISION_TOTAL
);
499 return RewriteResponse(REWRITE_AGAIN
, n
);
501 }else if(dIsConstant
&& d
.getConst
<Rational
>().isNegativeOne()){
502 if(k
== kind::INTS_MODULUS
|| k
== kind::INTS_MODULUS_TOTAL
){
503 return RewriteResponse(REWRITE_DONE
, mkRationalNode(0));
505 Assert(k
== kind::INTS_DIVISION
|| k
== kind::INTS_DIVISION_TOTAL
);
506 return RewriteResponse(REWRITE_AGAIN
, NodeManager::currentNM()->mkNode(kind::UMINUS
, n
));
508 }else if(dIsConstant
&& n
.getKind() == kind::CONST_RATIONAL
){
509 Assert(d
.getConst
<Rational
>().isIntegral());
510 Assert(n
.getConst
<Rational
>().isIntegral());
511 Assert(!d
.getConst
<Rational
>().isZero());
512 Integer di
= d
.getConst
<Rational
>().getNumerator();
513 Integer ni
= n
.getConst
<Rational
>().getNumerator();
515 bool isDiv
= (k
== kind::INTS_DIVISION
|| k
== kind::INTS_DIVISION_TOTAL
);
517 Integer result
= isDiv
? ni
.euclidianDivideQuotient(di
) : ni
.euclidianDivideRemainder(di
);
519 Node resultNode
= mkRationalNode(Rational(result
));
520 return RewriteResponse(REWRITE_DONE
, resultNode
);
522 return RewriteResponse(REWRITE_DONE
, t
);
526 }/* CVC4::theory::arith namespace */
527 }/* CVC4::theory namespace */
528 }/* CVC4 namespace */