Fix soundess issue for bags with negative multiplicity (#7539)
[cvc5.git] / test / unit / theory / theory_bags_rewriter_white.cpp
1 /******************************************************************************
2 * Top contributors (to current version):
3 * Aina Niemetz, Mudathir Mohamed, Andrew Reynolds
4 *
5 * This file is part of the cvc5 project.
6 *
7 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 * in the top-level source directory and their institutional affiliations.
9 * All rights reserved. See the file COPYING in the top-level source
10 * directory for licensing information.
11 * ****************************************************************************
12 *
13 * White box testing of bags rewriter
14 */
15
16 #include "expr/dtype.h"
17 #include "expr/emptybag.h"
18 #include "test_smt.h"
19 #include "theory/bags/bags_rewriter.h"
20 #include "theory/strings/type_enumerator.h"
21 #include "util/rational.h"
22 #include "util/string.h"
23
24 namespace cvc5 {
25
26 using namespace theory;
27 using namespace kind;
28 using namespace theory::bags;
29
30 namespace test {
31
32 typedef expr::Attribute<Node, Node> attribute;
33
34 class TestTheoryWhiteBagsRewriter : public TestSmt
35 {
36 protected:
37 void SetUp() override
38 {
39 TestSmt::SetUp();
40 d_rewriter.reset(new BagsRewriter(nullptr));
41 }
42
43 std::vector<Node> getNStrings(size_t n)
44 {
45 std::vector<Node> elements(n);
46 for (size_t i = 0; i < n; i++)
47 {
48 elements[i] =
49 d_skolemManager->mkDummySkolem("x", d_nodeManager->stringType());
50 }
51 return elements;
52 }
53
54 std::unique_ptr<BagsRewriter> d_rewriter;
55 };
56
57 TEST_F(TestTheoryWhiteBagsRewriter, empty_bag_normal_form)
58 {
59 Node emptybag = d_nodeManager->mkConst(EmptyBag(d_nodeManager->stringType()));
60 // empty bags are in normal form
61 ASSERT_TRUE(emptybag.isConst());
62 RewriteResponse response = d_rewriter->postRewrite(emptybag);
63 ASSERT_TRUE(emptybag == response.d_node && response.d_status == REWRITE_DONE);
64 }
65
66 TEST_F(TestTheoryWhiteBagsRewriter, bag_equality)
67 {
68 std::vector<Node> elements = getNStrings(2);
69 Node x = elements[0];
70 Node y = elements[1];
71 Node c = d_skolemManager->mkDummySkolem("c", d_nodeManager->integerType());
72 Node d = d_skolemManager->mkDummySkolem("d", d_nodeManager->integerType());
73 Node A = d_nodeManager->mkBag(d_nodeManager->stringType(), x, c);
74 Node B = d_nodeManager->mkBag(d_nodeManager->stringType(), y, d);
75 Node emptyBag = d_nodeManager->mkConst(
76 EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
77 Node emptyString = d_nodeManager->mkConst(String(""));
78 Node constantBag = d_nodeManager->mkBag(d_nodeManager->stringType(),
79 emptyString,
80 d_nodeManager->mkConst(Rational(1)));
81
82 // (= A A) = true where A is a bag
83 Node n1 = A.eqNode(A);
84 RewriteResponse response1 = d_rewriter->preRewrite(n1);
85 ASSERT_TRUE(response1.d_node == d_nodeManager->mkConst(true)
86 && response1.d_status == REWRITE_AGAIN_FULL);
87
88 // (= A B) = false if A and B are different bag constants
89 Node n2 = constantBag.eqNode(emptyBag);
90 RewriteResponse response2 = d_rewriter->postRewrite(n2);
91 ASSERT_TRUE(response2.d_node == d_nodeManager->mkConst(false)
92 && response2.d_status == REWRITE_AGAIN_FULL);
93
94 // (= B A) = (= A B) if A < B and at least one of A or B is not a constant
95 Node n3 = B.eqNode(A);
96 RewriteResponse response3 = d_rewriter->postRewrite(n3);
97 ASSERT_TRUE(response3.d_node == A.eqNode(B)
98 && response3.d_status == REWRITE_AGAIN_FULL);
99
100 // (= A B) = (= A B) no rewrite
101 Node n4 = A.eqNode(B);
102 RewriteResponse response4 = d_rewriter->postRewrite(n4);
103 ASSERT_TRUE(response4.d_node == n4 && response4.d_status == REWRITE_DONE);
104 }
105
106 TEST_F(TestTheoryWhiteBagsRewriter, mkBag_constant_element)
107 {
108 std::vector<Node> elements = getNStrings(1);
109 Node negative = d_nodeManager->mkBag(d_nodeManager->stringType(),
110 elements[0],
111 d_nodeManager->mkConst(Rational(-1)));
112 Node zero = d_nodeManager->mkBag(d_nodeManager->stringType(),
113 elements[0],
114 d_nodeManager->mkConst(Rational(0)));
115 Node positive = d_nodeManager->mkBag(d_nodeManager->stringType(),
116 elements[0],
117 d_nodeManager->mkConst(Rational(1)));
118 Node emptybag = d_nodeManager->mkConst(
119 EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
120 RewriteResponse negativeResponse = d_rewriter->postRewrite(negative);
121 RewriteResponse zeroResponse = d_rewriter->postRewrite(zero);
122 RewriteResponse positiveResponse = d_rewriter->postRewrite(positive);
123
124 // bags with non-positive multiplicity are rewritten as empty bags
125 ASSERT_TRUE(negativeResponse.d_status == REWRITE_AGAIN_FULL
126 && negativeResponse.d_node == emptybag);
127 ASSERT_TRUE(zeroResponse.d_status == REWRITE_AGAIN_FULL
128 && zeroResponse.d_node == emptybag);
129
130 // no change for positive
131 ASSERT_TRUE(positiveResponse.d_status == REWRITE_DONE
132 && positive == positiveResponse.d_node);
133 }
134
135 TEST_F(TestTheoryWhiteBagsRewriter, mkBag_variable_element)
136 {
137 Node skolem =
138 d_skolemManager->mkDummySkolem("x", d_nodeManager->stringType());
139 Node variable = d_nodeManager->mkBag(d_nodeManager->stringType(),
140 skolem,
141 d_nodeManager->mkConst(Rational(-1)));
142 Node negative = d_nodeManager->mkBag(d_nodeManager->stringType(),
143 skolem,
144 d_nodeManager->mkConst(Rational(-1)));
145 Node zero = d_nodeManager->mkBag(
146 d_nodeManager->stringType(), skolem, d_nodeManager->mkConst(Rational(0)));
147 Node positive = d_nodeManager->mkBag(
148 d_nodeManager->stringType(), skolem, d_nodeManager->mkConst(Rational(1)));
149 Node emptybag = d_nodeManager->mkConst(
150 EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
151 RewriteResponse negativeResponse = d_rewriter->postRewrite(negative);
152 RewriteResponse zeroResponse = d_rewriter->postRewrite(zero);
153 RewriteResponse positiveResponse = d_rewriter->postRewrite(positive);
154
155 // bags with non-positive multiplicity are rewritten as empty bags
156 ASSERT_TRUE(negativeResponse.d_status == REWRITE_AGAIN_FULL
157 && negativeResponse.d_node == emptybag);
158 ASSERT_TRUE(zeroResponse.d_status == REWRITE_AGAIN_FULL
159 && zeroResponse.d_node == emptybag);
160
161 // no change for positive
162 ASSERT_TRUE(positiveResponse.d_status == REWRITE_DONE
163 && positive == positiveResponse.d_node);
164 }
165
166 TEST_F(TestTheoryWhiteBagsRewriter, bag_count)
167 {
168 Node zero = d_nodeManager->mkConst(Rational(0));
169 Node one = d_nodeManager->mkConst(Rational(1));
170 Node three = d_nodeManager->mkConst(Rational(3));
171 Node skolem =
172 d_skolemManager->mkDummySkolem("x", d_nodeManager->stringType());
173 Node emptyBag = d_nodeManager->mkConst(
174 EmptyBag(d_nodeManager->mkBagType(skolem.getType())));
175
176 // (bag.count x emptybag) = 0
177 Node n1 = d_nodeManager->mkNode(BAG_COUNT, skolem, emptyBag);
178 RewriteResponse response1 = d_rewriter->postRewrite(n1);
179 ASSERT_TRUE(response1.d_status == REWRITE_AGAIN_FULL
180 && response1.d_node == zero);
181
182 // (bag.count x (mkBag x c) = (ite (>= c 1) c 0)
183 Node bag = d_nodeManager->mkBag(d_nodeManager->stringType(), skolem, three);
184 Node n2 = d_nodeManager->mkNode(BAG_COUNT, skolem, bag);
185 RewriteResponse response2 = d_rewriter->postRewrite(n2);
186
187 Node geq = d_nodeManager->mkNode(GEQ, three, one);
188 Node ite = d_nodeManager->mkNode(ITE, geq, three, zero);
189 ASSERT_TRUE(response2.d_status == REWRITE_AGAIN_FULL
190 && response2.d_node == ite);
191 }
192
193 TEST_F(TestTheoryWhiteBagsRewriter, duplicate_removal)
194 {
195 Node x = d_skolemManager->mkDummySkolem("x", d_nodeManager->stringType());
196 Node bag = d_nodeManager->mkBag(
197 d_nodeManager->stringType(), x, d_nodeManager->mkConst(Rational(5)));
198
199 // (duplicate_removal (mkBag x n)) = (mkBag x 1)
200 Node n = d_nodeManager->mkNode(DUPLICATE_REMOVAL, bag);
201 RewriteResponse response = d_rewriter->postRewrite(n);
202 Node noDuplicate = d_nodeManager->mkBag(
203 d_nodeManager->stringType(), x, d_nodeManager->mkConst(Rational(1)));
204 ASSERT_TRUE(response.d_node == noDuplicate
205 && response.d_status == REWRITE_AGAIN_FULL);
206 }
207
208 TEST_F(TestTheoryWhiteBagsRewriter, union_max)
209 {
210 int n = 3;
211 std::vector<Node> elements = getNStrings(2);
212 Node emptyBag = d_nodeManager->mkConst(
213 EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
214 Node A = d_nodeManager->mkBag(d_nodeManager->stringType(),
215 elements[0],
216 d_nodeManager->mkConst(Rational(n)));
217 Node B = d_nodeManager->mkBag(d_nodeManager->stringType(),
218 elements[1],
219 d_nodeManager->mkConst(Rational(n + 1)));
220 Node unionMaxAB = d_nodeManager->mkNode(UNION_MAX, A, B);
221 Node unionMaxBA = d_nodeManager->mkNode(UNION_MAX, B, A);
222 Node unionDisjointAB = d_nodeManager->mkNode(UNION_DISJOINT, A, B);
223 Node unionDisjointBA = d_nodeManager->mkNode(UNION_DISJOINT, B, A);
224
225 // (union_max A emptybag) = A
226 Node unionMax1 = d_nodeManager->mkNode(UNION_MAX, A, emptyBag);
227 RewriteResponse response1 = d_rewriter->postRewrite(unionMax1);
228 ASSERT_TRUE(response1.d_node == A
229 && response1.d_status == REWRITE_AGAIN_FULL);
230
231 // (union_max emptybag A) = A
232 Node unionMax2 = d_nodeManager->mkNode(UNION_MAX, emptyBag, A);
233 RewriteResponse response2 = d_rewriter->postRewrite(unionMax2);
234 ASSERT_TRUE(response2.d_node == A
235 && response2.d_status == REWRITE_AGAIN_FULL);
236
237 // (union_max A A) = A
238 Node unionMax3 = d_nodeManager->mkNode(UNION_MAX, A, A);
239 RewriteResponse response3 = d_rewriter->postRewrite(unionMax3);
240 ASSERT_TRUE(response3.d_node == A
241 && response3.d_status == REWRITE_AGAIN_FULL);
242
243 // (union_max A (union_max A B)) = (union_max A B)
244 Node unionMax4 = d_nodeManager->mkNode(UNION_MAX, A, unionMaxAB);
245 RewriteResponse response4 = d_rewriter->postRewrite(unionMax4);
246 ASSERT_TRUE(response4.d_node == unionMaxAB
247 && response4.d_status == REWRITE_AGAIN_FULL);
248
249 // (union_max A (union_max B A)) = (union_max B A)
250 Node unionMax5 = d_nodeManager->mkNode(UNION_MAX, A, unionMaxBA);
251 RewriteResponse response5 = d_rewriter->postRewrite(unionMax5);
252 ASSERT_TRUE(response5.d_node == unionMaxBA
253 && response4.d_status == REWRITE_AGAIN_FULL);
254
255 // (union_max (union_max A B) A) = (union_max A B)
256 Node unionMax6 = d_nodeManager->mkNode(UNION_MAX, unionMaxAB, A);
257 RewriteResponse response6 = d_rewriter->postRewrite(unionMax6);
258 ASSERT_TRUE(response6.d_node == unionMaxAB
259 && response6.d_status == REWRITE_AGAIN_FULL);
260
261 // (union_max (union_max B A) A) = (union_max B A)
262 Node unionMax7 = d_nodeManager->mkNode(UNION_MAX, unionMaxBA, A);
263 RewriteResponse response7 = d_rewriter->postRewrite(unionMax7);
264 ASSERT_TRUE(response7.d_node == unionMaxBA
265 && response7.d_status == REWRITE_AGAIN_FULL);
266
267 // (union_max A (union_disjoint A B)) = (union_disjoint A B)
268 Node unionMax8 = d_nodeManager->mkNode(UNION_MAX, A, unionDisjointAB);
269 RewriteResponse response8 = d_rewriter->postRewrite(unionMax8);
270 ASSERT_TRUE(response8.d_node == unionDisjointAB
271 && response8.d_status == REWRITE_AGAIN_FULL);
272
273 // (union_max A (union_disjoint B A)) = (union_disjoint B A)
274 Node unionMax9 = d_nodeManager->mkNode(UNION_MAX, A, unionDisjointBA);
275 RewriteResponse response9 = d_rewriter->postRewrite(unionMax9);
276 ASSERT_TRUE(response9.d_node == unionDisjointBA
277 && response9.d_status == REWRITE_AGAIN_FULL);
278
279 // (union_max (union_disjoint A B) A) = (union_disjoint A B)
280 Node unionMax10 = d_nodeManager->mkNode(UNION_MAX, unionDisjointAB, A);
281 RewriteResponse response10 = d_rewriter->postRewrite(unionMax10);
282 ASSERT_TRUE(response10.d_node == unionDisjointAB
283 && response10.d_status == REWRITE_AGAIN_FULL);
284
285 // (union_max (union_disjoint B A) A) = (union_disjoint B A)
286 Node unionMax11 = d_nodeManager->mkNode(UNION_MAX, unionDisjointBA, A);
287 RewriteResponse response11 = d_rewriter->postRewrite(unionMax11);
288 ASSERT_TRUE(response11.d_node == unionDisjointBA
289 && response11.d_status == REWRITE_AGAIN_FULL);
290 }
291
292 TEST_F(TestTheoryWhiteBagsRewriter, union_disjoint)
293 {
294 int n = 3;
295 std::vector<Node> elements = getNStrings(3);
296 Node emptyBag = d_nodeManager->mkConst(
297 EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
298 Node A = d_nodeManager->mkBag(d_nodeManager->stringType(),
299 elements[0],
300 d_nodeManager->mkConst(Rational(n)));
301 Node B = d_nodeManager->mkBag(d_nodeManager->stringType(),
302 elements[1],
303 d_nodeManager->mkConst(Rational(n + 1)));
304 Node C = d_nodeManager->mkBag(d_nodeManager->stringType(),
305 elements[2],
306 d_nodeManager->mkConst(Rational(n + 2)));
307
308 Node unionDisjointAB = d_nodeManager->mkNode(UNION_DISJOINT, A, B);
309 Node unionDisjointBA = d_nodeManager->mkNode(UNION_DISJOINT, B, A);
310 Node unionMaxAB = d_nodeManager->mkNode(UNION_MAX, A, B);
311 Node unionMaxAC = d_nodeManager->mkNode(UNION_MAX, A, C);
312 Node unionMaxBA = d_nodeManager->mkNode(UNION_MAX, B, A);
313 Node intersectionAB = d_nodeManager->mkNode(INTERSECTION_MIN, A, B);
314 Node intersectionBA = d_nodeManager->mkNode(INTERSECTION_MIN, B, A);
315
316 // (union_disjoint A emptybag) = A
317 Node unionDisjoint1 = d_nodeManager->mkNode(UNION_DISJOINT, A, emptyBag);
318 RewriteResponse response1 = d_rewriter->postRewrite(unionDisjoint1);
319 ASSERT_TRUE(response1.d_node == A
320 && response1.d_status == REWRITE_AGAIN_FULL);
321
322 // (union_disjoint emptybag A) = A
323 Node unionDisjoint2 = d_nodeManager->mkNode(UNION_DISJOINT, emptyBag, A);
324 RewriteResponse response2 = d_rewriter->postRewrite(unionDisjoint2);
325 ASSERT_TRUE(response2.d_node == A
326 && response2.d_status == REWRITE_AGAIN_FULL);
327
328 // (union_disjoint (union_max A B) (intersection_min B A)) =
329 // (union_disjoint A B) // sum(a,b) = max(a,b) + min(a,b)
330 Node unionDisjoint3 =
331 d_nodeManager->mkNode(UNION_DISJOINT, unionMaxAB, intersectionBA);
332 RewriteResponse response3 = d_rewriter->postRewrite(unionDisjoint3);
333 ASSERT_TRUE(response3.d_node == unionDisjointAB
334 && response3.d_status == REWRITE_AGAIN_FULL);
335
336 // (union_disjoint (intersection_min B A)) (union_max A B) =
337 // (union_disjoint B A) // sum(a,b) = max(a,b) + min(a,b)
338 Node unionDisjoint4 =
339 d_nodeManager->mkNode(UNION_DISJOINT, unionMaxBA, intersectionBA);
340 RewriteResponse response4 = d_rewriter->postRewrite(unionDisjoint4);
341 ASSERT_TRUE(response4.d_node == unionDisjointBA
342 && response4.d_status == REWRITE_AGAIN_FULL);
343
344 // (union_disjoint (intersection_min B A)) (union_max A B) =
345 // (union_disjoint B A) // sum(a,b) = max(a,b) + min(a,b)
346 Node unionDisjoint5 =
347 d_nodeManager->mkNode(UNION_DISJOINT, unionMaxAC, intersectionAB);
348 RewriteResponse response5 = d_rewriter->postRewrite(unionDisjoint5);
349 ASSERT_TRUE(response5.d_node == unionDisjoint5
350 && response5.d_status == REWRITE_DONE);
351 }
352
353 TEST_F(TestTheoryWhiteBagsRewriter, intersection_min)
354 {
355 int n = 3;
356 std::vector<Node> elements = getNStrings(2);
357 Node emptyBag = d_nodeManager->mkConst(
358 EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
359 Node A = d_nodeManager->mkBag(d_nodeManager->stringType(),
360 elements[0],
361 d_nodeManager->mkConst(Rational(n)));
362 Node B = d_nodeManager->mkBag(d_nodeManager->stringType(),
363 elements[1],
364 d_nodeManager->mkConst(Rational(n + 1)));
365 Node unionMaxAB = d_nodeManager->mkNode(UNION_MAX, A, B);
366 Node unionMaxBA = d_nodeManager->mkNode(UNION_MAX, B, A);
367 Node unionDisjointAB = d_nodeManager->mkNode(UNION_DISJOINT, A, B);
368 Node unionDisjointBA = d_nodeManager->mkNode(UNION_DISJOINT, B, A);
369
370 // (intersection_min A emptybag) = emptyBag
371 Node n1 = d_nodeManager->mkNode(INTERSECTION_MIN, A, emptyBag);
372 RewriteResponse response1 = d_rewriter->postRewrite(n1);
373 ASSERT_TRUE(response1.d_node == emptyBag
374 && response1.d_status == REWRITE_AGAIN_FULL);
375
376 // (intersection_min emptybag A) = emptyBag
377 Node n2 = d_nodeManager->mkNode(INTERSECTION_MIN, emptyBag, A);
378 RewriteResponse response2 = d_rewriter->postRewrite(n2);
379 ASSERT_TRUE(response2.d_node == emptyBag
380 && response2.d_status == REWRITE_AGAIN_FULL);
381
382 // (intersection_min A A) = A
383 Node n3 = d_nodeManager->mkNode(INTERSECTION_MIN, A, A);
384 RewriteResponse response3 = d_rewriter->postRewrite(n3);
385 ASSERT_TRUE(response3.d_node == A
386 && response3.d_status == REWRITE_AGAIN_FULL);
387
388 // (intersection_min A (union_max A B) = A
389 Node n4 = d_nodeManager->mkNode(INTERSECTION_MIN, A, unionMaxAB);
390 RewriteResponse response4 = d_rewriter->postRewrite(n4);
391 ASSERT_TRUE(response4.d_node == A
392 && response4.d_status == REWRITE_AGAIN_FULL);
393
394 // (intersection_min A (union_max B A) = A
395 Node n5 = d_nodeManager->mkNode(INTERSECTION_MIN, A, unionMaxBA);
396 RewriteResponse response5 = d_rewriter->postRewrite(n5);
397 ASSERT_TRUE(response5.d_node == A
398 && response4.d_status == REWRITE_AGAIN_FULL);
399
400 // (intersection_min (union_max A B) A) = A
401 Node n6 = d_nodeManager->mkNode(INTERSECTION_MIN, unionMaxAB, A);
402 RewriteResponse response6 = d_rewriter->postRewrite(n6);
403 ASSERT_TRUE(response6.d_node == A
404 && response6.d_status == REWRITE_AGAIN_FULL);
405
406 // (intersection_min (union_max B A) A) = A
407 Node n7 = d_nodeManager->mkNode(INTERSECTION_MIN, unionMaxBA, A);
408 RewriteResponse response7 = d_rewriter->postRewrite(n7);
409 ASSERT_TRUE(response7.d_node == A
410 && response7.d_status == REWRITE_AGAIN_FULL);
411
412 // (intersection_min A (union_disjoint A B) = A
413 Node n8 = d_nodeManager->mkNode(INTERSECTION_MIN, A, unionDisjointAB);
414 RewriteResponse response8 = d_rewriter->postRewrite(n8);
415 ASSERT_TRUE(response8.d_node == A
416 && response8.d_status == REWRITE_AGAIN_FULL);
417
418 // (intersection_min A (union_disjoint B A) = A
419 Node n9 = d_nodeManager->mkNode(INTERSECTION_MIN, A, unionDisjointBA);
420 RewriteResponse response9 = d_rewriter->postRewrite(n9);
421 ASSERT_TRUE(response9.d_node == A
422 && response9.d_status == REWRITE_AGAIN_FULL);
423
424 // (intersection_min (union_disjoint A B) A) = A
425 Node n10 = d_nodeManager->mkNode(INTERSECTION_MIN, unionDisjointAB, A);
426 RewriteResponse response10 = d_rewriter->postRewrite(n10);
427 ASSERT_TRUE(response10.d_node == A
428 && response10.d_status == REWRITE_AGAIN_FULL);
429
430 // (intersection_min (union_disjoint B A) A) = A
431 Node n11 = d_nodeManager->mkNode(INTERSECTION_MIN, unionDisjointBA, A);
432 RewriteResponse response11 = d_rewriter->postRewrite(n11);
433 ASSERT_TRUE(response11.d_node == A
434 && response11.d_status == REWRITE_AGAIN_FULL);
435 }
436
437 TEST_F(TestTheoryWhiteBagsRewriter, difference_subtract)
438 {
439 int n = 3;
440 std::vector<Node> elements = getNStrings(2);
441 Node emptyBag = d_nodeManager->mkConst(
442 EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
443 Node A = d_nodeManager->mkBag(d_nodeManager->stringType(),
444 elements[0],
445 d_nodeManager->mkConst(Rational(n)));
446 Node B = d_nodeManager->mkBag(d_nodeManager->stringType(),
447 elements[1],
448 d_nodeManager->mkConst(Rational(n + 1)));
449 Node unionMaxAB = d_nodeManager->mkNode(UNION_MAX, A, B);
450 Node unionMaxBA = d_nodeManager->mkNode(UNION_MAX, B, A);
451 Node unionDisjointAB = d_nodeManager->mkNode(UNION_DISJOINT, A, B);
452 Node unionDisjointBA = d_nodeManager->mkNode(UNION_DISJOINT, B, A);
453 Node intersectionAB = d_nodeManager->mkNode(INTERSECTION_MIN, A, B);
454 Node intersectionBA = d_nodeManager->mkNode(INTERSECTION_MIN, B, A);
455
456 // (difference_subtract A emptybag) = A
457 Node n1 = d_nodeManager->mkNode(DIFFERENCE_SUBTRACT, A, emptyBag);
458 RewriteResponse response1 = d_rewriter->postRewrite(n1);
459 ASSERT_TRUE(response1.d_node == A
460 && response1.d_status == REWRITE_AGAIN_FULL);
461
462 // (difference_subtract emptybag A) = emptyBag
463 Node n2 = d_nodeManager->mkNode(DIFFERENCE_SUBTRACT, emptyBag, A);
464 RewriteResponse response2 = d_rewriter->postRewrite(n2);
465 ASSERT_TRUE(response2.d_node == emptyBag
466 && response2.d_status == REWRITE_AGAIN_FULL);
467
468 // (difference_subtract A A) = emptybag
469 Node n3 = d_nodeManager->mkNode(DIFFERENCE_SUBTRACT, A, A);
470 RewriteResponse response3 = d_rewriter->postRewrite(n3);
471 ASSERT_TRUE(response3.d_node == emptyBag
472 && response3.d_status == REWRITE_AGAIN_FULL);
473
474 // (difference_subtract (union_disjoint A B) A) = B
475 Node n4 = d_nodeManager->mkNode(DIFFERENCE_SUBTRACT, unionDisjointAB, A);
476 RewriteResponse response4 = d_rewriter->postRewrite(n4);
477 ASSERT_TRUE(response4.d_node == B
478 && response4.d_status == REWRITE_AGAIN_FULL);
479
480 // (difference_subtract (union_disjoint B A) A) = B
481 Node n5 = d_nodeManager->mkNode(DIFFERENCE_SUBTRACT, unionDisjointBA, A);
482 RewriteResponse response5 = d_rewriter->postRewrite(n5);
483 ASSERT_TRUE(response5.d_node == B
484 && response4.d_status == REWRITE_AGAIN_FULL);
485
486 // (difference_subtract A (union_disjoint A B)) = emptybag
487 Node n6 = d_nodeManager->mkNode(DIFFERENCE_SUBTRACT, A, unionDisjointAB);
488 RewriteResponse response6 = d_rewriter->postRewrite(n6);
489 ASSERT_TRUE(response6.d_node == emptyBag
490 && response6.d_status == REWRITE_AGAIN_FULL);
491
492 // (difference_subtract A (union_disjoint B A)) = emptybag
493 Node n7 = d_nodeManager->mkNode(DIFFERENCE_SUBTRACT, A, unionDisjointBA);
494 RewriteResponse response7 = d_rewriter->postRewrite(n7);
495 ASSERT_TRUE(response7.d_node == emptyBag
496 && response7.d_status == REWRITE_AGAIN_FULL);
497
498 // (difference_subtract A (union_max A B)) = emptybag
499 Node n8 = d_nodeManager->mkNode(DIFFERENCE_SUBTRACT, A, unionMaxAB);
500 RewriteResponse response8 = d_rewriter->postRewrite(n8);
501 ASSERT_TRUE(response8.d_node == emptyBag
502 && response8.d_status == REWRITE_AGAIN_FULL);
503
504 // (difference_subtract A (union_max B A)) = emptybag
505 Node n9 = d_nodeManager->mkNode(DIFFERENCE_SUBTRACT, A, unionMaxBA);
506 RewriteResponse response9 = d_rewriter->postRewrite(n9);
507 ASSERT_TRUE(response9.d_node == emptyBag
508 && response9.d_status == REWRITE_AGAIN_FULL);
509
510 // (difference_subtract (intersection_min A B) A) = emptybag
511 Node n10 = d_nodeManager->mkNode(DIFFERENCE_SUBTRACT, intersectionAB, A);
512 RewriteResponse response10 = d_rewriter->postRewrite(n10);
513 ASSERT_TRUE(response10.d_node == emptyBag
514 && response10.d_status == REWRITE_AGAIN_FULL);
515
516 // (difference_subtract (intersection_min B A) A) = emptybag
517 Node n11 = d_nodeManager->mkNode(DIFFERENCE_SUBTRACT, intersectionBA, A);
518 RewriteResponse response11 = d_rewriter->postRewrite(n11);
519 ASSERT_TRUE(response11.d_node == emptyBag
520 && response11.d_status == REWRITE_AGAIN_FULL);
521 }
522
523 TEST_F(TestTheoryWhiteBagsRewriter, difference_remove)
524 {
525 int n = 3;
526 std::vector<Node> elements = getNStrings(2);
527 Node emptyBag = d_nodeManager->mkConst(
528 EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
529 Node A = d_nodeManager->mkBag(d_nodeManager->stringType(),
530 elements[0],
531 d_nodeManager->mkConst(Rational(n)));
532 Node B = d_nodeManager->mkBag(d_nodeManager->stringType(),
533 elements[1],
534 d_nodeManager->mkConst(Rational(n + 1)));
535 Node unionMaxAB = d_nodeManager->mkNode(UNION_MAX, A, B);
536 Node unionMaxBA = d_nodeManager->mkNode(UNION_MAX, B, A);
537 Node unionDisjointAB = d_nodeManager->mkNode(UNION_DISJOINT, A, B);
538 Node unionDisjointBA = d_nodeManager->mkNode(UNION_DISJOINT, B, A);
539 Node intersectionAB = d_nodeManager->mkNode(INTERSECTION_MIN, A, B);
540 Node intersectionBA = d_nodeManager->mkNode(INTERSECTION_MIN, B, A);
541
542 // (difference_remove A emptybag) = A
543 Node n1 = d_nodeManager->mkNode(DIFFERENCE_REMOVE, A, emptyBag);
544 RewriteResponse response1 = d_rewriter->postRewrite(n1);
545 ASSERT_TRUE(response1.d_node == A
546 && response1.d_status == REWRITE_AGAIN_FULL);
547
548 // (difference_remove emptybag A) = emptyBag
549 Node n2 = d_nodeManager->mkNode(DIFFERENCE_REMOVE, emptyBag, A);
550 RewriteResponse response2 = d_rewriter->postRewrite(n2);
551 ASSERT_TRUE(response2.d_node == emptyBag
552 && response2.d_status == REWRITE_AGAIN_FULL);
553
554 // (difference_remove A A) = emptybag
555 Node n3 = d_nodeManager->mkNode(DIFFERENCE_REMOVE, A, A);
556 RewriteResponse response3 = d_rewriter->postRewrite(n3);
557 ASSERT_TRUE(response3.d_node == emptyBag
558 && response3.d_status == REWRITE_AGAIN_FULL);
559
560 // (difference_remove A (union_disjoint A B)) = emptybag
561 Node n6 = d_nodeManager->mkNode(DIFFERENCE_REMOVE, A, unionDisjointAB);
562 RewriteResponse response6 = d_rewriter->postRewrite(n6);
563 ASSERT_TRUE(response6.d_node == emptyBag
564 && response6.d_status == REWRITE_AGAIN_FULL);
565
566 // (difference_remove A (union_disjoint B A)) = emptybag
567 Node n7 = d_nodeManager->mkNode(DIFFERENCE_REMOVE, A, unionDisjointBA);
568 RewriteResponse response7 = d_rewriter->postRewrite(n7);
569 ASSERT_TRUE(response7.d_node == emptyBag
570 && response7.d_status == REWRITE_AGAIN_FULL);
571
572 // (difference_remove A (union_max A B)) = emptybag
573 Node n8 = d_nodeManager->mkNode(DIFFERENCE_REMOVE, A, unionMaxAB);
574 RewriteResponse response8 = d_rewriter->postRewrite(n8);
575 ASSERT_TRUE(response8.d_node == emptyBag
576 && response8.d_status == REWRITE_AGAIN_FULL);
577
578 // (difference_remove A (union_max B A)) = emptybag
579 Node n9 = d_nodeManager->mkNode(DIFFERENCE_REMOVE, A, unionMaxBA);
580 RewriteResponse response9 = d_rewriter->postRewrite(n9);
581 ASSERT_TRUE(response9.d_node == emptyBag
582 && response9.d_status == REWRITE_AGAIN_FULL);
583
584 // (difference_remove (intersection_min A B) A) = emptybag
585 Node n10 = d_nodeManager->mkNode(DIFFERENCE_REMOVE, intersectionAB, A);
586 RewriteResponse response10 = d_rewriter->postRewrite(n10);
587 ASSERT_TRUE(response10.d_node == emptyBag
588 && response10.d_status == REWRITE_AGAIN_FULL);
589
590 // (difference_remove (intersection_min B A) A) = emptybag
591 Node n11 = d_nodeManager->mkNode(DIFFERENCE_REMOVE, intersectionBA, A);
592 RewriteResponse response11 = d_rewriter->postRewrite(n11);
593 ASSERT_TRUE(response11.d_node == emptyBag
594 && response11.d_status == REWRITE_AGAIN_FULL);
595 }
596
597 TEST_F(TestTheoryWhiteBagsRewriter, choose)
598 {
599 Node x = d_skolemManager->mkDummySkolem("x", d_nodeManager->stringType());
600 Node c = d_nodeManager->mkConst(Rational(3));
601 Node bag = d_nodeManager->mkBag(d_nodeManager->stringType(), x, c);
602
603 // (bag.choose (mkBag x c)) = x where c is a constant > 0
604 Node n1 = d_nodeManager->mkNode(BAG_CHOOSE, bag);
605 RewriteResponse response1 = d_rewriter->postRewrite(n1);
606 ASSERT_TRUE(response1.d_node == x
607 && response1.d_status == REWRITE_AGAIN_FULL);
608 }
609
610 TEST_F(TestTheoryWhiteBagsRewriter, bag_card)
611 {
612 Node x = d_skolemManager->mkDummySkolem("x", d_nodeManager->stringType());
613 Node emptyBag = d_nodeManager->mkConst(
614 EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
615 Node zero = d_nodeManager->mkConst(Rational(0));
616 Node c = d_nodeManager->mkConst(Rational(3));
617 Node bag = d_nodeManager->mkBag(d_nodeManager->stringType(), x, c);
618 std::vector<Node> elements = getNStrings(2);
619 Node A = d_nodeManager->mkBag(d_nodeManager->stringType(),
620 elements[0],
621 d_nodeManager->mkConst(Rational(4)));
622 Node B = d_nodeManager->mkBag(d_nodeManager->stringType(),
623 elements[1],
624 d_nodeManager->mkConst(Rational(5)));
625 Node unionDisjointAB = d_nodeManager->mkNode(UNION_DISJOINT, A, B);
626
627 // TODO(projects#223): enable this test after implementing bags normal form
628 // // (bag.card emptybag) = 0
629 // Node n1 = d_nodeManager->mkNode(BAG_CARD, emptyBag);
630 // RewriteResponse response1 = d_rewriter->postRewrite(n1);
631 // ASSERT_TRUE(response1.d_node == zero && response1.d_status ==
632 // REWRITE_AGAIN_FULL);
633
634 // (bag.card (mkBag x c)) = c where c is a constant > 0
635 Node n2 = d_nodeManager->mkNode(BAG_CARD, bag);
636 RewriteResponse response2 = d_rewriter->postRewrite(n2);
637 ASSERT_TRUE(response2.d_node == c
638 && response2.d_status == REWRITE_AGAIN_FULL);
639
640 // (bag.card (union-disjoint A B)) = (+ (bag.card A) (bag.card B))
641 Node n3 = d_nodeManager->mkNode(BAG_CARD, unionDisjointAB);
642 Node cardA = d_nodeManager->mkNode(BAG_CARD, A);
643 Node cardB = d_nodeManager->mkNode(BAG_CARD, B);
644 Node plus = d_nodeManager->mkNode(PLUS, cardA, cardB);
645 RewriteResponse response3 = d_rewriter->postRewrite(n3);
646 ASSERT_TRUE(response3.d_node == plus
647 && response3.d_status == REWRITE_AGAIN_FULL);
648 }
649
650 TEST_F(TestTheoryWhiteBagsRewriter, is_singleton)
651 {
652 Node emptybag = d_nodeManager->mkConst(
653 EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
654 Node x = d_skolemManager->mkDummySkolem("x", d_nodeManager->stringType());
655 Node c = d_skolemManager->mkDummySkolem("c", d_nodeManager->integerType());
656 Node bag = d_nodeManager->mkBag(d_nodeManager->stringType(), x, c);
657
658 // TODO(projects#223): complete this function
659 // (bag.is_singleton emptybag) = false
660 // Node n1 = d_nodeManager->mkNode(BAG_IS_SINGLETON, emptybag);
661 // RewriteResponse response1 = d_rewriter->postRewrite(n1);
662 // ASSERT_TRUE(response1.d_node == d_nodeManager->mkConst(false)
663 // && response1.d_status == REWRITE_AGAIN_FULL);
664
665 // (bag.is_singleton (mkBag x c) = (c == 1)
666 Node n2 = d_nodeManager->mkNode(BAG_IS_SINGLETON, bag);
667 RewriteResponse response2 = d_rewriter->postRewrite(n2);
668 Node one = d_nodeManager->mkConst(Rational(1));
669 Node equal = c.eqNode(one);
670 ASSERT_TRUE(response2.d_node == equal
671 && response2.d_status == REWRITE_AGAIN_FULL);
672 }
673
674 TEST_F(TestTheoryWhiteBagsRewriter, from_set)
675 {
676 Node x = d_skolemManager->mkDummySkolem("x", d_nodeManager->stringType());
677 Node singleton = d_nodeManager->mkSingleton(d_nodeManager->stringType(), x);
678
679 // (bag.from_set (singleton (singleton_op Int) x)) = (mkBag x 1)
680 Node n = d_nodeManager->mkNode(BAG_FROM_SET, singleton);
681 RewriteResponse response = d_rewriter->postRewrite(n);
682 Node one = d_nodeManager->mkConst(Rational(1));
683 Node bag = d_nodeManager->mkBag(d_nodeManager->stringType(), x, one);
684 ASSERT_TRUE(response.d_node == bag
685 && response.d_status == REWRITE_AGAIN_FULL);
686 }
687
688 TEST_F(TestTheoryWhiteBagsRewriter, to_set)
689 {
690 Node x = d_skolemManager->mkDummySkolem("x", d_nodeManager->stringType());
691 Node bag = d_nodeManager->mkBag(
692 d_nodeManager->stringType(), x, d_nodeManager->mkConst(Rational(5)));
693
694 // (bag.to_set (mkBag x n)) = (singleton (singleton_op T) x)
695 Node n = d_nodeManager->mkNode(BAG_TO_SET, bag);
696 RewriteResponse response = d_rewriter->postRewrite(n);
697 Node singleton = d_nodeManager->mkSingleton(d_nodeManager->stringType(), x);
698 ASSERT_TRUE(response.d_node == singleton
699 && response.d_status == REWRITE_AGAIN_FULL);
700 }
701
702 TEST_F(TestTheoryWhiteBagsRewriter, map)
703 {
704 TypeNode bagStringType =
705 d_nodeManager->mkBagType(d_nodeManager->stringType());
706 Node emptybagString = d_nodeManager->mkConst(EmptyBag(bagStringType));
707
708 Node empty = d_nodeManager->mkConst(String(""));
709 Node xString = d_nodeManager->mkBoundVar("x", d_nodeManager->stringType());
710 Node bound = d_nodeManager->mkNode(kind::BOUND_VAR_LIST, xString);
711 Node lambda = d_nodeManager->mkNode(LAMBDA, bound, empty);
712
713 // (bag.map (lambda ((x U)) t) emptybag) = emptybag
714 Node n1 = d_nodeManager->mkNode(BAG_MAP, lambda, emptybagString);
715 RewriteResponse response1 = d_rewriter->postRewrite(n1);
716 ASSERT_TRUE(response1.d_node == emptybagString
717 && response1.d_status == REWRITE_AGAIN_FULL);
718
719 std::vector<Node> elements = getNStrings(2);
720 Node a = d_nodeManager->mkConst(String("a"));
721 Node b = d_nodeManager->mkConst(String("b"));
722 Node A = d_nodeManager->mkBag(
723 d_nodeManager->stringType(), a, d_nodeManager->mkConst(Rational(3)));
724 Node B = d_nodeManager->mkBag(
725 d_nodeManager->stringType(), b, d_nodeManager->mkConst(Rational(4)));
726 Node unionDisjointAB = d_nodeManager->mkNode(UNION_DISJOINT, A, B);
727
728 ASSERT_TRUE(unionDisjointAB.isConst());
729
730 // (bag.map (lambda ((x Int)) "") (union_disjoint (bag "a" 3) (bag "b" 4))) =
731 // (bag "" 7))
732 Node n2 = d_nodeManager->mkNode(BAG_MAP, lambda, unionDisjointAB);
733
734 Node rewritten = Rewriter::rewrite(n2);
735 Node bag = d_nodeManager->mkBag(
736 d_nodeManager->stringType(), empty, d_nodeManager->mkConst(Rational(7)));
737 ASSERT_TRUE(rewritten == bag);
738 }
739
740 } // namespace test
741 } // namespace cvc5