1 /********************* */
2 /*! \file arith_rewriter.cpp
4 ** Original author: Tim King
5 ** Major contributors: none
6 ** Minor contributors (to current version): Morgan Deters, Dejan Jovanovic
7 ** This file is part of the CVC4 project.
8 ** Copyright (c) 2009-2013 New York University and The University of Iowa
9 ** See the file COPYING in the top-level source directory for licensing
10 ** information.\endverbatim
12 ** \brief [[ Add one-line brief description here ]]
14 ** [[ Add lengthier description here ]]
15 ** \todo document this file
18 #include "theory/theory.h"
19 #include "theory/arith/normal_form.h"
20 #include "theory/arith/arith_rewriter.h"
21 #include "theory/arith/arith_utilities.h"
31 bool ArithRewriter::isAtom(TNode n
) {
32 return arith::isRelationOperator(n
.getKind());
35 RewriteResponse
ArithRewriter::rewriteConstant(TNode t
){
37 Assert(t
.getKind() == kind::CONST_RATIONAL
);
39 return RewriteResponse(REWRITE_DONE
, t
);
42 RewriteResponse
ArithRewriter::rewriteVariable(TNode t
){
45 return RewriteResponse(REWRITE_DONE
, t
);
48 RewriteResponse
ArithRewriter::rewriteMinus(TNode t
, bool pre
){
49 Assert(t
.getKind()== kind::MINUS
);
54 Node zeroNode
= mkRationalNode(zero
);
55 return RewriteResponse(REWRITE_DONE
, zeroNode
);
57 Node noMinus
= makeSubtractionNode(t
[0],t
[1]);
58 return RewriteResponse(REWRITE_DONE
, noMinus
);
61 Polynomial minuend
= Polynomial::parsePolynomial(t
[0]);
62 Polynomial subtrahend
= Polynomial::parsePolynomial(t
[1]);
63 Polynomial diff
= minuend
- subtrahend
;
64 return RewriteResponse(REWRITE_DONE
, diff
.getNode());
68 RewriteResponse
ArithRewriter::rewriteUMinus(TNode t
, bool pre
){
69 Assert(t
.getKind()== kind::UMINUS
);
71 if(t
[0].getKind() == kind::CONST_RATIONAL
){
72 Rational neg
= -(t
[0].getConst
<Rational
>());
73 return RewriteResponse(REWRITE_DONE
, mkRationalNode(neg
));
76 Node noUminus
= makeUnaryMinusNode(t
[0]);
78 return RewriteResponse(REWRITE_DONE
, noUminus
);
80 return RewriteResponse(REWRITE_AGAIN
, noUminus
);
83 RewriteResponse
ArithRewriter::preRewriteTerm(TNode t
){
85 return rewriteConstant(t
);
87 return rewriteVariable(t
);
89 switch(Kind k
= t
.getKind()){
91 return rewriteMinus(t
, true);
93 return rewriteUMinus(t
, true);
95 case kind::DIVISION_TOTAL
:
96 return rewriteDiv(t
,true);
98 return preRewritePlus(t
);
100 return preRewriteMult(t
);
101 //case kind::INTS_DIVISION:
102 //case kind::INTS_MODULUS:
103 case kind::INTS_DIVISION_TOTAL
:
104 case kind::INTS_MODULUS_TOTAL
:
105 return rewriteIntsDivModTotal(t
,true);
111 RewriteResponse
ArithRewriter::postRewriteTerm(TNode t
){
113 return rewriteConstant(t
);
115 return rewriteVariable(t
);
119 return rewriteMinus(t
, false);
121 return rewriteUMinus(t
, false);
123 case kind::DIVISION_TOTAL
:
124 return rewriteDiv(t
, false);
126 return postRewritePlus(t
);
128 return postRewriteMult(t
);
129 //case kind::INTS_DIVISION:
130 //case kind::INTS_MODULUS:
131 case kind::INTS_DIVISION_TOTAL
:
132 case kind::INTS_MODULUS_TOTAL
:
133 return rewriteIntsDivModTotal(t
, false);
141 RewriteResponse
ArithRewriter::preRewriteMult(TNode t
){
142 Assert(t
.getKind()== kind::MULT
);
144 // Rewrite multiplications with a 0 argument and to 0
147 for(TNode::iterator i
= t
.begin(); i
!= t
.end(); ++i
) {
148 if((*i
).getKind() == kind::CONST_RATIONAL
) {
149 if((*i
).getConst
<Rational
>() == qZero
) {
150 return RewriteResponse(REWRITE_DONE
, mkRationalNode(qZero
));
154 return RewriteResponse(REWRITE_DONE
, t
);
156 RewriteResponse
ArithRewriter::preRewritePlus(TNode t
){
157 Assert(t
.getKind()== kind::PLUS
);
159 return RewriteResponse(REWRITE_DONE
, t
);
162 RewriteResponse
ArithRewriter::postRewritePlus(TNode t
){
163 Assert(t
.getKind()== kind::PLUS
);
165 Polynomial res
= Polynomial::mkZero();
167 for(TNode::iterator i
= t
.begin(), end
= t
.end(); i
!= end
; ++i
){
169 Polynomial currPoly
= Polynomial::parsePolynomial(curr
);
171 res
= res
+ currPoly
;
174 return RewriteResponse(REWRITE_DONE
, res
.getNode());
177 RewriteResponse
ArithRewriter::postRewriteMult(TNode t
){
178 Assert(t
.getKind()== kind::MULT
);
180 Polynomial res
= Polynomial::mkOne();
182 for(TNode::iterator i
= t
.begin(), end
= t
.end(); i
!= end
; ++i
){
184 Polynomial currPoly
= Polynomial::parsePolynomial(curr
);
186 res
= res
* currPoly
;
189 return RewriteResponse(REWRITE_DONE
, res
.getNode());
192 RewriteResponse
ArithRewriter::postRewriteAtom(TNode atom
){
194 TNode left
= atom
[0];
195 TNode right
= atom
[1];
197 Polynomial pleft
= Polynomial::parsePolynomial(left
);
198 Polynomial pright
= Polynomial::parsePolynomial(right
);
200 Comparison cmp
= Comparison::mkComparison(atom
.getKind(), pleft
, pright
);
201 Assert(cmp
.isNormalForm());
202 return RewriteResponse(REWRITE_DONE
, cmp
.getNode());
205 RewriteResponse
ArithRewriter::preRewriteAtom(TNode atom
){
206 Assert(isAtom(atom
));
208 NodeManager
* currNM
= NodeManager::currentNM();
210 if(atom
.getKind() == kind::EQUAL
) {
211 if(atom
[0] == atom
[1]) {
212 return RewriteResponse(REWRITE_DONE
, currNM
->mkConst(true));
214 }else if(atom
.getKind() == kind::GT
){
215 Node leq
= currNM
->mkNode(kind::LEQ
, atom
[0], atom
[1]);
216 return RewriteResponse(REWRITE_DONE
, currNM
->mkNode(kind::NOT
, leq
));
217 }else if(atom
.getKind() == kind::LT
){
218 Node geq
= currNM
->mkNode(kind::GEQ
, atom
[0], atom
[1]);
219 return RewriteResponse(REWRITE_DONE
, currNM
->mkNode(kind::NOT
, geq
));
222 return RewriteResponse(REWRITE_DONE
, atom
);
225 RewriteResponse
ArithRewriter::postRewrite(TNode t
){
227 RewriteResponse response
= postRewriteTerm(t
);
228 if(Debug
.isOn("arith::rewriter") && response
.status
== REWRITE_DONE
) {
229 Polynomial::parsePolynomial(response
.node
);
233 RewriteResponse response
= postRewriteAtom(t
);
234 if(Debug
.isOn("arith::rewriter") && response
.status
== REWRITE_DONE
) {
235 Comparison::parseNormalForm(response
.node
);
240 return RewriteResponse(REWRITE_DONE
, Node::null());
244 RewriteResponse
ArithRewriter::preRewrite(TNode t
){
246 return preRewriteTerm(t
);
248 return preRewriteAtom(t
);
251 return RewriteResponse(REWRITE_DONE
, Node::null());
255 Node
ArithRewriter::makeUnaryMinusNode(TNode n
){
256 Rational
qNegOne(-1);
257 return NodeManager::currentNM()->mkNode(kind::MULT
, mkRationalNode(qNegOne
),n
);
260 Node
ArithRewriter::makeSubtractionNode(TNode l
, TNode r
){
261 Node negR
= makeUnaryMinusNode(r
);
262 Node diff
= NodeManager::currentNM()->mkNode(kind::PLUS
, l
, negR
);
267 RewriteResponse
ArithRewriter::rewriteDiv(TNode t
, bool pre
){
268 Assert(t
.getKind() == kind::DIVISION_TOTAL
|| t
.getKind()== kind::DIVISION
);
273 if(right
.getKind() == kind::CONST_RATIONAL
){
274 const Rational
& den
= right
.getConst
<Rational
>();
277 if(t
.getKind() == kind::DIVISION_TOTAL
){
278 return RewriteResponse(REWRITE_DONE
, mkRationalNode(0));
280 // This is unsupported, but this is not a good place to complain
281 return RewriteResponse(REWRITE_DONE
, t
);
284 Assert(den
!= Rational(0));
286 if(left
.getKind() == kind::CONST_RATIONAL
){
287 const Rational
& num
= left
.getConst
<Rational
>();
288 Rational div
= num
/ den
;
289 Node result
= mkRationalNode(div
);
290 return RewriteResponse(REWRITE_DONE
, result
);
293 Rational div
= den
.inverse();
295 Node result
= mkRationalNode(div
);
297 Node mult
= NodeManager::currentNM()->mkNode(kind::MULT
,left
,result
);
299 return RewriteResponse(REWRITE_DONE
, mult
);
301 return RewriteResponse(REWRITE_AGAIN
, mult
);
304 return RewriteResponse(REWRITE_DONE
, t
);
308 RewriteResponse
ArithRewriter::rewriteIntsDivModTotal(TNode t
, bool pre
){
309 Kind k
= t
.getKind();
310 // Assert(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL ||
311 // k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
313 //Leaving the function as before (INTS_MODULUS can be handled),
314 // but restricting its use here
315 Assert(k
== kind::INTS_MODULUS_TOTAL
|| k
== kind::INTS_DIVISION_TOTAL
);
316 TNode n
= t
[0], d
= t
[1];
317 bool dIsConstant
= d
.getKind() == kind::CONST_RATIONAL
;
318 if(dIsConstant
&& d
.getConst
<Rational
>().isZero()){
319 if(k
== kind::INTS_MODULUS_TOTAL
|| k
== kind::INTS_DIVISION_TOTAL
){
320 return RewriteResponse(REWRITE_DONE
, mkRationalNode(0));
322 // Do nothing for k == INTS_MODULUS
323 return RewriteResponse(REWRITE_DONE
, t
);
325 }else if(dIsConstant
&& d
.getConst
<Rational
>().isOne()){
326 if(k
== kind::INTS_MODULUS
|| k
== kind::INTS_MODULUS_TOTAL
){
327 return RewriteResponse(REWRITE_DONE
, mkRationalNode(0));
329 Assert(k
== kind::INTS_DIVISION
|| k
== kind::INTS_DIVISION_TOTAL
);
330 return RewriteResponse(REWRITE_AGAIN
, n
);
332 }else if(dIsConstant
&& n
.getKind() == kind::CONST_RATIONAL
){
333 Assert(d
.getConst
<Rational
>().isIntegral());
334 Assert(n
.getConst
<Rational
>().isIntegral());
335 Assert(!d
.getConst
<Rational
>().isZero());
336 Integer di
= d
.getConst
<Rational
>().getNumerator();
337 Integer ni
= n
.getConst
<Rational
>().getNumerator();
339 bool isDiv
= (k
== kind::INTS_DIVISION
|| k
== kind::INTS_DIVISION_TOTAL
);
341 Integer result
= isDiv
? ni
.euclidianDivideQuotient(di
) : ni
.euclidianDivideRemainder(di
);
343 Node resultNode
= mkRationalNode(Rational(result
));
344 return RewriteResponse(REWRITE_DONE
, resultNode
);
346 return RewriteResponse(REWRITE_DONE
, t
);
350 }/* CVC4::theory::arith namespace */
351 }/* CVC4::theory namespace */
352 }/* CVC4 namespace */