Merge pull request #28 from kbansal/sets
[cvc5.git] / src / theory / arith / arith_static_learner.cpp
1 /********************* */
2 /*! \file arith_static_learner.cpp
3 ** \verbatim
4 ** Original author: Tim King
5 ** Major contributors: Morgan Deters, Dejan Jovanovic
6 ** Minor contributors (to current version): none
7 ** This file is part of the CVC4 project.
8 ** Copyright (c) 2009-2013 New York University and The University of Iowa
9 ** See the file COPYING in the top-level source directory for licensing
10 ** information.\endverbatim
11 **
12 ** \brief [[ Add one-line brief description here ]]
13 **
14 ** [[ Add lengthier description here ]]
15 ** \todo document this file
16 **/
17
18 #include "theory/rewriter.h"
19
20 #include "theory/arith/arith_utilities.h"
21 #include "theory/arith/arith_static_learner.h"
22 #include "theory/arith/options.h"
23
24 #include "theory/arith/normal_form.h"
25
26 #include "expr/expr.h"
27 #include "expr/convenience_node_builders.h"
28
29 #include <vector>
30
31 using namespace std;
32 using namespace CVC4::kind;
33
34 namespace CVC4 {
35 namespace theory {
36 namespace arith {
37
38
39 ArithStaticLearner::ArithStaticLearner(context::Context* userContext) :
40 d_minMap(userContext),
41 d_maxMap(userContext),
42 d_statistics()
43 {
44 }
45
46 ArithStaticLearner::~ArithStaticLearner(){
47 }
48
49 ArithStaticLearner::Statistics::Statistics():
50 d_iteMinMaxApplications("theory::arith::iteMinMaxApplications", 0),
51 d_iteConstantApplications("theory::arith::iteConstantApplications", 0)
52 {
53 StatisticsRegistry::registerStat(&d_iteMinMaxApplications);
54 StatisticsRegistry::registerStat(&d_iteConstantApplications);
55 }
56
57 ArithStaticLearner::Statistics::~Statistics(){
58 StatisticsRegistry::unregisterStat(&d_iteMinMaxApplications);
59 StatisticsRegistry::unregisterStat(&d_iteConstantApplications);
60 }
61
62 void ArithStaticLearner::staticLearning(TNode n, NodeBuilder<>& learned){
63
64 vector<TNode> workList;
65 workList.push_back(n);
66 TNodeSet processed;
67
68 //Contains an underapproximation of nodes that must hold.
69 TNodeSet defTrue;
70
71 defTrue.insert(n);
72
73 while(!workList.empty()) {
74 n = workList.back();
75
76 bool unprocessedChildren = false;
77 for(TNode::iterator i = n.begin(), iend = n.end(); i != iend; ++i) {
78 if(processed.find(*i) == processed.end()) {
79 // unprocessed child
80 workList.push_back(*i);
81 unprocessedChildren = true;
82 }
83 }
84 if(n.getKind() == AND && defTrue.find(n) != defTrue.end() ){
85 for(TNode::iterator i = n.begin(), iend = n.end(); i != iend; ++i) {
86 defTrue.insert(*i);
87 }
88 }
89
90 if(unprocessedChildren) {
91 continue;
92 }
93
94 workList.pop_back();
95 // has node n been processed in the meantime ?
96 if(processed.find(n) != processed.end()) {
97 continue;
98 }
99 processed.insert(n);
100
101 process(n,learned, defTrue);
102
103 }
104 }
105
106
107 void ArithStaticLearner::process(TNode n, NodeBuilder<>& learned, const TNodeSet& defTrue){
108 Debug("arith::static") << "===================== looking at " << n << endl;
109
110 switch(n.getKind()){
111 case ITE:
112 if(n.hasBoundVar()) {
113 // Unsafe with non-ground ITEs; do nothing
114 Debug("arith::static") << "(potentially) non-ground ITE, ignoring..." << endl;
115 break;
116 }
117
118 if(n[0].getKind() != EQUAL &&
119 isRelationOperator(n[0].getKind()) ){
120 iteMinMax(n, learned);
121 }
122
123 if((d_minMap.find(n[1]) != d_minMap.end() && d_minMap.find(n[2]) != d_minMap.end()) ||
124 (d_maxMap.find(n[1]) != d_maxMap.end() && d_maxMap.find(n[2]) != d_maxMap.end())) {
125 iteConstant(n, learned);
126 }
127 break;
128
129 case CONST_RATIONAL:
130 // Mark constants as minmax
131 d_minMap.insert(n, n.getConst<Rational>());
132 d_maxMap.insert(n, n.getConst<Rational>());
133 break;
134 default: // Do nothing
135 break;
136 }
137 }
138
139 void ArithStaticLearner::iteMinMax(TNode n, NodeBuilder<>& learned){
140 Assert(n.getKind() == kind::ITE);
141 Assert(n[0].getKind() != EQUAL);
142 Assert(isRelationOperator(n[0].getKind()));
143
144 TNode c = n[0];
145 Kind k = oldSimplifiedKind(c);
146 TNode t = n[1];
147 TNode e = n[2];
148 TNode cleft = (c.getKind() == NOT) ? c[0][0] : c[0];
149 TNode cright = (c.getKind() == NOT) ? c[0][1] : c[1];
150
151 if((t == cright) && (e == cleft)){
152 TNode tmp = t;
153 t = e;
154 e = tmp;
155 k = reverseRelationKind(k);
156 }
157 //(ite (< x y) x y)
158 //(ite (x < y) x y)
159 //(ite (x - y < 0) x y)
160 // ----------------
161 // (ite (x - y < -c) )
162
163 if(t == cleft && e == cright){
164 // t == cleft && e == cright
165 Assert( t == cleft );
166 Assert( e == cright );
167 switch(k){
168 case LT: // (ite (< x y) x y)
169 case LEQ: { // (ite (<= x y) x y)
170 Node nLeqX = NodeBuilder<2>(LEQ) << n << t;
171 Node nLeqY = NodeBuilder<2>(LEQ) << n << e;
172 Debug("arith::static") << n << "is a min =>" << nLeqX << nLeqY << endl;
173 learned << nLeqX << nLeqY;
174 ++(d_statistics.d_iteMinMaxApplications);
175 break;
176 }
177 case GT: // (ite (> x y) x y)
178 case GEQ: { // (ite (>= x y) x y)
179 Node nGeqX = NodeBuilder<2>(GEQ) << n << t;
180 Node nGeqY = NodeBuilder<2>(GEQ) << n << e;
181 Debug("arith::static") << n << "is a max =>" << nGeqX << nGeqY << endl;
182 learned << nGeqX << nGeqY;
183 ++(d_statistics.d_iteMinMaxApplications);
184 break;
185 }
186 default: Unreachable();
187 }
188 }
189 }
190
191 void ArithStaticLearner::iteConstant(TNode n, NodeBuilder<>& learned){
192 Assert(n.getKind() == ITE);
193
194 Debug("arith::static") << "iteConstant(" << n << ")" << endl;
195
196 if (d_minMap.find(n[1]) != d_minMap.end() && d_minMap.find(n[2]) != d_minMap.end()) {
197 const DeltaRational& first = d_minMap[n[1]];
198 const DeltaRational& second = d_minMap[n[2]];
199 DeltaRational min = std::min(first, second);
200 CDNodeToMinMaxMap::const_iterator minFind = d_minMap.find(n);
201 if (minFind == d_minMap.end() || (*minFind).second < min) {
202 d_minMap.insert(n, min);
203 Node nGeqMin;
204 if (min.getInfinitesimalPart() == 0) {
205 nGeqMin = NodeBuilder<2>(kind::GEQ) << n << mkRationalNode(min.getNoninfinitesimalPart());
206 } else {
207 nGeqMin = NodeBuilder<2>(kind::GT) << n << mkRationalNode(min.getNoninfinitesimalPart());
208 }
209 learned << nGeqMin;
210 Debug("arith::static") << n << " iteConstant" << nGeqMin << endl;
211 ++(d_statistics.d_iteConstantApplications);
212 }
213 }
214
215 if (d_maxMap.find(n[1]) != d_maxMap.end() && d_maxMap.find(n[2]) != d_maxMap.end()) {
216 const DeltaRational& first = d_maxMap[n[1]];
217 const DeltaRational& second = d_maxMap[n[2]];
218 DeltaRational max = std::max(first, second);
219 CDNodeToMinMaxMap::const_iterator maxFind = d_maxMap.find(n);
220 if (maxFind == d_maxMap.end() || (*maxFind).second > max) {
221 d_maxMap.insert(n, max);
222 Node nLeqMax;
223 if (max.getInfinitesimalPart() == 0) {
224 nLeqMax = NodeBuilder<2>(kind::LEQ) << n << mkRationalNode(max.getNoninfinitesimalPart());
225 } else {
226 nLeqMax = NodeBuilder<2>(kind::LT) << n << mkRationalNode(max.getNoninfinitesimalPart());
227 }
228 learned << nLeqMax;
229 Debug("arith::static") << n << " iteConstant" << nLeqMax << endl;
230 ++(d_statistics.d_iteConstantApplications);
231 }
232 }
233 }
234
235 std::set<Node> listToSet(TNode l){
236 std::set<Node> ret;
237 while(l.getKind() == OR){
238 Assert(l.getNumChildren() == 2);
239 ret.insert(l[0]);
240 l = l[1];
241 }
242 return ret;
243 }
244
245 void ArithStaticLearner::addBound(TNode n) {
246
247 CDNodeToMinMaxMap::const_iterator minFind = d_minMap.find(n[0]);
248 CDNodeToMinMaxMap::const_iterator maxFind = d_maxMap.find(n[0]);
249
250 Rational constant = n[1].getConst<Rational>();
251 DeltaRational bound = constant;
252
253 switch(Kind k = n.getKind()) {
254 case kind::LT:
255 bound = DeltaRational(constant, -1);
256 /* fall through */
257 case kind::LEQ:
258 if (maxFind == d_maxMap.end() || (*maxFind).second > bound) {
259 d_maxMap.insert(n[0], bound);
260 Debug("arith::static") << "adding bound " << n << endl;
261 }
262 break;
263 case kind::GT:
264 bound = DeltaRational(constant, 1);
265 /* fall through */
266 case kind::GEQ:
267 if (minFind == d_minMap.end() || (*minFind).second < bound) {
268 d_minMap.insert(n[0], bound);
269 Debug("arith::static") << "adding bound " << n << endl;
270 }
271 break;
272 default:
273 Unhandled(k);
274 break;
275 }
276 }
277
278 }/* CVC4::theory::arith namespace */
279 }/* CVC4::theory namespace */
280 }/* CVC4 namespace */