39987ce9d5dd9fe4bdba2820d5f215a1e823c91f
[cvc5.git] / src / theory / bags / bags_utils.cpp
1 /******************************************************************************
2 * Top contributors (to current version):
3 * Mudathir Mohamed, Aina Niemetz
4 *
5 * This file is part of the cvc5 project.
6 *
7 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 * in the top-level source directory and their institutional affiliations.
9 * All rights reserved. See the file COPYING in the top-level source
10 * directory for licensing information.
11 * ****************************************************************************
12 *
13 * Utility functions for bags.
14 */
15 #include "bags_utils.h"
16
17 #include "expr/emptybag.h"
18 #include "smt/logic_exception.h"
19 #include "theory/sets/normal_form.h"
20 #include "theory/type_enumerator.h"
21 #include "util/rational.h"
22
23 using namespace cvc5::kind;
24
25 namespace cvc5 {
26 namespace theory {
27 namespace bags {
28
29 Node BagsUtils::computeDisjointUnion(TypeNode bagType,
30 const std::vector<Node>& bags)
31 {
32 NodeManager* nm = NodeManager::currentNM();
33 if (bags.empty())
34 {
35 return nm->mkConst(EmptyBag(bagType));
36 }
37 if (bags.size() == 1)
38 {
39 return bags[0];
40 }
41 Node unionDisjoint = bags[0];
42 for (size_t i = 1; i < bags.size(); i++)
43 {
44 if (bags[i].getKind() == BAG_EMPTY)
45 {
46 continue;
47 }
48 unionDisjoint = nm->mkNode(BAG_UNION_DISJOINT, unionDisjoint, bags[i]);
49 }
50 return unionDisjoint;
51 }
52
53 bool BagsUtils::isConstant(TNode n)
54 {
55 if (n.getKind() == BAG_EMPTY)
56 {
57 // empty bags are already normalized
58 return true;
59 }
60 if (n.getKind() == BAG_MAKE)
61 {
62 // see the implementation in MkBagTypeRule::computeIsConst
63 return n.isConst();
64 }
65 if (n.getKind() == BAG_UNION_DISJOINT)
66 {
67 if (!(n[0].getKind() == kind::BAG_MAKE && n[0].isConst()))
68 {
69 // the first child is not a constant
70 return false;
71 }
72 // store the previous element to check the ordering of elements
73 Node previousElement = n[0][0];
74 Node current = n[1];
75 while (current.getKind() == BAG_UNION_DISJOINT)
76 {
77 if (!(current[0].getKind() == kind::BAG_MAKE && current[0].isConst()))
78 {
79 // the current element is not a constant
80 return false;
81 }
82 if (previousElement >= current[0][0])
83 {
84 // the ordering is violated
85 return false;
86 }
87 previousElement = current[0][0];
88 current = current[1];
89 }
90 // check last element
91 if (!(current.getKind() == kind::BAG_MAKE && current.isConst()))
92 {
93 // the last element is not a constant
94 return false;
95 }
96 if (previousElement >= current[0])
97 {
98 // the ordering is violated
99 return false;
100 }
101 return true;
102 }
103
104 // only nodes with kinds EMPTY_BAG, BAG_MAKE, and BAG_UNION_DISJOINT can be
105 // constants
106 return false;
107 }
108
109 bool BagsUtils::areChildrenConstants(TNode n)
110 {
111 return std::all_of(n.begin(), n.end(), [](Node c) { return c.isConst(); });
112 }
113
114 Node BagsUtils::evaluate(TNode n)
115 {
116 Assert(areChildrenConstants(n));
117 if (n.isConst())
118 {
119 // a constant node is already in a normal form
120 return n;
121 }
122 switch (n.getKind())
123 {
124 case BAG_MAKE: return evaluateMakeBag(n);
125 case BAG_COUNT: return evaluateBagCount(n);
126 case BAG_DUPLICATE_REMOVAL: return evaluateDuplicateRemoval(n);
127 case BAG_UNION_DISJOINT: return evaluateUnionDisjoint(n);
128 case BAG_UNION_MAX: return evaluateUnionMax(n);
129 case BAG_INTER_MIN: return evaluateIntersectionMin(n);
130 case BAG_DIFFERENCE_SUBTRACT: return evaluateDifferenceSubtract(n);
131 case BAG_DIFFERENCE_REMOVE: return evaluateDifferenceRemove(n);
132 case BAG_CARD: return evaluateCard(n);
133 case BAG_IS_SINGLETON: return evaluateIsSingleton(n);
134 case BAG_FROM_SET: return evaluateFromSet(n);
135 case BAG_TO_SET: return evaluateToSet(n);
136 case BAG_MAP: return evaluateBagMap(n);
137 case BAG_FILTER: return evaluateBagFilter(n);
138 case BAG_FOLD: return evaluateBagFold(n);
139 default: break;
140 }
141 Unhandled() << "Unexpected bag kind '" << n.getKind() << "' in node " << n
142 << std::endl;
143 }
144
145 template <typename T1, typename T2, typename T3, typename T4, typename T5>
146 Node BagsUtils::evaluateBinaryOperation(const TNode& n,
147 T1&& equal,
148 T2&& less,
149 T3&& greaterOrEqual,
150 T4&& remainderOfA,
151 T5&& remainderOfB)
152 {
153 std::map<Node, Rational> elementsA = getBagElements(n[0]);
154 std::map<Node, Rational> elementsB = getBagElements(n[1]);
155 std::map<Node, Rational> elements;
156
157 std::map<Node, Rational>::const_iterator itA = elementsA.begin();
158 std::map<Node, Rational>::const_iterator itB = elementsB.begin();
159
160 Trace("bags-evaluate") << "[NormalForm::evaluateBinaryOperation "
161 << n.getKind() << "] " << std::endl
162 << "elements A: " << elementsA << std::endl
163 << "elements B: " << elementsB << std::endl;
164
165 while (itA != elementsA.end() && itB != elementsB.end())
166 {
167 if (itA->first == itB->first)
168 {
169 equal(elements, itA, itB);
170 itA++;
171 itB++;
172 }
173 else if (itA->first < itB->first)
174 {
175 less(elements, itA, itB);
176 itA++;
177 }
178 else
179 {
180 greaterOrEqual(elements, itA, itB);
181 itB++;
182 }
183 }
184
185 // handle the remaining elements from A
186 remainderOfA(elements, elementsA, itA);
187 // handle the remaining elements from B
188 remainderOfB(elements, elementsB, itB);
189
190 Trace("bags-evaluate") << "elements: " << elements << std::endl;
191 Node bag = constructConstantBagFromElements(n.getType(), elements);
192 Trace("bags-evaluate") << "bag: " << bag << std::endl;
193 return bag;
194 }
195
196 std::map<Node, Rational> BagsUtils::getBagElements(TNode n)
197 {
198 std::map<Node, Rational> elements;
199 if (n.getKind() == BAG_EMPTY)
200 {
201 return elements;
202 }
203 while (n.getKind() == kind::BAG_UNION_DISJOINT)
204 {
205 Assert(n[0].getKind() == kind::BAG_MAKE);
206 Node element = n[0][0];
207 Rational count = n[0][1].getConst<Rational>();
208 elements[element] = count;
209 n = n[1];
210 }
211 Assert(n.getKind() == kind::BAG_MAKE);
212 Node lastElement = n[0];
213 Rational lastCount = n[1].getConst<Rational>();
214 elements[lastElement] = lastCount;
215 return elements;
216 }
217
218 Node BagsUtils::constructConstantBagFromElements(
219 TypeNode t, const std::map<Node, Rational>& elements)
220 {
221 Assert(t.isBag());
222 NodeManager* nm = NodeManager::currentNM();
223 if (elements.empty())
224 {
225 return nm->mkConst(EmptyBag(t));
226 }
227 TypeNode elementType = t.getBagElementType();
228 std::map<Node, Rational>::const_reverse_iterator it = elements.rbegin();
229 Node bag = nm->mkBag(elementType, it->first, nm->mkConstInt(it->second));
230 while (++it != elements.rend())
231 {
232 Node n = nm->mkBag(elementType, it->first, nm->mkConstInt(it->second));
233 bag = nm->mkNode(BAG_UNION_DISJOINT, n, bag);
234 }
235 return bag;
236 }
237
238 Node BagsUtils::constructBagFromElements(TypeNode t,
239 const std::map<Node, Node>& elements)
240 {
241 Assert(t.isBag());
242 NodeManager* nm = NodeManager::currentNM();
243 if (elements.empty())
244 {
245 return nm->mkConst(EmptyBag(t));
246 }
247 TypeNode elementType = t.getBagElementType();
248 std::map<Node, Node>::const_reverse_iterator it = elements.rbegin();
249 Node bag = nm->mkBag(elementType, it->first, it->second);
250 while (++it != elements.rend())
251 {
252 Node n = nm->mkBag(elementType, it->first, it->second);
253 bag = nm->mkNode(BAG_UNION_DISJOINT, n, bag);
254 }
255 return bag;
256 }
257
258 Node BagsUtils::evaluateMakeBag(TNode n)
259 {
260 // the case where n is const should be handled earlier.
261 // here we handle the case where the multiplicity is zero or negative
262 Assert(n.getKind() == BAG_MAKE && !n.isConst()
263 && n[1].getConst<Rational>().sgn() < 1);
264 Node emptybag = NodeManager::currentNM()->mkConst(EmptyBag(n.getType()));
265 return emptybag;
266 }
267
268 Node BagsUtils::evaluateBagCount(TNode n)
269 {
270 Assert(n.getKind() == BAG_COUNT);
271 // Examples
272 // --------
273 // - (bag.count "x" (as bag.empty (Bag String))) = 0
274 // - (bag.count "x" (bag "y" 5)) = 0
275 // - (bag.count "x" (bag "x" 4)) = 4
276 // - (bag.count "x" (bag.union_disjoint (bag "x" 4) (bag "y" 5)) = 4
277 // - (bag.count "x" (bag.union_disjoint (bag "y" 5) (bag "z" 5)) = 0
278
279 std::map<Node, Rational> elements = getBagElements(n[1]);
280 std::map<Node, Rational>::iterator it = elements.find(n[0]);
281
282 NodeManager* nm = NodeManager::currentNM();
283 if (it != elements.end())
284 {
285 Node count = nm->mkConstInt(it->second);
286 return count;
287 }
288 return nm->mkConstInt(Rational(0));
289 }
290
291 Node BagsUtils::evaluateDuplicateRemoval(TNode n)
292 {
293 Assert(n.getKind() == BAG_DUPLICATE_REMOVAL);
294
295 // Examples
296 // --------
297 // - (bag.duplicate_removal (as bag.empty (Bag String))) = (as bag.empty (Bag
298 // String))
299 // - (bag.duplicate_removal (bag "x" 4)) = (bag "x" 1)
300 // - (bag.duplicate_removal (bag.disjoint_union (bag "x" 3) (bag "y" 5)) =
301 // (bag.disjoint_union (bag "x" 1) (bag "y" 1)
302
303 std::map<Node, Rational> oldElements = getBagElements(n[0]);
304 // copy elements from the old bag
305 std::map<Node, Rational> newElements(oldElements);
306 Rational one = Rational(1);
307 std::map<Node, Rational>::iterator it;
308 for (it = newElements.begin(); it != newElements.end(); it++)
309 {
310 it->second = one;
311 }
312 Node bag = constructConstantBagFromElements(n[0].getType(), newElements);
313 return bag;
314 }
315
316 Node BagsUtils::evaluateUnionDisjoint(TNode n)
317 {
318 Assert(n.getKind() == BAG_UNION_DISJOINT);
319 // Example
320 // -------
321 // input: (bag.union_disjoint A B)
322 // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
323 // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
324 // output:
325 // (bag.union_disjoint A B)
326 // where A = (bag "x" 7)
327 // B = (bag.union_disjoint (bag "y" 1) (bag "z" 2)))
328
329 auto equal = [](std::map<Node, Rational>& elements,
330 std::map<Node, Rational>::const_iterator& itA,
331 std::map<Node, Rational>::const_iterator& itB) {
332 // compute the sum of the multiplicities
333 elements[itA->first] = itA->second + itB->second;
334 };
335
336 auto less = [](std::map<Node, Rational>& elements,
337 std::map<Node, Rational>::const_iterator& itA,
338 std::map<Node, Rational>::const_iterator& itB) {
339 // add the element to the result
340 elements[itA->first] = itA->second;
341 };
342
343 auto greaterOrEqual = [](std::map<Node, Rational>& elements,
344 std::map<Node, Rational>::const_iterator& itA,
345 std::map<Node, Rational>::const_iterator& itB) {
346 // add the element to the result
347 elements[itB->first] = itB->second;
348 };
349
350 auto remainderOfA = [](std::map<Node, Rational>& elements,
351 std::map<Node, Rational>& elementsA,
352 std::map<Node, Rational>::const_iterator& itA) {
353 // append the remainder of A
354 while (itA != elementsA.end())
355 {
356 elements[itA->first] = itA->second;
357 itA++;
358 }
359 };
360
361 auto remainderOfB = [](std::map<Node, Rational>& elements,
362 std::map<Node, Rational>& elementsB,
363 std::map<Node, Rational>::const_iterator& itB) {
364 // append the remainder of B
365 while (itB != elementsB.end())
366 {
367 elements[itB->first] = itB->second;
368 itB++;
369 }
370 };
371
372 return evaluateBinaryOperation(
373 n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
374 }
375
376 Node BagsUtils::evaluateUnionMax(TNode n)
377 {
378 Assert(n.getKind() == BAG_UNION_MAX);
379 // Example
380 // -------
381 // input: (bag.union_max A B)
382 // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
383 // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
384 // output:
385 // (bag.union_disjoint A B)
386 // where A = (bag "x" 4)
387 // B = (bag.union_disjoint (bag "y" 1) (bag "z" 2)))
388
389 auto equal = [](std::map<Node, Rational>& elements,
390 std::map<Node, Rational>::const_iterator& itA,
391 std::map<Node, Rational>::const_iterator& itB) {
392 // compute the maximum multiplicity
393 elements[itA->first] = std::max(itA->second, itB->second);
394 };
395
396 auto less = [](std::map<Node, Rational>& elements,
397 std::map<Node, Rational>::const_iterator& itA,
398 std::map<Node, Rational>::const_iterator& itB) {
399 // add to the result
400 elements[itA->first] = itA->second;
401 };
402
403 auto greaterOrEqual = [](std::map<Node, Rational>& elements,
404 std::map<Node, Rational>::const_iterator& itA,
405 std::map<Node, Rational>::const_iterator& itB) {
406 // add to the result
407 elements[itB->first] = itB->second;
408 };
409
410 auto remainderOfA = [](std::map<Node, Rational>& elements,
411 std::map<Node, Rational>& elementsA,
412 std::map<Node, Rational>::const_iterator& itA) {
413 // append the remainder of A
414 while (itA != elementsA.end())
415 {
416 elements[itA->first] = itA->second;
417 itA++;
418 }
419 };
420
421 auto remainderOfB = [](std::map<Node, Rational>& elements,
422 std::map<Node, Rational>& elementsB,
423 std::map<Node, Rational>::const_iterator& itB) {
424 // append the remainder of B
425 while (itB != elementsB.end())
426 {
427 elements[itB->first] = itB->second;
428 itB++;
429 }
430 };
431
432 return evaluateBinaryOperation(
433 n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
434 }
435
436 Node BagsUtils::evaluateIntersectionMin(TNode n)
437 {
438 Assert(n.getKind() == BAG_INTER_MIN);
439 // Example
440 // -------
441 // input: (bag.inter_min A B)
442 // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
443 // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
444 // output:
445 // (bag "x" 3)
446
447 auto equal = [](std::map<Node, Rational>& elements,
448 std::map<Node, Rational>::const_iterator& itA,
449 std::map<Node, Rational>::const_iterator& itB) {
450 // compute the minimum multiplicity
451 elements[itA->first] = std::min(itA->second, itB->second);
452 };
453
454 auto less = [](std::map<Node, Rational>& elements,
455 std::map<Node, Rational>::const_iterator& itA,
456 std::map<Node, Rational>::const_iterator& itB) {
457 // do nothing
458 };
459
460 auto greaterOrEqual = [](std::map<Node, Rational>& elements,
461 std::map<Node, Rational>::const_iterator& itA,
462 std::map<Node, Rational>::const_iterator& itB) {
463 // do nothing
464 };
465
466 auto remainderOfA = [](std::map<Node, Rational>& elements,
467 std::map<Node, Rational>& elementsA,
468 std::map<Node, Rational>::const_iterator& itA) {
469 // do nothing
470 };
471
472 auto remainderOfB = [](std::map<Node, Rational>& elements,
473 std::map<Node, Rational>& elementsB,
474 std::map<Node, Rational>::const_iterator& itB) {
475 // do nothing
476 };
477
478 return evaluateBinaryOperation(
479 n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
480 }
481
482 Node BagsUtils::evaluateDifferenceSubtract(TNode n)
483 {
484 Assert(n.getKind() == BAG_DIFFERENCE_SUBTRACT);
485 // Example
486 // -------
487 // input: (bag.difference_subtract A B)
488 // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
489 // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
490 // output:
491 // (bag.union_disjoint (bag "x" 1) (bag "z" 2))
492
493 auto equal = [](std::map<Node, Rational>& elements,
494 std::map<Node, Rational>::const_iterator& itA,
495 std::map<Node, Rational>::const_iterator& itB) {
496 // subtract the multiplicities
497 elements[itA->first] = itA->second - itB->second;
498 };
499
500 auto less = [](std::map<Node, Rational>& elements,
501 std::map<Node, Rational>::const_iterator& itA,
502 std::map<Node, Rational>::const_iterator& itB) {
503 // itA->first is not in B, so we add it to the difference subtract
504 elements[itA->first] = itA->second;
505 };
506
507 auto greaterOrEqual = [](std::map<Node, Rational>& elements,
508 std::map<Node, Rational>::const_iterator& itA,
509 std::map<Node, Rational>::const_iterator& itB) {
510 // itB->first is not in A, so we just skip it
511 };
512
513 auto remainderOfA = [](std::map<Node, Rational>& elements,
514 std::map<Node, Rational>& elementsA,
515 std::map<Node, Rational>::const_iterator& itA) {
516 // append the remainder of A
517 while (itA != elementsA.end())
518 {
519 elements[itA->first] = itA->second;
520 itA++;
521 }
522 };
523
524 auto remainderOfB = [](std::map<Node, Rational>& elements,
525 std::map<Node, Rational>& elementsB,
526 std::map<Node, Rational>::const_iterator& itB) {
527 // do nothing
528 };
529
530 return evaluateBinaryOperation(
531 n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
532 }
533
534 Node BagsUtils::evaluateDifferenceRemove(TNode n)
535 {
536 Assert(n.getKind() == BAG_DIFFERENCE_REMOVE);
537 // Example
538 // -------
539 // input: (bag.difference_remove A B)
540 // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
541 // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
542 // output:
543 // (bag "z" 2)
544
545 auto equal = [](std::map<Node, Rational>& elements,
546 std::map<Node, Rational>::const_iterator& itA,
547 std::map<Node, Rational>::const_iterator& itB) {
548 // skip the shared element by doing nothing
549 };
550
551 auto less = [](std::map<Node, Rational>& elements,
552 std::map<Node, Rational>::const_iterator& itA,
553 std::map<Node, Rational>::const_iterator& itB) {
554 // itA->first is not in B, so we add it to the difference remove
555 elements[itA->first] = itA->second;
556 };
557
558 auto greaterOrEqual = [](std::map<Node, Rational>& elements,
559 std::map<Node, Rational>::const_iterator& itA,
560 std::map<Node, Rational>::const_iterator& itB) {
561 // itB->first is not in A, so we just skip it
562 };
563
564 auto remainderOfA = [](std::map<Node, Rational>& elements,
565 std::map<Node, Rational>& elementsA,
566 std::map<Node, Rational>::const_iterator& itA) {
567 // append the remainder of A
568 while (itA != elementsA.end())
569 {
570 elements[itA->first] = itA->second;
571 itA++;
572 }
573 };
574
575 auto remainderOfB = [](std::map<Node, Rational>& elements,
576 std::map<Node, Rational>& elementsB,
577 std::map<Node, Rational>::const_iterator& itB) {
578 // do nothing
579 };
580
581 return evaluateBinaryOperation(
582 n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
583 }
584
585 Node BagsUtils::evaluateChoose(TNode n)
586 {
587 Assert(n.getKind() == BAG_CHOOSE);
588 // Examples
589 // --------
590 // - (bag.choose (bag "x" 4)) = "x"
591
592 if (n[0].getKind() == BAG_MAKE)
593 {
594 return n[0][0];
595 }
596 throw LogicException("BAG_CHOOSE_TOTAL is not supported yet");
597 }
598
599 Node BagsUtils::evaluateCard(TNode n)
600 {
601 Assert(n.getKind() == BAG_CARD);
602 // Examples
603 // --------
604 // - (card (as bag.empty (Bag String))) = 0
605 // - (bag.choose (bag "x" 4)) = 4
606 // - (bag.choose (bag.union_disjoint (bag "x" 4) (bag "y" 1))) = 5
607
608 std::map<Node, Rational> elements = getBagElements(n[0]);
609 Rational sum(0);
610 for (std::pair<Node, Rational> element : elements)
611 {
612 sum += element.second;
613 }
614
615 NodeManager* nm = NodeManager::currentNM();
616 Node sumNode = nm->mkConstInt(sum);
617 return sumNode;
618 }
619
620 Node BagsUtils::evaluateIsSingleton(TNode n)
621 {
622 Assert(n.getKind() == BAG_IS_SINGLETON);
623 // Examples
624 // --------
625 // - (bag.is_singleton (as bag.empty (Bag String))) = false
626 // - (bag.is_singleton (bag "x" 1)) = true
627 // - (bag.is_singleton (bag "x" 4)) = false
628 // - (bag.is_singleton (bag.union_disjoint (bag "x" 1) (bag "y" 1)))
629 // = false
630
631 if (n[0].getKind() == BAG_MAKE && n[0][1].getConst<Rational>().isOne())
632 {
633 return NodeManager::currentNM()->mkConst(true);
634 }
635 return NodeManager::currentNM()->mkConst(false);
636 }
637
638 Node BagsUtils::evaluateFromSet(TNode n)
639 {
640 Assert(n.getKind() == BAG_FROM_SET);
641
642 // Examples
643 // --------
644 // - (bag.from_set (as set.empty (Set String))) = (as bag.empty (Bag String))
645 // - (bag.from_set (set.singleton "x")) = (bag "x" 1)
646 // - (bag.from_set (set.union (set.singleton "x") (set.singleton "y"))) =
647 // (bag.disjoint_union (bag "x" 1) (bag "y" 1))
648
649 NodeManager* nm = NodeManager::currentNM();
650 std::set<Node> setElements =
651 sets::NormalForm::getElementsFromNormalConstant(n[0]);
652 Rational one = Rational(1);
653 std::map<Node, Rational> bagElements;
654 for (const Node& element : setElements)
655 {
656 bagElements[element] = one;
657 }
658 TypeNode bagType = nm->mkBagType(n[0].getType().getSetElementType());
659 Node bag = constructConstantBagFromElements(bagType, bagElements);
660 return bag;
661 }
662
663 Node BagsUtils::evaluateToSet(TNode n)
664 {
665 Assert(n.getKind() == BAG_TO_SET);
666
667 // Examples
668 // --------
669 // - (bag.to_set (as bag.empty (Bag String))) = (as set.empty (Set String))
670 // - (bag.to_set (bag "x" 4)) = (set.singleton "x")
671 // - (bag.to_set (bag.disjoint_union (bag "x" 3) (bag "y" 5)) =
672 // (set.union (set.singleton "x") (set.singleton "y")))
673
674 NodeManager* nm = NodeManager::currentNM();
675 std::map<Node, Rational> bagElements = getBagElements(n[0]);
676 std::set<Node> setElements;
677 std::map<Node, Rational>::const_reverse_iterator it;
678 for (it = bagElements.rbegin(); it != bagElements.rend(); it++)
679 {
680 setElements.insert(it->first);
681 }
682 TypeNode setType = nm->mkSetType(n[0].getType().getBagElementType());
683 Node set = sets::NormalForm::elementsToSet(setElements, setType);
684 return set;
685 }
686
687 Node BagsUtils::evaluateBagMap(TNode n)
688 {
689 Assert(n.getKind() == BAG_MAP);
690
691 // Examples
692 // --------
693 // - (bag.map ((lambda ((x String)) "z")
694 // (bag.union_disjoint (bag "a" 2) (bag "b" 3)) =
695 // (bag.union_disjoint
696 // (bag ((lambda ((x String)) "z") "a") 2)
697 // (bag ((lambda ((x String)) "z") "b") 3)) =
698 // (bag "z" 5)
699
700 std::map<Node, Rational> elements = BagsUtils::getBagElements(n[1]);
701 std::map<Node, Rational> mappedElements;
702 std::map<Node, Rational>::iterator it = elements.begin();
703 NodeManager* nm = NodeManager::currentNM();
704 while (it != elements.end())
705 {
706 Node mappedElement = nm->mkNode(APPLY_UF, n[0], it->first);
707 mappedElements[mappedElement] = it->second;
708 ++it;
709 }
710 TypeNode t = nm->mkBagType(n[0].getType().getRangeType());
711 Node ret = BagsUtils::constructConstantBagFromElements(t, mappedElements);
712 return ret;
713 }
714
715 Node BagsUtils::evaluateBagFilter(TNode n)
716 {
717 Assert(n.getKind() == BAG_FILTER);
718
719 // - (bag.filter p (as bag.empty (Bag T)) = (as bag.empty (Bag T))
720 // - (bag.filter p (bag.union_disjoint (bag "a" 3) (bag "b" 2))) =
721 // (bag.union_disjoint
722 // (ite (p "a") (bag "a" 3) (as bag.empty (Bag T)))
723 // (ite (p "b") (bag "b" 2) (as bag.empty (Bag T)))
724
725 Node P = n[0];
726 Node A = n[1];
727 TypeNode bagType = A.getType();
728 NodeManager* nm = NodeManager::currentNM();
729 Node empty = nm->mkConst(EmptyBag(bagType));
730
731 std::map<Node, Rational> elements = getBagElements(n[1]);
732 std::vector<Node> bags;
733
734 for (const auto& [e, count] : elements)
735 {
736 Node multiplicity = nm->mkConst(CONST_RATIONAL, count);
737 Node bag = nm->mkBag(bagType.getBagElementType(), e, multiplicity);
738 Node pOfe = nm->mkNode(APPLY_UF, P, e);
739 Node ite = nm->mkNode(ITE, pOfe, bag, empty);
740 bags.push_back(ite);
741 }
742 Node ret = computeDisjointUnion(bagType, bags);
743 return ret;
744 }
745
746 Node BagsUtils::evaluateBagFold(TNode n)
747 {
748 Assert(n.getKind() == BAG_FOLD);
749
750 // Examples
751 // --------
752 // minimum string
753 // - (bag.fold
754 // ((lambda ((x String) (y String)) (ite (str.< x y) x y))
755 // ""
756 // (bag.union_disjoint (bag "a" 2) (bag "b" 3))
757 // = "a"
758
759 Node f = n[0]; // combining function
760 Node ret = n[1]; // initial value
761 Node A = n[2]; // bag
762 std::map<Node, Rational> elements = BagsUtils::getBagElements(A);
763
764 std::map<Node, Rational>::iterator it = elements.begin();
765 NodeManager* nm = NodeManager::currentNM();
766 while (it != elements.end())
767 {
768 // apply the combination function n times, where n is the multiplicity
769 Rational count = it->second;
770 Assert(count.sgn() >= 0) << "negative multiplicity" << std::endl;
771 while (!count.isZero())
772 {
773 ret = nm->mkNode(APPLY_UF, f, it->first, ret);
774 count = count - 1;
775 }
776 ++it;
777 }
778 return ret;
779 }
780
781 } // namespace bags
782 } // namespace theory
783 } // namespace cvc5