* theory "tree" rewriting implemented and works
[cvc5.git] / test / unit / theory / theory_arith_white.h
1
2 #include <cxxtest/TestSuite.h>
3
4 #include "theory/theory.h"
5 #include "theory/arith/theory_arith.h"
6 #include "expr/node.h"
7 #include "expr/node_manager.h"
8 #include "context/context.h"
9 #include "util/rational.h"
10
11 #include "theory/theory_test_utils.h"
12
13 #include <vector>
14
15 using namespace CVC4;
16 using namespace CVC4::theory;
17 using namespace CVC4::theory::arith;
18 using namespace CVC4::expr;
19 using namespace CVC4::context;
20 using namespace CVC4::kind;
21
22 using namespace std;
23
24 class TheoryArithWhite : public CxxTest::TestSuite {
25
26 Context* d_ctxt;
27 NodeManager* d_nm;
28 NodeManagerScope* d_scope;
29
30 TestOutputChannel d_outputChannel;
31 Theory::Effort d_level;
32
33 TheoryArith* d_arith;
34
35 TypeNode* d_booleanType;
36 TypeNode* d_realType;
37
38 const Rational d_zero;
39 const Rational d_one;
40
41 std::set<Node>* preregistered;
42
43 bool debug;
44
45 public:
46
47 TheoryArithWhite() : d_level(Theory::FULL_EFFORT), d_zero(0), d_one(1), debug(false) {}
48
49 void setUp() {
50 d_ctxt = new Context;
51 d_nm = new NodeManager(d_ctxt);
52 d_scope = new NodeManagerScope(d_nm);
53 d_outputChannel.clear();
54 d_arith = new TheoryArith(d_ctxt, d_outputChannel);
55
56 preregistered = new std::set<Node>();
57
58 d_booleanType = new TypeNode(d_nm->booleanType());
59 d_realType = new TypeNode(d_nm->realType());
60
61 }
62
63 void tearDown() {
64 delete d_realType;
65 delete d_booleanType;
66
67 delete preregistered;
68
69 delete d_arith;
70 d_outputChannel.clear();
71 delete d_scope;
72 delete d_nm;
73 delete d_ctxt;
74 }
75
76 Node fakeTheoryEnginePreprocess(TNode inp){
77 Node rewrite = d_arith->rewrite(inp);
78
79 if(debug) cout << rewrite << inp << endl;
80
81 std::list<Node> toPreregister;
82
83 toPreregister.push_back(rewrite);
84 for(std::list<Node>::iterator i = toPreregister.begin(); i != toPreregister.end(); ++i){
85 Node n = *i;
86 preregistered->insert(n);
87
88 for(Node::iterator citer = n.begin(); citer != n.end(); ++citer){
89 Node c = *citer;
90 if(preregistered->find(c) == preregistered->end()){
91 toPreregister.push_back(c);
92 }
93 }
94 }
95 for(std::list<Node>::reverse_iterator i = toPreregister.rbegin(); i != toPreregister.rend(); ++i){
96 Node n = *i;
97 if(debug) cout << n.getId() << " "<< n << endl;
98 d_arith->preRegisterTerm(n);
99 }
100
101 return rewrite;
102 }
103
104 void testAssert() {
105 Node x = d_nm->mkVar(*d_realType);
106 Node c = d_nm->mkConst<Rational>(d_zero);
107
108 Node leq = d_nm->mkNode(LEQ, x, c);
109 Node rLeq = fakeTheoryEnginePreprocess(leq);
110
111 d_arith->assertFact(rLeq);
112
113 d_arith->check(d_level);
114
115 TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 0u);
116 }
117
118 Node simulateSplit(TNode l, TNode r){
119 Node eq = d_nm->mkNode(EQUAL, l, r);
120 Node lt = d_nm->mkNode(LT, l, r);
121 Node gt = d_nm->mkNode(GT, l, r);
122
123 Node dis = d_nm->mkNode(OR, eq, lt, gt);
124 return dis;
125 }
126
127 void testAssertEqualityEagerSplit() {
128 Node x = d_nm->mkVar(*d_realType);
129 Node c = d_nm->mkConst<Rational>(d_zero);
130
131 Node eq = d_nm->mkNode(EQUAL, x, c);
132 Node expectedDisjunct = simulateSplit(x,c);
133
134 Node rEq = fakeTheoryEnginePreprocess(eq);
135
136 d_arith->assertFact(rEq);
137
138 d_arith->check(d_level);
139
140 TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 1u);
141
142 TS_ASSERT_EQUALS(d_outputChannel.getIthNode(0), expectedDisjunct);
143 TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(0), AUG_LEMMA);
144
145 }
146 void testLtRewrite() {
147 Node x = d_nm->mkVar(*d_realType);
148 Node c = d_nm->mkConst<Rational>(d_zero);
149
150 Node lt = d_nm->mkNode(LT, x, c);
151 Node geq = d_nm->mkNode(GEQ, x, c);
152 Node expectedRewrite = d_nm->mkNode(NOT, geq);
153
154 Node rewrite = d_arith->rewrite(lt);
155
156 TS_ASSERT_EQUALS(expectedRewrite, rewrite);
157 }
158
159 void testBasicConflict() {
160 Node x = d_nm->mkVar(*d_realType);
161 Node c = d_nm->mkConst<Rational>(d_zero);
162
163 Node eq = d_nm->mkNode(EQUAL, x, c);
164 Node lt = d_nm->mkNode(LT, x, c);
165 Node expectedDisjunct = simulateSplit(x,c);
166
167 Node rEq = fakeTheoryEnginePreprocess(eq);
168 Node rLt = fakeTheoryEnginePreprocess(lt);
169
170 d_arith->assertFact(rEq);
171 d_arith->assertFact(rLt);
172
173
174 d_arith->check(d_level);
175
176 TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 2u);
177 TS_ASSERT_EQUALS(d_outputChannel.getIthNode(0), expectedDisjunct);
178 TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(0), AUG_LEMMA);
179
180 TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(1), CONFLICT);
181
182 Node expectedClonflict = d_nm->mkNode(AND, rEq, rLt);
183
184 TS_ASSERT_EQUALS(d_outputChannel.getIthNode(1), expectedClonflict);
185 }
186
187 void testBasicPropagate() {
188 Node x = d_nm->mkVar(*d_realType);
189 Node c = d_nm->mkConst<Rational>(d_zero);
190
191 Node eq = d_nm->mkNode(EQUAL, x, c);
192 Node lt = d_nm->mkNode(LT, x, c);
193 Node expectedDisjunct = simulateSplit(x,c);
194
195 Node rEq = fakeTheoryEnginePreprocess(eq);
196 Node rLt = fakeTheoryEnginePreprocess(lt);
197
198 d_arith->assertFact(rEq);
199
200
201 d_arith->check(d_level);
202 d_arith->propagate(d_level);
203
204 TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 2u);
205 TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(0), AUG_LEMMA);
206
207
208 Node expectedProp = d_nm->mkNode(GEQ, x, c);
209 TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(1), PROPAGATE);
210 TS_ASSERT_EQUALS(d_outputChannel.getIthNode(1), expectedProp);
211
212 }
213 void testTPLt1() {
214 Node x = d_nm->mkVar(*d_realType);
215 Node c0 = d_nm->mkConst<Rational>(d_zero);
216 Node c1 = d_nm->mkConst<Rational>(d_one);
217
218 Node leq0 = d_nm->mkNode(LEQ, x, c0);
219 Node leq1 = d_nm->mkNode(LEQ, x, c1);
220 Node lt1 = d_nm->mkNode(LT, x, c1);
221
222 Node rLeq0 = fakeTheoryEnginePreprocess(leq0);
223 Node rLt1 = fakeTheoryEnginePreprocess(lt1);
224 Node rLeq1 = fakeTheoryEnginePreprocess(leq1);
225
226 d_arith->assertFact(rLt1);
227
228
229 d_arith->check(d_level);
230 d_arith->propagate(d_level);
231
232 #ifdef CVC4_ASSERTIONS
233 TS_ASSERT_THROWS( d_arith->explain(rLeq0, d_level), AssertionException );
234 TS_ASSERT_THROWS( d_arith->explain(rLt1, d_level), AssertionException );
235 #endif
236 d_arith->explain(rLeq1, d_level);
237
238 TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 2u);
239 TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(0), PROPAGATE);
240 TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(1), EXPLANATION);
241 TS_ASSERT_EQUALS(d_outputChannel.getIthNode(0), leq1);
242 TS_ASSERT_EQUALS(d_outputChannel.getIthNode(1), rLt1);
243 }
244
245
246 void testTPLeq0() {
247 Node x = d_nm->mkVar(*d_realType);
248 Node c0 = d_nm->mkConst<Rational>(d_zero);
249 Node c1 = d_nm->mkConst<Rational>(d_one);
250
251 Node leq0 = d_nm->mkNode(LEQ, x, c0);
252 Node leq1 = d_nm->mkNode(LEQ, x, c1);
253 Node lt1 = d_nm->mkNode(LT, x, c1);
254
255 Node rLeq0 = fakeTheoryEnginePreprocess(leq0);
256 Node rLt1 = fakeTheoryEnginePreprocess(lt1);
257 Node rLeq1 = fakeTheoryEnginePreprocess(leq1);
258
259 d_arith->assertFact(rLeq0);
260
261
262 d_arith->check(d_level);
263 d_arith->propagate(d_level);
264
265
266 d_arith->explain(rLt1, d_level);
267 #ifdef CVC4_ASSERTIONS
268 TS_ASSERT_THROWS( d_arith->explain(rLeq0, d_level), AssertionException );
269 #endif
270 d_arith->explain(rLeq1, d_level);
271
272 TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 4u);
273 TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(0), PROPAGATE);
274 TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(1), PROPAGATE);
275 TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(2), EXPLANATION);
276 TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(3), EXPLANATION);
277
278 TS_ASSERT_EQUALS(d_outputChannel.getIthNode(1), rLt1);
279 TS_ASSERT_EQUALS(d_outputChannel.getIthNode(0), rLeq1);
280
281
282 TS_ASSERT_EQUALS(d_outputChannel.getIthNode(2), rLeq0);
283 TS_ASSERT_EQUALS(d_outputChannel.getIthNode(3), rLeq0);
284 }
285 void testTPLeq1() {
286 Node x = d_nm->mkVar(*d_realType);
287 Node c0 = d_nm->mkConst<Rational>(d_zero);
288 Node c1 = d_nm->mkConst<Rational>(d_one);
289
290 Node leq0 = d_nm->mkNode(LEQ, x, c0);
291 Node leq1 = d_nm->mkNode(LEQ, x, c1);
292 Node lt1 = d_nm->mkNode(LT, x, c1);
293
294 Node rLeq0 = fakeTheoryEnginePreprocess(leq0);
295 Node rLt1 = fakeTheoryEnginePreprocess(lt1);
296 Node rLeq1 = fakeTheoryEnginePreprocess(leq1);
297
298 d_arith->assertFact(rLeq1);
299
300
301 d_arith->check(d_level);
302 d_arith->propagate(d_level);
303
304 #ifdef CVC4_ASSERTIONS
305 TS_ASSERT_THROWS( d_arith->explain(rLeq0, d_level), AssertionException );
306 TS_ASSERT_THROWS( d_arith->explain(rLeq1, d_level), AssertionException );
307 TS_ASSERT_THROWS( d_arith->explain(rLt1, d_level), AssertionException );
308 #endif
309
310 TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 0u);
311 }
312 };