1 /********************* */
2 /*! \file arith_rewriter.cpp
4 ** Top contributors (to current version):
5 ** Tim King, Morgan Deters, Dejan Jovanovic
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2016 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
12 ** \brief [[ Add one-line brief description here ]]
14 ** [[ Add lengthier description here ]]
15 ** \todo document this file
22 #include "smt/logic_exception.h"
23 #include "theory/arith/arith_rewriter.h"
24 #include "theory/arith/arith_utilities.h"
25 #include "theory/arith/normal_form.h"
26 #include "theory/theory.h"
32 bool ArithRewriter::isAtom(TNode n
) {
34 return arith::isRelationOperator(k
) || k
== kind::IS_INTEGER
35 || k
== kind::DIVISIBLE
;
38 RewriteResponse
ArithRewriter::rewriteConstant(TNode t
){
40 Assert(t
.getKind() == kind::CONST_RATIONAL
);
42 return RewriteResponse(REWRITE_DONE
, t
);
45 RewriteResponse
ArithRewriter::rewriteVariable(TNode t
){
48 return RewriteResponse(REWRITE_DONE
, t
);
51 RewriteResponse
ArithRewriter::rewriteMinus(TNode t
, bool pre
){
52 Assert(t
.getKind()== kind::MINUS
);
57 Node zeroNode
= mkRationalNode(zero
);
58 return RewriteResponse(REWRITE_DONE
, zeroNode
);
60 Node noMinus
= makeSubtractionNode(t
[0],t
[1]);
61 return RewriteResponse(REWRITE_DONE
, noMinus
);
64 Polynomial minuend
= Polynomial::parsePolynomial(t
[0]);
65 Polynomial subtrahend
= Polynomial::parsePolynomial(t
[1]);
66 Polynomial diff
= minuend
- subtrahend
;
67 return RewriteResponse(REWRITE_DONE
, diff
.getNode());
71 RewriteResponse
ArithRewriter::rewriteUMinus(TNode t
, bool pre
){
72 Assert(t
.getKind()== kind::UMINUS
);
74 if(t
[0].getKind() == kind::CONST_RATIONAL
){
75 Rational neg
= -(t
[0].getConst
<Rational
>());
76 return RewriteResponse(REWRITE_DONE
, mkRationalNode(neg
));
79 Node noUminus
= makeUnaryMinusNode(t
[0]);
81 return RewriteResponse(REWRITE_DONE
, noUminus
);
83 return RewriteResponse(REWRITE_AGAIN
, noUminus
);
86 RewriteResponse
ArithRewriter::preRewriteTerm(TNode t
){
88 return rewriteConstant(t
);
90 return rewriteVariable(t
);
92 switch(Kind k
= t
.getKind()){
94 return rewriteMinus(t
, true);
96 return rewriteUMinus(t
, true);
98 case kind::DIVISION_TOTAL
:
99 return rewriteDiv(t
,true);
101 return preRewritePlus(t
);
103 case kind::NONLINEAR_MULT
:
104 return preRewriteMult(t
);
105 case kind::INTS_DIVISION
:
106 case kind::INTS_MODULUS
:
107 return RewriteResponse(REWRITE_DONE
, t
);
108 case kind::INTS_DIVISION_TOTAL
:
109 case kind::INTS_MODULUS_TOTAL
:
110 return rewriteIntsDivModTotal(t
,true);
113 const Rational
& rat
= t
[0].getConst
<Rational
>();
115 return RewriteResponse(REWRITE_DONE
, t
[0]);
117 return RewriteResponse(REWRITE_DONE
,
118 NodeManager::currentNM()->mkConst(-rat
));
121 return RewriteResponse(REWRITE_DONE
, t
);
122 case kind::IS_INTEGER
:
123 case kind::TO_INTEGER
:
124 return RewriteResponse(REWRITE_DONE
, t
);
126 return RewriteResponse(REWRITE_DONE
, t
[0]);
128 return RewriteResponse(REWRITE_DONE
, t
);
135 RewriteResponse
ArithRewriter::postRewriteTerm(TNode t
){
137 return rewriteConstant(t
);
139 return rewriteVariable(t
);
143 return rewriteMinus(t
, false);
145 return rewriteUMinus(t
, false);
147 case kind::DIVISION_TOTAL
:
148 return rewriteDiv(t
, false);
150 return postRewritePlus(t
);
152 case kind::NONLINEAR_MULT
:
153 return postRewriteMult(t
);
154 case kind::INTS_DIVISION
:
155 case kind::INTS_MODULUS
:
156 return RewriteResponse(REWRITE_DONE
, t
);
157 case kind::INTS_DIVISION_TOTAL
:
158 case kind::INTS_MODULUS_TOTAL
:
159 return rewriteIntsDivModTotal(t
, false);
162 const Rational
& rat
= t
[0].getConst
<Rational
>();
164 return RewriteResponse(REWRITE_DONE
, t
[0]);
166 return RewriteResponse(REWRITE_DONE
,
167 NodeManager::currentNM()->mkConst(-rat
));
171 return RewriteResponse(REWRITE_DONE
, t
[0]);
172 case kind::TO_INTEGER
:
174 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(Rational(t
[0].getConst
<Rational
>().floor())));
176 if(t
[0].getType().isInteger()) {
177 return RewriteResponse(REWRITE_DONE
, t
[0]);
179 //Unimplemented("TO_INTEGER, nonconstant");
180 //return rewriteToInteger(t);
181 return RewriteResponse(REWRITE_DONE
, t
);
182 case kind::IS_INTEGER
:
184 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(t
[0].getConst
<Rational
>().getDenominator() == 1));
186 if(t
[0].getType().isInteger()) {
187 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(true));
189 //Unimplemented("IS_INTEGER, nonconstant");
190 //return rewriteIsInteger(t);
191 return RewriteResponse(REWRITE_DONE
, t
);
194 if(t
[1].getKind() == kind::CONST_RATIONAL
){
195 const Rational
& exp
= t
[1].getConst
<Rational
>();
198 return RewriteResponse(REWRITE_DONE
, mkRationalNode(Rational(1)));
199 }else if(exp
.sgn() > 0 && exp
.isIntegral()){
200 Integer num
= exp
.getNumerator();
201 NodeBuilder
<> nb(kind::MULT
);
203 for(Integer
i(0); i
< num
; i
= i
+ one
){
206 Assert(nb
.getNumChildren() > 0);
208 return RewriteResponse(REWRITE_AGAIN
, mult
);
212 // Todo improve the exception thrown
213 std::stringstream ss
;
214 ss
<< "The POW(^) operator can only be used with a natural number ";
215 ss
<< "in the exponent. Exception occured in:" << std::endl
;
217 throw LogicException(ss
.str());
226 RewriteResponse
ArithRewriter::preRewriteMult(TNode t
){
227 Assert(t
.getKind()== kind::MULT
|| t
.getKind()== kind::NONLINEAR_MULT
);
229 if(t
.getNumChildren() == 2){
230 if(t
[0].getKind() == kind::CONST_RATIONAL
231 && t
[0].getConst
<Rational
>().isOne()){
232 return RewriteResponse(REWRITE_DONE
, t
[1]);
234 if(t
[1].getKind() == kind::CONST_RATIONAL
235 && t
[1].getConst
<Rational
>().isOne()){
236 return RewriteResponse(REWRITE_DONE
, t
[0]);
240 // Rewrite multiplications with a 0 argument and to 0
241 for(TNode::iterator i
= t
.begin(); i
!= t
.end(); ++i
) {
242 if((*i
).getKind() == kind::CONST_RATIONAL
) {
243 if((*i
).getConst
<Rational
>().isZero()) {
245 return RewriteResponse(REWRITE_DONE
, zero
);
249 return RewriteResponse(REWRITE_DONE
, t
);
252 static bool canFlatten(Kind k
, TNode t
){
253 for(TNode::iterator i
= t
.begin(); i
!= t
.end(); ++i
) {
255 if(child
.getKind() == k
){
262 static void flatten(std::vector
<TNode
>& pb
, Kind k
, TNode t
){
263 if(t
.getKind() == k
){
264 for(TNode::iterator i
= t
.begin(); i
!= t
.end(); ++i
) {
266 if(child
.getKind() == k
){
267 flatten(pb
, k
, child
);
277 static Node
flatten(Kind k
, TNode t
){
278 std::vector
<TNode
> pb
;
280 Assert(pb
.size() >= 2);
281 return NodeManager::currentNM()->mkNode(k
, pb
);
284 RewriteResponse
ArithRewriter::preRewritePlus(TNode t
){
285 Assert(t
.getKind()== kind::PLUS
);
287 if(canFlatten(kind::PLUS
, t
)){
288 return RewriteResponse(REWRITE_DONE
, flatten(kind::PLUS
, t
));
290 return RewriteResponse(REWRITE_DONE
, t
);
294 RewriteResponse
ArithRewriter::postRewritePlus(TNode t
){
295 Assert(t
.getKind()== kind::PLUS
);
297 std::vector
<Monomial
> monomials
;
298 std::vector
<Polynomial
> polynomials
;
300 for(TNode::iterator i
= t
.begin(), end
= t
.end(); i
!= end
; ++i
){
302 if(Monomial::isMember(curr
)){
303 monomials
.push_back(Monomial::parseMonomial(curr
));
305 polynomials
.push_back(Polynomial::parsePolynomial(curr
));
309 if(!monomials
.empty()){
310 Monomial::sort(monomials
);
311 Monomial::combineAdjacentMonomials(monomials
);
312 polynomials
.push_back(Polynomial::mkPolynomial(monomials
));
315 Polynomial res
= Polynomial::sumPolynomials(polynomials
);
317 return RewriteResponse(REWRITE_DONE
, res
.getNode());
320 RewriteResponse
ArithRewriter::postRewriteMult(TNode t
){
321 Assert(t
.getKind()== kind::MULT
|| t
.getKind()==kind::NONLINEAR_MULT
);
323 Polynomial res
= Polynomial::mkOne();
325 for(TNode::iterator i
= t
.begin(), end
= t
.end(); i
!= end
; ++i
){
327 Polynomial currPoly
= Polynomial::parsePolynomial(curr
);
329 res
= res
* currPoly
;
332 return RewriteResponse(REWRITE_DONE
, res
.getNode());
335 RewriteResponse
ArithRewriter::postRewriteAtom(TNode atom
){
336 if(atom
.getKind() == kind::IS_INTEGER
) {
337 if(atom
[0].isConst()) {
338 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(atom
[0].getConst
<Rational
>().isIntegral()));
340 if(atom
[0].getType().isInteger()) {
341 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(true));
343 // not supported, but this isn't the right place to complain
344 return RewriteResponse(REWRITE_DONE
, atom
);
345 } else if(atom
.getKind() == kind::DIVISIBLE
) {
346 if(atom
[0].isConst()) {
347 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(bool((atom
[0].getConst
<Rational
>() / atom
.getOperator().getConst
<Divisible
>().k
).isIntegral())));
349 if(atom
.getOperator().getConst
<Divisible
>().k
.isOne()) {
350 return RewriteResponse(REWRITE_DONE
, NodeManager::currentNM()->mkConst(true));
352 return RewriteResponse(REWRITE_AGAIN
, NodeManager::currentNM()->mkNode(kind::EQUAL
, NodeManager::currentNM()->mkNode(kind::INTS_MODULUS_TOTAL
, atom
[0], NodeManager::currentNM()->mkConst(Rational(atom
.getOperator().getConst
<Divisible
>().k
))), NodeManager::currentNM()->mkConst(Rational(0))));
356 TNode left
= atom
[0];
357 TNode right
= atom
[1];
359 Polynomial pleft
= Polynomial::parsePolynomial(left
);
360 Polynomial pright
= Polynomial::parsePolynomial(right
);
362 Debug("arith::rewriter") << "pleft " << pleft
.getNode() << std::endl
;
363 Debug("arith::rewriter") << "pright " << pright
.getNode() << std::endl
;
365 Comparison cmp
= Comparison::mkComparison(atom
.getKind(), pleft
, pright
);
366 Assert(cmp
.isNormalForm());
367 return RewriteResponse(REWRITE_DONE
, cmp
.getNode());
370 RewriteResponse
ArithRewriter::preRewriteAtom(TNode atom
){
371 Assert(isAtom(atom
));
373 NodeManager
* currNM
= NodeManager::currentNM();
375 if(atom
.getKind() == kind::EQUAL
) {
376 if(atom
[0] == atom
[1]) {
377 return RewriteResponse(REWRITE_DONE
, currNM
->mkConst(true));
379 }else if(atom
.getKind() == kind::GT
){
380 Node leq
= currNM
->mkNode(kind::LEQ
, atom
[0], atom
[1]);
381 return RewriteResponse(REWRITE_DONE
, currNM
->mkNode(kind::NOT
, leq
));
382 }else if(atom
.getKind() == kind::LT
){
383 Node geq
= currNM
->mkNode(kind::GEQ
, atom
[0], atom
[1]);
384 return RewriteResponse(REWRITE_DONE
, currNM
->mkNode(kind::NOT
, geq
));
385 }else if(atom
.getKind() == kind::IS_INTEGER
){
386 if(atom
[0].getType().isInteger()){
387 return RewriteResponse(REWRITE_DONE
, currNM
->mkConst(true));
389 }else if(atom
.getKind() == kind::DIVISIBLE
){
390 if(atom
.getOperator().getConst
<Divisible
>().k
.isOne()){
391 return RewriteResponse(REWRITE_DONE
, currNM
->mkConst(true));
395 return RewriteResponse(REWRITE_DONE
, atom
);
398 RewriteResponse
ArithRewriter::postRewrite(TNode t
){
400 RewriteResponse response
= postRewriteTerm(t
);
401 if(Debug
.isOn("arith::rewriter") && response
.status
== REWRITE_DONE
) {
402 Polynomial::parsePolynomial(response
.node
);
406 RewriteResponse response
= postRewriteAtom(t
);
407 if(Debug
.isOn("arith::rewriter") && response
.status
== REWRITE_DONE
) {
408 Comparison::parseNormalForm(response
.node
);
413 return RewriteResponse(REWRITE_DONE
, Node::null());
417 RewriteResponse
ArithRewriter::preRewrite(TNode t
){
419 return preRewriteTerm(t
);
421 return preRewriteAtom(t
);
424 return RewriteResponse(REWRITE_DONE
, Node::null());
428 Node
ArithRewriter::makeUnaryMinusNode(TNode n
){
429 Rational
qNegOne(-1);
430 return NodeManager::currentNM()->mkNode(kind::MULT
, mkRationalNode(qNegOne
),n
);
433 Node
ArithRewriter::makeSubtractionNode(TNode l
, TNode r
){
434 Node negR
= makeUnaryMinusNode(r
);
435 Node diff
= NodeManager::currentNM()->mkNode(kind::PLUS
, l
, negR
);
440 RewriteResponse
ArithRewriter::rewriteDiv(TNode t
, bool pre
){
441 Assert(t
.getKind() == kind::DIVISION_TOTAL
|| t
.getKind()== kind::DIVISION
);
446 if(right
.getKind() == kind::CONST_RATIONAL
){
447 const Rational
& den
= right
.getConst
<Rational
>();
450 if(t
.getKind() == kind::DIVISION_TOTAL
){
451 return RewriteResponse(REWRITE_DONE
, mkRationalNode(0));
453 // This is unsupported, but this is not a good place to complain
454 return RewriteResponse(REWRITE_DONE
, t
);
457 Assert(den
!= Rational(0));
459 if(left
.getKind() == kind::CONST_RATIONAL
){
460 const Rational
& num
= left
.getConst
<Rational
>();
461 Rational div
= num
/ den
;
462 Node result
= mkRationalNode(div
);
463 return RewriteResponse(REWRITE_DONE
, result
);
466 Rational div
= den
.inverse();
468 Node result
= mkRationalNode(div
);
470 Node mult
= NodeManager::currentNM()->mkNode(kind::MULT
,left
,result
);
472 return RewriteResponse(REWRITE_DONE
, mult
);
474 return RewriteResponse(REWRITE_AGAIN
, mult
);
477 return RewriteResponse(REWRITE_DONE
, t
);
481 RewriteResponse
ArithRewriter::rewriteIntsDivModTotal(TNode t
, bool pre
){
482 Kind k
= t
.getKind();
483 // Assert(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL ||
484 // k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
486 //Leaving the function as before (INTS_MODULUS can be handled),
487 // but restricting its use here
488 Assert(k
== kind::INTS_MODULUS_TOTAL
|| k
== kind::INTS_DIVISION_TOTAL
);
489 TNode n
= t
[0], d
= t
[1];
490 bool dIsConstant
= d
.getKind() == kind::CONST_RATIONAL
;
491 if(dIsConstant
&& d
.getConst
<Rational
>().isZero()){
492 if(k
== kind::INTS_MODULUS_TOTAL
|| k
== kind::INTS_DIVISION_TOTAL
){
493 return RewriteResponse(REWRITE_DONE
, mkRationalNode(0));
495 // Do nothing for k == INTS_MODULUS
496 return RewriteResponse(REWRITE_DONE
, t
);
498 }else if(dIsConstant
&& d
.getConst
<Rational
>().isOne()){
499 if(k
== kind::INTS_MODULUS
|| k
== kind::INTS_MODULUS_TOTAL
){
500 return RewriteResponse(REWRITE_DONE
, mkRationalNode(0));
502 Assert(k
== kind::INTS_DIVISION
|| k
== kind::INTS_DIVISION_TOTAL
);
503 return RewriteResponse(REWRITE_AGAIN
, n
);
505 }else if(dIsConstant
&& d
.getConst
<Rational
>().isNegativeOne()){
506 if(k
== kind::INTS_MODULUS
|| k
== kind::INTS_MODULUS_TOTAL
){
507 return RewriteResponse(REWRITE_DONE
, mkRationalNode(0));
509 Assert(k
== kind::INTS_DIVISION
|| k
== kind::INTS_DIVISION_TOTAL
);
510 return RewriteResponse(REWRITE_AGAIN
, NodeManager::currentNM()->mkNode(kind::UMINUS
, n
));
512 }else if(dIsConstant
&& n
.getKind() == kind::CONST_RATIONAL
){
513 Assert(d
.getConst
<Rational
>().isIntegral());
514 Assert(n
.getConst
<Rational
>().isIntegral());
515 Assert(!d
.getConst
<Rational
>().isZero());
516 Integer di
= d
.getConst
<Rational
>().getNumerator();
517 Integer ni
= n
.getConst
<Rational
>().getNumerator();
519 bool isDiv
= (k
== kind::INTS_DIVISION
|| k
== kind::INTS_DIVISION_TOTAL
);
521 Integer result
= isDiv
? ni
.euclidianDivideQuotient(di
) : ni
.euclidianDivideRemainder(di
);
523 Node resultNode
= mkRationalNode(Rational(result
));
524 return RewriteResponse(REWRITE_DONE
, resultNode
);
526 return RewriteResponse(REWRITE_DONE
, t
);
530 }/* CVC4::theory::arith namespace */
531 }/* CVC4::theory namespace */
532 }/* CVC4 namespace */