Fixed TestTheoryWhiteBagsRewriter.map failure (#7103)
[cvc5.git] / test / unit / theory / theory_bags_rewriter_white.cpp
1 /******************************************************************************
2 * Top contributors (to current version):
3 * Aina Niemetz, Mudathir Mohamed, Andrew Reynolds
4 *
5 * This file is part of the cvc5 project.
6 *
7 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 * in the top-level source directory and their institutional affiliations.
9 * All rights reserved. See the file COPYING in the top-level source
10 * directory for licensing information.
11 * ****************************************************************************
12 *
13 * White box testing of bags rewriter
14 */
15
16 #include "expr/dtype.h"
17 #include "expr/emptybag.h"
18 #include "test_smt.h"
19 #include "theory/bags/bags_rewriter.h"
20 #include "theory/strings/type_enumerator.h"
21 #include "util/rational.h"
22 #include "util/string.h"
23
24 namespace cvc5 {
25
26 using namespace theory;
27 using namespace kind;
28 using namespace theory::bags;
29
30 namespace test {
31
32 typedef expr::Attribute<Node, Node> attribute;
33
34 class TestTheoryWhiteBagsRewriter : public TestSmt
35 {
36 protected:
37 void SetUp() override
38 {
39 TestSmt::SetUp();
40 d_rewriter.reset(new BagsRewriter(nullptr));
41 }
42
43 std::vector<Node> getNStrings(size_t n)
44 {
45 std::vector<Node> elements(n);
46 for (size_t i = 0; i < n; i++)
47 {
48 elements[i] =
49 d_skolemManager->mkDummySkolem("x", d_nodeManager->stringType());
50 }
51 return elements;
52 }
53
54 std::unique_ptr<BagsRewriter> d_rewriter;
55 };
56
57 TEST_F(TestTheoryWhiteBagsRewriter, empty_bag_normal_form)
58 {
59 Node emptybag = d_nodeManager->mkConst(EmptyBag(d_nodeManager->stringType()));
60 // empty bags are in normal form
61 ASSERT_TRUE(emptybag.isConst());
62 RewriteResponse response = d_rewriter->postRewrite(emptybag);
63 ASSERT_TRUE(emptybag == response.d_node && response.d_status == REWRITE_DONE);
64 }
65
66 TEST_F(TestTheoryWhiteBagsRewriter, bag_equality)
67 {
68 std::vector<Node> elements = getNStrings(2);
69 Node x = elements[0];
70 Node y = elements[1];
71 Node c = d_skolemManager->mkDummySkolem("c", d_nodeManager->integerType());
72 Node d = d_skolemManager->mkDummySkolem("d", d_nodeManager->integerType());
73 Node A = d_nodeManager->mkBag(d_nodeManager->stringType(), x, c);
74 Node B = d_nodeManager->mkBag(d_nodeManager->stringType(), y, d);
75 Node emptyBag = d_nodeManager->mkConst(
76 EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
77 Node emptyString = d_nodeManager->mkConst(String(""));
78 Node constantBag = d_nodeManager->mkBag(d_nodeManager->stringType(),
79 emptyString,
80 d_nodeManager->mkConst(Rational(1)));
81
82 // (= A A) = true where A is a bag
83 Node n1 = A.eqNode(A);
84 RewriteResponse response1 = d_rewriter->preRewrite(n1);
85 ASSERT_TRUE(response1.d_node == d_nodeManager->mkConst(true)
86 && response1.d_status == REWRITE_AGAIN_FULL);
87
88 // (= A B) = false if A and B are different bag constants
89 Node n2 = constantBag.eqNode(emptyBag);
90 RewriteResponse response2 = d_rewriter->postRewrite(n2);
91 ASSERT_TRUE(response2.d_node == d_nodeManager->mkConst(false)
92 && response2.d_status == REWRITE_AGAIN_FULL);
93
94 // (= B A) = (= A B) if A < B and at least one of A or B is not a constant
95 Node n3 = B.eqNode(A);
96 RewriteResponse response3 = d_rewriter->postRewrite(n3);
97 ASSERT_TRUE(response3.d_node == A.eqNode(B)
98 && response3.d_status == REWRITE_AGAIN_FULL);
99
100 // (= A B) = (= A B) no rewrite
101 Node n4 = A.eqNode(B);
102 RewriteResponse response4 = d_rewriter->postRewrite(n4);
103 ASSERT_TRUE(response4.d_node == n4 && response4.d_status == REWRITE_DONE);
104 }
105
106 TEST_F(TestTheoryWhiteBagsRewriter, mkBag_constant_element)
107 {
108 std::vector<Node> elements = getNStrings(1);
109 Node negative = d_nodeManager->mkBag(d_nodeManager->stringType(),
110 elements[0],
111 d_nodeManager->mkConst(Rational(-1)));
112 Node zero = d_nodeManager->mkBag(d_nodeManager->stringType(),
113 elements[0],
114 d_nodeManager->mkConst(Rational(0)));
115 Node positive = d_nodeManager->mkBag(d_nodeManager->stringType(),
116 elements[0],
117 d_nodeManager->mkConst(Rational(1)));
118 Node emptybag = d_nodeManager->mkConst(
119 EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
120 RewriteResponse negativeResponse = d_rewriter->postRewrite(negative);
121 RewriteResponse zeroResponse = d_rewriter->postRewrite(zero);
122 RewriteResponse positiveResponse = d_rewriter->postRewrite(positive);
123
124 // bags with non-positive multiplicity are rewritten as empty bags
125 ASSERT_TRUE(negativeResponse.d_status == REWRITE_AGAIN_FULL
126 && negativeResponse.d_node == emptybag);
127 ASSERT_TRUE(zeroResponse.d_status == REWRITE_AGAIN_FULL
128 && zeroResponse.d_node == emptybag);
129
130 // no change for positive
131 ASSERT_TRUE(positiveResponse.d_status == REWRITE_DONE
132 && positive == positiveResponse.d_node);
133 }
134
135 TEST_F(TestTheoryWhiteBagsRewriter, mkBag_variable_element)
136 {
137 Node skolem =
138 d_skolemManager->mkDummySkolem("x", d_nodeManager->stringType());
139 Node variable = d_nodeManager->mkBag(d_nodeManager->stringType(),
140 skolem,
141 d_nodeManager->mkConst(Rational(-1)));
142 Node negative = d_nodeManager->mkBag(d_nodeManager->stringType(),
143 skolem,
144 d_nodeManager->mkConst(Rational(-1)));
145 Node zero = d_nodeManager->mkBag(
146 d_nodeManager->stringType(), skolem, d_nodeManager->mkConst(Rational(0)));
147 Node positive = d_nodeManager->mkBag(
148 d_nodeManager->stringType(), skolem, d_nodeManager->mkConst(Rational(1)));
149 Node emptybag = d_nodeManager->mkConst(
150 EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
151 RewriteResponse negativeResponse = d_rewriter->postRewrite(negative);
152 RewriteResponse zeroResponse = d_rewriter->postRewrite(zero);
153 RewriteResponse positiveResponse = d_rewriter->postRewrite(positive);
154
155 // bags with non-positive multiplicity are rewritten as empty bags
156 ASSERT_TRUE(negativeResponse.d_status == REWRITE_AGAIN_FULL
157 && negativeResponse.d_node == emptybag);
158 ASSERT_TRUE(zeroResponse.d_status == REWRITE_AGAIN_FULL
159 && zeroResponse.d_node == emptybag);
160
161 // no change for positive
162 ASSERT_TRUE(positiveResponse.d_status == REWRITE_DONE
163 && positive == positiveResponse.d_node);
164 }
165
166 TEST_F(TestTheoryWhiteBagsRewriter, bag_count)
167 {
168 int n = 3;
169 Node skolem =
170 d_skolemManager->mkDummySkolem("x", d_nodeManager->stringType());
171 Node emptyBag = d_nodeManager->mkConst(
172 EmptyBag(d_nodeManager->mkBagType(skolem.getType())));
173 Node bag = d_nodeManager->mkBag(
174 d_nodeManager->stringType(), skolem, d_nodeManager->mkConst(Rational(n)));
175
176 // (bag.count x emptybag) = 0
177 Node n1 = d_nodeManager->mkNode(BAG_COUNT, skolem, emptyBag);
178 RewriteResponse response1 = d_rewriter->postRewrite(n1);
179 ASSERT_TRUE(response1.d_status == REWRITE_AGAIN_FULL
180 && response1.d_node == d_nodeManager->mkConst(Rational(0)));
181
182 // (bag.count x (mkBag x c) = c where c > 0 is a constant
183 Node n2 = d_nodeManager->mkNode(BAG_COUNT, skolem, bag);
184 RewriteResponse response2 = d_rewriter->postRewrite(n2);
185 ASSERT_TRUE(response2.d_status == REWRITE_AGAIN_FULL
186 && response2.d_node == d_nodeManager->mkConst(Rational(n)));
187 }
188
189 TEST_F(TestTheoryWhiteBagsRewriter, duplicate_removal)
190 {
191 Node x = d_skolemManager->mkDummySkolem("x", d_nodeManager->stringType());
192 Node bag = d_nodeManager->mkBag(
193 d_nodeManager->stringType(), x, d_nodeManager->mkConst(Rational(5)));
194
195 // (duplicate_removal (mkBag x n)) = (mkBag x 1)
196 Node n = d_nodeManager->mkNode(DUPLICATE_REMOVAL, bag);
197 RewriteResponse response = d_rewriter->postRewrite(n);
198 Node noDuplicate = d_nodeManager->mkBag(
199 d_nodeManager->stringType(), x, d_nodeManager->mkConst(Rational(1)));
200 ASSERT_TRUE(response.d_node == noDuplicate
201 && response.d_status == REWRITE_AGAIN_FULL);
202 }
203
204 TEST_F(TestTheoryWhiteBagsRewriter, union_max)
205 {
206 int n = 3;
207 std::vector<Node> elements = getNStrings(2);
208 Node emptyBag = d_nodeManager->mkConst(
209 EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
210 Node A = d_nodeManager->mkBag(d_nodeManager->stringType(),
211 elements[0],
212 d_nodeManager->mkConst(Rational(n)));
213 Node B = d_nodeManager->mkBag(d_nodeManager->stringType(),
214 elements[1],
215 d_nodeManager->mkConst(Rational(n + 1)));
216 Node unionMaxAB = d_nodeManager->mkNode(UNION_MAX, A, B);
217 Node unionMaxBA = d_nodeManager->mkNode(UNION_MAX, B, A);
218 Node unionDisjointAB = d_nodeManager->mkNode(UNION_DISJOINT, A, B);
219 Node unionDisjointBA = d_nodeManager->mkNode(UNION_DISJOINT, B, A);
220
221 // (union_max A emptybag) = A
222 Node unionMax1 = d_nodeManager->mkNode(UNION_MAX, A, emptyBag);
223 RewriteResponse response1 = d_rewriter->postRewrite(unionMax1);
224 ASSERT_TRUE(response1.d_node == A
225 && response1.d_status == REWRITE_AGAIN_FULL);
226
227 // (union_max emptybag A) = A
228 Node unionMax2 = d_nodeManager->mkNode(UNION_MAX, emptyBag, A);
229 RewriteResponse response2 = d_rewriter->postRewrite(unionMax2);
230 ASSERT_TRUE(response2.d_node == A
231 && response2.d_status == REWRITE_AGAIN_FULL);
232
233 // (union_max A A) = A
234 Node unionMax3 = d_nodeManager->mkNode(UNION_MAX, A, A);
235 RewriteResponse response3 = d_rewriter->postRewrite(unionMax3);
236 ASSERT_TRUE(response3.d_node == A
237 && response3.d_status == REWRITE_AGAIN_FULL);
238
239 // (union_max A (union_max A B)) = (union_max A B)
240 Node unionMax4 = d_nodeManager->mkNode(UNION_MAX, A, unionMaxAB);
241 RewriteResponse response4 = d_rewriter->postRewrite(unionMax4);
242 ASSERT_TRUE(response4.d_node == unionMaxAB
243 && response4.d_status == REWRITE_AGAIN_FULL);
244
245 // (union_max A (union_max B A)) = (union_max B A)
246 Node unionMax5 = d_nodeManager->mkNode(UNION_MAX, A, unionMaxBA);
247 RewriteResponse response5 = d_rewriter->postRewrite(unionMax5);
248 ASSERT_TRUE(response5.d_node == unionMaxBA
249 && response4.d_status == REWRITE_AGAIN_FULL);
250
251 // (union_max (union_max A B) A) = (union_max A B)
252 Node unionMax6 = d_nodeManager->mkNode(UNION_MAX, unionMaxAB, A);
253 RewriteResponse response6 = d_rewriter->postRewrite(unionMax6);
254 ASSERT_TRUE(response6.d_node == unionMaxAB
255 && response6.d_status == REWRITE_AGAIN_FULL);
256
257 // (union_max (union_max B A) A) = (union_max B A)
258 Node unionMax7 = d_nodeManager->mkNode(UNION_MAX, unionMaxBA, A);
259 RewriteResponse response7 = d_rewriter->postRewrite(unionMax7);
260 ASSERT_TRUE(response7.d_node == unionMaxBA
261 && response7.d_status == REWRITE_AGAIN_FULL);
262
263 // (union_max A (union_disjoint A B)) = (union_disjoint A B)
264 Node unionMax8 = d_nodeManager->mkNode(UNION_MAX, A, unionDisjointAB);
265 RewriteResponse response8 = d_rewriter->postRewrite(unionMax8);
266 ASSERT_TRUE(response8.d_node == unionDisjointAB
267 && response8.d_status == REWRITE_AGAIN_FULL);
268
269 // (union_max A (union_disjoint B A)) = (union_disjoint B A)
270 Node unionMax9 = d_nodeManager->mkNode(UNION_MAX, A, unionDisjointBA);
271 RewriteResponse response9 = d_rewriter->postRewrite(unionMax9);
272 ASSERT_TRUE(response9.d_node == unionDisjointBA
273 && response9.d_status == REWRITE_AGAIN_FULL);
274
275 // (union_max (union_disjoint A B) A) = (union_disjoint A B)
276 Node unionMax10 = d_nodeManager->mkNode(UNION_MAX, unionDisjointAB, A);
277 RewriteResponse response10 = d_rewriter->postRewrite(unionMax10);
278 ASSERT_TRUE(response10.d_node == unionDisjointAB
279 && response10.d_status == REWRITE_AGAIN_FULL);
280
281 // (union_max (union_disjoint B A) A) = (union_disjoint B A)
282 Node unionMax11 = d_nodeManager->mkNode(UNION_MAX, unionDisjointBA, A);
283 RewriteResponse response11 = d_rewriter->postRewrite(unionMax11);
284 ASSERT_TRUE(response11.d_node == unionDisjointBA
285 && response11.d_status == REWRITE_AGAIN_FULL);
286 }
287
288 TEST_F(TestTheoryWhiteBagsRewriter, union_disjoint)
289 {
290 int n = 3;
291 std::vector<Node> elements = getNStrings(3);
292 Node emptyBag = d_nodeManager->mkConst(
293 EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
294 Node A = d_nodeManager->mkBag(d_nodeManager->stringType(),
295 elements[0],
296 d_nodeManager->mkConst(Rational(n)));
297 Node B = d_nodeManager->mkBag(d_nodeManager->stringType(),
298 elements[1],
299 d_nodeManager->mkConst(Rational(n + 1)));
300 Node C = d_nodeManager->mkBag(d_nodeManager->stringType(),
301 elements[2],
302 d_nodeManager->mkConst(Rational(n + 2)));
303
304 Node unionDisjointAB = d_nodeManager->mkNode(UNION_DISJOINT, A, B);
305 Node unionDisjointBA = d_nodeManager->mkNode(UNION_DISJOINT, B, A);
306 Node unionMaxAB = d_nodeManager->mkNode(UNION_MAX, A, B);
307 Node unionMaxAC = d_nodeManager->mkNode(UNION_MAX, A, C);
308 Node unionMaxBA = d_nodeManager->mkNode(UNION_MAX, B, A);
309 Node intersectionAB = d_nodeManager->mkNode(INTERSECTION_MIN, A, B);
310 Node intersectionBA = d_nodeManager->mkNode(INTERSECTION_MIN, B, A);
311
312 // (union_disjoint A emptybag) = A
313 Node unionDisjoint1 = d_nodeManager->mkNode(UNION_DISJOINT, A, emptyBag);
314 RewriteResponse response1 = d_rewriter->postRewrite(unionDisjoint1);
315 ASSERT_TRUE(response1.d_node == A
316 && response1.d_status == REWRITE_AGAIN_FULL);
317
318 // (union_disjoint emptybag A) = A
319 Node unionDisjoint2 = d_nodeManager->mkNode(UNION_DISJOINT, emptyBag, A);
320 RewriteResponse response2 = d_rewriter->postRewrite(unionDisjoint2);
321 ASSERT_TRUE(response2.d_node == A
322 && response2.d_status == REWRITE_AGAIN_FULL);
323
324 // (union_disjoint (union_max A B) (intersection_min B A)) =
325 // (union_disjoint A B) // sum(a,b) = max(a,b) + min(a,b)
326 Node unionDisjoint3 =
327 d_nodeManager->mkNode(UNION_DISJOINT, unionMaxAB, intersectionBA);
328 RewriteResponse response3 = d_rewriter->postRewrite(unionDisjoint3);
329 ASSERT_TRUE(response3.d_node == unionDisjointAB
330 && response3.d_status == REWRITE_AGAIN_FULL);
331
332 // (union_disjoint (intersection_min B A)) (union_max A B) =
333 // (union_disjoint B A) // sum(a,b) = max(a,b) + min(a,b)
334 Node unionDisjoint4 =
335 d_nodeManager->mkNode(UNION_DISJOINT, unionMaxBA, intersectionBA);
336 RewriteResponse response4 = d_rewriter->postRewrite(unionDisjoint4);
337 ASSERT_TRUE(response4.d_node == unionDisjointBA
338 && response4.d_status == REWRITE_AGAIN_FULL);
339
340 // (union_disjoint (intersection_min B A)) (union_max A B) =
341 // (union_disjoint B A) // sum(a,b) = max(a,b) + min(a,b)
342 Node unionDisjoint5 =
343 d_nodeManager->mkNode(UNION_DISJOINT, unionMaxAC, intersectionAB);
344 RewriteResponse response5 = d_rewriter->postRewrite(unionDisjoint5);
345 ASSERT_TRUE(response5.d_node == unionDisjoint5
346 && response5.d_status == REWRITE_DONE);
347 }
348
349 TEST_F(TestTheoryWhiteBagsRewriter, intersection_min)
350 {
351 int n = 3;
352 std::vector<Node> elements = getNStrings(2);
353 Node emptyBag = d_nodeManager->mkConst(
354 EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
355 Node A = d_nodeManager->mkBag(d_nodeManager->stringType(),
356 elements[0],
357 d_nodeManager->mkConst(Rational(n)));
358 Node B = d_nodeManager->mkBag(d_nodeManager->stringType(),
359 elements[1],
360 d_nodeManager->mkConst(Rational(n + 1)));
361 Node unionMaxAB = d_nodeManager->mkNode(UNION_MAX, A, B);
362 Node unionMaxBA = d_nodeManager->mkNode(UNION_MAX, B, A);
363 Node unionDisjointAB = d_nodeManager->mkNode(UNION_DISJOINT, A, B);
364 Node unionDisjointBA = d_nodeManager->mkNode(UNION_DISJOINT, B, A);
365
366 // (intersection_min A emptybag) = emptyBag
367 Node n1 = d_nodeManager->mkNode(INTERSECTION_MIN, A, emptyBag);
368 RewriteResponse response1 = d_rewriter->postRewrite(n1);
369 ASSERT_TRUE(response1.d_node == emptyBag
370 && response1.d_status == REWRITE_AGAIN_FULL);
371
372 // (intersection_min emptybag A) = emptyBag
373 Node n2 = d_nodeManager->mkNode(INTERSECTION_MIN, emptyBag, A);
374 RewriteResponse response2 = d_rewriter->postRewrite(n2);
375 ASSERT_TRUE(response2.d_node == emptyBag
376 && response2.d_status == REWRITE_AGAIN_FULL);
377
378 // (intersection_min A A) = A
379 Node n3 = d_nodeManager->mkNode(INTERSECTION_MIN, A, A);
380 RewriteResponse response3 = d_rewriter->postRewrite(n3);
381 ASSERT_TRUE(response3.d_node == A
382 && response3.d_status == REWRITE_AGAIN_FULL);
383
384 // (intersection_min A (union_max A B) = A
385 Node n4 = d_nodeManager->mkNode(INTERSECTION_MIN, A, unionMaxAB);
386 RewriteResponse response4 = d_rewriter->postRewrite(n4);
387 ASSERT_TRUE(response4.d_node == A
388 && response4.d_status == REWRITE_AGAIN_FULL);
389
390 // (intersection_min A (union_max B A) = A
391 Node n5 = d_nodeManager->mkNode(INTERSECTION_MIN, A, unionMaxBA);
392 RewriteResponse response5 = d_rewriter->postRewrite(n5);
393 ASSERT_TRUE(response5.d_node == A
394 && response4.d_status == REWRITE_AGAIN_FULL);
395
396 // (intersection_min (union_max A B) A) = A
397 Node n6 = d_nodeManager->mkNode(INTERSECTION_MIN, unionMaxAB, A);
398 RewriteResponse response6 = d_rewriter->postRewrite(n6);
399 ASSERT_TRUE(response6.d_node == A
400 && response6.d_status == REWRITE_AGAIN_FULL);
401
402 // (intersection_min (union_max B A) A) = A
403 Node n7 = d_nodeManager->mkNode(INTERSECTION_MIN, unionMaxBA, A);
404 RewriteResponse response7 = d_rewriter->postRewrite(n7);
405 ASSERT_TRUE(response7.d_node == A
406 && response7.d_status == REWRITE_AGAIN_FULL);
407
408 // (intersection_min A (union_disjoint A B) = A
409 Node n8 = d_nodeManager->mkNode(INTERSECTION_MIN, A, unionDisjointAB);
410 RewriteResponse response8 = d_rewriter->postRewrite(n8);
411 ASSERT_TRUE(response8.d_node == A
412 && response8.d_status == REWRITE_AGAIN_FULL);
413
414 // (intersection_min A (union_disjoint B A) = A
415 Node n9 = d_nodeManager->mkNode(INTERSECTION_MIN, A, unionDisjointBA);
416 RewriteResponse response9 = d_rewriter->postRewrite(n9);
417 ASSERT_TRUE(response9.d_node == A
418 && response9.d_status == REWRITE_AGAIN_FULL);
419
420 // (intersection_min (union_disjoint A B) A) = A
421 Node n10 = d_nodeManager->mkNode(INTERSECTION_MIN, unionDisjointAB, A);
422 RewriteResponse response10 = d_rewriter->postRewrite(n10);
423 ASSERT_TRUE(response10.d_node == A
424 && response10.d_status == REWRITE_AGAIN_FULL);
425
426 // (intersection_min (union_disjoint B A) A) = A
427 Node n11 = d_nodeManager->mkNode(INTERSECTION_MIN, unionDisjointBA, A);
428 RewriteResponse response11 = d_rewriter->postRewrite(n11);
429 ASSERT_TRUE(response11.d_node == A
430 && response11.d_status == REWRITE_AGAIN_FULL);
431 }
432
433 TEST_F(TestTheoryWhiteBagsRewriter, difference_subtract)
434 {
435 int n = 3;
436 std::vector<Node> elements = getNStrings(2);
437 Node emptyBag = d_nodeManager->mkConst(
438 EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
439 Node A = d_nodeManager->mkBag(d_nodeManager->stringType(),
440 elements[0],
441 d_nodeManager->mkConst(Rational(n)));
442 Node B = d_nodeManager->mkBag(d_nodeManager->stringType(),
443 elements[1],
444 d_nodeManager->mkConst(Rational(n + 1)));
445 Node unionMaxAB = d_nodeManager->mkNode(UNION_MAX, A, B);
446 Node unionMaxBA = d_nodeManager->mkNode(UNION_MAX, B, A);
447 Node unionDisjointAB = d_nodeManager->mkNode(UNION_DISJOINT, A, B);
448 Node unionDisjointBA = d_nodeManager->mkNode(UNION_DISJOINT, B, A);
449 Node intersectionAB = d_nodeManager->mkNode(INTERSECTION_MIN, A, B);
450 Node intersectionBA = d_nodeManager->mkNode(INTERSECTION_MIN, B, A);
451
452 // (difference_subtract A emptybag) = A
453 Node n1 = d_nodeManager->mkNode(DIFFERENCE_SUBTRACT, A, emptyBag);
454 RewriteResponse response1 = d_rewriter->postRewrite(n1);
455 ASSERT_TRUE(response1.d_node == A
456 && response1.d_status == REWRITE_AGAIN_FULL);
457
458 // (difference_subtract emptybag A) = emptyBag
459 Node n2 = d_nodeManager->mkNode(DIFFERENCE_SUBTRACT, emptyBag, A);
460 RewriteResponse response2 = d_rewriter->postRewrite(n2);
461 ASSERT_TRUE(response2.d_node == emptyBag
462 && response2.d_status == REWRITE_AGAIN_FULL);
463
464 // (difference_subtract A A) = emptybag
465 Node n3 = d_nodeManager->mkNode(DIFFERENCE_SUBTRACT, A, A);
466 RewriteResponse response3 = d_rewriter->postRewrite(n3);
467 ASSERT_TRUE(response3.d_node == emptyBag
468 && response3.d_status == REWRITE_AGAIN_FULL);
469
470 // (difference_subtract (union_disjoint A B) A) = B
471 Node n4 = d_nodeManager->mkNode(DIFFERENCE_SUBTRACT, unionDisjointAB, A);
472 RewriteResponse response4 = d_rewriter->postRewrite(n4);
473 ASSERT_TRUE(response4.d_node == B
474 && response4.d_status == REWRITE_AGAIN_FULL);
475
476 // (difference_subtract (union_disjoint B A) A) = B
477 Node n5 = d_nodeManager->mkNode(DIFFERENCE_SUBTRACT, unionDisjointBA, A);
478 RewriteResponse response5 = d_rewriter->postRewrite(n5);
479 ASSERT_TRUE(response5.d_node == B
480 && response4.d_status == REWRITE_AGAIN_FULL);
481
482 // (difference_subtract A (union_disjoint A B)) = emptybag
483 Node n6 = d_nodeManager->mkNode(DIFFERENCE_SUBTRACT, A, unionDisjointAB);
484 RewriteResponse response6 = d_rewriter->postRewrite(n6);
485 ASSERT_TRUE(response6.d_node == emptyBag
486 && response6.d_status == REWRITE_AGAIN_FULL);
487
488 // (difference_subtract A (union_disjoint B A)) = emptybag
489 Node n7 = d_nodeManager->mkNode(DIFFERENCE_SUBTRACT, A, unionDisjointBA);
490 RewriteResponse response7 = d_rewriter->postRewrite(n7);
491 ASSERT_TRUE(response7.d_node == emptyBag
492 && response7.d_status == REWRITE_AGAIN_FULL);
493
494 // (difference_subtract A (union_max A B)) = emptybag
495 Node n8 = d_nodeManager->mkNode(DIFFERENCE_SUBTRACT, A, unionMaxAB);
496 RewriteResponse response8 = d_rewriter->postRewrite(n8);
497 ASSERT_TRUE(response8.d_node == emptyBag
498 && response8.d_status == REWRITE_AGAIN_FULL);
499
500 // (difference_subtract A (union_max B A)) = emptybag
501 Node n9 = d_nodeManager->mkNode(DIFFERENCE_SUBTRACT, A, unionMaxBA);
502 RewriteResponse response9 = d_rewriter->postRewrite(n9);
503 ASSERT_TRUE(response9.d_node == emptyBag
504 && response9.d_status == REWRITE_AGAIN_FULL);
505
506 // (difference_subtract (intersection_min A B) A) = emptybag
507 Node n10 = d_nodeManager->mkNode(DIFFERENCE_SUBTRACT, intersectionAB, A);
508 RewriteResponse response10 = d_rewriter->postRewrite(n10);
509 ASSERT_TRUE(response10.d_node == emptyBag
510 && response10.d_status == REWRITE_AGAIN_FULL);
511
512 // (difference_subtract (intersection_min B A) A) = emptybag
513 Node n11 = d_nodeManager->mkNode(DIFFERENCE_SUBTRACT, intersectionBA, A);
514 RewriteResponse response11 = d_rewriter->postRewrite(n11);
515 ASSERT_TRUE(response11.d_node == emptyBag
516 && response11.d_status == REWRITE_AGAIN_FULL);
517 }
518
519 TEST_F(TestTheoryWhiteBagsRewriter, difference_remove)
520 {
521 int n = 3;
522 std::vector<Node> elements = getNStrings(2);
523 Node emptyBag = d_nodeManager->mkConst(
524 EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
525 Node A = d_nodeManager->mkBag(d_nodeManager->stringType(),
526 elements[0],
527 d_nodeManager->mkConst(Rational(n)));
528 Node B = d_nodeManager->mkBag(d_nodeManager->stringType(),
529 elements[1],
530 d_nodeManager->mkConst(Rational(n + 1)));
531 Node unionMaxAB = d_nodeManager->mkNode(UNION_MAX, A, B);
532 Node unionMaxBA = d_nodeManager->mkNode(UNION_MAX, B, A);
533 Node unionDisjointAB = d_nodeManager->mkNode(UNION_DISJOINT, A, B);
534 Node unionDisjointBA = d_nodeManager->mkNode(UNION_DISJOINT, B, A);
535 Node intersectionAB = d_nodeManager->mkNode(INTERSECTION_MIN, A, B);
536 Node intersectionBA = d_nodeManager->mkNode(INTERSECTION_MIN, B, A);
537
538 // (difference_remove A emptybag) = A
539 Node n1 = d_nodeManager->mkNode(DIFFERENCE_REMOVE, A, emptyBag);
540 RewriteResponse response1 = d_rewriter->postRewrite(n1);
541 ASSERT_TRUE(response1.d_node == A
542 && response1.d_status == REWRITE_AGAIN_FULL);
543
544 // (difference_remove emptybag A) = emptyBag
545 Node n2 = d_nodeManager->mkNode(DIFFERENCE_REMOVE, emptyBag, A);
546 RewriteResponse response2 = d_rewriter->postRewrite(n2);
547 ASSERT_TRUE(response2.d_node == emptyBag
548 && response2.d_status == REWRITE_AGAIN_FULL);
549
550 // (difference_remove A A) = emptybag
551 Node n3 = d_nodeManager->mkNode(DIFFERENCE_REMOVE, A, A);
552 RewriteResponse response3 = d_rewriter->postRewrite(n3);
553 ASSERT_TRUE(response3.d_node == emptyBag
554 && response3.d_status == REWRITE_AGAIN_FULL);
555
556 // (difference_remove A (union_disjoint A B)) = emptybag
557 Node n6 = d_nodeManager->mkNode(DIFFERENCE_REMOVE, A, unionDisjointAB);
558 RewriteResponse response6 = d_rewriter->postRewrite(n6);
559 ASSERT_TRUE(response6.d_node == emptyBag
560 && response6.d_status == REWRITE_AGAIN_FULL);
561
562 // (difference_remove A (union_disjoint B A)) = emptybag
563 Node n7 = d_nodeManager->mkNode(DIFFERENCE_REMOVE, A, unionDisjointBA);
564 RewriteResponse response7 = d_rewriter->postRewrite(n7);
565 ASSERT_TRUE(response7.d_node == emptyBag
566 && response7.d_status == REWRITE_AGAIN_FULL);
567
568 // (difference_remove A (union_max A B)) = emptybag
569 Node n8 = d_nodeManager->mkNode(DIFFERENCE_REMOVE, A, unionMaxAB);
570 RewriteResponse response8 = d_rewriter->postRewrite(n8);
571 ASSERT_TRUE(response8.d_node == emptyBag
572 && response8.d_status == REWRITE_AGAIN_FULL);
573
574 // (difference_remove A (union_max B A)) = emptybag
575 Node n9 = d_nodeManager->mkNode(DIFFERENCE_REMOVE, A, unionMaxBA);
576 RewriteResponse response9 = d_rewriter->postRewrite(n9);
577 ASSERT_TRUE(response9.d_node == emptyBag
578 && response9.d_status == REWRITE_AGAIN_FULL);
579
580 // (difference_remove (intersection_min A B) A) = emptybag
581 Node n10 = d_nodeManager->mkNode(DIFFERENCE_REMOVE, intersectionAB, A);
582 RewriteResponse response10 = d_rewriter->postRewrite(n10);
583 ASSERT_TRUE(response10.d_node == emptyBag
584 && response10.d_status == REWRITE_AGAIN_FULL);
585
586 // (difference_remove (intersection_min B A) A) = emptybag
587 Node n11 = d_nodeManager->mkNode(DIFFERENCE_REMOVE, intersectionBA, A);
588 RewriteResponse response11 = d_rewriter->postRewrite(n11);
589 ASSERT_TRUE(response11.d_node == emptyBag
590 && response11.d_status == REWRITE_AGAIN_FULL);
591 }
592
593 TEST_F(TestTheoryWhiteBagsRewriter, choose)
594 {
595 Node x = d_skolemManager->mkDummySkolem("x", d_nodeManager->stringType());
596 Node c = d_nodeManager->mkConst(Rational(3));
597 Node bag = d_nodeManager->mkBag(d_nodeManager->stringType(), x, c);
598
599 // (bag.choose (mkBag x c)) = x where c is a constant > 0
600 Node n1 = d_nodeManager->mkNode(BAG_CHOOSE, bag);
601 RewriteResponse response1 = d_rewriter->postRewrite(n1);
602 ASSERT_TRUE(response1.d_node == x
603 && response1.d_status == REWRITE_AGAIN_FULL);
604 }
605
606 TEST_F(TestTheoryWhiteBagsRewriter, bag_card)
607 {
608 Node x = d_skolemManager->mkDummySkolem("x", d_nodeManager->stringType());
609 Node emptyBag = d_nodeManager->mkConst(
610 EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
611 Node zero = d_nodeManager->mkConst(Rational(0));
612 Node c = d_nodeManager->mkConst(Rational(3));
613 Node bag = d_nodeManager->mkBag(d_nodeManager->stringType(), x, c);
614 std::vector<Node> elements = getNStrings(2);
615 Node A = d_nodeManager->mkBag(d_nodeManager->stringType(),
616 elements[0],
617 d_nodeManager->mkConst(Rational(4)));
618 Node B = d_nodeManager->mkBag(d_nodeManager->stringType(),
619 elements[1],
620 d_nodeManager->mkConst(Rational(5)));
621 Node unionDisjointAB = d_nodeManager->mkNode(UNION_DISJOINT, A, B);
622
623 // TODO(projects#223): enable this test after implementing bags normal form
624 // // (bag.card emptybag) = 0
625 // Node n1 = d_nodeManager->mkNode(BAG_CARD, emptyBag);
626 // RewriteResponse response1 = d_rewriter->postRewrite(n1);
627 // ASSERT_TRUE(response1.d_node == zero && response1.d_status ==
628 // REWRITE_AGAIN_FULL);
629
630 // (bag.card (mkBag x c)) = c where c is a constant > 0
631 Node n2 = d_nodeManager->mkNode(BAG_CARD, bag);
632 RewriteResponse response2 = d_rewriter->postRewrite(n2);
633 ASSERT_TRUE(response2.d_node == c
634 && response2.d_status == REWRITE_AGAIN_FULL);
635
636 // (bag.card (union-disjoint A B)) = (+ (bag.card A) (bag.card B))
637 Node n3 = d_nodeManager->mkNode(BAG_CARD, unionDisjointAB);
638 Node cardA = d_nodeManager->mkNode(BAG_CARD, A);
639 Node cardB = d_nodeManager->mkNode(BAG_CARD, B);
640 Node plus = d_nodeManager->mkNode(PLUS, cardA, cardB);
641 RewriteResponse response3 = d_rewriter->postRewrite(n3);
642 ASSERT_TRUE(response3.d_node == plus
643 && response3.d_status == REWRITE_AGAIN_FULL);
644 }
645
646 TEST_F(TestTheoryWhiteBagsRewriter, is_singleton)
647 {
648 Node emptybag = d_nodeManager->mkConst(
649 EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
650 Node x = d_skolemManager->mkDummySkolem("x", d_nodeManager->stringType());
651 Node c = d_skolemManager->mkDummySkolem("c", d_nodeManager->integerType());
652 Node bag = d_nodeManager->mkBag(d_nodeManager->stringType(), x, c);
653
654 // TODO(projects#223): complete this function
655 // (bag.is_singleton emptybag) = false
656 // Node n1 = d_nodeManager->mkNode(BAG_IS_SINGLETON, emptybag);
657 // RewriteResponse response1 = d_rewriter->postRewrite(n1);
658 // ASSERT_TRUE(response1.d_node == d_nodeManager->mkConst(false)
659 // && response1.d_status == REWRITE_AGAIN_FULL);
660
661 // (bag.is_singleton (mkBag x c) = (c == 1)
662 Node n2 = d_nodeManager->mkNode(BAG_IS_SINGLETON, bag);
663 RewriteResponse response2 = d_rewriter->postRewrite(n2);
664 Node one = d_nodeManager->mkConst(Rational(1));
665 Node equal = c.eqNode(one);
666 ASSERT_TRUE(response2.d_node == equal
667 && response2.d_status == REWRITE_AGAIN_FULL);
668 }
669
670 TEST_F(TestTheoryWhiteBagsRewriter, from_set)
671 {
672 Node x = d_skolemManager->mkDummySkolem("x", d_nodeManager->stringType());
673 Node singleton = d_nodeManager->mkSingleton(d_nodeManager->stringType(), x);
674
675 // (bag.from_set (singleton (singleton_op Int) x)) = (mkBag x 1)
676 Node n = d_nodeManager->mkNode(BAG_FROM_SET, singleton);
677 RewriteResponse response = d_rewriter->postRewrite(n);
678 Node one = d_nodeManager->mkConst(Rational(1));
679 Node bag = d_nodeManager->mkBag(d_nodeManager->stringType(), x, one);
680 ASSERT_TRUE(response.d_node == bag
681 && response.d_status == REWRITE_AGAIN_FULL);
682 }
683
684 TEST_F(TestTheoryWhiteBagsRewriter, to_set)
685 {
686 Node x = d_skolemManager->mkDummySkolem("x", d_nodeManager->stringType());
687 Node bag = d_nodeManager->mkBag(
688 d_nodeManager->stringType(), x, d_nodeManager->mkConst(Rational(5)));
689
690 // (bag.to_set (mkBag x n)) = (singleton (singleton_op T) x)
691 Node n = d_nodeManager->mkNode(BAG_TO_SET, bag);
692 RewriteResponse response = d_rewriter->postRewrite(n);
693 Node singleton = d_nodeManager->mkSingleton(d_nodeManager->stringType(), x);
694 ASSERT_TRUE(response.d_node == singleton
695 && response.d_status == REWRITE_AGAIN_FULL);
696 }
697
698 TEST_F(TestTheoryWhiteBagsRewriter, map)
699 {
700 TypeNode bagStringType =
701 d_nodeManager->mkBagType(d_nodeManager->stringType());
702 Node emptybagString = d_nodeManager->mkConst(EmptyBag(bagStringType));
703
704 Node one = d_nodeManager->mkConst(Rational(1));
705 Node x = d_nodeManager->mkBoundVar("x", d_nodeManager->integerType());
706 std::vector<Node> args;
707 args.push_back(x);
708 Node bound = d_nodeManager->mkNode(kind::BOUND_VAR_LIST, args);
709 Node lambda = d_nodeManager->mkNode(LAMBDA, bound, one);
710
711 // (bag.map (lambda ((x U)) t) emptybag) = emptybag
712 Node n1 = d_nodeManager->mkNode(BAG_MAP, lambda, emptybagString);
713 RewriteResponse response1 = d_rewriter->postRewrite(n1);
714 TypeNode bagIntType = d_nodeManager->mkBagType(d_nodeManager->integerType());
715 Node emptybagInteger = d_nodeManager->mkConst(EmptyBag(bagIntType));
716 ASSERT_TRUE(response1.d_node == emptybagInteger
717 && response1.d_status == REWRITE_AGAIN_FULL);
718
719 std::vector<Node> elements = getNStrings(2);
720 Node a = d_nodeManager->mkConst(String("a"));
721 Node b = d_nodeManager->mkConst(String("b"));
722 Node A = d_nodeManager->mkBag(d_nodeManager->stringType(),
723 a,
724 d_nodeManager->mkConst(Rational(3)));
725 Node B = d_nodeManager->mkBag(d_nodeManager->stringType(),
726 b,
727 d_nodeManager->mkConst(Rational(4)));
728 Node unionDisjointAB = d_nodeManager->mkNode(UNION_DISJOINT, A, B);
729
730 ASSERT_TRUE(unionDisjointAB.isConst());
731 // - (bag.map (lambda ((x Int)) 1) (union_disjoint (bag "a" 3) (bag "b" 4))) =
732 // (bag 1 7))
733 Node n2 = d_nodeManager->mkNode(BAG_MAP, lambda, unionDisjointAB);
734
735 std::cout << n2 << std::endl;
736
737 Node rewritten = Rewriter:: rewrite(n2);
738 std::cout << rewritten << std::endl;
739
740 Node bag = d_nodeManager->mkBag(d_nodeManager->integerType(),
741 one, d_nodeManager->mkConst(Rational(7)));
742 ASSERT_TRUE(rewritten == bag);
743 }
744
745 } // namespace test
746 } // namespace cvc5