// (union_disjoint A B) // sum(a,b) = max(a,b) + min(a,b)
// check if the operands of union_max and intersection_min are the same
std::set<Node> left(n[0].begin(), n[0].end());
- std::set<Node> right(n[0].begin(), n[0].end());
+ std::set<Node> right(n[1].begin(), n[1].end());
if (left == right)
{
Node rewritten = d_nm->mkNode(UNION_DISJOINT, n[0][0], n[0][1]);
Trace("bags-model") << "Term set: " << termSet << std::endl;
+ std::set<Node> processedBags;
+
// get the relevant bag equivalence classes
for (const Node& n : termSet)
{
TypeNode tn = n.getType();
if (!tn.isBag())
{
+ // we are only concerned here about bag terms
continue;
}
Node r = d_state.getRepresentative(n);
+ if (processedBags.find(r) != processedBags.end())
+ {
+ // skip bags whose representatives are already processed
+ continue;
+ }
+
+ processedBags.insert(r);
+
std::set<Node> solverElements = d_state.getElements(r);
std::set<Node> elements;
// only consider terms in termSet and ignore other elements in the solver
void testUnionDisjoint()
{
int n = 3;
- vector<Node> elements = getNStrings(2);
+ vector<Node> elements = getNStrings(3);
Node emptyBag =
d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
Node A = d_nm->mkBag(
d_nm->stringType(), elements[0], d_nm->mkConst(Rational(n)));
Node B = d_nm->mkBag(
d_nm->stringType(), elements[1], d_nm->mkConst(Rational(n + 1)));
+ Node C = d_nm->mkBag(
+ d_nm->stringType(), elements[2], d_nm->mkConst(Rational(n + 2)));
+
Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
+ Node unionMaxAC = d_nm->mkNode(UNION_MAX, A, C);
Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
Node intersectionAB = d_nm->mkNode(INTERSECTION_MIN, A, B);
Node intersectionBA = d_nm->mkNode(INTERSECTION_MIN, B, A);
RewriteResponse response4 = d_rewriter->postRewrite(unionDisjoint4);
TS_ASSERT(response4.d_node == unionDisjointBA
&& response4.d_status == REWRITE_AGAIN_FULL);
+
+ // (union_disjoint (intersection_min B A)) (union_max A B) =
+ // (union_disjoint B A) // sum(a,b) = max(a,b) + min(a,b)
+ Node unionDisjoint5 =
+ d_nm->mkNode(UNION_DISJOINT, unionMaxAC, intersectionAB);
+ RewriteResponse response5 = d_rewriter->postRewrite(unionDisjoint5);
+ TS_ASSERT(response5.d_node == unionDisjoint5
+ && response5.d_status == REWRITE_DONE);
}
void testIntersectionMin()