add bag.fold operator (#7718)
[cvc5.git] / src / theory / bags / normal_form.cpp
1 /******************************************************************************
2 * Top contributors (to current version):
3 * Mudathir Mohamed, Aina Niemetz
4 *
5 * This file is part of the cvc5 project.
6 *
7 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 * in the top-level source directory and their institutional affiliations.
9 * All rights reserved. See the file COPYING in the top-level source
10 * directory for licensing information.
11 * ****************************************************************************
12 *
13 * Normal form for bag constants.
14 */
15 #include "normal_form.h"
16
17 #include "expr/emptybag.h"
18 #include "smt/logic_exception.h"
19 #include "theory/sets/normal_form.h"
20 #include "theory/type_enumerator.h"
21 #include "util/rational.h"
22
23 using namespace cvc5::kind;
24
25 namespace cvc5 {
26 namespace theory {
27 namespace bags {
28
29 bool NormalForm::isConstant(TNode n)
30 {
31 if (n.getKind() == BAG_EMPTY)
32 {
33 // empty bags are already normalized
34 return true;
35 }
36 if (n.getKind() == BAG_MAKE)
37 {
38 // see the implementation in MkBagTypeRule::computeIsConst
39 return n.isConst();
40 }
41 if (n.getKind() == BAG_UNION_DISJOINT)
42 {
43 if (!(n[0].getKind() == kind::BAG_MAKE && n[0].isConst()))
44 {
45 // the first child is not a constant
46 return false;
47 }
48 // store the previous element to check the ordering of elements
49 Node previousElement = n[0][0];
50 Node current = n[1];
51 while (current.getKind() == BAG_UNION_DISJOINT)
52 {
53 if (!(current[0].getKind() == kind::BAG_MAKE && current[0].isConst()))
54 {
55 // the current element is not a constant
56 return false;
57 }
58 if (previousElement >= current[0][0])
59 {
60 // the ordering is violated
61 return false;
62 }
63 previousElement = current[0][0];
64 current = current[1];
65 }
66 // check last element
67 if (!(current.getKind() == kind::BAG_MAKE && current.isConst()))
68 {
69 // the last element is not a constant
70 return false;
71 }
72 if (previousElement >= current[0])
73 {
74 // the ordering is violated
75 return false;
76 }
77 return true;
78 }
79
80 // only nodes with kinds EMPTY_BAG, BAG_MAKE, and BAG_UNION_DISJOINT can be
81 // constants
82 return false;
83 }
84
85 bool NormalForm::areChildrenConstants(TNode n)
86 {
87 return std::all_of(n.begin(), n.end(), [](Node c) { return c.isConst(); });
88 }
89
90 Node NormalForm::evaluate(TNode n)
91 {
92 Assert(areChildrenConstants(n));
93 if (n.isConst())
94 {
95 // a constant node is already in a normal form
96 return n;
97 }
98 switch (n.getKind())
99 {
100 case BAG_MAKE: return evaluateMakeBag(n);
101 case BAG_COUNT: return evaluateBagCount(n);
102 case BAG_DUPLICATE_REMOVAL: return evaluateDuplicateRemoval(n);
103 case BAG_UNION_DISJOINT: return evaluateUnionDisjoint(n);
104 case BAG_UNION_MAX: return evaluateUnionMax(n);
105 case BAG_INTER_MIN: return evaluateIntersectionMin(n);
106 case BAG_DIFFERENCE_SUBTRACT: return evaluateDifferenceSubtract(n);
107 case BAG_DIFFERENCE_REMOVE: return evaluateDifferenceRemove(n);
108 case BAG_CARD: return evaluateCard(n);
109 case BAG_IS_SINGLETON: return evaluateIsSingleton(n);
110 case BAG_FROM_SET: return evaluateFromSet(n);
111 case BAG_TO_SET: return evaluateToSet(n);
112 case BAG_MAP: return evaluateBagMap(n);
113 case BAG_FOLD: return evaluateBagFold(n);
114 default: break;
115 }
116 Unhandled() << "Unexpected bag kind '" << n.getKind() << "' in node " << n
117 << std::endl;
118 }
119
120 template <typename T1, typename T2, typename T3, typename T4, typename T5>
121 Node NormalForm::evaluateBinaryOperation(const TNode& n,
122 T1&& equal,
123 T2&& less,
124 T3&& greaterOrEqual,
125 T4&& remainderOfA,
126 T5&& remainderOfB)
127 {
128 std::map<Node, Rational> elementsA = getBagElements(n[0]);
129 std::map<Node, Rational> elementsB = getBagElements(n[1]);
130 std::map<Node, Rational> elements;
131
132 std::map<Node, Rational>::const_iterator itA = elementsA.begin();
133 std::map<Node, Rational>::const_iterator itB = elementsB.begin();
134
135 Trace("bags-evaluate") << "[NormalForm::evaluateBinaryOperation "
136 << n.getKind() << "] " << std::endl
137 << "elements A: " << elementsA << std::endl
138 << "elements B: " << elementsB << std::endl;
139
140 while (itA != elementsA.end() && itB != elementsB.end())
141 {
142 if (itA->first == itB->first)
143 {
144 equal(elements, itA, itB);
145 itA++;
146 itB++;
147 }
148 else if (itA->first < itB->first)
149 {
150 less(elements, itA, itB);
151 itA++;
152 }
153 else
154 {
155 greaterOrEqual(elements, itA, itB);
156 itB++;
157 }
158 }
159
160 // handle the remaining elements from A
161 remainderOfA(elements, elementsA, itA);
162 // handle the remaining elements from B
163 remainderOfB(elements, elementsB, itB);
164
165 Trace("bags-evaluate") << "elements: " << elements << std::endl;
166 Node bag = constructConstantBagFromElements(n.getType(), elements);
167 Trace("bags-evaluate") << "bag: " << bag << std::endl;
168 return bag;
169 }
170
171 std::map<Node, Rational> NormalForm::getBagElements(TNode n)
172 {
173 std::map<Node, Rational> elements;
174 if (n.getKind() == BAG_EMPTY)
175 {
176 return elements;
177 }
178 while (n.getKind() == kind::BAG_UNION_DISJOINT)
179 {
180 Assert(n[0].getKind() == kind::BAG_MAKE);
181 Node element = n[0][0];
182 Rational count = n[0][1].getConst<Rational>();
183 elements[element] = count;
184 n = n[1];
185 }
186 Assert(n.getKind() == kind::BAG_MAKE);
187 Node lastElement = n[0];
188 Rational lastCount = n[1].getConst<Rational>();
189 elements[lastElement] = lastCount;
190 return elements;
191 }
192
193 Node NormalForm::constructConstantBagFromElements(
194 TypeNode t, const std::map<Node, Rational>& elements)
195 {
196 Assert(t.isBag());
197 NodeManager* nm = NodeManager::currentNM();
198 if (elements.empty())
199 {
200 return nm->mkConst(EmptyBag(t));
201 }
202 TypeNode elementType = t.getBagElementType();
203 std::map<Node, Rational>::const_reverse_iterator it = elements.rbegin();
204 Node bag = nm->mkBag(elementType,
205 it->first,
206 nm->mkConst<Rational>(CONST_RATIONAL, it->second));
207 while (++it != elements.rend())
208 {
209 Node n = nm->mkBag(elementType,
210 it->first,
211 nm->mkConst<Rational>(CONST_RATIONAL, it->second));
212 bag = nm->mkNode(BAG_UNION_DISJOINT, n, bag);
213 }
214 return bag;
215 }
216
217 Node NormalForm::constructBagFromElements(TypeNode t,
218 const std::map<Node, Node>& elements)
219 {
220 Assert(t.isBag());
221 NodeManager* nm = NodeManager::currentNM();
222 if (elements.empty())
223 {
224 return nm->mkConst(EmptyBag(t));
225 }
226 TypeNode elementType = t.getBagElementType();
227 std::map<Node, Node>::const_reverse_iterator it = elements.rbegin();
228 Node bag = nm->mkBag(elementType, it->first, it->second);
229 while (++it != elements.rend())
230 {
231 Node n = nm->mkBag(elementType, it->first, it->second);
232 bag = nm->mkNode(BAG_UNION_DISJOINT, n, bag);
233 }
234 return bag;
235 }
236
237 Node NormalForm::evaluateMakeBag(TNode n)
238 {
239 // the case where n is const should be handled earlier.
240 // here we handle the case where the multiplicity is zero or negative
241 Assert(n.getKind() == BAG_MAKE && !n.isConst()
242 && n[1].getConst<Rational>().sgn() < 1);
243 Node emptybag = NodeManager::currentNM()->mkConst(EmptyBag(n.getType()));
244 return emptybag;
245 }
246
247 Node NormalForm::evaluateBagCount(TNode n)
248 {
249 Assert(n.getKind() == BAG_COUNT);
250 // Examples
251 // --------
252 // - (bag.count "x" (as bag.empty (Bag String))) = 0
253 // - (bag.count "x" (bag "y" 5)) = 0
254 // - (bag.count "x" (bag "x" 4)) = 4
255 // - (bag.count "x" (bag.union_disjoint (bag "x" 4) (bag "y" 5)) = 4
256 // - (bag.count "x" (bag.union_disjoint (bag "y" 5) (bag "z" 5)) = 0
257
258 std::map<Node, Rational> elements = getBagElements(n[1]);
259 std::map<Node, Rational>::iterator it = elements.find(n[0]);
260
261 NodeManager* nm = NodeManager::currentNM();
262 if (it != elements.end())
263 {
264 Node count = nm->mkConst(CONST_RATIONAL, it->second);
265 return count;
266 }
267 return nm->mkConst(CONST_RATIONAL, Rational(0));
268 }
269
270 Node NormalForm::evaluateDuplicateRemoval(TNode n)
271 {
272 Assert(n.getKind() == BAG_DUPLICATE_REMOVAL);
273
274 // Examples
275 // --------
276 // - (bag.duplicate_removal (as bag.empty (Bag String))) = (as bag.empty (Bag
277 // String))
278 // - (bag.duplicate_removal (bag "x" 4)) = (bag "x" 1)
279 // - (bag.duplicate_removal (bag.disjoint_union (bag "x" 3) (bag "y" 5)) =
280 // (bag.disjoint_union (bag "x" 1) (bag "y" 1)
281
282 std::map<Node, Rational> oldElements = getBagElements(n[0]);
283 // copy elements from the old bag
284 std::map<Node, Rational> newElements(oldElements);
285 Rational one = Rational(1);
286 std::map<Node, Rational>::iterator it;
287 for (it = newElements.begin(); it != newElements.end(); it++)
288 {
289 it->second = one;
290 }
291 Node bag = constructConstantBagFromElements(n[0].getType(), newElements);
292 return bag;
293 }
294
295 Node NormalForm::evaluateUnionDisjoint(TNode n)
296 {
297 Assert(n.getKind() == BAG_UNION_DISJOINT);
298 // Example
299 // -------
300 // input: (bag.union_disjoint A B)
301 // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
302 // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
303 // output:
304 // (bag.union_disjoint A B)
305 // where A = (bag "x" 7)
306 // B = (bag.union_disjoint (bag "y" 1) (bag "z" 2)))
307
308 auto equal = [](std::map<Node, Rational>& elements,
309 std::map<Node, Rational>::const_iterator& itA,
310 std::map<Node, Rational>::const_iterator& itB) {
311 // compute the sum of the multiplicities
312 elements[itA->first] = itA->second + itB->second;
313 };
314
315 auto less = [](std::map<Node, Rational>& elements,
316 std::map<Node, Rational>::const_iterator& itA,
317 std::map<Node, Rational>::const_iterator& itB) {
318 // add the element to the result
319 elements[itA->first] = itA->second;
320 };
321
322 auto greaterOrEqual = [](std::map<Node, Rational>& elements,
323 std::map<Node, Rational>::const_iterator& itA,
324 std::map<Node, Rational>::const_iterator& itB) {
325 // add the element to the result
326 elements[itB->first] = itB->second;
327 };
328
329 auto remainderOfA = [](std::map<Node, Rational>& elements,
330 std::map<Node, Rational>& elementsA,
331 std::map<Node, Rational>::const_iterator& itA) {
332 // append the remainder of A
333 while (itA != elementsA.end())
334 {
335 elements[itA->first] = itA->second;
336 itA++;
337 }
338 };
339
340 auto remainderOfB = [](std::map<Node, Rational>& elements,
341 std::map<Node, Rational>& elementsB,
342 std::map<Node, Rational>::const_iterator& itB) {
343 // append the remainder of B
344 while (itB != elementsB.end())
345 {
346 elements[itB->first] = itB->second;
347 itB++;
348 }
349 };
350
351 return evaluateBinaryOperation(
352 n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
353 }
354
355 Node NormalForm::evaluateUnionMax(TNode n)
356 {
357 Assert(n.getKind() == BAG_UNION_MAX);
358 // Example
359 // -------
360 // input: (bag.union_max A B)
361 // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
362 // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
363 // output:
364 // (bag.union_disjoint A B)
365 // where A = (bag "x" 4)
366 // B = (bag.union_disjoint (bag "y" 1) (bag "z" 2)))
367
368 auto equal = [](std::map<Node, Rational>& elements,
369 std::map<Node, Rational>::const_iterator& itA,
370 std::map<Node, Rational>::const_iterator& itB) {
371 // compute the maximum multiplicity
372 elements[itA->first] = std::max(itA->second, itB->second);
373 };
374
375 auto less = [](std::map<Node, Rational>& elements,
376 std::map<Node, Rational>::const_iterator& itA,
377 std::map<Node, Rational>::const_iterator& itB) {
378 // add to the result
379 elements[itA->first] = itA->second;
380 };
381
382 auto greaterOrEqual = [](std::map<Node, Rational>& elements,
383 std::map<Node, Rational>::const_iterator& itA,
384 std::map<Node, Rational>::const_iterator& itB) {
385 // add to the result
386 elements[itB->first] = itB->second;
387 };
388
389 auto remainderOfA = [](std::map<Node, Rational>& elements,
390 std::map<Node, Rational>& elementsA,
391 std::map<Node, Rational>::const_iterator& itA) {
392 // append the remainder of A
393 while (itA != elementsA.end())
394 {
395 elements[itA->first] = itA->second;
396 itA++;
397 }
398 };
399
400 auto remainderOfB = [](std::map<Node, Rational>& elements,
401 std::map<Node, Rational>& elementsB,
402 std::map<Node, Rational>::const_iterator& itB) {
403 // append the remainder of B
404 while (itB != elementsB.end())
405 {
406 elements[itB->first] = itB->second;
407 itB++;
408 }
409 };
410
411 return evaluateBinaryOperation(
412 n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
413 }
414
415 Node NormalForm::evaluateIntersectionMin(TNode n)
416 {
417 Assert(n.getKind() == BAG_INTER_MIN);
418 // Example
419 // -------
420 // input: (bag.inter_min A B)
421 // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
422 // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
423 // output:
424 // (bag "x" 3)
425
426 auto equal = [](std::map<Node, Rational>& elements,
427 std::map<Node, Rational>::const_iterator& itA,
428 std::map<Node, Rational>::const_iterator& itB) {
429 // compute the minimum multiplicity
430 elements[itA->first] = std::min(itA->second, itB->second);
431 };
432
433 auto less = [](std::map<Node, Rational>& elements,
434 std::map<Node, Rational>::const_iterator& itA,
435 std::map<Node, Rational>::const_iterator& itB) {
436 // do nothing
437 };
438
439 auto greaterOrEqual = [](std::map<Node, Rational>& elements,
440 std::map<Node, Rational>::const_iterator& itA,
441 std::map<Node, Rational>::const_iterator& itB) {
442 // do nothing
443 };
444
445 auto remainderOfA = [](std::map<Node, Rational>& elements,
446 std::map<Node, Rational>& elementsA,
447 std::map<Node, Rational>::const_iterator& itA) {
448 // do nothing
449 };
450
451 auto remainderOfB = [](std::map<Node, Rational>& elements,
452 std::map<Node, Rational>& elementsB,
453 std::map<Node, Rational>::const_iterator& itB) {
454 // do nothing
455 };
456
457 return evaluateBinaryOperation(
458 n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
459 }
460
461 Node NormalForm::evaluateDifferenceSubtract(TNode n)
462 {
463 Assert(n.getKind() == BAG_DIFFERENCE_SUBTRACT);
464 // Example
465 // -------
466 // input: (bag.difference_subtract A B)
467 // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
468 // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
469 // output:
470 // (bag.union_disjoint (bag "x" 1) (bag "z" 2))
471
472 auto equal = [](std::map<Node, Rational>& elements,
473 std::map<Node, Rational>::const_iterator& itA,
474 std::map<Node, Rational>::const_iterator& itB) {
475 // subtract the multiplicities
476 elements[itA->first] = itA->second - itB->second;
477 };
478
479 auto less = [](std::map<Node, Rational>& elements,
480 std::map<Node, Rational>::const_iterator& itA,
481 std::map<Node, Rational>::const_iterator& itB) {
482 // itA->first is not in B, so we add it to the difference subtract
483 elements[itA->first] = itA->second;
484 };
485
486 auto greaterOrEqual = [](std::map<Node, Rational>& elements,
487 std::map<Node, Rational>::const_iterator& itA,
488 std::map<Node, Rational>::const_iterator& itB) {
489 // itB->first is not in A, so we just skip it
490 };
491
492 auto remainderOfA = [](std::map<Node, Rational>& elements,
493 std::map<Node, Rational>& elementsA,
494 std::map<Node, Rational>::const_iterator& itA) {
495 // append the remainder of A
496 while (itA != elementsA.end())
497 {
498 elements[itA->first] = itA->second;
499 itA++;
500 }
501 };
502
503 auto remainderOfB = [](std::map<Node, Rational>& elements,
504 std::map<Node, Rational>& elementsB,
505 std::map<Node, Rational>::const_iterator& itB) {
506 // do nothing
507 };
508
509 return evaluateBinaryOperation(
510 n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
511 }
512
513 Node NormalForm::evaluateDifferenceRemove(TNode n)
514 {
515 Assert(n.getKind() == BAG_DIFFERENCE_REMOVE);
516 // Example
517 // -------
518 // input: (bag.difference_remove A B)
519 // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
520 // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
521 // output:
522 // (bag "z" 2)
523
524 auto equal = [](std::map<Node, Rational>& elements,
525 std::map<Node, Rational>::const_iterator& itA,
526 std::map<Node, Rational>::const_iterator& itB) {
527 // skip the shared element by doing nothing
528 };
529
530 auto less = [](std::map<Node, Rational>& elements,
531 std::map<Node, Rational>::const_iterator& itA,
532 std::map<Node, Rational>::const_iterator& itB) {
533 // itA->first is not in B, so we add it to the difference remove
534 elements[itA->first] = itA->second;
535 };
536
537 auto greaterOrEqual = [](std::map<Node, Rational>& elements,
538 std::map<Node, Rational>::const_iterator& itA,
539 std::map<Node, Rational>::const_iterator& itB) {
540 // itB->first is not in A, so we just skip it
541 };
542
543 auto remainderOfA = [](std::map<Node, Rational>& elements,
544 std::map<Node, Rational>& elementsA,
545 std::map<Node, Rational>::const_iterator& itA) {
546 // append the remainder of A
547 while (itA != elementsA.end())
548 {
549 elements[itA->first] = itA->second;
550 itA++;
551 }
552 };
553
554 auto remainderOfB = [](std::map<Node, Rational>& elements,
555 std::map<Node, Rational>& elementsB,
556 std::map<Node, Rational>::const_iterator& itB) {
557 // do nothing
558 };
559
560 return evaluateBinaryOperation(
561 n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
562 }
563
564 Node NormalForm::evaluateChoose(TNode n)
565 {
566 Assert(n.getKind() == BAG_CHOOSE);
567 // Examples
568 // --------
569 // - (bag.choose (bag "x" 4)) = "x"
570
571 if (n[0].getKind() == BAG_MAKE)
572 {
573 return n[0][0];
574 }
575 throw LogicException("BAG_CHOOSE_TOTAL is not supported yet");
576 }
577
578 Node NormalForm::evaluateCard(TNode n)
579 {
580 Assert(n.getKind() == BAG_CARD);
581 // Examples
582 // --------
583 // - (card (as bag.empty (Bag String))) = 0
584 // - (bag.choose (bag "x" 4)) = 4
585 // - (bag.choose (bag.union_disjoint (bag "x" 4) (bag "y" 1))) = 5
586
587 std::map<Node, Rational> elements = getBagElements(n[0]);
588 Rational sum(0);
589 for (std::pair<Node, Rational> element : elements)
590 {
591 sum += element.second;
592 }
593
594 NodeManager* nm = NodeManager::currentNM();
595 Node sumNode = nm->mkConst(CONST_RATIONAL, sum);
596 return sumNode;
597 }
598
599 Node NormalForm::evaluateIsSingleton(TNode n)
600 {
601 Assert(n.getKind() == BAG_IS_SINGLETON);
602 // Examples
603 // --------
604 // - (bag.is_singleton (as bag.empty (Bag String))) = false
605 // - (bag.is_singleton (bag "x" 1)) = true
606 // - (bag.is_singleton (bag "x" 4)) = false
607 // - (bag.is_singleton (bag.union_disjoint (bag "x" 1) (bag "y" 1)))
608 // = false
609
610 if (n[0].getKind() == BAG_MAKE && n[0][1].getConst<Rational>().isOne())
611 {
612 return NodeManager::currentNM()->mkConst(true);
613 }
614 return NodeManager::currentNM()->mkConst(false);
615 }
616
617 Node NormalForm::evaluateFromSet(TNode n)
618 {
619 Assert(n.getKind() == BAG_FROM_SET);
620
621 // Examples
622 // --------
623 // - (bag.from_set (as set.empty (Set String))) = (as bag.empty (Bag String))
624 // - (bag.from_set (set.singleton "x")) = (bag "x" 1)
625 // - (bag.from_set (set.union (set.singleton "x") (set.singleton "y"))) =
626 // (bag.disjoint_union (bag "x" 1) (bag "y" 1))
627
628 NodeManager* nm = NodeManager::currentNM();
629 std::set<Node> setElements =
630 sets::NormalForm::getElementsFromNormalConstant(n[0]);
631 Rational one = Rational(1);
632 std::map<Node, Rational> bagElements;
633 for (const Node& element : setElements)
634 {
635 bagElements[element] = one;
636 }
637 TypeNode bagType = nm->mkBagType(n[0].getType().getSetElementType());
638 Node bag = constructConstantBagFromElements(bagType, bagElements);
639 return bag;
640 }
641
642 Node NormalForm::evaluateToSet(TNode n)
643 {
644 Assert(n.getKind() == BAG_TO_SET);
645
646 // Examples
647 // --------
648 // - (bag.to_set (as bag.empty (Bag String))) = (as set.empty (Set String))
649 // - (bag.to_set (bag "x" 4)) = (set.singleton "x")
650 // - (bag.to_set (bag.disjoint_union (bag "x" 3) (bag "y" 5)) =
651 // (set.union (set.singleton "x") (set.singleton "y")))
652
653 NodeManager* nm = NodeManager::currentNM();
654 std::map<Node, Rational> bagElements = getBagElements(n[0]);
655 std::set<Node> setElements;
656 std::map<Node, Rational>::const_reverse_iterator it;
657 for (it = bagElements.rbegin(); it != bagElements.rend(); it++)
658 {
659 setElements.insert(it->first);
660 }
661 TypeNode setType = nm->mkSetType(n[0].getType().getBagElementType());
662 Node set = sets::NormalForm::elementsToSet(setElements, setType);
663 return set;
664 }
665
666 Node NormalForm::evaluateBagMap(TNode n)
667 {
668 Assert(n.getKind() == BAG_MAP);
669
670 // Examples
671 // --------
672 // - (bag.map ((lambda ((x String)) "z")
673 // (bag.union_disjoint (bag "a" 2) (bag "b" 3)) =
674 // (bag.union_disjoint
675 // (bag ((lambda ((x String)) "z") "a") 2)
676 // (bag ((lambda ((x String)) "z") "b") 3)) =
677 // (bag "z" 5)
678
679 std::map<Node, Rational> elements = NormalForm::getBagElements(n[1]);
680 std::map<Node, Rational> mappedElements;
681 std::map<Node, Rational>::iterator it = elements.begin();
682 NodeManager* nm = NodeManager::currentNM();
683 while (it != elements.end())
684 {
685 Node mappedElement = nm->mkNode(APPLY_UF, n[0], it->first);
686 mappedElements[mappedElement] = it->second;
687 ++it;
688 }
689 TypeNode t = nm->mkBagType(n[0].getType().getRangeType());
690 Node ret = NormalForm::constructConstantBagFromElements(t, mappedElements);
691 return ret;
692 }
693
694 Node NormalForm::evaluateBagFold(TNode n)
695 {
696 Assert(n.getKind() == BAG_FOLD);
697
698 // Examples
699 // --------
700 // minimum string
701 // - (bag.fold
702 // ((lambda ((x String) (y String)) (ite (str.< x y) x y))
703 // ""
704 // (bag.union_disjoint (bag "a" 2) (bag "b" 3))
705 // = "a"
706
707 Node f = n[0]; // combining function
708 Node ret = n[1]; // initial value
709 Node A = n[2]; // bag
710 std::map<Node, Rational> elements = NormalForm::getBagElements(A);
711
712 std::map<Node, Rational>::iterator it = elements.begin();
713 NodeManager* nm = NodeManager::currentNM();
714 while (it != elements.end())
715 {
716 // apply the combination function n times, where n is the multiplicity
717 Rational count = it->second;
718 Assert(count.sgn() >= 0) << "negative multiplicity" << std::endl;
719 while (!count.isZero())
720 {
721 ret = nm->mkNode(APPLY_UF, f, it->first, ret);
722 count = count - 1;
723 }
724 ++it;
725 }
726 return ret;
727 }
728
729 } // namespace bags
730 } // namespace theory
731 } // namespace cvc5