1 /********************* */
2 /*! \file arith_rewriter.cpp
4 ** Original author: taking
5 ** Major contributors: none
6 ** Minor contributors (to current version): mdeters, dejan
7 ** This file is part of the CVC4 prototype.
8 ** Copyright (c) 2009, 2010, 2011 The Analysis of Computer Systems Group (ACSys)
9 ** Courant Institute of Mathematical Sciences
10 ** New York University
11 ** See the file COPYING in the top-level source directory for licensing
12 ** information.\endverbatim
14 ** \brief [[ Add one-line brief description here ]]
16 ** [[ Add lengthier description here ]]
17 ** \todo document this file
20 #include "theory/theory.h"
21 #include "theory/arith/normal_form.h"
22 #include "theory/arith/arith_rewriter.h"
23 #include "theory/arith/arith_utilities.h"
33 bool isVariable(TNode t
){
34 return t
.getMetaKind() == kind::metakind::VARIABLE
;
37 bool ArithRewriter::isAtom(TNode n
) {
38 return arith::isRelationOperator(n
.getKind());
41 RewriteResponse
ArithRewriter::rewriteConstant(TNode t
){
42 Assert(t
.getMetaKind() == kind::metakind::CONSTANT
);
43 Assert(t
.getKind() == kind::CONST_RATIONAL
);
45 return RewriteResponse(REWRITE_DONE
, t
);
48 RewriteResponse
ArithRewriter::rewriteVariable(TNode t
){
49 Assert(isVariable(t
));
51 return RewriteResponse(REWRITE_DONE
, t
);
54 RewriteResponse
ArithRewriter::rewriteMinus(TNode t
, bool pre
){
55 Assert(t
.getKind()== kind::MINUS
);
60 Node zeroNode
= mkRationalNode(zero
);
61 return RewriteResponse(REWRITE_DONE
, zeroNode
);
63 Node noMinus
= makeSubtractionNode(t
[0],t
[1]);
64 return RewriteResponse(REWRITE_DONE
, noMinus
);
67 Polynomial minuend
= Polynomial::parsePolynomial(t
[0]);
68 Polynomial subtrahend
= Polynomial::parsePolynomial(t
[0]);
69 Polynomial diff
= minuend
- subtrahend
;
70 return RewriteResponse(REWRITE_DONE
, diff
.getNode());
74 RewriteResponse
ArithRewriter::rewriteUMinus(TNode t
, bool pre
){
75 Assert(t
.getKind()== kind::UMINUS
);
77 Node noUminus
= makeUnaryMinusNode(t
[0]);
79 return RewriteResponse(REWRITE_DONE
, noUminus
);
81 return RewriteResponse(REWRITE_AGAIN
, noUminus
);
84 RewriteResponse
ArithRewriter::preRewriteTerm(TNode t
){
85 if(t
.getMetaKind() == kind::metakind::CONSTANT
){
86 return rewriteConstant(t
);
87 }else if(isVariable(t
)){
88 return rewriteVariable(t
);
89 }else if(t
.getKind() == kind::MINUS
){
90 return rewriteMinus(t
, true);
91 }else if(t
.getKind() == kind::UMINUS
){
92 return rewriteUMinus(t
, true);
93 }else if(t
.getKind() == kind::DIVISION
){
94 return RewriteResponse(REWRITE_DONE
, t
); // wait until t[1] is rewritten
95 }else if(t
.getKind() == kind::PLUS
){
96 return preRewritePlus(t
);
97 }else if(t
.getKind() == kind::MULT
){
98 return preRewriteMult(t
);
99 }else if(t
.getKind() == kind::INTS_DIVISION
){
101 if(t
[1].getKind()== kind::CONST_RATIONAL
&& t
[1].getConst
<Rational
>() == intOne
){
102 return RewriteResponse(REWRITE_AGAIN
, t
[0]);
104 return RewriteResponse(REWRITE_DONE
, t
);
106 }else if(t
.getKind() == kind::INTS_MODULUS
){
108 if(t
[1].getKind()== kind::CONST_RATIONAL
&& t
[1].getConst
<Rational
>() == intOne
){
110 return RewriteResponse(REWRITE_AGAIN
, mkRationalNode(intZero
));
112 return RewriteResponse(REWRITE_DONE
, t
);
118 RewriteResponse
ArithRewriter::postRewriteTerm(TNode t
){
119 if(t
.getMetaKind() == kind::metakind::CONSTANT
){
120 return rewriteConstant(t
);
121 }else if(isVariable(t
)){
122 return rewriteVariable(t
);
123 }else if(t
.getKind() == kind::MINUS
){
124 return rewriteMinus(t
, false);
125 }else if(t
.getKind() == kind::UMINUS
){
126 return rewriteUMinus(t
, false);
127 }else if(t
.getKind() == kind::DIVISION
){
128 return rewriteDivByConstant(t
, false);
129 }else if(t
.getKind() == kind::PLUS
){
130 return postRewritePlus(t
);
131 }else if(t
.getKind() == kind::MULT
){
132 return postRewriteMult(t
);
133 }else if(t
.getKind() == kind::INTS_DIVISION
){
134 return RewriteResponse(REWRITE_DONE
, t
);
135 }else if(t
.getKind() == kind::INTS_MODULUS
){
136 return RewriteResponse(REWRITE_DONE
, t
);
142 RewriteResponse
ArithRewriter::preRewriteMult(TNode t
){
143 Assert(t
.getKind()== kind::MULT
);
145 // Rewrite multiplications with a 0 argument and to 0
148 for(TNode::iterator i
= t
.begin(); i
!= t
.end(); ++i
) {
149 if((*i
).getKind() == kind::CONST_RATIONAL
) {
150 if((*i
).getConst
<Rational
>() == qZero
) {
151 return RewriteResponse(REWRITE_DONE
, mkRationalNode(qZero
));
155 return RewriteResponse(REWRITE_DONE
, t
);
157 RewriteResponse
ArithRewriter::preRewritePlus(TNode t
){
158 Assert(t
.getKind()== kind::PLUS
);
160 return RewriteResponse(REWRITE_DONE
, t
);
163 RewriteResponse
ArithRewriter::postRewritePlus(TNode t
){
164 Assert(t
.getKind()== kind::PLUS
);
166 Polynomial res
= Polynomial::mkZero();
168 for(TNode::iterator i
= t
.begin(), end
= t
.end(); i
!= end
; ++i
){
170 Polynomial currPoly
= Polynomial::parsePolynomial(curr
);
172 res
= res
+ currPoly
;
175 return RewriteResponse(REWRITE_DONE
, res
.getNode());
178 RewriteResponse
ArithRewriter::postRewriteMult(TNode t
){
179 Assert(t
.getKind()== kind::MULT
);
181 Polynomial res
= Polynomial::mkOne();
183 for(TNode::iterator i
= t
.begin(), end
= t
.end(); i
!= end
; ++i
){
185 Polynomial currPoly
= Polynomial::parsePolynomial(curr
);
187 res
= res
* currPoly
;
190 return RewriteResponse(REWRITE_DONE
, res
.getNode());
193 // RewriteResponse ArithRewriter::postRewriteAtomConstantRHS(TNode t){
194 // TNode left = t[0];
195 // TNode right = t[1];
197 // Polynomial pLeft = Polynomial::parsePolynomial(left);
200 // Comparison cmp = Comparison::mkComparison(t.getKind(), Polynomial::parsePolynomial(left), Constant(right));
202 // Assert(cmp.isNormalForm());
203 // return RewriteResponse(REWRITE_DONE, cmp.getNode());
206 RewriteResponse
ArithRewriter::postRewriteAtom(TNode atom
){
208 TNode left
= atom
[0];
209 TNode right
= atom
[1];
211 Polynomial pleft
= Polynomial::parsePolynomial(left
);
212 Polynomial pright
= Polynomial::parsePolynomial(right
);
214 Comparison cmp
= Comparison::mkComparison(atom
.getKind(), pleft
, pright
);
215 Assert(cmp
.isNormalForm());
216 return RewriteResponse(REWRITE_DONE
, cmp
.getNode());
219 RewriteResponse
ArithRewriter::preRewriteAtom(TNode atom
){
220 Assert(isAtom(atom
));
222 NodeManager
* currNM
= NodeManager::currentNM();
224 if(atom
.getKind() == kind::EQUAL
) {
225 if(atom
[0] == atom
[1]) {
226 return RewriteResponse(REWRITE_DONE
, currNM
->mkConst(true));
228 }else if(atom
.getKind() == kind::GT
){
229 Node leq
= currNM
->mkNode(kind::LEQ
, atom
[0], atom
[1]);
230 return RewriteResponse(REWRITE_DONE
, currNM
->mkNode(kind::NOT
, leq
));
231 }else if(atom
.getKind() == kind::LT
){
232 Node geq
= currNM
->mkNode(kind::GEQ
, atom
[0], atom
[1]);
233 return RewriteResponse(REWRITE_DONE
, currNM
->mkNode(kind::NOT
, geq
));
236 return RewriteResponse(REWRITE_DONE
, atom
);
239 RewriteResponse
ArithRewriter::postRewrite(TNode t
){
241 RewriteResponse response
= postRewriteTerm(t
);
242 if(Debug
.isOn("arith::rewriter") && response
.status
== REWRITE_DONE
) {
243 Polynomial::parsePolynomial(response
.node
);
247 RewriteResponse response
= postRewriteAtom(t
);
248 if(Debug
.isOn("arith::rewriter") && response
.status
== REWRITE_DONE
) {
249 Comparison::parseNormalForm(response
.node
);
254 return RewriteResponse(REWRITE_DONE
, Node::null());
258 RewriteResponse
ArithRewriter::preRewrite(TNode t
){
260 return preRewriteTerm(t
);
262 return preRewriteAtom(t
);
265 return RewriteResponse(REWRITE_DONE
, Node::null());
269 Node
ArithRewriter::makeUnaryMinusNode(TNode n
){
270 Rational
qNegOne(-1);
271 return NodeManager::currentNM()->mkNode(kind::MULT
, mkRationalNode(qNegOne
),n
);
274 Node
ArithRewriter::makeSubtractionNode(TNode l
, TNode r
){
275 Node negR
= makeUnaryMinusNode(r
);
276 Node diff
= NodeManager::currentNM()->mkNode(kind::PLUS
, l
, negR
);
281 RewriteResponse
ArithRewriter::rewriteDivByConstant(TNode t
, bool pre
){
282 Assert(t
.getKind()== kind::DIVISION
);
286 Assert(right
.getKind()== kind::CONST_RATIONAL
);
289 const Rational
& den
= right
.getConst
<Rational
>();
291 Assert(den
!= Rational(0));
293 Rational div
= den
.inverse();
295 Node result
= mkRationalNode(div
);
297 Node mult
= NodeManager::currentNM()->mkNode(kind::MULT
,left
,result
);
299 return RewriteResponse(REWRITE_DONE
, mult
);
301 return RewriteResponse(REWRITE_AGAIN
, mult
);
305 }/* CVC4::theory::arith namespace */
306 }/* CVC4::theory namespace */
307 }/* CVC4 namespace */