1 /******************************************************************************
2 * Top contributors (to current version):
3 * Mudathir Mohamed, Aina Niemetz
5 * This file is part of the cvc5 project.
7 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 * in the top-level source directory and their institutional affiliations.
9 * All rights reserved. See the file COPYING in the top-level source
10 * directory for licensing information.
11 * ****************************************************************************
13 * Utility functions for bags.
15 #include "bags_utils.h"
17 #include "expr/emptybag.h"
18 #include "smt/logic_exception.h"
19 #include "theory/sets/normal_form.h"
20 #include "theory/type_enumerator.h"
21 #include "util/rational.h"
23 using namespace cvc5::kind
;
29 Node
BagsUtils::computeDisjointUnion(TypeNode bagType
,
30 const std::vector
<Node
>& bags
)
32 NodeManager
* nm
= NodeManager::currentNM();
35 return nm
->mkConst(EmptyBag(bagType
));
41 Node unionDisjoint
= bags
[0];
42 for (size_t i
= 1; i
< bags
.size(); i
++)
44 if (bags
[i
].getKind() == BAG_EMPTY
)
48 unionDisjoint
= nm
->mkNode(BAG_UNION_DISJOINT
, unionDisjoint
, bags
[i
]);
53 bool BagsUtils::isConstant(TNode n
)
55 if (n
.getKind() == BAG_EMPTY
)
57 // empty bags are already normalized
60 if (n
.getKind() == BAG_MAKE
)
62 // see the implementation in MkBagTypeRule::computeIsConst
65 if (n
.getKind() == BAG_UNION_DISJOINT
)
67 if (!(n
[0].getKind() == kind::BAG_MAKE
&& n
[0].isConst()))
69 // the first child is not a constant
72 // store the previous element to check the ordering of elements
73 Node previousElement
= n
[0][0];
75 while (current
.getKind() == BAG_UNION_DISJOINT
)
77 if (!(current
[0].getKind() == kind::BAG_MAKE
&& current
[0].isConst()))
79 // the current element is not a constant
82 if (previousElement
>= current
[0][0])
84 // the ordering is violated
87 previousElement
= current
[0][0];
91 if (!(current
.getKind() == kind::BAG_MAKE
&& current
.isConst()))
93 // the last element is not a constant
96 if (previousElement
>= current
[0])
98 // the ordering is violated
104 // only nodes with kinds EMPTY_BAG, BAG_MAKE, and BAG_UNION_DISJOINT can be
109 bool BagsUtils::areChildrenConstants(TNode n
)
111 return std::all_of(n
.begin(), n
.end(), [](Node c
) { return c
.isConst(); });
114 Node
BagsUtils::evaluate(TNode n
)
116 Assert(areChildrenConstants(n
));
119 // a constant node is already in a normal form
124 case BAG_MAKE
: return evaluateMakeBag(n
);
125 case BAG_COUNT
: return evaluateBagCount(n
);
126 case BAG_DUPLICATE_REMOVAL
: return evaluateDuplicateRemoval(n
);
127 case BAG_UNION_DISJOINT
: return evaluateUnionDisjoint(n
);
128 case BAG_UNION_MAX
: return evaluateUnionMax(n
);
129 case BAG_INTER_MIN
: return evaluateIntersectionMin(n
);
130 case BAG_DIFFERENCE_SUBTRACT
: return evaluateDifferenceSubtract(n
);
131 case BAG_DIFFERENCE_REMOVE
: return evaluateDifferenceRemove(n
);
132 case BAG_CARD
: return evaluateCard(n
);
133 case BAG_IS_SINGLETON
: return evaluateIsSingleton(n
);
134 case BAG_FROM_SET
: return evaluateFromSet(n
);
135 case BAG_TO_SET
: return evaluateToSet(n
);
136 case BAG_MAP
: return evaluateBagMap(n
);
137 case BAG_FILTER
: return evaluateBagFilter(n
);
138 case BAG_FOLD
: return evaluateBagFold(n
);
141 Unhandled() << "Unexpected bag kind '" << n
.getKind() << "' in node " << n
145 template <typename T1
, typename T2
, typename T3
, typename T4
, typename T5
>
146 Node
BagsUtils::evaluateBinaryOperation(const TNode
& n
,
153 std::map
<Node
, Rational
> elementsA
= getBagElements(n
[0]);
154 std::map
<Node
, Rational
> elementsB
= getBagElements(n
[1]);
155 std::map
<Node
, Rational
> elements
;
157 std::map
<Node
, Rational
>::const_iterator itA
= elementsA
.begin();
158 std::map
<Node
, Rational
>::const_iterator itB
= elementsB
.begin();
160 Trace("bags-evaluate") << "[NormalForm::evaluateBinaryOperation "
161 << n
.getKind() << "] " << std::endl
162 << "elements A: " << elementsA
<< std::endl
163 << "elements B: " << elementsB
<< std::endl
;
165 while (itA
!= elementsA
.end() && itB
!= elementsB
.end())
167 if (itA
->first
== itB
->first
)
169 equal(elements
, itA
, itB
);
173 else if (itA
->first
< itB
->first
)
175 less(elements
, itA
, itB
);
180 greaterOrEqual(elements
, itA
, itB
);
185 // handle the remaining elements from A
186 remainderOfA(elements
, elementsA
, itA
);
187 // handle the remaining elements from B
188 remainderOfB(elements
, elementsB
, itB
);
190 Trace("bags-evaluate") << "elements: " << elements
<< std::endl
;
191 Node bag
= constructConstantBagFromElements(n
.getType(), elements
);
192 Trace("bags-evaluate") << "bag: " << bag
<< std::endl
;
196 std::map
<Node
, Rational
> BagsUtils::getBagElements(TNode n
)
198 std::map
<Node
, Rational
> elements
;
199 if (n
.getKind() == BAG_EMPTY
)
203 while (n
.getKind() == kind::BAG_UNION_DISJOINT
)
205 Assert(n
[0].getKind() == kind::BAG_MAKE
);
206 Node element
= n
[0][0];
207 Rational count
= n
[0][1].getConst
<Rational
>();
208 elements
[element
] = count
;
211 Assert(n
.getKind() == kind::BAG_MAKE
);
212 Node lastElement
= n
[0];
213 Rational lastCount
= n
[1].getConst
<Rational
>();
214 elements
[lastElement
] = lastCount
;
218 Node
BagsUtils::constructConstantBagFromElements(
219 TypeNode t
, const std::map
<Node
, Rational
>& elements
)
222 NodeManager
* nm
= NodeManager::currentNM();
223 if (elements
.empty())
225 return nm
->mkConst(EmptyBag(t
));
227 TypeNode elementType
= t
.getBagElementType();
228 std::map
<Node
, Rational
>::const_reverse_iterator it
= elements
.rbegin();
229 Node bag
= nm
->mkBag(elementType
, it
->first
, nm
->mkConstInt(it
->second
));
230 while (++it
!= elements
.rend())
232 Node n
= nm
->mkBag(elementType
, it
->first
, nm
->mkConstInt(it
->second
));
233 bag
= nm
->mkNode(BAG_UNION_DISJOINT
, n
, bag
);
238 Node
BagsUtils::constructBagFromElements(TypeNode t
,
239 const std::map
<Node
, Node
>& elements
)
242 NodeManager
* nm
= NodeManager::currentNM();
243 if (elements
.empty())
245 return nm
->mkConst(EmptyBag(t
));
247 TypeNode elementType
= t
.getBagElementType();
248 std::map
<Node
, Node
>::const_reverse_iterator it
= elements
.rbegin();
249 Node bag
= nm
->mkBag(elementType
, it
->first
, it
->second
);
250 while (++it
!= elements
.rend())
252 Node n
= nm
->mkBag(elementType
, it
->first
, it
->second
);
253 bag
= nm
->mkNode(BAG_UNION_DISJOINT
, n
, bag
);
258 Node
BagsUtils::evaluateMakeBag(TNode n
)
260 // the case where n is const should be handled earlier.
261 // here we handle the case where the multiplicity is zero or negative
262 Assert(n
.getKind() == BAG_MAKE
&& !n
.isConst()
263 && n
[1].getConst
<Rational
>().sgn() < 1);
264 Node emptybag
= NodeManager::currentNM()->mkConst(EmptyBag(n
.getType()));
268 Node
BagsUtils::evaluateBagCount(TNode n
)
270 Assert(n
.getKind() == BAG_COUNT
);
273 // - (bag.count "x" (as bag.empty (Bag String))) = 0
274 // - (bag.count "x" (bag "y" 5)) = 0
275 // - (bag.count "x" (bag "x" 4)) = 4
276 // - (bag.count "x" (bag.union_disjoint (bag "x" 4) (bag "y" 5)) = 4
277 // - (bag.count "x" (bag.union_disjoint (bag "y" 5) (bag "z" 5)) = 0
279 std::map
<Node
, Rational
> elements
= getBagElements(n
[1]);
280 std::map
<Node
, Rational
>::iterator it
= elements
.find(n
[0]);
282 NodeManager
* nm
= NodeManager::currentNM();
283 if (it
!= elements
.end())
285 Node count
= nm
->mkConstInt(it
->second
);
288 return nm
->mkConstInt(Rational(0));
291 Node
BagsUtils::evaluateDuplicateRemoval(TNode n
)
293 Assert(n
.getKind() == BAG_DUPLICATE_REMOVAL
);
297 // - (bag.duplicate_removal (as bag.empty (Bag String))) = (as bag.empty (Bag
299 // - (bag.duplicate_removal (bag "x" 4)) = (bag "x" 1)
300 // - (bag.duplicate_removal (bag.disjoint_union (bag "x" 3) (bag "y" 5)) =
301 // (bag.disjoint_union (bag "x" 1) (bag "y" 1)
303 std::map
<Node
, Rational
> oldElements
= getBagElements(n
[0]);
304 // copy elements from the old bag
305 std::map
<Node
, Rational
> newElements(oldElements
);
306 Rational one
= Rational(1);
307 std::map
<Node
, Rational
>::iterator it
;
308 for (it
= newElements
.begin(); it
!= newElements
.end(); it
++)
312 Node bag
= constructConstantBagFromElements(n
[0].getType(), newElements
);
316 Node
BagsUtils::evaluateUnionDisjoint(TNode n
)
318 Assert(n
.getKind() == BAG_UNION_DISJOINT
);
321 // input: (bag.union_disjoint A B)
322 // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
323 // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
325 // (bag.union_disjoint A B)
326 // where A = (bag "x" 7)
327 // B = (bag.union_disjoint (bag "y" 1) (bag "z" 2)))
329 auto equal
= [](std::map
<Node
, Rational
>& elements
,
330 std::map
<Node
, Rational
>::const_iterator
& itA
,
331 std::map
<Node
, Rational
>::const_iterator
& itB
) {
332 // compute the sum of the multiplicities
333 elements
[itA
->first
] = itA
->second
+ itB
->second
;
336 auto less
= [](std::map
<Node
, Rational
>& elements
,
337 std::map
<Node
, Rational
>::const_iterator
& itA
,
338 std::map
<Node
, Rational
>::const_iterator
& itB
) {
339 // add the element to the result
340 elements
[itA
->first
] = itA
->second
;
343 auto greaterOrEqual
= [](std::map
<Node
, Rational
>& elements
,
344 std::map
<Node
, Rational
>::const_iterator
& itA
,
345 std::map
<Node
, Rational
>::const_iterator
& itB
) {
346 // add the element to the result
347 elements
[itB
->first
] = itB
->second
;
350 auto remainderOfA
= [](std::map
<Node
, Rational
>& elements
,
351 std::map
<Node
, Rational
>& elementsA
,
352 std::map
<Node
, Rational
>::const_iterator
& itA
) {
353 // append the remainder of A
354 while (itA
!= elementsA
.end())
356 elements
[itA
->first
] = itA
->second
;
361 auto remainderOfB
= [](std::map
<Node
, Rational
>& elements
,
362 std::map
<Node
, Rational
>& elementsB
,
363 std::map
<Node
, Rational
>::const_iterator
& itB
) {
364 // append the remainder of B
365 while (itB
!= elementsB
.end())
367 elements
[itB
->first
] = itB
->second
;
372 return evaluateBinaryOperation(
373 n
, equal
, less
, greaterOrEqual
, remainderOfA
, remainderOfB
);
376 Node
BagsUtils::evaluateUnionMax(TNode n
)
378 Assert(n
.getKind() == BAG_UNION_MAX
);
381 // input: (bag.union_max A B)
382 // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
383 // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
385 // (bag.union_disjoint A B)
386 // where A = (bag "x" 4)
387 // B = (bag.union_disjoint (bag "y" 1) (bag "z" 2)))
389 auto equal
= [](std::map
<Node
, Rational
>& elements
,
390 std::map
<Node
, Rational
>::const_iterator
& itA
,
391 std::map
<Node
, Rational
>::const_iterator
& itB
) {
392 // compute the maximum multiplicity
393 elements
[itA
->first
] = std::max(itA
->second
, itB
->second
);
396 auto less
= [](std::map
<Node
, Rational
>& elements
,
397 std::map
<Node
, Rational
>::const_iterator
& itA
,
398 std::map
<Node
, Rational
>::const_iterator
& itB
) {
400 elements
[itA
->first
] = itA
->second
;
403 auto greaterOrEqual
= [](std::map
<Node
, Rational
>& elements
,
404 std::map
<Node
, Rational
>::const_iterator
& itA
,
405 std::map
<Node
, Rational
>::const_iterator
& itB
) {
407 elements
[itB
->first
] = itB
->second
;
410 auto remainderOfA
= [](std::map
<Node
, Rational
>& elements
,
411 std::map
<Node
, Rational
>& elementsA
,
412 std::map
<Node
, Rational
>::const_iterator
& itA
) {
413 // append the remainder of A
414 while (itA
!= elementsA
.end())
416 elements
[itA
->first
] = itA
->second
;
421 auto remainderOfB
= [](std::map
<Node
, Rational
>& elements
,
422 std::map
<Node
, Rational
>& elementsB
,
423 std::map
<Node
, Rational
>::const_iterator
& itB
) {
424 // append the remainder of B
425 while (itB
!= elementsB
.end())
427 elements
[itB
->first
] = itB
->second
;
432 return evaluateBinaryOperation(
433 n
, equal
, less
, greaterOrEqual
, remainderOfA
, remainderOfB
);
436 Node
BagsUtils::evaluateIntersectionMin(TNode n
)
438 Assert(n
.getKind() == BAG_INTER_MIN
);
441 // input: (bag.inter_min A B)
442 // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
443 // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
447 auto equal
= [](std::map
<Node
, Rational
>& elements
,
448 std::map
<Node
, Rational
>::const_iterator
& itA
,
449 std::map
<Node
, Rational
>::const_iterator
& itB
) {
450 // compute the minimum multiplicity
451 elements
[itA
->first
] = std::min(itA
->second
, itB
->second
);
454 auto less
= [](std::map
<Node
, Rational
>& elements
,
455 std::map
<Node
, Rational
>::const_iterator
& itA
,
456 std::map
<Node
, Rational
>::const_iterator
& itB
) {
460 auto greaterOrEqual
= [](std::map
<Node
, Rational
>& elements
,
461 std::map
<Node
, Rational
>::const_iterator
& itA
,
462 std::map
<Node
, Rational
>::const_iterator
& itB
) {
466 auto remainderOfA
= [](std::map
<Node
, Rational
>& elements
,
467 std::map
<Node
, Rational
>& elementsA
,
468 std::map
<Node
, Rational
>::const_iterator
& itA
) {
472 auto remainderOfB
= [](std::map
<Node
, Rational
>& elements
,
473 std::map
<Node
, Rational
>& elementsB
,
474 std::map
<Node
, Rational
>::const_iterator
& itB
) {
478 return evaluateBinaryOperation(
479 n
, equal
, less
, greaterOrEqual
, remainderOfA
, remainderOfB
);
482 Node
BagsUtils::evaluateDifferenceSubtract(TNode n
)
484 Assert(n
.getKind() == BAG_DIFFERENCE_SUBTRACT
);
487 // input: (bag.difference_subtract A B)
488 // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
489 // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
491 // (bag.union_disjoint (bag "x" 1) (bag "z" 2))
493 auto equal
= [](std::map
<Node
, Rational
>& elements
,
494 std::map
<Node
, Rational
>::const_iterator
& itA
,
495 std::map
<Node
, Rational
>::const_iterator
& itB
) {
496 // subtract the multiplicities
497 elements
[itA
->first
] = itA
->second
- itB
->second
;
500 auto less
= [](std::map
<Node
, Rational
>& elements
,
501 std::map
<Node
, Rational
>::const_iterator
& itA
,
502 std::map
<Node
, Rational
>::const_iterator
& itB
) {
503 // itA->first is not in B, so we add it to the difference subtract
504 elements
[itA
->first
] = itA
->second
;
507 auto greaterOrEqual
= [](std::map
<Node
, Rational
>& elements
,
508 std::map
<Node
, Rational
>::const_iterator
& itA
,
509 std::map
<Node
, Rational
>::const_iterator
& itB
) {
510 // itB->first is not in A, so we just skip it
513 auto remainderOfA
= [](std::map
<Node
, Rational
>& elements
,
514 std::map
<Node
, Rational
>& elementsA
,
515 std::map
<Node
, Rational
>::const_iterator
& itA
) {
516 // append the remainder of A
517 while (itA
!= elementsA
.end())
519 elements
[itA
->first
] = itA
->second
;
524 auto remainderOfB
= [](std::map
<Node
, Rational
>& elements
,
525 std::map
<Node
, Rational
>& elementsB
,
526 std::map
<Node
, Rational
>::const_iterator
& itB
) {
530 return evaluateBinaryOperation(
531 n
, equal
, less
, greaterOrEqual
, remainderOfA
, remainderOfB
);
534 Node
BagsUtils::evaluateDifferenceRemove(TNode n
)
536 Assert(n
.getKind() == BAG_DIFFERENCE_REMOVE
);
539 // input: (bag.difference_remove A B)
540 // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
541 // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
545 auto equal
= [](std::map
<Node
, Rational
>& elements
,
546 std::map
<Node
, Rational
>::const_iterator
& itA
,
547 std::map
<Node
, Rational
>::const_iterator
& itB
) {
548 // skip the shared element by doing nothing
551 auto less
= [](std::map
<Node
, Rational
>& elements
,
552 std::map
<Node
, Rational
>::const_iterator
& itA
,
553 std::map
<Node
, Rational
>::const_iterator
& itB
) {
554 // itA->first is not in B, so we add it to the difference remove
555 elements
[itA
->first
] = itA
->second
;
558 auto greaterOrEqual
= [](std::map
<Node
, Rational
>& elements
,
559 std::map
<Node
, Rational
>::const_iterator
& itA
,
560 std::map
<Node
, Rational
>::const_iterator
& itB
) {
561 // itB->first is not in A, so we just skip it
564 auto remainderOfA
= [](std::map
<Node
, Rational
>& elements
,
565 std::map
<Node
, Rational
>& elementsA
,
566 std::map
<Node
, Rational
>::const_iterator
& itA
) {
567 // append the remainder of A
568 while (itA
!= elementsA
.end())
570 elements
[itA
->first
] = itA
->second
;
575 auto remainderOfB
= [](std::map
<Node
, Rational
>& elements
,
576 std::map
<Node
, Rational
>& elementsB
,
577 std::map
<Node
, Rational
>::const_iterator
& itB
) {
581 return evaluateBinaryOperation(
582 n
, equal
, less
, greaterOrEqual
, remainderOfA
, remainderOfB
);
585 Node
BagsUtils::evaluateChoose(TNode n
)
587 Assert(n
.getKind() == BAG_CHOOSE
);
590 // - (bag.choose (bag "x" 4)) = "x"
592 if (n
[0].getKind() == BAG_MAKE
)
596 throw LogicException("BAG_CHOOSE_TOTAL is not supported yet");
599 Node
BagsUtils::evaluateCard(TNode n
)
601 Assert(n
.getKind() == BAG_CARD
);
604 // - (card (as bag.empty (Bag String))) = 0
605 // - (bag.choose (bag "x" 4)) = 4
606 // - (bag.choose (bag.union_disjoint (bag "x" 4) (bag "y" 1))) = 5
608 std::map
<Node
, Rational
> elements
= getBagElements(n
[0]);
610 for (std::pair
<Node
, Rational
> element
: elements
)
612 sum
+= element
.second
;
615 NodeManager
* nm
= NodeManager::currentNM();
616 Node sumNode
= nm
->mkConstInt(sum
);
620 Node
BagsUtils::evaluateIsSingleton(TNode n
)
622 Assert(n
.getKind() == BAG_IS_SINGLETON
);
625 // - (bag.is_singleton (as bag.empty (Bag String))) = false
626 // - (bag.is_singleton (bag "x" 1)) = true
627 // - (bag.is_singleton (bag "x" 4)) = false
628 // - (bag.is_singleton (bag.union_disjoint (bag "x" 1) (bag "y" 1)))
631 if (n
[0].getKind() == BAG_MAKE
&& n
[0][1].getConst
<Rational
>().isOne())
633 return NodeManager::currentNM()->mkConst(true);
635 return NodeManager::currentNM()->mkConst(false);
638 Node
BagsUtils::evaluateFromSet(TNode n
)
640 Assert(n
.getKind() == BAG_FROM_SET
);
644 // - (bag.from_set (as set.empty (Set String))) = (as bag.empty (Bag String))
645 // - (bag.from_set (set.singleton "x")) = (bag "x" 1)
646 // - (bag.from_set (set.union (set.singleton "x") (set.singleton "y"))) =
647 // (bag.disjoint_union (bag "x" 1) (bag "y" 1))
649 NodeManager
* nm
= NodeManager::currentNM();
650 std::set
<Node
> setElements
=
651 sets::NormalForm::getElementsFromNormalConstant(n
[0]);
652 Rational one
= Rational(1);
653 std::map
<Node
, Rational
> bagElements
;
654 for (const Node
& element
: setElements
)
656 bagElements
[element
] = one
;
658 TypeNode bagType
= nm
->mkBagType(n
[0].getType().getSetElementType());
659 Node bag
= constructConstantBagFromElements(bagType
, bagElements
);
663 Node
BagsUtils::evaluateToSet(TNode n
)
665 Assert(n
.getKind() == BAG_TO_SET
);
669 // - (bag.to_set (as bag.empty (Bag String))) = (as set.empty (Set String))
670 // - (bag.to_set (bag "x" 4)) = (set.singleton "x")
671 // - (bag.to_set (bag.disjoint_union (bag "x" 3) (bag "y" 5)) =
672 // (set.union (set.singleton "x") (set.singleton "y")))
674 NodeManager
* nm
= NodeManager::currentNM();
675 std::map
<Node
, Rational
> bagElements
= getBagElements(n
[0]);
676 std::set
<Node
> setElements
;
677 std::map
<Node
, Rational
>::const_reverse_iterator it
;
678 for (it
= bagElements
.rbegin(); it
!= bagElements
.rend(); it
++)
680 setElements
.insert(it
->first
);
682 TypeNode setType
= nm
->mkSetType(n
[0].getType().getBagElementType());
683 Node set
= sets::NormalForm::elementsToSet(setElements
, setType
);
687 Node
BagsUtils::evaluateBagMap(TNode n
)
689 Assert(n
.getKind() == BAG_MAP
);
693 // - (bag.map ((lambda ((x String)) "z")
694 // (bag.union_disjoint (bag "a" 2) (bag "b" 3)) =
695 // (bag.union_disjoint
696 // (bag ((lambda ((x String)) "z") "a") 2)
697 // (bag ((lambda ((x String)) "z") "b") 3)) =
700 std::map
<Node
, Rational
> elements
= BagsUtils::getBagElements(n
[1]);
701 std::map
<Node
, Rational
> mappedElements
;
702 std::map
<Node
, Rational
>::iterator it
= elements
.begin();
703 NodeManager
* nm
= NodeManager::currentNM();
704 while (it
!= elements
.end())
706 Node mappedElement
= nm
->mkNode(APPLY_UF
, n
[0], it
->first
);
707 mappedElements
[mappedElement
] = it
->second
;
710 TypeNode t
= nm
->mkBagType(n
[0].getType().getRangeType());
711 Node ret
= BagsUtils::constructConstantBagFromElements(t
, mappedElements
);
715 Node
BagsUtils::evaluateBagFilter(TNode n
)
717 Assert(n
.getKind() == BAG_FILTER
);
719 // - (bag.filter p (as bag.empty (Bag T)) = (as bag.empty (Bag T))
720 // - (bag.filter p (bag.union_disjoint (bag "a" 3) (bag "b" 2))) =
721 // (bag.union_disjoint
722 // (ite (p "a") (bag "a" 3) (as bag.empty (Bag T)))
723 // (ite (p "b") (bag "b" 2) (as bag.empty (Bag T)))
727 TypeNode bagType
= A
.getType();
728 NodeManager
* nm
= NodeManager::currentNM();
729 Node empty
= nm
->mkConst(EmptyBag(bagType
));
731 std::map
<Node
, Rational
> elements
= getBagElements(n
[1]);
732 std::vector
<Node
> bags
;
734 for (const auto& [e
, count
] : elements
)
736 Node multiplicity
= nm
->mkConst(CONST_RATIONAL
, count
);
737 Node bag
= nm
->mkBag(bagType
.getBagElementType(), e
, multiplicity
);
738 Node pOfe
= nm
->mkNode(APPLY_UF
, P
, e
);
739 Node ite
= nm
->mkNode(ITE
, pOfe
, bag
, empty
);
742 Node ret
= computeDisjointUnion(bagType
, bags
);
746 Node
BagsUtils::evaluateBagFold(TNode n
)
748 Assert(n
.getKind() == BAG_FOLD
);
754 // ((lambda ((x String) (y String)) (ite (str.< x y) x y))
756 // (bag.union_disjoint (bag "a" 2) (bag "b" 3))
759 Node f
= n
[0]; // combining function
760 Node ret
= n
[1]; // initial value
761 Node A
= n
[2]; // bag
762 std::map
<Node
, Rational
> elements
= BagsUtils::getBagElements(A
);
764 std::map
<Node
, Rational
>::iterator it
= elements
.begin();
765 NodeManager
* nm
= NodeManager::currentNM();
766 while (it
!= elements
.end())
768 // apply the combination function n times, where n is the multiplicity
769 Rational count
= it
->second
;
770 Assert(count
.sgn() >= 0) << "negative multiplicity" << std::endl
;
771 while (!count
.isZero())
773 ret
= nm
->mkNode(APPLY_UF
, f
, it
->first
, ret
);
782 } // namespace theory