From c0937f742479d8a5054e42597da9447d55e876c0 Mon Sep 17 00:00:00 2001 From: mudathirmahgoub Date: Mon, 1 Feb 2021 08:42:39 -0600 Subject: [PATCH] Fix BagsRewriter::rewriteUnionDisjoint (#5840) This PR fixes the implementation of (union_disjoint (union_max A B) (intersection_min A B)) =(union_disjoint A B). It also skips processed bags during model building. --- src/theory/bags/bags_rewriter.cpp | 2 +- src/theory/bags/theory_bags.cpp | 11 +++++++++++ test/unit/theory/theory_bags_rewriter_white.h | 14 +++++++++++++- 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/src/theory/bags/bags_rewriter.cpp b/src/theory/bags/bags_rewriter.cpp index 66886bfbf..9f53c29ca 100644 --- a/src/theory/bags/bags_rewriter.cpp +++ b/src/theory/bags/bags_rewriter.cpp @@ -246,7 +246,7 @@ BagsRewriteResponse BagsRewriter::rewriteUnionDisjoint(const TNode& n) const // (union_disjoint A B) // sum(a,b) = max(a,b) + min(a,b) // check if the operands of union_max and intersection_min are the same std::set left(n[0].begin(), n[0].end()); - std::set right(n[0].begin(), n[0].end()); + std::set right(n[1].begin(), n[1].end()); if (left == right) { Node rewritten = d_nm->mkNode(UNION_DISJOINT, n[0][0], n[0][1]); diff --git a/src/theory/bags/theory_bags.cpp b/src/theory/bags/theory_bags.cpp index 15e8e00e7..6df44295e 100644 --- a/src/theory/bags/theory_bags.cpp +++ b/src/theory/bags/theory_bags.cpp @@ -144,15 +144,26 @@ bool TheoryBags::collectModelValues(TheoryModel* m, Trace("bags-model") << "Term set: " << termSet << std::endl; + std::set processedBags; + // get the relevant bag equivalence classes for (const Node& n : termSet) { TypeNode tn = n.getType(); if (!tn.isBag()) { + // we are only concerned here about bag terms continue; } Node r = d_state.getRepresentative(n); + if (processedBags.find(r) != processedBags.end()) + { + // skip bags whose representatives are already processed + continue; + } + + processedBags.insert(r); + std::set solverElements = d_state.getElements(r); std::set elements; // only consider terms in termSet and ignore other elements in the solver diff --git a/test/unit/theory/theory_bags_rewriter_white.h b/test/unit/theory/theory_bags_rewriter_white.h index 98e3cf887..10a624238 100644 --- a/test/unit/theory/theory_bags_rewriter_white.h +++ b/test/unit/theory/theory_bags_rewriter_white.h @@ -280,16 +280,20 @@ class BagsTypeRuleWhite : public CxxTest::TestSuite void testUnionDisjoint() { int n = 3; - vector elements = getNStrings(2); + vector elements = getNStrings(3); Node emptyBag = d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); Node A = d_nm->mkBag( d_nm->stringType(), elements[0], d_nm->mkConst(Rational(n))); Node B = d_nm->mkBag( d_nm->stringType(), elements[1], d_nm->mkConst(Rational(n + 1))); + Node C = d_nm->mkBag( + d_nm->stringType(), elements[2], d_nm->mkConst(Rational(n + 2))); + Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B); Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A); Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B); + Node unionMaxAC = d_nm->mkNode(UNION_MAX, A, C); Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A); Node intersectionAB = d_nm->mkNode(INTERSECTION_MIN, A, B); Node intersectionBA = d_nm->mkNode(INTERSECTION_MIN, B, A); @@ -321,6 +325,14 @@ class BagsTypeRuleWhite : public CxxTest::TestSuite RewriteResponse response4 = d_rewriter->postRewrite(unionDisjoint4); TS_ASSERT(response4.d_node == unionDisjointBA && response4.d_status == REWRITE_AGAIN_FULL); + + // (union_disjoint (intersection_min B A)) (union_max A B) = + // (union_disjoint B A) // sum(a,b) = max(a,b) + min(a,b) + Node unionDisjoint5 = + d_nm->mkNode(UNION_DISJOINT, unionMaxAC, intersectionAB); + RewriteResponse response5 = d_rewriter->postRewrite(unionDisjoint5); + TS_ASSERT(response5.d_node == unionDisjoint5 + && response5.d_status == REWRITE_DONE); } void testIntersectionMin() -- 2.30.2